Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Appearance settings

Replaces lambda collators with a single_example_collator to use a consistent interface#1472

Merged
finbarrtimbers merged 3 commits intomainallenai/open-instruct:mainfrom
finbarr/consistent-collationallenai/open-instruct:finbarr/consistent-collationCopy head branch name to clipboard
Feb 14, 2026
Merged

Replaces lambda collators with a single_example_collator to use a consistent interface#1472
finbarrtimbers merged 3 commits intomainallenai/open-instruct:mainfrom
finbarr/consistent-collationallenai/open-instruct:finbarr/consistent-collationCopy head branch name to clipboard

Conversation

@finbarrtimbers
Copy link
Copy Markdown
Collaborator

@finbarrtimbers finbarrtimbers commented Feb 13, 2026

This lets us avoid some branching in #1466.

Experiment links:

@gemini-code-assist
Copy link
Copy Markdown
Contributor

Summary of Changes

Hello @finbarrtimbers, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!

This pull request streamlines data collation and DPO training by replacing ad-hoc lambda functions with a consistent single_example_collator interface. It significantly enhances the handling of packed sequences, allowing for better integration with torch.compile and providing more accurate token and sequence counting. The changes also introduce profiling capabilities for DPO training and refactor core batch processing utilities, leading to more robust and efficient model training workflows.

Highlights

  • Standardized Collator Usage: Replaced inline lambda collators with a dedicated single_example_collator function across data_loader.py, grpo_fast.py, and test_grpo_fast.py for improved consistency and maintainability.
  • Enhanced DPO Training with Packing: Introduced an _overflow mechanism in HFDataLoader to manage examples that don't perfectly fill batches, especially when using packing. The restriction preventing packing and torch.compile from being used together in DPO training was removed, and max_seq_length support was added to TensorDataCollatorWithFlatteningDPO for pre-filtering.
  • Improved Token and Sequence Counting: Enhanced token counting in HFDataLoader and olmo_core_train_modules.py to support both packed and non-packed batches using new utility functions (get_num_tokens, get_num_sequences). The PerfCallback now accurately tracks sequence counts for packed batches and includes a tokens_per_second_per_gpu metric.
  • Refactored DPO Utilities: Centralized batch statistics extraction into _get_batch_stats in dpo_utils.py and moved the pad_to_length function to a new tensor_utils.py module. The logic for calculating log probabilities in padding-free collators was improved, addressing an offset bug and enhancing segment handling.
  • Integrated Profiling: Added a profiling argument to ExperimentConfig and integrated ProfilerCallback for detailed performance analysis during DPO training, with corresponding updates in debug scripts.
  • Updated Debug Scripts and Changelog: Modified DPO debug training scripts (multi_node.sh, single_gpu.sh) to enable packing, adjust batching parameters, and activate profiling. A new changelog entry details these DPO training improvements.
  • Comprehensive Testing: Added extensive unit tests for DPO packing, index preservation, log probability calculations, and pre-filtering in test_padding_free_collator.py to ensure robustness.

🧠 New Feature in Public Preview: You can now enable Memory to help Gemini Code Assist learn from your team's feedback. This makes future code reviews more consistent and personalized to your project's style. Click here to enable Memory in your admin console.

