Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Appearance settings
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
Show all changes
36 commits
Select commit Hold shift + click to select a range
c60369c
init model done.
ZihengJiang Jun 10, 2025
cc42721
rebase.
ZihengJiang Jul 8, 2025
f28d409
update.
ZihengJiang Jul 8, 2025
d61f8f7
use legacy worker impl.
ZihengJiang Jul 8, 2025
fd49b29
format.
ZihengJiang Jul 9, 2025
aa61718
format.
ZihengJiang Jul 9, 2025
514c20d
align config test.
ZihengJiang Jul 9, 2025
bb07a5a
hide engine class.
ZihengJiang Jul 10, 2025
3c9deee
processsor class.
ZihengJiang Jul 10, 2025
e8d5d9a
license.
ZihengJiang Jul 10, 2025
390fbad
move value out.
ZihengJiang Jul 14, 2025
a92778f
make loss_fn as member function.
ZihengJiang Jul 14, 2025
a6497d3
update doc.
ZihengJiang Jul 14, 2025
819531c
format.
ZihengJiang Jul 14, 2025
60aa94d
format.
ZihengJiang Jul 14, 2025
1ffabff
format.
ZihengJiang Jul 14, 2025
0b4a502
docs.
ZihengJiang Jul 14, 2025
0de48f8
add type annotation for train_batch interface.
ZihengJiang Jul 15, 2025
1f824b2
support multi outputs in infer_batch.
ZihengJiang Jul 15, 2025
fa5be48
introduce post_fn, remove processor
ZihengJiang Jul 16, 2025
9cf3e5e
format.
ZihengJiang Jul 16, 2025
668bf28
add engine registry.
ZihengJiang Jul 16, 2025
440715b
address comment.
ZihengJiang Jul 16, 2025
0d2835c
address comment.
ZihengJiang Jul 16, 2025
791e231
address comment.
ZihengJiang Jul 16, 2025
0465afe
add e2e test for sanity check.
ZihengJiang Jul 17, 2025
b07d596
update.
ZihengJiang Jul 17, 2025
5a65b6b
format.
ZihengJiang Jul 17, 2025
d37ec8d
update.
ZihengJiang Jul 17, 2025
db52ccb
update.
ZihengJiang Jul 17, 2025
41e2333
update.
ZihengJiang Jul 17, 2025
3863e83
fix.
ZihengJiang Jul 17, 2025
f3116b6
fix.
ZihengJiang Jul 17, 2025
0cebf37
fix.
ZihengJiang Jul 17, 2025
e61145f
update.
ZihengJiang Jul 18, 2025
263ea10
update.
ZihengJiang Jul 18, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Next Next commit
init model done.
eval done.

update done.

support rmpad.

checkpoint done.

clean code.

update interface.

add comment.

minor update.

compute_value done.

update done.

add process function back.

update interface.

move ulysses into engine.

Actor worker interface.

init megatron engine code structure.
  • Loading branch information
ZihengJiang committed Jul 17, 2025
commit c60369caaf9a71be2a4451436f52c91be519a3f6
1 change: 1 addition & 0 deletions 1 examples/ppo_trainer/run_deepseek7b_llm.sh
Original file line number Diff line number Diff line change
Expand Up @@ -38,4 +38,5 @@ python3 -m verl.trainer.main_ppo \
trainer.nnodes=1 \
trainer.save_freq=20 \
trainer.test_freq=1 \
trainer.use_legacy_worker_impl=False \
trainer.total_epochs=15 $@
3 changes: 3 additions & 0 deletions 3 verl/trainer/config/ppo_trainer.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -322,6 +322,9 @@ trainer:
# Device to run training on (e.g., "cuda", "cpu")
device: cuda

# whether to use legacy role implementation
use_legacy_worker_impl: False

# configs related to ray initialization
ray_init:

Expand Down
10 changes: 9 additions & 1 deletion 10 verl/trainer/main_ppo.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,7 +127,15 @@ def run(self, config):
if config.actor_rollout_ref.actor.strategy in {"fsdp", "fsdp2"}:
assert config.critic.strategy in {"fsdp", "fsdp2"}
from verl.single_controller.ray import RayWorkerGroup
from verl.workers.fsdp_workers import ActorRolloutRefWorker, AsyncActorRolloutRefWorker, CriticWorker
from verl.workers.fsdp_workers import ActorRolloutRefWorker, AsyncActorRolloutRefWorker

if config.trainer.use_legacy_worker_impl:
# import warnings
# warnings.warn(f"Legacy worker impl is going to be deprecated, will be removed in the future. \
# Please use the new worker impl supported for PPO trainer.")
ZihengJiang marked this conversation as resolved.
Outdated
Show resolved Hide resolved
from verl.workers.fsdp_workers import CriticWorker
else:
from verl.workers.roles import CriticWorker

