0%

MXNet 源码分析-SimpleBind

上次说到了 MXNet 的基本过程,其中最重要的一步是构建 Executor, 而构建 Executor 中最重要的就是通过 Bind 方法把 MXNet 的各个部件如 Symbol, context, NDArray 参数等数据捏合到一起,因此,要深入理解 MXNet 的计算过程,需要对该 Bind 方法进行深入研究。由于 Bind 方法中的参数,尤其是 NDArray 的参数要求是和 computation graph 的 topo 序相同,因此,直接分析 Bind 仍然不能理解这一点,所以,这里首先分析 SimpleBind, 了解 MXNet 的 Bind 需要怎样的参数输入,而 SimpleBind 又是怎样实现这一点的。

cpp 接口的 Bind 调用的是 MXNet 提供的 C API, 所以,具体的 C API 细节先不讨论,主要关注 MXNet 本身。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
// symbol.hpp
inline Executor *Symbol::SimpleBind(
const Context &context, const std::map<std::string, NDArray> &args_map,
const std::map<std::string, NDArray> &arg_grad_store,
const std::map<std::string, OpReqType> &grad_req_type,
const std::map<std::string, NDArray> &aux_map) {
std::vector<NDArray> arg_arrays;
std::vector<NDArray> grad_arrays;
std::vector<OpReqType> grad_reqs;
std::vector<NDArray> aux_arrays;

InferExecutorArrays(context, &arg_arrays, &grad_arrays, &grad_reqs,
&aux_arrays, args_map, arg_grad_store, grad_req_type,
aux_map);

return new Executor(*this, context, arg_arrays, grad_arrays, grad_reqs,
aux_arrays);
}

SimpleBind 的实现如上,其中,这里最重要的是 InferExecutorArrays 方法,该方法实现了网络需要的arg_arrays, arg_arrays, grad_reqs, aux_arrays 的推导。注意,这里的已知条件是三个 std::map, 其 key 值需要与定义的 Symbol 对应 (如前一篇所述). MXNet cpp API 正式根据这些 key 进行数据绑定的。
既然这次主要关注 MXNet Bind 的参数,那么,接下来看一下 InferExecutorArrays 中发生了什么。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
inline void Symbol::InferExecutorArrays(
const Context &context, std::vector<NDArray> *arg_arrays,
std::vector<NDArray> *grad_arrays, std::vector<OpReqType> *grad_reqs,
std::vector<NDArray> *aux_arrays,
const std::map<std::string, NDArray> &args_map,
const std::map<std::string, NDArray> &arg_grad_store,
const std::map<std::string, OpReqType> &grad_req_type,
const std::map<std::string, NDArray> &aux_map) const {

// 获取该 Symbol 所有的参数,结果是按照拓扑序排序的
const auto arg_name_list = ListArguments();
std::vector<std::vector<mx_uint> > in_shapes, aux_shapes, out_shapes;
std::map<std::string, std::vector<mx_uint> > arg_shapes;

// 获取用户层给的输入 NDArray 参数的 shape, 使用该 shape 可以推导出中间层的 shape
for (const auto &arg_name : arg_name_list) {
auto iter = args_map.find(arg_name);
if (iter != args_map.end()) {
arg_shapes[arg_name] = iter->second.GetShape();
}
}
// 推导中间层的 shape. 如果用户没有给出中间层参数,那么 MXNet 计算需要的参数的 shape 信息,如果用户给出了,那么,这里检查
// 确保用户给定的 NDArray 的 shape 和 MXNet 推导出来的一致。
InferShape(arg_shapes, &in_shapes, &aux_shapes, &out_shapes);

// 上面已经推导出来了所有输入的 shape, 并且分配了 memory, 放在了 in_shapes 中,
// 那么,这里,就要把 Symbol 和 和 NDArray 按照 key(也就是 name) 对应起来
for (size_t i = 0; i < in_shapes.size(); ++i) {
const auto &shape = in_shapes[i];
const auto &arg_name = arg_name_list[i];
// args_map 中是用户层给定的 NDArray, 按照名字 push_back 到 arg_arrays 中去
auto iter_arg = args_map.find(arg_name);
if (iter_arg != args_map.end()) {
arg_arrays->push_back(iter_arg->second)
// 如果不在用户层给的 NDArray 中,那么,这里 MXNet 完成 memory 申请,并且把 NDArray push_back
} else {
arg_arrays->push_back(NDArray(shape, context, false));
NDArray::SampleGaussian(0, 1, &arg_arrays->back());
}
// 存储梯度的 NDArray 的处理,参数 NDArray 逻辑相同
auto iter_grad = arg_grad_store.find(arg_name);
if (iter_grad != arg_grad_store.end()) {
grad_arrays->push_back(iter_grad->second);
} else {
grad_arrays->push_back(NDArray(shape, context, false));
}
// 同样的,req_type 的逻辑也是相同的
auto iter_req = grad_req_type.find(arg_name);
if (iter_req != grad_req_type.end()) {
grad_reqs->push_back(iter_req->second);
} else if (arg_name.rfind("data") == arg_name.length() - 4
|| arg_name.rfind("label") == arg_name.length() - 5) {
grad_reqs->push_back(OpReqType::kNullOp);
} else {
grad_reqs->push_back(OpReqType::kWriteTo);
}
}
// auxiliary state 单独处理,相同的逻辑
const auto aux_name_list = ListAuxiliaryStates();
for (size_t i = 0; i < aux_shapes.size(); ++i) {
const auto &shape = aux_shapes[i];
const auto &aux_name = aux_name_list[i];
auto iter_aux = aux_map.find(aux_name);
if (iter_aux != aux_map.end()) {
aux_arrays->push_back(iter_aux->second);
} else {
aux_arrays->push_back(NDArray(shape, context, false));
NDArray::SampleGaussian(0, 1, &aux_arrays->back());
}
}
}

