Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Appearance settings

Add torch.compile support with max_length padding#1445

Merged
finbarrtimbers merged 41 commits intomainallenai/open-instruct:mainfrom
finbarr/dpo-mfu-pr3-torch-compileallenai/open-instruct:finbarr/dpo-mfu-pr3-torch-compileCopy head branch name to clipboard
Feb 2, 2026
Merged

Add torch.compile support with max_length padding#1445
finbarrtimbers merged 41 commits intomainallenai/open-instruct:mainfrom
finbarr/dpo-mfu-pr3-torch-compileallenai/open-instruct:finbarr/dpo-mfu-pr3-torch-compileCopy head branch name to clipboard

Conversation

@finbarrtimbers
Copy link
Copy Markdown
Collaborator

@finbarrtimbers finbarrtimbers commented Jan 30, 2026

Adds a --compile_model flag to dpo.py.

It doesn't help performance at all, but it does let us use gradient checkpointing, which will be needed for large models (Wandb):

Screenshot 2026-02-02 at 2 23 33 PM

Runs:

  1. Single GPU (01KGG5MAQ9ENC1Y7ARMN6BQF6W)
  2. Multi-node (01KGG5RV1BER13M28XZXCJXS9Z)

In an earlier version of this PR, to avoid recompiles, we added a flag to pad the sequence lengths to have a fixed size. We ran an ablation comparing compilation with max length padding (Beaker) to compilation without (Beaker), and found that max length padding hurts performance, but avoids recompiles:

Metric With Padding Without Padding
MFU (avg) 4.58% 11.85%
Seconds/step 1.04s 0.42s
Recompilations 48 396

As such, we removed the padding.

@gemini-code-assist
Copy link
Copy Markdown
Contributor

Summary of Changes

Hello @finbarrtimbers, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!

This pull request introduces the capability to utilize torch.compile for optimizing model execution, which necessitates a change in how data sequences are handled. To meet the static tensor shape requirements of torch.compile, the data collator has been enhanced to pad all sequences to a uniform maximum length when this optimization is enabled. This ensures compatibility and leverages torch.compile's performance benefits.

Highlights

  • Torch.compile Integration: Added support for torch.compile to model blocks, enabling potential performance improvements by compiling the model for faster execution.
  • Fixed-Size Padding for Static Shapes: Implemented fixed-size padding for sequences within DataCollatorForSeq2SeqDPO when torch.compile is active, which is crucial for meeting the static tensor shape requirements of torch.compile.
  • New Padding Utilities: Introduced a new pad_tensor_to_length utility function and a PAD_VALUES constant to standardize and manage padding logic across different tensor types.
  • Configuration and Collator Updates: Updated the TrainingConfig with a compile_model flag to control torch.compile activation and modified the data collator instantiation to conditionally pass max_length based on this flag.

🧠 New Feature in Public Preview: You can now enable Memory to help Gemini Code Assist learn from your team's feedback. This makes future code reviews more consistent and personalized to your project's style. Click here to enable Memory in your admin console.

Using Gemini Code Assist

The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.

Invoking Gemini

You can request assistance from Gemini at any point by creating a comment using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands on the current page.

Feature Command Description
Code Review /gemini review Performs a code review for the current pull request in its current state.
Pull Request Summary /gemini summary Provides a summary of the current pull request in its current state.
Comment @gemini-code-assist Responds in comments when explicitly tagged, both in pull request comments and review comments.
Help /gemini help Displays a list of available commands.

Customization

To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.

Limitations & Feedback

Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here.

You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.

Footnotes

  1. Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution.

Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces support for torch.compile by adding a compile_model configuration and implementing fixed-size padding to max_length when this option is enabled. The changes are logical and correctly address the requirement of static tensor shapes for torch.compile. My review includes a couple of suggestions to improve code maintainability and efficiency by addressing code duplication and a redundant padding operation.

Comment thread open_instruct/dpo_utils.py Outdated
Comment thread open_instruct/dpo_utils.py Outdated
Copy link
Copy Markdown

@chatgpt-codex-connector chatgpt-codex-connector Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

💡 Codex Review

Here are some automated review suggestions for this pull request.

Reviewed commit: e074b27857

ℹ️ About Codex in GitHub

Your team has set up Codex to review pull requests in this repo. Reviews are triggered when you

  • Open a pull request for review
  • Mark a draft as ready
  • Comment "@codex review".

If Codex has suggestions, it will comment; otherwise it will react with 👍.

Codex can also answer questions or update the PR. Try commenting "@codex address that feedback".

Comment thread open_instruct/dpo.py
- Add compile_model config field to enable torch.compile on model blocks
- Add max_length field to DataCollatorForSeq2SeqDPO for fixed-size padding
- Add PAD_VALUES dict and pad_tensor_to_length() helper function
- Update collator instantiation to pass max_length when compile_model is enabled

