0%

Python Code Snippets

工作中最常用的是 Python, 因此,把 Python 常用的代码单独列出来。

使用邮件监控 mxnet 的训练过程

目前公司 GPU 集群的工作方式是把调试好的代码提交到集群排队,等集群调度到任务之后开始训练,因此,使用扫码登陆微信监控的方法已不适用,摸索了一套使用邮件监控的方法。简单来讲就是在训练完一个 epoch 之后给指定邮箱发送一封邮件,邮件中包含必要的信息,例如任务名称,各个评估 metric 的结果,当前训练的 epoch 数等。过程很简单,主要是两个方面,一是使用 python 调用命令行的 sendmail 命令发送邮件,二是mx.mod.Module.fit的回调。代码如下:

1
2
3
4
5
6
7
8
9
10
11
12
13
def sendmail(param):
job_name = "local test" if socket.gethostname() == "yz-gpu002.hogpu.cc" else os.environ["PBS_JOBNAME"]
text = job_name + "\n"
text = text + "EPOCH: " + str(param.epoch) + "\n"
for name, value in param.eval_metric.get_name_value():
text = text + name + ": " + str(value) + "\n"
msg = MIMEText(text)
msg["From"] = "yushu.gao@hogpu.cc"
msg["To"] = "hobot@hobot.cc"
msg["Subject"] = job_name
p = Popen(["/usr/sbin/sendmail", "-t", "-oi"], stdin=PIPE)
p.communicate(msg.as_string())
model.fit(...,eval_end_callback=sendmail,...)

其中,msg["From"]可以任意指定

SQLite3 数据库

1
2
3
4
5
6
7
8
9
10
11
12
import sqlite3
sqlite_file = "./db.sqlite"
conn = sqlite3.connect(sqlite_file)
c = conn.cursor()
cmd = "DROP TABLE {table_name}".format(table_name=cfg.table_name)
c.execute(cmd)
cmd = 'CREATE TABLE {table_name} ({image} {image_dt}, {render} {render_dt}, {label} {label_dt}, {predict} {predict_dt}, {miou} {miou_dt})'.format(table_name=cfg.table_name, image="image", image_dt='TEXT', render='render', render_dt="TEXT", label="label", label_dt="TEXT", predict="predict", predict_dt="TEXT", miou="miou", miou_dt="REAL")
c.execute(cmd)
cmd = "INSERT INTO {table_name} (image, render, label, predict, miou) VALUES('{image}', '{render}', '{label}', '{predict}', {miou})".format(table_name=table_name, image=image, render=render, label=label, predict=predict, miou=miou)
c.execute(cmd)
conn.commit()
conn.close()

LMDB

这里需要特别注意的是 env.begin() 一定要在所有的 db 全部打开之后调用,否则出错

1
2
3
4
5
6
7
8
9
10
11
# 写数据库
env = lmdb.open("data.lmdb/", max_dbs=4, map_size=int(1e12))
train_data = env.open_db("train_data")
train_label = env.open_db("train_label")
with env.begin(write=True) as txn:
txn.put("key", data, db=train_data)

# 读取数据库,返回的是buffer,还需要转换
with env.begin(write=True) as txn:
buffer = txn.get("key", db=train_data)
data = np.frombuffer(buffer, dtype='uint8')

requests

下载文件

1
2
3
4
5
res = requests.get(gift_url, stream=True)
with open(os.path.join(config.img_dir, fname), "wb") as f:
for chuck in res.iter_content(1024):
f.write(chuck)
res.close()

上传 (POST) 文件

1
2
3
4
url = "http://10.84.145.69:8080"
img = {"image": open("./example.jpeg", "rb")}
res = requests.post(url, files=img)
print res.text)

设置编码

1
# -*- coding: utf-8 -*-

matplotlib

在没有显示设备的机器上保存图像

1
2
import matplotlib
matplotlib.use('Agg')

调用系统命令

1
2
import subprocess
subprocess.call(["ls", "-l", "/home/"])

logging 日志输出格式

极简方法

1
logging.basicConfig(level=logging.DEBUG, format="%(asctime)s %(filename)s %(lineno)i: %(levelname)s: %(message)s")

同时写文件和终端

1
2
3
4
5
6
7
8
9
10
11
12
logger = logging.getLogger()
logger.setLevel(logging.DEBUG)
local_dir = "./" if socket.gethostname() == "yz-gpu002.hogpu.cc" else os.environ["LOCAL_OUTPUT"]
formatter = logging.Formatter("%(asctime)s %(filename)s %(lineno)i: %(levelname)s: %(message)s")
fh = logging.FileHandler(local_dir + "/pruing.log", "w")
# fh.setLevel(logging.DEBUG)
fh.setFormatter(formatter)
logger.addHandler(fh)
ch = logging.StreamHandler()
# ch.setLevel(logging.DEBUG)
ch.setFormatter(formatter)
logger.addHandler(ch)

遍历目录下面的所有指定类型的文件,包括子目录的文件

1
2
3
4
5
6
def mk_list(target=None):
matches = []
for root, dirnames, filenames in os.walk(target):
for filename in fnmatch.filter(filenames, '*.jpg'):
matches.append(os.path.join(root, filename))
return matches

tornado

提供 POST 服务 (接收上传的文件)

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
import tornado
import tornado.web
import tornado.ioloop
import os
import uuid

class Upload(tornado.web.RequestHandler):
def post(self):
fileinfo = self.request.files['image'][0]
fname = fileinfo['filename']
extn = os.path.splitext(fname)[1]
cname = str(uuid.uuid4())+extn
fh = open(__UPLOADS__+cname, 'w')
fh.write(fileinfo['body'])
print os.path.abspath(__UPLOADS__+cname)
ret = web_demo([os.path.abspath(__UPLOADS__ + cname)])
self.write(ret)

application = tornado.web.Application([
(r"/upload", Upload),
(r'/images/(.*)', tornado.web.StaticFileHandler,{'path': __static__path__}) # 直接访问获取图片
], debug= True)

if __name__ == "__main__":
application.listen(8888)
tornado.ioloop.IOLoop.instance().start()

提供 POST 服务 (不接收上传的问题件)

1
2
3
4
5
6
7
8
9
10
11
12
13
import tornado.web
import tornado.ioloop
import json
import os, sys
class GetTaskHandler(tornado.web.RequestHandler):
def post(self):
parsed_json = json.loads(self.request.body)
task_type = parsed_json['task_type']
row = self.GetTask(task_req=task_type)
if row:
self.write(row)
else:
self.write('no more data')

multiprocessing

1
2
import multiprocessing
multiprocessing.Pool(processes=multiprocessing.cpu_count()).map(func, list) # func 是具体处理函数,list 是可迭代对象

获取当前文件所在的路径

1
cur_path = os.path.abspath(os.path.dirname(__file__))

argparse

1
2
3
parser = argparse.ArgumentParser(description='train an image classifer on cifar10')
parser.add_argument('--network', type=str, default='inception-bn-28-small', help = 'the cnn to use')
args = parser.parse_args()