Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Appearance settings

Commit dbb7bd1

Browse filesBrowse files
committed
release MobileViT, from @murufeng
1 parent 86a7302 commit dbb7bd1
Copy full SHA for dbb7bd1

File tree

3 files changed

+76
-55
lines changed
Filter options

3 files changed

+76
-55
lines changed

‎README.md

Copy file name to clipboardExpand all lines: README.md
+13-2Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -559,12 +559,12 @@ perspective for the global processing of information with transformers.
559559

560560
You can use it with the following code (ex. mobilevit_xs)
561561

562-
```
562+
```python
563563
import torch
564564
from vit_pytorch.mobile_vit import MobileViT
565565

566566
mbvit_xs = MobileViT(
567-
image_size=(256, 256),
567+
image_size = (256, 256),
568568
dims = [96, 120, 144],
569569
channels = [16, 32, 48, 48, 64, 64, 80, 80, 96, 96, 384],
570570
num_classes = 1000
@@ -1190,6 +1190,17 @@ Coming from computer vision and new to transformers? Here are some resources tha
11901190
}
11911191
```
11921192

1193+
```bibtex
1194+
@misc{mehta2021mobilevit,
1195+
title = {MobileViT: Light-weight, General-purpose, and Mobile-friendly Vision Transformer},
1196+
author = {Sachin Mehta and Mohammad Rastegari},
1197+
year = {2021},
1198+
eprint = {2110.02178},
1199+
archivePrefix = {arXiv},
1200+
primaryClass = {cs.CV}
1201+
}
1202+
```
1203+
11931204
```bibtex
11941205
@misc{vaswani2017attention,
11951206
title = {Attention Is All You Need},

‎setup.py

Copy file name to clipboardExpand all lines: setup.py
+1-1Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
setup(
44
name = 'vit-pytorch',
55
packages = find_packages(exclude=['examples']),
6-
version = '0.24.3',
6+
version = '0.25.0',
77
license='MIT',
88
description = 'Vision Transformer (ViT) - Pytorch',
99
author = 'Phil Wang',

‎vit_pytorch/mobile_vit.py

Copy file name to clipboardExpand all lines: vit_pytorch/mobile_vit.py
+62-52Lines changed: 62 additions & 52 deletions
Original file line numberDiff line numberDiff line change
@@ -9,9 +9,9 @@
99
import torch.nn as nn
1010

1111
from einops import rearrange
12+
from einops.layers.torch import Reduce
1213

1314
def _make_divisible(v, divisor, min_value=None):
14-
1515
if min_value is None:
1616
min_value = divisor
1717
new_v = max(min_value, int(v + divisor / 2) // divisor * divisor)
@@ -20,7 +20,7 @@ def _make_divisible(v, divisor, min_value=None):
2020
return new_v
2121

2222

23-
def Conv_BN_ReLU(inp, oup, kernel, stride=1):
23+
def conv_bn_relu(inp, oup, kernel, stride=1):
2424
return nn.Sequential(
2525
nn.Conv2d(inp, oup, kernel_size=kernel, stride=stride, padding=1, bias=False),
2626
nn.BatchNorm2d(oup),
@@ -63,8 +63,6 @@ class Attention(nn.Module):
6363
def __init__(self, dim, heads=8, dim_head=64, dropout=0.):
6464
super().__init__()
6565
inner_dim = dim_head * heads
66-
project_out = not (heads == 1 and dim_head == dim)
67-
6866
self.heads = heads
6967
self.scale = dim_head ** -0.5
7068

@@ -74,7 +72,7 @@ def __init__(self, dim, heads=8, dim_head=64, dropout=0.):
7472
self.to_out = nn.Sequential(
7573
nn.Linear(inner_dim, dim),
7674
nn.Dropout(dropout)
77-
) if project_out else nn.Identity()
75+
)
7876

7977
def forward(self, x):
8078
qkv = self.to_qkv(x).chunk(3, dim=-1)
@@ -96,6 +94,7 @@ def __init__(self, dim, depth, heads, dim_head, mlp_dim, dropout = 0.):
9694
PreNorm(dim, Attention(dim, heads = heads, dim_head = dim_head, dropout = dropout)),
9795
PreNorm(dim, FeedForward(dim, mlp_dim, dropout = dropout))
9896
]))
97+
9998
def forward(self, x):
10099
for attn, ff in self.layers:
101100
x = attn(x) + x
@@ -136,23 +135,24 @@ def __init__(self, inp, oup, stride=1, expand_ratio=4):
136135
)
137136

138137
def forward(self, x):
138+
out = self.conv(x)
139+
139140
if self.identity:
140-
return x + self.conv(x)
141-
else:
142-
return self.conv(x)
141+
out = out + x
142+
return out
143143

144144
class MobileViTBlock(nn.Module):
145145
def __init__(self, dim, depth, channel, kernel_size, patch_size, mlp_dim, dropout=0.):
146146
super().__init__()
147147
self.ph, self.pw = patch_size
148148

149-
self.conv1 = Conv_BN_ReLU(channel, channel, kernel_size)
149+
self.conv1 = conv_bn_relu(channel, channel, kernel_size)
150150
self.conv2 = conv_1x1_bn(channel, dim)
151151

152152
self.transformer = Transformer(dim, depth, 1, 32, mlp_dim, dropout)
153153

154154
self.conv3 = conv_1x1_bn(dim, channel)
155-
self.conv4 = Conv_BN_ReLU(2 * channel, channel, kernel_size)
155+
self.conv4 = conv_bn_relu(2 * channel, channel, kernel_size)
156156

