Bucket RNN 分析和实现

RNN 是另外一种在深度学习领域常见的网络模型, 在语音识别, NLP 等需要对 sequence 数据建模的场景中应用广泛. 由于 sequence 数据的长度是变化的, 所以, 在实际工程中, 常见的一种做法是指定一个 sequence 的长度 \(L\)(一般是最长的一个 sequence 的长度), 然后, 把长度不等于 \(L\) 的数据进行 mask, 从而使得所有的 sequence 都具有相同的长度. 然而, 这种做法存在的问题是, 在 sequence 的长度差别比较大的情况下, 这样生硬的进行 mask 的做法显然不靠谱. 当前, 比较常见一种 trade-off 的做法是 Bucket RNN.

RNN

RNN 是一种能够学习 sequence to sequence 数据的模型, 一个基本的 RNN 在每一个状态点对应一个输入和输出, 而且, 在状态 \(t+1\) 的输出是由输入 \(x_{t+1}\) 和 状态 \(t\) 共同控制的, 如下图所示.


但是, 在工程实现中会面临一个比较大的问题. 在实现中, 为了加快计算速度, 一般都是把数据进行向量化之后通过矩阵计算实现加快计算的目的(例如调用 BLAS 接口, 常见的如 openblas, atlas, mkl 等). 但是, 在 RNN 中, 不同的样本数据中, 状态数量 T 往往是不同的, 这就导致无法进行向量化, 常见的一种解决方法就是上面说的使用一个特定的数把 sequence 长度短的数据进行补齐. 另外一种更合理的方法是实现 bucket RNN.


如上图, bucket rnn 的做法是实现了几个固定长度的 rnn, 例如 100, 110, 120 等. 其中, sequence 长度小于等于 100 的样本, 输入到”第一个网络”中, 同理, sequence 长度大于 100 小于 110 的样本输入到”第二个网络”中…… 这里非常关键的一点是 bucket size 不同的 rnn 一定要进行权值共享. 也就是说, bucket=100 的参数, 要和 bucket=110 的 rnn 的前 100 个 time step 的参数要共享. 其它的同理. 在训练的时候, 每一个 mini batch 的数据的 bucket size 是相同的, 这样, 就可以把数据进行向量化, 然后使用 BLAS 进行加速了.
理解这一点关键是要理解, 在训练 RNN 的时候是把 RNN 中的所有的循环展开, 然后直接进行 forward 计算. 之前一直迷惑就是因为没有抓住这个关键点. 例如, 我用 MXNet 可视化了一个 bucket=3 的 RNN 如下图:

MXNet 源码

DataIter 部分

这部分代码主要实现了两个特殊的功能:

  1. 把样本数据放入到其对应的 bucket 中, 并且做好相应的 mask 操作
  2. 每一个 mini batch 要告诉 executor 目前的 bucket, 从而 executor 决定使用哪个 symbol

第一个功能是在 mx.io.DataIter 的初始化中完成的, 没有与 MXNet 密切相关的特殊的地方.
第二个功能实现如下:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
class SimpleBatch(object):
def __init__(self, data_names, data, label_names, label, bucket_key):
self.data = data
self.label = label
self.data_names = data_names
self.label_names = label_names
self.bucket_key = bucket_key # 告诉 executor 当前的 mini batch 的 bucket
self.pad = 0
self.index = None # TODO: what is index?
@property
def provide_data(self):
return [(n, x.shape) for n, x in zip(self.data_names, self.data)]
@property
def provide_label(self):
return [(n, x.shape) for n, x in zip(self.label_names, self.label)]

BucketModule 部分

