Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Appearance settings

Commit 7092296

Browse filesBrowse files
authored
FSDP2 example code for tutorial (#1343)
* FSDP2 example Summary: Test Plan: Reviewers: Subscribers: Tasks: Tags: * update README Summary: Test Plan: Reviewers: Subscribers: Tasks: Tags: * fix typo in README Summary: Test Plan: Reviewers: Subscribers: Tasks: Tags: * fix README Summary: Test Plan: Reviewers: Subscribers: Tasks: Tags:
1 parent 54e132e commit 7092296
Copy full SHA for 7092296

File tree

Expand file treeCollapse file tree

5 files changed

+492
-0
lines changed
Filter options
Expand file treeCollapse file tree

5 files changed

+492
-0
lines changed

‎distributed/FSDP2/README.md

Copy file name to clipboard
+26Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,26 @@
1+
## FSDP2
2+
To run FSDP2 on transformer model:
3+
```
4+
cd distributed/FSDP2
5+
torchrun --nproc_per_node 2 train.py
6+
```
7+
* For 1st time, it creates a "checkpoints" folder and saves state dicts there
8+
* For 2nd time, it loads from previous checkpoints
9+
10+
To enable explicit prefetching
11+
```
12+
torchrun --nproc_per_node 2 train.py --explicit-prefetch
13+
```
14+
15+
To enable mixed precision
16+
```
17+
torchrun --nproc_per_node 2 train.py --mixed-precision
18+
```
19+
20+
To showcase DCP API
21+
```
22+
torchrun --nproc_per_node 2 train.py --dcp-api
23+
```
24+
25+
## Ensure you are running a recent version of PyTorch:
26+
see https://pytorch.org/get-started/locally/ to install at least 2.5 and ideally a current nightly build.

‎distributed/FSDP2/checkpoint.py

Copy file name to clipboard
+209Lines changed: 209 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,209 @@
1+
import os
2+
import time
3+
4+
import torch
5+
import torch.nn as nn
6+
from torch.distributed.checkpoint.state_dict import (
7+
_init_optim_state,
8+
get_model_state_dict,
9+
get_optimizer_state_dict,
10+
set_model_state_dict,
11+
set_optimizer_state_dict,
12+
StateDictOptions,
13+
)
14+
from torch.distributed.fsdp import FSDPModule
15+
from torch.distributed.tensor import distribute_tensor, DTensor
16+
17+
18+
MODEL_CHECKPOINT = "model_state_dict.pt"
19+
OPTIM_CHECKPOINT = "optim_state_dict.pt"
20+
PARAMS = "params"
21+
22+
23+
def get_latest_checkpoint_folder(path):
24+
max_num = None
25+
if not os.path.exists(path):
26+
return max_num
27+
for name in os.listdir(path):
28+
folder_path = os.path.join(path, name)
29+
if os.path.isdir(folder_path):
30+
try:
31+
num = int(name)
32+
if max_num is None or num > max_num:
33+
max_num = num
34+
except ValueError:
35+
pass # Skip non-numeric folder names
36+
return max_num
37+
38+
39+
class Checkpointer:
40+
def __init__(self, folder: str, dcp_api: bool):
41+
self.folder = folder
42+
self.dcp_api = dcp_api
43+
self.last_training_time = get_latest_checkpoint_folder(
44+
f"{folder}/{'dcp_api' if dcp_api else 'dtensor_api'}"
45+
)
46+
47+
def is_empty(self):
48+
return self.last_training_time is None
49+
50+
def load_model(self, model: FSDPModule):
51+
last_model_checkpoint = (
52+
f"{self.folder}/{'dcp_api' if self.dcp_api else 'dtensor_api'}"
53+
f"/{self.last_training_time}/{MODEL_CHECKPOINT}"
54+
)
55+
full_sd = torch.load(
56+
last_model_checkpoint, mmap=True, weights_only=True, map_location="cpu"
57+
)
58+
if self.dcp_api:
59+
set_model_state_dict(
60+
model=model,
61+
model_state_dict=full_sd,
62+
options=StateDictOptions(
63+
full_state_dict=True,
64+
broadcast_from_rank0=True,
65+
),
66+
)
67+
return
68+
meta_sharded_sd = model.state_dict()
69+
sharded_sd = {}
70+
for param_name, full_tensor in full_sd.items():
71+
sharded_meta_param = meta_sharded_sd.get(param_name)
72+
sharded_tensor = distribute_tensor(
73+
full_tensor,
74+
sharded_meta_param.device_mesh,
75+
sharded_meta_param.placements,
76+
)
77+
sharded_sd[param_name] = nn.Parameter(sharded_tensor)
78+
# choose `assign=True` since we cannot call `copy_` on meta tensor
79+
model.load_state_dict(sharded_sd, strict=False, assign=True)
80+
81+
def load_optim(self, model: FSDPModule, opt: torch.optim.Optimizer):
82+
last_optim_checkpoint = (
83+
f"{self.folder}/{'dcp_api' if self.dcp_api else 'dtensor_api'}"
84+
f"/{self.last_training_time}/{OPTIM_CHECKPOINT}"
85+
)
86+
full_sd = torch.load(
87+
last_optim_checkpoint, mmap=True, weights_only=True, map_location="cpu"
88+
)
89+
if self.dcp_api:
90+
set_optimizer_state_dict(
91+
model=model,
92+
optimizers=opt,
93+
optim_state_dict=full_sd,
94+
options=StateDictOptions(
95+
full_state_dict=True,
96+
broadcast_from_rank0=True,
97+
),
98+
)
99+
return
100+
_init_optim_state(opt)
101+
param_groups = opt.state_dict()["param_groups"]
102+
state = opt.state_dict()["state"]
103+
104+
full_param_groups = full_sd["param_groups"]
105+
full_state = full_sd["state"]
106+
107+
for param_group, full_param_group in zip(param_groups, full_param_groups):
108+
for key, value in full_param_group.items():
109+
if key == PARAMS:
110+
continue
111+
param_group[key] = value
112+
for pid, full_pid in zip(param_group[PARAMS], full_param_group[PARAMS]):
113+
if pid not in state:
114+
continue
115+
param_state = state[pid]
116+
full_param_state = full_state[full_pid]
117+
for attr, full_tensor in full_param_state.items():
118+
sharded_tensor = param_state[attr]
119+
if isinstance(sharded_tensor, DTensor):
120+
# exp_avg is DTensor
121+
param_state[attr] = distribute_tensor(
122+
full_tensor,
123+
sharded_tensor.device_mesh,
124+
sharded_tensor.placements,
125+
)
126+
else:
127+
# step is plain tensor
128+
param_state[attr] = full_tensor
129+
opt.load_state_dict(
130+
{
131+
"param_groups": param_groups,
132+
"state": state,
133+
}
134+
)
135+
136+
def _get_full_model_state_dict(self, model: FSDPModule):
137+
if self.dcp_api:
138+
return get_model_state_dict(
139+
model=model,
140+
options=StateDictOptions(
141+
full_state_dict=True,
142+
cpu_offload=True,
143+
),
144+
)
145+
146+
sharded_sd = model.state_dict()
147+
cpu_state_dict = {}
148+
for param_name, sharded_param in sharded_sd.items():
149+
full_param = sharded_param.full_tensor()
150+
if torch.distributed.get_rank() == 0:
151+
cpu_state_dict[param_name] = full_param.cpu()
152+
else:
153+
del full_param
154+
return cpu_state_dict
155+
156+
def _get_full_optimizer_state_dict(
157+
self,
158+
model: FSDPModule,
159+
opt: torch.optim.Optimizer,
160+
):
161+
if self.dcp_api:
162+
return get_optimizer_state_dict(
163+
model=model,
164+
optimizers=opt,
165+
options=StateDictOptions(
166+
full_state_dict=True,
167+
cpu_offload=True,
168+
),
169+
)
170+
is_rank_zero = torch.distributed.get_rank() == 0
171+
sharded_sd = opt.state_dict()
172+
sharded_state = sharded_sd["state"]
173+
full_state = {}
174+
for group_id, sharded_group in sharded_state.items():
175+
group_state = {}
176+
for attr, sharded_tensor in sharded_group.items():
177+
if isinstance(sharded_tensor, DTensor):
178+
# "exp_avg" in AdamW is `DTensor`
179+
full_tensor = sharded_tensor.full_tensor()
180+
else:
181+
# "step" in AdamW is plain tensor
182+
full_tensor = sharded_tensor
183+
if is_rank_zero:
184+
group_state[attr] = full_tensor.cpu()
185+
else:
186+
del full_tensor
187+
if is_rank_zero:
188+
full_state[group_id] = group_state
189+
else:
190+
del group_state
191+
if is_rank_zero:
192+
return {
193+
"param_groups": sharded_sd["param_groups"],
194+
"state": full_state,
195+
}
196+
else:
197+
return {}
198+
199+
def save(self, model: FSDPModule, optim: torch.optim.Optimizer):
200+
model_state_dict = self._get_full_model_state_dict(model)
201+
optim_state_dict = self._get_full_optimizer_state_dict(model, optim)
202+
if torch.distributed.get_rank() == 0:
203+
new_training_time = int(time.time() * 1000)
204+
new_checkpoint_folder = f"{self.folder}/{'dcp_api' if self.dcp_api else 'dtensor_api'}/{new_training_time}"
205+
new_model_checkpoint = f"{new_checkpoint_folder}/{MODEL_CHECKPOINT}"
206+
new_optim_checkpoint = f"{new_checkpoint_folder}/{OPTIM_CHECKPOINT}"
207+
os.makedirs(new_checkpoint_folder, exist_ok=True)
208+
torch.save(model_state_dict, new_model_checkpoint)
209+
torch.save(optim_state_dict, new_optim_checkpoint)

‎distributed/FSDP2/model.py

Copy file name to clipboard
+134Lines changed: 134 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,134 @@
1+
from dataclasses import dataclass
2+
3+
import torch
4+
import torch.nn as nn
5+
import torch.nn.functional as F
6+
7+
8+
@dataclass
9+
class ModelArgs:
10+
n_layers: int = 2
11+
vocab_size: int = 8
12+
max_seq_len: int = 16
13+
dim: int = 16
14+
n_heads: int = 4
15+
dropout_p: float = 0.1
16+
17+
18+
class Attention(nn.Module):
19+
def __init__(self, args: ModelArgs):
20+
super().__init__()
21+
assert args.dim % args.n_heads == 0
22+
self.head_dim = args.dim // args.n_heads
23+
self.n_heads = args.n_heads
24+
self.dropout_p = args.dropout_p
25+
self.resid_dropout = nn.Dropout(args.dropout_p)
26+
27+
self.wq = nn.Linear(args.dim, args.dim, bias=False)
28+
self.wk = nn.Linear(args.dim, args.dim, bias=False)
29+
self.wv = nn.Linear(args.dim, args.dim, bias=False)
30+
self.wo = nn.Linear(args.dim, args.dim, bias=False)
31+
32+
def forward(self, x):
33+
bsz, seq_len, _ = x.size()
34+
queries, keys, values = self.wq(x), self.wk(x), self.wv(x)
35+
queries = queries.view(bsz, seq_len, self.n_heads, self.head_dim)
36+
keys = keys.view(bsz, seq_len, self.n_heads, self.head_dim)
37+
values = values.view(bsz, seq_len, self.n_heads, self.head_dim)
38+
39+
queries = queries.transpose(1, 2) # (bsz, n_heads, seq_len, head_dim)
40+
keys = keys.transpose(1, 2) # (bsz, n_heads, seq_len, head_dim)
41+
values = values.transpose(1, 2) # (bsz, n_heads, seq_len, head_dim)
42+
43+
output = F.scaled_dot_product_attention(
44+
queries,
45+
keys,
46+
values,
47+
None,
48+
self.dropout_p if self.training else 0,
49+
)
50+
output = output.transpose(1, 2).contiguous().view(bsz, seq_len, -1)
51+
return self.resid_dropout(self.wo(output))
52+
53+
def reset_parameters(self):
54+
self.wq.reset_parameters()
55+
self.wk.reset_parameters()
56+
self.wv.reset_parameters()
57+
self.wo.reset_parameters()
58+
59+
60+
class FeedForward(nn.Module):
61+
def __init__(self, dim, hidden_dim, dropout_p):
62+
super().__init__()
63+
self.w1 = nn.Linear(dim, hidden_dim)
64+
self.gelu = nn.GELU()
65+
self.w2 = nn.Linear(hidden_dim, dim)
66+
self.resid_dropout = nn.Dropout(dropout_p)
67+
68+
def forward(self, x):
69+
return self.resid_dropout(self.w2(self.gelu(self.w1(x))))
70+
71+
def reset_parameters(self):
72+
self.w1.reset_parameters()
73+
self.w2.reset_parameters()
74+
75+
76+
class TransformerBlock(nn.Module):
77+
def __init__(self, args: ModelArgs):
78+
super().__init__()
79+
self.attention_norm = nn.LayerNorm(args.dim)
80+
self.attention = Attention(args)
81+
self.ffn_norm = nn.LayerNorm(args.dim)
82+
self.feed_forward = FeedForward(
83+
args.dim, hidden_dim=4 * args.dim, dropout_p=args.dropout_p
84+
)
85+
86+
def forward(self, x):
87+
h = x + self.attention(self.attention_norm(x))
88+
out = h + self.feed_forward(self.ffn_norm(h))
89+
return out
90+
91+
def reset_parameters(self):
92+
self.attention_norm.reset_parameters()
93+
self.attention.reset_parameters()
94+
self.ffn_norm.reset_parameters()
95+
self.feed_forward.reset_parameters()
96+
97+
98+
# A toy transformer model, partly inspired by the nanoGPT model:
99+
# https://github.com/karpathy/nanoGPT.
100+
class Transformer(nn.Module):
101+
def __init__(self, args: ModelArgs):
102+
super().__init__()
103+
assert args.vocab_size is not None
104+
assert args.max_seq_len is not None
105+
self.model_args = args
106+
self.max_seq_len = args.max_seq_len
107+
self.tok_embeddings = nn.Embedding(args.vocab_size, args.dim)
108+
self.pos_embeddings = nn.Embedding(args.max_seq_len, args.dim)
109+
self.dropout = nn.Dropout(args.dropout_p)
110+
self.layers = nn.ModuleList()
111+
for _ in range(args.n_layers):
112+
self.layers.append(TransformerBlock(args))
113+
self.norm = nn.LayerNorm(args.dim)
114+
self.output = nn.Linear(args.dim, args.vocab_size, bias=False)
115+
116+
def forward(self, tokens):
117+
_bsz, seq_len = tokens.size()
118+
assert seq_len <= self.max_seq_len
119+
h = self.tok_embeddings(tokens)
120+
pos = torch.arange(0, seq_len, device=tokens.device)
121+
p = self.pos_embeddings(pos) # positional embeddings of shape (seq_len, dim)
122+
h = h + p
123+
h = self.dropout(h)
124+
for layer in self.layers:
125+
h = layer(h)
126+
h = self.norm(h)
127+
output = self.output(h).float()
128+
return output
129+
130+
def reset_parameters(self):
131+
self.tok_embeddings.reset_parameters()
132+
self.pos_embeddings.reset_parameters()
133+
self.norm.reset_parameters()
134+
self.output.reset_parameters()

0 commit comments

Comments
0 (0)
Morty Proxy This is a proxified and sanitized view of the page, visit original site.