Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Appearance settings

[None][fix] Fix compute token accounting for KV cache reuse with context chunking#12976

Merged
longlee0622 merged 10 commits intoNVIDIA:mainNVIDIA/TensorRT-LLM:mainfrom
lancelly:fix_resue_computelancelly/TensorRT-LLM:fix_resue_computeCopy head branch name to clipboard
Apr 18, 2026
Merged

[None][fix] Fix compute token accounting for KV cache reuse with context chunking#12976
longlee0622 merged 10 commits intoNVIDIA:mainNVIDIA/TensorRT-LLM:mainfrom
lancelly:fix_resue_computelancelly/TensorRT-LLM:fix_resue_computeCopy head branch name to clipboard

Conversation

@lancelly
Copy link
Copy Markdown
Collaborator

@lancelly lancelly commented Apr 13, 2026

Summary

  • Fix compute token accounting when KV cache reuse is combined with context chunking: setPrepopulatedPromptLen shifts the chunk window right by the reused amount rather than shrinking it, so non-last chunks still process ~chunkSize tokens
  • C++: Add reuse_adjusted_compute() helper and replace all 6 manual max(0, chunkSize - reusable) formulas in microBatchScheduler.cpp
  • Python: Rewrite _compute_scheduled_tokens with V1/V2 scheduler awareness and correct non-last-chunk handling in py_executor.py
  • Add unit tests for both C++ (ReusableTokensChunkShiftNonLastChunk) and Python (TestComputeScheduledTokens) covering the chunk-shift scenario

@lancelly lancelly requested a review from a team as a code owner April 13, 2026 03:32
@lancelly lancelly requested a review from dongxuy04 April 13, 2026 03:32
@coderabbitai
Copy link
Copy Markdown
Contributor

coderabbitai Bot commented Apr 13, 2026

📝 Walkthrough

Walkthrough

This change introduces a reuse_adjusted_compute helper function to the micro batch scheduler to properly account for reusable token credits in chunk-based scheduling, updating compute accounting across multiple scheduling paths. The PyExecutor was refactored to remove DwdpManager support, KV-cache baselines, and MPI workarounds while implementing a new _compute_scheduled_tokens method to improve scheduled token estimation. Unit tests were added to validate both the scheduler's compute accounting and the executor's token estimation logic.

Changes

Cohort / File(s) Summary
Micro Batch Scheduler Core
cpp/tensorrt_llm/batch_manager/microBatchScheduler.cpp
Introduced reuse_adjusted_compute helper function to properly account for reusable token credits in chunk scheduling. Updated compute accounting in fitDraftTokens, setCtxRequestsChunkSize (both scheduling policies), and operator() to use the new helper instead of direct subtraction logic.
PyExecutor Refactoring
tensorrt_llm/_torch/pyexecutor/py_executor.py
Removed DwdpManager support from __init__, eliminated KV-cache warmup baselines, MPI deadlock workaround, and fast-transfer response path. Implemented new _compute_scheduled_tokens method to account for reusable-prefix credits. Simplified disagg context status checks and adjusted resource update calls.
Scheduler Unit Tests
cpp/tests/unit_tests/batch_manager/microBatchSchedulerTest.cpp
Added ReusableTokensChunkShiftNonLastChunk test case validating that reusable-token compute accounting for non-last chunks charges the full contextChunkSize rather than subtracting reusable tokens.
PyExecutor Unit Tests
tests/unittest/_torch/executor/test_py_executor.py
Added TestComputeScheduledTokens test class with helper factories and comprehensive test cases covering no-reuse, last-chunk reuse, non-last-chunk chunk-shift reuse, generation token accounting, and mixed context/generation scenarios.

Estimated code review effort

🎯 4 (Complex) | ⏱️ ~45 minutes

🚥 Pre-merge checks | ✅ 2 | ❌ 1

❌ Failed checks (1 warning)

Check name Status Explanation Resolution
Docstring Coverage ⚠️ Warning Docstring coverage is 43.48% which is insufficient. The required threshold is 80.00%. Write docstrings for the functions missing them to satisfy the coverage threshold.
✅ Passed checks (2 passed)
Check name Status Explanation
Title check ✅ Passed The title clearly and specifically describes the main change: fixing compute token accounting for KV cache reuse with context chunking.
Description check ✅ Passed The PR description covers the core issue, C++ and Python changes, and test additions. However, the test plan checkbox items are incomplete (unchecked) as noted in the objectives.

✏️ Tip: You can configure your own custom pre-merge checks in the settings.

✨ Finishing Touches
🧪 Generate unit tests (beta)
  • Create PR with unit tests

Comment @coderabbitai help to get the list of available commands and usage tips.

