Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Appearance settings
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
27 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
72 changes: 15 additions & 57 deletions 72 .github/workflows/e2e_ppo_trainer.yml
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ jobs:

e2e_ppo_trainer_vllm:
runs-on: [L20x8]
timeout-minutes: 40 # Increase this timeout value as needed
timeout-minutes: 60 # Increase this timeout value as needed
env:
HTTP_PROXY: ${{ secrets.PROXY_HTTP }}
HTTPS_PROXY: ${{ secrets.PROXY_HTTPS }}
Expand Down Expand Up @@ -161,6 +161,14 @@ jobs:
run: |
ray stop --force
LIGER=True bash tests/e2e/ppo_trainer/run_model_reward.sh
- name: Running GSM8K E2E with rmpad using model rm with Fused Kernel enabled
run: |
ray stop --force
FUSED_KERNELS=True bash tests/e2e/ppo_trainer/run_model_reward.sh
- name: Running GSM8K E2E with rmpad using model rm with Fused Kernel enabled
run: |
ray stop --force
FUSED_KERNEL=True FUSED_KERNEL_BACKEND=triton bash tests/e2e/ppo_trainer/run_model_reward.sh

e2e_ppo_trainer_vllm_vlm:
runs-on: [L20x8]
Expand Down Expand Up @@ -272,7 +280,7 @@ jobs:
e2e_ppo_trainer_sglang_vlm:
runs-on: [L20x8]
needs: pre_commit_for_ppo
timeout-minutes: 40 # Increase this timeout value as needed
timeout-minutes: 60 # Increase this timeout value as needed
env:
HTTP_PROXY: ${{ secrets.PROXY_HTTP }}
HTTPS_PROXY: ${{ secrets.PROXY_HTTPS }}
Expand Down Expand Up @@ -304,71 +312,21 @@ jobs:
ENGINE=sglang GPU_MEMORY_UTILIZATION=0.6 ACTOR_FSDP_PARAM_OFFLOAD=True \
ACTOR_FSDP_OPTIMIZER_OFFLOAD=True REF_FSDP_PARAM_OFFLOAD=True \
bash tests/e2e/ppo_trainer/run_function_reward.sh

e2e_ppo_trainer_fused_kernels_vllm:
runs-on: [L20x8]
needs: pre_commit_for_ppo
timeout-minutes: 40 # Increase this timeout value as needed
env:
HTTP_PROXY: ${{ secrets.PROXY_HTTP }}
HTTPS_PROXY: ${{ secrets.PROXY_HTTPS }}
NO_PROXY: "localhost,127.0.0.1,hf-mirror.com"
HF_ENDPOINT: "https://hf-mirror.com"
HF_HUB_ENABLE_HF_TRANSFER: "0" # This is more stable
container:
image: hiyouga/verl:ngc-th2.6.0-cu126-vllm0.8.3-flashinfer0.2.2-cxx11abi0
options: --gpus all --shm-size=50g # Visual dataloader requires large memory
steps:
- uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2
with:
fetch-depth: 0
- name: Install the current repository
run: |
pip3 install -e .[test,geo,vllm]
# Geo3k
- name: Prepare Geo3k dataset
run: |
ray stop --force
python3 examples/data_preprocess/geo3k.py
- name: Running Geo3k VLM E2E with rmpad using fused kernel (Qwen2.5-VL)
- name: Running Geo3k VLM E2E with rmpad using torch fused kernel (Qwen2.5-VL)
run: |
ray stop --force
FUSED_KERNELS=True TRAIN_FILES=$HOME/data/geo3k/train.parquet VAL_FILES=$HOME/data/geo3k/test.parquet \
MAX_PROMPT_LEN=1536 MAX_RESPONSE_LEN=1536 \
MODEL_ID=Qwen/Qwen2.5-VL-3B-Instruct \
ADV_ESTIMATOR=grpo RM_PAD=True USE_KL=True ENABLE_CHUNKED_PREFILL=False \
GPU_MEMORY_UTILIZATION=0.6 ACTOR_FSDP_PARAM_OFFLOAD=True \
ENGINE=sglang GPU_MEMORY_UTILIZATION=0.6 ACTOR_FSDP_PARAM_OFFLOAD=True \
ACTOR_FSDP_OPTIMIZER_OFFLOAD=True REF_FSDP_PARAM_OFFLOAD=True \
bash tests/e2e/ppo_trainer/run_function_reward.sh

