MXNet 源码分析-Bind

Bind 是 mxnet 中最重要的一个函数, 它连接起来了各个模块. Bind 启动了整个 mxnet 的工作. 因此, 在分析 mxnet 的各个模块之前最好能先对 Bind 有个初步的了解, 理解 mxnet 的各个模块是怎么串起来的. SimpleBind 主要的工作是构建和整理参数, 最后还是要把参数传递给 Bind.

SimpleBind 的调用关系如下:
SimpleBind->Executor->MXExecutorBindEX->Bind
其中, Executor 是 mxnet 的 c++ API 封装的 Executor, MXExecutorBindEX 是 mxnet 的 C API, BindExecutor::Bind

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
Executor *Executor::Bind(nnvm::Symbol symbol,
const Context& default_ctx,
const std::map<std::string, Context>& group2ctx,
const std::vector<NDArray> &in_args,
const std::vector<NDArray> &arg_grad_store,
const std::vector<OpReqType> &grad_req_type,
const std::vector<NDArray> &aux_states,
Executor* shared_exec) {
// 创建 GraphExecutor 实例
auto exec = new exec::GraphExecutor();
// 调用 Init 初始化该实例, 这一函数完成了 Bind 的主要工作
exec->Init(symbol, default_ctx, group2ctx,
in_args, arg_grad_store, grad_req_type, aux_states,
reinterpret_cast<Executor*>(shared_exec));
return exec;
}

由上面代码可以看到, 首先创建一个 GraphExecutor, 然后调用 GraphExecutor::Init 去初始化该 GraphExecutor. Bind 的几乎所有的工作都是在 Init 中完成的. Init 中又调用了其它的初始化方法, 例如 memory 的初始化, operator 的初始化.

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
void GraphExecutor::Init(nnvm::Symbol symbol,
const Context& default_ctx,
const std::map<std::string, Context>& ctx_map,
const std::vector<NDArray>& in_args,
const std::vector<NDArray>& arg_grad_store,
const std::vector<OpReqType>& grad_req_type,
const std::vector<NDArray>& aux_states,
Executor* shared_exec,
const nnvm::NodeEntryMap<NDArray>& feed_dict) {
// 通过用户层传递过来的 Symbol, 构建一个 computation graph
// 如果所有的 grad_req_type 都是 kNullOp 的话表明不需要计算梯度
// 直接返回原 Symbol, 否则, mxnet 会自动构建计算 gradient 的节点加入到 computation graph 中
nnvm::Graph g = InitGraph(symbol, default_ctx,
ctx_map, in_args, arg_grad_store,
grad_req_type, aux_states, feed_dict);
g.attrs["saved_opr"] = std::make_shared<nnvm::any>(std::move(saved_opr_));
// 绑定 graph 中各个节点的 operator
g = AttachOpExecs(g);
// 绑定 graph 中各个节点所需要的资源, 这里(应该是)只声明了需要的资源类型以及数量, 并没有进行真正的资源分配
g = AttachOpResources(g);
graph_ = std::move(g);
// 分配 memory
if (shared_exec != nullptr) {
this->InitDataEntryMemory(&(dynamic_cast<GraphExecutor*>(shared_exec)->data_pool_));
} else {
this->InitDataEntryMemory(nullptr);
}
{
// initialize output arrays
auto& idx = graph_.indexed_graph();
for (size_t i = 0; i < num_forward_outputs_; ++i) {
auto& e = idx.outputs()[i];
output_arrays_.push_back(data_entry_[idx.entry_id(e)]);
}
// initialize head gradient array
head_grad_array_.resize(symbol.outputs.size());
for (size_t i = num_forward_inputs_; i < idx.input_nodes().size(); ++i) {
uint32_t nid = idx.input_nodes().at(i);
uint32_t oid = head_grad_map_.at(idx[nid].source);
head_grad_array_[oid] = data_entry_[idx.entry_id(nid, 0)];
}
}
this->InitCachedOps();
this->InitOpSegs();
}

