在很多机器学习系统中,计算图是一个非常重要的概念,计算图是一个有向无环图,它描述的计算的依赖关系。有了 计算图之后,我们就可以根据 static graph 对计算做各种优化。
计算图
在计算系统的优化中,计算图可以帮助减少内存的占用和提升计算的并行度。
内存占用
假设我们有下面的一个计算图:
上图的计算为:
1 | var data, weight, bias; |
在上面的计算中, data, weight, bias, gamma, beta
所占用的储存知道上面整段代码跑完之后才会释放,虽然,data, weight, bias
只在 Conv 的计算中用到了,后面再也没有用到。因此,在这里就有内存占用的优化空间。
在第 2 行之后,系统并不知道后面再也不会用到 data, weight, bias
了,因此,系统不能释放它们。因此,为了优化内存占用,我们可以做如下优化:
1 | var data, weight, bias; |
如上所示,我们显式地加入和释放内存的语句。那么,这个过程有没有可能自动化呢?答案显然是可以的。这里就引入了计算图的概念。
计算图的意思是,我们在定义计算的时候,并没有真正进行计算,而是首先构建一个计算图,然后,针对计算图进行优化,最后才是真正的计算,还是上面的例子:
1 | Symbol data, weight, bias; |
如上所示,在第 5 行得到的 out
就是上面的示意图,通过 out.optimize()
, 系统就分析出在计算完Conv
之后可以释放 data, weight, bias
, 在计算完 BatchNorm
之后系统可以释放 gamma, beta
. 通过这种方法,我们就可以节省一部分内存占用。
计算并行度
我们对上图示例稍加改动,得到如下的计算图:
对应的计算为:
1 | var data, weight, bias; |
上面的计算先计算一个 Conv
, 再计算一次 Pooling
,然后把结果 Concat
起来,最后算一次 ReLU
. 虽然 Conv
和 Pooling
的计算没有互相依赖,可以同时进行,但是,在上述代码中,系统并不知道他们之间的依赖关系,只能按部就班地一行一行进行计算。如果可以提前得到计算图,通过分析就可以知道计算图的依赖关系,那么,就可以大大提升计算的并行,从而加速整个计算的过程。
Static Graph
static graph 就是 mxnet 中的计算图。