actor_rollout_cls = (
AsyncActorRolloutRefWorker
Expand Down
3 changes: 3 additions & 0 deletions 3 verl/workers/engine/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
from .base import BaseEngine

__all__ = ["BaseEngine"]
194 changes: 194 additions & 0 deletions 194 verl/workers/engine/base.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,194 @@
# Copyright 2024 Bytedance Ltd. and/or its affiliates
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
The abstract base class defining the interface for model training engines.
"""


class BaseEngine(object):
"""
Abstract base class defining the interface for model training engines.

Engine implementations must subclass BaseEngine and provide concrete behavior for all methods.
"""
def __init__(self, config):
"""
Initialize the BaseEngine.

Args:
config: Configuration object containing parameters for engine setup.
"""
raise NotImplementedError

def init_model(self):
"""
Instantiate or load the model, optimizer, and learning rate scheduler.

Should prepare all components necessary for training or evaluation.
"""
raise NotImplementedError

def train_mode(self):
"""
Context manager entry for switching the engine and model into training mode.

Usage:
with engine.train_mode():
# runs in training mode
"""
raise NotImplementedError

def eval_mode(self):
"""
Context manager entry for switching the engine and model into evaluation mode.

Usage:
with engine.eval_mode():
# runs in evaluation mode
"""
raise NotImplementedError

eric-haibin-lin marked this conversation as resolved.
Show resolved Hide resolved

def infer_batch(self,
batch,
ctx=None,
preprocess_fn=None,
postprocess_fn=None):
"""
Execute a forward pass over a batch of data.

Args:
batch: Raw batch data (e.g., tensors or mappings) to process.
ctx: Optional context dict passed to preprocess/postprocess functions.
preprocess_fn: Function(batch, ctx) -> (inputs, ctx), applied before model call.
postprocess_fn: Function(outputs, ctx) -> (predictions, ctx), applied after model call.

Returns:
(predictions, ctx)
"""
raise NotImplementedError


def train_batch(self,
batch,
ctx=None,
preprocess_fn=None,
postprocess_fn=None):
"""
Execute a forward pass and backward pass over a batch of data.

Args:
batch: Raw batch data (e.g., tensors or mappings) to process.
ctx: Optional context dict passed to preprocess/postprocess functions.
preprocess_fn: Function(batch, ctx) -> (inputs, ctx), applied before model call.
postprocess_fn: Function(outputs, ctx) -> (predictions, ctx), applied after model call.

Returns:
(predictions, loss, ctx)
"""
raise NotImplementedError

def optimizer_zero_grad(self):
"""
Zero out gradients of all parameters before starting a new backward pass.
"""
raise NotImplementedError

def optimizer_step(self):
"""
Perform an optimization step to update model parameters based on accumulated gradients.

Returns:
grad_norm (float): The norm of the gradients before clipping or update.
"""
raise NotImplementedError

def lr_scheduler_step(self):
"""
Advance the learning rate scheduler by one step.

Returns:
current_lr (float or list[float]): Updated learning rate(s).
"""
raise NotImplementedError

def shard_data(self, data):
"""
Shard or partition data for distributed training or parallel execution.

Args:
data: Data structure to be sharded across devices/workers.

Returns:
Sharded data in the same format as input.
"""
raise NotImplementedError

def unshard_data(self, data):
"""
Reconstruct or gather sharded data back to a unified format.

Args:
data: Sharded data structure to reconstruct.

Returns:
Unsharded, combined data.
"""
raise NotImplementedError


def set_loss_fn(self, loss_fn):
"""
Set the loss function to be used during training.

Args:
loss_fn: Callable(data, predictions, ctx) -> (loss_tensor, new_ctx)
"""
raise NotImplementedError

def to(self, device: str, model: bool = True, optimizer: bool = True):
"""
Move model parameters, optimizer states, or both to the specified device.

Args:
device: Target device identifier (e.g., "cuda" or "cpu").
model: If True, move the model.
optimizer: If True, move the optimizer states.
"""
raise NotImplementedError


def save_checkpoint(self, local_path, hdfs_path=None, global_step=0, max_ckpt_to_keep=None):
"""
Save model, optimizer, and scheduler states to a checkpoint.

Args:
local_path: Local filesystem path to save checkpoint.
hdfs_path: Optional HDFS path to copy checkpoint.
global_step: Integer training step number for naming.
max_ckpt_to_keep: Maximum number of recent checkpoints to retain.
"""
raise NotImplementedError


def load_checkpoint(self, local_path, hdfs_path=None, del_local_after_load=True):
"""
Load model, optimizer, and scheduler states from a checkpoint.

Args:
local_path: Local filesystem path of the checkpoint.
hdfs_path: Optional HDFS path where checkpoint is stored.
del_local_after_load: Whether to delete local copy after loading.
"""
raise NotImplementedError
16 changes: 16 additions & 0 deletions 16 verl/workers/engine/fsdp/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
# Copyright 2024 Bytedance Ltd. and/or its affiliates
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from .engine_impl import FSDPEngine

__all__ = ["FSDPEngine"]
Loading
Morty Proxy This is a proxified and sanitized view of the page, visit original site.