Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Appearance settings

[None][fix] Fix errors in KV cache manager V2 and scheduler V2#13104

Merged
jiaganc merged 8 commits intoNVIDIA:mainNVIDIA/TensorRT-LLM:mainfrom
jiaganc:fix-v2jiaganc/TensorRT-LLM:fix-v2Copy head branch name to clipboard
Apr 22, 2026
Merged

[None][fix] Fix errors in KV cache manager V2 and scheduler V2#13104
jiaganc merged 8 commits intoNVIDIA:mainNVIDIA/TensorRT-LLM:mainfrom
jiaganc:fix-v2jiaganc/TensorRT-LLM:fix-v2Copy head branch name to clipboard

Conversation

@jiaganc
Copy link
Copy Markdown
Collaborator

@jiaganc jiaganc commented Apr 16, 2026

Summary by CodeRabbit

  • Bug Fixes
    • Fixed KV cache memory allocation issue in distributed inference scenarios where memory would accumulate across skipped iterations, improving overall memory efficiency during multi-rank inference operations.

Description

Two fixes for KV cache manager V2 in multi-rank scenarios:

  1. Sync KV cache quota across TP ranks: When mapping.world_size > 1, use allreduce(MIN) to ensure all TP ranks allocate the same KV cache quota. Without this, ranks with different available memory would get different quotas, causing the scheduler to produce different batches across ranks.

  2. Revert spurious KV cache capacity growth: When attention DP causes can_queue=False after scheduling (another rank has an empty batch), the forward pass is skipped but the V2 scheduler already grew each generation request's KV cache capacity. Add revert_allocate_generation() to undo that growth so it does not accumulate across skipped iterations and overflow the host page-index buffer. Applied in both the non-overlap and overlap executor loops.

Test Coverage

  • Multi-GPU TP tests with KV cache manager V2
  • Attention DP scenarios with mismatched batch sizes across ranks

PR Checklist

Please review the following before submitting your PR:

  • PR description clearly explains what and why. If using CodeRabbit's summary, please make sure it makes sense.

  • PR Follows TRT-LLM CODING GUIDELINES to the best of your knowledge.

  • Test cases are provided for new code paths (see test instructions)

  • Any new dependencies have been scanned for license and vulnerabilities

  • CODEOWNERS updated if ownership changes

  • Documentation updated as needed

  • Update tava architecture diagram if there is a significant design change in PR.

  • The reviewers assigned automatically/manually are appropriate for the PR.

  • Please check this after reviewing the above items as appropriate for this PR.

GitHub Bot Help

To see a list of available CI bot commands, please comment /bot help.

jiaganc added 2 commits April 15, 2026 22:21
Signed-off-by: Jiagan Cheng <jiaganc@nvidia.com>
When attention DP causes can_queue=False after scheduling, the forward
pass is skipped but the V2 scheduler already grew each generation
request's KV cache capacity. Add revert_allocate_generation() to undo
that spurious growth so it does not accumulate across skipped iterations
and overflow the host page-index buffer.

Partial cherry-pick of d8ff758 from tekit (revert_allocate_generation
parts only).

Signed-off-by: Jiagan Cheng <jiaganc@nvidia.com>
@jiaganc jiaganc requested review from a team as code owners April 16, 2026 05:25
@jiaganc jiaganc requested a review from joyang-nv April 16, 2026 05:25
@coderabbitai
Copy link
Copy Markdown
Contributor

coderabbitai Bot commented Apr 16, 2026

📝 Walkthrough

Walkthrough

Updated the KV cache management system to revert generation-phase allocations when certain rank scheduling conditions occur. Modified the executor loop to detect when only some ranks can queue, and added a new method to reverse KV cache capacity growth allocated during generation attempts.

Changes

