0%

Winograd 方法快速计算卷积

在 ConvNet 中,大部分的计算耗费在计算卷积的过程中,尤其是在端上设备中,对于性能的要求更为苛刻。程序的性能优化是一个复杂而庞大的话题。高性能计算就像系统设计一样,虽然有一些指导原则,但是,对于不同的场景需要有不同的设计方案,因此,对于同一个优化问题,不同的人可能会给出完全不同的优化方案。本文不是探讨硬件和代码级的优化,仅针对计算卷积的一个特定方法 — Winograd 方法做一个简单的记述。

计算卷积的各种方法

首先要明确一点,这里说的卷积是是指 ConvNet 中默认的卷积,而不是数学意义上的卷积。其实,ConvNet 中的卷积对于与数学中的 cross correlation.
计算卷积的方法有很多种,常见的有以下几种方法:

  1. 滑窗:这种方法是最直观最简单的方法。但是,该方法不容易实现大规模加速,因此,通常情况下不采用这种方法 (但是也不是绝对不会用,在一些特定的条件下该方法反而是最高效的.)
  2. im2col: 目前几乎所有的主流计算框架包括 Caffe, MXNet 等都实现了该方法。该方法把整个卷积过程转化成了 GEMM 过程,而 GEMM 在各种 BLAS 库中都是被极致优化的,一般来说,速度较快。
  3. FFT: 傅里叶变换和快速傅里叶变化是在经典图像处理里面经常使用的计算方法,但是,在 ConvNet 中通常不采用,主要是因为在 ConvNet 中的卷积模板通常都比较小,例如 \(3 \times3\) 等,这种情况下,FFT 的时间开销反而更大。
  4. Winograd: Winograd 是存在已久最近被重新发现的方法,在大部分场景中,Winograd 方法都显示和较大的优势,目前 cudnn 中计算卷积就使用了该方法。

Winograd 方法

Winograd 方法简单来讲,就是用更多的加法计算来减少乘法计算。因此,一个前提就是,在处理器中,乘法计算的时钟周期数要大于加法计算的时钟周期数。
Winograd 计算卷积需要完成的乘法的次数为:

\[ \mu\left(F\left(m \times n, r \times s\right)\right) = \left(m+n-1\right) \times \left(n+s-1\right) \]

\(r \times s\) 表示卷积核的大小,\(m \times n\) 表示输出大小。

所以,做一个简单的对比计算:\(3 \times 3\) 的卷积核,输出为 \(2 \times 2\), 那么,滑窗或者 im2col 需要的乘法计算次数为 \(3 \times 3 \times 2 \times 2 = 36\), Winograd 需要的乘法计算次数为 \((3+2-1) \times (3+2-1) = 16\).

Winograd 的证明方法较为复杂,要用到数论中的一些知识,但是,使用起来很简单。只需要按照如下公式计算:

\[ Y = A^T \left[\left[GgG^T\right] \odot \left[B^TdB\right]\right]A \]

其中,\(\odot\) 表示 element-wise multiplication. \(A, G, B\) 根据输出大小和卷积核大小不同有不同的定义,并且是提前确定了的。每种输出大小和卷积核的 \(A, G, B\) 具体是多少,可以通过 https://github.com/andravin/wincnn 的脚本计算。\(g\) 表示的是卷积核,\(d\) 表示要进行卷积计算的 data. \(g\) 的 shape 为 \(r \times r\), \(d\) 的 shape 为 \((m+r-1) \times (m+r-1)\)

后注

  1. 以上描述的 Winograd 算法只展示了在二维的图像 (更确切的说是 tile) 上的过程,具体在 ConvNet 的多个 channel 的情况,直接逐个 channel 按照上述方法计算完然后相加即可;
  2. 按照 #1 的思路,在计算多个 channel 的时候,仍然有可减少计算次数的地方。
  3. 按照 #2 的思路,Winograd 在目前使用越来越多的 depthwise conv 中其优势不明显了。
  4. 在 tile 较大的时候,Winograd 方法不适用,因为,在做 inverse transform 的时候的计算开销抵消了 Winograd 带来的计算节省。
  5. Winograd 会产生误差