MXNet Transfer Learning

Transfer learning 通常在训练数据不足的情况下使用. 所谓 transfer learning 最直观的理解是给模型一个很好初始值. 由于使用的 base model 通常都是在各大 benchmark 上取得 state of the art 效果的模型, 因此, transfer learning 通常能带来事半功倍的效果. 另外, transfer learning 还可以防止网络过早出现过拟合. 当然, 使用 transfer learning 的基础是两个数据集的数据分布是相似的. 本文使用 BOT 2016 计算机视觉大赛的数据, 详细记录了 transfer learning 的过程.

准备数据

下载/解压数据就不再赘述了. 下面是我的工作目录的状态


其中, data 目录结构如下

生成 MXNet 的 rec 格式的数据

分为两个步骤, 首先是生成一个 list 文件, 该文件没一行有 3 各字段 label, index, path, 分别对应了路径为 path 的图片的 labelindex. 第二步是使用上面生成的 list 文件, 来生成训练过程中真正使用的 rec 文件.

步骤一: 生成 list 文件

暂时不能使用 gif 格式的图片, 所以, 我没有处理 gif 格式的图片.

1
2
3
4
5
6
# --list=true 是说明我们要生成list
# --train-ratio=0.8 我们把 80% 的数据用来作为训练数据, 剩下的 20% 自动作为 validation
# --recursive=true 迭代地去找图片
# bot 生成的 list 文件的前缀
# data/ 我们存放的图片数据的根目录
python path/to/mxnet/tools/im2rec.py --list=true --train-ratio=0.8 --exts=['.jpg','.jpeg','.png'] --recursive=true bot data/

这个步骤完成之后, 会生成 3 个 list 文件, 分别是:

  1. bot_train.lst, 存放了训练数据的信息, 顺序已经打乱
  2. bot_val.lst, 存放了 validation 数据, 顺序没有打乱
  3. bot_test.lst, 存放 test 数据, 因为在生成 list 的时候我们没有指定要把多少比例的数据作为 test 数据, 默认为 0, 所以, 这个文件是空的.

步骤二: 生成 rec 文件

最终输入到 MXNet 中的数据是 rec 格式, 我们需要利用步骤一得到的 list 文件, 生成 rec 文件.

1
2
3
4
# --resize=224, 把图片较短的边 resize 到 224, 不同的配置 resize 的方法不同, 具体内容请参考源码
# bot_val.lst, 步骤一中生生成的 list 文件
# data/ 图片存放的根目录, 使用 bot_val.lst 中每一行的地址拼上 data/ 就可以准确定位到一张图片
python path/to/mxnet/tools/im2rec.py --resize=224 --quality=100 bot_val.lst data/

这个步骤之后, 我们得到了 2 个 rec 文件, 分别是:

  1. bot_train.rec, 训练数据
  2. bot_val.rec, validation 数据.

im2rec.py 还有很多其它的选项配置, 这里没有用到,所以就不一一解释了.

转换模型参数和网络结构信息

首先要说明, MXNet 提供的用于转换 caffe 网络结构的脚本有一点小错误, 首先我们要修改一下脚本.
打开 path/to/mxnet/tools/caffe_converter/convert_symbol.py 文件, 修改一下 188 行.

1
2
# symbol_string, output_name = proto2script(sys.argv[1])
symbol_string, output_name, input_dim = proto2script(sys.argv[1])

你也可以打印出来 input_dim 的值, 这个值是 (10, 3, 224, 224)

步骤一: 转换网络结构

1
2
3
4
5
6
# 下载 caffe 格式的 deploy 文件
wget -c https://gist.githubusercontent.com/ksimonyan/211839e770f7b538e2d8/raw/c3ba00e272d9f48594acef1f67e5fd12aff7a806/VGG_ILSVRC_16_layers_deploy.prototxt
# 转换
# VGG_ILSVRC_16_layers_deploy.prototxt 上面下载的文件
# vgg16.py 转换结果的存储位置
python convert_symbol.py VGG_ILSVRC_16_layers_deploy.prototxt vgg16.py

