-
Notifications
You must be signed in to change notification settings - Fork 24.4k
HSDP + DTensor Support in FSDP #118618
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
HSDP + DTensor Support in FSDP #118618
Conversation
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/118618
Note: Links to docs will display an error until the docs builds have been completed. ❌ 5 New FailuresAs of commit 831141c with merge base 5dfcf07 ( NEW FAILURES - The following jobs have failed:
This comment was automatically generated by Dr. CI and updates every 15 minutes. |
Please seek CI approval before scheduling CIFlow labels |
@pytorchbot label "module: distributed_checkpoint" |
Please seek CI approval before scheduling CIFlow labels |
Please seek CI approval before scheduling CIFlow labels |
1 similar comment
Please seek CI approval before scheduling CIFlow labels |
@wz337 Do we still support process group for HSDP? |
is it possible to wrap the PG you have in a device_mesh, or, construct a 2D mesh early on and pull a PG out of it for other uses, but then only have a device_mesh inside FSDP? Doesn't seem great to have mixed pg+mesh IMO. |
Users can still use ProcessGroup as input for HSDP, but they need to do some additional work for checkpointing due to the duplicate FQNs. If they use DCP, they would need to pass the only one replicate group as the process group for dcp.save(). |
I think we should error our or throw warning either in HSDP side or in DCP side for this when we found user is using process group for HSDP? We don't want silent issue to the user |
@mvpatel2000 Could you give a little bit more information regarding the current issue without the change? For example, the error trace. |
This PR does break some unittests. |
Removes raising error if a device_mesh has a parent. The comment says that HSDP + TP is not supported, but I'm able to do 2D parallelism + HSDP fine. The only issues are: - this check - #118618 - a series of PRs related to checkpointing with 3D meshes that I will open We currently monkeypatch for the above which I am slowly upstreaming. I imagine torch will have a better, native integration eventually, but this check seems too aggressive in the meantime given DTensor now lets users do some things themselves (which is amazing 🎉)! Pull Request resolved: #118620 Approved by: https://github.com/wz337, https://github.com/wanchaol
Essentially, the issue is the outermost root FSDP module is passed a device_mesh but no process_group (correct). But this line The trace ends with a value error complaining both are specified. I unfortunately don't have one handy as we've monkeypatched this for a while. I'm not sure what correct fix is given unit test failures. @fegin or @wz337 do either of you have recommendations? |
Sorry, I have not been following closely, but is the issue the same as #118906? |
Let me take a look into this |
if sharding_strategy in HYBRID_SHARDING_STRATEGIES: | ||
if ( | ||
sharding_strategy in HYBRID_SHARDING_STRATEGIES | ||
and device_mesh is not None |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This might just be a typo? I think we want device_mesh is None
(i.e. the user did not pass device_mesh)
-- then we forward the process_group
constructed by the root to the children.
I opened a PR with this here with some basic test: #119064
I am going to close this in favor of @awgu 's PR. Thanks for taking it over! |
Removes raising error if a device_mesh has a parent. The comment says that HSDP + TP is not supported, but I'm able to do 2D parallelism + HSDP fine. The only issues are: - this check - #118618 - a series of PRs related to checkpointing with 3D meshes that I will open We currently monkeypatch for the above which I am slowly upstreaming. I imagine torch will have a better, native integration eventually, but this check seems too aggressive in the meantime given DTensor now lets users do some things themselves (which is amazing 🎉)! Pull Request resolved: #118620 Approved by: https://github.com/Skylion007
Removes raising error if a device_mesh has a parent. The comment says that HSDP + TP is not supported, but I'm able to do 2D parallelism + HSDP fine. The only issues are: - this check - #118618 - a series of PRs related to checkpointing with 3D meshes that I will open We currently monkeypatch for the above which I am slowly upstreaming. I imagine torch will have a better, native integration eventually, but this check seems too aggressive in the meantime given DTensor now lets users do some things themselves (which is amazing 🎉)! Pull Request resolved: #118620 Approved by: https://github.com/wz337, https://github.com/wanchaol
Removes raising error if a device_mesh has a parent. The comment says that HSDP + TP is not supported, but I'm able to do 2D parallelism + HSDP fine. The only issues are: - this check - #118618 - a series of PRs related to checkpointing with 3D meshes that I will open We currently monkeypatch for the above which I am slowly upstreaming. I imagine torch will have a better, native integration eventually, but this check seems too aggressive in the meantime given DTensor now lets users do some things themselves (which is amazing 🎉)! Pull Request resolved: #118620 Approved by: https://github.com/Skylion007
Removes raising error if a device_mesh has a parent. The comment says that HSDP + TP is not supported, but I'm able to do 2D parallelism + HSDP fine. The only issues are: - this check - pytorch#118618 - a series of PRs related to checkpointing with 3D meshes that I will open We currently monkeypatch for the above which I am slowly upstreaming. I imagine torch will have a better, native integration eventually, but this check seems too aggressive in the meantime given DTensor now lets users do some things themselves (which is amazing 🎉)! Pull Request resolved: pytorch#118620 Approved by: https://github.com/Skylion007
Removes raising error if a device_mesh has a parent. The comment says that HSDP + TP is not supported, but I'm able to do 2D parallelism + HSDP fine. The only issues are: - this check - #118618 - a series of PRs related to checkpointing with 3D meshes that I will open We currently monkeypatch for the above which I am slowly upstreaming. I imagine torch will have a better, native integration eventually, but this check seems too aggressive in the meantime given DTensor now lets users do some things themselves (which is amazing 🎉)! Pull Request resolved: #118620 Approved by: https://github.com/Skylion007
FSDP should take either process groups or device_mesh. When device_mesh is specified with DTensor, passing in process groups as well seems to make things blow up
cc @mrshenli @pritamdamania87 @zhaojuanmao @satgera @rohan-varma @gqchen @aazzolini @osalpekar @jiayisuse @H-Huang @kwen2501 @awgu @penguinwu @fegin @XilunWu @wanchaol @fduwjj @wz337 @tianyu-l @wconstab @yf225 @LucasLLC