0%

MXNet 源码详解 -- Parameter

在 MXNet 中,参数管理是非常重要的一个功能,MXNet 通过 Parameter 提供了统一的参数定义使用方式。

MXNet 中的大部分 Operator 都其特定的参数,例如,Convolution 的 padding, strides, kernel, num_filter 等,这些参数就是通过 Parameter 设置的。Parameter 封装地比较晚上,即使不完全理解 Parameter 的细节也不影响实现一个带有参数的 Operator, 但是,通过对 Parameter 的深入学习,可以更灵活得使用 Parameter, 并且学习 Parameter 的设计思想。

基本流程

1
2
3
4
5
6
7
8
9
10
11
struct Param : public dmlc::Parameter<Param> {
float learning_rate;
int num_hidden;
// declare parameters in header file
DMLC_DECLARE_PARAMETER(Param) {
DMLC_DECLARE_FIELD(num_hidden);
DMLC_DECLARE_FIELD(learning_rate);
}
};
// register it in cc file
DMLC_REGISTER_PARAMETER(Param);

Parameter 总是定义为一个结构体,DMLC_DECLARE_PARAMETER 定义如下:

1
2
3
#define DMLC_DECLARE_PARAMETER(PType)                                   \
static ::dmlc::parameter::ParamManager *__MANAGER__(); \
inline void __DECLARE__(::dmlc::parameter::ParamManagerSingleton<PType> *manager) \

结合 DMLC_REGISTER_PARAMETER 来看,DMLC_REGISTER_PARAMETER 的定义如下:

