Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Appearance settings

Fix ZeRO-2 discarding gradients during manual accumulation#1498

Merged
hamishivi merged 3 commits intomainallenai/open-instruct:mainfrom
fix/zero2-gradient-accumulationallenai/open-instruct:fix/zero2-gradient-accumulationCopy head branch name to clipboard
Feb 26, 2026
Merged

Fix ZeRO-2 discarding gradients during manual accumulation#1498
hamishivi merged 3 commits intomainallenai/open-instruct:mainfrom
fix/zero2-gradient-accumulationallenai/open-instruct:fix/zero2-gradient-accumulationCopy head branch name to clipboard

Conversation

@hamishivi
Copy link
Copy Markdown
Collaborator

@hamishivi hamishivi commented Feb 26, 2026

Summary

  • With gradient_accumulation_steps=1 in the DS config, every backward() call is treated as a gradient accumulation boundary. In ZeRO-2 (stage_1_and_2.py:881), hitting the boundary resets the internal gradient accumulator (all_grad_tensors[i] = None), so only the last micro-batch's gradient survives to step(). ZeRO-3 doesn't have this bug because it uses a separate micro_step_id counter that only resets in step()/zero_grad().
  • Uses DeepSpeed's set_gradient_accumulation_boundary() API to explicitly mark only the final backward before step() as the boundary. Handles varying accumulation step counts naturally.

Details

Root cause (in DeepSpeed stage_1_and_2.py):

# ZeRO-2 epilogue — called after every backward()
if self.is_gradient_accumulation_boundary:
    self.averaged_gradients[i] = self.get_flat_partition(...)
    self.all_grad_tensors[i] = None  # ← RESETS accumulator!

With gradient_accumulation_steps=1, every backward is a boundary, so all_grad_tensors is reset each time and averaged_gradients only ever holds the last backward's gradient.

ZeRO-3 is unaffected (in stage3.py):

# ZeRO-3 uses micro_step_id which only resets in step()/zero_grad()
if self.micro_step_id == 0:
    grad_buffer.copy_(grad_partition)   # first backward: overwrite
else:
    grad_buffer.add_(grad_partition)    # subsequent: accumulate
self.micro_step_id += 1

Impact: With num_mini_batches=1 and ~256 samples per rank, ZeRO-2 was using only the last sample's gradient (divided by the total token count of all 256 samples), giving an effective learning rate ~256x smaller than intended.

Validation

Running on a wordle sample.

  • Purple -> with zero 3
  • Blue -> with zero 2 before fix
  • Green -> with zero 2 after fix.
image

🤖 Generated with Claude Code

With gradient_accumulation_steps=1 in the DeepSpeed config, every
backward() call is treated as a gradient accumulation boundary.
In ZeRO-2 (stage_1_and_2.py), hitting the boundary resets the internal
gradient accumulator (all_grad_tensors = None), so only the last
backward's gradient survives — all previous micro-batches are discarded.

ZeRO-3 doesn't have this bug because it uses a separate micro_step_id
counter that only resets in step()/zero_grad(), correctly accumulating
across all backward calls.

Fix: use DeepSpeed's set_gradient_accumulation_boundary() API to
explicitly mark only the final backward before step() as the boundary.
This works with varying accumulation step counts.

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
@gemini-code-assist
Copy link
Copy Markdown
Contributor

Summary of Changes

Hello, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!

This pull request addresses a critical issue in DeepSpeed's ZeRO-2 optimization, which previously caused incorrect gradient accumulation during manual training loops. By explicitly signaling gradient accumulation boundaries, the fix ensures that all micro-batch gradients are correctly aggregated, thereby restoring the intended learning rate and improving training stability and performance.

Highlights

  • ZeRO-2 Gradient Accumulation Fix: Corrected a bug in DeepSpeed's ZeRO-2 where gradients were discarded during manual accumulation when gradient_accumulation_steps=1, leading to an effective learning rate reduction.
  • Explicit Accumulation Boundary: Implemented self.model.set_gradient_accumulation_boundary() to explicitly mark the final backward pass in an accumulation group, ensuring proper gradient accumulation.

🧠 New Feature in Public Preview: You can now enable Memory to help Gemini Code Assist learn from your team's feedback. This makes future code reviews more consistent and personalized to your project's style. Click here to enable Memory in your admin console.

Changelog
  • CHANGELOG.md
    • Added a new entry under 'Fixed' for the ZeRO-2 gradient accumulation bug.
  • open_instruct/grpo_fast.py
    • Introduced logic to call set_gradient_accumulation_boundary() before backward() and updated the condition for self.model.step() to use the new boundary flag.
Activity
  • No specific review comments or activity have been recorded for this pull request yet.
Using Gemini Code Assist

The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.

Invoking Gemini

You can request assistance from Gemini at any point by creating a comment using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands on the current page.

Feature Command Description
Code Review /gemini review Performs a code review for the current pull request in its current state.
Pull Request Summary /gemini summary Provides a summary of the current pull request in its current state.
Comment @gemini-code-assist Responds in comments when explicitly tagged, both in pull request comments and review comments.
Help /gemini help Displays a list of available commands.

Customization

To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.

Limitations & Feedback

Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here.

You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.

Footnotes

  1. Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution.

Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request addresses a critical bug in the gradient accumulation logic when using DeepSpeed ZeRO-2, where gradients were being discarded during manual accumulation. The fix correctly uses set_gradient_accumulation_boundary() to inform DeepSpeed about the accumulation boundaries, ensuring gradients from all micro-batches are properly accumulated. The implementation is correct and improves code clarity. I've only suggested a minor update to the changelog to replace the placeholder pull request URL.

Comment thread CHANGELOG.md Outdated
root and others added 2 commits February 26, 2026 21:59
Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
Copy link
Copy Markdown
Collaborator

@natolambert natolambert left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

lgtm! Nice find

@hamishivi hamishivi added this pull request to the merge queue Feb 26, 2026
Merged via the queue into main with commit 48664c5 Feb 26, 2026
7 checks passed
@hamishivi hamishivi deleted the fix/zero2-gradient-accumulation branch February 26, 2026 22:30
mnoukhov pushed a commit that referenced this pull request Feb 26, 2026
* Fix ZeRO-2 discarding gradients during manual gradient accumulation

With gradient_accumulation_steps=1 in the DeepSpeed config, every
backward() call is treated as a gradient accumulation boundary.
In ZeRO-2 (stage_1_and_2.py), hitting the boundary resets the internal
gradient accumulator (all_grad_tensors = None), so only the last
backward's gradient survives — all previous micro-batches are discarded.

ZeRO-3 doesn't have this bug because it uses a separate micro_step_id
counter that only resets in step()/zero_grad(), correctly accumulating
across all backward calls.

Fix: use DeepSpeed's set_gradient_accumulation_boundary() API to
explicitly mark only the final backward before step() as the boundary.
This works with varying accumulation step counts.

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>

* Update CHANGELOG with PR number

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>

* clean comment

---------

Co-authored-by: root <root@saturn-cs-aus-230.reviz.ai2.in>
Co-authored-by: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants

Morty Proxy This is a proxified and sanitized view of the page, visit original site.