e2e_ppo_trainer_fused_kernels_sglang:
runs-on: [L20x8]
needs: pre_commit_for_ppo
timeout-minutes: 40 # Increase this timeout value as needed
env:
HTTP_PROXY: ${{ secrets.PROXY_HTTP }}
HTTPS_PROXY: ${{ secrets.PROXY_HTTPS }}
NO_PROXY: "localhost,127.0.0.1,hf-mirror.com"
HF_ENDPOINT: "https://hf-mirror.com"
HF_HUB_ENABLE_HF_TRANSFER: "0" # This is more stable
container:
image: whatcanyousee/verl:ngc-cu124-vllm0.8.5-sglang0.4.6.post5-mcore0.12.0-te2.3
options: --gpus all --shm-size=50g # Visual dataloader requires large memory
steps:
- uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2
with:
fetch-depth: 0
- name: Install the current repository
run: |
pip3 install -e .[test,geo,gpu,sglang]
- name: Prepare Geo3k dataset
- name: Running Geo3k VLM E2E with rmpad using triton fused kernel (Qwen2.5-VL)
run: |
ray stop --force
python3 examples/data_preprocess/geo3k.py
- name: Running Geo3k VLM E2E with rmpad using fused kernel (Qwen2.5-VL)
run: |
ray stop --force
FUSED_KERNELS=True TRAIN_FILES=$HOME/data/geo3k/train.parquet VAL_FILES=$HOME/data/geo3k/test.parquet \
FUSED_KERNELS=True FUSED_KERNEL_BACKEND=triton \
TRAIN_FILES=$HOME/data/geo3k/train.parquet VAL_FILES=$HOME/data/geo3k/test.parquet \
MAX_PROMPT_LEN=1536 MAX_RESPONSE_LEN=1536 \
MODEL_ID=Qwen/Qwen2.5-VL-3B-Instruct \
ADV_ESTIMATOR=grpo RM_PAD=True USE_KL=True ENABLE_CHUNKED_PREFILL=False \
Expand Down
7 changes: 5 additions & 2 deletions 7 .github/workflows/kernels.yml
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ permissions:
contents: read

jobs:
e2e_gsm8k_megatron:
kernels:
runs-on: [L20x8]
timeout-minutes: 40 # Increase this timeout value as needed
env:
Expand All @@ -59,4 +59,7 @@ jobs:
pip3 install --no-deps -e .[test]
- name: Testing LinearCrossEntropy Correction, Computation Time and Memory Consumption
run: |
python3 tests/kernels/test_linear_cross_entropy.py
python3 tests/kernels/test_linear_cross_entropy.py
- name: Testing LinearCrossEntropyTP Correction, Computation Time and Memory Consumption
run: |
LOW_MEMORY=True torchrun --standalone --nnodes=1 --nproc-per-node=8 tests/kernels/test_linear_cross_entropy_tp.py
64 changes: 64 additions & 0 deletions 64 examples/ppo_trainer/run_qwen2-7b_rm_seq_balance_fused_kernels.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
set -x

gsm8k_train_path=$HOME/data/gsm8k/train.parquet
gsm8k_test_path=$HOME/data/gsm8k/test.parquet
math_train_path=$HOME/data/math/train.parquet
math_test_path=$HOME/data/math/test.parquet

train_files="['$gsm8k_train_path', '$math_train_path']"
test_files="['$gsm8k_test_path', '$math_test_path']"

FUSED_KERNEL_BACKEND=triton # or 'torch' for torch backend

