Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Appearance settings

Commit df651ff

Browse filesBrowse files
snarayan21petrex
authored andcommitted
Remove activation checkpointing tag to get correct FQNs (#124698)
Fixes #124546 When setting `use_orig_params = False` and using activation checkpointing, the FQN mapping as retrieved by the `_get_fqns` function is incorrect because the prefix that is added to the name of each activation checkpointed module, `_checkpoint_wrapped_module`, can still be present. I think this is an edge case with the `_get_fqns` function that was not addressed by this previous commit #118119. Without the change, the list of object names for an activation checkpointed module with FSDP (and `use_orig_params=False`) can be something like: ``` ['model', '_fsdp_wrapped_module', 'transformer', 'blocks', '0', '_fsdp_wrapped_module', '_checkpoint_wrapped_module', '_flat_param'] ``` Which will incorrectly return just one FQN, `{'model.transformer.blocks.0._flat_param'}`, when all the FQNs of the parameters of the transformer block should be returned. With the change, the list of object names will now have `_checkpoint_wrapped_module` removed: ``` ['model', '_fsdp_wrapped_module', 'transformer', 'blocks', '0', '_fsdp_wrapped_module', '_flat_param'] ``` And the FQNs are correctly retrieved and returned in `_get_fqns` when [this condition](https://github.com/pytorch/pytorch/blob/ea61c9cb299b6dfebc57dc9d8821c34321d568ab/torch/distributed/checkpoint/state_dict.py#L168) is satisfied. The correct FQNs are: ``` {'model.transformer.blocks.0.attn.Wqkv.bias', 'model.transformer.blocks.0.ffn.up_proj.bias', 'model.transformer.blocks.0.attn.out_proj.weight', 'model.transformer.blocks.0.norm_2.weight', 'model.transformer.blocks.0.ffn.down_proj.weight', 'model.transformer.blocks.0.attn.Wqkv.weight', 'model.transformer.blocks.0.norm_2.bias', 'model.transformer.blocks.0.ffn.up_proj.weight', 'model.transformer.blocks.0.ffn.down_proj.bias', 'model.transformer.blocks.0.norm_1.bias', 'model.transformer.blocks.0.norm_1.weight', 'model.transformer.blocks.0.attn.out_proj.bias'} ``` Pull Request resolved: #124698 Approved by: https://github.com/Skylion007
1 parent 9ae9be2 commit df651ff
Copy full SHA for df651ff

File tree

Expand file treeCollapse file tree

1 file changed

+4
-3
lines changed
Filter options
Expand file treeCollapse file tree

1 file changed

+4
-3
lines changed

‎torch/distributed/checkpoint/state_dict.py

Copy file name to clipboardExpand all lines: torch/distributed/checkpoint/state_dict.py
+4-3Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -152,8 +152,11 @@ def _get_fqns(
152152
Returns:
153153
The canonical FQNs based on the model traversal.
154154
"""
155+
156+
# Remove the checkpoint prefix, if it exists.
157+
name = name.replace(_CHECKPOINT_PREFIX, "")
155158
if "." not in name:
156-
return {name.replace(_CHECKPOINT_PREFIX, "")}
159+
return {name}
157160

158161
obj_names = name.split(".")
159162
fqn_obj_names = []
@@ -170,8 +173,6 @@ def _get_fqns(
170173
flat_param = getattr(curr_obj, FLAT_PARAM)
171174
if prefix:
172175
prefix = f"{prefix}."
173-
# FSDP already handles removal of checkpoint prefix, so we can return
174-
# directly
175176
return {f"{prefix}{fqn}" for fqn in flat_param._fqns}
176177
curr_obj = getattr(curr_obj, FSDP_WRAPPED_MODULE)
177178
if curr_obj_name != FSDP_WRAPPED_MODULE:

0 commit comments

Comments
0 (0)
Morty Proxy This is a proxified and sanitized view of the page, visit original site.