Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Appearance settings
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
21 commits
Select commit Hold shift + click to select a range
c346f0f
refactor: moved Main from __main__ to main
le1nux Apr 24, 2025
90c4bbe
feat: added batch generator util
le1nux Apr 24, 2025
c87688a
refactor: split experiment_id syncing into multiple utity functions
le1nux Apr 24, 2025
ebaf13e
feat: implemented grid search setup for profiling
le1nux Apr 24, 2025
bcc0e7b
refactor: added OOM error handling in CudaEnv
le1nux Apr 24, 2025
899401a
feat: added torchrun script for distributed profiling
le1nux Apr 24, 2025
87296d1
feat: added profiling README
le1nux Apr 24, 2025
d6dddc0
feat: added profiler implementation
le1nux Apr 24, 2025
011e41b
feat: drafted profile logs analyzer
le1nux Apr 24, 2025
e6da67f
chore: minor renamings
le1nux Apr 24, 2025
2d4a9ad
refactor: making sure that each compil
le1nux Apr 27, 2025
d990420
feat: added torchrun launcher
le1nux Apr 27, 2025
f4a5b48
refactor: wrapped up the profiler_starter
le1nux Apr 27, 2025
b2aa30b
feat: added activation checkpoint profiling example
le1nux Apr 27, 2025
94025af
feat: setup forward pass profiling
le1nux Apr 29, 2025
3ea7cf5
feat: ops in selective op activation checkpointing are now configurable
le1nux Apr 29, 2025
0b99db2
feat: added profliing logs analysis notebook
le1nux Apr 29, 2025
1889776
refactor: adapted the configs and evaluation code for selective op AC
le1nux Apr 29, 2025
c718182
chore: added profiling experiments to gitignore
le1nux Apr 29, 2025
522489e
chore: Merge branch 'fsdp2_activation_checkpointing' into profiling_f…
le1nux May 26, 2025
4da7d2c
chore: fix failing unit tests
le1nux May 26, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion 1 .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -163,5 +163,4 @@ tests/tmp/*
*wandb_storage*
.coverage/*
*.pbin

tutorials/profiling/experiments
257 changes: 51 additions & 206 deletions 257 src/modalities/__main__.py
Original file line number Diff line number Diff line change
@@ -1,19 +1,14 @@
#!/usr/bin/env python

import json
import logging
import os
import shutil
from datetime import datetime
from functools import partial
from pathlib import Path
from typing import Callable, Optional, Type
from typing import Optional

import click
import click_pathlib
import yaml
from omegaconf import DictConfig
from pydantic import BaseModel, FilePath
from pydantic import FilePath

from modalities.api import (
FileExistencePolicy,
Expand All @@ -27,22 +22,12 @@
shuffle_jsonl_data,
shuffle_tokenized_data,
)
from modalities.batch import EvaluationResultBatch
from modalities.config.component_factory import ComponentFactory
from modalities.config.config import ProcessGroupBackendType, load_app_config_dict
from modalities.config.instantiation_models import TrainingComponentsInstantiationModel, TrainingReportGenerator
from modalities.evaluator import Evaluator
from modalities.gym import Gym
from modalities.logging_broker.message_broker import MessageBroker
from modalities.logging_broker.messages import MessageTypes, ProgressUpdate
from modalities.logging_broker.publisher import MessagePublisher
from modalities.logging_broker.subscriber import MessageSubscriberIF
from modalities.config.instantiation_models import TrainingComponentsInstantiationModel
from modalities.main import Main
from modalities.models.huggingface_adapters.hf_adapter import HFModelAdapter
from modalities.registry.components import COMPONENTS
from modalities.registry.registry import Registry
from modalities.running_env.cuda_env import CudaEnv
from modalities.trainer import Trainer
from modalities.util import get_experiment_id_of_run, get_total_number_of_trainable_parameters, print_rank_0
from modalities.utils.profilers.modalities_profiler import ModalitiesProfiler


@click.group()
Expand Down Expand Up @@ -511,194 +496,54 @@ def CMD_shuffle_jsonl_data(
)


class Main:
"""Main class that orchestrates the training process."""

def __init__(
self,
config_path: Path,
additional_resolver_funs: Optional[dict[str, Callable]] = None,
experiment_id: Optional[str] = None,
) -> None:
if experiment_id is None:
experiment_id = get_experiment_id_of_run(config_path)

self.config_dict = load_app_config_dict(
config_file_path=config_path, experiment_id=experiment_id, additional_resolver_funs=additional_resolver_funs
)
self.config_path = config_path

self.registry = Registry(COMPONENTS)
self.component_factory = ComponentFactory(registry=self.registry)

def add_custom_component(
self, component_key: str, variant_key: str, custom_component: Type, custom_config: Type
) -> None:
"""Add a custom component to the registry.

This method comes in especially handy
when Modalities is used as a library and the user wants to add custom components
(e.g., custom model or custom loss function) to the registry.

Args:
component_key (str): Key of the component to be added to the registry
variant_key (str): Key of the variant to be added to the registry
custom_component (Type): The class type of the custom component
custom_config (Type): The pydantic config type of the custom component
"""
self.registry.add_entity(
component_key=component_key,
variant_key=variant_key,
component_type=custom_component,
component_config_type=custom_config,
)

def build_components(self, components_model_type: Type[BaseModel]) -> BaseModel:
"""Given a pydantic basemodel, this method builds the components specified in the config file.

Depending on the use case (e.g., training, inference, etc.), the user can pass different pydantic base models.
For instance, for tokenization, the basemodel would only have the tokenization-related components specified.

Args:
components_model_type (Type[BaseModel]): The pydantic basemodel type that should be
used to build the components.

Returns:
BaseModel: The components built based on the config file.
"""
components = self.component_factory.build_components(
config_dict=self.config_dict, components_model_type=components_model_type
)
return components

def run(self, components: TrainingComponentsInstantiationModel):
"""Entrypoint fo running the training process.

We pass in a TrainingComponentsInstantiationModel,
which is a pydantic model that contains all the components needed for the training process.

Args:
components (TrainingComponentsInstantiationModel): The components needed for the training process.
"""
# save the config file to the checkpointing path
if components.settings.cuda_env.global_rank == 0:
experiment_path = components.settings.paths.checkpoint_saving_path / components.settings.experiment_id
os.makedirs(experiment_path, exist_ok=True)
shutil.copy(self.config_path, experiment_path / self.config_path.name)
resolved_config_path = (experiment_path / self.config_path.name).with_suffix(".yaml.resolved")
with open(resolved_config_path, "w", encoding="utf-8") as f:
yaml.dump(self.config_dict, f)

evaluation_result_publisher, progress_publisher = self.get_logging_publishers(
progress_subscriber=components.progress_subscriber,
results_subscriber=components.evaluation_subscriber,
global_rank=components.settings.cuda_env.global_rank,
local_rank=components.settings.cuda_env.local_rank,
)

# Trainer
global_num_tokens_per_train_step = (
components.settings.step_profile.local_train_micro_batch_size
* components.settings.step_profile.sequence_length
* components.settings.step_profile.gradient_accumulation_steps
* components.settings.cuda_env.world_size
)
trainer = Trainer(
global_rank=components.settings.cuda_env.global_rank,
progress_publisher=progress_publisher,
num_target_steps=components.settings.training_target.num_target_steps,
num_target_tokens=components.settings.training_target.num_target_tokens,
num_seen_train_steps=components.settings.training_progress.num_seen_steps,
global_num_seen_tokens=components.settings.training_progress.global_num_seen_tokens,
evaluation_result_publisher=evaluation_result_publisher,
gradient_acc_steps=components.settings.step_profile.gradient_accumulation_steps,
gradient_clipper=components.gradient_clipper,
global_num_tokens_per_train_step=global_num_tokens_per_train_step,
mfu_calculator=components.mfu_calculator,
)

# Evaluator
evaluator = Evaluator(
progress_publisher=progress_publisher,
evaluation_result_publisher=evaluation_result_publisher,
)

# Gym
gym = Gym(
trainer=trainer,
evaluator=evaluator,
loss_fun=components.loss_fn,
num_ranks=components.settings.cuda_env.world_size,
)
num_params = get_total_number_of_trainable_parameters(components.app_state.model)
components.evaluation_subscriber.consume_dict({"No. parameters": num_params})
logging.info(f"Training model with {num_params} parameters.")

print_rank_0(f"Model initialized at {datetime.now()}.")

report = TrainingReportGenerator(
training_target=components.settings.training_target,
intervals=components.settings.intervals,
step_profile=components.settings.step_profile,
cuda_env=components.settings.cuda_env,
consistency_enforcement=components.settings.consistency_enforcement,
train_dataset=components.train_dataset,
training_progress=components.settings.training_progress,
).get_report()

print_rank_0(report)

gym.run(
train_data_loader=components.train_dataloader,
evaluation_data_loaders=components.eval_dataloaders,
checkpoint_saving=components.checkpoint_saving,
app_state=components.app_state,
checkpointing_interval_in_steps=components.settings.intervals.checkpointing_interval_in_steps,
evaluation_interval_in_steps=components.settings.intervals.evaluation_interval_in_steps,
training_log_interval_in_steps=components.settings.intervals.training_log_interval_in_steps,
)

def get_logging_publishers(
self,
progress_subscriber: MessageSubscriberIF[ProgressUpdate],
results_subscriber: MessageSubscriberIF[EvaluationResultBatch],
global_rank: int,
local_rank: int,
) -> tuple[MessagePublisher[EvaluationResultBatch], MessagePublisher[ProgressUpdate]]:
"""Returns the logging publishers for the training.

These publishers are used to pass the evaluation results and the progress updates to the message broker.
The message broker is then used to pass the messages to the subscribers, such as WandB.

Args:
progress_subscriber (MessageSubscriberIF[ProgressUpdate]): The progress subscriber
results_subscriber (MessageSubscriberIF[EvaluationResultBatch]): The results subscriber
global_rank (int): The global rank of the current process
local_rank (int): The local rank of the current process on the current node

Returns:
tuple[MessagePublisher[EvaluationResultBatch], MessagePublisher[ProgressUpdate]]: The evaluation
result publisher and the progress publisher
"""
message_broker = MessageBroker()
progress_publisher = MessagePublisher[ProgressUpdate](
message_broker=message_broker,
global_rank=global_rank,
local_rank=local_rank,
)
evaluation_result_publisher = MessagePublisher[EvaluationResultBatch](
message_broker=message_broker,
global_rank=global_rank,
local_rank=local_rank,
)
@main.group(name="profile")
def profile():
"""
Collection of utilities to profile modalities.
"""
pass

message_broker.add_subscriber(subscription=MessageTypes.EVALUATION_RESULT, subscriber=results_subscriber)
message_broker.add_subscriber(
subscription=MessageTypes.BATCH_PROGRESS_UPDATE,
subscriber=progress_subscriber,
)

return evaluation_result_publisher, progress_publisher
@profile.command(name="train_step")
@click.option(
"--config_file_path",
type=click_pathlib.Path(exists=True),
required=True,
help="Path to the YAML training config file.",
)
@click.option(
"--experiment_folder_path",
type=click_pathlib.Path(file_okay=False),
required=True,
help="Path to the experiment output directory.",
)
@click.option(
"--num_warmup_steps",
type=int,
default=1,
show_default=True,
help="Number of warmup steps to skip in profiling.",
)
@click.option(
"--num_measurement_steps",
type=int,
default=3,
show_default=True,
help="Number of steps to measure during profiling.",
)
def CMD_entry_point_run_train_step_profiler(
config_file_path: Path,
experiment_folder_path: Path,
num_warmup_steps: int,
num_measurement_steps: int,
):
"""Run train step profiler and write result to JSON if RANK=0."""
ModalitiesProfiler.get_train_step_statistics(
config_file_path=config_file_path,
experiment_folder_path=experiment_folder_path,
num_warmup_steps=num_warmup_steps,
num_measurement_steps=num_measurement_steps,
)


if __name__ == "__main__":
Expand Down
2 changes: 1 addition & 1 deletion 2 src/modalities/config/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -315,7 +315,7 @@ class SelectiveLayerACParams(BaseModel):
ac_freq: int

class SelectiveOpACParams(BaseModel):
pass
save_ops_keys: list[str]

sac_variant: SelectiveActivationCheckpointingVariants
layers_fqn: str
Expand Down
4 changes: 4 additions & 0 deletions 4 src/modalities/config/pydantic_if_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
from modalities.tokenization.tokenizer_wrapper import TokenizerWrapper
from modalities.training.gradient_clipping.gradient_clipper import GradientClipperIF
from modalities.utils.mfu import MFUCalculatorABC
from modalities.utils.profilers.batch_generator import DatasetBatchGeneratorIF


class PydanticThirdPartyTypeIF:
Expand Down Expand Up @@ -79,3 +80,6 @@ def __get_pydantic_core_schema__(
PydanticDeviceMeshIFType = Annotated[DeviceMesh, PydanticThirdPartyTypeIF(DeviceMesh)]
PydanticAppStateType = Annotated[AppState, PydanticThirdPartyTypeIF(AppState)]
PydanticMFUCalculatorABCType = Annotated[MFUCalculatorABC, PydanticThirdPartyTypeIF(MFUCalculatorABC)]
PydanticDatasetBatchGeneratorIFType = Annotated[
DatasetBatchGeneratorIF, PydanticThirdPartyTypeIF(DatasetBatchGeneratorIF)
]
Loading
Morty Proxy This is a proxified and sanitized view of the page, visit original site.