这段代码至少说明了 2 点问题:

  • 这里对 Symbol(更严格的说是 computation graph) 进行了 topo 排序,参数的 Symbol 的输入和 NDArray 的绑定都是按照该唯一的 topo 序进行的。简单来说,topo 排序体现在 const auto arg_name_list = ListArguments(); 这一语句中。ListArguments() 是下面这样的调用关系:
    Symbol::ListArguments–>MXSymbolListArguments–>NNSymbolListInputNames–>Symbol::NNSymbolListInputNames–>Symbol::ListInputs–>DFSVisit
    • Symbol::ListArgumentsSymbol 是 cpp API 定义的 Symbol, 后面的 Symbol 是 MXNet 本身定义的 Symbol.
    • DFSVisit 是对 computation graph 的深度优先遍历,遍历结果就是输出该图的唯一的拓扑排序
  • NDArray 和 Symbol 是按照 name 进行对应的。

另外,需要关注的另外一个点是 InferShape(arg_shapes, &in_shapes, &aux_shapes, &out_shapes);, 该语句最终调用了每个 operator 的 InferShape 方法。举例来讲,FullyConnectedOpInferShape(定义在FullyConnectedProp中) 是:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
bool InferShape(std::vector<TShape> *in_shape,
std::vector<TShape> *out_shape,
std::vector<TShape> *aux_shape) const override {
using namespace mshadow;
if (!param_.no_bias) {
CHECK_EQ(in_shape->size(), 3U) << "Input:[data, weight, bias]";
} else {
CHECK_EQ(in_shape->size(), 2U) << "Input:[data, weight]";
}
CHECK_EQ(out_shape->size(), 1U);
TShape dshape = (*in_shape)[fullc::kData];
TShape oshape = (*out_shape)[0];
// require data to be known
if (dshape.ndim() == 0) return false;

index_t num_input = dshape.ProdShape(1, dshape.ndim());
SHAPE_ASSIGN_CHECK(*in_shape, fullc::kWeight, Shape2(param_.num_hidden, num_input));
if (!param_.no_bias) {
SHAPE_ASSIGN_CHECK(*in_shape, fullc::kBias, Shape1(param_.num_hidden));
}

SHAPE_ASSIGN_CHECK(*out_shape, 0, Shape2(dshape[0], param_.num_hidden));
if (oshape.ndim() != 0) {
dshape[0] = oshape[0];
SHAPE_ASSIGN_CHECK(*in_shape, fullc::kData, dshape);
}
return true;
}

这里,关键的代码是SHAPE_ASSIGN_CHECK(*in_shape, fullc::kWeight, Shape2(param_.num_hidden, num_input));, SHAPE_ASSIGN_CHECK 是一个宏定义,完成了分配/一致性检查的工作, 具体如下:

1
2
3
4
5
6
7
8
9
#define SHAPE_ASSIGN_CHECK(shape_array, index, shape)                       \
{ \
if (!shape_assign(&(shape_array)[index], TShape(shape))) { \
std::ostringstream os; \
os << "Shape inconsistent, Provided=" << (shape_array)[index] << ',' \
<< " inferred shape=" << shape; \
throw ::MXNet::op::InferShapeError(os.str(), index); \
} \
}

这个宏定义体现了 MXNet 中宏的飘逸的用法,更加飘逸的用法还有很多。
shape_assign 的定义如下,其功能还是,如果 y 是空的,也就是y->ndim()==0 那么,把 x 赋值给 y 否则,检查 x 和 y 的一致性

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
inline bool shape_assign(TShape *y, const TShape& x) {
if (y->ndim() == 0) {
*y = x;
return true;
} else if (y->ndim() != x.ndim()) {
return x.ndim() == 0;
} else {
for (size_t i = 0; i < y->ndim(); ++i) {
if ((*y)[i] == 0) {
(*y)[i] = x[i];
} else if ((*y)[i] != x[i] && x[i] != 0) {
return false;
}
}
return true;
}
}

现在,我们已经了解了 2 个重要事实:

  1. MXNet 中所有的输入参数 NDArray 是按照拓扑序和 Symbol 对应起来的
  2. (optimal) cpp API 和 Python API 中的 simple_bind 按照 NDArray 和 Symbol 的 name 进行对应
  3. 给定必要的参数 NDArray, MXNet 可以自动推导出其它参数信息。举例来说,一个全连接,给定出入输出节点的个数,可以推导出 w 的 shape