Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Appearance settings

Commit d19c436

Browse filesBrowse files
committed
[FSDP][optim_state_dict] Call synchronize() to ensure the temporary tensors being recycled
emporary tensors could not be recycled unless the operations are finished. Calling synchronize() can ensure all the operations are finished. The action can prevent OOM from happening. Differential Revision: [D52890462](https://our.internmc.facebook.com/intern/diff/D52890462/) ghstack-source-id: 212464431 Pull Request resolved: #117799
1 parent 2c5488d commit d19c436
Copy full SHA for d19c436

File tree

Expand file treeCollapse file tree

1 file changed

+17
-1
lines changed
Filter options
Expand file treeCollapse file tree

1 file changed

+17
-1
lines changed

‎torch/distributed/fsdp/_optim_utils.py

Copy file name to clipboardExpand all lines: torch/distributed/fsdp/_optim_utils.py
+17-1Lines changed: 17 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -393,6 +393,7 @@ def _shard_orig_param_state(
393393
)
394394
if not shard_param_info.in_shard:
395395
return {}
396+
396397
# Flatten and shard the state.
397398
new_optim_state: Dict[str, Any] = {}
398399
intra_param_start_idx = shard_param_info.intra_param_start_idx
@@ -403,8 +404,11 @@ def _shard_orig_param_state(
403404
and value.dim() > 0
404405
and fsdp_state.sharding_strategy != ShardingStrategy.NO_SHARD
405406
):
406-
value = value.flatten()[intra_param_start_idx : intra_param_end_idx + 1].clone() # type: ignore[operator]
407+
value = value.flatten()[
408+
intra_param_start_idx : intra_param_end_idx + 1
409+
].clone() # type: ignore[operator]
407410
new_optim_state[state_name] = value
411+
del optim_state
408412
return new_optim_state
409413

410414

@@ -461,6 +465,8 @@ def _flatten_optim_state_dict(
461465
unflat_osd_state = unflat_osd["state"]
462466
all_state_keys = set(unflat_osd_state.keys())
463467

468+
sync_threshold = 200
469+
curr_numel = 0
464470
for param, fqns in param_to_fqns.items():
465471
fqn = fqns[0]
466472
if fqn not in unflat_osd_state:
@@ -485,6 +491,7 @@ def _flatten_optim_state_dict(
485491
fqn,
486492
unflat_osd_state[fqn],
487493
)
494+
488495
else:
489496
flat_state = _flatten_optim_state(
490497
fsdp_param_info,
@@ -515,6 +522,15 @@ def _flatten_optim_state_dict(
515522
f"The state of {key} is empty. This should happen when "
516523
"use_orig_params=True."
517524
)
525+
526+
for t in flat_state.items():
527+
if torch.is_tensor(t):
528+
curr_numel += t.numel()
529+
# Call synchronize() to ensure the some temporary tensors being recycled.
530+
if curr_numel > sync_threshold:
531+
torch.cuda.Synchronize()
532+
curr_numel = 0
533+
518534
else: # do not flatten non-FSDP parameters' states
519535
assert len(fqns) == 1
520536
key = _OptimStateKey(tuple(fqns), False)

0 commit comments

Comments
0 (0)
Morty Proxy This is a proxified and sanitized view of the page, visit original site.