在上一篇中介绍了如何使用 MXNet 的 legacy 接口实现一个 Operator, 本文介绍如何使用 nnvm 接口实现 Operator.
本文仍然以实现一个量化训练的 Operator 为例,具体的数学过程参考上一篇。
使用 nnvm 实现 Operator 和原来的 legacy 接口要实现的内容比较类似,只不过 nnvm 方式所有的接口都是使用函数实现的,因此,在 Operator 中无法记录 Operator 的中间状态,所有需要记录的状态都要在外部初始化,然后以参数的形式传到相应的函数中进行更新。
nnvm 的方法实现 Operator 同样需要实现 3 个主要的部分,Parameter 的定义,Operator 的计算,以及 Operator 的属性。
由于量化训练的原理在上一篇中已经讲过了,因此,在这片文章中只关注具体的代码实现。
属性参数定义
属性参数 Parameter 的定义和之前的完全一致,这里直接跳过了。
计算的定义
计算的定义接口和 legacy 的方法也几乎相同,只不过,这里不需要继承 Operator 类了,而是直接实现具体的计算函数。
Forward
这里为了简单,同样默认只支持了 float
类型。
1 | template <typename xpu> |
Backward
在反向中,需要特别注意的一点是,对于网络来说,输出的节点的梯度是放在 inputs 中的,而输入节点的 grad 是放在 ouputs 中的,这是因为,在 nnvm 中,由于计算图的关系,在计算梯度的时候,我们是从前向图的输出节点的 grad 去计算输入节点的 grad, 因此,计算梯度的计算图恰好和前向图是反的。在 gradient graph 中,原始数据的输入对应的梯度体现在计算图中是输出,原始数据的输出对应的梯度体现在计算图中是出入。
1 | template <typename xpu> |
计算属性
同样的,我们也需要对计算定义一些属性,类似上一篇中的 InferShape, InferType 等,nnvm 简化了这部分的工作,直接注册成对应的方法就可以了。
1 | NNVM_REGISTER_OP(Quanti) |
从上面的代码可以看到几个比较重要的地方:InferShape, InferType 和上一篇一样,也是通过输入的 Shape 和 Type 去推导其它输入输出数据的 Shape 和 Type, FCompute 用来指定该 Operator 的具体的计算的实现。整体上看,和上一篇基于类的方法要完成的工作是一样的,只不过换了一种方式。
最后在对应的 .cu 文件中指定在 GPU 上如何计算之后,整个 Operator 就完成了。
1 | NNVM_REGISTER_OP(Quanti).set_attr<FCompute>("FCompute<gpu>", QuantizationCompute<gpu>); |
如何选择
最后说一说在实现具体的 Operator 的时候,是选择基于上一篇介绍的基于 class 的方法还是选择本篇介绍的基于 nnvm 的方法。
理论上,在写一个 Operator 的时候,使用两种方法的任意一种都是可以的,MXNet 需要做到兼容上一篇的 class 的方法。两种方法各有优缺点,但是,我个人推荐使用 nnvm 的这种方法。
- 由于 nnvm 兼容之前的代码做的还不够完备,如果现在使用 class 方法实现 Operator 的话可能会出现一些不可预料的问题。下面会举例说明我遇到的一些兼容方面的问题。
- 使用 nnvm 的方法的话 Operator 写起来更方面一些。模式也更加统一。
nnvm 兼容旧代码的问题
到目前为止,我至少碰到了三次因为兼容做的不完备导致而踩到的坑。
一次是 InplaceOption
的坑,由于没有兼容InplaceOption
选项,网络在设置 OpReqType 的时候出现了问题,训练的时候直接崩溃。解决方法是在 legacy_op_util.cc
中加上对于 InplaceOption
的兼容。
第二次的问题就更隐蔽了,在使用 CacheOp 的时候,由于我需要在运行之前,根据用户的设置,修改 graph 的一些属性,因此,在我拿到 CacheOp 的 graph 之后,我会去修改 graph 中节点的一些 attrs 并且保存该 graph. 但是,由于兼容的问题,在我修改了 graph 之后,虽然落盘的 graph 显示是已经修改了 attrs, 但是,实际上在内存中用于训练的 graph 并没有变更 attrs. 最终,导致,模型训练完之后再 load 模型进行 inference 的时候结果完全不符合预期。解决方法大概有两个,一个是把 graph 保存成 json 文件之后重新 load 该文件去 build graph, 另外一种方法是在 MXNet 中 fix, 见 https://github.com/shuokay/incubator-mxnet/commit/9865ce9213af956bb200eeb414daff31055d9096
第三次是,由于 nnvm 的模式不保存内部状态,因此导致,之前用 class 方法写的带有内部状态的 Operator 全部都需要做相应的修改,去掉内部状态。