Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Appearance settings

Commit 6b79c83

Browse filesBrowse files
authored
New features and refactoring (qubvel-org#100)
* Add aux classification head * Add ability to change input channels * Add encoder_depth parameter * Add mobilenet encoder * Add hall of fame
1 parent 378431b commit 6b79c83
Copy full SHA for 6b79c83

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.
Dismiss banner
Expand file treeCollapse file tree

41 files changed

+2443
-1541
lines changed

‎HALLOFFAME.md

Copy file name to clipboard
+76Lines changed: 76 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,76 @@
1+
# Hall of Fame
2+
3+
`Segmentation Models` package is widely used in the image segmentation competitions.
4+
Here you can find competitions, names of the winners and links to their solutions.
5+
6+
Please, follow these rules, when adding a solution to the "Hall of Fame":
7+
8+
1. Solution should be high rated (e.g. for Kaggle gold or silver medal)
9+
2. There should be a description of the solution (post at the forum / code / blog post / paper / pre-print)
10+
11+
12+
## Kaggle
13+
14+
### [Severstal: Steel Defect Detection](https://www.kaggle.com/c/severstal-steel-defect-detection)
15+
16+
- 1st place.
17+
[Wuxi Jiangsu](https://www.kaggle.com/rguo97),
18+
[Hongbo Zhu](https://www.kaggle.com/zhuhongbo),
19+
[Yizhuo Yu](https://www.kaggle.com/paffpaffyu)
20+
[[description](https://www.kaggle.com/c/severstal-steel-defect-detection/discussion/114254#latest-675874)]
21+
22+
- 5th place.
23+
[Guanshuo Xu](https://www.kaggle.com/wowfattie)
24+
[[description](https://www.kaggle.com/c/severstal-steel-defect-detection/discussion/117208#latest-675385)]
25+
26+
- 9th place.
27+
[Jacek Poplawski](https://www.linkedin.com/in/jacekpoplawski/)
28+
[[description](https://www.kaggle.com/c/severstal-steel-defect-detection/discussion/114297#latest-660842)]
29+
30+
- 10th place.
31+
[Alexey Rozhkov](https://www.linkedin.com/in/alexisrozhkov)
32+
[[description](https://www.kaggle.com/c/severstal-steel-defect-detection/discussion/114465#latest-659615)]
33+
34+
- 12th place.
35+
[Pavel Yakubovskiy](https://www.linkedin.com/in/pavel-yakubovskiy/),
36+
[Ilya Dobrynin](https://www.linkedin.com/in/ilya-dobrynin-79a89b106/),
37+
[Denis Kolpakov](https://www.linkedin.com/in/denis-kolpakov-ab3137197/)
38+
[[description](https://www.kaggle.com/c/severstal-steel-defect-detection/discussion/114309#latest-661404)]
39+
40+
- 31st place.
41+
[Insaf Ashrapov](https://www.linkedin.com/in/iashrapov/),
42+
[Igor Krashenyi](https://www.linkedin.com/in/igor-krashenyi-38b89b98),
43+
[Pavel Pleskov](https://www.linkedin.com/in/ppleskov),
44+
[Anton Zakharenkov](https://www.linkedin.com/in/anton-zakharenkov/),
45+
[Nikolai Popov](https://www.linkedin.com/in/nikolai-popov-b2157370/)
46+
[[description](https://www.kaggle.com/c/severstal-steel-defect-detection/discussion/114383#latest-658438)]
47+
[[code](https://github.com/Diyago/Severstal-Steel-Defect-Detection)]
48+
49+
- 55th place.
50+
[Karl Hornlund](https://www.linkedin.com/in/karl-hornlund/)
51+
[[description](https://www.kaggle.com/c/severstal-steel-defect-detection/discussion/114410#latest-672682)]
52+
[[code](https://github.com/khornlund/severstal-steel-defect-detection)]
53+
54+
- Efficiency round 1st place.
55+
[Stefan Stefanov](https://www.linkedin.com/in/stefan-stefanov-63a77b1)
56+
[[description](https://www.kaggle.com/c/severstal-steel-defect-detection/discussion/117486#latest-674229)]
57+
58+
59+
### [Understanding Clouds from Satellite Images](https://www.kaggle.com/c/understanding_cloud_organization)
60+
61+
- 2nd place.
62+
[Andrey Kiryasov](https://www.kaggle.com/ekydna)
63+
[[description](https://www.kaggle.com/c/understanding_cloud_organization/discussion/118255#latest-678189)]
64+
65+
- 4th place.
66+
[Ching-Loong Seow](https://www.linkedin.com/in/clseow/)
67+
[[description](https://www.kaggle.com/c/understanding_cloud_organization/discussion/118016#latest-677333)]
68+
69+
- 34th place.
70+
[Karl Hornlund](https://www.linkedin.com/in/karl-hornlund/)
71+
[[description](https://www.kaggle.com/c/understanding_cloud_organization/discussion/118250#latest-678176)]
72+
[[code](https://github.com/khornlund/understanding-cloud-organization)]
73+
74+
- 55th place.
75+
[Pavel Yakubovskiy](https://www.linkedin.com/in/pavel-yakubovskiy/)
76+
[[description](https://www.kaggle.com/c/understanding_cloud_organization/discussion/118019#latest-678626)]

‎README.md

Copy file name to clipboardExpand all lines: README.md
+109-30Lines changed: 109 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@ The main features of this library are:
77

88
- High level API (just two lines to create neural network)
99
- 4 models architectures for binary and multi class segmentation (including legendary Unet)
10-
- 31 available encoders for each architecture
10+
- 45 available encoders for each architecture
1111
- All encoders have pre-trained weights for faster and better convergence
1212

1313
### Table of content
@@ -16,10 +16,14 @@ The main features of this library are:
1616
3. [Models](#models)
1717
1. [Architectures](#architectires)
1818
2. [Encoders](#encoders)
19-
3. [Pretrained weights](#weights)
2019
4. [Models API](#api)
20+
1. [Input channels](#input-channels)
21+
2. [Auxiliary classification output](#auxiliary-classification-output)
22+
3. [Depth](#depth)
2123
5. [Installation](#installation)
22-
6. [License](#license)
24+
6. [Competitions won with the library](#competitions-won-with-the-library)
25+
7. [License](#license)
26+
8. [Contributing](#contributing)
2327

2428
### Quick start <a name="start"></a>
2529
Since the library is built on the PyTorch framework, created segmentation model is just a PyTorch nn.Module, which can be created as easy as:
@@ -60,33 +64,95 @@ preprocess_input = get_preprocessing_fn('resnet18', pretrained='imagenet')
6064

6165
#### Encoders <a name="encoders"></a>
6266

63-
| Type | Encoder names |
64-
|------------|---------------------------------------------------------------------------------------------|
65-
| VGG | vgg11, vgg13, vgg16, vgg19, vgg11bn, vgg13bn, vgg16bn, vgg19bn |
66-
| DenseNet | densenet121, densenet169, densenet201, densenet161 |
67-
| DPN | dpn68, dpn68b, dpn92, dpn98, dpn107, dpn131 |
68-
| Inception | inceptionresnetv2 |
69-
| ResNet | resnet18, resnet34, resnet50, resnet101, resnet152 |
70-
| ResNeXt | resnext50_32x4d, resnext101_32x8d, resnext101_32x16d, resnext101_32x32d, resnext101_32x48d |
71-
| SE-ResNet | se_resnet50, se_resnet101, se_resnet152 |
72-
| SE-ResNeXt | se_resnext50_32x4d, se_resnext101_32x4d |
73-
| SENet | senet154 |
74-
| EfficientNet | efficientnet-b0, efficientnet-b1, efficientnet-b2, efficientnet-b3, efficientnet-b4, efficientnet-b5, efficientnet-b6, efficientnet-b7
75-
76-
#### Weights <a name="weights"></a>
77-
78-
| Weights name | Encoder names |
79-
|---------------------------------------------------------------------------|-------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------|
80-
| imagenet+5k | dpn68b, dpn92, dpn107 |
81-
| imagenet | vgg11, vgg13, vgg16, vgg19, vgg11bn, vgg13bn, vgg16bn, vgg19bn, <br> densenet121, densenet169, densenet201, densenet161, dpn68, dpn98, dpn131, <br> inceptionresnetv2, <br> resnet18, resnet34, resnet50, resnet101, resnet152, <br> resnext50_32x4d, resnext101_32x8d, <br> se_resnet50, se_resnet101, se_resnet152, <br> se_resnext50_32x4d, se_resnext101_32x4d, <br> senet154, <br> efficientnet-b0, efficientnet-b1, efficientnet-b2, efficientnet-b3, efficientnet-b4, efficientnet-b5, efficientnet-b6, efficientnet-b7 |
82-
| [instagram](https://pytorch.org/hub/facebookresearch_WSL-Images_resnext/) | resnext101_32x8d, resnext101_32x16d, resnext101_32x32d, resnext101_32x48d |
67+
|Encoder |Weights |Params, M |
68+
|--------------------------------|:------------------------------:|:------------------------------:|
69+
|resnet18 |imagenet |11M |
70+
|resnet34 |imagenet |21M |
71+
|resnet50 |imagenet |23M |
72+
|resnet101 |imagenet |42M |
73+
|resnet152 |imagenet |58M |
74+
|resnext50_32x4d |imagenet |22M |
75+
|resnext101_32x8d |imagenet<br>instagram |86M |
76+
|resnext101_32x16d |instagram |191M |
77+
|resnext101_32x32d |instagram |466M |
78+
|resnext101_32x48d |instagram |826M |
79+
|dpn68 |imagenet |11M |
80+
|dpn68b |imagenet+5k |11M |
81+
|dpn92 |imagenet+5k |34M |
82+
|dpn98 |imagenet |58M |
83+
|dpn107 |imagenet+5k |84M |
84+
|dpn131 |imagenet |76M |
85+
|vgg11 |imagenet |9M |
86+
|vgg11_bn |imagenet |9M |
87+
|vgg13 |imagenet |9M |
88+
|vgg13_bn |imagenet |9M |
89+
|vgg16 |imagenet |14M |
90+
|vgg16_bn |imagenet |14M |
91+
|vgg19 |imagenet |20M |
92+
|vgg19_bn |imagenet |20M |
93+
|senet154 |imagenet |113M |
94+
|se_resnet50 |imagenet |26M |
95+
|se_resnet101 |imagenet |47M |
96+
|se_resnet152 |imagenet |64M |
97+
|se_resnext50_32x4d |imagenet |25M |
98+
|se_resnext101_32x4d |imagenet |46M |
99+
|densenet121 |imagenet |6M |
100+
|densenet169 |imagenet |12M |
101+
|densenet201 |imagenet |18M |
102+
|densenet161 |imagenet |26M |
103+
|inceptionresnetv2 |imagenet<br>imagenet+background |54M |
104+
|inceptionv4 |imagenet<br>imagenet+background |41M |
105+
|efficientnet-b0 |imagenet |4M |
106+
|efficientnet-b1 |imagenet |6M |
107+
|efficientnet-b2 |imagenet |7M |
108+
|efficientnet-b3 |imagenet |10M |
109+
|efficientnet-b4 |imagenet |17M |
110+
|efficientnet-b5 |imagenet |28M |
111+
|efficientnet-b6 |imagenet |40M |
112+
|efficientnet-b7 |imagenet |63M |
113+
|mobilenet_v2 |imagenet |2M |
83114

84115
### Models API <a name="api"></a>
116+
85117
- `model.encoder` - pretrained backbone to extract features of different spatial resolution
86-
- `model.decoder` - segmentation head, depends on models architecture (`Unet`/`Linknet`/`PSPNet`/`FPN`)
87-
- `model.activation` - output activation function, one of `sigmoid`, `softmax`
88-
- `model.forward(x)` - sequentially pass `x` through model\`s encoder and decoder (return logits!)
89-
- `model.predict(x)` - inference method, switch model to `.eval()` mode, call `.forward(x)` and apply activation function with `torch.no_grad()`
118+
- `model.decoder` - depends on models architecture (`Unet`/`Linknet`/`PSPNet`/`FPN`)
119+
- `model.segmentation_head` - last block to produce required number of mask channels (include also optional upsampling and activation)
120+
- `model.classification_head` - optional block which create classification head on top of encoder
121+
- `model.forward(x)` - sequentially pass `x` through model\`s encoder, decoder and segmentation head (and classification head if specified)
122+
123+
##### Input channels
124+
Input channels parameter allow you to create models, which process tensors with arbitrary number of channels.
125+
If you use pretrained weights from imagenet - weights of first convolution will be reused for
126+
1- or 2- channels inputs, for input channels > 4 weights of first convolution will be initialized randomly.
127+
```python
128+
model = smp.FPN('resnet34', in_channels=1)
129+
mask = model(torch.ones([1, 1, 64, 64]))
130+
```
131+
132+
##### Auxiliary classification output
133+
All models support `aux_params` parameters, which is default set to `None`.
134+
If `aux_params = None` than classification auxiliary output is not created, else
135+
model produce not only `mask`, but also `label` output with shape `NC`.
136+
Classification head consist of GlobalPooling->Dropout(optional)->Linear->Activation(optional) layers, which can be
137+
configured by `aux_params` as follows:
138+
```python
139+
aux_params=dict(
140+
pooling='avg', # one of 'avg', 'max'
141+
dropout=0.5, # dropout ratio, default is None
142+
activation='sigmoid', # activation function, default is None
143+
classes=4, # define number of output labels
144+
)
145+
model = smp.Unet('resnet34', classes=4, aux_params=aux_params)
146+
mask, label = model(x)
147+
```
148+
149+
##### Depth
150+
Depth parameter specify a number of downsampling operations in encoder, so you can make
151+
your model lighted if specify smaller `depth`.
152+
```python
153+
model = smp.FPN('resnet34', depth=4)
154+
```
155+
90156

91157
### Installation <a name="installation"></a>
92158
PyPI version:
@@ -97,11 +163,24 @@ Latest version from source:
97163
```bash
98164
$ pip install git+https://github.com/qubvel/segmentation_models.pytorch
99165
````
166+
167+
### Competitions won with the library
168+
169+
`Segmentation Models` package is widely used in the image segmentation competitions.
170+
[Here](https://github.com/qubvel/segmentation_models.pytorch/blob/master/HALLOFFAME.md) you can find competitions, names of the winners and links to their solutions.
171+
172+
100173
### License <a name="license"></a>
101174
Project is distributed under [MIT License](https://github.com/qubvel/segmentation_models.pytorch/blob/master/LICENSE)
102175
103-
### Run tests
176+
177+
### Contributing
178+
179+
##### Run test
180+
```bash
181+
$ docker build -f docker/Dockerfile.dev -t smp:dev . && docker run --rm smp:dev pytest -p no:cacheprovider
182+
```
183+
##### Generate table
104184
```bash
105-
$ docker build -f docker/Dockerfile.dev -t smp:dev .
106-
$ docker run --rm smp:dev pytest -p no:cacheprovider
185+
$ docker build -f docker/Dockerfile.dev -t smp:dev . && docker run --rm smp:dev python misc/generate_table.py
107186
```

‎docker/Dockerfile.dev

Copy file name to clipboardExpand all lines: docker/Dockerfile.dev
+1-1Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
FROM anibali/pytorch:cuda-9.0
1+
FROM python:3.6 #anibali/pytorch:cuda-9.0
22

33
WORKDIR /tmp/smp/
44

‎examples/cars segmentation (camvid).ipynb

Copy file name to clipboardExpand all lines: examples/cars segmentation (camvid).ipynb
+145-169Lines changed: 145 additions & 169 deletions
Large diffs are not rendered by default.

‎misc/generate_table.py

Copy file name to clipboard
+33Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,33 @@
1+
import segmentation_models_pytorch as smp
2+
3+
encoders = smp.encoders.encoders
4+
5+
6+
WIDTH = 32
7+
COLUMNS = [
8+
"Encoder",
9+
"Weights",
10+
"Params, M",
11+
]
12+
13+
def wrap_row(r):
14+
return "|{}|".format(r)
15+
16+
header = "|".join([column.ljust(WIDTH, ' ') for column in COLUMNS])
17+
separator = "|".join(["-" * WIDTH] + [":" + "-" * (WIDTH - 2) + ":"] * (len(COLUMNS) - 1))
18+
19+
print(wrap_row(header))
20+
print(wrap_row(separator))
21+
22+
for encoder_name, encoder in encoders.items():
23+
weights = "<br>".join(encoder["pretrained_settings"].keys())
24+
encoder_name = encoder_name.ljust(WIDTH, " ")
25+
weights = weights.ljust(WIDTH, " ")
26+
27+
model = encoder["encoder"](**encoder["params"], depth=5)
28+
params = sum(p.numel() for p in model.parameters())
29+
params = str(params // 1000000) + "M"
30+
params = params.ljust(WIDTH, " ")
31+
32+
row = "|".join([encoder_name, weights, params])
33+
print(wrap_row(row))

‎requirements.txt

Copy file name to clipboard
+2-2Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,3 @@
1-
torchvision>=0.2.2,<=0.4.0
1+
torchvision>=0.3.0
22
pretrainedmodels==0.7.4
3-
efficientnet-pytorch==0.4.0
3+
efficientnet-pytorch>=0.5.1
+1-1Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,3 @@
1-
VERSION = (0, 0, 3)
1+
VERSION = (0, 1, 0)
22

33
__version__ = '.'.join(map(str, VERSION))
+11-1Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1,11 @@
1-
from .encoder_decoder import EncoderDecoder
1+
from .model import SegmentationModel
2+
3+
from .modules import (
4+
Conv2dReLU,
5+
Attention,
6+
)
7+
8+
from .heads import (
9+
SegmentationHead,
10+
ClassificationHead,
11+
)

‎segmentation_models_pytorch/base/encoder_decoder.py

Copy file name to clipboardExpand all lines: segmentation_models_pytorch/base/encoder_decoder.py
-47Lines changed: 0 additions & 47 deletions
This file was deleted.

0 commit comments

Comments
0 (0)
Morty Proxy This is a proxified and sanitized view of the page, visit original site.