完成之后, 生成 vgg16.py 文件, 就是我们需要的结果
下面是我的转换结果

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
conv1_1 = mx.symbol.Convolution(name='conv1_1', data=data , num_filter=64, pad=(1,1), kernel=(3,3), stride=(1,1), no_bias=False)
relu1_1 = mx.symbol.Activation(name='relu1_1', data=conv1_1 , act_type='relu')
conv1_2 = mx.symbol.Convolution(name='conv1_2', data=relu1_1 , num_filter=64, pad=(1,1), kernel=(3,3), stride=(1,1), no_bias=False)
relu1_2 = mx.symbol.Activation(name='relu1_2', data=conv1_2 , act_type='relu')
pool1 = mx.symbol.Pooling(name='pool1', data=relu1_2 , pad=(0,0), kernel=(2,2), stride=(2,2), pool_type='max')
conv2_1 = mx.symbol.Convolution(name='conv2_1', data=pool1 , num_filter=128, pad=(1,1), kernel=(3,3), stride=(1,1), no_bias=False)
relu2_1 = mx.symbol.Activation(name='relu2_1', data=conv2_1 , act_type='relu')
conv2_2 = mx.symbol.Convolution(name='conv2_2', data=relu2_1 , num_filter=128, pad=(1,1), kernel=(3,3), stride=(1,1), no_bias=False)
relu2_2 = mx.symbol.Activation(name='relu2_2', data=conv2_2 , act_type='relu')
pool2 = mx.symbol.Pooling(name='pool2', data=relu2_2 , pad=(0,0), kernel=(2,2), stride=(2,2), pool_type='max')
conv3_1 = mx.symbol.Convolution(name='conv3_1', data=pool2 , num_filter=256, pad=(1,1), kernel=(3,3), stride=(1,1), no_bias=False)
relu3_1 = mx.symbol.Activation(name='relu3_1', data=conv3_1 , act_type='relu')
conv3_2 = mx.symbol.Convolution(name='conv3_2', data=relu3_1 , num_filter=256, pad=(1,1), kernel=(3,3), stride=(1,1), no_bias=False)
relu3_2 = mx.symbol.Activation(name='relu3_2', data=conv3_2 , act_type='relu')
conv3_3 = mx.symbol.Convolution(name='conv3_3', data=relu3_2 , num_filter=256, pad=(1,1), kernel=(3,3), stride=(1,1), no_bias=False)
relu3_3 = mx.symbol.Activation(name='relu3_3', data=conv3_3 , act_type='relu')
pool3 = mx.symbol.Pooling(name='pool3', data=relu3_3 , pad=(0,0), kernel=(2,2), stride=(2,2), pool_type='max')
conv4_1 = mx.symbol.Convolution(name='conv4_1', data=pool3 , num_filter=512, pad=(1,1), kernel=(3,3), stride=(1,1), no_bias=False)
relu4_1 = mx.symbol.Activation(name='relu4_1', data=conv4_1 , act_type='relu')
conv4_2 = mx.symbol.Convolution(name='conv4_2', data=relu4_1 , num_filter=512, pad=(1,1), kernel=(3,3), stride=(1,1), no_bias=False)
relu4_2 = mx.symbol.Activation(name='relu4_2', data=conv4_2 , act_type='relu')
conv4_3 = mx.symbol.Convolution(name='conv4_3', data=relu4_2 , num_filter=512, pad=(1,1), kernel=(3,3), stride=(1,1), no_bias=False)
relu4_3 = mx.symbol.Activation(name='relu4_3', data=conv4_3 , act_type='relu')
pool4 = mx.symbol.Pooling(name='pool4', data=relu4_3 , pad=(0,0), kernel=(2,2), stride=(2,2), pool_type='max')
conv5_1 = mx.symbol.Convolution(name='conv5_1', data=pool4 , num_filter=512, pad=(1,1), kernel=(3,3), stride=(1,1), no_bias=False)
relu5_1 = mx.symbol.Activation(name='relu5_1', data=conv5_1 , act_type='relu')
conv5_2 = mx.symbol.Convolution(name='conv5_2', data=relu5_1 , num_filter=512, pad=(1,1), kernel=(3,3), stride=(1,1), no_bias=False)
relu5_2 = mx.symbol.Activation(name='relu5_2', data=conv5_2 , act_type='relu')
conv5_3 = mx.symbol.Convolution(name='conv5_3', data=relu5_2 , num_filter=512, pad=(1,1), kernel=(3,3), stride=(1,1), no_bias=False)
relu5_3 = mx.symbol.Activation(name='relu5_3', data=conv5_3 , act_type='relu')
pool5 = mx.symbol.Pooling(name='pool5', data=relu5_3 , pad=(0,0), kernel=(2,2), stride=(2,2), pool_type='max')
flatten_0=mx.symbol.Flatten(name='flatten_0', data=pool5)
fc6 = mx.symbol.FullyConnected(name='fc6', data=flatten_0 , num_hidden=4096, no_bias=False)
relu6 = mx.symbol.Activation(name='relu6', data=fc6 , act_type='relu')
drop6 = mx.symbol.Dropout(name='drop6', data=relu6 , p=0.500000)
fc7 = mx.symbol.FullyConnected(name='fc7', data=drop6 , num_hidden=4096, no_bias=False)
relu7 = mx.symbol.Activation(name='relu7', data=fc7 , act_type='relu')
drop7 = mx.symbol.Dropout(name='drop7', data=relu7 , p=0.500000)
fc8 = mx.symbol.FullyConnected(name='fc8', data=drop7 , num_hidden=1000, no_bias=False)
prob = mx.symbol.SoftmaxOutput(name='prob', data=fc8 )

