MXNet 源码分析-SimpleBind

上次说到了 MXNet 的基本过程, 其中最重要的一步是构建 Executor, 而构建 Executor 中最重要的就是通过 Bind 方法把 MXNet 的各个部件如 Symbol, context, NDArray 参数等数据捏合到一起, 因此, 要深入理解 MXNet 的计算过程, 需要对该 Bind 方法进行深入研究. 由于 Bind 方法中的参数, 尤其是 NDArray 的参数要求是和 computation graph 的 topo 序相同, 因此, 直接分析 Bind 仍然不能理解这一点, 所以, 这里首先分析 SimpleBind, 了解 MXNet 的 Bind 需要怎样的参数输入, 而 SimpleBind 又是怎样实现这一点的.

cpp 接口的 Bind 调用的是 MXNet 提供的 C API, 所以, 具体的 C API 细节先不讨论, 主要关注 MXNet 本身.

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
// symbol.hpp
inline Executor *Symbol::SimpleBind(
const Context &context, const std::map<std::string, NDArray> &args_map,
const std::map<std::string, NDArray> &arg_grad_store,
const std::map<std::string, OpReqType> &grad_req_type,
const std::map<std::string, NDArray> &aux_map) {
std::vector<NDArray> arg_arrays;
std::vector<NDArray> grad_arrays;
std::vector<OpReqType> grad_reqs;
std::vector<NDArray> aux_arrays;
InferExecutorArrays(context, &arg_arrays, &grad_arrays, &grad_reqs,
&aux_arrays, args_map, arg_grad_store, grad_req_type,
aux_map);
return new Executor(*this, context, arg_arrays, grad_arrays, grad_reqs,
aux_arrays);
}

SimpleBind 的实现如上, 其中, 这里最重要的是 InferExecutorArrays 方法, 该方法实现了网络需要的arg_arrays, arg_arrays, grad_reqs, aux_arrays 的推导. 注意, 这里的已知条件是三个 std::map, 其 key 值需要与定义的 Symbol 对应(如前一篇所述). MXNet cpp API 正式根据这些 key 进行数据绑定的.
既然这次主要关注 MXNet Bind 的参数, 那么, 接下来看一下 InferExecutorArrays 中发生了什么.

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
inline void Symbol::InferExecutorArrays(
const Context &context, std::vector<NDArray> *arg_arrays,
std::vector<NDArray> *grad_arrays, std::vector<OpReqType> *grad_reqs,
std::vector<NDArray> *aux_arrays,
const std::map<std::string, NDArray> &args_map,
const std::map<std::string, NDArray> &arg_grad_store,
const std::map<std::string, OpReqType> &grad_req_type,
const std::map<std::string, NDArray> &aux_map) const {
// 获取该 Symbol 所有的参数, 结果是按照拓扑序排序的
const auto arg_name_list = ListArguments();
std::vector<std::vector<mx_uint> > in_shapes, aux_shapes, out_shapes;
std::map<std::string, std::vector<mx_uint> > arg_shapes;
// 获取用户层给的输入 NDArray 参数的 shape, 使用该 shape 可以推导出中间层的 shape
for (const auto &arg_name : arg_name_list) {
auto iter = args_map.find(arg_name);
if (iter != args_map.end()) {
arg_shapes[arg_name] = iter->second.GetShape();
}
}
// 推导中间层的 shape. 如果用户没有给出中间层参数, 那么 MXNet 计算需要的参数的 shape 信息, 如果用户给出了, 那么, 这里检查
// 确保用户给定的 NDArray 的 shape 和 MXNet 推导出来的一致.
InferShape(arg_shapes, &in_shapes, &aux_shapes, &out_shapes);
// 上面已经推导出来了所有输入的 shape, 并且分配了 memory, 放在了 in_shapes 中,
// 那么, 这里, 就要把 Symbol 和 和 NDArray 按照 key(也就是 name) 对应起来
for (size_t i = 0; i < in_shapes.size(); ++i) {
const auto &shape = in_shapes[i];
const auto &arg_name = arg_name_list[i];
// args_map 中是用户层给定的 NDArray, 按照名字 push_back 到 arg_arrays 中去
auto iter_arg = args_map.find(arg_name);
if (iter_arg != args_map.end()) {
arg_arrays->push_back(iter_arg->second)
// 如果不在用户层给的 NDArray 中, 那么, 这里 MXNet 完成 memory 申请, 并且把 NDArray push_back
} else {
arg_arrays->push_back(NDArray(shape, context, false));
NDArray::SampleGaussian(0, 1, &arg_arrays->back());
}
// 存储梯度的 NDArray 的处理, 参数 NDArray 逻辑相同
auto iter_grad = arg_grad_store.find(arg_name);
if (iter_grad != arg_grad_store.end()) {
grad_arrays->push_back(iter_grad->second);
} else {
grad_arrays->push_back(NDArray(shape, context, false));
}
// 同样的, req_type 的逻辑也是相同的
auto iter_req = grad_req_type.find(arg_name);
if (iter_req != grad_req_type.end()) {
grad_reqs->push_back(iter_req->second);
} else if (arg_name.rfind("data") == arg_name.length() - 4
|| arg_name.rfind("label") == arg_name.length() - 5) {
grad_reqs->push_back(OpReqType::kNullOp);
} else {
grad_reqs->push_back(OpReqType::kWriteTo);
}
}
// auxiliary state 单独处理, 相同的逻辑
const auto aux_name_list = ListAuxiliaryStates();
for (size_t i = 0; i < aux_shapes.size(); ++i) {
const auto &shape = aux_shapes[i];
const auto &aux_name = aux_name_list[i];
auto iter_aux = aux_map.find(aux_name);
if (iter_aux != aux_map.end()) {
aux_arrays->push_back(iter_aux->second);
} else {
aux_arrays->push_back(NDArray(shape, context, false));
NDArray::SampleGaussian(0, 1, &aux_arrays->back());
}
}
}

