Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Appearance settings

HSDP + DTensor Support in FSDP #118618

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 2 commits into from
Closed

Conversation

mvpatel2000
Copy link
Contributor

@mvpatel2000 mvpatel2000 commented Jan 30, 2024

FSDP should take either process groups or device_mesh. When device_mesh is specified with DTensor, passing in process groups as well seems to make things blow up

cc @mrshenli @pritamdamania87 @zhaojuanmao @satgera @rohan-varma @gqchen @aazzolini @osalpekar @jiayisuse @H-Huang @kwen2501 @awgu @penguinwu @fegin @XilunWu @wanchaol @fduwjj @wz337 @tianyu-l @wconstab @yf225 @LucasLLC

@pytorch-bot pytorch-bot bot added the release notes: distributed (sharded) release notes category label Jan 30, 2024
Copy link

pytorch-bot bot commented Jan 30, 2024

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/118618

Note: Links to docs will display an error until the docs builds have been completed.

❌ 5 New Failures

As of commit 831141c with merge base 5dfcf07 (image):

NEW FAILURES - The following jobs have failed:

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@github-actions github-actions bot added oncall: distributed Add this issue/PR to distributed oncall triage queue ciflow/inductor labels Jan 30, 2024
Copy link

pytorch-bot bot commented Jan 30, 2024

Please seek CI approval before scheduling CIFlow labels

@mvpatel2000
Copy link
Contributor Author

@pytorchbot label "module: distributed_checkpoint"

Copy link

pytorch-bot bot commented Jan 30, 2024

Please seek CI approval before scheduling CIFlow labels

@mvpatel2000 mvpatel2000 changed the title HSDP + DTensor Support HSDP + DTensor Support in FSDP Jan 30, 2024
@fegin fegin added ciflow/trunk Trigger trunk jobs on your pull request ciflow/periodic Trigger jobs ran periodically on master (periodic.yml) on the PR labels Jan 30, 2024
Copy link

pytorch-bot bot commented Jan 30, 2024

Please seek CI approval before scheduling CIFlow labels

1 similar comment
Copy link

pytorch-bot bot commented Jan 30, 2024

Please seek CI approval before scheduling CIFlow labels

@pytorch-bot pytorch-bot bot removed ciflow/periodic Trigger jobs ran periodically on master (periodic.yml) on the PR ciflow/trunk Trigger trunk jobs on your pull request labels Jan 30, 2024
@fegin
Copy link
Contributor

fegin commented Jan 30, 2024

@wz337 Do we still support process group for HSDP?

@wz337 wz337 added ciflow/periodic Trigger jobs ran periodically on master (periodic.yml) on the PR release notes: distributed (fsdp) release notes category module: fsdp labels Jan 30, 2024
@wconstab
Copy link
Contributor

is it possible to wrap the PG you have in a device_mesh, or, construct a 2D mesh early on and pull a PG out of it for other uses, but then only have a device_mesh inside FSDP? Doesn't seem great to have mixed pg+mesh IMO.

@wz337
Copy link
Contributor

wz337 commented Jan 30, 2024

@wz337 Do we still support process group for HSDP?

Users can still use ProcessGroup as input for HSDP, but they need to do some additional work for checkpointing due to the duplicate FQNs. If they use DCP, they would need to pass the only one replicate group as the process group for dcp.save().

@wanchaol
Copy link
Collaborator

Users can still use ProcessGroup as input for HSDP, but they need to do some additional work for checkpointing due to the duplicate FQNs. If they use DCP, they would need to pass the only one replicate group as the process group for dcp.save().

I think we should error our or throw warning either in HSDP side or in DCP side for this when we found user is using process group for HSDP? We don't want silent issue to the user

@wz337
Copy link
Contributor

wz337 commented Jan 30, 2024

@mvpatel2000 Could you give a little bit more information regarding the current issue without the change? For example, the error trace.

@fegin
Copy link
Contributor

fegin commented Jan 30, 2024

This PR does break some unittests.

@ezyang ezyang requested review from wanchaol and awgu February 1, 2024 12:45
@ezyang ezyang added the triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module label Feb 1, 2024
pytorchmergebot pushed a commit that referenced this pull request Feb 2, 2024
Removes raising error if a device_mesh has a parent.

The comment says that HSDP + TP is not supported, but I'm able to do 2D parallelism + HSDP fine. The only issues are:
- this check
- #118618
- a series of PRs related to checkpointing with 3D meshes that I will open
We currently monkeypatch for the above which I am slowly upstreaming.

I imagine torch will have a better, native integration eventually, but this check seems too aggressive in the meantime given DTensor now lets users do some things themselves (which is amazing 🎉)!

Pull Request resolved: #118620
Approved by: https://github.com/wz337, https://github.com/wanchaol
@mvpatel2000
Copy link
Contributor Author

@mvpatel2000 Could you give a little bit more information regarding the current issue without the change? For example, the error trace.

