Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Appearance settings

Commit b77cf5a

Browse filesBrowse files
authored
Add DeepLabV3 implementation (qubvel-org#149)
* Add DeepLabV3 implemetation * Add DeepLabV3 to README * Add DeepLabV3 docstring
1 parent b74d91f commit b77cf5a
Copy full SHA for b77cf5a

File tree

Expand file treeCollapse file tree

7 files changed

+231
-13
lines changed
Filter options
Expand file treeCollapse file tree

7 files changed

+231
-13
lines changed

‎README.md

Copy file name to clipboardExpand all lines: README.md
+1Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -68,6 +68,7 @@ preprocess_input = get_preprocessing_fn('resnet18', pretrained='imagenet')
6868
- [FPN](http://presentations.cocodataset.org/COCO17-Stuff-FAIR.pdf)
6969
- [PSPNet](https://arxiv.org/abs/1612.01105)
7070
- [PAN](https://arxiv.org/abs/1805.10180)
71+
- [DeepLabV3](https://arxiv.org/abs/1706.05587)
7172

7273
#### Encoders <a name="encoders"></a>
7374

‎segmentation_models_pytorch/__init__.py

Copy file name to clipboardExpand all lines: segmentation_models_pytorch/__init__.py
+1Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
from .linknet import Linknet
33
from .fpn import FPN
44
from .pspnet import PSPNet
5+
from .deeplabv3 import DeepLabV3
56
from .pan import PAN
67

78
from . import encoders
+1Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
from .model import DeepLabV3
+107Lines changed: 107 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,107 @@
1+
"""
2+
BSD 3-Clause License
3+
4+
Copyright (c) Soumith Chintala 2016,
5+
All rights reserved.
6+
7+
Redistribution and use in source and binary forms, with or without
8+
modification, are permitted provided that the following conditions are met:
9+
10+
* Redistributions of source code must retain the above copyright notice, this
11+
list of conditions and the following disclaimer.
12+
13+
* Redistributions in binary form must reproduce the above copyright notice,
14+
this list of conditions and the following disclaimer in the documentation
15+
and/or other materials provided with the distribution.
16+
17+
* Neither the name of the copyright holder nor the names of its
18+
contributors may be used to endorse or promote products derived from
19+
this software without specific prior written permission.
20+
21+
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
22+
AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
23+
IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
24+
DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
25+
FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
26+
DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
27+
SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
28+
CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
29+
OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
30+
OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
31+
"""
32+
33+
import torch
34+
from torch import nn
35+
from torch.nn import functional as F
36+
37+
__all__ = ["DeepLabV3Decoder"]
38+
39+
40+
class DeepLabV3Decoder(nn.Sequential):
41+
def __init__(self, in_channels, out_channels=256, atrous_rates=(12, 24, 36)):
42+
super().__init__(
43+
ASPP(in_channels, out_channels, atrous_rates),
44+
nn.Conv2d(out_channels, out_channels, 3, padding=1, bias=False),
45+
nn.BatchNorm2d(out_channels),
46+
nn.ReLU(),
47+
)
48+
self.out_channels = out_channels
49+
50+
def forward(self, *features):
51+
return super().forward(features[-1])
52+
53+
54+
class ASPPConv(nn.Sequential):
55+
def __init__(self, in_channels, out_channels, dilation):
56+
modules = [
57+
nn.Conv2d(in_channels, out_channels, 3, padding=dilation, dilation=dilation, bias=False),
58+
nn.BatchNorm2d(out_channels),
59+
nn.ReLU()
60+
]
61+
super(ASPPConv, self).__init__(*modules)
62+
63+
64+
class ASPPPooling(nn.Sequential):
65+
def __init__(self, in_channels, out_channels):
66+
super(ASPPPooling, self).__init__(
67+
nn.AdaptiveAvgPool2d(1),
68+
nn.Conv2d(in_channels, out_channels, 1, bias=False),
69+
nn.BatchNorm2d(out_channels),
70+
nn.ReLU())
71+
72+
def forward(self, x):
73+
size = x.shape[-2:]
74+
for mod in self:
75+
x = mod(x)
76+
return F.interpolate(x, size=size, mode='bilinear', align_corners=False)
77+
78+
79+
class ASPP(nn.Module):
80+
def __init__(self, in_channels, out_channels, atrous_rates):
81+
super(ASPP, self).__init__()
82+
modules = []
83+
modules.append(nn.Sequential(
84+
nn.Conv2d(in_channels, out_channels, 1, bias=False),
85+
nn.BatchNorm2d(out_channels),
86+
nn.ReLU()))
87+
88+
rate1, rate2, rate3 = tuple(atrous_rates)
89+
modules.append(ASPPConv(in_channels, out_channels, rate1))
90+
modules.append(ASPPConv(in_channels, out_channels, rate2))
91+
modules.append(ASPPConv(in_channels, out_channels, rate3))
92+
modules.append(ASPPPooling(in_channels, out_channels))
93+
94+
self.convs = nn.ModuleList(modules)
95+
96+
self.project = nn.Sequential(
97+
nn.Conv2d(5 * out_channels, out_channels, 1, bias=False),
98+
nn.BatchNorm2d(out_channels),
99+
nn.ReLU(),
100+
nn.Dropout(0.5))
101+
102+
def forward(self, x):
103+
res = []
104+
for conv in self.convs:
105+
res.append(conv(x))
106+
res = torch.cat(res, dim=1)
107+
return self.project(res)
+81Lines changed: 81 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,81 @@
1+
import torch.nn as nn
2+
3+
from typing import Optional
4+
from .decoder import DeepLabV3Decoder
5+
from ..base import SegmentationModel, SegmentationHead, ClassificationHead
6+
from ..encoders import get_encoder
7+
8+
9+
class DeepLabV3(SegmentationModel):
10+
"""DeepLabV3_ implemetation from "Rethinking Atrous Convolution for Semantic Image Segmentation"
11+
Args:
12+
encoder_name: name of classification model (without last dense layers) used as feature
13+
extractor to build segmentation model.
14+
encoder_depth: number of stages used in decoder, larger depth - more features are generated.
15+
e.g. for depth=3 encoder will generate list of features with following spatial shapes
16+
[(H,W), (H/2, W/2), (H/4, W/4), (H/8, W/8)], so in general the deepest feature will have
17+
spatial resolution (H/(2^depth), W/(2^depth)]
18+
encoder_weights: one of ``None`` (random initialization), ``imagenet`` (pre-training on ImageNet).
19+
decoder_channels: a number of convolution filters in ASPP module (default 256).
20+
in_channels: number of input channels for model, default is 3.
21+
classes: a number of classes for output (output shape - ``(batch, classes, h, w)``).
22+
activation (str, callable): activation function used in ``.predict(x)`` method for inference.
23+
One of [``sigmoid``, ``softmax2d``, callable, None]
24+
upsampling: optional, final upsampling factor
25+
(default is 8 to preserve input -> output spatial shape identity)
26+
aux_params: if specified model will have additional classification auxiliary output
27+
build on top of encoder, supported params:
28+
- classes (int): number of classes
29+
- pooling (str): one of 'max', 'avg'. Default is 'avg'.
30+
- dropout (float): dropout factor in [0, 1)
31+
- activation (str): activation function to apply "sigmoid"/"softmax" (could be None to return logits)
32+
Returns:
33+
``torch.nn.Module``: **DeepLabV3**
34+
.. _DeeplabV3:
35+
https://arxiv.org/abs/1706.05587
36+
"""
37+
38+
def __init__(
39+
self,
40+
encoder_name: str = "resnet34",
41+
encoder_depth: int = 5,
42+
encoder_weights: Optional[str] = "imagenet",
43+
decoder_channels: int = 256,
44+
in_channels: int = 3,
45+
classes: int = 1,
46+
activation: Optional[str] = None,
47+
upsampling: int = 8,
48+
aux_params: Optional[dict] = None,
49+
):
50+
super().__init__()
51+
52+
self.encoder = get_encoder(
53+
encoder_name,
54+
in_channels=in_channels,
55+
depth=encoder_depth,
56+
weights=encoder_weights,
57+
)
58+
self.encoder.make_dilated(
59+
stage_list=[4, 5],
60+
dilation_list=[2, 4]
61+
)
62+
63+
self.decoder = DeepLabV3Decoder(
64+
in_channels=self.encoder.out_channels[-1],
65+
out_channels=decoder_channels,
66+
)
67+
68+
self.segmentation_head = SegmentationHead(
69+
in_channels=self.decoder.out_channels,
70+
out_channels=classes,
71+
activation=activation,
72+
kernel_size=1,
73+
upsampling=upsampling,
74+
)
75+
76+
if aux_params is not None:
77+
self.classification_head = ClassificationHead(
78+
in_channels=self.encoder.out_channels[-1], **aux_params
79+
)
80+
else:
81+
self.classification_head = None

‎segmentation_models_pytorch/encoders/_base.py

Copy file name to clipboardExpand all lines: segmentation_models_pytorch/encoders/_base.py
+2-1Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
import torch
22
import torch.nn as nn
33
from typing import List
4+
from collections import OrderedDict
45

56
from . import _utils as utils
67

@@ -12,7 +13,7 @@ class EncoderMixin:
1213
"""
1314

1415
@property
15-
def out_channels(self) -> List:
16+
def out_channels(self):
1617
"""Return channels dimensions for each tensor of forward output of encoder"""
1718
return self._out_channels[: self._depth + 1]
1819

‎tests/test_models.py

Copy file name to clipboardExpand all lines: tests/test_models.py
+38-12Lines changed: 38 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -27,18 +27,32 @@ def get_encoders():
2727

2828
ENCODERS = get_encoders()
2929
DEFAULT_ENCODER = "resnet18"
30-
DEFAULT_SAMPLE = torch.ones([1, 3, 64, 64])
31-
DEFAULT_PAN_SAMPLE = torch.ones([2, 3, 256, 256])
3230

3331

34-
def _test_forward(model):
32+
def get_sample(model_class):
33+
if model_class in [smp.Unet, smp.Linknet, smp.FPN, smp.PSPNet]:
34+
sample = torch.ones([1, 3, 64, 64])
35+
elif model_class == smp.PAN:
36+
sample = torch.ones([2, 3, 256, 256])
37+
elif model_class == smp.DeepLabV3:
38+
sample = torch.ones([2, 3, 128, 128])
39+
else:
40+
raise ValueError("Not supported model class {}".format(model_class))
41+
return sample
42+
43+
44+
def _test_forward(model, sample, test_shape=False):
3545
with torch.no_grad():
36-
model(DEFAULT_SAMPLE)
46+
out = model(sample)
47+
if test_shape:
48+
assert out.shape[2:] == sample.shape[2:]
3749

3850

39-
def _test_forward_backward(model, sample):
51+
def _test_forward_backward(model, sample, test_shape=False):
4052
out = model(sample)
4153
out.mean().backward()
54+
if test_shape:
55+
assert out.shape[2:] == sample.shape[2:]
4256

4357

4458
@pytest.mark.parametrize("encoder_name", ENCODERS)
@@ -50,12 +64,22 @@ def test_forward(model_class, encoder_name, encoder_depth, **kwargs):
5064
model = model_class(
5165
encoder_name, encoder_depth=encoder_depth, encoder_weights=None, **kwargs
5266
)
53-
_test_forward(model)
67+
sample = get_sample(model_class)
5468

69+
if encoder_depth == 5 and model_class != smp.PSPNet:
70+
test_shape = True
71+
else:
72+
test_shape = False
5573

56-
@pytest.mark.parametrize("model_class", [smp.PAN, smp.FPN, smp.PSPNet, smp.Linknet, smp.Unet])
74+
_test_forward(model, sample, test_shape)
75+
76+
77+
@pytest.mark.parametrize(
78+
"model_class",
79+
[smp.PAN, smp.FPN, smp.PSPNet, smp.Linknet, smp.Unet, smp.DeepLabV3]
80+
)
5781
def test_forward_backward(model_class):
58-
sample = DEFAULT_PAN_SAMPLE if model_class is smp.PAN else DEFAULT_SAMPLE
82+
sample = get_sample(model_class)
5983
model = model_class(DEFAULT_ENCODER, encoder_weights=None)
6084
_test_forward_backward(model, sample)
6185

@@ -65,8 +89,8 @@ def test_aux_output(model_class):
6589
model = model_class(
6690
DEFAULT_ENCODER, encoder_weights=None, aux_params=dict(classes=2)
6791
)
68-
sample = DEFAULT_PAN_SAMPLE if model_class is smp.PAN else DEFAULT_SAMPLE
69-
label_size = (2, 2) if model_class is smp.PAN else (1, 2)
92+
sample = get_sample(model_class)
93+
label_size = (sample.shape[0], 2)
7094
mask, label = model(sample)
7195
assert label.size() == label_size
7296

@@ -76,7 +100,8 @@ def test_aux_output(model_class):
76100
def test_upsample(model_class, upsampling):
77101
default_upsampling = 4 if model_class is smp.FPN else 8
78102
model = model_class(DEFAULT_ENCODER, encoder_weights=None, upsampling=upsampling)
79-
mask = model(DEFAULT_SAMPLE)
103+
sample = get_sample(model_class)
104+
mask = model(sample)
80105
assert mask.size()[-1] / 64 == upsampling / default_upsampling
81106

82107

@@ -106,7 +131,8 @@ def test_dilation(encoder_name):
106131

107132
encoder.eval()
108133
with torch.no_grad():
109-
output = encoder(DEFAULT_SAMPLE)
134+
sample = torch.ones([1, 3, 64, 64])
135+
output = encoder(sample)
110136

111137
shapes = [out.shape[-1] for out in output]
112138
assert shapes == [64, 32, 16, 8, 4, 4] # last downsampling replaced with dilation

0 commit comments

Comments
0 (0)
Morty Proxy This is a proxified and sanitized view of the page, visit original site.