Changelog
  • .claude/commands/run-dpo-experiments.md
    • Updated DPO experiment script names and execution instructions.
  • CHANGELOG.md
    • Added a changelog entry for DPO training improvements, including packing with torch.compile, bug fixes, and new metrics.
  • open_instruct/data_loader.py
    • Imported padding_free_collator.
    • Added _overflow attribute to HFDataLoader to handle batch remnants.
    • Modified batch iteration logic to incorporate the _overflow buffer.
    • Updated token counting to use padding_free_collator.get_num_tokens.
    • Introduced single_example_collator function.
  • open_instruct/dpo.py
    • Imported ProfilerCallback.
    • Added profiling setup to _setup_callbacks based on args.profiling.
    • Removed the incompatibility check between packing and compile_model.
    • Adjusted dp_world_size calculation to account for tensor_parallel_degree.
    • Added max_seq_length parameter to TensorDataCollatorWithFlatteningDPO initialization.
    • Modified cache_batch_size calculation for packed scenarios.
  • open_instruct/dpo_utils.py
    • Imported padding_free_collator and tensor_utils.
    • Added a profiling field to ExperimentConfig.
    • Introduced _get_batch_stats function for extracting token count, example count, and sequence lengths from a DPO batch.
    • Updated build_reference_logprobs_cache to utilize _get_batch_stats.
    • Removed the pad_to_length function, moving its functionality to tensor_utils.py.
    • Ensured the index key is always included in the DataCollatorForSeq2SeqDPO output.
  • open_instruct/grpo_fast.py
    • Replaced an inline lambda collator with data_loader_lib.single_example_collator.
  • open_instruct/olmo_core_callbacks.py
    • Imported padding_free_collator.
    • Added _interval_num_sequences attribute to PerfCallback.
    • Updated pre_step to calculate num_sequences using padding_free_collator.get_num_sequences.
    • Revised num_sequences calculation in post_step for accuracy with packed batches.
    • Added tokens_per_second_per_gpu metric to performance logging.
    • Reset _interval_num_sequences at the end of post_step.
  • open_instruct/olmo_core_train_modules.py
    • Imported padding_free_collator.
    • Updated train_batch to use padding_free_collator.get_num_tokens for calculating total_tokens and micro_tokens.
  • open_instruct/padding_free_collator.py
    • Added helper functions _pad_to_max_length, _collect_flattened_features, _filter_feature_dicts, and _split_prefixed_batch.
    • Added max_seq_length to TensorDataCollatorWithFlattening and TensorDataCollatorWithFlatteningDPO for pre-filtering.
    • Refined get_batch_logps logic for calculating log probabilities, improving segment handling and correcting the cu_seq_lens offset.
    • Introduced get_num_tokens and get_num_sequences functions for consistent token and sequence counting.
  • open_instruct/tensor_utils.py
    • Added a new utility file containing the pad_to_length function.
  • open_instruct/test_data_loader_gpu.py
    • Updated import path for the pad_to_length utility function from dpo_utils to tensor_utils.
  • open_instruct/test_grpo_fast.py
    • Replaced inline lambda collators with data_loader_lib.single_example_collator in test setups.
  • open_instruct/test_padding_free_collator.py
    • Imported TensorDataCollatorWithFlatteningDPO, concatenated_inputs, and get_batch_logps.
    • Added new test cases for DPO packing indices, cu_seq_lens consistency, concatenated inputs, log probability calculations, and pre-filtering behavior.
  • scripts/train/debug/dpo/multi_node.sh
    • Modified training parameters to enable packing.
    • Increased per_device_train_batch_size to 16.
    • Reduced gradient_accumulation_steps to 1.
    • Updated output_dir and mixer_list values.
    • Enabled profiling.
  • scripts/train/debug/dpo/single_gpu.sh
    • Added --no-host-networking argument.
Using Gemini Code Assist

The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.

Invoking Gemini

You can request assistance from Gemini at any point by creating a comment using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands on the current page.

Feature Command Description
Code Review /gemini review Performs a code review for the current pull request in its current state.
Pull Request Summary /gemini summary Provides a summary of the current pull request in its current state.
Comment @gemini-code-assist Responds in comments when explicitly tagged, both in pull request comments and review comments.
Help /gemini help Displays a list of available commands.

Customization

To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.

Limitations & Feedback

Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here.

You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.

Footnotes

  1. Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution.

- Add single_example_collator to replace lambda x: x[0] everywhere,
  ensuring batch["index"] is always a tensor
- Track overflow in HFDataLoader._iter_batches so features filtered by
  the DPO collator are carried to the next batch instead of being
  silently dropped
