Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Appearance settings

Conversation

CurryRice233
Copy link
Contributor

@CurryRice233 CurryRice233 commented Jun 9, 2025

Checklist Before Starting

  • Search for similar PR(s).

What does this PR do?

  1. Add fsdp1 forward pefetch configuration.
  2. Add chunk entropy computation.
  3. Add torch.checkpoint to entropy computation.
  4. Move data to device from ActorRolloutRefWorker.update_actor to DataParallelPPOActor.update_policy.
  5. Add npu_cross_entropy_loss fusion kernel.

High-Level Design

  1. More detail see FSDP forward_pefetch
  2. logits usually is a large tensor [bsz*seq_len, voc], on compute_entropy_from_logits will use [bsz*seq_len, voc] * (4(float32) + 2(autocast of softmax+logsumexp) + 1(output of softmax)) memory. To reduce this memory peak, we can use chunk calculation, changing [bsz*seq_len, voc] to [chunk_size(2048), voc].
  3. During the training phase, enable_gradient_checkpointing=True is not applicable to entropy calculation, so add the recomputation function of entropy to reduce the memory peak during the training phase.
  4. On ActorRolloutRefWorker.update_actor all batch data is moved to the device, but this is unnecessary, DataParallelPPOActor.update_policy will move the data to the device for each micro batch.

Specific Changes

List the specific changes.

API

Add 3 new configurations in actor/ref, 1 new configuration in critic/reward.

  • actor_rollout_ref.actor.fsdp_config.forward_prefetch: False
  • actor_rollout_ref.actor.entropy_from_logits_with_chunking: False
  • actor_rollout_ref.actor.entropy_checkpointing: False
  • actor_rollout_ref.ref.fsdp_config.forward_prefetch: False
  • actor_rollout_ref.ref.entropy_from_logits_with_chunking: False
  • actor_rollout_ref.ref.entropy_checkpointing: False
  • critic.model.fsdp_config.forward_prefetch: False
  • reward_model.model.fsdp_config.forward_prefetch: False

Usage Example

Provide usage example(s) for easier usage.

# Add code snippet or script demonstrating how to use this 

Test

For changes that can not be tested by CI (e.g., algorithm implementation, new model support), validate by experiment(s) and show results like training curve plots, evaluatuion results, etc.

Additional Info.

  • Issue Number: Fixes issue # or discussion # if any.
  • Training: [Note which backend this PR will affect: FSDP, Megatron, both, or none]
  • Inference: [Note which backend this PR will affect: vLLM, SGLang, both, or none]

Checklist Before Submitting

  • Read the Contribute Guide.
  • Apply pre-commit checks.
  • Add [BREAKING] to the PR title if it breaks any API.
  • Update the documentation about your changes in the docs.
  • New CI unit test(s) are added to cover the code path.
  • Rely on existing unit tests on CI that covers the code path.

ulysses_sequence_parallel_size: 1

# calculate entropy with chunking to reduce memory peak
entropy_from_logits_with_chunking: False
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we remove these two options? I guess they should be on by default for NPU

Copy link
Contributor Author

@CurryRice233 CurryRice233 Jun 10, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we remove these two options? I guess they should be on by default for NPU

not only for NPU, but also for GPU, chunk+recompute can reduce entropy memory utilization(e2e memory peak). If memory is sufficient, its recommended to disable these two options.

@CurryRice233 CurryRice233 changed the title Add FSDP forward pefetch and recompute+chunking entropy Add FSDP forward pefetch and recompute chunking entropy Jun 10, 2025
@CurryRice233 CurryRice233 changed the title Add FSDP forward pefetch and recompute chunking entropy [FSDP] feat: Add FSDP forward pefetch and recompute chunking entropy Jun 10, 2025
@vermouth1992 vermouth1992 merged commit 0bd03d7 into volcengine:main Jun 11, 2025
38 of 39 checks passed
yellowbee686 pushed a commit to yellowbee686/verl that referenced this pull request Jun 13, 2025
…olcengine#1927)

### Checklist Before Starting

- [x] Search for similar PR(s).

### What does this PR do?

1. Add fsdp1 forward pefetch configuration.
2. Add chunk entropy computation.
3. Add torch.checkpoint to entropy computation.
4. Move data to device from `ActorRolloutRefWorker.update_actor` to
`DataParallelPPOActor.update_policy`.
5. Add `npu_cross_entropy_loss` fusion kernel.