步骤二: 转换参数

1
2
3
4
5
6
7
# 下载 caffe 格式的模型参数 caffemodel
wget -c http://www.robots.ox.ac.uk/~vgg/software/very_deep/caffe/VGG_ILSVRC_16_layers.caffemodel
# 转换格式
# VGG_ILSVRC_16_layers_deploy.prototxt 步骤一下载的 caffe 的 deploy 文件
# VGG_ILSVRC_16_layers.caffemodel 上面下载的 caffemodel 文件
# 这一步骤时间比较久, 需要 3~5min
python convert_model.py VGG_ILSVRC_16_layers_deploy.prototxt VGG_ILSVRC_16_layers.caffemodel vgg16

完成之后生成两个文件:

  1. vgg16-symbol.json, 定义了网络图结构
  2. vgg16-0001.params, 对应的参数, 约 528M

构建网络

现在万事俱备, 我们开始构建自己的用于 transfer learning 的网络. 首先是把 vgg16.py 略做修改.

  1. 修改最后两行

    1
    2
    3
    4
    # fc8 = mx.symbol.FullyConnected(name='fc8', data=drop7 , num_hidden=1000, no_bias=False)
    # prob = mx.symbol.SoftmaxOutput(name='prob', data=fc8 )
    fc8 = mx.symbol.FullyConnected(name='myfc8', data=drop7 , num_hidden=12, no_bias=False) # 因为 bot 比赛只有 12 类, 我们修改相应的类别数
    prob = mx.symbol.SoftmaxOutput(name='softmax', data=fc8 ) # mxnet 默认是 softmax_label, 所以, 这里把 prob 修改成 softmax
  2. 在最开始的部分加入输入 variable

    1
    data=mx.symbol.Variable("data")
  3. 修改相应层的 lr_mult 参数, 该参数的意义是该层的参数更新的 learn rate 需要在基础的base learn rate 上乘以该参数, 即 lr=lr_base*lr_mult, 例如:

    1
    conv1_1 = mx.symbol.Convolution(name='conv1_1', data=data , num_filter=64, pad=(1,1), kernel=(3,3), stride=(1,1), no_bias=False, attr={'lr_mult':'0.01'})

加载数据

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
def get_iterator(batch_size=128):
data_shape=(3, 224, 224)
data_iter="./"
train=mx.io.ImageRecordIter(
path_imgrec = data_iter+"bot_train.rec",
# "mean.bin" 并不存在, 这里 MXNet 发现不存在的话会自动生成
mean_img = data_iter+"mean.bin",
data_shape = data_shape,
batch_size = batch_size,
rand_crop = True,
rand_mirror = True,
)
val=mx.io.ImageRecordIter(
path_imgrec = data_iter+"bot_val.rec",
mean_img = data_iter+"mean.bin",
data_shape = data_shape,
batch_size = batch_size,
rand_crop = False,
rand_mirror = False,
)
return (train, val)

训练

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
# 配置 logging, 不然不会打印输出
head = '%(asctime)-15s %(message)s'
logging.basicConfig(level=logging.DEBUG, format=head)
batch_size=32
# 如果有 2 块显卡, 可以使用 dev=[mx.gpu(0), mx.gpu(1)], 其它类推
dev=[mx.gpu(0)]
# 载入我们转换好的模型.
old_model = mx.model.FeedForward.load("vgg16", 1)
model= mx.model.FeedForward(ctx=dev,
symbol=get_symbol(),
num_epoch=200,
learning_rate=0.01,
wd=0.0001,
arg_params=old_model.arg_params,
aux_params=old_model.aux_params,
allow_extra_params=True
)
data_train, data_test=get_iterator(batch_size)
model.fit(X=data_train,
eval_data=data_test,
kvstore='local_allreduce_device',
batch_end_callback=mx.callback.Speedometer(batch_size, 50),
epoch_end_callback=mx.callback.do_checkpoint("bot"))

结语

因为是图像识别任务, 所以, 整个过程还是中规中矩, 比较简答. 通过这种方法, 在没有调整训练参数的情况下, 跑了 2 个 epoch, validation 准确率能达到 92.6%. 效果还可以.
完整代码请参考 github