- Overflow persists across epoch reshuffles
- Remove unnecessary "index" in features[0] guards in DPO collators

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces a significant set of improvements for DPO training. The primary change is enabling packing with torch.compile, which was previously unsupported. This is achieved through a major refactoring of the data collators in padding_free_collator.py, introducing padding to a fixed max_seq_length to work around torch.compile's limitations with variable shapes.

The refactoring is well-executed, with helper functions extracted for better modularity and readability. Several correctness bugs are also fixed, including an off-by-one error in seq_idx calculation for packed DPO batches and a more robust implementation of get_batch_logps using torch.bucketize for better performance and correctness with variable-length sequences.

The PR also includes other valuable changes such as adding a profiling option, improving performance metric calculations in PerfCallback to correctly handle packed sequences, and replacing lambda collators with a named function for consistency, as stated in the title. The addition of comprehensive tests for the new DPO packing logic in test_padding_free_collator.py is excellent and provides high confidence in the correctness of these complex changes.

Overall, this is a high-quality contribution that improves performance, correctness, and maintainability of the DPO training code.

@finbarrtimbers finbarrtimbers force-pushed the finbarr/consistent-collation branch from c18148a to 1a578cf Compare February 13, 2026 17:31
Copy link
Copy Markdown

@chatgpt-codex-connector chatgpt-codex-connector Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

💡 Codex Review

Here are some automated review suggestions for this pull request.

Reviewed commit: c18148ab25

ℹ️ About Codex in GitHub

Your team has set up Codex to review pull requests in this repo. Reviews are triggered when you

  • Open a pull request for review
  • Mark a draft as ready
  • Comment "@codex review".

If Codex has suggestions, it will comment; otherwise it will react with 👍.

Codex can also answer questions or update the PR. Try commenting "@codex address that feedback".

Comment on lines 173 to +177
if len(batch_examples) == self._per_rank_batch_size:
yield to_device(self._collator(batch_examples), self._device)
all_examples = self._overflow + batch_examples
batch = to_device(self._collator(all_examples), self._device)
self._overflow = all_examples[len(batch["index"]) :]
yield batch
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P1 Badge Drain deferred overflow examples before ending an epoch

HFDataLoader._iter_batches() now defers trimmed samples into _overflow, but iteration still stops after effective_size source rows and never flushes any remaining overflow. With packing (TensorDataCollatorWithFlatteningDPO._prefilter_features) this means long-sequence workloads can repeatedly keep fewer than per_rank_batch_size items, leaving many deferred samples untrained in that epoch (and potentially accumulating across epochs), which silently skews DPO sampling and can invalidate experiment results.

Useful? React with 👍 / 👎.

Comment thread open_instruct/olmo_core_callbacks.py Outdated
Comment on lines +163 to +165
if num_seqs is None:
num_seqs = self.per_device_train_batch_size * 2
self._interval_num_sequences += num_seqs * self.num_training_gpus
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P2 Badge Include grad accumulation in non-packing sequence counts

In PerfCallback.pre_step, the non-packing fallback uses per_device_train_batch_size * 2, but DPO batches are built with per_device_train_batch_size * gradient_accumulation_steps per rank. When gradient_accumulation_steps > 1 and cu_seq_lens_k is absent, _interval_num_sequences is undercounted, so avg_sequence_length and MFU are systematically inflated, making performance metrics misleading for accumulation-heavy runs.

Useful? React with 👍 / 👎.

Comment thread open_instruct/data_loader.py
Comment thread open_instruct/data_loader.py
…r guard

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
@finbarrtimbers finbarrtimbers added this pull request to the merge queue Feb 14, 2026
Merged via the queue into main with commit f009311 Feb 14, 2026
6 of 7 checks passed
@finbarrtimbers finbarrtimbers deleted the finbarr/consistent-collation branch February 14, 2026 18:15
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants

Morty Proxy This is a proxified and sanitized view of the page, visit original site.