Cohort / File(s) Summary
KV Cache Allocation Reversion
tensorrt_llm/_torch/pyexecutor/py_executor.py
Modified _executor_loop and _executor_loop_overlap to capture both can_queue and can_queue_this_rank flags, then conditionally call revert_allocate_generation() on each request when not can_queue and can_queue_this_rank and scheduler manages KV suspension.
KV Cache Manager Updates
tensorrt_llm/_torch/pyexecutor/resource_manager.py
Synchronized quota computation across ranks using allreduce with ReduceOp.MIN in KVCacheManagerV2.__init__. Added new revert_allocate_generation(req) method to reverse generation-phase KV capacity increases.

Estimated code review effort

🎯 3 (Moderate) | ⏱️ ~20 minutes

🚥 Pre-merge checks | ✅ 1 | ❌ 2

❌ Failed checks (2 warnings)

Check name Status Explanation Resolution
Docstring Coverage ⚠️ Warning Docstring coverage is 50.00% which is insufficient. The required threshold is 80.00%. Write docstrings for the functions missing them to satisfy the coverage threshold.
Title check ⚠️ Warning The PR title mentions 'Fix errors in KV cache manager V2 and scheduler V2', but the actual changes focus specifically on KV cache quota synchronization across ranks and a new revert_allocate_generation() method. The title is partially related but not fully aligned with the detailed objectives. Revise the title to more specifically reflect the main changes: 'Sync KV cache quota across TP ranks and add revert_allocate_generation()' or similar.
✅ Passed checks (1 passed)
Check name Status Explanation
Description check ✅ Passed The PR description provides clear explanations of both fixes, test coverage details, and confirms completion of the PR checklist items.

✏️ Tip: You can configure your own custom pre-merge checks in the settings.

✨ Finishing Touches
🧪 Generate unit tests (beta)
  • Create PR with unit tests

Comment @coderabbitai help to get the list of available commands and usage tips.

Copy link
Copy Markdown
Contributor

@coderabbitai coderabbitai Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 2

🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.

Inline comments:
In `@tensorrt_llm/_torch/pyexecutor/py_executor.py`:
- Around line 2064-2069: The PP executor loop _executor_loop_pp() currently
drops scheduled batches when ADP sets can_queue=False without undoing the
temporary KV-cache capacity growth; replicate the same revert logic used in the
main executor loop: when you detect the skip condition (can_queue is False and
can_queue_this_rank is True and self._scheduler_manages_kv_suspend) iterate
scheduled_batch.generation_requests and call
self.kv_cache_manager.revert_allocate_generation(req) for each request (same as
the block around scheduled_batch.generation_requests in the non-PP executor),
ensuring skipped iterations do not leave allocated KV capacity; apply the same
change in the other similar block referenced (around the other occurrence).

In `@tensorrt_llm/_torch/pyexecutor/resource_manager.py`:
- Around line 1741-1743: The current allreduce on `quota` (using
`Distributed.get(mapping).allreduce(quota, op=ReduceOp.MIN)`) is operating over
world ranks and incorrectly syncs a PP-rank-dependent byte quota; instead either
perform the reduction only within the TP group or convert to a TP-consistent
metric before reducing: locate the reduction near
`mapping`/`Distributed.get(mapping)` and change it to use the appropriate TP
group (pass the group parameter to `allreduce`) or compute a token-capacity
value by dividing `quota` by the local `bytes_per_token` (from
`get_cache_bytes_per_token()` / `mapping.pp_layers()`), allreduce that
TP-consistent token count with `ReduceOp.MIN`, then convert back to per-rank
bytes using each rank’s `bytes_per_token`; ensure `ReduceOp.MIN` remains correct
for the chosen normalization.
🪄 Autofix (Beta)

Fix all unresolved CodeRabbit comments on this PR:

  • Push a commit to this branch (recommended)
  • Create a new PR with the fixes

ℹ️ Review info
⚙️ Run configuration

Configuration used: Path: .coderabbit.yaml

Review profile: CHILL

Plan: Pro Plus

Run ID: 186fe76f-d388-4c70-830a-b56e9cd2adc8

