Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Appearance settings

Add gradient accumulation/microbatching support to DPOTrainModule#1447

Merged
finbarrtimbers merged 14 commits intomainallenai/open-instruct:mainfrom
finbarr/dpo-mfu-pr5-microbatchingallenai/open-instruct:finbarr/dpo-mfu-pr5-microbatchingCopy head branch name to clipboard
Feb 11, 2026
Merged

Add gradient accumulation/microbatching support to DPOTrainModule#1447
finbarrtimbers merged 14 commits intomainallenai/open-instruct:mainfrom
finbarr/dpo-mfu-pr5-microbatchingallenai/open-instruct:finbarr/dpo-mfu-pr5-microbatchingCopy head branch name to clipboard

Conversation

@finbarrtimbers
Copy link
Copy Markdown
Collaborator

@finbarrtimbers finbarrtimbers commented Jan 30, 2026

Adds gradient accumulation to DPOTrainModule.

Runs

@gemini-code-assist
Copy link
Copy Markdown
Contributor

Summary of Changes

Hello @finbarrtimbers, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!

This pull request significantly enhances the DPO training module by introducing robust support for gradient accumulation and microbatching. These changes enable more efficient training with larger effective batch sizes, which is crucial for optimizing resource utilization and improving model performance, especially in distributed environments. The modifications involve restructuring the training loop, refining batch size calculations, and integrating mechanisms for proper metric accumulation across micro-batches.

Highlights

  • Gradient Accumulation Support: Implemented gradient accumulation and microbatching for DPO training, allowing for larger effective batch sizes without increasing memory footprint.
  • Batch Splitting Utility: Introduced a new utility function, split_batch_dpo(), to efficiently divide DPO batches into smaller micro-batches for sequential processing.
  • Microbatch Size Configuration: Added the rank_microbatch_size parameter to DPOTrainModule's initialization, providing explicit control over the size of micro-batches processed per rank.
  • Accumulator Tensors for Metrics: Integrated dedicated tensors (_total_loss, _total_chosen_logps, etc.) within DPOTrainModule to accurately accumulate metrics across multiple micro-batches before a single optimization step.
  • Distributed Synchronization Control: Created _train_microbatch_context() to manage FSDP/DDP synchronization during gradient accumulation steps, ensuring gradients are only synchronized on the final micro-batch.
  • Refactored Training Loop: Modified the train_batch() method to properly loop through micro-batches, ensuring correct loss scaling and metric accumulation for each micro-batch.
  • Token Count Metric: Added token_count calculation in DataCollatorForSeq2SeqDPO and integrated its recording as a training metric in DPOTrainModule.
  • Batch Size Calculation Fix: Corrected the calculation of rank_batch_size to properly incorporate gradient_accumulation_steps, ensuring accurate effective batch size computation.
  • Optimized Cache Batch Size: Updated the cache_batch_size calculation for inference to a more efficient dp_world_size * 2.

🧠 New Feature in Public Preview: You can now enable Memory to help Gemini Code Assist learn from your team's feedback. This makes future code reviews more consistent and personalized to your project's style. Click here to enable Memory in your admin console.

Using Gemini Code Assist

The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.

Invoking Gemini

You can request assistance from Gemini at any point by creating a comment using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands on the current page.

Feature Command Description
Code Review /gemini review Performs a code review for the current pull request in its current state.
Pull Request Summary /gemini summary Provides a summary of the current pull request in its current state.
Comment @gemini-code-assist Responds in comments when explicitly tagged, both in pull request comments and review comments.
Help /gemini help Displays a list of available commands.

Customization

To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.

Limitations & Feedback

Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here.

You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.

Footnotes

  1. Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution.

Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces gradient accumulation and micro-batching for DPO training, which is a great enhancement for training on larger batch sizes with limited memory. The changes involve refactoring the training loop, adding a batch splitting utility, and updating metric accumulation.

My review has identified a couple of critical issues:

  1. There's a potential runtime error when using packing=True with concatenated_forward=False, as the separate_forward function doesn't support packing. This issue appears in both dpo.py for reference cache generation and in DPOTrainModule for the training step.
  2. The accuracy metric calculation in DPOTrainModule is incorrect. It compares the mean of rewards instead of calculating the mean of per-sample accuracies, which will lead to misleading metric reports.

I've left specific comments with suggestions on how to address these points. Once these are fixed, the implementation will be much more robust.

Comment thread open_instruct/dpo.py Outdated
Comment thread open_instruct/olmo_core_train_modules.py Outdated
Comment thread open_instruct/olmo_core_train_modules.py Outdated
Copy link
Copy Markdown

@chatgpt-codex-connector chatgpt-codex-connector Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

💡 Codex Review

Here are some automated review suggestions for this pull request.

