Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Appearance settings

Conversation

Jianbing-D
Copy link
Contributor

@Jianbing-D Jianbing-D commented Mar 4, 2025

Implemented forward and backward of the following compute logics, which eliminated many intermediate storage tensors, and resulted in reduced peak memory usage.

Equivalent compute logic:

def run_torch_entropy(hidden: torch.Tensor,
                    weight: torch.Tensor,
                    labels: torch.Tensor) -> typing.List[torch.Tensor]:
    logits = torch.matmul(hidden.to(torch.float32), weight.to(torch.float32)) # [num_tokens, vocab_size]
    pd = torch.nn.functional.softmax(logits, dim=-1) # [num_tokens, vocab_size]
    entropy_a = torch.logsumexp(logits, dim=-1) # [num_tokens]
    entropy_b = torch.sum(pd * logits, dim=-1) # [num_tokens]
    entropy = entropy_a - entropy_b
    logprobs = torch.nn.functional.cross_entropy(logits, labels) # [1]
    logprobs = torch.neg(logprobs)
    return logprobs, entropy

API

from verl.utils.kernel import linear_cross_entropy

hidden = torch.randn(num_tokens, hidden_size, dtype=torch.bfloat16, device="cuda")
weight = torch.randn(hidden_size, vocab_size, dtype=torch.bfloat16, device="cuda")
labels = torch.randint(0, vocab_size, (num_tokens,), device="cuda")

loss, entropy = linear_cross_entropy(hidden, weight, labels, reduction="mean")

Storage and latency

image

Unit test

$ cd verl/
$ python3 tests/kernel/test_memory_efficient_entropy.py

NOTE

For compatibility, torch.library.triton_op was not applied to those APIs, so that torch.compile might not be able to be enabled on top of it.

@Jianbing-D Jianbing-D marked this pull request as ready for review March 4, 2025 07:01
@Jianbing-D Jianbing-D force-pushed the main branch 2 times, most recently from 1221335 to a14f31d Compare March 5, 2025 07:13
vermouth1992
vermouth1992 previously approved these changes Mar 8, 2025
@vermouth1992
Copy link
Collaborator

Could you please perform formatting according to the readme?

@ETOgaosion
Copy link
Collaborator

Tested with torch and VeRL current implementation, the improvement is huge.

image

Currently integrated to dp_actor.py

@ETOgaosion
Copy link
Collaborator

The integration has OOM problem, with current fake-weight way. Will reconsider the fusion of linear layer with cross entropy.

@vermouth1992
Copy link
Collaborator

A success of intergration is that the max_token_len can be significantly increased compared to not using this kernel

@Jianbing-D
Copy link
Contributor Author

TP experiment result

image

@gameofdimension
Copy link
Contributor

Liger has a similar kernel called FusedLinearCrossEntropy

@vermouth1992
Copy link
Collaborator

Liger has a similar kernel called FusedLinearCrossEntropy

The kernel in liger can't satisfy the requirement as there are additional loss computation after the kernel, which liger kernel can't support

@Jianbing-D
Copy link
Contributor Author

@vermouth1992
Copy link
Collaborator

There are multiple CI failures. Could you please fix them? Thanks.

@ETOgaosion ETOgaosion reopened this Jun 8, 2025
@volcengine volcengine deleted a comment from CLAassistant Jun 8, 2025
@ETOgaosion ETOgaosion changed the title Memory efficient cross entropy with a linear layer fused [feat] Memory efficient cross entropy with a linear layer fused Jun 8, 2025
@ETOgaosion
Copy link
Collaborator

ETOgaosion commented Jun 8, 2025

Sorry for the close and open operations.

Use main branch to PR may be a dangerous operation for maintainers to cooperation and rebase (QaQ)

Next time will still use PR to others' repo.

@ETOgaosion ETOgaosion changed the title [feat] Memory efficient cross entropy with a linear layer fused [fsdp] feat: Memory efficient cross entropy with a linear layer fused Jun 8, 2025
@ETOgaosion ETOgaosion merged commit c8908e1 into volcengine:main Jun 11, 2025
36 checks passed
yellowbee686 pushed a commit to yellowbee686/verl that referenced this pull request Jun 13, 2025
…volcengine#462)

Implemented forward and backward of the following compute logics, which
eliminated many intermediate storage tensors, and resulted in reduced
peak memory usage.

## Equivalent compute logic:
```python
def run_torch_entropy(hidden: torch.Tensor,
                    weight: torch.Tensor,
                    labels: torch.Tensor) -> typing.List[torch.Tensor]:
    logits = torch.matmul(hidden.to(torch.float32), weight.to(torch.float32)) # [num_tokens, vocab_size]
    pd = torch.nn.functional.softmax(logits, dim=-1) # [num_tokens, vocab_size]
    entropy_a = torch.logsumexp(logits, dim=-1) # [num_tokens]
    entropy_b = torch.sum(pd * logits, dim=-1) # [num_tokens]
    entropy = entropy_a - entropy_b
    logprobs = torch.nn.functional.cross_entropy(logits, labels) # [1]
    logprobs = torch.neg(logprobs)
    return logprobs, entropy
```

## API
```python
from verl.utils.kernel import linear_cross_entropy

hidden = torch.randn(num_tokens, hidden_size, dtype=torch.bfloat16, device="cuda")
weight = torch.randn(hidden_size, vocab_size, dtype=torch.bfloat16, device="cuda")
labels = torch.randint(0, vocab_size, (num_tokens,), device="cuda")

loss, entropy = linear_cross_entropy(hidden, weight, labels, reduction="mean")
```

