Pyramid Scene Parseing Network Note

​ PSPNet 采用金字塔池化模块搭建的场景分析网络 , 基于语义分割的场景解析,其目的是赋予图像中每个像素一个类别标签

1 PSPNet介绍

由于传统的FCN存在缺陷:

  1. 不匹配上下文关系, FCN将水中的“boat”预测为“car”
  2. 类别混淆,将摩天大楼一部分识别成了其他。
  3. 不显著的类别难以预测,可能会忽略小的东西。

    基于对FCN的透彻分析,作者提出了能够获取全局场景的深度网络PSPNet,可以融合局部特征和全局特征,有较好的效果。

1.1 Pyramid Pooling Module(金字塔池化模块)

​ 金字塔池化模块Pyramid Pooling Module由一组不同尺度的池化块组成

​ 该模块融合了4种不同金字塔尺度的特征,第一行红色是最粗糙的特征–全局池化生成单个bin输出,后面三行是不同尺度的池化特征。为了保证全局特征的权重,如果金字塔共有N个级别,则在每个级别后使用1X1的卷积将对于级别通道降为原本的1/N。再通过双线性插值获得未池化前的大小,最终concat到一起。
​ 在POOL过程中,论文一共采用了$1X1,2X2,3X3,6X6$ 4个等级的池化尺寸(即输出尺寸)进行池化。

1.2 基于ResNet的预训练辅助损失函数

 在ResNet101的基础上做了改进,除了使用后面的softmax分类做loss,额外的在第四阶段添加了一个辅助的loss, 叫做**auxiliary loss**, 两个loss一起传播,使用不同的权重,共同优化参数。后续的实验证明这样做有利于快速收敛。 

2 代码实现

金字塔池化以及PSPNet网络架构

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
class _PyramidPoolingModule(nn.Module):
def __init__(self, in_dim, reduction_dim, setting):
super(_PyramidPoolingModule, self).__init__()
self.features = []
for s in setting:
self.features.append(nn.Sequential(
nn.AdaptiveAvgPool2d(s),
nn.Conv2d(in_dim, reduction_dim, kernel_size=1, bias=False),
nn.BatchNorm2d(reduction_dim, momentum=.95),
nn.ReLU(inplace=True)
))
self.features = nn.ModuleList(self.features)

def forward(self, x):
x_size = x.size()
out = [x]
for f in self.features:
out.append(F.upsample(f(x), x_size[2:], mode='bilinear'))
out = torch.cat(out, 1)
return out


class PSPNet(nn.Module):
def __init__(self, num_classes, pretrained=True, use_aux=True):
super(PSPNet, self).__init__()
self.use_aux = use_aux
resnet = models.resnet101()
if pretrained:
resnet.load_state_dict(torch.load(res101_path))
self.layer0 = nn.Sequential(resnet.conv1, resnet.bn1, resnet.relu, resnet.maxpool)
self.layer1, self.layer2, self.layer3, self.layer4 = resnet.layer1, resnet.layer2, resnet.layer3, resnet.layer4

for n, m in self.layer3.named_modules():
if 'conv2' in n:
m.dilation, m.padding, m.stride = (2, 2), (2, 2), (1, 1)
elif 'downsample.0' in n:
m.stride = (1, 1)
for n, m in self.layer4.named_modules():
if 'conv2' in n:
m.dilation, m.padding, m.stride = (4, 4), (4, 4), (1, 1)
elif 'downsample.0' in n:
m.stride = (1, 1)

self.ppm = _PyramidPoolingModule(2048, 512, (1, 2, 3, 6))
self.final = nn.Sequential(
nn.Conv2d(4096, 512, kernel_size=3, padding=1, bias=False),
nn.BatchNorm2d(512, momentum=.95),
nn.ReLU(inplace=True),
nn.Dropout(0.1),
nn.Conv2d(512, num_classes, kernel_size=1)
)

if use_aux:
self.aux_logits = nn.Conv2d(1024, num_classes, kernel_size=1)
initialize_weights(self.aux_logits)

initialize_weights(self.ppm, self.final)

def forward(self, x):
x_size = x.size()
x = self.layer0(x)
x = self.layer1(x)
x = self.layer2(x)
x = self.layer3(x)
if self.training and self.use_aux:
aux = self.aux_logits(x)
x = self.layer4(x)
x = self.ppm(x)
x = self.final(x)
if self.training and self.use_aux:
return F.upsample(x, x_size[2:], mode='bilinear'), F.upsample(aux, x_size[2:], mode='bilinear')
return F.upsample(x, x_size[2:], mode='bilinear')