Reviewed commit: 0f9a766699

ℹ️ About Codex in GitHub

Your team has set up Codex to review pull requests in this repo. Reviews are triggered when you

  • Open a pull request for review
  • Mark a draft as ready
  • Comment "@codex review".

If Codex has suggestions, it will comment; otherwise it will react with 👍.

Codex can also answer questions or update the PR. Try commenting "@codex address that feedback".

Comment thread open_instruct/olmo_core_train_modules.py Outdated
Comment thread open_instruct/olmo_core_train_modules.py Outdated
@finbarrtimbers finbarrtimbers force-pushed the finbarr/dpo-mfu-pr5-microbatching branch 9 times, most recently from 1ff30b6 to ba2c8c0 Compare February 3, 2026 21:00
@finbarrtimbers finbarrtimbers changed the title Add gradient accumulation/microbatching support Add gradient accumulation/microbatching support to DPOTrainModule Feb 3, 2026
finbarrtimbers and others added 2 commits February 6, 2026 07:17
- Add split_batch_dpo() function to split DPO batches into micro-batches
- Add rank_microbatch_size parameter to DPOTrainModule.__init__()
- Add accumulator tensors (_total_loss, _total_chosen_logps, etc.)
- Add _train_microbatch_context() for FSDP/DDP sync control
- Refactor train_batch() to loop over micro-batches with proper scaling
- Add token_count metric recording in DPOTrainModule
- Add token_count calculation in DataCollatorForSeq2SeqDPO
- Fix batch size calculation: rank_batch_size = per_device * grad_accum
- Update cache batch size to dp_world_size * 2

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
@finbarrtimbers finbarrtimbers force-pushed the finbarr/dpo-mfu-pr5-microbatching branch from a07d926 to d270162 Compare February 6, 2026 17:11
finbarrtimbers and others added 8 commits February 6, 2026 10:39
Replace functools.partial wrapping of forward_fn with explicit kwargs
dict passed at call sites. Extract inline loss computation into
_compute_microbatch_loss method on DPOTrainModule for readability.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
DPO batches only contain torch.Tensor values, so the list splitting
branch was unreachable. Simplify to tensor-only with an error for
unexpected types.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
Replace 7 individual self._total_* attributes with self._metrics dict,
and have _compute_microbatch_loss return step metrics alongside loss.
This simplifies zeroing, accumulation, and division to simple loops.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
loss, step_metrics = self._compute_microbatch_loss(micro_batch)
for k, v in step_metrics.items():
self._metrics[k] += v.detach()
(loss / num_micro_batches).backward()
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we need to divide by total tokens in the batch, rather than the number of microbatches. This is because each microbatch is weighted evenly even though they have differing numbers of tokens in them.

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice catch, done.

Comment thread open_instruct/olmo_core_train_modules.py Outdated
…obatch count

Ensures each token contributes equally to gradients and logged metrics
regardless of how padding is distributed across microbatches.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
self.record_metric("train/loss", loss.detach(), ReduceType.mean)
self.record_metric("train/logps_chosen", policy_chosen_logps.mean().detach(), ReduceType.mean)
self.record_metric("train/logps_rejected", policy_rejected_logps.mean().detach(), ReduceType.mean)
self.record_metric("train/loss", self._metrics["loss"], ReduceType.mean)
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

should reduceType be mean? shouldnt be sum since we accounted for the token weighting above?

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ah yeah, nice catch. annoyingly Olmo-core doesn't provide a ReduceType.weighted_mean. So I had to roll my own.

self._metrics[k] += v.detach() * weight
(loss * weight).backward()

self.model.post_batch(dry_run=dry_run)
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is this where the optim step is taken?

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No, it happens after, in the TransformerTrainModule.optim_step method. See here:

https://github.com/allenai/OLMo-core/blob/3af842521375a373266673dda262debe0748a462/src/olmo_core/train/trainer.py#L1419

Accumulate token-weighted sums during microbatch loop, then do a single
all_reduce(SUM) across DP ranks and divide by global total tokens. This
matches the GRPO pattern and produces globally correct token-weighted
metrics.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
@finbarrtimbers finbarrtimbers added this pull request to the merge queue Feb 11, 2026
@github-merge-queue github-merge-queue Bot removed this pull request from the merge queue due to failed status checks Feb 11, 2026
@finbarrtimbers finbarrtimbers added this pull request to the merge queue Feb 11, 2026
Merged via the queue into main with commit 06b90f2 Feb 11, 2026
7 checks passed
@finbarrtimbers finbarrtimbers deleted the finbarr/dpo-mfu-pr5-microbatching branch February 11, 2026 22:54
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants

Morty Proxy This is a proxified and sanitized view of the page, visit original site.