0%

MXNet 中 Parameter 模块的设计

参数是机器学习中很重要的一个部分,参数是用户和机器学习框架进行交互的一种方式。这篇文章将会介绍 DMLC 的参数模块,设计该该轻量级的 C++ 模块的目的就是支持通用的机器学习库,其具备以下几个优点:

  • 类型字段,默认值和约束的声明很简单
  • 自动检查是否满足约束,如果不满足的话抛出异常
  • 自动生成易读的参数文档
  • 序列化到 JSON 中,反序列化到 std::map(std::string, std::string) 中。

使用参数模块

声明参数

在 dmlc 的 parameter 模块中,每一个参数都可以声明成一个 structure. 因此,可以非常高效的获取 parameter 的各个字段,例如:

1
weight -= param.learning_rate * gradient;

和普通的 structure 唯一的区别是需要声明所有的字段以及其默认值和约束。例如声明 MyParam 的参数结构:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
#include <dmlc/parameter.h>

// declare the parameter, normally put it in header file.
struct MyParam : public dmlc::Parameter<MyParam> {
float learning_rate;
int num_hidden;
int activation;
std::string name;
// declare parameters
DMLC_DECLARE_PARAMETER(MyParam) {
DMLC_DECLARE_FIELD(num_hidden).set_range(0, 1000)
.describe("Number of hidden unit in the fully connected layer.");
DMLC_DECLARE_FIELD(learning_rate).set_default(0.01f)
.describe("Learning rate of SGD optimization.");
DMLC_DECLARE_FIELD(activation).add_enum("relu", 1).add_enum("sigmoid", 2)
.describe("Activation function type.");
DMLC_DECLARE_FIELD(name).set_default("layer")
.describe("Name of the net.");
}
};

// register the parameter, this is normally in a cc file.
DMLC_REGISTER_PARAMETER(MyParam);

从以上代码可知,不同之处是在 DMLC_DECLARE_PARAMETER(MyParam) 后面这一部分,这部分声明了所有的字段。这个例子声明了 float, intstring 类型的参数。这个实例代码的特点是:

  • 对于数值型的参数,可以通过 .set_range(begin, end) 来限定范围
  • 可以定义枚举类型,这个例子中是 activation. 用户只能设置 sigmoidrelu 两种类型的激活函数字段,并且分别映射到 1 和 2 上去
  • describe 函数增加对字段的描述,用来生成用户易读的文档。

设置参数

声明了参数之后,我们可以像普通的 structure 一样声明该 structure. 不同的地方是,MyParam structure 带有成员函数,目的是使得对 parameter 的操作方便一些。为了从外部数据源设置参数,使用 Init 函数:

1
2
3
4
5
6
7
8
9
10
11
int main() {
MyParam param;
std::vector<std::pair<std::string, std::string> > param_data = {
{"num_hidden", "100"},
{"activation", "relu"},
{"name", "myname"}
};
// set the parameters
param.Init(param_data);
return 0;
}

调用完 Init 之后 param 就被 param_data 中的特定的 key 值填充了。更重要的是,Init 函数会自动检查 parameter range, 并且在出现错误的时候抛出 dmlc::ParamError 异常以及详细的错误信息

生成用户易读的文档

参数模块的另外一个特点是生成用户易读的文档。这在构建 language binding(例如 Python, R) 的时候非常有用。获取描述文档的方法是:

1
std::string docstring = MyParam::__DOC__();

另外该模块还提供了结构化的获取字段信息的方法:

1
std::vector<dmlc::ParamFieldInfo> fields = MyParam::__FIELDS__();

参数序列化

之中最常用的序列化方法是使用下面的代码把参数转化成 std::map<string, string> 表示:

1
std::map<string, string> dict = param.__DICT__();

std::map<string, string> dict = param.__DICT__(); 进行序列化就比较简单了。这种序列化的方法根据设备和平台的不同而不同。然而,这种序列化方法不是很紧凑,只推荐用来序列化一般的参数。
直接序列化到 JSON 以及从 JSON 文件中 load 参数也是支持的。

应用实例

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
// This is an example program showing usage of parameter module
// Build, on root folder, type
//
// make example
//
// Example usage:
//
// example/parameter num_hidden=100 name=aaa activation=relu
//

#include <dmlc/parameter.h>