torch.compile requires static tensor shapes, so when enabled, all sequences
are padded to max_seq_length.

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
@finbarrtimbers finbarrtimbers force-pushed the finbarr/dpo-mfu-pr3-torch-compile branch from e074b27 to 9d90798 Compare January 31, 2026 00:00
finbarrtimbers and others added 2 commits February 1, 2026 09:29
- Delete pad_tensor_to_length, consolidate into pad_to_length
- Update pad_to_length to use torch.nn.functional.pad
- Remove unused dim parameter from pad_to_length
- Combine two padding codepaths in DataCollatorForSeq2SeqDPO into one
- Always pad attention_mask and labels (not just input_ids)

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
Packing creates variable-length batches which causes torch.compile to
recompile on every batch, making it ineffective.

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
Copy link
Copy Markdown
Collaborator

@hamishivi hamishivi left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm okay with the change, but I sorta dislike having to do it. It feels like it's covering up a deeper issue (the modelling code making some assumptions about batch lengths....?).

Comment thread open_instruct/dpo_utils.py Outdated
Comment thread open_instruct/dpo.py
if args.packing and args.compile_model:
raise ValueError(
"packing and compile_model cannot be used together. "
"Packing creates variable-length batches which causes torch.compile to recompile on every batch. "
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is kinda surprising to me? I thought you could use varlen_attn to avoid torch.compile doing this (https://docs.pytorch.org/tutorials/intermediate/variable_length_attention_tutorial.html).

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe it's bc we use flash attention?

@finbarrtimbers finbarrtimbers added this pull request to the merge queue Feb 1, 2026
@hamishivi hamishivi removed this pull request from the merge queue due to a manual request Feb 1, 2026
finbarrtimbers and others added 17 commits February 1, 2026 22:27
Co-authored-by: Hamish Ivison <hamishivi@gmail.com>
Remove conditional padding based on compile_model and add TORCH_LOGS
environment variable to log graph_breaks and recompiles for measuring
torch.compile impact without padding.

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
…heckpointing

The flag was renamed in commit 5142c19 but these scripts weren't updated.

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
Add --push_to_hub false to all DPO debug scripts to prevent
unnecessary model uploads during testing.

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
PerfCallback was not logging because it depends on train/token_count
metric which was not being recorded in DPOTrainModule.train_batch().

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
- Update global_num_tokens_in_batch() to use attention_mask.sum() instead of
  numel() so padding tokens are not counted
- Add attention_mask to DPO collator alongside input_ids concatenation
- Reuse global_num_tokens_in_batch() in DPO train module for train/token_count
  metric as single source of truth
- Add tests verifying padding doesn't affect token counts or MFU
- Add coding convention rules about imports to CLAUDE.md

This fixes inflated MFU, TPS, and token_count metrics when padding is enabled
for torch.compile.

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
The test imports data_loader which requires vllm, so it can only
run on Beaker GPU tests.

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
The test now uses an actual ModelDims instance with real model
architecture parameters (OLMo-2-1B-like config) and the real
approximate_learner_utilization calculation. This verifies that
padding tokens are properly excluded from the MFU calculation.

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
Also update CLAUDE.md to clarify that import logging should never
be used directly.

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
The expected MFU is arbitrary since time is mocked. The important
thing is consistency across different padding levels.

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
When enabled, pads all sequences to max_seq_length instead of just
the longest in the batch. Useful with torch.compile to avoid
recompilation due to varying tensor shapes.

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
- single_gpu_pad.sh: Single GPU with padding
- multi_node_pad.sh: Multi-node with padding

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
finbarrtimbers and others added 19 commits February 2, 2026 12:21
… metrics

- Add pre_step/post_step timing for accurate seconds_per_step measurement
- Add mfu_avg as simple running average of MFU values
- Rename tokens_per_second_step -> tokens_per_second
- Rename tokens_per_second_total -> tokens_per_second_avg
- Fix multi_node_pad.sh missing --try_launch_beaker_eval_jobs false

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
The olmo_core Callback.pre_step() method receives a batch argument.

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
Padding to max_seq_length requires more memory.

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
16k with padding was OOM even with low activation_memory_budget.

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
Uses --gradient_checkpointing from main branch for fair comparison.

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
…t_checkpointing

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
This reverts commit 6bb4599.

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
The padding feature didn't help performance, so remove it while keeping
the compile_model flag.

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
Keep the file rename but remove the TestPerfCallbackMFU test class.

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
The collator needs to pad chosen and rejected sequences to the same
length before concatenating them.

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
Remove --compile_model flag from DPO scripts since it's now the default.

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
@finbarrtimbers finbarrtimbers added this pull request to the merge queue Feb 2, 2026
Merged via the queue into main with commit e4a8825 Feb 2, 2026
7 checks passed
@finbarrtimbers finbarrtimbers deleted the finbarr/dpo-mfu-pr3-torch-compile branch February 2, 2026 23:47
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants

Morty Proxy This is a proxified and sanitized view of the page, visit original site.