python3 -m verl.trainer.main_ppo \
algorithm.adv_estimator=gae \
data.train_files="$train_files" \
data.val_files="$test_files" \
data.train_batch_size=4096 \
data.max_prompt_length=4096 \
data.max_response_length=4096 \
data.filter_overlong_prompts=True \
data.truncation='error' \
data.return_raw_chat=True \
actor_rollout_ref.model.path=Qwen/Qwen2-7B-Instruct \
actor_rollout_ref.actor.optim.lr=1e-6 \
actor_rollout_ref.model.use_remove_padding=True \
actor_rollout_ref.model.use_fused_kernels=True \
actor_rollout_ref.model.fused_kernel_options.impl_backend=$FUSED_KERNEL_BACKEND \
actor_rollout_ref.model.enable_gradient_checkpointing=True \
actor_rollout_ref.actor.ppo_mini_batch_size=512 \
actor_rollout_ref.actor.use_dynamic_bsz=True \
actor_rollout_ref.actor.ppo_max_token_len_per_gpu=24000 \
actor_rollout_ref.actor.fsdp_config.param_offload=False \
actor_rollout_ref.actor.fsdp_config.optimizer_offload=False \
actor_rollout_ref.actor.use_kl_loss=False \
actor_rollout_ref.rollout.tensor_model_parallel_size=2 \
actor_rollout_ref.rollout.name=vllm \
actor_rollout_ref.rollout.gpu_memory_utilization=0.5 \
actor_rollout_ref.rollout.log_prob_max_token_len_per_gpu=24000 \
critic.optim.lr=1e-5 \
critic.model.use_remove_padding=True \
critic.model.path=Qwen/Qwen2-7B-Instruct \
critic.model.enable_gradient_checkpointing=True \
critic.use_dynamic_bsz=True \
critic.ppo_max_token_len_per_gpu=98304 \
critic.model.fsdp_config.param_offload=False \
critic.model.fsdp_config.optimizer_offload=False \
reward_model.enable=True \
reward_model.model.path=sfairXC/FsfairX-LLaMA3-RM-v0.1\
reward_model.model.use_remove_padding=True \
reward_model.model.fsdp_config.param_offload=True \
reward_model.micro_batch_size_per_gpu=32 \
reward_model.use_dynamic_bsz=True \
reward_model.forward_max_token_len_per_gpu=98304 \
algorithm.use_kl_in_reward=False \
trainer.critic_warmup=0 \
trainer.logger=['console','wandb'] \
trainer.project_name='verl_example_gsm8k' \
trainer.experiment_name='qwen2-7b_hybrid_rm_bsz8k_p4k_r4k_seq_packing_fused_kernel' \
trainer.n_gpus_per_node=8 \
trainer.val_before_train=False \
trainer.nnodes=1 \
trainer.save_freq=20 \
trainer.test_freq=5 \
trainer.total_epochs=15 $@
2 changes: 2 additions & 0 deletions 2 recipe/prime/config/prime_trainer.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,8 @@ reward_model:
ref_path: ${reward_model.model.path}
use_remove_padding: True
use_fused_kernels: ${actor_rollout_ref.model.use_fused_kernels}
fused_kernel_options:
impl_backend: torch # triton, torch
tokenizer_path: ${actor_rollout_ref.model.path}
enable_gradient_checkpointing: ${actor_rollout_ref.model.enable_gradient_checkpointing}
ref_type: freeze
Expand Down
2 changes: 2 additions & 0 deletions 2 recipe/prime/prime_dp_rm.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,7 @@ def _forward_micro_batch(self, micro_batch, prompt_length):
attention_mask=None,
position_ids=position_ids_rmpad,
use_cache=False,
return_dict=self.use_fused_kernels,
)

if self.use_fused_kernels:
Expand All @@ -100,6 +101,7 @@ def _forward_micro_batch(self, micro_batch, prompt_length):
attention_mask=micro_batch["attention_mask"],
position_ids=micro_batch["position_ids"],
use_cache=False,
return_dict=self.use_fused_kernels,
)

if self.use_fused_kernels:
Expand Down
4 changes: 4 additions & 0 deletions 4 recipe/prime/prime_fsdp_workers.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,11 +129,15 @@ def _build_reward_ref_model_optimizer(self, config):
trust_remote_code=trust_remote_code,
)

