LMDB 的 Python 接口使用方法

MXNet中对classification任务提供了把训练图像数据转换成一个大的二进制文件的方法,但是,对于其它任务例如语义分割等,并没有提供类似的功能。这里,介绍一下如何使用LMDB的Python接口把语义分割训练数据的图像和标签转换成LMDB的文件。首先,这里简要介绍一下问什么要把图像文件转换成大的二进制文件。

为什么要把图像数据转换成大的二进制文件

简单来说,是因为读写小文件的速度太慢。那么, 不禁要问,图像数据也是二进制文件,单个大的二进制文件例如LMDB文件也是二进制文件,为什么单个图像读写速度就慢了呢?这里分两种情况解释。

  1. 机械硬盘的情况:机械硬盘的每次读写启动时间比较场,例如磁头的寻道时间占比很高,因此,如果单个小文件读写,尤其是随机读写单个小文件的时候,这个寻道时间占比就会很高,最后导致大量读写小文件的时候时间会很浪费;
  2. NFS的情况:在NFS的场景下,系统的一次读写首先要进行上百次的网络通讯,并且这个通讯次数和文件的大小无关。因此,如果是读写小文件,这个网络通讯时间占据了整个读写时间的大部分。

固态硬盘的情况下应该也会有一些类似的开销,目前没有研究过。

LMDB 的使用方法

LMDB 是一个key value 内存映射的数据库。内存映射的意思就是说,LMDB在使用的时候,会把磁盘上的数据映射到内存中,因此,只顺序读写的情况下,相当于直接在内存中操作数据,因此速度很快。另外,在网络训练场景中存储的LMDB数据是一个二进制文件,因此,也克服了上一部分说的小文件读写的问题。

LMDB 写数据

这里展示一下如何把图像数据存储到LMDB中,基本上分为3个步骤:

  1. 创建LMDB文件
  2. 创建对应的database
  3. 向对应的database中写数据

注意,LMDB中的数据不会保存图像的shape信息,为了在读取数据的过程中正确地恢复出图像,需要知道图像的shape 信息。如果图像的shape全部是相同的,那么可以不用存储该信息,如果不同,也可以把该信息存储在LMDB中。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
def img2lmdb():
# 创建数据库文件
env = lmdb.open(cfg.dataset, max_dbs=4, map_size=int(1e12))
# 创建对应的数据库
train_data = env.open_db("train_data")
train_label = env.open_db("train_label")
val_data = env.open_db("val_data")
val_label = env.open_db("val_label")
train_image_list, train_label_list = get_image_label_list(train=True)
val_image_list, val_label_list = get_image_label_list(train=False)
# 把图像数据写入到LMDB中
with env.begin(write=True) as txn:
for idx, path in enumerate(train_image_list):
logging.debug("{} {}".format(idx, path))
data = read_fixed_image(path)
txn.put(str(idx), data, db=train_data)

for idx, path in enumerate(train_label_list):
logging.debug("{} {}".format(idx, path))
data = read_fixed_label(path)
txn.put(str(idx), data, db=train_label)

for idx, path in enumerate(val_image_list):
logging.debug("{} {}".format(idx, path))
data = read_fixed_image(path)
txn.put(str(idx), data, db=val_data)

for idx, path in enumerate(val_label_list):
logging.debug("{} {}".format(idx, path))
data = read_fixed_label(path)
txn.put(str(idx), data, db=val_label)

LMDB读数据

例子如下,不解释了~

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
class DataIter(gluon.data.Dataset):

def __init__(self, train=True):
env = lmdb.open(cfg.dataset, max_dbs=4, map_size=int(1e12), readonly=True)
self.train_data = env.open_db("train_data")
self.train_label = env.open_db("train_label")
self.val_data = env.open_db("val_data")
self.val_label = env.open_db("val_label")
self.txn = env.begin()
self._length = self.txn.stat(db=self.train_data)["entries"] if train else self.txn.stat(db=self.val_data)["entries"]
self.train = train

def __getitem__(self, idx):
idx = str(idx)
if self.train:
image = self.txn.get(idx, db=self.train_data)
image = np.frombuffer(image, 'uint8')
image = np.reshape(image, [4] + list(cfg.raw_size))
label = self.txn.get(idx, db=self.train_label)
label = np.frombuffer(label, 'uint8')
label = np.reshape(label, list(cfg.raw_size))
return nd.array(image), nd.array(label / 255.0)
else:
image = self.txn.get(idx, db=self.val_data)
image = np.frombuffer(image, 'uint8')
image = np.reshape(image, [4] + list(cfg.raw_size))
label = self.txn.get(idx, db=self.val_label)
label = np.frombuffer(label, 'uint8')
label = np.reshape(label, list(cfg.raw_size))
return nd.array(image), nd.array(label / 255.0)

def __len__(self):
return self._length

0%