MXNet 源码详解 -- 自定义 Operator
本文是 MXNet 源码解读系列的第二篇, 本篇介绍如何使用 C++ 实现新的 Operator.
随着 MXNet 的演进, 在 MXNet 中实现新的 Operator 有两种方式, 一种是早期的版本, 通过定义计算类和相应的属性类的方式, 另外一种是现在 MXNet 中在使用的通过 nnvm 的方式. MXNet 仍然对早期通过类定义的 Operator 的方式提供良好的兼容. 本文将详细介绍如何通过类的方式实现新的 Operator. 具体来说本文通过实现一个简单的量化训练 Operator 来详细解释如何在 MXNet 中实现一个新的 Operator.
量化训练
关于量化训练的详细内容可以参考其它的文章, 简单来说, 本文将实现的 Operator 为:
$$
y = \text{clip}(\text{round}(s \cdot x), -127, 127)
$$
在 MXNet 中实现一个新的 Operator 需要实现 3 个主要的部分, Operator 的属性参数, Operator 的计算, Operator 的属性信息. 接下来逐个介绍.
属性参数定义
Operator 的属性参数定义了 Operator 的一般属性, 例如, 在 Convolution 中的, kernel_size, num_filter, strides, padding 等等这些信息, 都属于属性参数. 在本文的量化训练的例子中, scale 是属性参数. 属性的参数定义通过继承 dmlc::Parameter来实现.
1 | struct QuantiParam : dmlc::Parameter<QuantiParam> { |
关于 Parameter 内部的实现, 可以参考源码或者关注后续的文章, 当前可以先照猫画虎实现对应的 Parameter. 其中, 要注意的一点是, 在 Parameter 中定义的所有的参数, 都必须支持 >> 这个操作符, 这是因为, 在 MXNet 的传递中使用的是 std::pair<std::string, std::string>, 在 Parameter 内部解析的时候, 需要把的std::string 类型转换成对应的参数类型, 底层是通过 >> 这个操作符实现的.
计算的定义
对于一个 Operator 来说, 最重要的就是定义该 Operator 的计算是什么, 在 MXNet 中, 我们定义的 Operator 通常都是要用来训练神经网络的, 因此, 我们要同时定义 Forward 和 Backward, 如果明确不需要 Backward 的时候, 也可以不实现 Backward. 例如, 在部署量化模型的预测库中, 因为只需要跑前向计算, 因此, 我们就可以只实现 Forward 而不用实现 Backward. 这里, 因为我定义的是量化训练的 Operator, 因此, 我会同时实现 Forward 和 Backward.
Forward
1 | virtual void Forward(const OpContext& ctx, |
首先我们关注 Forward 的接口, ctx 是用来定义 Operator 是在什么设备上跑的, 例如 CPU 还是 GPU, in_data 是该计算所需要输入的数据, 例如, 在 Convolution 中, in_data 就是计算卷积 data, weight 和 bias, 在本文的量化 Operator 中是输入数据data, 也就是上述公式中的 x, req 是用来提示 memory 的复用关系, 它定义了 Operator 的计算结果是如何放到out_data中去的, 这个参数基本上不需要用户干预, 框架会自己维护该参数. out_data 是计算的输出. aux_states 是辅助参数, 例如, 在 BatchNorm 中的running_mean, running_var 等都是 aux_states, 简单来说, aux_states 是模型的参数, 但是, 他和 in_data 中的参数的不同点是, aux_states 不需要计算梯度.
第 10-11 行: 把 TBlob 类型的输入数据转换成 mshadow::Tensor 类型, 其目的是方便使用 mshadow 进行计算. Tensor 并没有存储具体的数据, 只是保存了指向数据的指针, 因此, 这个转换和具体的计算相比, 其代价可以忽略不计. mshadow 是一个实现了 lazy compute 的Tensor计算库, 在 MXNet 中, 熟练掌握 shadow 能给实现 Operator 带来巨大的便利.
第 13-14 行: 量化训练的实现, F 是 mshadow 中定义的一个接口, 在这里我们完全使用了mshadow完成了具体的计算, 并且, 通过mshadow实现的代码是可以同时在 CPU 和 GPU 上面运行的, 因此, 我们不需要单独再实现 GPU 计算的代码了.
Backward
因为我们要实现量化训练, 因此, 这里我们还需要实现对应的 Backward 的计算. 我们这里实现一个最简单的版本–其反向的梯度永远是 1.
$$
\text{grad} = 1
$$
1 | virtual void Backward(const OpContext& ctx, |
Forward 的计算中我们没有展示req是如何使用的, 在这里, 我特意使用了 req 来表明计算出来的梯度要如何放到 dgrad 中. 如果深究的话, OpReqType 有以下几个选项: kNullOp(什么也不做), kWriteTo(把梯度写到 dgrad 中), kWriteInplace(原地写, 这个要配合 Operator 属性定义中的 InplaceOption 使用), kAddTo(和原来的数据相加)
属性定义
在QuantiParam中我们已经定义了 Operator 的一些属性, 在QuantiParam中定义的参数, 最终都会暴露给用户, 用户可以根据具体的需求给参数设置不同的值. 这里定义的属性, 和 Operator 的计算行为没有关系, 不管这里如何定义, 都不应该影响 Operator 的计算行为 (如果影响了, 那么说明代码里有 bug). 这里的属性, 有些是需要在运行时才知道具体行为的, OperatorProperty 中包含了 Operator 的所有的信息. OperatorProperty 的接口很多, 但是, 大部分时间我们不需要全部实现. 这里, 为了简单, 我们只实现几个必要的接口. 其它的接口有的是为了进行 memory 优化的, 有的是用来定义 Operator 的数据类型的, 可以在需要的时候做具体的实现, 尤其是InferType接口, 在实现量化预测库的时候, 因为所有的数据都是使用的不同位宽的整型数据, 在实现的时候需要在InferType中做具体的推断.
1 | class QuantiProp : public OperatorProperty { |
在这里最重要的就是 InerShape 这个接口, 该接口给定第一个输入的 shape, 然后根据QuantiParam参数信息推导出其它输入的 shape 和输出的 shape. 本文的 Operator 非常简单, 只有一个输入和一个输出, 并且输入和输出的 shape 完全一样, 因此, 该接口实现起来非常简单. 在其它的一些 Operator 中, 例如, Convolution 中, 输入有 data, weight, bias (也有可能没有), 并且他们的 shape 都不一样, 这时候就可以在这里通过 data 的 shape 和 Convolution 的 Parameter 参数推断出 weight, bias, output 的 shape.
注册 Operator
以上我们实现了具体的 Operator, 为何运行起来, 我们还需要把 Operator 注册到框架中.
1 | Operator* QuantiProp::CreateOperatorEx(Context ctx, |
这个接口实现在大部分 Operator 中也不需要修改. CreateOp 要根据 CPU 或者 GPU 做相应的特化. 例如, 在 CPU 上的定义为:
1 | template <> |
另外注意, 这里直接把 QuantiOp 实例化为了 float 类型, 如果要支持其他类型, 比如 float16, 可以根据 dtype 来做具体的实例化.
最后就是提供给用户的接口:
1 | DMLC_REGISTER_PARAMETER(QuantiParam); |
其中, 第 3 行是 Operator 的计算数据, 第 4 行是 Operator 的属性参数.
支持 GPU 上计算
为了支持 GPU 计算, 我们还需要创建 GPU 上对应的 Operator, 和上面cpu上的CreateOp类似, 把CreateOp特化到gpu上就好了.
1 | Operator* CreateOp<gpu>(QuantiParam param, int dtype) { |
总结
本文介绍了如何动手实现 MXNet 的 Operator, 为了重点关注实现一个 Operator 中的主要工作内容, 本文没有介绍关于 memory 优化的内容 (在很多时候还是比较重要的) 以及 OperatorProperty 中其它接口的定义和实现. 本文的主要目的是用最简洁的例子和代码教大家实现一个能运行的 Operator, 至于各种优化, 各种特殊场景下的接口实现, 相信大家遇到具体需求的时候会逐渐学会.
本文所有的代码实现在这里: https://github.com/shuokay/incubator-mxnet/tree/legacy-operator/plugin/quantization