-
Notifications
You must be signed in to change notification settings - Fork 2.3k
[fsdp] feat: Memory efficient cross entropy with a linear layer fused #462
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
1221335
to
a14f31d
Compare
Could you please perform formatting according to the readme? |
The integration has OOM problem, with current fake-weight way. Will reconsider the fusion of linear layer with cross entropy. |
A success of intergration is that the max_token_len can be significantly increased compared to not using this kernel |
Liger has a similar kernel called |
The kernel in liger can't satisfy the requirement as there are additional loss computation after the kernel, which liger kernel can't support |
There are multiple CI failures. Could you please fix them? Thanks. |
Signed-off-by: Jianbing Dong <jianbingd@nvidia.com>
Sorry for the close and open operations. Use main branch to PR may be a dangerous operation for maintainers to cooperation and rebase (QaQ) Next time will still use PR to others' repo. |
Signed-off-by: Jianbing Dong <jianbingd@nvidia.com>
…volcengine#462) Implemented forward and backward of the following compute logics, which eliminated many intermediate storage tensors, and resulted in reduced peak memory usage. ## Equivalent compute logic: ```python def run_torch_entropy(hidden: torch.Tensor, weight: torch.Tensor, labels: torch.Tensor) -> typing.List[torch.Tensor]: logits = torch.matmul(hidden.to(torch.float32), weight.to(torch.float32)) # [num_tokens, vocab_size] pd = torch.nn.functional.softmax(logits, dim=-1) # [num_tokens, vocab_size] entropy_a = torch.logsumexp(logits, dim=-1) # [num_tokens] entropy_b = torch.sum(pd * logits, dim=-1) # [num_tokens] entropy = entropy_a - entropy_b logprobs = torch.nn.functional.cross_entropy(logits, labels) # [1] logprobs = torch.neg(logprobs) return logprobs, entropy ``` ## API ```python from verl.utils.kernel import linear_cross_entropy hidden = torch.randn(num_tokens, hidden_size, dtype=torch.bfloat16, device="cuda") weight = torch.randn(hidden_size, vocab_size, dtype=torch.bfloat16, device="cuda") labels = torch.randint(0, vocab_size, (num_tokens,), device="cuda") loss, entropy = linear_cross_entropy(hidden, weight, labels, reduction="mean") ``` ## Storage and latency <img width="636" alt="image" src="https://github.com/user-attachments/assets/396b7303-a46a-46b1-a261-917fda034b02" /> ## Unit test ```shell $ cd verl/ $ python3 tests/kernel/test_memory_efficient_entropy.py ``` # NOTE For compatibility, `torch.library.triton_op` was not applied to those APIs, so that `torch.compile` might not be able to be enabled on top of it. --------- Signed-off-by: Jianbing Dong <jianbingd@nvidia.com> Co-authored-by: ETOgaosion <gaoziyuan19@mails.ucas.ac.cn> Co-authored-by: gaoziyuan.955 <gaoziyuan.955@bytedance.com> Co-authored-by: Blue Space <57280232+ETOgaosion@users.noreply.github.com>
Does this PR improve the GRPO loss computation in terms of peak memory? I've come across https://unsloth.ai/blog/grpo which describes how to implement GRPO in the chunked/fused style as well. So I wonder if Verl implements such technique as well |
I was wondering whether the recent introduction of this feature might have contributed to the issue described below. |
Curious, why Wouldn't it be better to be also be able to use torch.compile on the whole model / loss? |
I noticed some weird results after enabling kernel fusion as described in #2656. wondering if it's a bug or I didn't use it correctly. @Jianbing-D |
@WindowsXp-Beta are problems with both torch and triton fused backend? |
Sorry for the late response. Was testing whether if it's caused by our internal model. Our current results show |
@vadimkantorov sorry for the late update. Spent some time setting up the environment to run a Qwen2.5-VL using the mainline code. We found log_probs and entropy calculated by fused_kernel and vanilla torch impl matched for Qwen2.5-VL. So looks like the problem is our side and we're still working on it. |
Hi @vadimkantorov , after more tests we suspected the triton kernel may have bugs on certain |
…volcengine#462) Implemented forward and backward of the following compute logics, which eliminated many intermediate storage tensors, and resulted in reduced peak memory usage. ## Equivalent compute logic: ```python def run_torch_entropy(hidden: torch.Tensor, weight: torch.Tensor, labels: torch.Tensor) -> typing.List[torch.Tensor]: logits = torch.matmul(hidden.to(torch.float32), weight.to(torch.float32)) # [num_tokens, vocab_size] pd = torch.nn.functional.softmax(logits, dim=-1) # [num_tokens, vocab_size] entropy_a = torch.logsumexp(logits, dim=-1) # [num_tokens] entropy_b = torch.sum(pd * logits, dim=-1) # [num_tokens] entropy = entropy_a - entropy_b logprobs = torch.nn.functional.cross_entropy(logits, labels) # [1] logprobs = torch.neg(logprobs) return logprobs, entropy ``` ## API ```python from verl.utils.kernel import linear_cross_entropy hidden = torch.randn(num_tokens, hidden_size, dtype=torch.bfloat16, device="cuda") weight = torch.randn(hidden_size, vocab_size, dtype=torch.bfloat16, device="cuda") labels = torch.randint(0, vocab_size, (num_tokens,), device="cuda") loss, entropy = linear_cross_entropy(hidden, weight, labels, reduction="mean") ``` ## Storage and latency <img width="636" alt="image" src="https://github.com/user-attachments/assets/396b7303-a46a-46b1-a261-917fda034b02" /> ## Unit test ```shell $ cd verl/ $ python3 tests/kernel/test_memory_efficient_entropy.py ``` # NOTE For compatibility, `torch.library.triton_op` was not applied to those APIs, so that `torch.compile` might not be able to be enabled on top of it. --------- Signed-off-by: Jianbing Dong <jianbingd@nvidia.com> Co-authored-by: ETOgaosion <gaoziyuan19@mails.ucas.ac.cn> Co-authored-by: gaoziyuan.955 <gaoziyuan.955@bytedance.com> Co-authored-by: Blue Space <57280232+ETOgaosion@users.noreply.github.com>
Implemented forward and backward of the following compute logics, which eliminated many intermediate storage tensors, and resulted in reduced peak memory usage.
Equivalent compute logic:
API
Storage and latency
Unit test
$ cd verl/ $ python3 tests/kernel/test_memory_efficient_entropy.py
NOTE
For compatibility,
torch.library.triton_op
was not applied to those APIs, so thattorch.compile
might not be able to be enabled on top of it.