[TRTLLM-11471][fix] Eliminate redundant serialization and MPI collectives in safe_allgather/safe_gather#13089
[TRTLLM-11471][fix] Eliminate redundant serialization and MPI collectives in safe_allgather/safe_gather#13089pcastonguay merged 3 commits intoNVIDIA:mainNVIDIA/TensorRT-LLM:mainfrom chienchunhung:fix/safe-mpi-comm-perf-regressionchienchunhung/TensorRT-LLM:fix/safe-mpi-comm-perf-regressionCopy head branch name to clipboard
Conversation
|
No actionable comments were generated in the recent review. 🎉 ℹ️ Recent review info⚙️ Run configurationConfiguration used: Path: .coderabbit.yaml Review profile: CHILL Plan: Pro Plus Run ID: 📒 Files selected for processing (2)
📝 WalkthroughWalkthroughThis change refactors the distributed communicator to optimize collective operations by separating serialization and length exchange from data transfer. A new Changes
Sequence Diagram(s)sequenceDiagram
actor Rank0 as Rank 0
actor Rank1 as Rank 1
participant MPI
Rank0->>Rank0: Pickle serialize obj
Rank1->>Rank1: Pickle serialize obj
Note over Rank0,Rank1: Step 1: Exchange Lengths
Rank0->>MPI: Allgather(sendbuf=[len0])
Rank1->>MPI: Allgather(sendbuf=[len1])
MPI-->>Rank0: recvbuf=[len0, len1]
MPI-->>Rank1: recvbuf=[len0, len1]
Note over Rank0,Rank1: Step 2: Conditional Data Transfer
alt Fits in int32
Rank0->>MPI: Gatherv/Allgatherv (int32 counts/displs)
Rank1->>MPI: Gatherv/Allgatherv (int32 counts/displs)
else Exceeds int32
loop For each chunk
Rank0->>MPI: Gatherv/Allgatherv (chunked)
Rank1->>MPI: Gatherv/Allgatherv (chunked)
end
end
Rank0->>Rank0: Unpickle deserialize
Rank1->>Rank1: Unpickle deserialize
Estimated code review effort🎯 3 (Moderate) | ⏱️ ~25 minutes 🚥 Pre-merge checks | ✅ 2 | ❌ 1❌ Failed checks (1 warning)
✅ Passed checks (2 passed)
✏️ Tip: You can configure your own custom pre-merge checks in the settings. ✨ Finishing Touches🧪 Generate unit tests (beta)
Comment |
|
/bot help |
GitHub Bot Help
Provide a user friendly way for developers to interact with a Jenkins server. Run See details below for each supported subcommand. Details
Launch build/test pipelines. All previously running jobs will be killed.
kill
Kill all running builds associated with pull request. skip
Skip testing for latest commit on pull request. reuse-pipeline
Reuse a previous pipeline to validate current commit. This action will also kill all currently running builds associated with the pull request. IMPORTANT NOTE: This is dangerous since lack of user care and validation can cause top of tree to break. |
|
/bot run --disable-fail-fast --add-multi-gpu-test |
56a5dfb to
fc936fd
Compare
|
/bot run --disable-fail-fast --add-multi-gpu-test |
|
PR_Github #43839 [ run ] triggered by Bot. Commit: |
|
PR_Github #43839 [ run ] completed with state
|
fc936fd to
223cc98
Compare
|
/bot run --disable-fail-fast |
|
PR_Github #44506 [ run ] triggered by Bot. Commit: |
|
PR_Github #44506 [ run ] completed with state
|
223cc98 to
28bf48a
Compare
|
/bot run --disable-fail-fast |
|
PR_Github #45009 [ run ] triggered by Bot. Commit: |
|
PR_Github #45009 [ run ] completed with state |
|
/bot run --disable-fail-fast |
|
PR_Github #45025 [ run ] triggered by Bot. Commit: |
|
PR_Github #45025 [ run ] completed with state
|
f062b6b to
8211f2b
Compare
|
/bot run --disable-fail-fast |
|
PR_Github #45317 [ run ] triggered by Bot. Commit: |
|
PR_Github #45317 [ run ] completed with state |
8211f2b to
11370d2
Compare
|
/bot run --disable-fail-fast |
|
PR_Github #45475 [ run ] triggered by Bot. Commit: |
|
PR_Github #45475 [ run ] completed with state
|
11370d2 to
e2e72ca
Compare
|
/bot run --disable-fail-fast |
|
PR_Github #45651 [ run ] triggered by Bot. Commit: |
e2e72ca to
e3dfa70
Compare
|
/bot run --disable-fail-fast |
|
PR_Github #45769 [ run ] triggered by Bot. Commit: |
|
PR_Github #45769 [ run ] completed with state |
|
/bot run --disable-fail-fast |
|
PR_Github #45989 [ run ] triggered by Bot. Commit: |
|
PR_Github #45989 [ run ] completed with state |
… safe_allgather/safe_gather The original implementation serialized objects and exchanged lengths via Python-level comm.allgather (2 internal MPI collectives + 1 serialization), then called comm.allgather/gather again for the data transfer (2 more MPI collectives + 1 more serialization) — totaling 4 MPI collectives and 3 serializations per call. This rewrites the functions to: 1. Serialize once with pickle.dumps 2. Exchange lengths via buffer-based MPI_Allgather (1 collective) 3. Transfer raw bytes via MPI_Allgatherv/Gatherv (1 collective) This matches the collective count that mpi4py's comm.allgather(obj) uses internally (2 collectives, 1 serialization) while preserving the >2GB chunking safety for payloads exceeding the int32 displacement limit. Additional fixes: - Preserve exception chain on serialization failures (from exc) - Log when entering the chunked transfer path (>2GB payloads) - Remove dead size>0 guards (MPI guarantees size>=1) - Fix test docstring referencing wrong filename - Add CommSpy-based tests verifying exact collective and serialization counts Signed-off-by: Chien-Chun Hung <2679986+chienchunhung@users.noreply.github.com>
…ad of backslash continuation Signed-off-by: Chien-Chun Hung <2679986+chienchunhung@users.noreply.github.com>
Signed-off-by: Chien-Chun Hung <2679986+chienchunhung@users.noreply.github.com>
e3dfa70 to
ee36d04
Compare
|
/bot run --disable-fail-fast |
|
PR_Github #46219 [ run ] triggered by Bot. Commit: |
|
PR_Github #46219 [ run ] completed with state
|
|
/bot run --disable-fail-fast |
|
PR_Github #46423 [ run ] triggered by Bot. Commit: |
|
PR_Github #46423 [ run ] completed with state |
Summary by CodeRabbit
Performance
Bug Fixes
Description
Summary
safe_allgather/safe_gatherintroduced by PR [TRTLLM-11471][feat] Add safe version of allgather with chunking #12174: every call was doing 4 MPI collectives + 3 serializations instead ofthe original 2 + 1.
MPI_Allgather(1 collective), then transfer raw bytes viaMPI_Allgatherv/MPI_Gatherv(1 collective) — matching what mpi4py'scomm.allgather(obj)does internally.Before PR #12174 (original mpi4py)
PR #12174 (introduced regression)
This PR (fix)
Additional improvements
from local_ser_error)size > 0guards (MPI guaranteessize >= 1)CommSpy-based tests verifying exact collective and serialization countsTest coverage
TestSafeAllgather/TestSafeGathertests + CommSpy teststest_*_large_object,test_*_multi_round_chunkingtest_*_displacement_correctness_asymmetricchunk_sizevalidationtest_*_invalid_chunk_sizechunk_sizeauto-cappingtest_allgather_chunk_size_auto_cappedtest_*_uses_exactly_two_collectivespickle.dumps)test_*_serializes_oncetest_allgather_none,test_allgather_empty_collectionstest_gather_non_zero_roottest_allgather_cross_rank_consistencyTestMPIDistAllgather,TestMPIDistGatherPR Checklist
Please review the following before submitting your PR:
PR description clearly explains what and why. If using CodeRabbit's summary, please make sure it makes sense.
PR Follows TRT-LLM CODING GUIDELINES to the best of your knowledge.
Test cases are provided for new code paths (see test instructions)
Any new dependencies have been scanned for license and vulnerabilities
CODEOWNERS updated if ownership changes
Documentation updated as needed
Update tava architecture diagram if there is a significant design change in PR.
The reviewers assigned automatically/manually are appropriate for the PR.
Please check this after reviewing the above items as appropriate for this PR.
GitHub Bot Help
To see a list of available CI bot commands, please comment
/bot help.