Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Appearance settings

Commit 1e1f13e

Browse filesBrowse files
authored
[Feature] Dilated encoders (qubvel-org#125)
* Prepare encoders stages * Dilated encoders * Fix for encoders which not support dilation
1 parent 1204d2f commit 1e1f13e
Copy full SHA for 1e1f13e

File tree

Expand file treeCollapse file tree

14 files changed

+251
-187
lines changed
Filter options
Expand file treeCollapse file tree

14 files changed

+251
-187
lines changed

‎docker/Dockerfile.dev

Copy file name to clipboardExpand all lines: docker/Dockerfile.dev
+1-1Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
FROM python:3.6 #anibali/pytorch:cuda-9.0
1+
FROM anibali/pytorch:no-cuda
22

33
WORKDIR /tmp/smp/
44

‎segmentation_models_pytorch/encoders/_base.py

Copy file name to clipboardExpand all lines: segmentation_models_pytorch/encoders/_base.py
+14-33Lines changed: 14 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -2,12 +2,15 @@
22
import torch.nn as nn
33
from typing import List
44

5+
from . import _utils as utils
6+
57

68
class EncoderMixin:
79
"""Add encoder functionality such as:
810
- output channels specification of feature tensors (produced by encoder)
911
- patching first convolution for arbitrary input channels
1012
"""
13+
1114
@property
1215
def out_channels(self) -> List:
1316
"""Return channels dimensions for each tensor of forward output of encoder"""
@@ -22,38 +25,16 @@ def set_in_channels(self, in_channels):
2225
if self._out_channels[0] == 3:
2326
self._out_channels = tuple([in_channels] + list(self._out_channels)[1:])
2427

25-
patch_first_conv(model=self, in_channels=in_channels)
28+
utils.patch_first_conv(model=self, in_channels=in_channels)
2629

30+
def get_stages(self):
31+
"""Method should be overridden in encoder"""
32+
raise NotImplementedError
2733

28-
def patch_first_conv(model, in_channels):
29-
"""Change first convolution layer input channels.
30-
In case:
31-
in_channels == 1 or in_channels == 2 -> reuse original weights
32-
in_channels > 3 -> make random kaiming normal initialization
33-
"""
34-
35-
# get first conv
36-
for module in model.modules():
37-
if isinstance(module, nn.Conv2d):
38-
break
39-
40-
# change input channels for first conv
41-
module.in_channels = in_channels
42-
weight = module.weight.detach()
43-
reset = False
44-
45-
if in_channels == 1:
46-
weight = weight.sum(1, keepdim=True)
47-
elif in_channels == 2:
48-
weight = weight[:, :2] * (3.0 / 2.0)
49-
else:
50-
reset = True
51-
weight = torch.Tensor(
52-
module.out_channels,
53-
module.in_channels // module.groups,
54-
*module.kernel_size
55-
)
56-
57-
module.weight = nn.parameter.Parameter(weight)
58-
if reset:
59-
module.reset_parameters()
34+
def make_dilated(self, stage_list, dilation_list):
35+
stages = self.get_stages()
36+
for stage_indx, dilation_rate in zip(stage_list, dilation_list):
37+
utils.replace_strides_with_dilation(
38+
module=stages[stage_indx],
39+
dilation_rate=dilation_rate,
40+
)
+50Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,50 @@
1+
import torch
2+
import torch.nn as nn
3+
4+
5+
def patch_first_conv(model, in_channels):
6+
"""Change first convolution layer input channels.
7+
In case:
8+
in_channels == 1 or in_channels == 2 -> reuse original weights
9+
in_channels > 3 -> make random kaiming normal initialization
10+
"""
11+
12+
# get first conv
13+
for module in model.modules():
14+
if isinstance(module, nn.Conv2d):
15+
break
16+
17+
# change input channels for first conv
18+
module.in_channels = in_channels
19+
weight = module.weight.detach()
20+
reset = False
21+
22+
if in_channels == 1:
23+
weight = weight.sum(1, keepdim=True)
24+
elif in_channels == 2:
25+
weight = weight[:, :2] * (3.0 / 2.0)
26+
else:
27+
reset = True
28+
weight = torch.Tensor(
29+
module.out_channels,
30+
module.in_channels // module.groups,
31+
*module.kernel_size
32+
)
33+
34+
module.weight = nn.parameter.Parameter(weight)
35+
if reset:
36+
module.reset_parameters()
37+
38+
39+
def replace_strides_with_dilation(module, dilation_rate):
40+
"""Patch Conv2d modules replacing strides with dilation"""
41+
for mod in module.modules():
42+
if isinstance(mod, nn.Conv2d):
43+
mod.stride = (1, 1)
44+
mod.dilation = (dilation_rate, dilation_rate)
45+
kh, kw = mod.kernel_size
46+
mod.padding = ((kh // 2) * dilation_rate, (kh // 2) * dilation_rate)
47+
48+
# Kostyl for EfficientNet
49+
if hasattr(mod, "static_padding"):
50+
mod.static_padding = nn.Identity()

‎segmentation_models_pytorch/encoders/densenet.py

Copy file name to clipboardExpand all lines: segmentation_models_pytorch/encoders/densenet.py
+38-35Lines changed: 38 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,20 @@
3232
from ._base import EncoderMixin
3333

3434

35+
class TransitionWithSkip(nn.Module):
36+
37+
def __init__(self, module):
38+
super().__init__()
39+
self.module = module
40+
41+
def forward(self, x):
42+
for module in self.module:
43+
x = module(x)
44+
if isinstance(module, nn.ReLU):
45+
skip = x
46+
return x, skip
47+
48+
3549
class DenseNetEncoder(DenseNet, EncoderMixin):
3650
def __init__(self, out_channels, depth=5, **kwargs):
3751
super().__init__(**kwargs)
@@ -40,44 +54,33 @@ def __init__(self, out_channels, depth=5, **kwargs):
4054
self._in_channels = 3
4155
del self.classifier
4256

43-
@staticmethod
44-
def _transition(x, transition_block):
45-
for module in transition_block:
46-
x = module(x)
47-
if isinstance(module, nn.ReLU):
48-
skip = x
49-
return x, skip
57+
def make_dilated(self, stage_list, dilation_list):
58+
raise ValueError("DenseNet encoders do not support dilated mode "
59+
"due to pooling operation for downsampling!")
60+
61+
def get_stages(self):
62+
return [
63+
nn.Identity(),
64+
nn.Sequential(self.features.conv0, self.features.norm0, self.features.relu0),
65+
nn.Sequential(self.features.pool0, self.features.denseblock1,
66+
TransitionWithSkip(self.features.transition1)),
67+
nn.Sequential(self.features.denseblock2, TransitionWithSkip(self.features.transition2)),
68+
nn.Sequential(self.features.denseblock3, TransitionWithSkip(self.features.transition3)),
69+
nn.Sequential(self.features.denseblock4, self.features.norm5)
70+
]
5071

5172
def forward(self, x):
5273

53-
features = [x]
54-
55-
if self._depth > 0:
56-
x = self.features.conv0(x)
57-
x = self.features.norm0(x)
58-
x = self.features.relu0(x)
59-
features.append(x)
60-
61-
if self._depth > 1:
62-
x = self.features.pool0(x)
63-
x = self.features.denseblock1(x)
64-
x, x1 = self._transition(x, self.features.transition1)
65-
features.append(x1)
66-
67-
if self._depth > 2:
68-
x = self.features.denseblock2(x)
69-
x, x2 = self._transition(x, self.features.transition2)
70-
features.append(x2)
71-
72-
if self._depth > 3:
73-
x = self.features.denseblock3(x)
74-
x, x3 = self._transition(x, self.features.transition3)
75-
features.append(x3)
76-
77-
if self._depth > 4:
78-
x = self.features.denseblock4(x)
79-
x4 = self.features.norm5(x)
80-
features.append(x4)
74+
stages = self.get_stages()
75+
76+
features = []
77+
for i in range(self._depth + 1):
78+
x = stages[i](x)
79+
if isinstance(x, (list, tuple)):
80+
x, skip = x
81+
features.append(skip)
82+
else:
83+
features.append(x)
8184

8285
return features
8386

‎segmentation_models_pytorch/encoders/dpn.py

Copy file name to clipboardExpand all lines: segmentation_models_pytorch/encoders/dpn.py
+6-3Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -43,9 +43,8 @@ def __init__(self, stage_idxs, out_channels, depth=5, **kwargs):
4343

4444
del self.last_linear
4545

46-
def forward(self, x):
47-
48-
stages = [
46+
def get_stages(self):
47+
return [
4948
nn.Identity(),
5049
nn.Sequential(self.features[0].conv, self.features[0].bn, self.features[0].act),
5150
nn.Sequential(self.features[0].pool, self.features[1 : self._stage_idxs[0]]),
@@ -54,6 +53,10 @@ def forward(self, x):
5453
self.features[self._stage_idxs[2] : self._stage_idxs[3]],
5554
]
5655

56+
def forward(self, x):
57+
58+
stages = self.get_stages()
59+
5760
features = []
5861
for i in range(self._depth + 1):
5962
x = stages[i](x)

‎segmentation_models_pytorch/encoders/efficientnet.py

Copy file name to clipboardExpand all lines: segmentation_models_pytorch/encoders/efficientnet.py
+37-26Lines changed: 37 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@
2222
number of feature tensors = 6 (one with same resolution as input and 5 downsampled),
2323
depth = 3 -> number of feature tensors = 4 (one with same resolution as input and 3 downsampled).
2424
"""
25-
25+
import torch.nn as nn
2626
from efficientnet_pytorch import EfficientNet
2727
from efficientnet_pytorch.utils import url_map, get_model_params
2828

@@ -35,33 +35,44 @@ def __init__(self, stage_idxs, out_channels, model_name, depth=5):
3535
blocks_args, global_params = get_model_params(model_name, override_params=None)
3636
super().__init__(blocks_args, global_params)
3737

38-
self._stage_idxs = list(stage_idxs) + [len(self._blocks)]
38+
self._stage_idxs = stage_idxs
3939
self._out_channels = out_channels
4040
self._depth = depth
4141
self._in_channels = 3
4242

4343
del self._fc
4444

45+
def get_stages(self):
46+
return [
47+
nn.Identity(),
48+
nn.Sequential(self._conv_stem, self._bn0, self._swish),
49+
self._blocks[:self._stage_idxs[0]],
50+
self._blocks[self._stage_idxs[0]:self._stage_idxs[1]],
51+
self._blocks[self._stage_idxs[1]:self._stage_idxs[2]],
52+
self._blocks[self._stage_idxs[2]:],
53+
]
54+
4555
def forward(self, x):
56+
stages = self.get_stages()
4657

47-
features = [x]
58+
block_number = 0.
59+
drop_connect_rate = self._global_params.drop_connect_rate
4860

49-
if self._depth > 0:
50-
x = self._swish(self._bn0(self._conv_stem(x)))
51-
features.append(x)
61+
features = []
62+
for i in range(self._depth + 1):
5263

53-
if self._depth > 1:
54-
skip_connection_idx = 0
55-
for idx, block in enumerate(self._blocks):
56-
drop_connect_rate = self._global_params.drop_connect_rate
57-
if drop_connect_rate:
58-
drop_connect_rate *= float(idx) / len(self._blocks)
59-
x = block(x, drop_connect_rate=drop_connect_rate)
60-
if idx == self._stage_idxs[skip_connection_idx] - 1:
61-
skip_connection_idx += 1
62-
features.append(x)
63-
if skip_connection_idx + 1 == self._depth:
64-
break
64+
# Identity and Sequential stages
65+
if i < 2:
66+
x = stages[i](x)
67+
68+
# Block stages need drop_connect rate
69+
else:
70+
for module in stages[i]:
71+
drop_connect = drop_connect_rate * block_number / len(self._blocks)
72+
block_number += 1.
73+
x = module(x, drop_connect)
74+
75+
features.append(x)
6576

6677
return features
6778

@@ -90,7 +101,7 @@ def _get_pretrained_settings(encoder):
90101
"pretrained_settings": _get_pretrained_settings("efficientnet-b0"),
91102
"params": {
92103
"out_channels": (3, 32, 24, 40, 112, 320),
93-
"stage_idxs": (3, 5, 9),
104+
"stage_idxs": (3, 5, 9, 16),
94105
"model_name": "efficientnet-b0",
95106
},
96107
},
@@ -99,7 +110,7 @@ def _get_pretrained_settings(encoder):
99110
"pretrained_settings": _get_pretrained_settings("efficientnet-b1"),
100111
"params": {
101112
"out_channels": (3, 32, 24, 40, 112, 320),
102-
"stage_idxs": (5, 8, 16),
113+
"stage_idxs": (5, 8, 16, 23),
103114
"model_name": "efficientnet-b1",
104115
},
105116
},
@@ -108,7 +119,7 @@ def _get_pretrained_settings(encoder):
108119
"pretrained_settings": _get_pretrained_settings("efficientnet-b2"),
109120
"params": {
110121
"out_channels": (3, 32, 24, 48, 120, 352),
111-
"stage_idxs": (5, 8, 16),
122+
"stage_idxs": (5, 8, 16, 23),
112123
"model_name": "efficientnet-b2",
113124
},
114125
},
@@ -117,7 +128,7 @@ def _get_pretrained_settings(encoder):
117128
"pretrained_settings": _get_pretrained_settings("efficientnet-b3"),
118129
"params": {
119130
"out_channels": (3, 40, 32, 48, 136, 384),
120-
"stage_idxs": (5, 8, 18),
131+
"stage_idxs": (5, 8, 18, 26),
121132
"model_name": "efficientnet-b3",
122133
},
123134
},
@@ -126,7 +137,7 @@ def _get_pretrained_settings(encoder):
126137
"pretrained_settings": _get_pretrained_settings("efficientnet-b4"),
127138
"params": {
128139
"out_channels": (3, 48, 32, 56, 160, 448),
129-
"stage_idxs": (6, 10, 22),
140+
"stage_idxs": (6, 10, 22, 32),
130141
"model_name": "efficientnet-b4",
131142
},
132143
},
@@ -135,7 +146,7 @@ def _get_pretrained_settings(encoder):
135146
"pretrained_settings": _get_pretrained_settings("efficientnet-b5"),
136147
"params": {
137148
"out_channels": (3, 48, 40, 64, 176, 512),
138-
"stage_idxs": (8, 13, 27),
149+
"stage_idxs": (8, 13, 27, 39),
139150
"model_name": "efficientnet-b5",
140151
},
141152
},
@@ -144,7 +155,7 @@ def _get_pretrained_settings(encoder):
144155
"pretrained_settings": _get_pretrained_settings("efficientnet-b6"),
145156
"params": {
146157
"out_channels": (3, 56, 40, 72, 200, 576),
147-
"stage_idxs": (9, 15, 31),
158+
"stage_idxs": (9, 15, 31, 45),
148159
"model_name": "efficientnet-b6",
149160
},
150161
},
@@ -153,7 +164,7 @@ def _get_pretrained_settings(encoder):
153164
"pretrained_settings": _get_pretrained_settings("efficientnet-b7"),
154165
"params": {
155166
"out_channels": (3, 64, 48, 80, 224, 640),
156-
"stage_idxs": (11, 18, 38),
167+
"stage_idxs": (11, 18, 38, 55),
157168
"model_name": "efficientnet-b7",
158169
},
159170
},

0 commit comments

Comments
0 (0)
Morty Proxy This is a proxified and sanitized view of the page, visit original site.