fused_kernel_options = config.model.get("fused_kernel_options", None)
fused_kernels_backend = fused_kernel_options.get("impl_backend", None) if fused_kernel_options is not None else None

apply_monkey_patch(
model=reward_module,
ulysses_sp_size=self.ulysses_sequence_parallel_size,
use_remove_padding=config.model.get("use_remove_padding", False),
use_fused_kernels=config.model.get("use_fused_kernels", False),
fused_kernels_backend=fused_kernels_backend,
)

# some parameters may not in torch_dtype
Expand Down
5 changes: 5 additions & 0 deletions 5 recipe/spin/fsdp_workers.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,7 @@ def init_model(self):
override_model_config = OmegaConf.to_container(self.config.model.get('override_config', OmegaConf.create()))

use_remove_padding = self.config.model.get('use_remove_padding', False)
use_fused_kernels = self.config.model.get('use_fused_kernels', False)

if self._is_actor or self._is_rollout or self._is_ref:
# we need the model for actor and rollout
Expand All @@ -91,6 +92,7 @@ def init_model(self):
optim_config=optim_config,
override_model_config=override_model_config,
use_remove_padding=use_remove_padding,
use_fused_kernels=use_fused_kernels,
enable_gradient_checkpointing=self.config.model.get('enable_gradient_checkpointing', False),
trust_remote_code=self.config.model.get('trust_remote_code', False),
use_liger=self.config.model.get('use_liger', False),
Expand All @@ -107,6 +109,7 @@ def init_model(self):
OmegaConf.set_struct(self.config.actor, True)
with open_dict(self.config.actor):
self.config.actor.use_remove_padding = use_remove_padding
self.config.actor.use_fused_kernels = use_fused_kernels
self.actor = DataParallelPPOActor(config=self.config.actor,
actor_module=self.actor_module_fsdp,
actor_optimizer=self.actor_optimizer)
Expand All @@ -121,13 +124,15 @@ def init_model(self):
optim_config=None,
override_model_config=override_model_config,
use_remove_padding=use_remove_padding,
use_fused_kernels=use_fused_kernels,
trust_remote_code=self.config.model.get(
'trust_remote_code', False),
use_liger=self.config.model.get('use_liger', False),
role='ref')[0]
OmegaConf.set_struct(self.config.ref, True)
with open_dict(self.config.ref):
self.config.ref.use_remove_padding = use_remove_padding
self.config.ref.use_fused_kernels = use_fused_kernels
self.ref_policy = DataParallelPPOActor(config=self.config.ref, actor_module=self.ref_module_fsdp)
self.checkpoint_manager = FSDPCheckpointManager(
model=self.actor_module_fsdp,
Expand Down
5 changes: 5 additions & 0 deletions 5 recipe/sppo/sppo_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@ def init_model(self):
override_model_config = OmegaConf.to_container(self.config.model.get("override_config", OmegaConf.create()))

use_remove_padding = self.config.model.get("use_remove_padding", False)
use_fused_kernels = self.config.model.get("use_fused_kernels", False)

if self._is_actor or self._is_rollout:
# we need the model for actor and rollout
Expand All @@ -63,6 +64,7 @@ def init_model(self):
optim_config=optim_config,
override_model_config=override_model_config,
use_remove_padding=use_remove_padding,
use_fused_kernels=use_fused_kernels,
enable_gradient_checkpointing=self.config.model.get("enable_gradient_checkpointing", False),
trust_remote_code=self.config.model.get("trust_remote_code", False),
use_liger=self.config.model.get("use_liger", False),
Expand All @@ -84,6 +86,7 @@ def init_model(self):
OmegaConf.set_struct(self.config.actor, True)
with open_dict(self.config.actor):
self.config.actor.use_remove_padding = use_remove_padding
self.config.actor.use_fused_kernels = use_fused_kernels
self.actor = DataParallelSPPOActor(config=self.config.actor, actor_module=self.actor_module_fsdp, actor_optimizer=self.actor_optimizer)

if self._is_rollout:
Expand All @@ -96,13 +99,15 @@ def init_model(self):
optim_config=None,
override_model_config=override_model_config,
use_remove_padding=use_remove_padding,
use_fused_kernels=use_fused_kernels,
trust_remote_code=self.config.model.get("trust_remote_code", False),
use_liger=self.config.model.get("use_liger", False),
role="ref",
)[0]
OmegaConf.set_struct(self.config.ref, True)
with open_dict(self.config.ref):
self.config.ref.use_remove_padding = use_remove_padding
self.config.ref.use_fused_kernels = use_fused_kernels
self.ref_policy = DataParallelSPPOActor(config=self.config.ref, actor_module=self.ref_module_fsdp)

