Skip to content

Navigation Menu

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Appearance settings

Commit 1de866d

Browse filesBrowse files
committed
add the proposed jumbo vit from Fuller et al. of Carleton University
1 parent 9f49a31 commit 1de866d
Copy full SHA for 1de866d

File tree

3 files changed

+211
-1
lines changed
Filter options

3 files changed

+211
-1
lines changed

‎README.md

Copy file name to clipboardExpand all lines: README.md
+9
Original file line numberDiff line numberDiff line change
@@ -2172,4 +2172,13 @@ Coming from computer vision and new to transformers? Here are some resources tha
21722172
}
21732173
```
21742174

2175+
```bibtex
2176+
@inproceedings{Fuller2025SimplerFV,
2177+
title = {Simpler Fast Vision Transformers with a Jumbo CLS Token},
2178+
author = {Anthony Fuller and Yousef Yassin and Daniel G. Kyrollos and Evan Shelhamer and James R. Green},
2179+
year = {2025},
2180+
url = {https://api.semanticscholar.org/CorpusID:276557720}
2181+
}
2182+
```
2183+
21752184
*I visualise a time when we will be to robots what dogs are to humans, and I’m rooting for the machines.* — Claude Shannon

‎setup.py

Copy file name to clipboardExpand all lines: setup.py
+1-1
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
setup(
77
name = 'vit-pytorch',
88
packages = find_packages(exclude=['examples']),
9-
version = '1.9.2',
9+
version = '1.10.1',
1010
license='MIT',
1111
description = 'Vision Transformer (ViT) - Pytorch',
1212
long_description = long_description,

‎vit_pytorch/jumbo_vit.py

Copy file name to clipboard
+201
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,201 @@
1+
import torch
2+
from torch import nn
3+
from torch.nn import Module, ModuleList
4+
5+
from einops import rearrange, repeat, reduce, pack, unpack
6+
from einops.layers.torch import Rearrange
7+
8+
# helpers
9+
10+
def pair(t):
11+
return t if isinstance(t, tuple) else (t, t)
12+
13+
def divisible_by(num, den):
14+
return (num % den) == 0
15+
16+
def posemb_sincos_2d(h, w, dim, temperature: int = 10000, dtype = torch.float32):
17+
y, x = torch.meshgrid(torch.arange(h), torch.arange(w), indexing="ij")
18+
assert divisible_by(dim, 4), "feature dimension must be multiple of 4 for sincos emb"
19+
20+
omega = torch.arange(dim // 4) / (dim // 4 - 1)
21+
omega = temperature ** -omega
22+
23+
y = y.flatten()[:, None] * omega[None, :]
24+
x = x.flatten()[:, None] * omega[None, :]
25+
pos_emb = torch.cat((x.sin(), x.cos(), y.sin(), y.cos()), dim=1)
26+
27+
return pos_emb.type(dtype)
28+
29+
# classes
30+
31+
def FeedForward(dim, mult = 4.):
32+
hidden_dim = int(dim * mult)
33+
return nn.Sequential(
34+
nn.LayerNorm(dim),
35+
nn.Linear(dim, hidden_dim),
36+
nn.GELU(),
37+
nn.Linear(hidden_dim, dim),
38+
)
39+
40+
class Attention(Module):
41+
def __init__(self, dim, heads = 8, dim_head = 64):
42+
super().__init__()
43+
inner_dim = dim_head * heads
44+
self.heads = heads
45+
self.scale = dim_head ** -0.5
46+
self.norm = nn.LayerNorm(dim)
47+
48+
self.attend = nn.Softmax(dim = -1)
49+
50+
self.to_qkv = nn.Linear(dim, inner_dim * 3, bias = False)
51+
self.to_out = nn.Linear(inner_dim, dim, bias = False)
52+
53+
def forward(self, x):
54+
x = self.norm(x)
55+
56+
qkv = self.to_qkv(x).chunk(3, dim = -1)
57+
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = self.heads), qkv)
58+
59+
dots = torch.matmul(q, k.transpose(-1, -2)) * self.scale
60+
61+
attn = self.attend(dots)
62+
63+
out = torch.matmul(attn, v)
64+
out = rearrange(out, 'b h n d -> b n (h d)')
65+
return self.to_out(out)
66+
67+
class JumboViT(Module):
68+
def __init__(
69+
self,
70+
*,
71+
image_size,
72+
patch_size,
73+
num_classes,
74+
dim,
75+
depth,
76+
heads,
77+
mlp_dim,
78+
num_jumbo_cls = 1, # differing from paper, allow for multiple jumbo cls, so one could break it up into 2 jumbo cls tokens with 3x the dim, as an example
79+
jumbo_cls_k = 6, # they use a CLS token with this factor times the dimension - 6 was the value they settled on
80+
jumbo_ff_mult = 2, # expansion factor of the jumbo cls token feedforward
81+
channels = 3,
82+
dim_head = 64
83+
):
84+
super().__init__()
85+
image_height, image_width = pair(image_size)
86+
patch_height, patch_width = pair(patch_size)
87+
88+
assert divisible_by(image_height, patch_height) and divisible_by(image_width, patch_width), 'Image dimensions must be divisible by the patch size.'
89+
90+
patch_dim = channels * patch_height * patch_width
91+
92+
self.to_patch_embedding = nn.Sequential(
93+
Rearrange("b c (h p1) (w p2) -> b (h w) (p1 p2 c)", p1 = patch_height, p2 = patch_width),
94+
nn.LayerNorm(patch_dim),
95+
nn.Linear(patch_dim, dim),
96+
nn.LayerNorm(dim),
97+
)
98+
99+
self.pos_embedding = posemb_sincos_2d(
100+
h = image_height // patch_height,
101+
w = image_width // patch_width,
102+
dim = dim,
103+
)
104+
105+
jumbo_cls_dim = dim * jumbo_cls_k
106+
107+
self.jumbo_cls_token = nn.Parameter(torch.zeros(num_jumbo_cls, jumbo_cls_dim))
108+
109+
jumbo_cls_to_tokens = Rearrange('b n (k d) -> b (n k) d', k = jumbo_cls_k)
110+
self.jumbo_cls_to_tokens = jumbo_cls_to_tokens
111+
112+
self.norm = nn.LayerNorm(dim)
113+
self.layers = ModuleList([])
114+
115+
# attention and feedforwards
116+
117+
self.jumbo_ff = nn.Sequential(
118+
Rearrange('b (n k) d -> b n (k d)', k = jumbo_cls_k),
119+
FeedForward(jumbo_cls_dim, int(jumbo_cls_dim * jumbo_ff_mult)), # they use separate parameters for the jumbo feedforward, weight tied for parameter efficient
120+
jumbo_cls_to_tokens
121+
)
122+
123+
for _ in range(depth):
124+
self.layers.append(ModuleList([
125+
Attention(dim, heads = heads, dim_head = dim_head),
126+
FeedForward(dim, mlp_dim),
127+
]))
128+
129+
self.to_latent = nn.Identity()
130+
131+
self.linear_head = nn.Linear(dim, num_classes)
132+
133+
def forward(self, img):
134+
135+
batch, device = img.shape[0], img.device
136+
137+
x = self.to_patch_embedding(img)
138+
139+
# pos embedding
140+
141+
pos_emb = self.pos_embedding.to(device, dtype = x.dtype)
142+
143+
x = x + pos_emb
144+
145+
# add cls tokens
146+
147+
cls_tokens = repeat(self.jumbo_cls_token, 'nj d -> b nj d', b = batch)
148+
149+
jumbo_tokens = self.jumbo_cls_to_tokens(cls_tokens)
150+
151+
x, cls_packed_shape = pack([jumbo_tokens, x], 'b * d')
152+
153+
# attention and feedforwards
154+
155+
for layer, (attn, ff) in enumerate(self.layers, start = 1):
156+
is_last = layer == len(self.layers)
157+
158+
x = attn(x) + x
159+
160+
# jumbo feedforward
161+
162+
jumbo_cls_tokens, x = unpack(x, cls_packed_shape, 'b * d')
163+
164+
x = ff(x) + x
165+
jumbo_cls_tokens = self.jumbo_ff(jumbo_cls_tokens) + jumbo_cls_tokens
166+
167+
if is_last:
168+
continue
169+
170+
x, _ = pack([jumbo_cls_tokens, x], 'b * d')
171+
172+
pooled = reduce(jumbo_cls_tokens, 'b n d -> b d', 'mean')
173+
174+
# normalization and project to logits
175+
176+
embed = self.norm(pooled)
177+
178+
embed = self.to_latent(embed)
179+
logits = self.linear_head(embed)
180+
return logits
181+
182+
# copy pasteable file
183+
184+
if __name__ == '__main__':
185+
186+
v = JumboViT(
187+
num_classes = 1000,
188+
image_size = 64,
189+
patch_size = 8,
190+
dim = 16,
191+
depth = 2,
192+
heads = 2,
193+
mlp_dim = 32,
194+
jumbo_cls_k = 3,
195+
jumbo_ff_mult = 2,
196+
)
197+
198+
images = torch.randn(1, 3, 64, 64)
199+
200+
logits = v(images)
201+
assert logits.shape == (1, 1000)

0 commit comments

Comments
0 (0)
Morty Proxy This is a proxified and sanitized view of the page, visit original site.