📥 Commits

Reviewing files that changed from the base of the PR and between b01ff5e and bdf22ac.

📒 Files selected for processing (2)
  • tensorrt_llm/_torch/pyexecutor/py_executor.py
  • tensorrt_llm/_torch/pyexecutor/resource_manager.py

Comment thread tensorrt_llm/_torch/pyexecutor/py_executor.py Outdated
Comment thread tensorrt_llm/_torch/pyexecutor/resource_manager.py Outdated
Check the return value of kv_cache.resize and raise RuntimeError with
request ID and capacity details if it fails.

Signed-off-by: Jiagan Cheng <jiaganc@nvidia.com>
@jiaganc jiaganc changed the title [None][fix] Fix KV cache manager V2 multi-rank consistency [None][fix] Fix error in KV cache manager V2 and scheduler V2 Apr 16, 2026
@jiaganc jiaganc changed the title [None][fix] Fix error in KV cache manager V2 and scheduler V2 [None][fix] Fix errors in KV cache manager V2 and scheduler V2 Apr 16, 2026
Add the same KV cache capacity revert logic to _executor_loop_pp() so
that PP+ADP scenarios also undo spurious growth when can_queue is False
but this rank had a non-empty batch.

Signed-off-by: Jiagan Cheng <jiaganc@nvidia.com>
@jiaganc jiaganc requested review from lancelly and yizhang-nv April 16, 2026 06:06
bytes_per_token varies across PP ranks (different local layers), so
allreducing raw byte quotas with MIN would under-size stages with fewer
layers.  Convert to token capacity first (matching V1 behavior), take
MIN, then convert back to bytes per rank.

Signed-off-by: Jiagan Cheng <jiaganc@nvidia.com>
@jiaganc
Copy link
Copy Markdown
Collaborator Author

jiaganc commented Apr 16, 2026

/bot run

@tensorrt-cicd
Copy link
Copy Markdown
Collaborator

PR_Github #43698 [ run ] triggered by Bot. Commit: f3dfc0b Link to invocation

@tensorrt-cicd
Copy link
Copy Markdown
Collaborator

PR_Github #43698 [ run ] completed with state SUCCESS. Commit: f3dfc0b
/LLM/main/L0_MergeRequest_PR pipeline #34180 completed with status: 'FAILURE'

CI Report

⚠️ Action Required:

  • Please check the failed tests and fix your PR
  • If you cannot view the failures, ask the CI triggerer to share details
  • Once fixed, request an NVIDIA team member to trigger CI again

Link to invocation

Comment thread tensorrt_llm/_torch/pyexecutor/py_executor.py Outdated
Copy link
Copy Markdown
Collaborator

@lancelly lancelly left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Under the scheduler architecuture right now, we can workaround like this. The fix looks good to me, thanks~

jiaganc added 2 commits April 16, 2026 23:22
The revert_allocate_generation guard was duplicated across three executor
loops (PP, non-overlap, overlap).  Extract into a single helper method
_maybe_revert_kv_growth to reduce repetition.

Signed-off-by: Jiagan Cheng <jiaganc@nvidia.com>
Extract the repeated revert_allocate_generation guard into a single
helper method.  Drop the redundant can_queue_this_rank check (when it is
False the batch has no generation requests so the loop is a no-op).

Signed-off-by: Jiagan Cheng <jiaganc@nvidia.com>
@jiaganc
Copy link
Copy Markdown
Collaborator Author

jiaganc commented Apr 17, 2026

/bot run

@tensorrt-cicd
Copy link
Copy Markdown
Collaborator

PR_Github #43988 [ run ] triggered by Bot. Commit: f276ed7 Link to invocation

Move the 'not can_queue' guard out of the helper into the call sites so
the helper is unconditional; the V1/V2 check remains inside.  Rename
from _maybe_revert_gen_alloc to _revert_gen_alloc to reflect the new
signature.  Drop can_queue_this_rank captures where it's no longer used.