1
2
3
4
5
6
7
#define DMLC_REGISTER_PARAMETER(PType)                                  \
::dmlc::parameter::ParamManager *PType::__MANAGER__() { \
static ::dmlc::parameter::ParamManagerSingleton<PType> inst(#PType); \
return &inst.manager; \
} \
static ::dmlc::parameter::ParamManager &__make__ ## PType ## ParamManager__ = \
(*PType::__MANAGER__())

由于 __MANAGER__ 返回的是 static 对象,因此,每一个类型的 Operator 有且仅有一个 ParamManager.
__DECLARE__ 和 后续的多个 DMLC_DECLARE_FIELD 共同定义了一个 __DECLARE__ 函数。

DMLC_DECLARE_FIELD 定义如下:

1
#define DMLC_DECLARE_FIELD(FieldName)  this->DECLARE(manager, #FieldName, FieldName)

DECLARE 的定义如下:

1
2
3
4
5
6
7
8
9
10
template<typename DType>
inline parameter::FieldEntry<DType>& DECLARE(
parameter::ParamManagerSingleton<PType> *manager,
const std::string &key, DType &ref) { // NOLINT(*)
parameter::FieldEntry<DType> *e =
new parameter::FieldEntry<DType>();
e->Init(key, this->head(), ref);
manager->manager.AddEntry(key, e);
return *e;
}

可以看到,这里的核心是 DECLARE. 其涉及了两个类:FieldEntry, ParamManagerSingleton.
把上面例子中的所有的宏展开:

1
2
3
4
5
6
7
8
struct MyParam :: Parameter<MyParam> {
float learning_rate;
int num_hidden;
static ::dmlc::parameter::ParamManager *__MANAGER__(); \
inline void __DECLARE__(::dmlc::parameter::ParamManagerSingleton<MyParam> *manager) {
this->DECLARE(manager, "learning_rate", learning_rate);
this->DECLARE(manager, "num_hidden", num_hidden);
}

DMLC_REGISTER_PARAMETER 展开:

1
2
3
4
5
6
::dmlc::parameter::ParamManager *MyParam::__MANAGER__() {
static ::dmlc::parameter::ParamManagerSingleton<MyParam> inst("MyParam");
return &inst.manager;
}
static ::dmlc::parameter::ParamManager &__make__MyParamParamManager__ =
(*MyParam::__MANAGER__());

ParamManagerSingleton 的初始化为:

1
2
3
4
5
6
7
8
9
template<typename PType>
struct ParamManagerSingleton {
ParamManager manager;
explicit ParamManagerSingleton(const std::string &param_name) {
PType param;
param.__DECLARE__(this);
manager.set_name(param_name);
}
};

可以看到,Parameter 的流程为:针对某个 Op 的参数例如 MyParam, 创建全局唯一的 ParamManagerSingleton. ParamManagerSingleton 在初始化的时候调用 Parameter 的 DELCARE 函数,在 DECLARE 中完成了对 Parameter 每个参数的初始化。

参数的初始化

在具体的 Operator 中,参数初始化如下:

1
2
3
void Init(const std::vector<std::pair<std::string, std::string> >& kwargs) override {
param_.Init(kwargs);
}

可以看到,用户输入的参数的 keyvalue 都是字符串,那么,Parameter 中是如何把字符串类型的参数转换成实际类型例如 float, int 的呢?

首先看到,上面的 Init 的调用为:

1
2
3
4
5
template<typename Container>
inline void Init(const Container &kwargs) {
PType::__MANAGER__()->RunInit(static_cast<PType*>(this),
kwargs.begin(), kwargs.end(), NULL);
}

进一步,我们看一下 manager 的 RunInit 调用:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
template<typename RandomAccessIterator>
inline void RunInit(void *head,
RandomAccessIterator begin,
RandomAccessIterator end,
std::vector<std::pair<std::string, std::string> > *unknown_args) const {
std::set<FieldAccessEntry*> selected_args;
for (RandomAccessIterator it = begin; it != end; ++it) {
FieldAccessEntry *e = Find(it->first);
if (e != NULL) {
e->Set(head, it->second);
e->Check(head);
selected_args.insert(e);
} else {
if (unknown_args != NULL) {
unknown_args->push_back(*it);
} else {
std::ostringstream os;
os << "Cannot find argument \'" << it->first << "\', Possible Arguments:\n";
os << "----------------\n";
PrintDocString(os);
throw dmlc::ParamError(os.str());
}
}
}

for (std::map<std::string, FieldAccessEntry*>::const_iterator it = entry_map_.begin();
it != entry_map_.end(); ++it) {
if (selected_args.count(it->second) == 0) {
it->second->SetDefault(head);
}
}
}

可以看到,RunInit 通过 for 循环对输入的参数逐个设置。核心部分就是 e->Set(head, it->second);

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
virtual void Set(void *head, const std::string &value) const {
std::istringstream is(value);
is >> this->Get(head);
if (!is.fail()) {
while (!is.eof()) {
int ch = is.get();
if (ch == EOF) {
is.clear(); break;
}
if (!isspace(ch)) {
is.setstate(std::ios::failbit); break;
}
}
}

if (is.fail()) {
std::ostringstream os;
os << "Invalid Parameter format for " << key_
<< " expect " << type_ << " but value=\'" << value<< '\'';
throw dmlc::ParamError(os.str());
}
}

从这里可以看到,is >> this->Get(head) 完成了参数的设置,因此,这就要求,每一个参数类型必须要定义 >>操作符。
为了确认这一点,我们来看一下 TShape:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
inline std::istream &operator>>(std::istream &is, TShape &shape) {
// get (
while (true) {
char ch = is.get();
if (ch == '(') break;
if (!isspace(ch)) {
is.setstate(std::ios::failbit);
return is;
}
}
index_t idx;
std::vector<index_t> tmp;
while (is >> idx) {
tmp.push_back(idx);
char ch;
do {
ch = is.get();
} while (isspace(ch));
if (ch == ',') {
while (true) {
ch = is.peek();
if (isspace(ch)) {
is.get(); continue;
}
if (ch == ')') {
is.get(); break;
}
break;
}
if (ch == ')') break;
} else if (ch == ')') {
break;
} else {
is.setstate(std::ios::failbit);
return is;
}
}
shape.CopyFrom(tmp.begin(), tmp.end());
return is;
}

因此,我们可以看到,TShape 是定义了 >> 操作符的,因此,TShape 也就可以用于参数的设置。