Essentially, the issue is the outermost root FSDP module is passed a device_mesh but no process_group (correct). But this line root_kwargs["process_group"] = (self.process_group, self._inter_node_pg) adds process_group to the root kwargs, which gets passed down recursively. So, if you have recursive wrapping, like some child module model.block1 that is also meant to be FSDP'd, then it gets root_kwargs passed into it for FSDP init. But now that child module, when it tries to init FSDP, has both a device_mesh and a process_group

The trace ends with a value error complaining both are specified. I unfortunately don't have one handy as we've monkeypatched this for a while.

I'm not sure what correct fix is given unit test failures. @fegin or @wz337 do either of you have recommendations?

@awgu
Copy link
Collaborator

awgu commented Feb 2, 2024

Sorry, I have not been following closely, but is the issue the same as #118906?

@mvpatel2000
Copy link
Contributor Author

Sorry, I have not been following closely, but is the issue the same as #118906?

@awgu I think it is the same issue :)

@awgu
Copy link
Collaborator

awgu commented Feb 2, 2024

Let me take a look into this

if sharding_strategy in HYBRID_SHARDING_STRATEGIES:
if (
sharding_strategy in HYBRID_SHARDING_STRATEGIES
and device_mesh is not None
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This might just be a typo? I think we want device_mesh is None (i.e. the user did not pass device_mesh) -- then we forward the process_group constructed by the root to the children.

I opened a PR with this here with some basic test: #119064

@mvpatel2000
Copy link
Contributor Author

I am going to close this in favor of @awgu 's PR. Thanks for taking it over!

@mvpatel2000 mvpatel2000 closed this Feb 2, 2024
@mvpatel2000 mvpatel2000 deleted the patch-4 branch February 5, 2024 18:27
pytorchmergebot pushed a commit that referenced this pull request Feb 8, 2024
Removes raising error if a device_mesh has a parent.

The comment says that HSDP + TP is not supported, but I'm able to do 2D parallelism + HSDP fine. The only issues are:
- this check
- #118618
- a series of PRs related to checkpointing with 3D meshes that I will open
We currently monkeypatch for the above which I am slowly upstreaming.

I imagine torch will have a better, native integration eventually, but this check seems too aggressive in the meantime given DTensor now lets users do some things themselves (which is amazing 🎉)!

Pull Request resolved: #118620
Approved by: https://github.com/Skylion007
pytorch-bot bot pushed a commit that referenced this pull request Feb 8, 2024
Removes raising error if a device_mesh has a parent.

The comment says that HSDP + TP is not supported, but I'm able to do 2D parallelism + HSDP fine. The only issues are:
- this check
- #118618
- a series of PRs related to checkpointing with 3D meshes that I will open
We currently monkeypatch for the above which I am slowly upstreaming.

I imagine torch will have a better, native integration eventually, but this check seems too aggressive in the meantime given DTensor now lets users do some things themselves (which is amazing 🎉)!

Pull Request resolved: #118620
Approved by: https://github.com/wz337, https://github.com/wanchaol
pytorch-bot bot pushed a commit that referenced this pull request Feb 8, 2024
Removes raising error if a device_mesh has a parent.

The comment says that HSDP + TP is not supported, but I'm able to do 2D parallelism + HSDP fine. The only issues are:
- this check
- #118618
- a series of PRs related to checkpointing with 3D meshes that I will open
We currently monkeypatch for the above which I am slowly upstreaming.

I imagine torch will have a better, native integration eventually, but this check seems too aggressive in the meantime given DTensor now lets users do some things themselves (which is amazing 🎉)!

Pull Request resolved: #118620
Approved by: https://github.com/Skylion007
mvpatel2000 added a commit to mvpatel2000/pytorch that referenced this pull request Feb 13, 2024
Removes raising error if a device_mesh has a parent.

The comment says that HSDP + TP is not supported, but I'm able to do 2D parallelism + HSDP fine. The only issues are:
- this check
- pytorch#118618
- a series of PRs related to checkpointing with 3D meshes that I will open
We currently monkeypatch for the above which I am slowly upstreaming.

I imagine torch will have a better, native integration eventually, but this check seems too aggressive in the meantime given DTensor now lets users do some things themselves (which is amazing 🎉)!

Pull Request resolved: pytorch#118620
Approved by: https://github.com/Skylion007
clee2000 pushed a commit that referenced this pull request Feb 14, 2024
Removes raising error if a device_mesh has a parent.

The comment says that HSDP + TP is not supported, but I'm able to do 2D parallelism + HSDP fine. The only issues are:
- this check
- #118618
- a series of PRs related to checkpointing with 3D meshes that I will open
We currently monkeypatch for the above which I am slowly upstreaming.

I imagine torch will have a better, native integration eventually, but this check seems too aggressive in the meantime given DTensor now lets users do some things themselves (which is amazing 🎉)!

Pull Request resolved: #118620
Approved by: https://github.com/Skylion007
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
ciflow/periodic Trigger jobs ran periodically on master (periodic.yml) on the PR module: fsdp oncall: distributed Add this issue/PR to distributed oncall triage queue open source release notes: distributed (fsdp) release notes category release notes: distributed (sharded) release notes category triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module
Projects
None yet
Development

Successfully merging this pull request may close these issues.

8 participants
Morty Proxy This is a proxified and sanitized view of the page, visit original site.