这段代码至少说明了 2 点问题:

  1. 这里对 Symbol(更严格的说是 computation graph) 进行了 topo 排序, 参数的 Symbol 的输入和 NDArray 的绑定都是按照该唯一的 topo 序进行的. 简单来说, topo 排序体现在 const auto arg_name_list = ListArguments(); 这一语句中. ListArguments() 是下面这样的调用关系:
    Symbol::ListArguments–>MXSymbolListArguments–>NNSymbolListInputNames–>Symbol::NNSymbolListInputNames–>Symbol::ListInputs–>DFSVisit
    • Symbol::ListArgumentsSymbol 是 cpp API 定义的 Symbol, 后面的 Symbol 是 MXNet 本身定义的 Symbol.
    • DFSVisit 是对 computation graph 的深度优先遍历, 遍历结果就是输出该图的唯一的拓扑排序
  2. NDArray 和 Symbol 是按照 name 进行对应的.

另外, 需要关注的另外一个点是 InferShape(arg_shapes, &in_shapes, &aux_shapes, &out_shapes);, 该语句最终调用了每个 operator 的 InferShape 方法. 举例来讲, FullyConnectedOpInferShape(定义在FullyConnectedProp中)是:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
bool InferShape(std::vector<TShape> *in_shape,
std::vector<TShape> *out_shape,
std::vector<TShape> *aux_shape) const override {
using namespace mshadow;
if (!param_.no_bias) {
CHECK_EQ(in_shape->size(), 3U) << "Input:[data, weight, bias]";
} else {
CHECK_EQ(in_shape->size(), 2U) << "Input:[data, weight]";
}
CHECK_EQ(out_shape->size(), 1U);
TShape dshape = (*in_shape)[fullc::kData];
TShape oshape = (*out_shape)[0];
// require data to be known
if (dshape.ndim() == 0) return false;
index_t num_input = dshape.ProdShape(1, dshape.ndim());
SHAPE_ASSIGN_CHECK(*in_shape, fullc::kWeight, Shape2(param_.num_hidden, num_input));
if (!param_.no_bias) {
SHAPE_ASSIGN_CHECK(*in_shape, fullc::kBias, Shape1(param_.num_hidden));
}
SHAPE_ASSIGN_CHECK(*out_shape, 0, Shape2(dshape[0], param_.num_hidden));
if (oshape.ndim() != 0) {
dshape[0] = oshape[0];
SHAPE_ASSIGN_CHECK(*in_shape, fullc::kData, dshape);
}
return true;
}

这里, 关键的代码是SHAPE_ASSIGN_CHECK(*in_shape, fullc::kWeight, Shape2(param_.num_hidden, num_input));, SHAPE_ASSIGN_CHECK 是一个宏定义, 完成了分配/一致性检查的工作, 具体如下:

1
2
3
4
5
6
7
8
9
#define SHAPE_ASSIGN_CHECK(shape_array, index, shape) \
{ \
if (!shape_assign(&(shape_array)[index], TShape(shape))) { \
std::ostringstream os; \
os << "Shape inconsistent, Provided=" << (shape_array)[index] << ',' \
<< " inferred shape=" << shape; \
throw ::MXNet::op::InferShapeError(os.str(), index); \
} \
}

这个宏定义体现了 MXNet 中宏的飘逸的用法, 更加飘逸的用法还有很多.
shape_assign 的定义如下, 其功能还是, 如果 y 是空的, 也就是y->ndim()==0 那么, 把 x 赋值给 y 否则, 检查 x 和 y 的一致性

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
inline bool shape_assign(TShape *y, const TShape& x) {
if (y->ndim() == 0) {
*y = x;
return true;
} else if (y->ndim() != x.ndim()) {
return x.ndim() == 0;
} else {
for (size_t i = 0; i < y->ndim(); ++i) {
if ((*y)[i] == 0) {
(*y)[i] = x[i];
} else if ((*y)[i] != x[i] && x[i] != 0) {
return false;
}
}
return true;
}
}

现在, 我们已经了解了2个重要事实:

  1. MXNet 中所有的输入参数 NDArray 是按照拓扑序和 Symbol 对应起来的
  2. (optimal) cpp API 和 Python API 中的 simple_bind 按照 NDArray 和 Symbol 的 name 进行对应
  3. 给定必要的参数 NDArray, MXNet 可以自动推导出其它参数信息. 举例来说, 一个全连接, 给定出入输出节点的个数, 可以推导出 w 的 shape