Add torch.compile support with max_length padding#1445
Add torch.compile support with max_length padding#1445finbarrtimbers merged 41 commits intomainallenai/open-instruct:mainfrom finbarr/dpo-mfu-pr3-torch-compileallenai/open-instruct:finbarr/dpo-mfu-pr3-torch-compileCopy head branch name to clipboard
Conversation
Summary of ChangesHello @finbarrtimbers, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed! This pull request introduces the capability to utilize Highlights
🧠 New Feature in Public Preview: You can now enable Memory to help Gemini Code Assist learn from your team's feedback. This makes future code reviews more consistent and personalized to your project's style. Click here to enable Memory in your admin console. Using Gemini Code AssistThe full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips. Invoking Gemini You can request assistance from Gemini at any point by creating a comment using either
Customization To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a Limitations & Feedback Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here. You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension. Footnotes
|
There was a problem hiding this comment.
Code Review
This pull request introduces support for torch.compile by adding a compile_model configuration and implementing fixed-size padding to max_length when this option is enabled. The changes are logical and correctly address the requirement of static tensor shapes for torch.compile. My review includes a couple of suggestions to improve code maintainability and efficiency by addressing code duplication and a redundant padding operation.
There was a problem hiding this comment.
💡 Codex Review
Here are some automated review suggestions for this pull request.
Reviewed commit: e074b27857
ℹ️ About Codex in GitHub
Your team has set up Codex to review pull requests in this repo. Reviews are triggered when you
- Open a pull request for review
- Mark a draft as ready
- Comment "@codex review".
If Codex has suggestions, it will comment; otherwise it will react with 👍.
Codex can also answer questions or update the PR. Try commenting "@codex address that feedback".
- Add compile_model config field to enable torch.compile on model blocks - Add max_length field to DataCollatorForSeq2SeqDPO for fixed-size padding - Add PAD_VALUES dict and pad_tensor_to_length() helper function - Update collator instantiation to pass max_length when compile_model is enabled torch.compile requires static tensor shapes, so when enabled, all sequences are padded to max_seq_length. Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
e074b27 to
9d90798
Compare
- Delete pad_tensor_to_length, consolidate into pad_to_length - Update pad_to_length to use torch.nn.functional.pad - Remove unused dim parameter from pad_to_length - Combine two padding codepaths in DataCollatorForSeq2SeqDPO into one - Always pad attention_mask and labels (not just input_ids) Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
Packing creates variable-length batches which causes torch.compile to recompile on every batch, making it ineffective. Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
hamishivi
left a comment
There was a problem hiding this comment.
I'm okay with the change, but I sorta dislike having to do it. It feels like it's covering up a deeper issue (the modelling code making some assumptions about batch lengths....?).
| if args.packing and args.compile_model: | ||
| raise ValueError( | ||
| "packing and compile_model cannot be used together. " | ||
| "Packing creates variable-length batches which causes torch.compile to recompile on every batch. " |
There was a problem hiding this comment.
This is kinda surprising to me? I thought you could use varlen_attn to avoid torch.compile doing this (https://docs.pytorch.org/tutorials/intermediate/variable_length_attention_tutorial.html).
There was a problem hiding this comment.
Maybe it's bc we use flash attention?
Co-authored-by: Hamish Ivison <hamishivi@gmail.com>
Remove conditional padding based on compile_model and add TORCH_LOGS environment variable to log graph_breaks and recompiles for measuring torch.compile impact without padding. Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
…heckpointing The flag was renamed in commit 5142c19 but these scripts weren't updated. Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
Add --push_to_hub false to all DPO debug scripts to prevent unnecessary model uploads during testing. Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
PerfCallback was not logging because it depends on train/token_count metric which was not being recorded in DPOTrainModule.train_batch(). Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
- Update global_num_tokens_in_batch() to use attention_mask.sum() instead of numel() so padding tokens are not counted - Add attention_mask to DPO collator alongside input_ids concatenation - Reuse global_num_tokens_in_batch() in DPO train module for train/token_count metric as single source of truth - Add tests verifying padding doesn't affect token counts or MFU - Add coding convention rules about imports to CLAUDE.md This fixes inflated MFU, TPS, and token_count metrics when padding is enabled for torch.compile. Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
The test imports data_loader which requires vllm, so it can only run on Beaker GPU tests. Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
The test now uses an actual ModelDims instance with real model architecture parameters (OLMo-2-1B-like config) and the real approximate_learner_utilization calculation. This verifies that padding tokens are properly excluded from the MFU calculation. Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
Also update CLAUDE.md to clarify that import logging should never be used directly. Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
The expected MFU is arbitrary since time is mocked. The important thing is consistency across different padding levels. Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
When enabled, pads all sequences to max_seq_length instead of just the longest in the batch. Useful with torch.compile to avoid recompilation due to varying tensor shapes. Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
- single_gpu_pad.sh: Single GPU with padding - multi_node_pad.sh: Multi-node with padding Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
… metrics - Add pre_step/post_step timing for accurate seconds_per_step measurement - Add mfu_avg as simple running average of MFU values - Rename tokens_per_second_step -> tokens_per_second - Rename tokens_per_second_total -> tokens_per_second_avg - Fix multi_node_pad.sh missing --try_launch_beaker_eval_jobs false Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
The olmo_core Callback.pre_step() method receives a batch argument. Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
Padding to max_seq_length requires more memory. Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
16k with padding was OOM even with low activation_memory_budget. Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
Uses --gradient_checkpointing from main branch for fair comparison. Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
…t_checkpointing Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
This reverts commit 6bb4599. Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
The padding feature didn't help performance, so remove it while keeping the compile_model flag. Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
Keep the file rename but remove the TestPerfCallbackMFU test class. Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
The collator needs to pad chosen and rejected sequences to the same length before concatenating them. Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
Remove --compile_model flag from DPO scripts since it's now the default. Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
Adds a
--compile_modelflag todpo.py.It doesn't help performance at all, but it does let us use gradient checkpointing, which will be needed for large models (Wandb):
Runs:
In an earlier version of this PR, to avoid recompiles, we added a flag to pad the sequence lengths to have a fixed size. We ran an ablation comparing compilation with max length padding (Beaker) to compilation without (Beaker), and found that max length padding hurts performance, but avoids recompiles:
As such, we removed the padding.