Transfer learning 通常在训练数据不足的情况下使用。所谓 transfer learning 最直观的理解是给模型一个很好初始值。由于使用的 base model 通常都是在各大 benchmark 上取得 state of the art 效果的模型,因此,transfer learning 通常能带来事半功倍的效果。另外,transfer learning 还可以防止网络过早出现过拟合。当然,使用 transfer learning 的基础是两个数据集的数据分布是相似的。本文使用 BOT 2016 计算机视觉大赛的数据,详细记录了 transfer learning 的过程。
准备数据
下载/解压数据就不再赘述了. 下面是我的工作目录的状态
其中,data
目录结构如下
生成 MXNet 的 rec 格式的数据
分为两个步骤,首先是生成一个 list 文件,该文件没一行有 3 各字段 label, index, path
, 分别对应了路径为 path
的图片的 label
和 index
. 第二步是使用上面生成的 list 文件,来生成训练过程中真正使用的 rec
文件。
步骤一:生成 list 文件
暂时不能使用 gif 格式的图片,所以,我没有处理 gif 格式的图片。
1 | # --list=true 是说明我们要生成list |
这个步骤完成之后,会生成 3 个 list 文件,分别是:
bot_train.lst
, 存放了训练数据的信息,顺序已经打乱bot_val.lst
, 存放了 validation 数据,顺序没有打乱bot_test.lst
, 存放 test 数据,因为在生成 list 的时候我们没有指定要把多少比例的数据作为 test 数据,默认为 0, 所以,这个文件是空的。
步骤二:生成 rec 文件
最终输入到 MXNet 中的数据是 rec
格式,我们需要利用步骤一得到的 list 文件,生成 rec
文件。
1 | # --resize=224, 把图片较短的边 resize 到 224, 不同的配置 resize 的方法不同, 具体内容请参考源码 |
这个步骤之后,我们得到了 2 个 rec
文件,分别是:
bot_train.rec
, 训练数据bot_val.rec
, validation 数据。
im2rec.py
还有很多其它的选项配置,这里没有用到,所以就不一一解释了。
转换模型参数和网络结构信息
首先要说明,MXNet 提供的用于转换 caffe
网络结构的脚本有一点小错误,首先我们要修改一下脚本。
打开 path/to/mxnet/tools/caffe_converter/convert_symbol.py
文件,修改一下 188 行。
1 | # symbol_string, output_name = proto2script(sys.argv[1]) |
你也可以打印出来 input_dim
的值,这个值是 (10, 3, 224, 224)
步骤一:转换网络结构
1 | # 下载 caffe 格式的 deploy 文件 |
完成之后,生成 vgg16.py
文件,就是我们需要的结果
下面是我的转换结果
1 | conv1_1 = mx.symbol.Convolution(name='conv1_1', data=data , num_filter=64, pad=(1,1), kernel=(3,3), stride=(1,1), no_bias=False) |
步骤二:转换参数
1 | # 下载 caffe 格式的模型参数 caffemodel |
完成之后生成两个文件:
vgg16-symbol.json
, 定义了网络图结构vgg16-0001.params
, 对应的参数,约 528M
构建网络
现在万事俱备,我们开始构建自己的用于 transfer learning 的网络。首先是把 vgg16.py 略做修改。
修改最后两行
1 | # fc8 = mx.symbol.FullyConnected(name='fc8', data=drop7 , num_hidden=1000, no_bias=False) |
在最开始的部分加入输入 variable
1 | data=mx.symbol.Variable("data") |
修改相应层的 lr_mult
参数,该参数的意义是该层的参数更新的 learn rate 需要在基础的 base learn rate 上乘以该参数,即 lr=lr_base*lr_mult
, 例如:
1 | conv1_1 = mx.symbol.Convolution(name='conv1_1', data=data , num_filter=64, pad=(1,1), kernel=(3,3), stride=(1,1), no_bias=False, attr={'lr_mult':'0.01'}) |
加载数据
1 | def get_iterator(batch_size=128): |
训练
1 | # 配置 logging, 不然不会打印输出 |
结语
因为是图像识别任务,所以,整个过程还是中规中矩,比较简答。通过这种方法,在没有调整训练参数的情况下,跑了 2 个 epoch, validation 准确率能达到 92.6%. 效果还可以。
完整代码请参考 github