-
Notifications
You must be signed in to change notification settings - Fork 24.4k
[FSDP][optim_state_dict] Call synchronize() to ensure DTensors.to_local() is synchronized #117799
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
…ensors being recycled emporary tensors could not be recycled unless the operations are finished. Calling synchronize() can ensure all the operations are finished. The action can prevent OOM from happening. Differential Revision: [D52890462](https://our.internmc.facebook.com/intern/diff/D52890462/) [ghstack-poisoned]
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/117799
Note: Links to docs will display an error until the docs builds have been completed. ❌ 2 New FailuresAs of commit ead91ac with merge base 5c17f66 ( NEW FAILURES - The following jobs have failed:
This comment was automatically generated by Dr. CI and updates every 15 minutes. |
…ensors being recycled emporary tensors could not be recycled unless the operations are finished. Calling synchronize() can ensure all the operations are finished. The action can prevent OOM from happening. Differential Revision: [D52890462](https://our.internmc.facebook.com/intern/diff/D52890462/) ghstack-source-id: 212464431 Pull Request resolved: #117799
Does this part of the optim state dict load use multiple streams (which is why memory is not freed immediately when there are no more references)? |
No multiple streams. But even if the temporary tensors are not referenced in Python, it can still being used, iiuc. For example, the |
I was wondering if you have tried Suppose at first we have this:
Now, suppose we
I did not see any special handling of clone implementationpytorch/aten/src/ATen/native/TensorFactories.cpp Lines 1730 to 1751 in 2f84a9d
|
@awgu that makes sense. However, the tensor should already be deleted as it has no reference after the inner util function. So I don't think |
@awgu I can confirm that |
…temporary tensors being recycled" emporary tensors could not be recycled unless the operations are finished. Calling synchronize() can ensure all the operations are finished. The action can prevent OOM from happening. Differential Revision: [D52890462](https://our.internmc.facebook.com/intern/diff/D52890462/) cc mrshenli pritamdamania87 zhaojuanmao satgera rohan-varma gqchen aazzolini osalpekar jiayisuse H-Huang kwen2501 awgu penguinwu XilunWu wanchaol fduwjj wz337 tianyu-l wconstab yf225 [ghstack-poisoned]
@awgu The root cause is |
Thanks for looking into this! @wanchaol Is there any way to specify to If we can run synchronous collectives and avoid |
@awgu This is optimizer state_dict code path. What's the down side of using |
@fegin The downside is just performance from synchronizing the CPU. If there is not much performance downside, then using (I do not have a good mental model of the optimizer state dict performance, so maybe this is not really an issue. However, in general, I do think that considering the performance of state dict makes sense.) |
@awgu Yup, agree that performance is important. Currently, |
…memory used by DTensors.to_local() being recycled" If a tensor is converted from DTensor (to_local()), there may be some async communication that has not finished yet. Calling synchronize() can ensure all the operations are finished. The action can prevent OOM from happening. Differential Revision: [D52890462](https://our.internmc.facebook.com/intern/diff/D52890462/) cc mrshenli pritamdamania87 zhaojuanmao satgera rohan-varma gqchen aazzolini osalpekar jiayisuse H-Huang kwen2501 awgu penguinwu XilunWu wanchaol fduwjj wz337 tianyu-l wconstab yf225 [ghstack-poisoned]
…sors.to_local() is synchronized" If a tensor is converted from DTensor (to_local()), there may be some async communication that has not finished yet. Calling `clone()` with that tensor does not seem to work (and may increase the memory usage, users report OOM). This is a temporary fix. Differential Revision: [D52890462](https://our.internmc.facebook.com/intern/diff/D52890462/) cc mrshenli pritamdamania87 zhaojuanmao satgera rohan-varma gqchen aazzolini osalpekar jiayisuse H-Huang kwen2501 awgu penguinwu XilunWu wanchaol fduwjj wz337 tianyu-l wconstab yf225 [ghstack-poisoned]
…al() is synchronized Pull Request resolved: #117799 If a tensor is converted from DTensor (to_local()), there may be some async communication that has not finished yet. Calling clone() with that tensor does not seem to work (and may increase the memory usage, users report OOM). This is a temporary fix. ghstack-source-id: 212762982 @exported-using-ghexport Differential Revision: [D52890462](https://our.internmc.facebook.com/intern/diff/D52890462/)
Can we use stream sync instead of device sync? |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Lacking a bit of context but approving to unblock the fix
if only device sync works and stream sync doesn't, that means there is something wrong with DTensor's to_local() -- it should sync its communication work back to the "current stream" (or else provide a Work handle for user to sync at some point) |
If (1) proper stream dependency is maintained by DTensor's to_local() and (2) torch.clone observes stream properly, you don't even need to call stream sync. |
@kwen2501 |
If |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Landing this as a temporary fix sounds good to me!
We should separately figure out how to make AsyncCollectiveTensor
wait before clone()
(if that is root issue).
in the more general sense of a user using DTensor.to_local, i don't know if it makes sense that operation has to be synchronous. Wouldn't it be reasonable to expect a to_local call to return a new 'ACT' that represents ongoing tensor work? |
value = value.flatten()[intra_param_start_idx : intra_param_end_idx + 1].clone() # type: ignore[operator] | ||
if fsdp_state._device_mesh is not None: | ||
# We have to call synchronize() if the tensor is gathered from | ||
# DTensor. Otherwise, the later `clone()` will cause errors. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
are you using the full_tensor()
API to do the gathering or something else? iirc full_tensor
gives sync behavior
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
are you using the
full_tensor()
API to do the gathering or something else? iircfull_tensor
gives sync behavior
No. We are still using redistribute for the all_gather, since full_tensor()
API was introduced later.
Maybe we could update this line to use full_tensor
API? https://github.com/pytorch/pytorch/blob/main/torch/distributed/fsdp/_optim_utils.py#L1455
The reasonableness depends on how much tendency users have to write
|
@kwen2501 |
@awgu You are right. Thanks for the correction! |
See the discussion in #117799. There are some issues when returning a AsyncCollectiveTensor (haven't found the root causes), including OOM and unexpected values. This PR force `_gather_state_dict()` to be synchronous with respect to the mian stream. Differential Revision: [D53049807](https://our.internmc.facebook.com/intern/diff/D53049807/) [ghstack-poisoned]
I have not found the root cause. Since |
See the discussion in #117799. There are some issues when returning a AsyncCollectiveTensor (haven't found the root causes), including OOM and unexpected values. This PR forces `_gather_state_dict()` to be synchronous with respect to the mian stream. Differential Revision: [D53049807](https://our.internmc.facebook.com/intern/diff/D53049807/) Pull Request resolved: #118197 Approved by: https://github.com/wz337, https://github.com/LucasLLC
…to_local() result" See the discussion in #117799. There are some issues when returning a AsyncCollectiveTensor (haven't found the root causes), including OOM and unexpected values. This PR forces `_gather_state_dict()` to be synchronous with respect to the mian stream. Differential Revision: [D53049807](https://our.internmc.facebook.com/intern/diff/D53049807/) cc mrshenli pritamdamania87 zhaojuanmao satgera rohan-varma gqchen aazzolini osalpekar jiayisuse H-Huang kwen2501 awgu penguinwu XilunWu wanchaol fduwjj wz337 tianyu-l wconstab yf225 [ghstack-poisoned]
See the discussion in #117799. There are some issues when returning a AsyncCollectiveTensor (haven't found the root causes), including OOM and unexpected values. This PR forces `_gather_state_dict()` to be synchronous with respect to the mian stream. Differential Revision: [D53049807](https://our.internmc.facebook.com/intern/diff/D53049807/) cc mrshenli pritamdamania87 zhaojuanmao satgera rohan-varma gqchen aazzolini osalpekar jiayisuse H-Huang kwen2501 awgu penguinwu XilunWu wanchaol fduwjj wz337 tianyu-l wconstab yf225 [ghstack-poisoned]
…118197) See the discussion in pytorch#117799. There are some issues when returning a AsyncCollectiveTensor (haven't found the root causes), including OOM and unexpected values. This PR forces `_gather_state_dict()` to be synchronous with respect to the mian stream. Differential Revision: [D53049807](https://our.internmc.facebook.com/intern/diff/D53049807/) Pull Request resolved: pytorch#118197 Approved by: https://github.com/wz337, https://github.com/LucasLLC
Stack from ghstack (oldest at bottom):
If a tensor is converted from DTensor (to_local()), there may be some async communication that has not finished yet. Calling
clone()
with that tensor does not seem to work (and may increase the memory usage, users report OOM). This is a temporary fix.Differential Revision: D52890462
cc @mrshenli @pritamdamania87 @zhaojuanmao @satgera @rohan-varma @gqchen @aazzolini @osalpekar @jiayisuse @H-Huang @kwen2501 @awgu @penguinwu @XilunWu @wanchaol @fduwjj @wz337 @tianyu-l @wconstab @yf225