@@ -7,7 +7,7 @@ The main features of this library are:
7
7
8
8
- High level API (just two lines to create neural network)
9
9
- 4 models architectures for binary and multi class segmentation (including legendary Unet)
10
- - 31 available encoders for each architecture
10
+ - 45 available encoders for each architecture
11
11
- All encoders have pre-trained weights for faster and better convergence
12
12
13
13
### Table of content
@@ -16,10 +16,14 @@ The main features of this library are:
16
16
3 . [ Models] ( #models )
17
17
1 . [ Architectures] ( #architectires )
18
18
2 . [ Encoders] ( #encoders )
19
- 3 . [ Pretrained weights] ( #weights )
20
19
4 . [ Models API] ( #api )
20
+ 1 . [ Input channels] ( #input-channels )
21
+ 2 . [ Auxiliary classification output] ( #auxiliary-classification-output )
22
+ 3 . [ Depth] ( #depth )
21
23
5 . [ Installation] ( #installation )
22
- 6 . [ License] ( #license )
24
+ 6 . [ Competitions won with the library] ( #competitions-won-with-the-library )
25
+ 7 . [ License] ( #license )
26
+ 8 . [ Contributing] ( #contributing )
23
27
24
28
### Quick start <a name =" start " ></a >
25
29
Since the library is built on the PyTorch framework, created segmentation model is just a PyTorch nn.Module, which can be created as easy as:
@@ -60,33 +64,95 @@ preprocess_input = get_preprocessing_fn('resnet18', pretrained='imagenet')
60
64
61
65
#### Encoders <a name =" encoders " ></a >
62
66
63
- | Type | Encoder names |
64
- | ------------| ---------------------------------------------------------------------------------------------|
65
- | VGG | vgg11, vgg13, vgg16, vgg19, vgg11bn, vgg13bn, vgg16bn, vgg19bn |
66
- | DenseNet | densenet121, densenet169, densenet201, densenet161 |
67
- | DPN | dpn68, dpn68b, dpn92, dpn98, dpn107, dpn131 |
68
- | Inception | inceptionresnetv2 |
69
- | ResNet | resnet18, resnet34, resnet50, resnet101, resnet152 |
70
- | ResNeXt | resnext50_32x4d, resnext101_32x8d, resnext101_32x16d, resnext101_32x32d, resnext101_32x48d |
71
- | SE-ResNet | se_resnet50, se_resnet101, se_resnet152 |
72
- | SE-ResNeXt | se_resnext50_32x4d, se_resnext101_32x4d |
73
- | SENet | senet154 |
74
- | EfficientNet | efficientnet-b0, efficientnet-b1, efficientnet-b2, efficientnet-b3, efficientnet-b4, efficientnet-b5, efficientnet-b6, efficientnet-b7
75
-
76
- #### Weights <a name =" weights " ></a >
77
-
78
- | Weights name | Encoder names |
79
- | ---------------------------------------------------------------------------| -------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------|
80
- | imagenet+5k | dpn68b, dpn92, dpn107 |
81
- | imagenet | vgg11, vgg13, vgg16, vgg19, vgg11bn, vgg13bn, vgg16bn, vgg19bn, <br > densenet121, densenet169, densenet201, densenet161, dpn68, dpn98, dpn131, <br > inceptionresnetv2, <br > resnet18, resnet34, resnet50, resnet101, resnet152, <br > resnext50_32x4d, resnext101_32x8d, <br > se_resnet50, se_resnet101, se_resnet152, <br > se_resnext50_32x4d, se_resnext101_32x4d, <br > senet154, <br > efficientnet-b0, efficientnet-b1, efficientnet-b2, efficientnet-b3, efficientnet-b4, efficientnet-b5, efficientnet-b6, efficientnet-b7 |
82
- | [ instagram] ( https://pytorch.org/hub/facebookresearch_WSL-Images_resnext/ ) | resnext101_32x8d, resnext101_32x16d, resnext101_32x32d, resnext101_32x48d |
67
+ | Encoder | Weights | Params, M |
68
+ | --------------------------------| :------------------------------:| :------------------------------:|
69
+ | resnet18 | imagenet | 11M |
70
+ | resnet34 | imagenet | 21M |
71
+ | resnet50 | imagenet | 23M |
72
+ | resnet101 | imagenet | 42M |
73
+ | resnet152 | imagenet | 58M |
74
+ | resnext50_32x4d | imagenet | 22M |
75
+ | resnext101_32x8d | imagenet<br >instagram | 86M |
76
+ | resnext101_32x16d | instagram | 191M |
77
+ | resnext101_32x32d | instagram | 466M |
78
+ | resnext101_32x48d | instagram | 826M |
79
+ | dpn68 | imagenet | 11M |
80
+ | dpn68b | imagenet+5k | 11M |
81
+ | dpn92 | imagenet+5k | 34M |
82
+ | dpn98 | imagenet | 58M |
83
+ | dpn107 | imagenet+5k | 84M |
84
+ | dpn131 | imagenet | 76M |
85
+ | vgg11 | imagenet | 9M |
86
+ | vgg11_bn | imagenet | 9M |
87
+ | vgg13 | imagenet | 9M |
88
+ | vgg13_bn | imagenet | 9M |
89
+ | vgg16 | imagenet | 14M |
90
+ | vgg16_bn | imagenet | 14M |
91
+ | vgg19 | imagenet | 20M |
92
+ | vgg19_bn | imagenet | 20M |
93
+ | senet154 | imagenet | 113M |
94
+ | se_resnet50 | imagenet | 26M |
95
+ | se_resnet101 | imagenet | 47M |
96
+ | se_resnet152 | imagenet | 64M |
97
+ | se_resnext50_32x4d | imagenet | 25M |
98
+ | se_resnext101_32x4d | imagenet | 46M |
99
+ | densenet121 | imagenet | 6M |
100
+ | densenet169 | imagenet | 12M |
101
+ | densenet201 | imagenet | 18M |
102
+ | densenet161 | imagenet | 26M |
103
+ | inceptionresnetv2 | imagenet<br >imagenet+background | 54M |
104
+ | inceptionv4 | imagenet<br >imagenet+background | 41M |
105
+ | efficientnet-b0 | imagenet | 4M |
106
+ | efficientnet-b1 | imagenet | 6M |
107
+ | efficientnet-b2 | imagenet | 7M |
108
+ | efficientnet-b3 | imagenet | 10M |
109
+ | efficientnet-b4 | imagenet | 17M |
110
+ | efficientnet-b5 | imagenet | 28M |
111
+ | efficientnet-b6 | imagenet | 40M |
112
+ | efficientnet-b7 | imagenet | 63M |
113
+ | mobilenet_v2 | imagenet | 2M |
83
114
84
115
### Models API <a name =" api " ></a >
116
+
85
117
- ` model.encoder ` - pretrained backbone to extract features of different spatial resolution
86
- - ` model.decoder ` - segmentation head, depends on models architecture (` Unet ` /` Linknet ` /` PSPNet ` /` FPN ` )
87
- - ` model.activation ` - output activation function, one of ` sigmoid ` , ` softmax `
88
- - ` model.forward(x) ` - sequentially pass ` x ` through model\` s encoder and decoder (return logits!)
89
- - ` model.predict(x) ` - inference method, switch model to ` .eval() ` mode, call ` .forward(x) ` and apply activation function with ` torch.no_grad() `
118
+ - ` model.decoder ` - depends on models architecture (` Unet ` /` Linknet ` /` PSPNet ` /` FPN ` )
119
+ - ` model.segmentation_head ` - last block to produce required number of mask channels (include also optional upsampling and activation)
120
+ - ` model.classification_head ` - optional block which create classification head on top of encoder
121
+ - ` model.forward(x) ` - sequentially pass ` x ` through model\` s encoder, decoder and segmentation head (and classification head if specified)
122
+
123
+ ##### Input channels
124
+ Input channels parameter allow you to create models, which process tensors with arbitrary number of channels.
125
+ If you use pretrained weights from imagenet - weights of first convolution will be reused for
126
+ 1- or 2- channels inputs, for input channels > 4 weights of first convolution will be initialized randomly.
127
+ ``` python
128
+ model = smp.FPN(' resnet34' , in_channels = 1 )
129
+ mask = model(torch.ones([1 , 1 , 64 , 64 ]))
130
+ ```
131
+
132
+ ##### Auxiliary classification output
133
+ All models support ` aux_params ` parameters, which is default set to ` None ` .
134
+ If ` aux_params = None ` than classification auxiliary output is not created, else
135
+ model produce not only ` mask ` , but also ` label ` output with shape ` NC ` .
136
+ Classification head consist of GlobalPooling->Dropout(optional)->Linear->Activation(optional) layers, which can be
137
+ configured by ` aux_params ` as follows:
138
+ ``` python
139
+ aux_params= dict (
140
+ pooling = ' avg' , # one of 'avg', 'max'
141
+ dropout = 0.5 , # dropout ratio, default is None
142
+ activation = ' sigmoid' , # activation function, default is None
143
+ classes = 4 , # define number of output labels
144
+ )
145
+ model = smp.Unet(' resnet34' , classes = 4 , aux_params = aux_params)
146
+ mask, label = model(x)
147
+ ```
148
+
149
+ ##### Depth
150
+ Depth parameter specify a number of downsampling operations in encoder, so you can make
151
+ your model lighted if specify smaller ` depth ` .
152
+ ``` python
153
+ model = smp.FPN(' resnet34' , depth = 4 )
154
+ ```
155
+
90
156
91
157
### Installation <a name =" installation " ></a >
92
158
PyPI version:
@@ -97,11 +163,24 @@ Latest version from source:
97
163
` ` ` bash
98
164
$ pip install git+https://github.com/qubvel/segmentation_models.pytorch
99
165
` ` ` `
166
+
167
+ # ## Competitions won with the library
168
+
169
+ ` Segmentation Models` package is widely used in the image segmentation competitions.
170
+ [Here](https://github.com/qubvel/segmentation_models.pytorch/blob/master/HALLOFFAME.md) you can find competitions, names of the winners and links to their solutions.
171
+
172
+
100
173
# ## License <a name="license"></a>
101
174
Project is distributed under [MIT License](https://github.com/qubvel/segmentation_models.pytorch/blob/master/LICENSE)
102
175
103
- # ## Run tests
176
+
177
+ # ## Contributing
178
+
179
+ # #### Run test
180
+ ` ` ` bash
181
+ $ docker build -f docker/Dockerfile.dev -t smp:dev . && docker run --rm smp:dev pytest -p no:cacheprovider
182
+ ```
183
+ ##### Generate table
104
184
``` bash
105
- $ docker build -f docker/Dockerfile.dev -t smp:dev .
106
- $ docker run --rm smp:dev pytest -p no:cacheprovider
185
+ $ docker build -f docker/Dockerfile.dev -t smp:dev . && docker run --rm smp:dev python misc/generate_table.py
107
186
```
0 commit comments