这个部分与非 bucket module 不同主要是实现把 mini batch 根据其 bucket 绑定到不同的 symbol 上面去.

  1. 切换 Symbol

    1
    2
    3
    4
    5
    6
    7
    8
    9
    10
    11
    12
    13
    14
    15
    16
    17
    18
    19
    20
    21
    22
    23
    24
    25
    def switch_bucket(self, bucket_key, data_shapes, label_shapes=None):
    """Switch to a different bucket. This will change `self.curr_module`.
    Parameters
    ----------
    bucket_key : str (or any python object)
    The key of the target bucket.
    data_shapes : list of (str, tuple)
    Typically `data_batch.provide_data`.
    label_shapes : list of (str, tuple)
    Typically `data_batch.provide_label`.
    """
    assert self.binded, 'call bind before switching bucket'
    if not bucket_key in self._buckets: # 这里体现了 DataIter 中的 bucket 的作用
    symbol, data_names, label_names = self._sym_gen(bucket_key)
    module = Module(symbol, data_names, label_names,
    logger=self.logger, context=self._context,
    work_load_list=self._work_load_list)
    module.layout_mapper = self.layout_mapper
    module.bind(data_shapes, label_shapes, self._curr_module.for_training,
    self._curr_module.inputs_need_grad,
    force_rebind=False, shared_module=self._curr_module)
    self._buckets[bucket_key] = module
    # self._buckets 存放的是 bind 了 input shape 的 forward model. 在 forward 的时候只需要把输入 feed 到模型中就可以了.
    self._curr_module = self._buckets[bucket_key]
  2. 参数共享
    这里, 首先要理解的一点是, MNXet 根据节点, 或者说是 Variable 的 name 去初始化/寻找权值参数. 例如, 在 transfer learning 中, 如果在 pretrain 的模型中找到相同的 name 那么就用 pretrain 的模型中的数值去初始化待训练的 Symbol 的参数, 如果找不到, 使用默认的方法初始化. 这里也是类似的方法.

    1
    2
    3
    4
    5
    6
    7
    8
    9
    10
    11
    12
    13
    14
    15
    16
    17
    18
    19
    20
    21
    22
    23
    24
    25
    26
    27
    28
    29
    30
    31
    32
    33
    34
    35
    36
    37
    38
    39
    40
    41
    42
    43
    44
    45
    46
    47
    48
    49
    50
    51
    52
    53
    54
    55
    56
    57
    58
    def bind(self, data_shapes, label_shapes=None, for_training=True,
    inputs_need_grad=False, force_rebind=False, shared_module=None,
    grad_req='write'):
    """Binding for a `BucketingModule` means setting up the buckets and bind the
    executor for the default bucket key. Executors corresponding to other keys are
    binded afterwards with `switch_bucket`.
    Parameters
    ----------
    data_shapes : list of (str, tuple)
    This should correspond to the symbol for the default bucket.
    label_shapes : list of (str, tuple)
    This should correspond to the symbol for the default bucket.
    for_training : bool
    Default is `True`.
    inputs_need_grad : bool
    Default is `False`.
    force_rebind : bool
    Default is `False`.
    shared_module : BucketingModule
    Default is `None`. This value is currently not used.
    grad_req : str, list of str, dict of str to str
    Requirement for gradient accumulation. Can be 'write', 'add', or 'null'
    (default to 'write').
    Can be specified globally (str) or for each argument (list, dict).
    """
    # in case we already initialized params, keep it
    if self.params_initialized:
    arg_params, aux_params = self.get_params()
    # force rebinding is typically used when one want to switch from
    # training to prediction phase.
    if force_rebind:
    self._reset_bind()
    if self.binded:
    self.logger.warning('Already binded, ignoring bind()')
    return
    assert shared_module is None, 'shared_module for BucketingModule is not supported'
    self.for_training = for_training
    self.inputs_need_grad = inputs_need_grad
    self.binded = True
    symbol, data_names, label_names = self._sym_gen(self._default_bucket_key)
    module = Module(symbol, data_names, label_names, logger=self.logger,
    context=self._context, work_load_list=self._work_load_list)
    module.layout_mapper = self.layout_mapper
    module.bind(data_shapes, label_shapes, for_training, inputs_need_grad,
    force_rebind=False, shared_module=None, grad_req=grad_req)
    self._curr_module = module
    self._buckets[self._default_bucket_key] = module
    # copy back saved params, if already initialized
    # 如果参数被训练过/初始化过. 那么, 直接把参数迁移过来. 实现了参数共享.
    if self.params_initialized:
    self.set_params(arg_params, aux_params)