MXNet IO 接口使用方法

MXNet 的 DataIter 输入的 Python 接口, 不仅适用于图片, 而且适用于其它的文件格式. DataIter 比 numpy.array 的巨大的优点是 DataIter 不需要一次性把数据读入到内存. 想象一下在 CNN 的训练中动辄几百 G 的训练数据, 一次性载入内存比较不现实. 另一方面, 通过 MXNet 的 IO 接口调用, 可以更灵活的控制数据的 IO 过程.

单进程操作

优点: 代码简单, 容易操作
缺点: 每次跑一个新的 Batch 的时候, MXNet 都要停下来等待读取/处理数据, 时间浪费非常严重

  1. 要定义一个 batch 中的数据,

    1
    2
    3
    4
    5
    class DataBatch(object):
    def __init__(self, data, label):
    # 这个两属性是必须的, 因为 Executor Manager 在 load 数据的时候, load 的对象就是 data 和 label.
    self.data = data
    self.label = label
  2. 使用 Python 接口定义 MXNet 的 Iter

    1
    2
    3
    4
    5
    6
    7
    8
    9
    10
    11
    12
    13
    14
    15
    16
    17
    18
    19
    20
    21
    22
    23
    24
    25
    26
    27
    class DataIter(mx.io.DataIter):
    def __init__(self, count, batch_size, num_label, height, width):
    super(DataIter, self).__init__()
    self.batch_size = batch_size
    self.count = count # 图片总量
    self.height = height # 图片的高度
    self.width = width # 图片的宽度
    # 这非常重要, 第一个参数 'data' 一定要和网络中的 `mxnet.symbol.Variable 的 name 对应起来`
    self.provide_data = [('data', (batch_size, 3, height, width))]
    self.provide_label = [('softmax_label', (self.batch_size, num_label))] # 同上
    def __iter__(self):
    for k in range(self.count / self.batch_size):
    data = []
    label = []
    for i in range(self.batch_size):
    one_img=func_get_img_data() # 获取一张图片数据, numpy.array 格式
    one_label= func_get_label_data() # 获取该图片的 label 数据, numpy.array 格式
    data.append(one_img)
    label.append(one_label)
    data_all = [mx.nd.array(data)]
    label_all = [mx.nd.array(label)]
    # 创建一个 Batch 的数据
    data_batch = OneBatch(data_all, label_all)
    # yield 的作用是每次返回一个 batch, 而不是所有的 batch, 上文所谓的节省内存的秘密就在这里
    yield data_batch

多进程操作

优点: 不需要 MXNet 停下来等待读取/处理数据
缺点: 代码有点小麻烦
MXNet 中使用的是 threading 来操作的 Prefetching, 我这里使用 multiprocessing

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
import multiprocessing
class DataIter(mx.io.DataIter):
def __init__(self, count, batch_size, num_label, height, width):
super(DataIter, self).__init__()
self.batch_size = batch_size
self.cursor = -batch_size
self.batch_size = batch_size
self.num_data = count
self.height = height # 图片的高度
self.width = width # 图片的宽度
# 这非常重要, 第一个参数 'data' 一定要和网络中的 `mxnet.symbol.Variable 的 name 对应起来`
self.provide_data = [('data', (batch_size, 3, height, width))]
self.provide_label = [('softmax_label', (self.batch_size, num_label))] # 同上
# 首先定义一个 Queue 对象, 用来存储 Prefetch 的数据
self.q = multiprocessing.Queue(maxsize=2)
# 创建4个读取数据的进程, 具体的读取方法在 self.write 中
self.pws = [multiprocessing.Process(target=self.write) for in in range(4)]
for pw in self.pws:
pw.daemon = True
pw.start()
def write(self):
while True:
for i in range(self.batch_size):
# 获取一张图片数据, numpy.array 格式
one_img = func_get_img_data()
# 获取该图片的 label 数据, numpy.array 格式
one_label = func_get_label_data()
data.append(one_img)
label.append(one_label)
data_all = [mx.nd.array(data)]
label_all = [mx.nd.array(label)]
# 创建一个 Batch 的数据
data_batch = OneBatch(data_all, label_all)
# block=True, 允许在队列满的时候阻塞, timeout=None, 永不超时
self.q.put(obj=data_batch, block=True, timeout=None)
def iter_next(self):
self.cursor += self.batch_size
return self.cursor < self.num_data
def next(self):
if self.q.empty():
logging.debug("waiting for data")
if self.iter_next():
# block=True, 允许在队列空的时候阻塞, timeout=None, 永不超时
return self.q.get(block=True, timeout=None)
else:
raise StopIteration
def reset(self):
self.cursor = -self.batch_size + (self.cursor%self.num_data) % self.batch_size