WaveNet 分析和实现

这篇文章是 DeepMind 团队使用深度学习.具体来说是卷积神经网络来做语音生成的一个工作. 基于基本的语音生成, 他们还做了 tts (text to speech), 音乐合成等. 实验效果很好. 我本人一直很喜欢 DeepMind 团队的工作, 因为, 他们对问题的提炼很好, 所用的方法也比较简单实用, 效果常常出人意料地好. 顺着这篇文章, 学习到了 dilate convolution 这个一个概念. dilate convolution 可以使用较少的计算就能 cover 到较大的 receptive field. 而且由于 dilate 本身的原因, 还可以防止 overfitting.

Dilate Convolution

既然这篇论文中用到了 dilate convolution, 而 dilate conv 又是我很喜欢的一种操作, 那就先来说明一下什么是 dilate conv.
其实, 一张图就足以解释什么是 dilate conv 了. 对于好奇心比较强的同学, 可以看这篇文章. 下图也是来自这篇文章.


上图的说明已经足够解释什么是 dilate convolution 了. 简单来说, dilate convolution 引入一个新的 hyper-parameter, dilate, 这个 hyper-parameter 的涵义是,

每隔 dilate-1 个像素取一个”像素”, 做卷积操作

tensorflow 和 MXNet 均实现了这个操作, 但是, tensorflow 实现的要求必须 x 方向和 y 方向的 dilate 要一致. MXNet 没有这个要求. 但是, MXNet 的实现出现了一点 bug, 修复方法见 #3479. 相信 MXNet 官方不久也会修复这个 bug.

WaveNet

WaveNet 主体思想是 casual layer, 为了 cover 到较大的 receptive field, 使用了 kernel size 较大的卷积核. 较大的 kernel size 带来的是计算量的迅速增加. 为了解决这个问题, 文章使用了上面介绍的 dilate convolution. 除了第一层还是传统的 convlution 操作, 剩下的层全部是 dilate convolution.

Causal Convolution Layer


上图是 casual layer 的示意图, 在 casual layer 中使用的是一维的卷积操作. 这样, 语音中每一个采样点的数据顺序仍然和输入的顺序保持一致.

Causal Dilate Convolution Layer


上图是 causal dilate convolution layer. dilate convolution 的优点上面已经说过了. 在文章中, 作者的 dilate 值分别是 \(\left[2^1, 2^2, 2^3, \cdots, 2^8\right]\).

输入输出

输入: 没什么好说的, 就是把输入归一化到 0 到 255 之间的整数.
输出: 把输入向 x 的负方向 sift 1, 得到 label 数据. 这样做的目的是, 使用前 t 个采样点, 可以预测第 t+1 个采样点. 缺少的一位补 0. 例如:
输入:\(\left[1,2,3,4,5,6,7,8,9\right]\)
输出:\(\left[2,3,4,5,6,7,8,9,0\right]\)

Activation

在这个工作中我认为另外一个比较重要点是 gated activation. 简单来说, 经过 Convolution 操作 activation 操作之后不是直接输出, 而是有一个 gate 来控制其输出大小. 该 gate 本质上也是一个 dialte convolution 操作加 activation 操作.
\[ \mathbf{z}=tanh\left(W_{f,k}*\mathbf{x}\right)\odot\sigma\left(W_{g,k}*\mathbf{x}\right) \]
这里的 \(*\) 是 convolution 操作, \(\odot\) 是 element-wise multiplication.

Residual and Skipped Connection

上面已经介绍了没一个 block 的结构, 最后就是把这些 block stack 起来, 然后形成最终的网络结构, 如下图. 采用了 Residual Net 中的 skip connection 的思想.

代码实践

