0%

使用 Gluon 实现 DeepLab V3

Gluon 是 MXNet 实现的一套同时可以支持动态图和静态图的计算接口。和原有的 Symbol 接口相比,Gluon 的封装层次更高,在某种程度上使用更灵活。本文记录在学习Gluon的过程中实现 DeepLab V3 的过程。同时,数据的处理和输入也是使用的 Gluon 中提供的 Dataset 接口。

DeepLabV3

首先要简要说明一下 DeepLab V3。DeepLab V3 和 PSPNet 基本一致,主要不同在于如何融合不同特征,改进了 PSPnet 的 ASPP module。PSPNet 在把 dilated conv 作用到 feature map 上的时候,如果 dilation rate 太大,那么,3x3 的 conv 就无法获取全局的特征而是退化到了一个 1x1 的 conv,为了有效提取全局特征, DeepLab V3 使用了 Global average pooling 操作。其它部分基本上和 PSPNet 是相同的,也讲不出什么道理。

Bottleneck

因为基本网络结构延续resent的思想,因此,首先要实现resnet的基本模块。

1
class Bottleneck(HybridBlock):
2
3
    def __init__(self, channels, strides, in_channels=0):
4
        super(Bottleneck, self).__init__()
5
        self.body = HybridSequential(prefix="")