157157
def forward(self, x):
158158
y = x.clone()
@@ -165,8 +165,7 @@ def forward(self, x):
165165
_, _, h, w = x.shape
166166
x = rearrange(x, 'b d (h ph) (w pw) -> b (ph pw) (h w) d', ph=self.ph, pw=self.pw)
167167
x = self.transformer(x)
168-
x = rearrange(x, 'b (ph pw) (h w) d -> b d (h ph) (w pw)', h=h // self.ph, w=w // self.pw, ph=self.ph,
169-
pw=self.pw)
168+
x = rearrange(x, 'b (ph pw) (h w) d -> b d (h ph) (w pw)', h=h // self.ph, w=w // self.pw, ph=self.ph, pw=self.pw)
170169

171170
# Fusion
172171
x = self.conv3(x)
@@ -176,54 +175,65 @@ def forward(self, x):
176175

177176

178177
class MobileViT(nn.Module):
179-
def __init__(self, image_size, dims, channels, num_classes, expansion=4, kernel_size=3, patch_size=(2, 2)):
178+
def __init__(
179+
self,
180+
image_size,
181+
dims,
182+
channels,
183+
num_classes,
184+
expansion = 4,
185+
kernel_size = 3,
186+
patch_size = (2, 2),
187+
depths = (2, 4, 3)
188+
):
180189
super().__init__()
190+
assert len(dims) == 3, 'dims must be a tuple of 3'
191+
assert len(depths) == 3, 'depths must be a tuple of 3'
192+
181193
ih, iw = image_size
182194
ph, pw = patch_size
183195
assert ih % ph == 0 and iw % pw == 0
184196

185-
L = [2, 4, 3]
186-
187-
self.conv1 = Conv_BN_ReLU(3, channels[0], kernel=3, stride=2)
188-
189-
self.mv2 = nn.ModuleList([])
190-
self.mv2.append(MV2Block(channels[0], channels[1], 1, expansion))
191-
self.mv2.append(MV2Block(channels[1], channels[2], 2, expansion))
192-
self.mv2.append(MV2Block(channels[2], channels[3], 1, expansion))
193-
self.mv2.append(MV2Block(channels[2], channels[3], 1, expansion))
194-
self.mv2.append(MV2Block(channels[3], channels[4], 2, expansion))
195-
self.mv2.append(MV2Block(channels[5], channels[6], 2, expansion))
196-
self.mv2.append(MV2Block(channels[7], channels[8], 2, expansion))
197-
198-
self.mvit = nn.ModuleList([])
199-
self.mvit.append(MobileViTBlock(dims[0], L[0], channels[5], kernel_size, patch_size, int(dims[0] * 2)))
200-
self.mvit.append(MobileViTBlock(dims[1], L[1], channels[7], kernel_size, patch_size, int(dims[1] * 4)))
201-
self.mvit.append(MobileViTBlock(dims[2], L[2], channels[9], kernel_size, patch_size, int(dims[2] * 4)))
202-
203-
self.conv2 = conv_1x1_bn(channels[-2], channels[-1])
204-
205-
self.pool = nn.AvgPool2d(ih // 32, 1)
206-
self.fc = nn.Linear(channels[-1], num_classes, bias=False)
197+
init_dim, *_, last_dim = channels
198+
199+
self.conv1 = conv_bn_relu(3, init_dim, kernel=3, stride=2)
200+
201+
self.stem = nn.ModuleList([])
202+
self.stem.append(MV2Block(channels[0], channels[1], 1, expansion))
203+
self.stem.append(MV2Block(channels[1], channels[2], 2, expansion))
204+
self.stem.append(MV2Block(channels[2], channels[3], 1, expansion))
205+
self.stem.append(MV2Block(channels[2], channels[3], 1, expansion))
206+
207+
self.trunk = nn.ModuleList([])
208+
self.trunk.append(nn.ModuleList([
209+
MV2Block(channels[3], channels[4], 2, expansion),
210+
MobileViTBlock(dims[0], depths[0], channels[5], kernel_size, patch_size, int(dims[0] * 2))
211+
]))
212+
213+
self.trunk.append(nn.ModuleList([
214+
MV2Block(channels[5], channels[6], 2, expansion),
215+
MobileViTBlock(dims[1], depths[1], channels[7], kernel_size, patch_size, int(dims[1] * 4))
216+
]))
217+
218+
self.trunk.append(nn.ModuleList([
219+
MV2Block(channels[7], channels[8], 2, expansion),
220+
MobileViTBlock(dims[2], depths[2], channels[9], kernel_size, patch_size, int(dims[2] * 4))
221+
]))
222+
223+
self.to_logits = nn.Sequential(
224+
conv_1x1_bn(channels[-2], last_dim),
225+
Reduce('b c h w -> b c', 'mean'),
226+
nn.Linear(channels[-1], num_classes, bias=False)
227+
)
207228

208229
def forward(self, x):
209230
x = self.conv1(x)
210-
x = self.mv2[0](x)
211-
212-
x = self.mv2[1](x)
213-
x = self.mv2[2](x)
214-
x = self.mv2[3](x)
215231

216-
x = self.mv2[4](x)
217-
x = self.mvit[0](x)
232+
for conv in self.stem:
233+
x = conv(x)
218234

219-
x = self.mv2[5](x)
220-
x = self.mvit[1](x)
221-
222-
x = self.mv2[6](x)
223-
x = self.mvit[2](x)
224-
x = self.conv2(x)
225-
226-
x = self.pool(x).view(-1, x.shape[1])
227-
x = self.fc(x)
228-
return x
235+
for conv, attn in self.trunk:
236+
x = conv(x)
237+
x = attn(x)
229238

239+
return self.to_logits(x)

0 commit comments

Comments
0 (0)
Morty Proxy This is a proxified and sanitized view of the page, visit original site.