-
Notifications
You must be signed in to change notification settings - Fork 2.3k
[trainer] refactor: Training Engine Interface and Development Plan #1977
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
eric-haibin-lin
merged 36 commits into
volcengine:main
from
ZihengJiang:ziheng/dev-0610
Jul 18, 2025
Merged
Changes from 1 commit
Commits
Show all changes
36 commits
Select commit
Hold shift + click to select a range
c60369c
init model done.
ZihengJiang cc42721
rebase.
ZihengJiang f28d409
update.
ZihengJiang d61f8f7
use legacy worker impl.
ZihengJiang fd49b29
format.
ZihengJiang aa61718
format.
ZihengJiang 514c20d
align config test.
ZihengJiang bb07a5a
hide engine class.
ZihengJiang 3c9deee
processsor class.
ZihengJiang e8d5d9a
license.
ZihengJiang 390fbad
move value out.
ZihengJiang a92778f
make loss_fn as member function.
ZihengJiang a6497d3
update doc.
ZihengJiang 819531c
format.
ZihengJiang 60aa94d
format.
ZihengJiang 1ffabff
format.
ZihengJiang 0b4a502
docs.
ZihengJiang 0de48f8
add type annotation for train_batch interface.
ZihengJiang 1f824b2
support multi outputs in infer_batch.
ZihengJiang fa5be48
introduce post_fn, remove processor
ZihengJiang 9cf3e5e
format.
ZihengJiang 668bf28
add engine registry.
ZihengJiang 440715b
address comment.
ZihengJiang 0d2835c
address comment.
ZihengJiang 791e231
address comment.
ZihengJiang 0465afe
add e2e test for sanity check.
ZihengJiang b07d596
update.
ZihengJiang 5a65b6b
format.
ZihengJiang d37ec8d
update.
ZihengJiang db52ccb
update.
ZihengJiang 41e2333
update.
ZihengJiang 3863e83
fix.
ZihengJiang f3116b6
fix.
ZihengJiang 0cebf37
fix.
ZihengJiang e61145f
update.
ZihengJiang 263ea10
update.
ZihengJiang File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Next
Next commit
init model done.
eval done. update done. support rmpad. checkpoint done. clean code. update interface. add comment. minor update. compute_value done. update done. add process function back. update interface. move ulysses into engine. Actor worker interface. init megatron engine code structure.
- Loading branch information
commit c60369caaf9a71be2a4451436f52c91be519a3f6
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,3 @@ | ||
from .base import BaseEngine | ||
|
||
__all__ = ["BaseEngine"] |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,194 @@ | ||
# Copyright 2024 Bytedance Ltd. and/or its affiliates | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
""" | ||
The abstract base class defining the interface for model training engines. | ||
""" | ||
|
||
|
||
class BaseEngine(object): | ||
""" | ||
Abstract base class defining the interface for model training engines. | ||
|
||
Engine implementations must subclass BaseEngine and provide concrete behavior for all methods. | ||
""" | ||
def __init__(self, config): | ||
""" | ||
Initialize the BaseEngine. | ||
|
||
Args: | ||
config: Configuration object containing parameters for engine setup. | ||
""" | ||
raise NotImplementedError | ||
|
||
def init_model(self): | ||
""" | ||
Instantiate or load the model, optimizer, and learning rate scheduler. | ||
|
||
Should prepare all components necessary for training or evaluation. | ||
""" | ||
raise NotImplementedError | ||
|
||
def train_mode(self): | ||
""" | ||
Context manager entry for switching the engine and model into training mode. | ||
|
||
Usage: | ||
with engine.train_mode(): | ||
# runs in training mode | ||
""" | ||
raise NotImplementedError | ||
|
||
def eval_mode(self): | ||
""" | ||
Context manager entry for switching the engine and model into evaluation mode. | ||
|
||
Usage: | ||
with engine.eval_mode(): | ||
# runs in evaluation mode | ||
""" | ||
raise NotImplementedError | ||
|
||
eric-haibin-lin marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
||
def infer_batch(self, | ||
batch, | ||
ctx=None, | ||
preprocess_fn=None, | ||
postprocess_fn=None): | ||
""" | ||
Execute a forward pass over a batch of data. | ||
|
||
Args: | ||
batch: Raw batch data (e.g., tensors or mappings) to process. | ||
ctx: Optional context dict passed to preprocess/postprocess functions. | ||
preprocess_fn: Function(batch, ctx) -> (inputs, ctx), applied before model call. | ||
postprocess_fn: Function(outputs, ctx) -> (predictions, ctx), applied after model call. | ||
|
||
Returns: | ||
(predictions, ctx) | ||
""" | ||
raise NotImplementedError | ||
|
||
|
||
def train_batch(self, | ||
batch, | ||
ctx=None, | ||
preprocess_fn=None, | ||
postprocess_fn=None): | ||
""" | ||
Execute a forward pass and backward pass over a batch of data. | ||
|
||
Args: | ||
batch: Raw batch data (e.g., tensors or mappings) to process. | ||
ctx: Optional context dict passed to preprocess/postprocess functions. | ||
preprocess_fn: Function(batch, ctx) -> (inputs, ctx), applied before model call. | ||
postprocess_fn: Function(outputs, ctx) -> (predictions, ctx), applied after model call. | ||
|
||
Returns: | ||
(predictions, loss, ctx) | ||
""" | ||
raise NotImplementedError | ||
|
||
def optimizer_zero_grad(self): | ||
""" | ||
Zero out gradients of all parameters before starting a new backward pass. | ||
""" | ||
raise NotImplementedError | ||
|
||
def optimizer_step(self): | ||
""" | ||
Perform an optimization step to update model parameters based on accumulated gradients. | ||
|
||
Returns: | ||
grad_norm (float): The norm of the gradients before clipping or update. | ||
""" | ||
raise NotImplementedError | ||
|
||
def lr_scheduler_step(self): | ||
""" | ||
Advance the learning rate scheduler by one step. | ||
|
||
Returns: | ||
current_lr (float or list[float]): Updated learning rate(s). | ||
""" | ||
raise NotImplementedError | ||
|
||
def shard_data(self, data): | ||
""" | ||
Shard or partition data for distributed training or parallel execution. | ||
|
||
Args: | ||
data: Data structure to be sharded across devices/workers. | ||
|
||
Returns: | ||
Sharded data in the same format as input. | ||
""" | ||
raise NotImplementedError | ||
|
||
def unshard_data(self, data): | ||
""" | ||
Reconstruct or gather sharded data back to a unified format. | ||
|
||
Args: | ||
data: Sharded data structure to reconstruct. | ||
|
||
Returns: | ||
Unsharded, combined data. | ||
""" | ||
raise NotImplementedError | ||
|
||
|
||
def set_loss_fn(self, loss_fn): | ||
""" | ||
Set the loss function to be used during training. | ||
|
||
Args: | ||
loss_fn: Callable(data, predictions, ctx) -> (loss_tensor, new_ctx) | ||
""" | ||
raise NotImplementedError | ||
|
||
def to(self, device: str, model: bool = True, optimizer: bool = True): | ||
""" | ||
Move model parameters, optimizer states, or both to the specified device. | ||
|
||
Args: | ||
device: Target device identifier (e.g., "cuda" or "cpu"). | ||
model: If True, move the model. | ||
optimizer: If True, move the optimizer states. | ||
""" | ||
raise NotImplementedError | ||
|
||
|
||
def save_checkpoint(self, local_path, hdfs_path=None, global_step=0, max_ckpt_to_keep=None): | ||
""" | ||
Save model, optimizer, and scheduler states to a checkpoint. | ||
|
||
Args: | ||
local_path: Local filesystem path to save checkpoint. | ||
hdfs_path: Optional HDFS path to copy checkpoint. | ||
global_step: Integer training step number for naming. | ||
max_ckpt_to_keep: Maximum number of recent checkpoints to retain. | ||
""" | ||
raise NotImplementedError | ||
|
||
|
||
def load_checkpoint(self, local_path, hdfs_path=None, del_local_after_load=True): | ||
""" | ||
Load model, optimizer, and scheduler states from a checkpoint. | ||
|
||
Args: | ||
local_path: Local filesystem path of the checkpoint. | ||
hdfs_path: Optional HDFS path where checkpoint is stored. | ||
del_local_after_load: Whether to delete local copy after loading. | ||
""" | ||
raise NotImplementedError |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,16 @@ | ||
# Copyright 2024 Bytedance Ltd. and/or its affiliates | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
from .engine_impl import FSDPEngine | ||
|
||
__all__ = ["FSDPEngine"] |
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.