@@ -254,12 +254,16 @@ class ProcessGroupNCCLErrorsTest : public ::testing::Test {
254
254
void SetUp () override {
255
255
// Enable LOG(INFO) messages.
256
256
c10::initLogging ();
257
- size_t numDevices = 1 ; // One device per rank (thread)
257
+ size_t numDevices = cudaNumDevices ();
258
258
TemporaryFile file;
259
259
store_ = c10::make_intrusive<::c10d::FileStore>(file.path , 1 );
260
260
261
+ at::cuda::OptionalCUDAGuard deviceGuard;
261
262
tensors_.resize (numDevices);
262
- tensors_[0 ] = at::empty ({3 , 3 }, at::kCUDA );
263
+ for (const auto i : c10::irange (numDevices)) {
264
+ deviceGuard.set_index (i);
265
+ tensors_[i] = at::ones ({3 , 3 }, at::kCUDA );
266
+ }
263
267
}
264
268
265
269
void TearDown () override {
@@ -282,6 +286,7 @@ TEST_F(ProcessGroupNCCLErrorsTest, testNCCLErrorsBlocking) {
282
286
283
287
auto work = pg.allreduce (tensors_);
284
288
work->wait ();
289
+ EXPECT_TRUE (work->isSuccess ());
285
290
EXPECT_EQ (1 , pg.getNCCLCommCacheSize ());
286
291
287
292
// Now run all reduce with errors.
@@ -291,6 +296,7 @@ TEST_F(ProcessGroupNCCLErrorsTest, testNCCLErrorsBlocking) {
291
296
292
297
// Verify the work item failed.
293
298
EXPECT_TRUE (work->isCompleted ());
299
+ EXPECT_FALSE (work->isSuccess ());
294
300
EXPECT_THROW (work->wait (), std::runtime_error);
295
301
296
302
// Communicators might be aborted here, further operations would fail.
@@ -308,6 +314,7 @@ TEST_F(ProcessGroupNCCLErrorsTest, testNCCLTimedoutErrorsBlocking) {
308
314
309
315
auto work = pg.allreduce (tensors_);
310
316
work->wait ();
317
+ EXPECT_TRUE (work->isSuccess ());
311
318
EXPECT_EQ (1 , pg.getNCCLCommCacheSize ());
312
319
313
320
// Now run all reduce with errors.
@@ -329,6 +336,7 @@ TEST_F(ProcessGroupNCCLErrorsTest, testNCCLErrorsNonBlocking) {
329
336
330
337
auto work = pg.allreduce (tensors_);
331
338
pg.barrier ()->wait ();
339
+ EXPECT_TRUE (work->isSuccess ());
332
340
EXPECT_EQ (1 , pg.getNCCLCommCacheSize ());
333
341
334
342
// Now run all reduce with errors.
@@ -339,7 +347,10 @@ TEST_F(ProcessGroupNCCLErrorsTest, testNCCLErrorsNonBlocking) {
339
347
work->wait ();
340
348
pg.barrier ()->wait ();
341
349
350
+ // Verify the work item failed.
342
351
EXPECT_TRUE (work->isCompleted ());
352
+ EXPECT_FALSE (work->isSuccess ());
353
+
343
354
// Communicators might be aborted here, further operations would fail.
344
355
}
345
356
@@ -415,6 +426,7 @@ TEST_F(ProcessGroupNCCLErrorsTest, testNCCLErrorsNoHeartbeat) {
415
426
// Normal collective case.
416
427
auto work = pg.allreduce (tensors_);
417
428
work->wait ();
429
+ EXPECT_TRUE (work->isSuccess ());
418
430
419
431
work = pg.allreduce (tensors_);
420
432
{
@@ -428,6 +440,7 @@ TEST_F(ProcessGroupNCCLErrorsTest, testNCCLErrorsNoHeartbeat) {
428
440
EXPECT_TRUE (pg.getErrorCaughtFlag ());
429
441
}
430
442
work->wait ();
443
+ EXPECT_TRUE (work->isSuccess ());
431
444
EXPECT_TRUE (traces.size () > 0 );
432
445
auto filename = c10::str (tempFilename, 0 );
433
446
auto traceFromStorage = readTraceFromFile (filename, traces.size ());
0 commit comments