Add gradient accumulation/microbatching support to DPOTrainModule#1447
Add gradient accumulation/microbatching support to DPOTrainModule#1447finbarrtimbers merged 14 commits intomainallenai/open-instruct:mainfrom finbarr/dpo-mfu-pr5-microbatchingallenai/open-instruct:finbarr/dpo-mfu-pr5-microbatchingCopy head branch name to clipboard
DPOTrainModule#1447Conversation
Summary of ChangesHello @finbarrtimbers, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed! This pull request significantly enhances the DPO training module by introducing robust support for gradient accumulation and microbatching. These changes enable more efficient training with larger effective batch sizes, which is crucial for optimizing resource utilization and improving model performance, especially in distributed environments. The modifications involve restructuring the training loop, refining batch size calculations, and integrating mechanisms for proper metric accumulation across micro-batches. Highlights
🧠 New Feature in Public Preview: You can now enable Memory to help Gemini Code Assist learn from your team's feedback. This makes future code reviews more consistent and personalized to your project's style. Click here to enable Memory in your admin console. Using Gemini Code AssistThe full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips. Invoking Gemini You can request assistance from Gemini at any point by creating a comment using either
Customization To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a Limitations & Feedback Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here. You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension. Footnotes
|
There was a problem hiding this comment.
Code Review
This pull request introduces gradient accumulation and micro-batching for DPO training, which is a great enhancement for training on larger batch sizes with limited memory. The changes involve refactoring the training loop, adding a batch splitting utility, and updating metric accumulation.
My review has identified a couple of critical issues:
- There's a potential runtime error when using
packing=Truewithconcatenated_forward=False, as theseparate_forwardfunction doesn't support packing. This issue appears in bothdpo.pyfor reference cache generation and inDPOTrainModulefor the training step. - The accuracy metric calculation in
DPOTrainModuleis incorrect. It compares the mean of rewards instead of calculating the mean of per-sample accuracies, which will lead to misleading metric reports.
I've left specific comments with suggestions on how to address these points. Once these are fixed, the implementation will be much more robust.
There was a problem hiding this comment.
💡 Codex Review
Here are some automated review suggestions for this pull request.
Reviewed commit: 0f9a766699
ℹ️ About Codex in GitHub
Your team has set up Codex to review pull requests in this repo. Reviews are triggered when you
- Open a pull request for review
- Mark a draft as ready
- Comment "@codex review".
If Codex has suggestions, it will comment; otherwise it will react with 👍.
Codex can also answer questions or update the PR. Try commenting "@codex address that feedback".
1ff30b6 to
ba2c8c0
Compare
DPOTrainModule
- Add split_batch_dpo() function to split DPO batches into micro-batches - Add rank_microbatch_size parameter to DPOTrainModule.__init__() - Add accumulator tensors (_total_loss, _total_chosen_logps, etc.) - Add _train_microbatch_context() for FSDP/DDP sync control - Refactor train_batch() to loop over micro-batches with proper scaling - Add token_count metric recording in DPOTrainModule - Add token_count calculation in DataCollatorForSeq2SeqDPO - Fix batch size calculation: rank_batch_size = per_device * grad_accum - Update cache batch size to dp_world_size * 2 Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
a07d926 to
d270162
Compare
Replace functools.partial wrapping of forward_fn with explicit kwargs dict passed at call sites. Extract inline loss computation into _compute_microbatch_loss method on DPOTrainModule for readability. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
DPO batches only contain torch.Tensor values, so the list splitting branch was unreachable. Simplify to tensor-only with an error for unexpected types. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
Replace 7 individual self._total_* attributes with self._metrics dict, and have _compute_microbatch_loss return step metrics alongside loss. This simplifies zeroing, accumulation, and division to simple loops. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
| loss, step_metrics = self._compute_microbatch_loss(micro_batch) | ||
| for k, v in step_metrics.items(): | ||
| self._metrics[k] += v.detach() | ||
| (loss / num_micro_batches).backward() |
There was a problem hiding this comment.
I think we need to divide by total tokens in the batch, rather than the number of microbatches. This is because each microbatch is weighted evenly even though they have differing numbers of tokens in them.
There was a problem hiding this comment.
Nice catch, done.
…obatch count Ensures each token contributes equally to gradients and logged metrics regardless of how padding is distributed across microbatches. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
| self.record_metric("train/loss", loss.detach(), ReduceType.mean) | ||
| self.record_metric("train/logps_chosen", policy_chosen_logps.mean().detach(), ReduceType.mean) | ||
| self.record_metric("train/logps_rejected", policy_rejected_logps.mean().detach(), ReduceType.mean) | ||
| self.record_metric("train/loss", self._metrics["loss"], ReduceType.mean) |
There was a problem hiding this comment.
should reduceType be mean? shouldnt be sum since we accounted for the token weighting above?
There was a problem hiding this comment.
ah yeah, nice catch. annoyingly Olmo-core doesn't provide a ReduceType.weighted_mean. So I had to roll my own.
| self._metrics[k] += v.detach() * weight | ||
| (loss * weight).backward() | ||
|
|
||
| self.model.post_batch(dry_run=dry_run) |
There was a problem hiding this comment.
is this where the optim step is taken?
There was a problem hiding this comment.
No, it happens after, in the TransformerTrainModule.optim_step method. See here:
Accumulate token-weighted sums during microbatch loop, then do a single all_reduce(SUM) across DP ranks and divide by global total tokens. This matches the GRPO pattern and produces globally correct token-weighted metrics. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
Adds gradient accumulation to
DPOTrainModule.Runs