Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Appearance settings

Commit f249c81

Browse filesBrowse files
SiarheiFedartsouqubvel
authored andcommitted
Add EfficientNet encoder (qubvel-org#73)
1 parent 7670703 commit f249c81
Copy full SHA for f249c81

File tree

Expand file treeCollapse file tree

4 files changed

+134
-2
lines changed
Filter options
Expand file treeCollapse file tree

4 files changed

+134
-2
lines changed

‎README.md

Copy file name to clipboardExpand all lines: README.md
+3-2Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -70,14 +70,15 @@ preprocess_input = get_preprocessing_fn('resnet18', pretrained='imagenet')
7070
| ResNeXt | resnext50_32x4d, resnext101_32x8d, resnext101_32x16d, resnext101_32x32d, resnext101_32x48d |
7171
| SE-ResNet | se_resnet50, se_resnet101, se_resnet152 |
7272
| SE-ResNeXt | se_resnext50_32x4d, se_resnext101_32x4d |
73-
| SENet | senet154 |
73+
| SENet | senet154 |
74+
| EfficientNet | efficientnet-b0, efficientnet-b1, efficientnet-b2, efficientnet-b3, efficientnet-b4, efficientnet-b5, efficientnet-b6, efficientnet-b7
7475

7576
#### Weights <a name="weights"></a>
7677

7778
| Weights name | Encoder names |
7879
|---------------------------------------------------------------------------|-------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------|
7980
| imagenet+5k | dpn68b, dpn92, dpn107 |
80-
| imagenet | vgg11, vgg13, vgg16, vgg19, vgg11bn, vgg13bn, vgg16bn, vgg19bn, <br> densenet121, densenet169, densenet201, densenet161, dpn68, dpn98, dpn131, <br> inceptionresnetv2, <br> resnet18, resnet34, resnet50, resnet101, resnet152, <br> resnext50_32x4d, resnext101_32x8d, <br> se_resnet50, se_resnet101, se_resnet152, <br> se_resnext50_32x4d, se_resnext101_32x4d, <br> senet154 |
81+
| imagenet | vgg11, vgg13, vgg16, vgg19, vgg11bn, vgg13bn, vgg16bn, vgg19bn, <br> densenet121, densenet169, densenet201, densenet161, dpn68, dpn98, dpn131, <br> inceptionresnetv2, <br> resnet18, resnet34, resnet50, resnet101, resnet152, <br> resnext50_32x4d, resnext101_32x8d, <br> se_resnet50, se_resnet101, se_resnet152, <br> se_resnext50_32x4d, se_resnext101_32x4d, <br> senet154, <br> efficientnet-b0, efficientnet-b1, efficientnet-b2, efficientnet-b3, efficientnet-b4, efficientnet-b5, efficientnet-b6, efficientnet-b7 |
8182
| [instagram](https://pytorch.org/hub/facebookresearch_WSL-Images_resnext/) | resnext101_32x8d, resnext101_32x16d, resnext101_32x32d, resnext101_32x48d |
8283

8384
### Models API <a name="api"></a>

‎requirements.txt

Copy file name to clipboard
+1Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,2 +1,3 @@
11
torchvision>=0.2.2,<=0.4.0
22
pretrainedmodels==0.7.4
3+
efficientnet-pytorch==0.4.0

‎segmentation_models_pytorch/encoders/__init__.py

Copy file name to clipboardExpand all lines: segmentation_models_pytorch/encoders/__init__.py
+3Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,8 @@
77
from .senet import senet_encoders
88
from .densenet import densenet_encoders
99
from .inceptionresnetv2 import inception_encoders
10+
from .efficientnet import efficient_net_encoders
11+
1012

1113
from ._preprocessing import preprocess_input
1214

@@ -17,6 +19,7 @@
1719
encoders.update(senet_encoders)
1820
encoders.update(densenet_encoders)
1921
encoders.update(inception_encoders)
22+
encoders.update(efficient_net_encoders)
2023

2124

2225
def get_encoder(name, encoder_weights=None):
+127Lines changed: 127 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,127 @@
1+
from efficientnet_pytorch import EfficientNet
2+
from efficientnet_pytorch.utils import relu_fn, url_map, get_model_params
3+
import torch.nn as nn
4+
import torch
5+
6+
7+
class EfficientNetEncoder(EfficientNet):
8+
def __init__(self, skip_connections, model_name):
9+
blocks_args, global_params = get_model_params(model_name, override_params=None)
10+
11+
super().__init__(blocks_args, global_params)
12+
self._skip_connections = list(skip_connections)
13+
self._skip_connections.append(len(self._blocks))
14+
15+
del self._fc
16+
17+
def forward(self, x):
18+
result = []
19+
x = relu_fn(self._bn0(self._conv_stem(x)))
20+
result.append(x)
21+
22+
skip_connection_idx = 0
23+
for idx, block in enumerate(self._blocks):
24+
drop_connect_rate = self._global_params.drop_connect_rate
25+
if drop_connect_rate:
26+
drop_connect_rate *= float(idx) / len(self._blocks)
27+
x = block(x, drop_connect_rate=drop_connect_rate)
28+
if idx == self._skip_connections[skip_connection_idx] - 1:
29+
skip_connection_idx += 1
30+
result.append(x)
31+
32+
return list(reversed(result))
33+
34+
def load_state_dict(self, state_dict, **kwargs):
35+
state_dict.pop('_fc.bias')
36+
state_dict.pop('_fc.weight')
37+
super().load_state_dict(state_dict, **kwargs)
38+
39+
40+
41+
def _get_pretrained_settings(encoder):
42+
pretrained_settings = {
43+
'imagenet': {
44+
'mean': [0.485, 0.456, 0.406],
45+
'std': [0.229, 0.224, 0.225],
46+
'url': url_map[encoder],
47+
'input_space': 'RGB',
48+
'input_range': [0, 1]
49+
}
50+
}
51+
return pretrained_settings
52+
53+
54+
efficient_net_encoders = {
55+
'efficientnet-b0': {
56+
'encoder': EfficientNetEncoder,
57+
'out_shapes': (320, 112, 40, 24, 32),
58+
'pretrained_settings': _get_pretrained_settings('efficientnet-b0'),
59+
'params': {
60+
'skip_connections': [3, 5, 9],
61+
'model_name': 'efficientnet-b0'
62+
}
63+
},
64+
'efficientnet-b1': {
65+
'encoder': EfficientNetEncoder,
66+
'out_shapes': (320, 112, 40, 24, 32),
67+
'pretrained_settings': _get_pretrained_settings('efficientnet-b1'),
68+
'params': {
69+
'skip_connections': [5, 8, 16],
70+
'model_name': 'efficientnet-b1'
71+
}
72+
},
73+
'efficientnet-b2': {
74+
'encoder': EfficientNetEncoder,
75+
'out_shapes': (352, 120, 48, 24, 32),
76+
'pretrained_settings': _get_pretrained_settings('efficientnet-b2'),
77+
'params': {
78+
'skip_connections': [5, 8, 16],
79+
'model_name': 'efficientnet-b2'
80+
}
81+
},
82+
'efficientnet-b3': {
83+
'encoder': EfficientNetEncoder,
84+
'out_shapes': (384, 136, 48, 32, 40),
85+
'pretrained_settings': _get_pretrained_settings('efficientnet-b3'),
86+
'params': {
87+
'skip_connections': [5, 8, 18],
88+
'model_name': 'efficientnet-b3'
89+
}
90+
},
91+
'efficientnet-b4': {
92+
'encoder': EfficientNetEncoder,
93+
'out_shapes': (448, 160, 56, 32, 48),
94+
'pretrained_settings': _get_pretrained_settings('efficientnet-b4'),
95+
'params': {
96+
'skip_connections': [6, 10, 22],
97+
'model_name': 'efficientnet-b4'
98+
}
99+
},
100+
'efficientnet-b5': {
101+
'encoder': EfficientNetEncoder,
102+
'out_shapes': (512, 176, 64, 40, 48),
103+
'pretrained_settings': _get_pretrained_settings('efficientnet-b5'),
104+
'params': {
105+
'skip_connections': [8, 13, 27],
106+
'model_name': 'efficientnet-b5'
107+
}
108+
},
109+
'efficientnet-b6': {
110+
'encoder': EfficientNetEncoder,
111+
'out_shapes': (576, 200, 72, 40, 56),
112+
'pretrained_settings': _get_pretrained_settings('efficientnet-b6'),
113+
'params': {
114+
'skip_connections': [9, 15, 31],
115+
'model_name': 'efficientnet-b6'
116+
}
117+
},
118+
'efficientnet-b7': {
119+
'encoder': EfficientNetEncoder,
120+
'out_shapes': (640, 224, 80, 48, 64),
121+
'pretrained_settings': _get_pretrained_settings('efficientnet-b7'),
122+
'params': {
123+
'skip_connections': [11, 18, 38],
124+
'model_name': 'efficientnet-b7'
125+
}
126+
}
127+
}

0 commit comments

Comments
0 (0)
Morty Proxy This is a proxified and sanitized view of the page, visit original site.