Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Appearance settings

[FSDP][optim_state_dict] Call synchronize() to ensure DTensors.to_local() is synchronized #117799

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 4 commits into from

Conversation

fegin
Copy link
Contributor

@fegin fegin commented Jan 18, 2024

Stack from ghstack (oldest at bottom):

If a tensor is converted from DTensor (to_local()), there may be some async communication that has not finished yet. Calling clone() with that tensor does not seem to work (and may increase the memory usage, users report OOM). This is a temporary fix.

Differential Revision: D52890462

cc @mrshenli @pritamdamania87 @zhaojuanmao @satgera @rohan-varma @gqchen @aazzolini @osalpekar @jiayisuse @H-Huang @kwen2501 @awgu @penguinwu @XilunWu @wanchaol @fduwjj @wz337 @tianyu-l @wconstab @yf225

…ensors being recycled

emporary tensors could not be recycled unless the operations are finished.  Calling synchronize() can ensure all the operations are finished. The action can prevent OOM from happening.

Differential Revision: [D52890462](https://our.internmc.facebook.com/intern/diff/D52890462/)

[ghstack-poisoned]
Copy link

pytorch-bot bot commented Jan 18, 2024

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/117799

Note: Links to docs will display an error until the docs builds have been completed.

❌ 2 New Failures

As of commit ead91ac with merge base 5c17f66 (image):

NEW FAILURES - The following jobs have failed:

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@pytorch-bot pytorch-bot bot added the release notes: distributed (fsdp) release notes category label Jan 18, 2024
fegin added a commit that referenced this pull request Jan 18, 2024
…ensors being recycled

emporary tensors could not be recycled unless the operations are finished.  Calling synchronize() can ensure all the operations are finished. The action can prevent OOM from happening.

Differential Revision: [D52890462](https://our.internmc.facebook.com/intern/diff/D52890462/)

ghstack-source-id: 212464431
Pull Request resolved: #117799
@github-actions github-actions bot added oncall: distributed Add this issue/PR to distributed oncall triage queue ciflow/inductor labels Jan 18, 2024
@awgu
Copy link
Collaborator

awgu commented Jan 18, 2024

Does this part of the optim state dict load use multiple streams (which is why memory is not freed immediately when there are no more references)?

@fegin
Copy link
Contributor Author

fegin commented Jan 18, 2024

No multiple streams. But even if the temporary tensors are not referenced in Python, it can still being used, iiuc. For example, the clone() one will keep the source tensor (and its memory) alive until the operation is done. Or do you think this should not happen?

@awgu
Copy link
Collaborator

awgu commented Jan 18, 2024

I was wondering if you have tried del-ing the source that is being cloned and seeing if that frees it at the time of del.

Suppose at first we have this:

a = torch.empty((3,), device="cuda")
b = a.clone()
c = torch.empty((3,), device="cuda")

c cannot reuse the memory of a since a is still alive (due to the Python reference).

Now, suppose we del a:

a = torch.empty((3,), device="cuda")
b = a.clone()
del a  # <--- add this
c = torch.empty((3,), device="cuda")

c can reuse the memory of a. It does not need to wait until the GPU copy kernel from clone() finishes because any subsequent GPU kernel using the memory for c will be sequentially after the GPU copy kernel.

I did not see any special handling of clone that would record a cudaEvent and free the memory later (like in the case of multiple streams and recordStream). I would be curious if del-ing did not free the memory instantly if there is only a single stream.

clone implementation

Tensor clone(const Tensor& src, c10::optional<c10::MemoryFormat> optional_memory_format) {
auto memory_format =
optional_memory_format.value_or(MemoryFormat::Preserve);
Tensor self;
if (memory_format == MemoryFormat::Preserve) {
if (src.is_non_overlapping_and_dense()) {
// Copy all strides, this is marginally faster than calling empty_like
self = at::empty_strided_symint(src.sym_sizes(), src.sym_strides(), src.options());
} else {
self = at::empty_like(src);
}
} else {
self = at::empty_like(src, src.options(), memory_format);
}
if (src._is_zerotensor()) {
self.zero_();
} else {
self.copy_(src);
}
return self;
}

@fegin
Copy link
Contributor Author

fegin commented Jan 19, 2024

@awgu that makes sense. However, the tensor should already be deleted as it has no reference after the inner util function. So I don't think del makes any difference in our case. I would do more tests to verify.

@fegin
Copy link
Contributor Author

fegin commented Jan 19, 2024

@awgu I can confirm that del alone is not enough. I'll check if multiple streams are being used.

…temporary tensors being recycled"

emporary tensors could not be recycled unless the operations are finished.  Calling synchronize() can ensure all the operations are finished. The action can prevent OOM from happening.

Differential Revision: [D52890462](https://our.internmc.facebook.com/intern/diff/D52890462/)

cc mrshenli pritamdamania87 zhaojuanmao satgera rohan-varma gqchen aazzolini osalpekar jiayisuse H-Huang kwen2501 awgu penguinwu XilunWu wanchaol fduwjj wz337 tianyu-l wconstab yf225

[ghstack-poisoned]
@fegin fegin changed the title [FSDP][optim_state_dict] Call synchronize() to ensure the temporary tensors being recycled [FSDP][optim_state_dict] Call synchronize() to ensure the memory used by DTensors.to_local() being recycled Jan 22, 2024
@fegin
Copy link
Contributor Author

fegin commented Jan 22, 2024

@awgu The root cause is DTensor.to_local() has asynchronous communication by default. So I believe we need to call synchronize() in such a case.

@awgu
Copy link
Collaborator

awgu commented Jan 22, 2024

@awgu The root cause is DTensor.to_local() has asynchronous communication by default. So I believe we need to call synchronize() in such a case.

Thanks for looking into this!

@wanchaol Is there any way to specify to to_local() to use synchronous collectives? We want to avoid recordStream, and IIUC, we only avoid it for synchronous collectives. (cc: @kwen2501).

If we can run synchronous collectives and avoid recordStream, then I think we can avoid the CPU sync from synchronize() and instead just rely on the PG NCCL calling current_stream.wait_stream(nccl_stream). Avoiding the CPU sync here could prevent some CPU boundedness (though I have not looked at any profiles).

@fegin
Copy link
Contributor Author

fegin commented Jan 22, 2024

@awgu This is optimizer state_dict code path. What's the down side of using torch.cuda.synchronize()?

@awgu
Copy link
Collaborator

awgu commented Jan 22, 2024

@fegin The downside is just performance from synchronizing the CPU. If there is not much performance downside, then using synchronize sounds good to me!

(I do not have a good mental model of the optimizer state dict performance, so maybe this is not really an issue. However, in general, I do think that considering the performance of state dict makes sense.)

@fegin
Copy link
Contributor Author

fegin commented Jan 22, 2024

@awgu Yup, agree that performance is important. Currently, synchronize() is probably not the main bottleneck. The main bottleneck is the allgather which is done in a per-parameter way without batching. However, it is hard to batching the allgather for optimizer state dict load.

…memory used by DTensors.to_local() being recycled"


If a tensor is converted from DTensor (to_local()), there may be some async communication that has not finished yet. Calling synchronize() can ensure all the operations are finished. The action can prevent OOM from happening.

Differential Revision: [D52890462](https://our.internmc.facebook.com/intern/diff/D52890462/)

cc mrshenli pritamdamania87 zhaojuanmao satgera rohan-varma gqchen aazzolini osalpekar jiayisuse H-Huang kwen2501 awgu penguinwu XilunWu wanchaol fduwjj wz337 tianyu-l wconstab yf225

[ghstack-poisoned]
@fegin fegin changed the title [FSDP][optim_state_dict] Call synchronize() to ensure the memory used by DTensors.to_local() being recycled [FSDP][optim_state_dict] Call synchronize() to ensure the memory used by DTensors.to_local() is synchronized Jan 23, 2024
@fegin fegin requested review from wz337, awgu and LucasLLC January 23, 2024 01:02
@fegin
Copy link
Contributor Author

fegin commented Jan 23, 2024

This becomes a bigger issue than just a performance issue. If we don't call torch.cuda.synchronize(), the subsequent clone() will not work -- the value will not be correctly cloned.

cc., @wanchaol @awgu @wz337 @LucasLLC

@fegin fegin changed the title [FSDP][optim_state_dict] Call synchronize() to ensure the memory used by DTensors.to_local() is synchronized [FSDP][optim_state_dict] Call synchronize() to ensure DTensors.to_local() is synchronized Jan 23, 2024
…sors.to_local() is synchronized"


If a tensor is converted from DTensor (to_local()), there may be some async communication that has not finished yet. Calling `clone()` with that tensor does not seem to work (and may increase the memory usage, users report OOM). This is a temporary fix.

Differential Revision: [D52890462](https://our.internmc.facebook.com/intern/diff/D52890462/)

cc mrshenli pritamdamania87 zhaojuanmao satgera rohan-varma gqchen aazzolini osalpekar jiayisuse H-Huang kwen2501 awgu penguinwu XilunWu wanchaol fduwjj wz337 tianyu-l wconstab yf225

[ghstack-poisoned]
fegin added a commit that referenced this pull request Jan 23, 2024
…al() is synchronized

Pull Request resolved: #117799

If a tensor is converted from DTensor (to_local()), there may be some async communication that has not finished yet. Calling clone() with that tensor does not seem to work (and may increase the memory usage, users report OOM). This is a temporary fix.

ghstack-source-id: 212762982
@exported-using-ghexport

Differential Revision: [D52890462](https://our.internmc.facebook.com/intern/diff/D52890462/)
@kwen2501
Copy link
Contributor

Can we use stream sync instead of device sync?

Copy link
Contributor

@LucasLLC LucasLLC left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Lacking a bit of context but approving to unblock the fix

@kwen2501
Copy link
Contributor

if only device sync works and stream sync doesn't, that means there is something wrong with DTensor's to_local() -- it should sync its communication work back to the "current stream" (or else provide a Work handle for user to sync at some point)

@kwen2501
Copy link
Contributor

If (1) proper stream dependency is maintained by DTensor's to_local() and (2) torch.clone observes stream properly, you don't even need to call stream sync.

@fegin
Copy link
Contributor Author

fegin commented Jan 23, 2024

@kwen2501 _optim_utils.py calls to DTensor.to_local() to gather the tensor. It does not know which stream DTensor.to_local() uses. So I don't think I'm able to use stream wait.

@kwen2501
Copy link
Contributor

If to_local()'s API does not have flags like async_op=True|False, it must always sync back to main stream, so that you don't need to figure out which stream to wait. Said plainly, to_local() must always call work.wait() inside or expose work out.

Copy link
Collaborator

@awgu awgu left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Landing this as a temporary fix sounds good to me!

We should separately figure out how to make AsyncCollectiveTensor wait before clone() (if that is root issue).

@wconstab
Copy link
Contributor

in the more general sense of a user using DTensor.to_local, i don't know if it makes sense that operation has to be synchronous. Wouldn't it be reasonable to expect a to_local call to return a new 'ACT' that represents ongoing tensor work?

value = value.flatten()[intra_param_start_idx : intra_param_end_idx + 1].clone() # type: ignore[operator]
if fsdp_state._device_mesh is not None:
# We have to call synchronize() if the tensor is gathered from
# DTensor. Otherwise, the later `clone()` will cause errors.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

are you using the full_tensor() API to do the gathering or something else? iirc full_tensor gives sync behavior

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

are you using the full_tensor() API to do the gathering or something else? iirc full_tensor gives sync behavior

No. We are still using redistribute for the all_gather, since full_tensor() API was introduced later.
Maybe we could update this line to use full_tensor API? https://github.com/pytorch/pytorch/blob/main/torch/distributed/fsdp/_optim_utils.py#L1455

@kwen2501
Copy link
Contributor

kwen2501 commented Jan 23, 2024

Wouldn't it be reasonable to expect a to_local call to return a new 'ACT' that represents ongoing tensor work?

The reasonableness depends on how much tendency users have to write
x.to_local()
vesus
x = x.to_local()

In a world that tensor.to(...) is prevailing, I'd say the tendency is non-negligible...

@awgu
Copy link
Collaborator

awgu commented Jan 23, 2024

@kwen2501 tensor.to(...) is not inplace though. Only nn.Module.to is in-place.

@kwen2501
Copy link
Contributor

@awgu You are right. Thanks for the correction!

fegin added a commit that referenced this pull request Jan 24, 2024
See the discussion in #117799.

There are some issues when returning a AsyncCollectiveTensor (haven't found the
root causes), including OOM and unexpected values.

This PR force `_gather_state_dict()` to be synchronous with respect to the mian stream.

Differential Revision: [D53049807](https://our.internmc.facebook.com/intern/diff/D53049807/)

[ghstack-poisoned]
@fegin
Copy link
Contributor Author

fegin commented Jan 24, 2024

I have not found the root cause. Since _gather_state_dict() is used not only by FSDP but also PP-FSDP, I decide to change the behavior of _gather_state_dict() to always call wait() before returning the tensor. The new PR is #118197.

@fegin fegin closed this Jan 24, 2024
pytorchmergebot pushed a commit that referenced this pull request Jan 25, 2024
See the discussion in #117799.

There are some issues when returning a AsyncCollectiveTensor (haven't found the
root causes), including OOM and unexpected values.

This PR forces `_gather_state_dict()` to be synchronous with respect to the mian stream.

Differential Revision: [D53049807](https://our.internmc.facebook.com/intern/diff/D53049807/)

Pull Request resolved: #118197
Approved by: https://github.com/wz337, https://github.com/LucasLLC
fegin added a commit that referenced this pull request Jan 25, 2024
…to_local() result"


See the discussion in #117799.

There are some issues when returning a AsyncCollectiveTensor (haven't found the
root causes), including OOM and unexpected values.

This PR forces `_gather_state_dict()` to be synchronous with respect to the mian stream.

Differential Revision: [D53049807](https://our.internmc.facebook.com/intern/diff/D53049807/)

cc mrshenli pritamdamania87 zhaojuanmao satgera rohan-varma gqchen aazzolini osalpekar jiayisuse H-Huang kwen2501 awgu penguinwu XilunWu wanchaol fduwjj wz337 tianyu-l wconstab yf225

[ghstack-poisoned]
fegin added a commit that referenced this pull request Jan 25, 2024
See the discussion in #117799.

There are some issues when returning a AsyncCollectiveTensor (haven't found the
root causes), including OOM and unexpected values.

This PR forces `_gather_state_dict()` to be synchronous with respect to the mian stream.

Differential Revision: [D53049807](https://our.internmc.facebook.com/intern/diff/D53049807/)

cc mrshenli pritamdamania87 zhaojuanmao satgera rohan-varma gqchen aazzolini osalpekar jiayisuse H-Huang kwen2501 awgu penguinwu XilunWu wanchaol fduwjj wz337 tianyu-l wconstab yf225

[ghstack-poisoned]
Skylion007 pushed a commit to Skylion007/pytorch that referenced this pull request Feb 12, 2024
…118197)

See the discussion in pytorch#117799.

There are some issues when returning a AsyncCollectiveTensor (haven't found the
root causes), including OOM and unexpected values.

This PR forces `_gather_state_dict()` to be synchronous with respect to the mian stream.

Differential Revision: [D53049807](https://our.internmc.facebook.com/intern/diff/D53049807/)

Pull Request resolved: pytorch#118197
Approved by: https://github.com/wz337, https://github.com/LucasLLC
@github-actions github-actions bot deleted the gh/fegin/199/head branch February 24, 2024 01:50
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
ciflow/inductor oncall: distributed Add this issue/PR to distributed oncall triage queue release notes: distributed (fsdp) release notes category
Projects
None yet
Development

Successfully merging this pull request may close these issues.

7 participants
Morty Proxy This is a proxified and sanitized view of the page, visit original site.