由上面可知, Init 主要调用了 InitGraph, AttachOpExecs, AttachOpResources, InitDataEntryMemory, InitCachedOpsInitOpSegs.

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
Graph GraphExecutor::InitGraph(nnvm::Symbol symbol,
const Context& default_ctx,
const std::map<std::string, Context>& ctx_map,
const std::vector<NDArray>& in_args,
const std::vector<NDArray>& arg_grad_store,
const std::vector<OpReqType>& grad_req_type,
const std::vector<NDArray>& aux_states,
const nnvm::NodeEntryMap<NDArray>& feed_dict) {
// setup gradient
// 构建 computation graph, 其中一个比较重要的工作是, 如果 grad_req_type 中有
// 非 kNullOp 的元素, 那么, 就要构建 gradient 节点
// 这里还完成了 arg_grad_store 的输入, 而 in_args 的输入是在本函数后面部分完成的
nnvm::Graph g = InitFullGraph(symbol, grad_req_type, arg_grad_store);
// 分配计算设备, CPU 还是 GPU 以及哪块 GPU 等信息
g = AssignContext(g, default_ctx, ctx_map,
in_args,
grad_store_,
aux_states,
num_forward_inputs_,
num_forward_outputs_);
const auto& idx = g.indexed_graph();
// get number of nodes used in forward pass
num_forward_nodes_ = 0;
for (size_t i = 0; i < num_forward_outputs_; ++i) {
num_forward_nodes_ = std::max(
num_forward_nodes_, static_cast<size_t>(idx.outputs()[i].node_id + 1));
}
// Setup data entry, shape and type.
// 对于所有的forward 过程中需要用的节点绑定 NDArray.
// 对比 SimpleBind 中需要按照 computation graph 的 topo 序绑定 NDArray
data_entry_.resize(idx.num_node_entries());
auto mutable_nodes = idx.mutable_input_nodes();
nnvm::ShapeVector arg_shapes;
nnvm::DTypeVector arg_types;
size_t arg_top = 0, aux_top = 0;
for (size_t i = 0; i < num_forward_inputs_; ++i) {
// 这里的 idx 是上面 g.indexed_graph() 的结果, 所以, 是 topo 排序好的
// 因为, indexed_graph 最终调用的也是 DFSVisit, 因此, 的到是相同的唯一的排序结果
const uint32_t nid = idx.input_nodes().at(i);
if (mutable_nodes.count(nid)) {
CHECK_LT(aux_top, aux_states.size());
data_entry_[idx.entry_id(nid, 0)] = aux_states[aux_top];
arg_shapes.push_back(aux_states[aux_top].shape());
arg_types.push_back(aux_states[aux_top].dtype());
++aux_top;
} else {
CHECK_LT(arg_top, in_args.size());
data_entry_[idx.entry_id(nid, 0)] = in_args[arg_top];
arg_shapes.push_back(in_args[arg_top].shape());
arg_types.push_back(in_args[arg_top].dtype());
++arg_top;
}
}
for (size_t j = num_forward_outputs_; j < idx.outputs().size(); ++j) {
data_entry_[idx.entry_id(idx.outputs()[j])]
= grad_store_[j - num_forward_outputs_].second;
}
arg_shapes.resize(idx.input_nodes().size(), TShape());
arg_types.resize(idx.input_nodes().size(), -1);
// other initializations
g = nnvm::pass::InferShape(g, arg_shapes, "__shape__");
g = nnvm::pass::InferType(g, arg_types, "__dtype__");
{
// memory allocator
const int kBadStorageID = -1;
const int kExternalStorageID = -2;
nnvm::StorageVector arg_storage_id(idx.num_node_entries(), kBadStorageID);
for (size_t j = num_forward_outputs_; j < idx.outputs().size(); ++j) {
arg_storage_id[idx.entry_id(idx.outputs()[j])] = kExternalStorageID;
}
for (const auto& kv : feed_dict) {
uint32_t eid = idx.entry_id(kv.first);
data_entry_[eid] = kv.second;
arg_storage_id[eid] = kExternalStorageID;
}
g.attrs["storage"] = std::make_shared<dmlc::any>(std::move(arg_storage_id));
g = nnvm::ApplyPass(g, "PlanMemory");
}
g = DetectInplaceAddTo(g);
return g;
}