Signed-off-by: Jiagan Cheng <jiaganc@nvidia.com>
@tensorrt-cicd
Copy link
Copy Markdown
Collaborator

PR_Github #43988 [ run ] completed with state SUCCESS. Commit: f276ed7
/LLM/main/L0_MergeRequest_PR pipeline #34427 completed with status: 'FAILURE'

CI Report

⚠️ Action Required:

  • Please check the failed tests and fix your PR
  • If you cannot view the failures, ask the CI triggerer to share details
  • Once fixed, request an NVIDIA team member to trigger CI again

Link to invocation

@jiaganc
Copy link
Copy Markdown
Collaborator Author

jiaganc commented Apr 20, 2026

/bot run --disable-fail-fast

@tensorrt-cicd
Copy link
Copy Markdown
Collaborator

PR_Github #44281 [ run ] triggered by Bot. Commit: a477741 Link to invocation

@tensorrt-cicd
Copy link
Copy Markdown
Collaborator

PR_Github #44281 [ run ] completed with state FAILURE. Commit: a477741
/LLM/main/L0_MergeRequest_PR pipeline #34702 completed with status: 'FAILURE'

CI Report

⚠️ Action Required:

  • Please check the failed tests and fix your PR
  • If you cannot view the failures, ask the CI triggerer to share details
  • Once fixed, request an NVIDIA team member to trigger CI again

Link to invocation

@nvpohanh
Copy link
Copy Markdown
Collaborator

@lowsfer could you review this? thanks

@jiaganc
Copy link
Copy Markdown
Collaborator Author

jiaganc commented Apr 20, 2026

/bot run --disable-fail-fast

@tensorrt-cicd
Copy link
Copy Markdown
Collaborator

PR_Github #44404 [ run ] triggered by Bot. Commit: a477741 Link to invocation

@tensorrt-cicd
Copy link
Copy Markdown
Collaborator

PR_Github #44404 [ run ] completed with state SUCCESS. Commit: a477741
/LLM/main/L0_MergeRequest_PR pipeline #34819 completed with status: 'FAILURE'

CI Report

⚠️ Action Required:

  • Please check the failed tests and fix your PR
  • If you cannot view the failures, ask the CI triggerer to share details
  • Once fixed, request an NVIDIA team member to trigger CI again

Link to invocation

@jiaganc
Copy link
Copy Markdown
Collaborator Author

jiaganc commented Apr 21, 2026

/bot run --disable-fail-fast

@tensorrt-cicd
Copy link
Copy Markdown
Collaborator

PR_Github #44593 [ run ] triggered by Bot. Commit: a477741 Link to invocation

@tensorrt-cicd
Copy link
Copy Markdown
Collaborator

PR_Github #44593 [ run ] completed with state SUCCESS. Commit: a477741
/LLM/main/L0_MergeRequest_PR pipeline #34979 completed with status: 'FAILURE'

CI Report

⚠️ Action Required:

  • Please check the failed tests and fix your PR
  • If you cannot view the failures, ask the CI triggerer to share details
  • Once fixed, request an NVIDIA team member to trigger CI again

Link to invocation

@jiaganc
Copy link
Copy Markdown
Collaborator Author

jiaganc commented Apr 21, 2026

/bot run

@tensorrt-cicd
Copy link
Copy Markdown
Collaborator

PR_Github #44669 [ run ] triggered by Bot. Commit: a477741 Link to invocation

@tensorrt-cicd
Copy link
Copy Markdown
Collaborator

PR_Github #44669 [ run ] completed with state SUCCESS. Commit: a477741
/LLM/main/L0_MergeRequest_PR pipeline #35040 completed with status: 'SUCCESS'
Pipeline passed with automatic retried tests. Check the rerun report for details.

CI Report

Link to invocation

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants

Morty Proxy This is a proxified and sanitized view of the page, visit original site.