Operators 是构建神经网络的必要元素, operators 定义了输入到输出的映射关系. MXNet 有一系列非常丰富的 operators, 有一些是简单的 operators, 例如 element-wise sum, 也有一些复杂的 operators 例如 convolution. 通常情况下, 使用这些 operators 可以构建大部分常见的 nn. 在 MXNet 实现的很多 operators 通常在 numpy 中是有等价形式的, 例如 repreat, tile 等. 那么, 为什么在MXNet中不直接使用 numpy 呢? 其中最重要的原因就是MXNet需要同时支持cpu和gpu运算, 而numpy目前并不支持gpu计算. 另外, 为了最大化 memory 和 runtime efficiency, MXNet 中的大量的 components 做了深度优化, 例如 tensor data structure (NDArray
), execution engine, computational graph 等等. MXNet 中实现的 operators 会综合考虑前面的各种优化从而做到性能的极致优化.
这个tutorial将会在MXNet backend 中用 C++ 实现一个operator. 之后, 使用python完成 unit test.
Implementation
An Operator Example
使用二次函数作为例子. \( f(x)=ax^2+bx+c \). 实现一个名字为 quadratic
的 operator, 要求如下:
- 输入为一个 tensor,
x
; - 输出为一个 tensor,
y
; - 满足
y.shape == x.shape
; - 把
x
中的元素输入到f
中得到相应的y
的值; a
,b
,c
是用户输入的 parameter.
在frontend, 该 op 的工作类似如下:
1 | x = [[1, 2], [3, 4]] |
实现该 op, 首先要创建 3 个文件, quadratic_op-inl.h
, quadratic_op.cc
, quadratic_op.cu
, 头文件的名字是op的前缀加 op
和 -inh
, 表示这是 op 的实现, 并且是在 CPU 和GPU 之间共享的 inline function. CPU 和 GPU 特定的实现 分别在他们各自的 .cc
和 .cu
中. 通常把 tensor 相关的operators放在 src/operator/tensor
中, nn 相关的operators放在src/operator/nn
中(目前还没有完成迁移).
接下来要完成以下几个工作:
- 在
quadratic_op-inl.h
中定义parameter struct 来注册a
,b
,c
; - 在
quadratic_op-inl.h
中定义 type 和 shape inference 的函数; - 在
quadratic_op-inl.h
中定义 forward 和 backward 的函数; - 在
quadratic.cc
和quadratic.cu
中使用 nnvm 分别注册 CPU 和 GPU 计算.
下面 step by step 地解释.
Parameter Registration
首先在 quadratic_op-inl.h
中定义 struct QuadraticParam
作为参数 a
, b
c
的 placeholder. 该 struct 继承自名字为 dmlc::Parameter
的 base template struct. 其中 template 的参数时候派生出来的 QuadraticParam. 这种技术成为 curiously recurring template pattern, 实现了 static polymorphism. 这个方法和 virtual function 很像, 但是, 节省了和 dynamic polymorphism 相关的开销.
1 | struct QuadraticParam : public dmlc::Parameter<QuadraticParam> { |
上面struct parameter 调用的函数的名字解释了它们的作用. 每一个 parameter 都设置了默认值 0, 目的是用户不需要传递 0 的参数. 对于参数如果在 runtime 是必须的, 可以不用设置默认值. 同时, 对每个参数增加了简单的描述, 因此 documentation engine 可以显示该描述(documentation engine 不在本文解释范围内)
Attribute Inference
Attribute Inference 是从用户提供的信息中推断神经网络中的 NDArray
的性质. NDArray
两种最常见的 attribute 是 data shape 和 data type. 举例来说, 给定一个 NDArray
名字为 data
, 执行 quadratic
op, output = mx.nd.quadratic(data, a=1, b=2, c=3)
. 在计算 output
之前, output
的 shape 和 type 根据 data
的shape 和 type 通过你定义的规则推理出来了, 该规则是为了给 output tensor 分配 memory.
其中需要注意的一点是, inference function 必须是可以 mutual inference 的. 即, 根据 op 的定义, 如果可能的, 可以通过一个argument 的 attribute来推理另外一个argument 的 attribute. 对于一个符号编程的nn来说, 这一点对于计算图推理 unknown attribute 非常有用. 用户可以把计算图看做是一个 symbol, 该symbol 拥有神经网络的每一个为 running data 初始化的 element, 包括 每一个 tensor 的 memory allocation, 每个 op 的 device placement 等. 用户通常只需要为计算图提供最少量的必要信息, 例如 input data shape, 计算图会利用该信息, 通过 inference function推理出 unknown attribute 来构建 nn.
例如下面的例子:
1 | import MXNet as mx |
最后一行代码是片段是包含三个 lists 的 tuple, 该 tuple 是 d.infer_shape()
返回的. 第一个 list 包含了所有的 augment a
, b
, c
的 shape, 第二个 list 包含了输出 d
的shape, 第三个 list 包含了 auxiliary 的 shape, 在这例子中没有使用, 因此是空的. 这个例子中, 只提供了 a
的第一个维度的信息和 c
的第二个 dimension 的信息, 在 shape [2, 0]
中的 0
表示该 dimension 的信息是不知道的, 在 shape (0, 3)
中的 0
也是同样的意思. 然而, symbol d
仍然成功地 infer 到了所有的 variables 的 shapes. 这就是 mutual inference 的作用. 在 MXNet 中, 整个过程可以表述为:
a
和b
是通过 element-wise multiplication operator 组合到一起的, 因此,a
和b
的 shape 应该是相同的, 因此,b
的 first dimension size 应该是2
;b
和c
是通过 element-wise multiplication operator 组合到一起的, 因此,b
和c
的 shape 应该是相同的, 因此,b
的 second dimension size 应该是3
;- 现在,
b
的 shape 已经是完全已知的了,a
和c
之前不知道的 dimension size 现在也知道了; d
是a*b
和b*c
相加的结果, 因此,d
的shape 也可以得到.
上面的是个步骤说明了 MXNet 中的 shape inference 逻辑是怎么工作的. 实际上, 它是在实现 element-wise multiplication and addition 的 shape inference function 中实现的.
对于我们的 quadratic
operator, shape inference 过程是极其类似的:
1 | inline bool QuadraticOpShape(const nnvm::NodeAttrs& attrs, |
上面的 function 需要注意以下几点:
attrs
包含了用户的输入参数a
,b
,c
. 在这里, 这三个参数没有用到, 因为对于 shape inference 来说并不依赖上述三个参数.in_attrs
是包含了 all input shapes 的 vector. 对于quadtatic
来说, 只有一个 input augment , 使用CHECK_EQ
来断言 vector 的 size 是否正确.out_attrs
是包含了 all output shapes 的 vector, 同样使用CHECK_EQ
来断言 vector 的 size- 使用
SHAPE_ASSIGN_CHECK
两次来完成 mutual inference, 一次是通过输入来 infer 输出, 一次是通过输出来 infer 输入. 如果在两个 shapes 的同一个 dimension 上有任何的非零的不相等的 values 就会抛出异常. - 在函数体的最后, 通过检查 shape 是不是非空以及 shape 的 size 是不是大于 0 来检查 output shape 是不是完全已知了. 在 MXNet 中, empty shape 意味着 shape 是 unknown 的, shape 中的 0 意味着 the size of that dimension is unknown. 这两种情形中的 missed information 必须要通过其它的 shapes 信息来 infer 到, 否则, 函数返回
false
来表示 shape inference 失败. - 对于 element-wise operators 的 mutual inference, MXNet 提供了如下的更简便的函数实现. 用户可以在 operator registration 中通过使用
n_in=1
和n_out=1
实例化该函数来取代上面的函数QuadraticOpShape
. 这里的QuadraticOpShape
只是为了解释方便.
1 | template<int n_in, int n_out> |
同样的逻辑也适用于 data type inference. 下面的 code sample 分析留给读者, 注意, -1
表示 data type unknown and must be inferred from other input or output data types.
1 | inline bool QuadraticOpType(const nnvm::NodeAttrs& attrs, |
同样的, 对于 element-wise operators MXNet 提供了下面的简单的函数来完成 mutual inference. 用户可以在 operator registration 中使用.
1 | template<int n_in, int n_out> |
Forward Function
Forward function 定义了 nn 中前向传播中 operator 的行为, forward function 的 signature 是固定的:
1 | void (const nnvm::NodeAttrs& attrs, |
下面是整个的 forward function code:
1 | template<typename xpu> // 1 |
- Line 1:
xpu
表示 generic device type, 从而该函数可以通过cpu
和gpu
来实例化成 支持 CPU 和 GPU 计算. 该实例化发生在.cc
和.cu
中注册 operator 的时候 - Line2:
attrs
是 node attribute, 包含了用户的输入参数a
,b
,c
. 这里的 node 代表了在整个 nn 的 computational graph 中该 operator 的 placeholder. - Line3:
ctx
包含了称为stream
的用来序列化异步执行的东西. 举例来说, 我们想使用和 CPU 上相同的stream
来 launch 多个 GPU kernels, 尽管 launch 操作是非阻塞的,stream
保证了 kernel 在 GPU 上执行的顺序和在 CPU 上执行的顺序是相同的. - Line4:
inputs
是 input tensors 的 vector (在 quadratic 中只有一个 input tensor) - Line5:
req
是OpReqType
value 的 vector, 每一个 value 定义了计算得到的结果如何写入到 output tensors 中. 因此,req
的数量必须和 output tensors 的数量相同. MXNet 目前支持三种类型的req
:null
,write
,add
.null
表示跳过计算对应的 output tensor,write
表示使用该 operator 的计算结果来覆盖当前 output tensors 中的值,add
表示把该 operator 的计算结果加到 output tensors 中去. 注意,null
和add
一般只会出现在 backward 中.null
通常用来跳过计算 un-learnable parameters(例如 index arrays),add
通常累加整个网络中的 gradients. - Line 6:
outputs
是 output tensors 的 vector (在quadratic
中只有一个 output tensor) - Lines 7-9: 检查每个 vector 的 size;
- Line 10: 为了 launch kernels, 从
ctx
中获取stream
- Lines 11-12: 为了后续编码方便, 定义 input tensors 和output tensors 的引用.
TBlob
可以看做是不同 dimension 的 tensors 的一个统一的数据结构, 从而具有不同 dimension 的tensors 可以放到一个同族的 container 中去, 例如std::vector
和std::list
. 通过TBlob
的get_with_shape
借口可以 get 到 tensors of desired dimension. - Line 13: 从 node attribute 中回去用户的 input parameters.
- Line 15-21: 这里是完成数学表达式计算的地方.
MSHADOW_TYPE_SWITCH
和MXNET_ASSIGN_REQ_SWITCH
两个宏似的代码可以支持 MXNet 的所有的 data types 和req
types. 在最里面的宏中, 我们 launch 到 kernel 从而计算output tensor, 每一个线程从 input tensor 中取一个 element, 输入到 quadratic function 中, 然后根据req
的值把结果赋值到 output tensor. 其中,Kernel::Launch
作为一个统一的借口来 launch parallel computation on both CPU and GPU. 因为在 CPU 和 GPU 上 parallelization approachs 经常是相同的, 因此, 这种方法使得大部分的 simple operators 可以在 CPU 和 GPU 上共享代码. kernel function 的定义如下, 其中, 函数Map
被每一个线程针对每一个输入元素执行. 其中的几个宏解释如下: (1)MSHADOW_XINLINE
是个强化宏用来 inline CPU 和 GPU 编译的 function. 它使得 CPU 计算和 GPU 计算可以共享相同的代码. (2)KERNEL_ASSIGN
宏的作用是统一不同的req
的语句到相同的一行代码中. 之所以被命名为KERNEL_ASSIGN
是因为我们称用来并行计算的代码为 kernels. 在 CPU 上, kernel 使用 OpenMP 的parallel
directive 来 wrap, 在 GPU 上, 它们是通过 CUDA library launch 的 kernel functions.
1 | template<int req> |
Backward Function
Backward function 的作用是在整个网络中传递最后一层输出的 loss function 的导数. 整个过程一般称为反向传播, 这里不会解释反向传播的具体理论. 这里要解决的问题是, 给定 operator 的 output 的 loss function 的 gradient(使用 tensor 表示), 计算该 operator 的输入的 gradient. 因为 a
, b
, c
是用户输入的不可训练的参数, 因此, 不需要计算 loss function 对于 a
, b
, c
的导数. 给定 dL/dy
和 y=a*x^2+b*x+c
, 其中 L
代表 loss function, y
代表 quadratic tensor 的输出, 需要计算 dL/dx
. 使用链式法则, 可以得到
1 | dL/dx = dL/dy * dy/dx = dL/dy * (2*a*x + b). |
上面的表达式表明, dL/dx
依赖于 output tensor 的 gradient 和 input tensor 的 value. backward function 的 signature 和 forward 的相同.
1 | template<typename xpu> // 1 |
- Lines 1-6: 同 forward.
- Lines 7-9: 检查 function arguments. 需要注意的一点是, 因为 input 的 gradient 同时依赖于 gradient of output 和 input tensor,
inputs
必须包含两个TBlob
对象. - Line 10: 同 forward
- Lines 11-13: 为了简化后面的代码, 使用
out_grad
来表示gradient of the operator output,in_data
表示 input of the operator,in_grad
表示 gradient of the operator input. - Line 14: get the parameter of object of
QuadraticParam
- Lines 16-22: 同 forward. this is where parallel computation for
in_grad
happens. structquadratic_backward
实现了每个线程计算in_grad
中的一个元素, 如下所示:
1 | template<int req> |
Operator Registration
到目前为止, 我们实现了 quadratic
operator 的必要的数据结构和函数. 现在, 需要使用 nnvm
来把 quadratic
operator 暴露到 frontend. 可以把注册过程想象成创建 operator object 实例, 保存到 operator manager (a singleton) 中, 设置 operator instance 的 attributes.
下面的代码来自 quadratic_op.cc
中, 用来注册在 CPU 上工作的 operator.
1 | DMLC_REGISTER_PARAMETER(QuadraticParam); // 1 |
- Line 1: 注册 parameter struct
- Line 2: 注册名字为
quadratic
的operator, 方法是, 创建一个Op
类型的实例, 保存该实例到 operator manager 中, 返回刚刚创建的 operator object. - Lines 3-4: 增加描述文档作为该 operator 的 operator attribute. documentation engine 会抽取该描述文档并且显示到文档页面.
- Line 5: 给该 operator 设置 parameter struct parser. 用来解析 front 传来的
a
,b
,c
. - Line 6: 设置该 operator 的输入数量
- Line7: 设置该 operator 的输出数量
- Lines 8-11: 定义一个 function, 该 function 的作用是产生 operator input arguments 的 names, 并且 names 放在 vector 中. 这个 function 的使用场景是, add missing arguments that users did not specify when creating a symbolic operator, 例如,
quad_func = mx.sym.quadratic()
仍然是 valid 的 symbol 因为我们已经对该 computational graph 的该 operator node 增加了 attributeFListInputNames
. MXNet would add the missing argument with namequadratic0_data
, 其中, 前缀quadratic0
是 operator 的 name 加上 an index, 后缀data
来自于用户定义的FListInputName
函数的返回值. 用户仍然可以像下面这样从quad_func
生成一个 executor:
1 | quad_exe = quad_func.simple_bind(ctx=mx.cpu(), quandratic0_data=(1,)) |
- Line 12: 注册 shape inference function
- Line 13: 注册 type inference function
- Line 14: 注册 forward function
- Lines 16-19: 这是一个注册函数, 表明哪个输出 tensor 可以 reuse 哪个输入 tensor 的 memory, 从而避免为 output tensor 分配 memory. 在
quadratic
这个 op 中, 只有一个输入和一个输出 tensor, 并且输出 tensor 可以 reuse 输入 tensor 的 memory space, 因此, 返回一个存储std::pair
对的std::vector
, 其中, pair 对的作用是说,input[0]
的 memory 可以被output[0]
reuse. 这里需要注意的是, 这只是给计算图的初始化提供了一个线索, 如果有其它的 Node 依赖 input tensor, 那么, input 的 memory space 就不会被 output 覆盖. - Line 20: Define the input argument name as
data
for the operator - Line 21: Add user input parameters
a
,b
,c
as the attributes of the operator - Line 22: 注册个名字为
_backward_quadratic
的函数, 作用是完成quadratic
的 backward pass. 名字最前面的下划线的意思是该函数不是直接暴露给用户的. 内部的 backward operator 的命名习惯是在相应的 forward operator 前面加上_backward_
前缀. - Line 23: 给
_backward_quadratic
设置 parameter parser. - Line 24: 设置输入的数量
- Line 25: 设置输出的数量
- Line 26: 给 operator 添加
TIsBackward
attribute. 添加该 attribute 的原因是, shape 和 type inference passes 都需要这个 attribute 来决定图中的某个 node 是 forward node 还是 backward node. - Line 27: 注册 backward function
到目前为止, 已经完成了 CPU 上的工作, 为了让代码也能够在 GPU 上工作, 只需要在 quadratic_op.cu
中增加以下代码. 注意, forward 和 backward functions 是通过 FCompute<gpu>
而不是 FCompute<cpu>
注册的.
1 | NNVM_REGISTER_OP(quadratic) |
Unit Test
现在已经在 MXNet backend 完成了 quadratic
op 的实现, 如果使用 python, 那么, 在 import MXNet as mx
的时候, 两个运行该后端实现的 Python function 也同时生成了, 分别是用于 imperative programming 的 MXNet.ndarray.quadratic
和用于 symbolic programming 的 MXNet.symbol.quadratic
.
为了在 frontend 进行测试, 需要在在 test_operator.py
中增加下面的代码. forward 的测试比较简单, 但是, backward 的测试稍显复杂. 首先创建一个 quadratic
symbol, 然后喂到 check_numeric_gradient
中. check_numeric_gradient
做的就是在在输入上加上一个轻微的扰动, 然后通过有限微分方法得到一个输出, 把该输出和通过 backward pass 得到的输出进行比较, 如果两个输出的差值在一定的范围内就认为测试通过, 否则测试不通过.(就是常规的检验 backward 的套路了)
这里使用 mx.nd.quadratic
检查 forward function, 使用 check_numeric_gradient
检查 backward function. 在 MXNet 中海油另外两个经常用到的 utility functions, check_symbolic_forward
和 check_symbolic_backward
. 如果在单元测试中使用者两个函数, users need to pass in the operator symbols and expected results for comparison.
以上内容翻译自文档