if self._is_actor:
Expand Down
2 changes: 2 additions & 0 deletions 2 tests/e2e/ppo_trainer/run_function_reward.sh
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ ACTOR_FSDP_OPTIMIZER_OFFLOAD=${ACTOR_FSDP_OPTIMIZER_OFFLOAD:-False}
REF_FSDP_PARAM_OFFLOAD=${REF_FSDP_PARAM_OFFLOAD:-True}
RM_PAD=${RM_PAD:-True}
FUSED_KERNELS=${FUSED_KERNELS:-False}
FUSED_KERNEL_BACKEND=${FUSED_KERNEL_BACKEND:-torch} # or 'triton' for triton backend
ADV_ESTIMATOR=${ADV_ESTIMATOR:-gae}
USE_KL=${USE_KL:-False}
CUSTOM_REWARD_FN=${CUSTOM_REWARD_FN:-False}
Expand Down Expand Up @@ -90,6 +91,7 @@ python3 -m verl.trainer.main_ppo \
actor_rollout_ref.actor.optim.lr=1e-6 \
actor_rollout_ref.model.use_remove_padding="${RM_PAD}" \
actor_rollout_ref.model.use_fused_kernels=${FUSED_KERNELS} \
actor_rollout_ref.model.fused_kernel_options.impl_backend=${FUSED_KERNEL_BACKEND} \
actor_rollout_ref.actor.ppo_mini_batch_size=${train_prompt_mini_bsz} \
actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu=${train_traj_micro_bsz_per_gpu} \
actor_rollout_ref.actor.strategy=${STRATEGY} \
Expand Down
4 changes: 4 additions & 0 deletions 4 tests/e2e/ppo_trainer/run_model_reward.sh
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,8 @@ TRAIN_FILES=${TRAIN_FILES:-$HOME/data/gsm8k/train.parquet}
VAL_FILES=${VAL_FILES:-$HOME/data/gsm8k/test.parquet}

RM_PAD=${RM_PAD:-True}
FUSED_KERNELS=${FUSED_KERNELS:-False}
FUSED_KERNEL_BACKEND=${FUSED_KERNEL_BACKEND:-torch} # or 'triton' for triton backend
SP_SIZE=${SP_SIZE:-1}
SEQ_BALANCE=${SEQ_BALANCE:-False}
LIGER=${LIGER:-False}
Expand Down Expand Up @@ -47,6 +49,8 @@ python3 -m verl.trainer.main_ppo \
actor_rollout_ref.model.use_liger="${LIGER}" \
actor_rollout_ref.actor.optim.lr=1e-6 \
actor_rollout_ref.model.use_remove_padding="${RM_PAD}" \
actor_rollout_ref.model.use_fused_kernels=${FUSED_KERNELS} \
actor_rollout_ref.model.fused_kernel_options.impl_backend=${FUSED_KERNEL_BACKEND} \
actor_rollout_ref.actor.optim.lr_warmup_steps_ratio=0.1 \
actor_rollout_ref.actor.ppo_mini_batch_size=${train_prompt_mini_bsz} \
actor_rollout_ref.actor.use_dynamic_bsz="${SEQ_BALANCE}" \
Expand Down
Loading
Loading
Morty Proxy This is a proxified and sanitized view of the page, visit original site.