Python Code Snippets

工作中最常用的是Python, 因此, 把Python常用的代码单独列出来.

sqlite3 数据库

1
2
3
4
5
6
7
8
9
10
11
12
import sqlite3
sqlite_file = "./db.sqlite"
conn = sqlite3.connect(sqlite_file)
c = conn.cursor()
cmd = "DROP TABLE {table_name}".format(table_name=cfg.table_name)
c.execute(cmd)
cmd = 'CREATE TABLE {table_name} ({image} {image_dt}, {render} {render_dt}, {label} {label_dt}, {predict} {predict_dt}, {miou} {miou_dt})'.format(table_name=cfg.table_name, image="image", image_dt='TEXT', render='render', render_dt="TEXT", label="label", label_dt="TEXT", predict="predict", predict_dt="TEXT", miou="miou", miou_dt="REAL")
c.execute(cmd)
cmd = "INSERT INTO {table_name} (image, render, label, predict, miou) VALUES('{image}', '{render}', '{label}', '{predict}', {miou})".format(table_name=table_name, image=image, render=render, label=label, predict=predict, miou=miou)
c.execute(cmd)
conn.commit()
conn.close()

lmdb

这里需要特别注意的是 env.begin() 一定要在所有的 db 全部打开之后调用,否则出错

1
2
3
4
5
6
7
8
9
10
11
# 写数据库
env = lmdb.open("data.lmdb/", max_dbs=4, map_size=int(1e12))
train_data = env.open_db("train_data")
train_label = env.open_db("train_label")
with env.begin(write=True) as txn:
txn.put("key", data, db=train_data)

# 读取数据库,返回的是buffer,还需要转换
with env.begin(write=True) as txn:
buffer = txn.get("key", db=train_data)
data = np.frombuffer(buffer, dtype='uint8')

requests

下载文件

1
2
3
4
5
res = requests.get(gift_url, stream=True)
with open(os.path.join(config.img_dir, fname), "wb") as f:
for chuck in res.iter_content(1024):
f.write(chuck)
res.close()

上传(POST)文件

1
2
3
4
url = "http://10.84.145.69:8080"
img = {"image": open("./example.jpeg", "rb")}
res = requests.post(url, files=img)
print res.text)

设置编码

1
# -*- coding: utf-8 -*-

matplotlib

在没有显示设备的机器上保存图像

1
2
import matplotlib
matplotlib.use('Agg')

调用系统命令

1
2
import subprocess
subprocess.call(["ls", "-l", "/home/"])

配置日志输出格式

1
2
3
4
5
6
7
8
9
10
11
12
13
import logging
logger = logging.getLogger(__file__)
logger.setLevel(logging.DEBUG)
fh = logging.FileHandler("file.log", "w")
fh.setLevel(logging.DEBUG)
ch = logging.StreamHandler()
ch.setLevel(logging.DEBUG)
formatter = logging.Formatter("%(asctime)s %(filename)s %(lineno)i: %(levelname)s: %(message)s")
fh.setFormatter(formatter)
ch.setFormatter(formatter)
logger.addHandler(fh)
logger.addHandler(ch)
logger.info("message")

遍历目录下面的所有指定类型的文件,包括子目录的文件

1
2
3
4
5
6
def mk_list(target=None):
matches = []
for root, dirnames, filenames in os.walk(target):
for filename in fnmatch.filter(filenames, '*.jpg'):
matches.append(os.path.join(root, filename))
return matches

tornado

提供POST服务(接收上传的文件)

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
import tornado
import tornado.web
import tornado.ioloop
import os
import uuid

class Upload(tornado.web.RequestHandler):
def post(self):
fileinfo = self.request.files['image'][0]
fname = fileinfo['filename']
extn = os.path.splitext(fname)[1]
cname = str(uuid.uuid4())+extn
fh = open(__UPLOADS__+cname, 'w')
fh.write(fileinfo['body'])
print os.path.abspath(__UPLOADS__+cname)
ret = web_demo([os.path.abspath(__UPLOADS__ + cname)])
self.write(ret)

application = tornado.web.Application([
(r"/upload", Upload),
(r'/images/(.*)', tornado.web.StaticFileHandler,{'path': __static__path__}) # 直接访问获取图片
], debug= True)

if __name__ == "__main__":
application.listen(8888)
tornado.ioloop.IOLoop.instance().start()

提供POST服务(不接收上传的问题件)

1
2
3
4
5
6
7
8
9
10
11
12
13
import tornado.web
import tornado.ioloop
import json
import os, sys
class GetTaskHandler(tornado.web.RequestHandler):
def post(self):
parsed_json = json.loads(self.request.body)
task_type = parsed_json['task_type']
row = self.GetTask(task_req=task_type)
if row:
self.write(row)
else:
self.write('no more data')

multiprocessing

1
2
import multiprocessing
multiprocessing.Pool(processes=multiprocessing.cpu_count()).map(func, list) # func 是具体处理函数, list 是可迭代对象

获取当前文件所在的路径

1
cur_path = os.path.abspath(os.path.dirname(__file__))

argparse

1
2
3
parser = argparse.ArgumentParser(description='train an image classifer on cifar10')
parser.add_argument('--network', type=str, default='inception-bn-28-small', help = 'the cnn to use')
args = parser.parse_args()
0%