Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Appearance settings

Conversation

DtYXs
Copy link
Contributor

@DtYXs DtYXs commented May 27, 2025

Checklist Before Starting

  • Search for similar PR(s).

What does this PR do?

Add support for PF-PPO in verl.

Specific Changes

verl/trainer/config/ppo_trainer.yaml: Add config for PF-PPO
verl/trainer/ppo/core_algos.py: Add compute_pf_ppo_reweight_data function.
verl/trainer/ppo/ray_trainer.py: Do PF-PPO in compute_advantage when config.algorithm.use_pf_ppo is True
README.md: Update PF-PPO in README

Usage Example

set -x

python3 -m verl.trainer.main_ppo \
    algorithm.adv_estimator=gae \
    algorithm.use_pf_ppo=True \
    algorithm.pf_ppo.reweight_method=pow \
    algorithm.pf_ppo.weight_pow=2.0 \
    data.train_files=$HOME/data/gsm8k/train.parquet \
    data.val_files=$HOME/data/gsm8k/test.parquet \
    data.train_batch_size=1024 \
    data.max_prompt_length=512 \
    data.max_response_length=512 \
    data.filter_overlong_prompts=True \
    data.truncation='error' \
    actor_rollout_ref.model.path=deepseek-ai/deepseek-llm-7b-chat \
    actor_rollout_ref.actor.optim.lr=1e-6 \
    actor_rollout_ref.model.use_remove_padding=True \
    actor_rollout_ref.actor.ppo_mini_batch_size=256 \
    actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu=16 \
    actor_rollout_ref.actor.fsdp_config.param_offload=False \
    actor_rollout_ref.actor.fsdp_config.optimizer_offload=False \
    actor_rollout_ref.actor.use_kl_loss=False \
    actor_rollout_ref.model.enable_gradient_checkpointing=True \
    actor_rollout_ref.rollout.log_prob_micro_batch_size_per_gpu=32 \
    actor_rollout_ref.rollout.tensor_model_parallel_size=4 \
    actor_rollout_ref.rollout.name=vllm \
    actor_rollout_ref.rollout.gpu_memory_utilization=0.4 \
    actor_rollout_ref.rollout.n=5 \
    critic.optim.lr=1e-5 \
    critic.model.use_remove_padding=True \
    critic.model.path=deepseek-ai/deepseek-llm-7b-chat \
    critic.model.enable_gradient_checkpointing=True \
    critic.ppo_micro_batch_size_per_gpu=32 \
    critic.model.fsdp_config.param_offload=False \
    critic.model.fsdp_config.optimizer_offload=False \
    algorithm.use_kl_in_reward=False \
    trainer.critic_warmup=0 \
    trainer.logger=['console','wandb'] \
    trainer.project_name='verl_example_gsm8k' \
    trainer.experiment_name='deepseek_llm_7b_function_rm' \
    trainer.n_gpus_per_node=8 \
    trainer.nnodes=1 \
    trainer.save_freq=20 \
    trainer.test_freq=1 \
    trainer.total_epochs=15 $@

Test

Simple gsm8k test.

image

Checklist Before Submitting

  • Read the Contribute Guide.
  • Apply pre-commit checks.
  • Add [BREAKING] to the PR title if it breaks any API.
  • Update the documentation about your changes in the docs.
  • Add CI test(s) if necessary.

@CLAassistant
Copy link

CLAassistant commented May 27, 2025

CLA assistant check
All committers have signed the CLA.

@hiyouga hiyouga self-requested a review May 27, 2025 09:53
@hiyouga hiyouga merged commit 75d2b36 into volcengine:main May 28, 2025
35 checks passed
wwwjn pushed a commit to wwwjn/verl that referenced this pull request Jun 10, 2025
### Checklist Before Starting

- [x] Search for similar PR(s).

### What does this PR do?

> Add support for [PF-PPO](https://arxiv.org/abs/2409.06957) in verl.

### Specific Changes

> `verl/trainer/config/ppo_trainer.yaml`: Add config for PF-PPO
`verl/trainer/ppo/core_algos.py`: Add `compute_pf_ppo_reweight_data`
function.
`verl/trainer/ppo/ray_trainer.py`: Do PF-PPO in `compute_advantage` when
`config.algorithm.use_pf_ppo` is `True`
`README.md`: Update PF-PPO in README

### Usage Example

```bash
set -x

python3 -m verl.trainer.main_ppo \
    algorithm.adv_estimator=gae \
    algorithm.use_pf_ppo=True \
    algorithm.pf_ppo.reweight_method=pow \
    algorithm.pf_ppo.weight_pow=2.0 \
    data.train_files=$HOME/data/gsm8k/train.parquet \
    data.val_files=$HOME/data/gsm8k/test.parquet \
    data.train_batch_size=1024 \
    data.max_prompt_length=512 \
    data.max_response_length=512 \
    data.filter_overlong_prompts=True \
    data.truncation='error' \
    actor_rollout_ref.model.path=deepseek-ai/deepseek-llm-7b-chat \
    actor_rollout_ref.actor.optim.lr=1e-6 \
    actor_rollout_ref.model.use_remove_padding=True \
    actor_rollout_ref.actor.ppo_mini_batch_size=256 \
    actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu=16 \
    actor_rollout_ref.actor.fsdp_config.param_offload=False \
    actor_rollout_ref.actor.fsdp_config.optimizer_offload=False \
    actor_rollout_ref.actor.use_kl_loss=False \
    actor_rollout_ref.model.enable_gradient_checkpointing=True \
    actor_rollout_ref.rollout.log_prob_micro_batch_size_per_gpu=32 \
    actor_rollout_ref.rollout.tensor_model_parallel_size=4 \
    actor_rollout_ref.rollout.name=vllm \
    actor_rollout_ref.rollout.gpu_memory_utilization=0.4 \
    actor_rollout_ref.rollout.n=5 \
    critic.optim.lr=1e-5 \
    critic.model.use_remove_padding=True \
    critic.model.path=deepseek-ai/deepseek-llm-7b-chat \
    critic.model.enable_gradient_checkpointing=True \
    critic.ppo_micro_batch_size_per_gpu=32 \
    critic.model.fsdp_config.param_offload=False \
    critic.model.fsdp_config.optimizer_offload=False \
    algorithm.use_kl_in_reward=False \
    trainer.critic_warmup=0 \
    trainer.logger=['console','wandb'] \
    trainer.project_name='verl_example_gsm8k' \
    trainer.experiment_name='deepseek_llm_7b_function_rm' \
    trainer.n_gpus_per_node=8 \
    trainer.nnodes=1 \
    trainer.save_freq=20 \
    trainer.test_freq=1 \
    trainer.total_epochs=15 $@
```

### Test

Simple gsm8k test.

<img width="502" alt="image"
src="https://github.com/user-attachments/assets/4298ce20-a691-4edb-8e4a-ef68fb0fb6be"
/>

### Checklist Before Submitting

- [x] Read the [Contribute
Guide](https://github.com/volcengine/verl?tab=readme-ov-file#contribution-guide).
- [x] Apply [pre-commit
checks](https://github.com/volcengine/verl?tab=readme-ov-file#code-linting-and-formatting).
- [x] Add `[BREAKING]` to the PR title if it breaks any API.
- [x] Update the documentation about your changes in the
[docs](https://github.com/volcengine/verl/tree/main/docs).
- [x] Add CI test(s) if necessary.

---------

Co-authored-by: hoshi-hiyouga <hiyouga@buaa.edu.cn>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants

Morty Proxy This is a proxified and sanitized view of the page, visit original site.