上次说到了 MXNet 的基本过程,其中最重要的一步是构建 Executor, 而构建 Executor 中最重要的就是通过 Bind 方法把 MXNet 的各个部件如 Symbol, context, NDArray 参数等数据捏合到一起,因此,要深入理解 MXNet 的计算过程,需要对该 Bind 方法进行深入研究。由于 Bind 方法中的参数,尤其是 NDArray 的参数要求是和 computation graph 的 topo 序相同,因此,直接分析 Bind 仍然不能理解这一点,所以,这里首先分析 SimpleBind, 了解 MXNet 的 Bind 需要怎样的参数输入,而 SimpleBind 又是怎样实现这一点的。
cpp 接口的 Bind 调用的是 MXNet 提供的 C API, 所以,具体的 C API 细节先不讨论,主要关注 MXNet 本身。
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18
| inline Executor *Symbol::SimpleBind( const Context &context, const std::map<std::string, NDArray> &args_map, const std::map<std::string, NDArray> &arg_grad_store, const std::map<std::string, OpReqType> &grad_req_type, const std::map<std::string, NDArray> &aux_map) { std::vector<NDArray> arg_arrays; std::vector<NDArray> grad_arrays; std::vector<OpReqType> grad_reqs; std::vector<NDArray> aux_arrays;
InferExecutorArrays(context, &arg_arrays, &grad_arrays, &grad_reqs, &aux_arrays, args_map, arg_grad_store, grad_req_type, aux_map);
return new Executor(*this, context, arg_arrays, grad_arrays, grad_reqs, aux_arrays); }
|
SimpleBind
的实现如上,其中,这里最重要的是 InferExecutorArrays
方法,该方法实现了网络需要的arg_arrays
, arg_arrays
, grad_reqs
, aux_arrays
的推导。注意,这里的已知条件是三个 std::map
, 其 key 值需要与定义的 Symbol 对应 (如前一篇所述). MXNet cpp API 正式根据这些 key 进行数据绑定的。
既然这次主要关注 MXNet Bind 的参数,那么,接下来看一下 InferExecutorArrays
中发生了什么。
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71
| inline void Symbol::InferExecutorArrays( const Context &context, std::vector<NDArray> *arg_arrays, std::vector<NDArray> *grad_arrays, std::vector<OpReqType> *grad_reqs, std::vector<NDArray> *aux_arrays, const std::map<std::string, NDArray> &args_map, const std::map<std::string, NDArray> &arg_grad_store, const std::map<std::string, OpReqType> &grad_req_type, const std::map<std::string, NDArray> &aux_map) const {
const auto arg_name_list = ListArguments(); std::vector<std::vector<mx_uint> > in_shapes, aux_shapes, out_shapes; std::map<std::string, std::vector<mx_uint> > arg_shapes;
for (const auto &arg_name : arg_name_list) { auto iter = args_map.find(arg_name); if (iter != args_map.end()) { arg_shapes[arg_name] = iter->second.GetShape(); } } InferShape(arg_shapes, &in_shapes, &aux_shapes, &out_shapes);
for (size_t i = 0; i < in_shapes.size(); ++i) { const auto &shape = in_shapes[i]; const auto &arg_name = arg_name_list[i]; auto iter_arg = args_map.find(arg_name); if (iter_arg != args_map.end()) { arg_arrays->push_back(iter_arg->second) } else { arg_arrays->push_back(NDArray(shape, context, false)); NDArray::SampleGaussian(0, 1, &arg_arrays->back()); } auto iter_grad = arg_grad_store.find(arg_name); if (iter_grad != arg_grad_store.end()) { grad_arrays->push_back(iter_grad->second); } else { grad_arrays->push_back(NDArray(shape, context, false)); } auto iter_req = grad_req_type.find(arg_name); if (iter_req != grad_req_type.end()) { grad_reqs->push_back(iter_req->second); } else if (arg_name.rfind("data") == arg_name.length() - 4 || arg_name.rfind("label") == arg_name.length() - 5) { grad_reqs->push_back(OpReqType::kNullOp); } else { grad_reqs->push_back(OpReqType::kWriteTo); } } const auto aux_name_list = ListAuxiliaryStates(); for (size_t i = 0; i < aux_shapes.size(); ++i) { const auto &shape = aux_shapes[i]; const auto &aux_name = aux_name_list[i]; auto iter_aux = aux_map.find(aux_name); if (iter_aux != aux_map.end()) { aux_arrays->push_back(iter_aux->second); } else { aux_arrays->push_back(NDArray(shape, context, false)); NDArray::SampleGaussian(0, 1, &aux_arrays->back()); } } }
|
这段代码至少说明了 2 点问题:
- 这里对 Symbol(更严格的说是 computation graph) 进行了 topo 排序,参数的 Symbol 的输入和 NDArray 的绑定都是按照该唯一的 topo 序进行的。简单来说,topo 排序体现在
const auto arg_name_list = ListArguments();
这一语句中。ListArguments()
是下面这样的调用关系:
Symbol::ListArguments
–>MXSymbolListArguments
–>NNSymbolListInputNames
–>Symbol::NNSymbolListInputNames
–>Symbol::ListInputs
–>DFSVisit
Symbol::ListArguments
的 Symbol
是 cpp API 定义的 Symbol, 后面的 Symbol 是 MXNet 本身定义的 Symbol.
DFSVisit
是对 computation graph 的深度优先遍历,遍历结果就是输出该图的唯一的拓扑排序
- NDArray 和 Symbol 是按照 name 进行对应的。
另外,需要关注的另外一个点是 InferShape(arg_shapes, &in_shapes, &aux_shapes, &out_shapes);
, 该语句最终调用了每个 operator 的 InferShape
方法。举例来讲,FullyConnectedOp
的 InferShape
(定义在FullyConnectedProp
中) 是:
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28
| bool InferShape(std::vector<TShape> *in_shape, std::vector<TShape> *out_shape, std::vector<TShape> *aux_shape) const override { using namespace mshadow; if (!param_.no_bias) { CHECK_EQ(in_shape->size(), 3U) << "Input:[data, weight, bias]"; } else { CHECK_EQ(in_shape->size(), 2U) << "Input:[data, weight]"; } CHECK_EQ(out_shape->size(), 1U); TShape dshape = (*in_shape)[fullc::kData]; TShape oshape = (*out_shape)[0]; if (dshape.ndim() == 0) return false;
index_t num_input = dshape.ProdShape(1, dshape.ndim()); SHAPE_ASSIGN_CHECK(*in_shape, fullc::kWeight, Shape2(param_.num_hidden, num_input)); if (!param_.no_bias) { SHAPE_ASSIGN_CHECK(*in_shape, fullc::kBias, Shape1(param_.num_hidden)); }
SHAPE_ASSIGN_CHECK(*out_shape, 0, Shape2(dshape[0], param_.num_hidden)); if (oshape.ndim() != 0) { dshape[0] = oshape[0]; SHAPE_ASSIGN_CHECK(*in_shape, fullc::kData, dshape); } return true; }
|
这里,关键的代码是SHAPE_ASSIGN_CHECK(*in_shape, fullc::kWeight, Shape2(param_.num_hidden, num_input));
, SHAPE_ASSIGN_CHECK
是一个宏定义,完成了分配/一致性检查的工作, 具体如下:
1 2 3 4 5 6 7 8 9
| #define SHAPE_ASSIGN_CHECK(shape_array, index, shape) \ { \ if (!shape_assign(&(shape_array)[index], TShape(shape))) { \ std::ostringstream os; \ os << "Shape inconsistent, Provided=" << (shape_array)[index] << ',' \ << " inferred shape=" << shape; \ throw ::MXNet::op::InferShapeError(os.str(), index); \ } \ }
|
这个宏定义体现了 MXNet 中宏的飘逸的用法,更加飘逸的用法还有很多。
shape_assign
的定义如下,其功能还是,如果 y 是空的,也就是y->ndim()==0
那么,把 x 赋值给 y 否则,检查 x 和 y 的一致性
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17
| inline bool shape_assign(TShape *y, const TShape& x) { if (y->ndim() == 0) { *y = x; return true; } else if (y->ndim() != x.ndim()) { return x.ndim() == 0; } else { for (size_t i = 0; i < y->ndim(); ++i) { if ((*y)[i] == 0) { (*y)[i] = x[i]; } else if ((*y)[i] != x[i] && x[i] != 0) { return false; } } return true; } }
|
现在,我们已经了解了 2 个重要事实:
- MXNet 中所有的输入参数 NDArray 是按照拓扑序和 Symbol 对应起来的
- (optimal) cpp API 和 Python API 中的 simple_bind 按照 NDArray 和 Symbol 的 name 进行对应
- 给定必要的参数 NDArray, MXNet 可以自动推导出其它参数信息。举例来说,一个全连接,给定出入输出节点的个数,可以推导出 w 的 shape