0%

MXNet 源码详解 -- Static Graph

在很多机器学习系统中,计算图是一个非常重要的概念,计算图是一个有向无环图,它描述的计算的依赖关系。有了 计算图之后,我们就可以根据 static graph 对计算做各种优化。

计算图

在计算系统的优化中,计算图可以帮助减少内存的占用和提升计算的并行度。

内存占用

假设我们有下面的一个计算图:

上图的计算为:

1
2
3
4
5
var data, weight, bias;
out = Conv(data, weight, bias);
var gamma, beta;
out = BatchNorm(out, gamma, beta);
out = ReLU(out)

在上面的计算中, data, weight, bias, gamma, beta 所占用的储存知道上面整段代码跑完之后才会释放,虽然,data, weight, bias 只在 Conv 的计算中用到了,后面再也没有用到。因此,在这里就有内存占用的优化空间。
在第 2 行之后,系统并不知道后面再也不会用到 data, weight, bias 了,因此,系统不能释放它们。因此,为了优化内存占用,我们可以做如下优化:

1
2
3
4
5
6
7
var data, weight, bias;
out = Conv(data, weight, bias);
del data, weight, bias;
var gamma, beta;
out = BatchNorm(out, gamma, beta);
del gamma, beta;
out = ReLU(out)

如上所示,我们显式地加入和释放内存的语句。那么,这个过程有没有可能自动化呢?答案显然是可以的。这里就引入了计算图的概念。
计算图的意思是,我们在定义计算的时候,并没有真正进行计算,而是首先构建一个计算图,然后,针对计算图进行优化,最后才是真正的计算,还是上面的例子:

1
2
3
4
5
6
7
Symbol data, weight, bias;
out = Conv(data, weight, bias);
Symbol gamma, beta;
out = BatchNorm(out, gamma, beta);
out = ReLU(out)
out.optimize()
out.run([nd_data, nd_weight, nd_bias, nd_gamma, nd_beta]);

如上所示,在第 5 行得到的 out 就是上面的示意图,通过 out.optimize(), 系统就分析出在计算完Conv之后可以释放 data, weight, bias, 在计算完 BatchNorm 之后系统可以释放 gamma, beta. 通过这种方法,我们就可以节省一部分内存占用。

计算并行度

我们对上图示例稍加改动,得到如下的计算图:

对应的计算为:

1
2
3
4
5
var data, weight, bias;
out0 = Conv(data, weight, bias);
out1 = Pooling(data)
out = Concat(out0, out1);
out = ReLU(out)

上面的计算先计算一个 Conv, 再计算一次 Pooling,然后把结果 Concat 起来,最后算一次 ReLU. 虽然 ConvPooling 的计算没有互相依赖,可以同时进行,但是,在上述代码中,系统并不知道他们之间的依赖关系,只能按部就班地一行一行进行计算。如果可以提前得到计算图,通过分析就可以知道计算图的依赖关系,那么,就可以大大提升计算的并行,从而加速整个计算的过程。

Static Graph

static graph 就是 mxnet 中的计算图。