0%

MXNet IO 接口使用方法

MXNet 的 DataIter 输入的 Python 接口,不仅适用于图片,而且适用于其它的文件格式。DataIter 比 numpy.array 的巨大的优点是 DataIter 不需要一次性把数据读入到内存。想象一下在 CNN 的训练中动辄几百 G 的训练数据,一次性载入内存比较不现实。另一方面,通过 MXNet 的 IO 接口调用,可以更灵活的控制数据的 IO 过程。

单进程操作

优点:代码简单,容易操作
缺点:每次跑一个新的 Batch 的时候,MXNet 都要停下来等待读取/处理数据, 时间浪费非常严重

要定义一个 batch 中的数据

1
2
3
4
5
class DataBatch(object):
def __init__(self, data, label):
# 这个两属性是必须的,因为 Executor Manager 在 load 数据的时候,load 的对象就是 data 和 label.
self.data = data
self.label = label

使用 Python 接口定义 MXNet 的 Iter

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
class DataIter(mx.io.DataIter):
def __init__(self, count, batch_size, num_label, height, width):
super(DataIter, self).__init__()
self.batch_size = batch_size
self.count = count # 图片总量
self.height = height # 图片的高度
self.width = width # 图片的宽度
# 这非常重要,第一个参数 'data' 一定要和网络中的 `mxnet.symbol.Variable 的 name 对应起来`
self.provide_data = [('data', (batch_size, 3, height, width))]
self.provide_label = [('softmax_label', (self.batch_size, num_label))] # 同上

def __iter__(self):
for k in range(self.count / self.batch_size):
data = []
label = []
for i in range(self.batch_size):
one_img=func_get_img_data() # 获取一张图片数据,numpy.array 格式
one_label= func_get_label_data() # 获取该图片的 label 数据,numpy.array 格式
data.append(one_img)
label.append(one_label)

data_all = [mx.nd.array(data)]
label_all = [mx.nd.array(label)]
# 创建一个 Batch 的数据
data_batch = OneBatch(data_all, label_all)
# yield 的作用是每次返回一个 batch, 而不是所有的 batch, 上文所谓的节省内存的秘密就在这里
yield data_batch

多进程操作

优点:不需要 MXNet 停下来等待读取/处理数据
缺点:代码有点小麻烦
MXNet 中使用的是 threading 来操作的 Prefetching, 我这里使用 multiprocessing

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
import multiprocessing
class DataIter(mx.io.DataIter):
def __init__(self, count, batch_size, num_label, height, width):
super(DataIter, self).__init__()
self.batch_size = batch_size
self.cursor = -batch_size
self.batch_size = batch_size
self.num_data = count
self.height = height # 图片的高度
self.width = width # 图片的宽度
# 这非常重要,第一个参数 'data' 一定要和网络中的 `mxnet.symbol.Variable 的 name 对应起来`
self.provide_data = [('data', (batch_size, 3, height, width))]
self.provide_label = [('softmax_label', (self.batch_size, num_label))] # 同上
# 首先定义一个 Queue 对象,用来存储 Prefetch 的数据
self.q = multiprocessing.Queue(maxsize=2)
# 创建 4 个读取数据的进程,具体的读取方法在 self.write 中
self.pws = [multiprocessing.Process(target=self.write) for in in range(4)]
for pw in self.pws:
pw.daemon = True
pw.start()

def write(self):
while True:
for i in range(self.batch_size):
# 获取一张图片数据,numpy.array 格式
one_img = func_get_img_data()
# 获取该图片的 label 数据,numpy.array 格式
one_label = func_get_label_data()
data.append(one_img)
label.append(one_label)
data_all = [mx.nd.array(data)]
label_all = [mx.nd.array(label)]
# 创建一个 Batch 的数据
data_batch = OneBatch(data_all, label_all)
# block=True, 允许在队列满的时候阻塞,timeout=None, 永不超时
self.q.put(obj=data_batch, block=True, timeout=None)
def iter_next(self):
self.cursor += self.batch_size
return self.cursor < self.num_data

def next(self):
if self.q.empty():
logging.debug("waiting for data")
if self.iter_next():
# block=True, 允许在队列空的时候阻塞,timeout=None, 永不超时
return self.q.get(block=True, timeout=None)
else:
raise StopIteration

def reset(self):
self.cursor = -self.batch_size + (self.cursor%self.num_data) % self.batch_size