NAS 提供了一种自动化搜索网络结构的方法。其大体思想是固定网络的训练过程,然后寻找通过该训练过程可以达到最好效果的网络结构。从一定程度上看这种方法和我们通常选定一种网络结构然后使用不同的优化参数和优化方法找到最好的模型参数的过程恰好相反。
目标网络结构的描述
NAS 使用一个 RNN 去训练网络结构,在 NAS 中,该 RNN 被称为 Controller.
初级网络结构描述
既然要用 RNN 去训练一个网络结构,那么就需要一种可以用 RNN 来描述网络结构的方法。在 NAS 中,使用字符串了描述网络结构。基本模式为 [number of filters][filter height][filter width][stride width], 例如如下图中描述的结构:
高级网络结构描述
上面的网络结构只能描述最简单的结构,如果像是 resnet 这种网络结构,上面的方法就无法描述。
为了描述 skip connection 这种结构,在 Controller 中增加了一个 Anchor Point 结构。例如,在第 N 层的 Anchor Point 描述的是前面 N-1 层了当前层是否有 skip connection
Controller 的训练
在训练的过程中,会逐渐增加目标网络的层数,也就是增加字符串的长度。
Controller 使用强化学习的 gradient policy 方法训练,训练的 reward 是 Controller 得到的神经网络在具体的数据集上的 performance, 例如,寻找 ConvNet 使用的是在 cifar10 上的 accuracy 作为强化学习的 reward.
怎样得到 reward
每次 Controller 输出的网络,全部使用同一套参数在数据集上进行相同的训练,训练结束后得到的 performance 作为 reward. 例如,搜索 cifar10 上的 CNN 结构,使用的策略是:
- 在该 CNN 上训练 50 个 epoch
- learning rate 设置为 0.1
- weight decay 设置为 1e-4
- momentum 设置为 0.9
- 使用 Nesterov Momentum
50 个 epoch 结束之后,使用该模型在 validation dataset 上测试 accuracy, 该 accuracy 作为 reward.