## Storage and latency
<img width="636" alt="image"
src="https://github.com/user-attachments/assets/396b7303-a46a-46b1-a261-917fda034b02"
/>

## Unit test
```shell
$ cd verl/
$ python3 tests/kernel/test_memory_efficient_entropy.py
```

# NOTE
For compatibility, `torch.library.triton_op` was not applied to those
APIs, so that `torch.compile` might not be able to be enabled on top of
it.

---------

Signed-off-by: Jianbing Dong <jianbingd@nvidia.com>
Co-authored-by: ETOgaosion <gaoziyuan19@mails.ucas.ac.cn>
Co-authored-by: gaoziyuan.955 <gaoziyuan.955@bytedance.com>
Co-authored-by: Blue Space <57280232+ETOgaosion@users.noreply.github.com>
@vadimkantorov
Copy link

Does this PR improve the GRPO loss computation in terms of peak memory? I've come across https://unsloth.ai/blog/grpo which describes how to implement GRPO in the chunked/fused style as well. So I wonder if Verl implements such technique as well

@hljjjmssyh
Copy link

I was wondering whether the recent introduction of this feature might have contributed to the issue described below.
#2547

@vadimkantorov
Copy link

Curious, why For compatibility, torch.library.triton_op was not applied to those APIs, so that torch.compile might not be able to be enabled on top of it.?

Wouldn't it be better to be also be able to use torch.compile on the whole model / loss?

@WindowsXp-Beta
Copy link
Contributor

I noticed some weird results after enabling kernel fusion as described in #2656. wondering if it's a bug or I didn't use it correctly. @Jianbing-D

@vadimkantorov
Copy link

vadimkantorov commented Jul 21, 2025

@WindowsXp-Beta are problems with both torch and triton fused backend?

@WindowsXp-Beta
Copy link
Contributor

@WindowsXp-Beta are problems with both torch and triton fused backend?

Sorry for the late response. Was testing whether if it's caused by our internal model. Our current results show torch backend works normally but triton backend leads to reward collapse and entropy mismatch. We're also testing on Qwen2.5-VL to see if the problem still exists.

@vadimkantorov
Copy link

cc @eric-haibin-lin @vermouth1992

@WindowsXp-Beta
Copy link
Contributor

@vadimkantorov sorry for the late update. Spent some time setting up the environment to run a Qwen2.5-VL using the mainline code. We found log_probs and entropy calculated by fused_kernel and vanilla torch impl matched for Qwen2.5-VL. So looks like the problem is our side and we're still working on it.

@WindowsXp-Beta
Copy link
Contributor

Hi @vadimkantorov , after more tests we suspected the triton kernel may have bugs on certain hidden_states and weights values. Details see the latest comment in #2656. I wonder if you have ever seen similar thing / could reproduce this mismatch on your side.

whatadayG pushed a commit to whatadayG/verl that referenced this pull request Sep 5, 2025
…volcengine#462)

Implemented forward and backward of the following compute logics, which
eliminated many intermediate storage tensors, and resulted in reduced
peak memory usage.

## Equivalent compute logic:
```python
def run_torch_entropy(hidden: torch.Tensor,
                    weight: torch.Tensor,
                    labels: torch.Tensor) -> typing.List[torch.Tensor]:
    logits = torch.matmul(hidden.to(torch.float32), weight.to(torch.float32)) # [num_tokens, vocab_size]
    pd = torch.nn.functional.softmax(logits, dim=-1) # [num_tokens, vocab_size]
    entropy_a = torch.logsumexp(logits, dim=-1) # [num_tokens]
    entropy_b = torch.sum(pd * logits, dim=-1) # [num_tokens]
    entropy = entropy_a - entropy_b
    logprobs = torch.nn.functional.cross_entropy(logits, labels) # [1]
    logprobs = torch.neg(logprobs)
    return logprobs, entropy
```

## API
```python
from verl.utils.kernel import linear_cross_entropy

hidden = torch.randn(num_tokens, hidden_size, dtype=torch.bfloat16, device="cuda")
weight = torch.randn(hidden_size, vocab_size, dtype=torch.bfloat16, device="cuda")
labels = torch.randint(0, vocab_size, (num_tokens,), device="cuda")

loss, entropy = linear_cross_entropy(hidden, weight, labels, reduction="mean")
```

## Storage and latency
<img width="636" alt="image"
src="https://github.com/user-attachments/assets/396b7303-a46a-46b1-a261-917fda034b02"
/>

## Unit test
```shell
$ cd verl/
$ python3 tests/kernel/test_memory_efficient_entropy.py
```

# NOTE
For compatibility, `torch.library.triton_op` was not applied to those
APIs, so that `torch.compile` might not be able to be enabled on top of
it.

---------

Signed-off-by: Jianbing Dong <jianbingd@nvidia.com>
Co-authored-by: ETOgaosion <gaoziyuan19@mails.ucas.ac.cn>
Co-authored-by: gaoziyuan.955 <gaoziyuan.955@bytedance.com>
Co-authored-by: Blue Space <57280232+ETOgaosion@users.noreply.github.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

7 participants

Morty Proxy This is a proxified and sanitized view of the page, visit original site.