-
Notifications
You must be signed in to change notification settings - Fork 24.4k
[state_dict] Calls wait() for the DTensor to_local() result #118197
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
See the discussion in #117799. There are some issues when returning a AsyncCollectiveTensor (haven't found the root causes), including OOM and unexpected values. This PR force `_gather_state_dict()` to be synchronous with respect to the mian stream. Differential Revision: [D53049807](https://our.internmc.facebook.com/intern/diff/D53049807/) [ghstack-poisoned]
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/118197
Note: Links to docs will display an error until the docs builds have been completed. ✅ You can merge normally! (1 Unrelated Failure)As of commit c23292f with merge base cef5b93 ( BROKEN TRUNK - The following job failed but were present on the merge base:👉 Rebase onto the `viable/strict` branch to avoid these failures
This comment was automatically generated by Dr. CI and updates every 15 minutes. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
lgtm!
@pytorchbot merge -f 'Landed internally' (Initiating merge automatically since Phabricator Diff has merged, using force because this PR might not pass merge_rules.json but landed internally) |
Merge startedYour change will be merged immediately since you used the force (-f) flag, bypassing any CI checks (ETA: 1-5 minutes). Please use Learn more about merging in the wiki. Questions? Feedback? Please reach out to the PyTorch DevX Team |
Maybe add a TODO here (to change things back) in case the ACT + torch.clone bug is fixed? |
As title Differential Revision: [D53038703](https://our.internmc.facebook.com/intern/diff/D53038703/) Pull Request resolved: #118195 Approved by: https://github.com/rohan-varma, https://github.com/wz337 ghstack dependencies: #118197
…118196) As title Differential Revision: [D53038704](https://our.internmc.facebook.com/intern/diff/D53038704/) Pull Request resolved: #118196 Approved by: https://github.com/rohan-varma, https://github.com/wz337 ghstack dependencies: #118197, #118195
D53049807 and #118197 got out of sync somehow Fixing externally since I'm pretty sure the internal version is correct Pull Request resolved: #118509 Approved by: https://github.com/malfet
D53049807 and pytorch#118197 got out of sync somehow Fixing externally since I'm pretty sure the internal version is correct Pull Request resolved: pytorch#118509 Approved by: https://github.com/malfet
…118197) See the discussion in pytorch#117799. There are some issues when returning a AsyncCollectiveTensor (haven't found the root causes), including OOM and unexpected values. This PR forces `_gather_state_dict()` to be synchronous with respect to the mian stream. Differential Revision: [D53049807](https://our.internmc.facebook.com/intern/diff/D53049807/) Pull Request resolved: pytorch#118197 Approved by: https://github.com/wz337, https://github.com/LucasLLC
…e dict loading Summary: This PR serves as a follow-up fix to address numerical correctness concerns identified in PR pytorch#118197, and we should only wait on `AsyncCollectiveTensor`. Without the change, we occasionally ran into exception: `AttributeError("'Tensor' object has no attribute 'wait'")` Test Plan: **CI**: Wait for the CI test **Test with prod model**: - Tested with models and no-longer ran into the exception after checkpoint loading. Differential Revision: D53680406
…e dict loading (#119716) Summary: This PR serves as a follow-up fix to address numerical correctness concerns identified in PR #118197, and we should only wait on `AsyncCollectiveTensor`. Without the change, we occasionally ran into exception: `AttributeError("'Tensor' object has no attribute 'wait'")` Test Plan: **CI**: Wait for the CI test **Test with prod model**: - Tested with models and no-longer ran into the exception after checkpoint loading. Differential Revision: D53680406 Pull Request resolved: #119716 Approved by: https://github.com/fegin, https://github.com/Skylion007, https://github.com/wz337
…e dict loading (pytorch#119716) Summary: This PR serves as a follow-up fix to address numerical correctness concerns identified in PR pytorch#118197, and we should only wait on `AsyncCollectiveTensor`. Without the change, we occasionally ran into exception: `AttributeError("'Tensor' object has no attribute 'wait'")` Test Plan: **CI**: Wait for the CI test **Test with prod model**: - Tested with models and no-longer ran into the exception after checkpoint loading. Differential Revision: D53680406 Pull Request resolved: pytorch#119716 Approved by: https://github.com/fegin, https://github.com/Skylion007, https://github.com/wz337
Stack from ghstack (oldest at bottom):
See the discussion in #117799.
There are some issues when returning a AsyncCollectiveTensor (haven't found the
root causes), including OOM and unexpected values.
This PR forces
_gather_state_dict()
to be synchronous with respect to the mian stream.Differential Revision: D53049807
cc @mrshenli @pritamdamania87 @zhaojuanmao @satgera @rohan-varma @gqchen @aazzolini @osalpekar @jiayisuse @H-Huang @kwen2501 @awgu @penguinwu @XilunWu @wanchaol @fduwjj @wz337 @tianyu-l @wconstab @yf225