### High-Level Design

1. More detail see [FSDP
forward_pefetch](https://docs.pytorch.org/docs/stable/fsdp.html#module-torch.distributed.fsdp)
2. `logits` usually is a large tensor [bsz\*seq_len, voc], on
`compute_entropy_from_logits` will use [bsz\*seq_len, voc] * (4(float32)
+ 2(autocast of softmax+logsumexp) + 1(output of softmax)) memory. To
reduce this memory peak, we can use chunk calculation, changing
[bsz*seq_len, voc] to [chunk_size(2048), voc].
3. During the training phase, `enable_gradient_checkpointing=True` is
not applicable to entropy calculation, so add the recomputation function
of entropy to reduce the memory peak during the training phase.
4. On `ActorRolloutRefWorker.update_actor` all batch data is moved to
the device, but this is unnecessary,
`DataParallelPPOActor.update_policy` will move the data to the device
for each micro batch.


### Specific Changes

> List the specific changes.

### API

Add 3 new configurations in actor/ref, 1 new configuration in
critic/reward.

- actor_rollout_ref.actor.fsdp_config.forward_prefetch: False
- actor_rollout_ref.actor.entropy_from_logits_with_chunking: False
- actor_rollout_ref.actor.entropy_checkpointing: False
- actor_rollout_ref.ref.fsdp_config.forward_prefetch: False
- actor_rollout_ref.ref.entropy_from_logits_with_chunking: False
- actor_rollout_ref.ref.entropy_checkpointing: False
- critic.model.fsdp_config.forward_prefetch: False
- reward_model.model.fsdp_config.forward_prefetch: False


### Usage Example

> Provide usage example(s) for easier usage.

```python
# Add code snippet or script demonstrating how to use this 
```

### Test

> For changes that can not be tested by CI (e.g., algorithm
implementation, new model support), validate by experiment(s) and show
results like training curve plots, evaluatuion results, etc.

### Additional Info.

- **Issue Number**: Fixes issue # or discussion # if any.
- **Training**: [Note which backend this PR will affect: FSDP, Megatron,
both, or none]
- **Inference**: [Note which backend this PR will affect: vLLM, SGLang,
both, or none]

### Checklist Before Submitting

- [x] Read the [Contribute
Guide](https://github.com/volcengine/verl?tab=readme-ov-file#contribution-guide).
- [x] Apply [pre-commit
checks](https://github.com/volcengine/verl?tab=readme-ov-file#code-linting-and-formatting).
- [x] Add `[BREAKING]` to the PR title if it breaks any API.
- [x] Update the documentation about your changes in the
[docs](https://github.com/volcengine/verl/tree/main/docs).
- [x] New CI unit test(s) are added to cover the code path.
- [x] Rely on existing unit tests on CI that covers the code path.
@eric-haibin-lin
Copy link
Collaborator

@CurryRice233
Copy link
Contributor Author

@CurryRice233 thx for the contribution! Would you mind adding these options to https://verl.readthedocs.io/en/latest/perf/perf_tuning.html (https://github.com/volcengine/verl/blob/main/docs/perf/perf_tuning.rst) ?

No problem, I will create a PR within this week.

vermouth1992 pushed a commit that referenced this pull request Jul 2, 2025
…2322)

### What does this PR do?

@eric-haibin-lin As this comment says
#1927 (comment),
add FSDP forward prefetch and entropy calculation memory optimization to
performance tuning guide.

### Checklist Before Starting

- [x] Search for similar PRs. Paste at least one query link here: ...
- [x] Format the PR title as `[{modules}] {type}: {description}` (This
will be checked by the CI)
- `{modules}` include `fsdp`, `megatron`, `sglang`, `vllm`, `rollout`,
`trainer`, `ci`, `training_utils`, `recipe`, `hardware`, `deployment`,
`ray`, `worker`, `single_controller`, `misc`, `perf`, `model`, `algo`,
`env`, `tool`, `ckpt`, `doc`, `data`
- If this PR involves multiple modules, separate them with `,` like
`[megatron, fsdp, doc]`
  - `{type}` is in `feat`, `fix`, `refactor`, `chore`, `test`
- If this PR breaks any API (CLI arguments, config, function signature,
etc.), add `[BREAKING]` to the beginning of the title.
  - Example: `[BREAKING][fsdp, megatron] feat: dynamic batching`

### Test

> For changes that can not be tested by CI (e.g., algorithm
implementation, new model support), validate by experiment(s) and show
results like training curve plots, evaluation results, etc.

### API and Usage Example

> Demonstrate how the API changes if any, and provide usage example(s)
if possible.

```python
# Add code snippet or script demonstrating how to use this
```

### High-Level Design

> Demonstrate the high-level design if this PR is complex.

### Specific Changes

> List the specific changes.

### Checklist Before Submitting

> [!IMPORTANT]
> Please check all the following items before requesting a review,
otherwise the reviewer might deprioritize this PR for review.

- [x] Read the [Contribute
Guide](https://github.com/volcengine/verl?tab=readme-ov-file#contribution-guide).
- [x] Apply [pre-commit
checks](https://github.com/volcengine/verl?tab=readme-ov-file#code-linting-and-formatting):
`pre-commit install && pre-commit run --all-files --show-diff-on-failure
--color=always`
- [x] Add / Update [the
documentation](https://github.com/volcengine/verl/tree/main/docs).
- [x] Add unit or end-to-end test(s) to [the CI
workflow](https://github.com/volcengine/verl/tree/main/.github/workflows)
to cover all the code. If not feasible, explain why: ...
- [x] Once your PR is ready for CI, send a message in [the `ci-request`
channel](https://verl-project.slack.com/archives/C091TCESWB1) in [the
`verl` Slack
workspace](https://join.slack.com/t/verl-project/shared_invite/zt-3855yhg8g-CTkqXu~hKojPCmo7k_yXTQ).
yellowbee686 pushed a commit to yellowbee686/verl that referenced this pull request Jul 2, 2025
…olcengine#2322)

### What does this PR do?

@eric-haibin-lin As this comment says
volcengine#1927 (comment),
add FSDP forward prefetch and entropy calculation memory optimization to
performance tuning guide.

### Checklist Before Starting

- [x] Search for similar PRs. Paste at least one query link here: ...
- [x] Format the PR title as `[{modules}] {type}: {description}` (This
will be checked by the CI)
- `{modules}` include `fsdp`, `megatron`, `sglang`, `vllm`, `rollout`,
`trainer`, `ci`, `training_utils`, `recipe`, `hardware`, `deployment`,
`ray`, `worker`, `single_controller`, `misc`, `perf`, `model`, `algo`,
`env`, `tool`, `ckpt`, `doc`, `data`
- If this PR involves multiple modules, separate them with `,` like
`[megatron, fsdp, doc]`
  - `{type}` is in `feat`, `fix`, `refactor`, `chore`, `test`
- If this PR breaks any API (CLI arguments, config, function signature,
etc.), add `[BREAKING]` to the beginning of the title.
  - Example: `[BREAKING][fsdp, megatron] feat: dynamic batching`

### Test

> For changes that can not be tested by CI (e.g., algorithm
implementation, new model support), validate by experiment(s) and show
results like training curve plots, evaluation results, etc.

### API and Usage Example

> Demonstrate how the API changes if any, and provide usage example(s)
if possible.

```python
# Add code snippet or script demonstrating how to use this
```

### High-Level Design

> Demonstrate the high-level design if this PR is complex.

### Specific Changes

> List the specific changes.

### Checklist Before Submitting

> [!IMPORTANT]
> Please check all the following items before requesting a review,
otherwise the reviewer might deprioritize this PR for review.

- [x] Read the [Contribute
Guide](https://github.com/volcengine/verl?tab=readme-ov-file#contribution-guide).
- [x] Apply [pre-commit
checks](https://github.com/volcengine/verl?tab=readme-ov-file#code-linting-and-formatting):
`pre-commit install && pre-commit run --all-files --show-diff-on-failure
--color=always`
- [x] Add / Update [the
documentation](https://github.com/volcengine/verl/tree/main/docs).
- [x] Add unit or end-to-end test(s) to [the CI
workflow](https://github.com/volcengine/verl/tree/main/.github/workflows)
to cover all the code. If not feasible, explain why: ...
- [x] Once your PR is ready for CI, send a message in [the `ci-request`
channel](https://verl-project.slack.com/archives/C091TCESWB1) in [the
`verl` Slack
workspace](https://join.slack.com/t/verl-project/shared_invite/zt-3855yhg8g-CTkqXu~hKojPCmo7k_yXTQ).
alexis-mmm pushed a commit to alexis-mmm/verl that referenced this pull request Jul 15, 2025
…2322)

### What does this PR do?

@eric-haibin-lin As this comment says
volcengine/verl#1927 (comment),
add FSDP forward prefetch and entropy calculation memory optimization to
performance tuning guide.

### Checklist Before Starting

- [x] Search for similar PRs. Paste at least one query link here: ...
- [x] Format the PR title as `[{modules}] {type}: {description}` (This
will be checked by the CI)
- `{modules}` include `fsdp`, `megatron`, `sglang`, `vllm`, `rollout`,
`trainer`, `ci`, `training_utils`, `recipe`, `hardware`, `deployment`,
`ray`, `worker`, `single_controller`, `misc`, `perf`, `model`, `algo`,
`env`, `tool`, `ckpt`, `doc`, `data`
- If this PR involves multiple modules, separate them with `,` like
`[megatron, fsdp, doc]`
  - `{type}` is in `feat`, `fix`, `refactor`, `chore`, `test`
- If this PR breaks any API (CLI arguments, config, function signature,
etc.), add `[BREAKING]` to the beginning of the title.
  - Example: `[BREAKING][fsdp, megatron] feat: dynamic batching`

### Test

> For changes that can not be tested by CI (e.g., algorithm
implementation, new model support), validate by experiment(s) and show
results like training curve plots, evaluation results, etc.

### API and Usage Example

> Demonstrate how the API changes if any, and provide usage example(s)
if possible.

```python
# Add code snippet or script demonstrating how to use this
```

### High-Level Design

> Demonstrate the high-level design if this PR is complex.

### Specific Changes

> List the specific changes.

### Checklist Before Submitting

> [!IMPORTANT]
> Please check all the following items before requesting a review,
otherwise the reviewer might deprioritize this PR for review.

- [x] Read the [Contribute
Guide](https://github.com/volcengine/verl?tab=readme-ov-file#contribution-guide).
- [x] Apply [pre-commit
checks](https://github.com/volcengine/verl?tab=readme-ov-file#code-linting-and-formatting):
`pre-commit install && pre-commit run --all-files --show-diff-on-failure
--color=always`
- [x] Add / Update [the
documentation](https://github.com/volcengine/verl/tree/main/docs).
- [x] Add unit or end-to-end test(s) to [the CI
workflow](https://github.com/volcengine/verl/tree/main/.github/workflows)
to cover all the code. If not feasible, explain why: ...
- [x] Once your PR is ready for CI, send a message in [the `ci-request`
channel](https://verl-project.slack.com/archives/C091TCESWB1) in [the
`verl` Slack
workspace](https://join.slack.com/t/verl-project/shared_invite/zt-3855yhg8g-CTkqXu~hKojPCmo7k_yXTQ).
oseyosey pushed a commit to oseyosey/verl that referenced this pull request Jul 28, 2025
…olcengine#2322)

### What does this PR do?

@eric-haibin-lin As this comment says
volcengine#1927 (comment),
add FSDP forward prefetch and entropy calculation memory optimization to
performance tuning guide.

### Checklist Before Starting

- [x] Search for similar PRs. Paste at least one query link here: ...
- [x] Format the PR title as `[{modules}] {type}: {description}` (This
will be checked by the CI)
- `{modules}` include `fsdp`, `megatron`, `sglang`, `vllm`, `rollout`,
`trainer`, `ci`, `training_utils`, `recipe`, `hardware`, `deployment`,
`ray`, `worker`, `single_controller`, `misc`, `perf`, `model`, `algo`,
`env`, `tool`, `ckpt`, `doc`, `data`
- If this PR involves multiple modules, separate them with `,` like
`[megatron, fsdp, doc]`
  - `{type}` is in `feat`, `fix`, `refactor`, `chore`, `test`
- If this PR breaks any API (CLI arguments, config, function signature,
etc.), add `[BREAKING]` to the beginning of the title.
  - Example: `[BREAKING][fsdp, megatron] feat: dynamic batching`

### Test

> For changes that can not be tested by CI (e.g., algorithm
implementation, new model support), validate by experiment(s) and show
results like training curve plots, evaluation results, etc.

### API and Usage Example

> Demonstrate how the API changes if any, and provide usage example(s)
if possible.

```python
# Add code snippet or script demonstrating how to use this
```

### High-Level Design

> Demonstrate the high-level design if this PR is complex.

### Specific Changes

> List the specific changes.

### Checklist Before Submitting

> [!IMPORTANT]
> Please check all the following items before requesting a review,
otherwise the reviewer might deprioritize this PR for review.

- [x] Read the [Contribute
Guide](https://github.com/volcengine/verl?tab=readme-ov-file#contribution-guide).
- [x] Apply [pre-commit
checks](https://github.com/volcengine/verl?tab=readme-ov-file#code-linting-and-formatting):
`pre-commit install && pre-commit run --all-files --show-diff-on-failure
--color=always`
- [x] Add / Update [the
documentation](https://github.com/volcengine/verl/tree/main/docs).
- [x] Add unit or end-to-end test(s) to [the CI
workflow](https://github.com/volcengine/verl/tree/main/.github/workflows)
to cover all the code. If not feasible, explain why: ...
- [x] Once your PR is ready for CI, send a message in [the `ci-request`
channel](https://verl-project.slack.com/archives/C091TCESWB1) in [the
`verl` Slack
workspace](https://join.slack.com/t/verl-project/shared_invite/zt-3855yhg8g-CTkqXu~hKojPCmo7k_yXTQ).
whatadayG pushed a commit to whatadayG/verl that referenced this pull request Sep 5, 2025
…olcengine#1927)

### Checklist Before Starting

- [x] Search for similar PR(s).

### What does this PR do?

1. Add fsdp1 forward pefetch configuration.
2. Add chunk entropy computation.
3. Add torch.checkpoint to entropy computation.
4. Move data to device from `ActorRolloutRefWorker.update_actor` to
`DataParallelPPOActor.update_policy`.
5. Add `npu_cross_entropy_loss` fusion kernel.

### High-Level Design

1. More detail see [FSDP
forward_pefetch](https://docs.pytorch.org/docs/stable/fsdp.html#module-torch.distributed.fsdp)
2. `logits` usually is a large tensor [bsz\*seq_len, voc], on
`compute_entropy_from_logits` will use [bsz\*seq_len, voc] * (4(float32)
+ 2(autocast of softmax+logsumexp) + 1(output of softmax)) memory. To
reduce this memory peak, we can use chunk calculation, changing
[bsz*seq_len, voc] to [chunk_size(2048), voc].
3. During the training phase, `enable_gradient_checkpointing=True` is
not applicable to entropy calculation, so add the recomputation function
of entropy to reduce the memory peak during the training phase.
4. On `ActorRolloutRefWorker.update_actor` all batch data is moved to
the device, but this is unnecessary,
`DataParallelPPOActor.update_policy` will move the data to the device
for each micro batch.


### Specific Changes

> List the specific changes.

### API

Add 3 new configurations in actor/ref, 1 new configuration in
critic/reward.

- actor_rollout_ref.actor.fsdp_config.forward_prefetch: False
- actor_rollout_ref.actor.entropy_from_logits_with_chunking: False
- actor_rollout_ref.actor.entropy_checkpointing: False
- actor_rollout_ref.ref.fsdp_config.forward_prefetch: False
- actor_rollout_ref.ref.entropy_from_logits_with_chunking: False
- actor_rollout_ref.ref.entropy_checkpointing: False
- critic.model.fsdp_config.forward_prefetch: False
- reward_model.model.fsdp_config.forward_prefetch: False


### Usage Example

> Provide usage example(s) for easier usage.

```python
# Add code snippet or script demonstrating how to use this 
```

### Test

> For changes that can not be tested by CI (e.g., algorithm
implementation, new model support), validate by experiment(s) and show
results like training curve plots, evaluatuion results, etc.

### Additional Info.

- **Issue Number**: Fixes issue # or discussion # if any.
- **Training**: [Note which backend this PR will affect: FSDP, Megatron,
both, or none]
- **Inference**: [Note which backend this PR will affect: vLLM, SGLang,
both, or none]

### Checklist Before Submitting

- [x] Read the [Contribute
Guide](https://github.com/volcengine/verl?tab=readme-ov-file#contribution-guide).
- [x] Apply [pre-commit
checks](https://github.com/volcengine/verl?tab=readme-ov-file#code-linting-and-formatting).
- [x] Add `[BREAKING]` to the PR title if it breaks any API.
- [x] Update the documentation about your changes in the
[docs](https://github.com/volcengine/verl/tree/main/docs).
- [x] New CI unit test(s) are added to cover the code path.
- [x] Rely on existing unit tests on CI that covers the code path.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants

Morty Proxy This is a proxified and sanitized view of the page, visit original site.