MXNet 的 DataIter 输入的 Python 接口,不仅适用于图片,而且适用于其它的文件格式。DataIter 比 numpy.array
的巨大的优点是 DataIter 不需要一次性把数据读入到内存。想象一下在 CNN 的训练中动辄几百 G 的训练数据,一次性载入内存比较不现实。另一方面,通过 MXNet 的 IO 接口调用,可以更灵活的控制数据的 IO 过程。
单进程操作
优点:代码简单,容易操作
缺点:每次跑一个新的 Batch 的时候,MXNet 都要停下来等待读取/处理数据, 时间浪费非常严重
要定义一个 batch 中的数据
1 2 3 4 5
| class DataBatch(object): def __init__(self, data, label): self.data = data self.label = label
|
使用 Python 接口定义 MXNet 的 Iter
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27
| class DataIter(mx.io.DataIter): def __init__(self, count, batch_size, num_label, height, width): super(DataIter, self).__init__() self.batch_size = batch_size self.count = count self.height = height self.width = width self.provide_data = [('data', (batch_size, 3, height, width))] self.provide_label = [('softmax_label', (self.batch_size, num_label))]
def __iter__(self): for k in range(self.count / self.batch_size): data = [] label = [] for i in range(self.batch_size): one_img=func_get_img_data() one_label= func_get_label_data() data.append(one_img) label.append(one_label)
data_all = [mx.nd.array(data)] label_all = [mx.nd.array(label)] data_batch = OneBatch(data_all, label_all) yield data_batch
|
多进程操作
优点:不需要 MXNet 停下来等待读取/处理数据
缺点:代码有点小麻烦
MXNet 中使用的是 threading
来操作的 Prefetching, 我这里使用 multiprocessing
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51
| import multiprocessing class DataIter(mx.io.DataIter): def __init__(self, count, batch_size, num_label, height, width): super(DataIter, self).__init__() self.batch_size = batch_size self.cursor = -batch_size self.batch_size = batch_size self.num_data = count self.height = height self.width = width self.provide_data = [('data', (batch_size, 3, height, width))] self.provide_label = [('softmax_label', (self.batch_size, num_label))] self.q = multiprocessing.Queue(maxsize=2) self.pws = [multiprocessing.Process(target=self.write) for in in range(4)] for pw in self.pws: pw.daemon = True pw.start()
def write(self): while True: for i in range(self.batch_size): one_img = func_get_img_data() one_label = func_get_label_data() data.append(one_img) label.append(one_label) data_all = [mx.nd.array(data)] label_all = [mx.nd.array(label)] data_batch = OneBatch(data_all, label_all) self.q.put(obj=data_batch, block=True, timeout=None) def iter_next(self): self.cursor += self.batch_size return self.cursor < self.num_data
def next(self): if self.q.empty(): logging.debug("waiting for data") if self.iter_next(): return self.q.get(block=True, timeout=None) else: raise StopIteration
def reset(self): self.cursor = -self.batch_size + (self.cursor%self.num_data) % self.batch_size
|