Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Appearance settings

Commit e051522

Browse filesBrowse files
committed
add vit with patch dropout, fully embrace structured dropout as multiple papers are now corroborating each other
1 parent 2f87c0c commit e051522
Copy full SHA for e051522

File tree

3 files changed

+163
-1
lines changed
Filter options

3 files changed

+163
-1
lines changed

‎README.md

Copy file name to clipboardExpand all lines: README.md
+10Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1884,4 +1884,14 @@ Coming from computer vision and new to transformers? Here are some resources tha
18841884
}
18851885
```
18861886

1887+
```bibtex
1888+
@article{Liu2022PatchDropoutEV,
1889+
title = {PatchDropout: Economizing Vision Transformers Using Patch Dropout},
1890+
author = {Yue Liu and Christos Matsoukas and Fredrik Strand and Hossein Azizpour and Kevin Smith},
1891+
journal = {ArXiv},
1892+
year = {2022},
1893+
volume = {abs/2208.07220}
1894+
}
1895+
```
1896+
18871897
*I visualise a time when we will be to robots what dogs are to humans, and I’m rooting for the machines.* — Claude Shannon

‎setup.py

Copy file name to clipboardExpand all lines: setup.py
+1-1Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
setup(
44
name = 'vit-pytorch',
55
packages = find_packages(exclude=['examples']),
6-
version = '0.39.1',
6+
version = '0.40.1',
77
license='MIT',
88
description = 'Vision Transformer (ViT) - Pytorch',
99
long_description_content_type = 'text/markdown',

‎vit_pytorch/vit_with_patch_dropout.py

Copy file name to clipboard
+152Lines changed: 152 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,152 @@
1+
import torch
2+
from torch import nn
3+
4+
from einops import rearrange, repeat
5+
from einops.layers.torch import Rearrange
6+
7+
# helpers
8+
9+
def pair(t):
10+
return t if isinstance(t, tuple) else (t, t)
11+
12+
# classes
13+
14+
class PatchDropout(nn.Module):
15+
def __init__(self, prob):
16+
super().__init__()
17+
assert 0 <= prob < 1.
18+
self.prob = prob
19+
20+
def forward(self, x):
21+
if not self.training or self.prob == 0.:
22+
return x
23+
24+
b, n, _, device = *x.shape, x.device
25+
26+
batch_indices = torch.arange(b, device = device)
27+
batch_indices = rearrange(batch_indices, '... -> ... 1')
28+
num_patches_keep = max(1, int(n * (1 - self.prob)))
29+
patch_indices_keep = torch.randn(b, n, device = device).topk(num_patches_keep, dim = -1).indices
30+
31+
return x[batch_indices, patch_indices_keep]
32+
33+
class PreNorm(nn.Module):
34+
def __init__(self, dim, fn):
35+
super().__init__()
36+
self.norm = nn.LayerNorm(dim)
37+
self.fn = fn
38+
def forward(self, x, **kwargs):
39+
return self.fn(self.norm(x), **kwargs)
40+
41+
class FeedForward(nn.Module):
42+
def __init__(self, dim, hidden_dim, dropout = 0.):
43+
super().__init__()
44+
self.net = nn.Sequential(
45+
nn.Linear(dim, hidden_dim),
46+
nn.GELU(),
47+
nn.Dropout(dropout),
48+
nn.Linear(hidden_dim, dim),
49+
nn.Dropout(dropout)
50+
)
51+
def forward(self, x):
52+
return self.net(x)
53+
54+
class Attention(nn.Module):
55+
def __init__(self, dim, heads = 8, dim_head = 64, dropout = 0.):
56+
super().__init__()
57+
inner_dim = dim_head * heads
58+
project_out = not (heads == 1 and dim_head == dim)
59+
60+
self.heads = heads
61+
self.scale = dim_head ** -0.5
62+
63+
self.attend = nn.Softmax(dim = -1)
64+
self.dropout = nn.Dropout(dropout)
65+
66+
self.to_qkv = nn.Linear(dim, inner_dim * 3, bias = False)
67+
68+
self.to_out = nn.Sequential(
69+
nn.Linear(inner_dim, dim),
70+
nn.Dropout(dropout)
71+
) if project_out else nn.Identity()
72+
73+
def forward(self, x):
74+
qkv = self.to_qkv(x).chunk(3, dim = -1)
75+
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = self.heads), qkv)
76+
77+
dots = torch.matmul(q, k.transpose(-1, -2)) * self.scale
78+
79+
attn = self.attend(dots)
80+
attn = self.dropout(attn)
81+
82+
out = torch.matmul(attn, v)
83+
out = rearrange(out, 'b h n d -> b n (h d)')
84+
return self.to_out(out)
85+
86+
class Transformer(nn.Module):
87+
def __init__(self, dim, depth, heads, dim_head, mlp_dim, dropout = 0.):
88+
super().__init__()
89+
self.layers = nn.ModuleList([])
90+
for _ in range(depth):
91+
self.layers.append(nn.ModuleList([
92+
PreNorm(dim, Attention(dim, heads = heads, dim_head = dim_head, dropout = dropout)),
93+
PreNorm(dim, FeedForward(dim, mlp_dim, dropout = dropout))
94+
]))
95+
def forward(self, x):
96+
for attn, ff in self.layers:
97+
x = attn(x) + x
98+
x = ff(x) + x
99+
return x
100+
101+
class ViT(nn.Module):
102+
def __init__(self, *, image_size, patch_size, num_classes, dim, depth, heads, mlp_dim, pool = 'cls', channels = 3, dim_head = 64, dropout = 0., emb_dropout = 0., patch_dropout = 0.25):
103+
super().__init__()
104+
image_height, image_width = pair(image_size)
105+
patch_height, patch_width = pair(patch_size)
106+
107+
assert image_height % patch_height == 0 and image_width % patch_width == 0, 'Image dimensions must be divisible by the patch size.'
108+
109+
num_patches = (image_height // patch_height) * (image_width // patch_width)
110+
patch_dim = channels * patch_height * patch_width
111+
assert pool in {'cls', 'mean'}, 'pool type must be either cls (cls token) or mean (mean pooling)'
112+
113+
self.to_patch_embedding = nn.Sequential(
114+
Rearrange('b c (h p1) (w p2) -> b (h w) (p1 p2 c)', p1 = patch_height, p2 = patch_width),
115+
nn.Linear(patch_dim, dim),
116+
)
117+
118+
self.pos_embedding = nn.Parameter(torch.randn(num_patches, dim))
119+
self.cls_token = nn.Parameter(torch.randn(1, 1, dim))
120+
121+
self.patch_dropout = PatchDropout(patch_dropout)
122+
self.dropout = nn.Dropout(emb_dropout)
123+
124+
self.transformer = Transformer(dim, depth, heads, dim_head, mlp_dim, dropout)
125+
126+
self.pool = pool
127+
self.to_latent = nn.Identity()
128+
129+
self.mlp_head = nn.Sequential(
130+
nn.LayerNorm(dim),
131+
nn.Linear(dim, num_classes)
132+
)
133+
134+
def forward(self, img):
135+
x = self.to_patch_embedding(img)
136+
b, n, _ = x.shape
137+
138+
x += self.pos_embedding
139+
140+
x = self.patch_dropout(x)
141+
142+
cls_tokens = repeat(self.cls_token, '1 1 d -> b 1 d', b = b)
143+
144+
x = torch.cat((cls_tokens, x), dim=1)
145+
x = self.dropout(x)
146+
147+
x = self.transformer(x)
148+
149+
x = x.mean(dim = 1) if self.pool == 'mean' else x[:, 0]
150+
151+
x = self.to_latent(x)
152+
return self.mlp_head(x)

0 commit comments

Comments
0 (0)
Morty Proxy This is a proxified and sanitized view of the page, visit original site.