Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Appearance settings

Commit c921c5c

Browse filesBrowse files
kwen2501pytorchmergebot
authored andcommitted
[c10d] Print certain logs only on head rank of each node (#125432)
Recently we added the following warning, which is printed on every rank and makes the log a bit verbose. This PR dedups certain logs that are identical across ranks and prints them only on head rank of each node. Resolves #126275 ========================================= [rank0]:[W502 14:06:55.821964708 ProcessGroupNCCL.cpp:1113] WARNING: process group has NOT been destroyed before it is being destructed. On normal program exit, the application should call destroy_process_group to ensure that any pending NCCL data transfers have finished in this process. In rare cases this process can exit before this point and block the progress of another member of the process group. This constraint has always been present, but this warning has only been added since PyTorch 2.4 [rank1]:[W502 14:06:57.994276972 ProcessGroupNCCL.cpp:1113] WARNING: process group has NOT been destroyed before it is being destructed. On normal program exit, the application should call destroy_process_group to ensure that any pending NCCL data transfers have finished in this process. In rare cases this process can exit before this point and block the progress of another member of the process group. This constraint has always been present, but this warning has only been added since PyTorch 2.4 [rank2]:[W502 14:07:00.353013116 ProcessGroupNCCL.cpp:1113] WARNING: process group has NOT been destroyed before it is being destructed. On normal program exit, the application should call destroy_process_group to ensure that any pending NCCL data transfers have finished in this process. In rare cases this process can exit before this point and block the progress of another member of the process group. This constraint has always been present, but this warning has only been added since PyTorch 2.4 [rank3]:[W502 14:07:02.515511670 ProcessGroupNCCL.cpp:1113] WARNING: process group has NOT been destroyed before it is being destructed. On normal program exit, the application should call destroy_process_group to ensure that any pending NCCL data transfers have finished in this process. In rare cases this process can exit before this point and block the progress of another member of the process group. This constraint has always been present, but this warning has only been added since PyTorch 2.4 Pull Request resolved: #125432 Approved by: https://github.com/wconstab
1 parent 0625f92 commit c921c5c
Copy full SHA for c921c5c

File tree

Expand file treeCollapse file tree

2 files changed

+25
-18
lines changed
Filter options
Expand file treeCollapse file tree

2 files changed

+25
-18
lines changed

‎torch/csrc/distributed/c10d/ProcessGroupNCCL.cpp

Copy file name to clipboardExpand all lines: torch/csrc/distributed/c10d/ProcessGroupNCCL.cpp
+22-18Lines changed: 22 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -398,8 +398,7 @@ at::Device ProcessGroupNCCL::guessDeviceForRank() const {
398398
if (getBoundDeviceId()) {
399399
return *getBoundDeviceId();
400400
} else {
401-
auto numGPUs = at::cuda::getNumGPUs();
402-
int16_t deviceIdx = static_cast<int16_t>(rank_ % numGPUs);
401+
int16_t deviceIdx = static_cast<int16_t>(rank_ % localDeviceCount_);
403402
return at::Device(at::DeviceType::CUDA, deviceIdx);
404403
}
405404
}
@@ -740,6 +739,7 @@ ProcessGroupNCCL::ProcessGroupNCCL(
740739
at::cuda::getNumGPUs() != 0,
741740
"ProcessGroupNCCL is only supported with GPUs, no GPUs found!");
742741
this->setGroupName(options_->group_name);
742+
this->localDeviceCount_ = at::cuda::getNumGPUs();
743743
logPrefix_ = createLogPrefix();
744744
blockingWait_ = getCvarBool(TORCH_NCCL_BLOCKING_WAIT, false);
745745
asyncErrorHandling_ = static_cast<ErrorHandlingMode>(
@@ -816,20 +816,23 @@ ProcessGroupNCCL::ProcessGroupNCCL(
816816
std::string torch_distributed_debug =
817817
getCvarString({"TORCH_DISTRIBUTED_DEBUG"}, OFF.c_str());
818818
LOG(INFO) << logPrefix() << "ProcessGroupNCCL initialization options: "
819-
<< "NCCL version: " << getNcclVersion() << ", size: " << size
820-
<< ", global rank: " << globalRank()
819+
<< "size: " << size << ", global rank: " << globalRank()
820+
<< ", TIMEOUT(ms): " << options_->timeout.count()
821+
<< ", USE_HIGH_PRIORITY_STREAM: "
822+
<< options_->is_high_priority_stream
823+
<< ", SPLIT_FROM: " << options_->split_from
824+
<< ", SPLIT_COLOR: " << options_->split_color
825+
<< ", PG Name: " << options_->group_name;
826+
827+
LOG(INFO) << logPrefix() << "ProcessGroupNCCL environments: "
828+
<< "NCCL version: " << getNcclVersion()
821829
<< ", TORCH_NCCL_ASYNC_ERROR_HANDLING: " << asyncErrorHandling_
822830
<< ", TORCH_NCCL_DUMP_ON_TIMEOUT: " << dumpOnException_
823831
<< ", TORCH_NCCL_WAIT_TIMEOUT_DUMP_MILSEC: "
824832
<< waitTimeoutDumpInMilSec_
825833
<< ", TORCH_NCCL_DESYNC_DEBUG: " << desyncDebug_
826834
<< ", TORCH_NCCL_ENABLE_TIMING: " << enableTiming_.load()
827835
<< ", TORCH_NCCL_BLOCKING_WAIT: " << blockingWait_
828-
<< ", TIMEOUT(ms): " << options_->timeout.count()
829-
<< ", USE_HIGH_PRIORITY_STREAM: "
830-
<< options_->is_high_priority_stream
831-
<< ", SPLIT_FROM: " << options_->split_from
832-
<< ", SPLIT_COLOR: " << options_->split_color
833836
<< ", TORCH_DISTRIBUTED_DEBUG: " << torch_distributed_debug
834837
#ifdef NCCL_HAS_COMM_REGISTER
835838
<< ", TORCH_NCCL_USE_TENSOR_REGISTER_ALLOCATOR_HOOK: "
@@ -840,8 +843,7 @@ ProcessGroupNCCL::ProcessGroupNCCL(
840843
<< ", TORCH_NCCL_HEARTBEAT_TIMEOUT_SEC: " << heartbeatTimeoutInSec_
841844
<< ", TORCH_NCCL_TRACE_BUFFER_SIZE: " << ncclTraceBufferSize_
842845
<< ", TORCH_NCCL_COORD_CHECK_MILSEC: " << coordCheckIntervalMilSec_
843-
<< ", TORCH_NCCL_NAN_CHECK: " << enableNanCheck_
844-
<< ", PG Name: " << options_->group_name;
846+
<< ", TORCH_NCCL_NAN_CHECK: " << enableNanCheck_;
845847

846848
if (options_->global_ranks_in_group.empty()) {
847849
this->globalRankStart = 0;
@@ -1119,13 +1121,15 @@ ProcessGroupNCCL::~ProcessGroupNCCL() {
11191121
LOG(INFO) << logPrefix() << "ProcessGroupNCCL destructor entered.";
11201122

11211123
if (!terminateProcessGroup_.load()) {
1122-
LOG(WARNING) << c10::str(
1123-
"WARNING: process group has NOT been destroyed before it is being destructed. ",
1124-
"On normal program exit, the application should call destroy_process_group to ",
1125-
"ensure that any pending NCCL data transfers have finished in this process. "
1126-
"In rare cases this process can exit before this point and block the progress of "
1127-
"another member of the process group. This constraint has always been present, "
1128-
" but this warning has only been added since PyTorch 2.4");
1124+
if (rank_ % localDeviceCount_ == 0) {
1125+
TORCH_WARN_ONCE(
1126+
"WARNING: process group has NOT been destroyed before we destruct ProcessGroupNCCL. ",
1127+
"On normal program exit, the application should call destroy_process_group to ",
1128+
"ensure that any pending NCCL operations have finished in this process. "
1129+
"In rare cases this process can exit before this point and block the progress of "
1130+
"another member of the process group. This constraint has always been present, "
1131+
" but this warning has only been added since PyTorch 2.4");
1132+
}
11291133
// If user haven't explicitly destroy/shutdown process group, destructor
11301134
// needs to do so
11311135
shutdown();

‎torch/csrc/distributed/c10d/ProcessGroupNCCL.hpp

Copy file name to clipboardExpand all lines: torch/csrc/distributed/c10d/ProcessGroupNCCL.hpp
+3Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1093,6 +1093,9 @@ class TORCH_API ProcessGroupNCCL : public Backend {
10931093
std::string logPrefix_;
10941094

10951095
c10::intrusive_ptr<intra_node_comm::IntraNodeComm> intraNodeComm_;
1096+
1097+
// Number of devices on this node.
1098+
int localDeviceCount_{0};
10961099
};
10971100

10981101
TORCH_API std::string dump_nccl_trace();

0 commit comments

Comments
0 (0)
Morty Proxy This is a proxified and sanitized view of the page, visit original site.