6
        self.body.add(nn.Conv2D(channels=channels // 4, kernel_size=1, strides=1))
7
        self.body.add(nn.BatchNorm())
8
        self.body.add(nn.Activation('relu'))
9
        self.body.add(nn.Conv2D(channels=channels // 4, kernel_size=3, strides=strides, padding=1, use_bias=False, in_channels=channels // 4))
10
        self.body.add(nn.BatchNorm())
11
        self.body.add(nn.Activation('relu'))
12
        self.body.add(nn.Conv2D(channels, kernel_size=1, strides=1))
13
        self.body.add(nn.BatchNorm())
14
        self.downsample = nn.HybridSequential()
15
        self.downsample.add(nn.Conv2D(channels=channels, kernel_size=1, strides=strides, use_bias=False, in_channels=in_channels))
16
        self.downsample.add(nn.BatchNorm())
17
18
    def hybrid_forward(self, F, x):
19
        residual = self.downsample(x)
20
        x = self.body(x)
21
        x = F.Activation(residual + x, act_type="relu")
22
        return x

DilatedBottleneck

DeepLab V3 和原始的 resnet 其中的一个不同点是deeplab V3 的网络使用了 dialted conv,目的是为了使用更少的参数来cover 到更大的 receptive field。在没有dilated conv的场景下,为了 cover 更大的 receptive field, 通常有两个做法:1. 使用pooling 操作,2. 使用更大的卷积核。Pooling 操作的存在的问题是损失和空间细节信息,这在semantic segmentation 这种追求像素级精度的场景下不太合适;使用更大的卷积核导致网络的参数量增加,网络有过拟合的风险。因此,目前dilated conv 在 semantic segmentation 领域非常流行。

1
class DilatedBottleneck(HybridBlock):
2
3
    def __init__(self, channels, strides, dilation=2, in_channels=0):
4
        super(DilatedBottleneck, self).__init__()
5
        self.body = HybridSequential(prefix="dialted-conv")
6
        self.body.add(nn.Conv2D(channels=channels // 4, kernel_size=1, strides=1))
7
        self.body.add(nn.BatchNorm())
8
        self.body.add(nn.Activation('relu'))
9
        self.body.add(nn.Conv2D(channels=channels // 4, kernel_size=3, strides=strides, padding=dilation, dilation=dilation, use_bias=False, in_channels=channels // 4))
10
        self.body.add(nn.BatchNorm())
11
        self.body.add(nn.Activation('relu'))
12
        self.body.add(nn.Conv2D(channels, kernel_size=1, strides=1))
13
        self.body.add(nn.BatchNorm())
14
15
    def hybrid_forward(self, F, x):
16
        residual = x
17
        x = self.body(x)
18
        x = F.Activation(residual + x, act_type="relu")
19
        return x

ASPP

DeepLab V3 的 ASPP 和 PSPNet 中的略有不同,主要是有一个 global average pooling。global average pooling 的 kernel size 要根据具体的 feature map 大小进行调整。
双线性插值的 UpSampling 似乎不太好用 gluon 接口实现。主要是因为,在双线性插值的时候,要提供一个 Data 的 symbol 和 weight 的symbol。下面代码的双线性插值是通过 deconv 实现的。双线性插值的 weight 是不需要训练的,因此,通过 deconv 实现 memory 占用会浪费一些。

1
class ASPP(HybridBlock):
2
3
    def __init__(self):
4
        super(ASPP, self).__init__(gap_kernel=56)
5
        self.aspp0 = nn.HybridSequential()
6
        self.aspp0.add(nn.Conv2D(channels=256, kernel_size=1, strides=1, padding=0))
7
        self.aspp0.add(nn.BatchNorm())
8
        self.aspp1 = self._make_aspp(6)
9
        self.aspp2 = self._make_aspp(12)
10
        self.aspp3 = self._make_aspp(18)
11
        self.gap = nn.HybridSequential()
12
        self.gap.add(nn.AvgPool2D(pool_size=gap_kernel, strides=1))
13
        self.gap.add(nn.Conv2D(channels=256, kernel_size=1))
14
        self.gap.add(nn.BatchNorm())
15
        upsampling = nn.Conv2DTranspose(channels=256, kernel_size=gap_kernel*2, strides=gap_kernel, padding=gap_kernel/2, weight_initializer=mx.init.Bilinear(), use_bias=False, groups=256)
16
        upsampling.collect_params().setattr("lr_mult", 0.0)
17
        self.gap.add(upsampling)
18
        self.concurent = gluon.contrib.nn.HybridConcurrent(axis=1)
19
        self.concurent.add(self.aspp0)
20
        self.concurent.add(self.aspp1)
21
        self.concurent.add(self.aspp2)
22
        self.concurent.add(self.aspp3)
23
        self.concurent.add(self.gap)
24
        self.fire = nn.HybridSequential()
25
        self.fire.add(nn.Conv2D(channels=256, kernel_size=1))
26
        self.fire.add(nn.BatchNorm())
27
28
    def hybrid_forward(self, F, x):
29
        return self.fire(self.concurent(x))
30
31
    def _make_aspp(self, dilation):
32
        aspp = nn.HybridSequential()
33
        aspp.add(nn.Conv2D(channels=256, kernel_size=3, strides=1, dilation=dilation, padding=dilation))
34
        aspp.add(nn.BatchNorm())
35
        return aspp

DeepLab V3

最后就是组合上述模块,实现 DeepLab V3。仍然有一个 deconv 实现双线性插值上采样的过程。

1
def ResNetFCN(pretrained=False):
2
3
    resnet = gluon.model_zoo.vision.resnet50_v1(pretrained=pretrained)
4
    net = nn.HybridSequential()
5
    for layer in resnet.features[:6]:
6
        net.add(layer)
7
    with net.name_scope():
8
        net.add(Bottleneck(1024, strides=2, in_channels=512))
9
        for _ in range(6):
10
            net.add(DilatedBottleneck(channels=1024, strides=1, dilation=2, in_channels=1024))
11
        net.add(nn.Conv2D(channels=2048, kernel_size=1, strides=1, padding=0))
12
        for _ in range(3):
13
            net.add(DilatedBottleneck(channels=2048, strides=1, dilation=4, in_channels=2048))
14
        net.add(ASPP())
15
        upsampling = nn.Conv2DTranspose(channels=4, kernel_size=32, strides=16, padding=8, weight_initializer=mx.init.Bilinear(), use_bias=False, groups=4)
16
        upsampling.collect_params().setattr("lr_mult", 0.0)
17
        net.add(upsampling)
18
        net.add(nn.BatchNorm())
19
    return net

DataIter

Gluon 提供了一套更灵活的feed数据的接口,下面是语义分割中的一个例子。

1
class DataIter(gluon.data.Dataset):
2
3
    def __init__(self, images, labels, train=True):
4
        self.data = images
5
        self.label = labels
6
        self.train = train
7
8
    def __getitem__(self, idx):
9
        if self.train:
10
            image, label = rand_crop(self.data[idx], self.label[idx])
11
            return nd.array(image), nd.array(label)
12
        else:
13
            return nd.array(self.data[idx]), nd.array(self.label[idx])
14
15
    def __len__(self):
16
        return self.data.shape[0]

训练

对网络进行初始化

1
ctx = [mx.gpu(i) for i in cfg.ctx]
2
net = ResNetFCN(pretrained=True)
3
net.initialize(init=mx.init.MSRAPrelu())
4
net.collect_params().reset_ctx(ctx=ctx)
5
if cfg.finetune:
6
    net.collect_params().load(cfg.base_param, ctx=ctx)
7
net.hybridize()
8
mx.nd.waitall()

定义相应的 DataIter

1
train_iter = DataIter(train_images, train_labels)
2
val_iter = DataIter(val_images, val_labels)
3
train_data = gluon.data.DataLoader(train_iter, cfg.batch_size, last_batch="discard", shuffle=True)
4
val_data = gluon.data.DataLoader(val_iter, cfg.batch_size, last_batch="discard")

定义loss和训练方法

1
loss = gluon.loss.SoftmaxCrossEntropyLoss(axis=1)
2
trainer = gluon.Trainer(net.collect_params(), 'adam', {'learning_rate': cfg.lr})

tensorboard

通过tensorboard监控训练过程

1
train_writer = SummaryWriter(cfg.train_logdir)
2
eval_writer = SummaryWriter(cfg.eval_logdir)

迭代训练

1
test_miou_list = []
2
for epoch in range(100000):
3
    for idx, batch in enumerate(train_data):
4
        data, label = batch
5
        data = gluon.utils.split_and_load(data, ctx)
6
        label = gluon.utils.split_and_load(label, ctx)
7
        losses = []
8
        with mx.autograd.record():
9
            outputs = [net(x) for x in data]
10
            losses = [loss(yhat, y) for yhat, y in zip(outputs, label)]
11
        for l in losses:
12
            l.backward()
13
        if count % 10 == 0:
14
          mx.nd.waitall()
15
        trainer.step(cfg.batch_size)
16
        if count % cfg.test_interval == 0:
17
            train_miou = np.array([IoU(yhat, y) for yhat, y in zip(outputs, label)]).mean()
18
            train_mloss = np.array([l.sum().asscalar() for l in losses]).mean()
19
            eval_mloss, eval_miou = evaluate(net, val_data, loss, ctx)
20
            logging.info("epoch: %d\ttrain-loss: %f\ttrain-miou: %f\ttest-loss: %f\ttest-miou: %f" % (epoch, train_mloss, train_miou, eval_mloss, eval_miou))
21
            train_writer.add_scalar("loss", train_mloss, count)
22
            train_writer.add_scalar("miou", train_miou, count)
23
            eval_writer.add_scalar("loss", eval_mloss, count)
24
            eval_writer.add_scalar("miou", eval_miou, count)
25
            net.collect_params().save(os.path.join(cfg.params_dir, str(count)))
26
            # 监督训练过程,如果 10 次 evaluation 的结果变化不大,就减小 learning rate
27
            if len(test_miou_list) > 10:
28
                test_miou_list.pop(0)
29
                test_miou_list.append(eval_miou)
30
                if max(test_miou_list) - min(test_miou_list) < 0.01:
31
                    lr = lr / 10.0
32
                    trainer.set_learning_rate(lr)
33
        count += 1

IoU

1
from __future__ import division
2
def IoU(yhat, y):
3
    if isinstance(yhat, mx.nd.NDArray):
4
        yhat = yhat.asnumpy()
5
    if isinstance(y, mx.nd.NDArray):
6
        y = y.asnumpy()
7
    yhat = yhat.argmax(axis=1)
8
    a_and_b = np.sum((y == yhat) * (y > 0))
9
    a_or_b = np.sum(((y > 0) + (yhat) > 0))
10
    return a_and_b / a_or_b if a_or_b > 0 else 0

评估算法

1
def evaluate(net, val_data, loss, ctx):
2
    eval_loss, eval_iou = [], []
3
    for idx, batch in enumerate(val_data):
4
        data, label = batch
5
        data = gluon.utils.split_and_load(data, ctx)
6
        label = gluon.utils.split_and_load(label, ctx)
7
        losses = []
8
        with mx.autograd.record(train_mode=False):
9
            outputs = [net(x) for x in data]
10
            losses = [loss(yhat, y) for yhat, y in zip(outputs, label)]
11
            eval_loss += [l.sum().asscalar() for l in losses]
12
            eval_iou += [IoU(yhat, y) for yhat, y in zip(outputs, label)]
13
    return np.array(eval_loss).mean(), np.array(eval_iou).mean()

需要注意的点

因为 MXNet 使用了 lazy evaluation 策略,因此,在训练的过程中,我们至少需要每隔几个迭代要同步一次数据,否则前端不停地把计算 push 到后端,导致显存会爆掉。同步数据有多种方式,例如,上面训练代码中,for 循环中的 mx.nd.waitall() 是一种方式,还有通过消费计算结果,例如打印 loss 等来实现同步。数据同步也不能太频繁,否则影响计算效率。