在 MXNet 中,参数管理是非常重要的一个功能,MXNet 通过 Parameter
提供了统一的参数定义使用方式。
MXNet 中的大部分 Operator 都其特定的参数,例如,Convolution 的 padding, strides, kernel, num_filter 等,这些参数就是通过 Parameter 设置的。Parameter 封装地比较晚上,即使不完全理解 Parameter 的细节也不影响实现一个带有参数的 Operator, 但是,通过对 Parameter 的深入学习,可以更灵活得使用 Parameter, 并且学习 Parameter 的设计思想。
基本流程 1 2 3 4 5 6 7 8 9 10 11 struct Param : public dmlc::Parameter<Param> { float learning_rate; int num_hidden; DMLC_DECLARE_PARAMETER (Param) { DMLC_DECLARE_FIELD (num_hidden); DMLC_DECLARE_FIELD (learning_rate); } }; DMLC_REGISTER_PARAMETER (Param);
Parameter 总是定义为一个结构体,DMLC_DECLARE_PARAMETER
定义如下:
1 2 3 #define DMLC_DECLARE_PARAMETER(PType) \ static ::dmlc::parameter::ParamManager *__MANAGER__(); \ inline void __DECLARE__(::dmlc::parameter::ParamManagerSingleton<PType> *manager) \
结合 DMLC_REGISTER_PARAMETER
来看,DMLC_REGISTER_PARAMETER
的定义如下:
1 2 3 4 5 6 7 #define DMLC_REGISTER_PARAMETER(PType) \ ::dmlc::parameter::ParamManager *PType::__MANAGER__() { \ static ::dmlc::parameter::ParamManagerSingleton<PType> inst(#PType); \ return &inst.manager; \ } \ static ::dmlc::parameter::ParamManager &__make__ ## PType ## ParamManager__ = \ (*PType::__MANAGER__())
由于 __MANAGER__
返回的是 static
对象,因此,每一个类型的 Operator 有且仅有一个 ParamManager.__DECLARE__
和 后续的多个 DMLC_DECLARE_FIELD
共同定义了一个 __DECLARE__
函数。
DMLC_DECLARE_FIELD
定义如下:
1 #define DMLC_DECLARE_FIELD(FieldName) this->DECLARE(manager, #FieldName, FieldName)
DECLARE
的定义如下:
1 2 3 4 5 6 7 8 9 10 template <typename DType>inline parameter::FieldEntry<DType>& DECLARE ( parameter::ParamManagerSingleton<PType> *manager, const std::string &key, DType &ref) { parameter::FieldEntry<DType> *e = new parameter::FieldEntry <DType>(); e->Init (key, this ->head (), ref); manager->manager.AddEntry (key, e); return *e;}
可以看到,这里的核心是 DECLARE
. 其涉及了两个类:FieldEntry
, ParamManagerSingleton
. 把上面例子中的所有的宏展开:
1 2 3 4 5 6 7 8 struct MyParam :: Parameter<MyParam> { float learning_rate; int num_hidden; static ::dmlc::parameter::ParamManager *__MANAGER__(); \ inline void __DECLARE__(::dmlc::parameter::ParamManagerSingleton<MyParam> *manager) { this ->DECLARE (manager, "learning_rate" , learning_rate); this ->DECLARE (manager, "num_hidden" , num_hidden); }
DMLC_REGISTER_PARAMETER
展开:
1 2 3 4 5 6 ::dmlc::parameter::ParamManager *MyParam::__MANAGER__() { static ::dmlc::parameter::ParamManagerSingleton<MyParam> inst ("MyParam" ) ; return &inst.manager; } static ::dmlc::parameter::ParamManager &__make__MyParamParamManager__ = (*MyParam::__MANAGER__());
而 ParamManagerSingleton
的初始化为:
1 2 3 4 5 6 7 8 9 template <typename PType>struct ParamManagerSingleton { ParamManager manager; explicit ParamManagerSingleton (const std::string ¶m_name) { PType param; param.__DECLARE__(this ); manager.set_name (param_name); } };
可以看到,Parameter 的流程为:针对某个 Op 的参数例如 MyParam, 创建全局唯一的 ParamManagerSingleton. ParamManagerSingleton 在初始化的时候调用 Parameter 的 DELCARE 函数,在 DECLARE 中完成了对 Parameter 每个参数的初始化。
参数的初始化 在具体的 Operator 中,参数初始化如下:
1 2 3 void Init (const std::vector<std::pair<std::string, std::string> >& kwargs) override { param_.Init (kwargs); }
可以看到,用户输入的参数的 key
和 value
都是字符串,那么,Parameter
中是如何把字符串类型的参数转换成实际类型例如 float
, int
的呢?
首先看到,上面的 Init 的调用为:
1 2 3 4 5 template <typename Container>inline void Init (const Container &kwargs) { PType::__MANAGER__()->RunInit (static_cast <PType*>(this ), kwargs.begin (), kwargs.end (), NULL ); }
进一步,我们看一下 manager 的 RunInit
调用:
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 template <typename RandomAccessIterator>inline void RunInit (void *head, RandomAccessIterator begin, RandomAccessIterator end, std::vector<std::pair<std::string, std::string> > *unknown_args) const { std::set<FieldAccessEntry*> selected_args; for (RandomAccessIterator it = begin; it != end; ++it) { FieldAccessEntry *e = Find (it->first); if (e != NULL ) { e->Set (head, it->second); e->Check (head); selected_args.insert (e); } else { if (unknown_args != NULL ) { unknown_args->push_back (*it); } else { std::ostringstream os; os << "Cannot find argument \'" << it->first << "\', Possible Arguments:\n" ; os << "----------------\n" ; PrintDocString (os); throw dmlc::ParamError (os.str ()); } } } for (std::map<std::string, FieldAccessEntry*>::const_iterator it = entry_map_.begin (); it != entry_map_.end (); ++it) { if (selected_args.count (it->second) == 0 ) { it->second->SetDefault (head); } } }
可以看到,RunInit
通过 for
循环对输入的参数逐个设置。核心部分就是 e->Set(head, it->second);
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 virtual void Set (void *head, const std::string &value) const { std::istringstream is (value) ; is >> this ->Get (head); if (!is.fail ()) { while (!is.eof ()) { int ch = is.get (); if (ch == EOF) { is.clear (); break ; } if (!isspace (ch)) { is.setstate (std::ios::failbit); break ; } } } if (is.fail ()) { std::ostringstream os; os << "Invalid Parameter format for " << key_ << " expect " << type_ << " but value=\'" << value<< '\'' ; throw dmlc::ParamError (os.str ()); } }
从这里可以看到,is >> this->Get(head)
完成了参数的设置,因此,这就要求,每一个参数类型必须要定义 >>
操作符。 为了确认这一点,我们来看一下 TShape
:
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 inline std::istream &operator >>(std::istream &is, TShape &shape) { while (true ) { char ch = is.get (); if (ch == '(' ) break ; if (!isspace (ch)) { is.setstate (std::ios::failbit); return is; } } index_t idx; std::vector<index_t > tmp; while (is >> idx) { tmp.push_back (idx); char ch; do { ch = is.get (); } while (isspace (ch)); if (ch == ',' ) { while (true ) { ch = is.peek (); if (isspace (ch)) { is.get (); continue ; } if (ch == ')' ) { is.get (); break ; } break ; } if (ch == ')' ) break ; } else if (ch == ')' ) { break ; } else { is.setstate (std::ios::failbit); return is; } } shape.CopyFrom (tmp.begin (), tmp.end ()); return is; }
因此,我们可以看到,TShape
是定义了 >>
操作符的,因此,TShape
也就可以用于参数的设置。