这个实现基于目前我所学习到的 MXNet, 尤其是输入的数据部分, 对以后的代码工作比较有帮助, 所以, 贴上来一个比较完整的代码.

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
import mxnet as mx
import numpy as np
# import minpy.numpy as np
import multiprocessing
from scipy.ndimage.interpolation import shift
def causal_layer(data=None, name="causal"):
assert isinstance(data, mx.symbol.Symbol)
zero = mx.symbol.Variable(name=name+"-zero")
concat = mx.symbol.Concat(*[data, zero], dim=3, name=name+"-concat")
causal = mx.symbol.Convolution(data=concat, kernel=(1, 2), stride=(1, 1), num_filter=16, name=name)
return causal
def residual_block(data=None, kernel=(1, 2), dilate=None, num_filter=16, name=None, stride=(1, 1), output_channel=None):
assert name is not None
assert dilate is not None
assert output_channel is not None
assert isinstance(data, mx.symbol.Symbol)
zero = mx.symbol.Variable(name=name+"-zero")
concat = mx.symbol.Concat(*[data, zero], dim=3, name=name+"-concat")
conv_filter = mx.symbol.Convolution(data=concat, kernel=kernel, stride=stride, dilate=dilate, num_filter=num_filter, name=name+"conv-filter")
conv_gate = mx.symbol.Convolution(data=concat, kernel=kernel, stride=stride, dilate=dilate, num_filter=num_filter, name=name+"conv-gate")
output_filter = mx.symbol.Activation(data=conv_filter, act_type="tanh")
output_gate = mx.symbol.Activation(data=conv_gate, act_type="sigmoid")
output = output_filter * output_gate
out = mx.symbol.Convolution(data=output, kernel=(1, 1), num_filter=output_channel)
return out+data, out
class DataBatch(object):
def __init__(self, data, label):
self.data = data
self.label = label
class DataIter(mx.io.DataIter):
def __init__(self, batch_size, length, names, shape):
self.provide_data = [(k, v) for k, v in shape.iteritems()]
self.provide_label = [("softmax_label", (batch_size, length))]
self.cur_batch = 0
self.num_batch = len(names)/batch_size
self.batch_size = batch_size
self.length = length
self.names = names
self.q = multiprocessing.Queue(maxsize=4)
self.pws = [multiprocessing.Process(target=self.get_batch) for i in xrange(4)]
for pw in self.pws:
pw.daemon = True
pw.start()
def reset(self):
self.cur_batch = 0
def __iter__(self):
return self
def __next__(self):
return self.next()
def get_batch(self):
while True:
data_all = np.empty(shape=(self.batch_size, 1, 1, self.length))
label_all = np.empty(shape=(self.batch_size, self.length))
mx_data = []
mx_label = []
idx = 0
while idx < self.batch_size:
name = random.choice(self.names)
# print name
audio, _ = librosa.load(name, sr=16000, mono=True)
if audio.shape[0] < self.length:
continue
audio = audio[:self.length]
magnitude = 1.0*np.log(1+255*np.abs(audio))/np.log(1.0+255)
signal = np.sign(audio) * magnitude
audio = ((signal+1)/2.0*255+0.5).astype(np.int32)
label = shift(audio, -1, cval=0)
data_all[idx, :, :, :] = audio
label_all[idx, :] = label
idx += 1
for k, v in shape.iteritems():
if "input" in k:
data = mx.nd.array(np.reshape(data_all, v))
else:
data = mx.nd.array(np.zeros(shape=v))
mx_data.append(data)
label = mx.nd.array(np.reshape(label_all, (self.batch_size, self.length)))
mx_label.append(label)
self.q.put(obj=DataBatch(mx_data, mx_label), block=True, timeout=None)
def next(self):
if self.q.empty():
logging.debug("waiting for data......")
if self.cur_batch < self.num_batch:
self.cur_batch += 1
return self.q.get(block=True, timeout=None)
else:
raise StopIteration
def get_network():
dilate = [2**i for i in range(1, 9)]
shape = {}
params = {'length': 2**15, 'batch_size': 100, 'num_batch': 1000}
batch_size = params['batch_size']
length = params['length']
data = mx.symbol.Variable(name="input")
net = causal_layer(data=data, name="causal")
shape = {
"input": (batch_size, 1, 1, length),
"causal-zero": (batch_size, 1, 1, 1)
}
residual = []
outs = []
for d in dilate:
name = "residual-"+str(d)
output_channel = 16
net, out = residual_block(data=net, kernel=(1, 2), dilate=(1, d), num_filter=32, stride=(1, 1), output_channel=output_channel, name=name)
residual.append(net)
outs.append(out)
shape[name+"-zero"] = (batch_size, output_channel, 1, d)
net = outs[0]+outs[1]+outs[2]+outs[3]+outs[4]+outs[5]+outs[6]+outs[7]
net = mx.symbol.Activation(data=net, act_type="relu", name="sum-activation")
net = mx.symbol.Convolution(data=net, kernel=(1, 1), num_filter=128, name="post-conv1")
net = mx.symbol.Activation(data=net, act_type="relu", name="post-activation1")
net = mx.symbol.Convolution(data=net, kernel=(1, 1), num_filter=256, name="post-conv2")
net = mx.symbol.SoftmaxOutput(data=net, name="softmax", multi_output=True)
return net, shape

完整实现 Github

说明

目前实现的是简单的生成模型, 并没有任何 condition, 如果是做 tts 还要进一步做 conditional 的模型.