0%

MXNet Source Code Memo

阅读源码过程中遇到的一些疑惑。有些已解决记录一下,有些没有解决。

MXNet 中 Symbol 的 name 是怎么确定的

MXNet 并不处理 Symbol 命名的工作,这个工作交给了 wrapper 来完成。
以 Python wrapper 为例,在用户指定了名字的情况下,直接使用用户指定的名字,如果冲突就报错。
如果用户没有指定具体的名字,那么,python wrapper 会自动生成名字,具体如下:
在 mxnet 的 python wrapper 中,mxnet/symbol/op.py

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
def %s(*%s, **kwargs):"""%(func_name, arr_name))
code.append("""
sym_args = []
for i in {}:
assert isinstance(i, SymbolBase), \\
"Positional arguments must be Symbol instances, " \\
"but got %s"%str(i)
sym_args.append(i)""".format(arr_name))
if dtype_name is not None:
code.append("""
if '%s' in kwargs:
kwargs['%s'] = _numpy.dtype(kwargs['%s']).name"""%(
dtype_name, dtype_name, dtype_name))
code.append("""
attr = kwargs.pop('attr', None)
kwargs.update(AttrScope.current.get(attr))
name = kwargs.pop('name', None)
name = NameManager.current.get(name, '%s')
_ = kwargs.pop('out', None)

具体 NameManager 在 name.py 中定义的

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
def get(self, name, hint):
"""Get the canonical name for a symbol.

This is the default implementation.
If the user specifies a name,
the user-specified name will be used.

When user does not specify a name, we automatically generate a
name based on the hint string.

Parameters
----------
name : str or None
The name specified by the user.

hint : str
A hint string, which can be used to generate name.

Returns
-------
full_name : str
A canonical name for the symbol.
"""
if name:
return name
if hint not in self._counter:
self._counter[hint] = 0
name = '%s%d' % (hint, self._counter[hint])
self._counter[hint] += 1
return name

MXNet 是如何关联 layer 之间的输入和输出的

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
// symbolic.cc:
// compositional logic
// args 输入的 symbol, 可认为是低层的symbol
void Symbol::Compose(const array_view<const Symbol*>& args,
const std::unordered_map<std::string, const Symbol*>& kwargs,
const std::string& name) {
static auto& flist_inputs = Op::GetAttr<FListInputNames>("FListInputNames");
static auto& fset_attrs = Op::GetAttr<FSetInputVarAttrOnCompose>("FSetInputVarAttrOnCompose");

CHECK(!outputs[0].node->is_variable()) << "Variable cannot be composed";
// parameter check.
for (size_t i = 0; i < args.size(); ++i) {
CHECK_EQ(args[i]->outputs.size(), 1U)
<< "Argument " << i << " is a tuple, single value is required";
}
for (const auto& kv : kwargs) {
CHECK_EQ(kv.second->outputs.size(), 1U)
<< "Keyword Argument " << kv.first << " is a tuple, single value is required";
}
// assign new name
outputs[0].node->attrs.name = name;

// Atomic functor composition.
if (IsAtomic(outputs)) {
// n 是本 symbol 的输出 Node
Node* n = outputs[0].node.get();
uint32_t n_req = n->num_inputs();

if (n_req != kVarg) {
n->inputs.resize(n_req);
CHECK_LE(args.size(), n_req)
<< "Incorrect number of arguments, requires " << n_req
<< ", provided " << args.size();
for (size_t i = 0; i < args.size(); ++i) {
// 本 symbol 的 input 来自 下层的 symbol 即 args 的outputs
// 十分注意一点, 这里使用的是 outputs[0], 所以, 使用的是下层输出的第 0 的位置
// 的内容, 这点在写具体 operator 的时候回有用
n->inputs[i] = args[i]->outputs[0];
}

MXNet 的输入节点的规划

在 InferShape 中完成的,包括每个 Node 之间的顺序
在每个具体的 operator 的 InferShape 函数确定了 in_shapeout_shape, 这个顺序又是具体 operator 中开始的两个枚举确定的,例如:

1
2
3
4
namespace fullc {
enum FullyConnectedOpInputs {kData, kWeight, kBias};
enum FullyConnectedOpOutputs {kOut};
} // fullc

InferShape 例子如下:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
bool InferShape(std::vector<TShape> *in_shape,
std::vector<TShape> *out_shape,
std::vector<TShape> *aux_shape) const override {
using namespace mshadow;
if (!param_.no_bias) {
CHECK_EQ(in_shape->size(), 3U) << "Input:[data, weight, bias]";
} else {
CHECK_EQ(in_shape->size(), 2U) << "Input:[data, weight]";
}
CHECK_EQ(out_shape->size(), 1U);
TShape dshape = (*in_shape)[fullc::kData];
TShape oshape = (*out_shape)[0];
// require data to be known
if (dshape.ndim() == 0) return false;

index_t num_input = dshape.ProdShape(1, dshape.ndim());
// 处理 in_shape
SHAPE_ASSIGN_CHECK(*in_shape, fullc::kWeight, Shape2(param_.num_hidden, num_input));
if (!param_.no_bias) {
SHAPE_ASSIGN_CHECK(*in_shape, fullc::kBias, Shape1(param_.num_hidden));
}
// 处理 out_shape
SHAPE_ASSIGN_CHECK(*out_shape, 0, Shape2(dshape[0], param_.num_hidden));
if (oshape.ndim() != 0) {
dshape[0] = oshape[0];
SHAPE_ASSIGN_CHECK(*in_shape, fullc::kData, dshape);
}
return true;
}

那么,还有一个问题,这里的输入输出的 TShape 和它们的 name 是怎么关联起来的呢?在 cpp package 的 InferExecutorArrays 中按照 ListArguments() 的顺序处理的。

1
2
3
4
5
6
7
8
9
10
11
12
InferShape(arg_shapes, &in_shapes, &aux_shapes, &out_shapes);

for (size_t i = 0; i < in_shapes.size(); ++i) {
const auto &shape = in_shapes[i];
const auto &arg_name = arg_name_list[i];
auto iter_arg = args_map.find(arg_name);
if (iter_arg != args_map.end()) {
arg_arrays->push_back(iter_arg->second);
} else {
arg_arrays->push_back(NDArray(shape, context, false));
NDArray::SampleGaussian(0, 1, &arg_arrays->back());
}