struct MyParam : public dmlc::Parameter<MyParam> {
float learning_rate;
int num_hidden;
int activation;
std::string name;
// declare parameters in header file
DMLC_DECLARE_PARAMETER(MyParam) {
DMLC_DECLARE_FIELD(num_hidden).set_range(0, 1000)
.describe("Number of hidden unit in the fully connected layer.");
DMLC_DECLARE_FIELD(learning_rate).set_default(0.01f)
.describe("Learning rate of SGD optimization.");
DMLC_DECLARE_FIELD(activation).add_enum("relu", 1).add_enum("sigmoid", 2)
.describe("Activation function type.");
DMLC_DECLARE_FIELD(name).set_default("mnet")
.describe("Name of the net.");

// user can also set nhidden besides num_hidden
DMLC_DECLARE_ALIAS(num_hidden, nhidden);
DMLC_DECLARE_ALIAS(activation, act);
}
};

// register it in cc file
DMLC_REGISTER_PARAMETER(MyParam);


int main(int argc, char *argv[]) {
if (argc == 1) {
printf("Usage: [key=value] ...\n");
return 0;
}

MyParam param;
std::map<std::string, std::string> kwargs;
for (int i = 0; i < argc; ++i) {
char name[256], val[256];
if (sscanf(argv[i], "%[^=]=%[^\n]", name, val) == 2) {
kwargs[name] = val;
}
}
printf("Docstring\n---------\n%s", MyParam::__DOC__().c_str());

printf("start to set parameters ...\n");
param.Init(kwargs);
printf("-----\n");
printf("param.num_hidden=%d\n", param.num_hidden);
printf("param.learning_rate=%f\n", param.learning_rate);
printf("param.name=%s\n", param.name.c_str());
printf("param.activation=%d\n", param.activation);
return 0;
}

工作原理

构建这样一个模块需要一些技巧,因为需要反射–获取结构体中的字段信息,而 C++不支持该功能
例如下面这段代码,Init 函数是怎么知道 num_hidden 的位置并且在 Init 函数中正确设置的呢?

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
#include <vector>
#include <string>
#include <dmlc/parameter.h>

// declare the parameter, normally put it in header file.
struct MyParam : public dmlc::Parameter<MyParam> {
float learning_rate;
int num_hidden;
// declare parameters
DMLC_DECLARE_PARAMETER(MyParam) {
DMLC_DECLARE_FIELD(num_hidden);
DMLC_DECLARE_FIELD(learning_rate).set_default(0.01f);
}
};

// register the parameter, this is normally in a cc file.
DMLC_REGISTER_PARAMETER(MyParam);

int main(int argc, char *argv[]) {
MyParam param;
std::vector<std::pair<std::string, std::string> > param_data = {
{"num_hidden", "100"},
};
param.Init(param_data);
return 0;
}

这里面的关键技巧在DMLC_DECLARE_PARAMETER(MyParam)中,这是在参数模块中定义的一个宏。如果把这个宏展开,大约是下面这样的代码:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
struct Parameter<MyParam> {
template<typename ValueType>
inline FieldEntry<ValueType>&
DECLARE(ParamManagerSingleton<MyParam> *manager,
const std::string& key,
ValueType& ref){
// offset gives a generic way to access the address of the field
// from beginning of the structure.
size_t offset = ((char*)&ref - (char*)this);
parameter::FieldEntry<ValueType> *e =
new parameter::FieldEntry<ValueType>(key, offset);
manager->AddEntry(key, e);
return *e;
}
};

struct MyParam : public dmlc::Parameter<MyParam> {
float learning_rate;
int num_hidden;
// declare parameters
inline void __DECLARE__(ParamManagerSingleton<MyParam> *manager) {
this->DECLARE(manager, "num_hidden", num_hidden);
this->DECLARE(manager, "learning_rate", learning_rate).set_default(0.01f);
}
};

// This code is only used to show the general idea.
// This code will only run once, the real code is done via singleton declaration pattern.
{
static ParamManagerSingleton<MyParam> manager;
MyParam tmp;
tmp->__DECLARE__(&manager);
}

这段代码并不是实际参数模块中使用的代码,但是,大体上展示了该模块是怎么工作的。关键的地方是,structure 的 layout 对于所有对象的实例来说都是固定的,为了说明怎么访问每一个字段,可以:

  • 创建 MyParam 的一个实例,调用 DECLARE 函数。
  • 字段和 structure header 的相对位置记录在一个全局的 singleton 中。
  • 调用 Init 的时候,可以在 singleton 中获取 offset 信息,通过 (ValueType*)((char*)this + offset) 访问字段的地址。

通过 C++ 的泛型编程,我们可以为机器学习库构建一个简单有用的参数模块,该模块在 dmlc 项目中已广泛使用,希望该模块对你也有帮助。