Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Appearance settings

Commit dfcfa20

Browse filesBrowse files
committed
add proposed parallel vit from facebook ai for exploration purposes
1 parent c2b2db2 commit dfcfa20
Copy full SHA for dfcfa20

File tree

Expand file treeCollapse file tree

4 files changed

+179
-1
lines changed
Filter options
Expand file treeCollapse file tree

4 files changed

+179
-1
lines changed

‎README.md

Copy file name to clipboardExpand all lines: README.md
+41Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@
2727
- [Adaptive Token Sampling](#adaptive-token-sampling)
2828
- [Patch Merger](#patch-merger)
2929
- [Vision Transformer for Small Datasets](#vision-transformer-for-small-datasets)
30+
- [Parallel ViT](#parallel-vit)
3031
- [Dino](#dino)
3132
- [Accessing Attention](#accessing-attention)
3233
- [Research Ideas](#research-ideas)
@@ -240,6 +241,7 @@ preds = v(img) # (1, 1000)
240241
```
241242

242243
## CCT
244+
243245
<img src="https://raw.githubusercontent.com/SHI-Labs/Compact-Transformers/main/images/model_sym.png" width="400px"></img>
244246

245247
<a href="https://arxiv.org/abs/2104.05704">CCT</a> proposes compact transformers
@@ -866,6 +868,37 @@ img = torch.randn(4, 3, 256, 256)
866868
tokens = spt(img) # (4, 256, 1024)
867869
```
868870

871+
## Parallel ViT
872+
873+
<img src="./images/parallel-vit.png" width="350px"></img>
874+
875+
This <a href="https://arxiv.org/abs/2203.09795">paper</a> propose parallelizing multiple attention and feedforward blocks per layer (2 blocks), claiming that it is easier to train without loss of performance.
876+
877+
You can try this variant as follows
878+
879+
```python
880+
import torch
881+
from vit_pytorch.parallel_vit import ViT
882+
883+
v = ViT(
884+
image_size = 256,
885+
patch_size = 16,
886+
num_classes = 1000,
887+
dim = 1024,
888+
depth = 12,
889+
heads = 8,
890+
mlp_dim = 2048,
891+
num_parallel_branches = 2, # in paper, they claimed 2 was optimal
892+
dropout = 0.1,
893+
emb_dropout = 0.1
894+
)
895+
896+
img = torch.randn(4, 3, 256, 256)
897+
898+
preds = v(img) # (4, 1000)
899+
```
900+
901+
869902
## Dino
870903

871904
<img src="./images/dino.png" width="350px"></img>
@@ -1396,6 +1429,14 @@ Coming from computer vision and new to transformers? Here are some resources tha
13961429
}
13971430
```
13981431

1432+
```bibtex
1433+
@inproceedings{Touvron2022ThreeTE,
1434+
title = {Three things everyone should know about Vision Transformers},
1435+
author = {Hugo Touvron and Matthieu Cord and Alaaeldin El-Nouby and Jakob Verbeek and Herv'e J'egou},
1436+
year = {2022}
1437+
}
1438+
```
1439+
13991440
```bibtex
14001441
@misc{vaswani2017attention,
14011442
title = {Attention Is All You Need},

‎images/parallel-vit.png

Copy file name to clipboard
14.3 KB
Loading

‎setup.py

Copy file name to clipboardExpand all lines: setup.py
+1-1Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
setup(
44
name = 'vit-pytorch',
55
packages = find_packages(exclude=['examples']),
6-
version = '0.28.2',
6+
version = '0.29.0',
77
license='MIT',
88
description = 'Vision Transformer (ViT) - Pytorch',
99
author = 'Phil Wang',

‎vit_pytorch/parallel_vit.py

Copy file name to clipboard
+137Lines changed: 137 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,137 @@
1+
import torch
2+
from torch import nn
3+
4+
from einops import rearrange, repeat
5+
from einops.layers.torch import Rearrange
6+
7+
# helpers
8+
9+
def pair(t):
10+
return t if isinstance(t, tuple) else (t, t)
11+
12+
# classes
13+
14+
class Parallel(nn.Module):
15+
def __init__(self, *fns):
16+
super().__init__()
17+
self.fns = nn.ModuleList(fns)
18+
19+
def forward(self, x):
20+
return sum([fn(x) for fn in self.fns])
21+
22+
class PreNorm(nn.Module):
23+
def __init__(self, dim, fn):
24+
super().__init__()
25+
self.norm = nn.LayerNorm(dim)
26+
self.fn = fn
27+
def forward(self, x, **kwargs):
28+
return self.fn(self.norm(x), **kwargs)
29+
30+
class FeedForward(nn.Module):
31+
def __init__(self, dim, hidden_dim, dropout = 0.):
32+
super().__init__()
33+
self.net = nn.Sequential(
34+
nn.Linear(dim, hidden_dim),
35+
nn.GELU(),
36+
nn.Dropout(dropout),
37+
nn.Linear(hidden_dim, dim),
38+
nn.Dropout(dropout)
39+
)
40+
def forward(self, x):
41+
return self.net(x)
42+
43+
class Attention(nn.Module):
44+
def __init__(self, dim, heads = 8, dim_head = 64, dropout = 0.):
45+
super().__init__()
46+
inner_dim = dim_head * heads
47+
project_out = not (heads == 1 and dim_head == dim)
48+
49+
self.heads = heads
50+
self.scale = dim_head ** -0.5
51+
52+
self.attend = nn.Softmax(dim = -1)
53+
self.to_qkv = nn.Linear(dim, inner_dim * 3, bias = False)
54+
55+
self.to_out = nn.Sequential(
56+
nn.Linear(inner_dim, dim),
57+
nn.Dropout(dropout)
58+
) if project_out else nn.Identity()
59+
60+
def forward(self, x):
61+
qkv = self.to_qkv(x).chunk(3, dim = -1)
62+
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = self.heads), qkv)
63+
64+
dots = torch.matmul(q, k.transpose(-1, -2)) * self.scale
65+
66+
attn = self.attend(dots)
67+
68+
out = torch.matmul(attn, v)
69+
out = rearrange(out, 'b h n d -> b n (h d)')
70+
return self.to_out(out)
71+
72+
class Transformer(nn.Module):
73+
def __init__(self, dim, depth, heads, dim_head, mlp_dim, num_parallel_branches = 2, dropout = 0.):
74+
super().__init__()
75+
self.layers = nn.ModuleList([])
76+
77+
attn_block = lambda: PreNorm(dim, Attention(dim, heads = heads, dim_head = dim_head, dropout = dropout))
78+
ff_block = lambda: PreNorm(dim, FeedForward(dim, mlp_dim, dropout = dropout))
79+
80+
for _ in range(depth):
81+
self.layers.append(nn.ModuleList([
82+
Parallel(*[attn_block() for _ in range(num_parallel_branches)]),
83+
Parallel(*[ff_block() for _ in range(num_parallel_branches)]),
84+
]))
85+
86+
def forward(self, x):
87+
for attns, ffs in self.layers:
88+
x = attns(x) + x
89+
x = ffs(x) + x
90+
return x
91+
92+
class ViT(nn.Module):
93+
def __init__(self, *, image_size, patch_size, num_classes, dim, depth, heads, mlp_dim, pool = 'cls', num_parallel_branches = 2, channels = 3, dim_head = 64, dropout = 0., emb_dropout = 0.):
94+
super().__init__()
95+
image_height, image_width = pair(image_size)
96+
patch_height, patch_width = pair(patch_size)
97+
98+
assert image_height % patch_height == 0 and image_width % patch_width == 0, 'Image dimensions must be divisible by the patch size.'
99+
100+
num_patches = (image_height // patch_height) * (image_width // patch_width)
101+
patch_dim = channels * patch_height * patch_width
102+
assert pool in {'cls', 'mean'}, 'pool type must be either cls (cls token) or mean (mean pooling)'
103+
104+
self.to_patch_embedding = nn.Sequential(
105+
Rearrange('b c (h p1) (w p2) -> b (h w) (p1 p2 c)', p1 = patch_height, p2 = patch_width),
106+
nn.Linear(patch_dim, dim),
107+
)
108+
109+
self.pos_embedding = nn.Parameter(torch.randn(1, num_patches + 1, dim))
110+
self.cls_token = nn.Parameter(torch.randn(1, 1, dim))
111+
self.dropout = nn.Dropout(emb_dropout)
112+
113+
self.transformer = Transformer(dim, depth, heads, dim_head, mlp_dim, num_parallel_branches, dropout)
114+
115+
self.pool = pool
116+
self.to_latent = nn.Identity()
117+
118+
self.mlp_head = nn.Sequential(
119+
nn.LayerNorm(dim),
120+
nn.Linear(dim, num_classes)
121+
)
122+
123+
def forward(self, img):
124+
x = self.to_patch_embedding(img)
125+
b, n, _ = x.shape
126+
127+
cls_tokens = repeat(self.cls_token, '() n d -> b n d', b = b)
128+
x = torch.cat((cls_tokens, x), dim=1)
129+
x += self.pos_embedding[:, :(n + 1)]
130+
x = self.dropout(x)
131+
132+
x = self.transformer(x)
133+
134+
x = x.mean(dim = 1) if self.pool == 'mean' else x[:, 0]
135+
136+
x = self.to_latent(x)
137+
return self.mlp_head(x)

0 commit comments

Comments
0 (0)
Morty Proxy This is a proxified and sanitized view of the page, visit original site.