Copy link
Copy Markdown
Contributor

@coderabbitai coderabbitai Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 2

🧹 Nitpick comments (1)
tensorrt_llm/_torch/pyexecutor/py_executor.py (1)

1737-1737: Consider if 1-second sleep is appropriate for the benchmark fill loop.

This sleep is in a tight loop that fills benchmark requests. While 1 second is more reasonable than some other sleep changes in this PR, it could still slow down the benchmark initialization if many iterations are needed.

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@tensorrt_llm/_torch/pyexecutor/py_executor.py` at line 1737, The tight
benchmark fill loop currently uses time.sleep(1), which can unnecessarily slow
benchmark initialization; change this to a small, configurable sleep and/or
remove the long fixed delay: replace the literal time.sleep(1) in the benchmark
fill loop with a configurable constant (e.g., BENCH_FILL_SLEEP_SEC) set to a
small value like 0.01 by default, or use time.sleep(0) to only yield the thread,
and expose that constant as a parameter or config option in the surrounding
function/class so callers can tune it for their workload (locate the literal
time.sleep(1) in the benchmark fill loop and update it and the surrounding
signature/config accordingly).
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.

Inline comments:
In `@cpp/tensorrt_llm/batch_manager/microBatchScheduler.cpp`:
- Around line 27-44: reuse_adjusted_compute currently returns 0 when a reusable
prefix fully covers the last chunk, letting scheduled (non-zero) chunks become
free; update the logic so any non-zero scheduled chunk is floored to at least 1
token like the no-chunk path does. Modify reuse_adjusted_compute (and the
chunked call sites that consume its result around lines referenced) to: compute
the adjusted cost as now but then if the original chunkSize was > 0 and the
computed value is 0, return 1 instead of 0; ensure all callers that use
reuse_adjusted_compute (including the other chunked sections noted) rely on that
floored value so batchNumTokens stays consistent with the no-chunking path.

In `@tests/unittest/_torch/executor/test_py_executor.py`:
- Around line 251-370: Run the project formatter (ruff format) on the added test
block containing the helpers _make_ctx_request and _make_gen_request and the
TestComputeScheduledTokens class so the new tests (methods like test_no_reuse,
test_last_chunk_with_reuse, etc.) conform to the repo's ruff formatting rules;
re-stage the file after formatting and push the changes to fix the failing
ruff-format CI hook.

---

Nitpick comments:
In `@tensorrt_llm/_torch/pyexecutor/py_executor.py`:
- Line 1737: The tight benchmark fill loop currently uses time.sleep(1), which
can unnecessarily slow benchmark initialization; change this to a small,
configurable sleep and/or remove the long fixed delay: replace the literal
time.sleep(1) in the benchmark fill loop with a configurable constant (e.g.,
BENCH_FILL_SLEEP_SEC) set to a small value like 0.01 by default, or use
time.sleep(0) to only yield the thread, and expose that constant as a parameter
or config option in the surrounding function/class so callers can tune it for
their workload (locate the literal time.sleep(1) in the benchmark fill loop and
update it and the surrounding signature/config accordingly).
🪄 Autofix (Beta)

Fix all unresolved CodeRabbit comments on this PR:

  • Push a commit to this branch (recommended)
  • Create a new PR with the fixes

ℹ️ Review info
⚙️ Run configuration

Configuration used: Path: .coderabbit.yaml

Review profile: CHILL

Plan: Pro

Run ID: 84f31d79-e8ec-4ea6-b627-3e2929ba4f5c

📥 Commits

Reviewing files that changed from the base of the PR and between cea18ab and d173fb1.

📒 Files selected for processing (4)
  • cpp/tensorrt_llm/batch_manager/microBatchScheduler.cpp
  • cpp/tests/unit_tests/batch_manager/microBatchSchedulerTest.cpp
  • tensorrt_llm/_torch/pyexecutor/py_executor.py
  • tests/unittest/_torch/executor/test_py_executor.py

Comment thread cpp/tensorrt_llm/batch_manager/microBatchScheduler.cpp
Comment thread tests/unittest/_torch/executor/test_py_executor.py
@lancelly
Copy link
Copy Markdown
Collaborator Author

/bot run --disable-fail-fast

@lancelly lancelly requested a review from liji-nv April 13, 2026 04:23
@lancelly
Copy link
Copy Markdown
Collaborator Author

@liji-nv @SimengLiu-nv Hi, Could you please help to review this PR? Combined with #12878, it should be able to fix the MNT errors we encountered.

@tensorrt-cicd
Copy link
Copy Markdown
Collaborator

PR_Github #42946 [ run ] triggered by Bot. Commit: 576c810 Link to invocation

…ext chunking

When KV cache reuse is combined with context chunking,
setPrepopulatedPromptLen shifts the chunk window right by the reused
amount rather than shrinking it. Non-last chunks still process
~chunkSize tokens in the forward pass.

The old formula (chunkSize - reusable) underestimated compute cost for
non-last chunks, causing the scheduler to over-pack batches beyond the
token budget. This adds reuse_adjusted_compute() to correctly account
for chunk-shift behavior, and updates _waiting_requests in PyExecutor
to subtract reusable tokens from scheduled token counts.

Signed-off-by: Lanyu Liao <lancelly@users.noreply.github.com>
@lancelly lancelly force-pushed the fix_resue_compute branch from 07e89ae to 4655598 Compare April 13, 2026 14:46
@lancelly
Copy link
Copy Markdown
Collaborator Author

/bot run --disable-fail-fast

@tensorrt-cicd
Copy link
Copy Markdown
Collaborator

PR_Github #43063 [ run ] triggered by Bot. Commit: 4655598 Link to invocation

@longlee0622 longlee0622 enabled auto-merge (squash) April 14, 2026 01:43
@lancelly
Copy link
Copy Markdown
Collaborator Author

/bot run --disable-fail-fast

@tensorrt-cicd
Copy link
Copy Markdown
Collaborator

PR_Github #43140 [ run ] triggered by Bot. Commit: 87be42c Link to invocation

Comment thread tensorrt_llm/_torch/pyexecutor/py_executor.py Outdated
Comment thread cpp/tensorrt_llm/batch_manager/microBatchScheduler.cpp
@tensorrt-cicd
Copy link
Copy Markdown
Collaborator

PR_Github #43140 [ run ] completed with state SUCCESS. Commit: 87be42c
/LLM/main/L0_MergeRequest_PR pipeline #33772 completed with status: 'FAILURE'

CI Report

⚠️ Action Required:

  • Please check the failed tests and fix your PR
  • If you cannot view the failures, ask the CI triggerer to share details
  • Once fixed, request an NVIDIA team member to trigger CI again

Link to invocation

lancelly and others added 3 commits April 14, 2026 23:02
…ext chunking

Introduce _reuse_adjusted_compute() to correctly calculate forward-pass
cost when setPrepopulatedPromptLen shifts the chunk window. Non-last
chunks cost chunkSize; last chunks cost contextRemaining - reusable.
Update C++ and Python tests to match the corrected behavior.

Signed-off-by: Lanyu Liao <lancelly@users.noreply.github.com>
@lancelly
Copy link
Copy Markdown
Collaborator Author

/bot run --disable-fail-fast

@tensorrt-cicd
Copy link
Copy Markdown
Collaborator

PR_Github #43249 [ run ] triggered by Bot. Commit: b8db9e7 Link to invocation

@tensorrt-cicd
Copy link
Copy Markdown
Collaborator

PR_Github #43249 [ run ] completed with state DISABLED
Freeze main and open the PR merge only after CI is back to healthy https://nvidia.slack.com/archives/C059LSY62BT/p1776141760843319?thread_ts=1775985925.442509&cid=C059LSY62BT

Link to invocation

@lancelly
Copy link
Copy Markdown
Collaborator Author

/bot run --disable-fail-fast

1 similar comment
@lancelly
Copy link
Copy Markdown
Collaborator Author

/bot run --disable-fail-fast

@tensorrt-cicd
Copy link
Copy Markdown
Collaborator

PR_Github #43322 [ run ] triggered by Bot. Commit: f850396 Link to invocation

@lancelly
Copy link
Copy Markdown
Collaborator Author

/bot run --disable-fail-fast

@tensorrt-cicd
Copy link
Copy Markdown
Collaborator

PR_Github #43482 [ run ] triggered by Bot. Commit: 7702dc6 Link to invocation

@tensorrt-cicd
Copy link
Copy Markdown
Collaborator

PR_Github #43482 [ run ] completed with state SUCCESS. Commit: 7702dc6
/LLM/main/L0_MergeRequest_PR pipeline #33999 completed with status: 'FAILURE'

CI Report

⚠️ Action Required:

  • Please check the failed tests and fix your PR
  • If you cannot view the failures, ask the CI triggerer to share details
  • Once fixed, request an NVIDIA team member to trigger CI again

Link to invocation

@lancelly
Copy link
Copy Markdown
Collaborator Author

/bot run --disable-fail-fast

@tensorrt-cicd
Copy link
Copy Markdown
Collaborator

PR_Github #43599 [ run ] triggered by Bot. Commit: 3a371ff Link to invocation

@lancelly
Copy link
Copy Markdown
Collaborator Author

/bot run --disable-fail-fast

@tensorrt-cicd
Copy link
Copy Markdown
Collaborator

PR_Github #43740 [ run ] triggered by Bot. Commit: 40305b8 Link to invocation

@longlee0622
Copy link
Copy Markdown
Collaborator

/bot run --disable-fail-fast

@tensorrt-cicd
Copy link
Copy Markdown
Collaborator

PR_Github #43760 [ run ] triggered by Bot. Commit: 40305b8 Link to invocation

@tensorrt-cicd
Copy link
Copy Markdown
Collaborator

PR_Github #43760 [ run ] completed with state SUCCESS. Commit: 40305b8
/LLM/main/L0_MergeRequest_PR pipeline #34244 completed with status: 'FAILURE'

CI Report

⚠️ Action Required:

  • Please check the failed tests and fix your PR
  • If you cannot view the failures, ask the CI triggerer to share details
  • Once fixed, request an NVIDIA team member to trigger CI again

Link to invocation

@lancelly
Copy link
Copy Markdown
Collaborator Author

/bot run --disable-fail-fast

@tensorrt-cicd
Copy link
Copy Markdown
Collaborator

PR_Github #43888 [ run ] triggered by Bot. Commit: 40305b8 Link to invocation

@tensorrt-cicd
Copy link
Copy Markdown
Collaborator

PR_Github #43888 [ run ] completed with state FAILURE. Commit: 40305b8
/LLM/main/L0_MergeRequest_PR pipeline #34340 completed with status: 'FAILURE'

CI Report

⚠️ Action Required:

  • Please check the failed tests and fix your PR
  • If you cannot view the failures, ask the CI triggerer to share details
  • Once fixed, request an NVIDIA team member to trigger CI again

Link to invocation

@lancelly
Copy link
Copy Markdown
Collaborator Author

/bot run --disable-fail-fast

@tensorrt-cicd
Copy link
Copy Markdown
Collaborator

PR_Github #43918 [ run ] triggered by Bot. Commit: 40305b8 Link to invocation

@tensorrt-cicd
Copy link
Copy Markdown
Collaborator

PR_Github #43918 [ run ] completed with state ABORTED. Commit: 40305b8

Link to invocation

@lancelly
Copy link
Copy Markdown
Collaborator Author

/bot run --disable-fail-fast

@tensorrt-cicd
Copy link
Copy Markdown
Collaborator

PR_Github #44095 [ run ] triggered by Bot. Commit: 40305b8 Link to invocation

Saddss added a commit to Saddss/TensorRT-LLM that referenced this pull request Apr 18, 2026
…ext chunking

Squash-merge of upstream PR NVIDIA#12976 (author: @lancelly).
setPrepopulatedPromptLen shifts the chunk window right by the reused amount
rather than shrinking it, so non-last chunks still process ~chunk_size tokens;
both microBatchScheduler.cpp and the Python PyMicroBatchScheduler are updated
to go through a single reuse_adjusted_compute helper, and py_executor.py gets
a new _compute_scheduled_tokens used by the batch-waiting heuristic.

Fixes the 'total_num_tokens should be less than or equal to max_num_tokens'
AssertionError at model_engine.py:2678 that appeared sporadically (2-7 h to
reproduce at production QPS) on native-offload configs with
enable_chunked_prefill + enable_block_reuse + host_cache_size.

Upstream PR: NVIDIA#12976

Signed-off-by: Saddss <2872669061@qq.com>
Made-with: Cursor
@tensorrt-cicd
Copy link
Copy Markdown
Collaborator

PR_Github #44095 [ run ] completed with state SUCCESS. Commit: 40305b8
/LLM/main/L0_MergeRequest_PR pipeline #34524 completed with status: 'SUCCESS'

CI Report

Link to invocation

@longlee0622 longlee0622 merged commit ea117de into NVIDIA:main Apr 18, 2026
5 checks passed
nv-yna added a commit to nv-yna/TensorRT-LLM that referenced this pull request Apr 29, 2026
…t overflow

Cherry-pick of Python portions from feat/bench_y PRs NVIDIA#12682 and NVIDIA#12806
that were not included in NVIDIA#12976 (which ported only the C++ fix to main).

Adds a remaining_budget re-validation guard in KVCacheManager.prepare_resources()
that re-probes the radix tree for actual reusable blocks after KV cache allocation
and skips requests whose forward cost exceeds the remaining budget. This catches
the estimation-vs-reality gap when cache eviction between scheduling and
prepare_resources() reduces actual reuse below the scheduler's estimate.

Original authors: Liao Lanyu (@lancelly), Jin Li (@liji-nv)

Signed-off-by: Yuewei Na <nv-yna@users.noreply.github.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants

Morty Proxy This is a proxified and sanitized view of the page, visit original site.