diff --git a/examples/graphbolt/pyg/multigpu/node_classification.py b/examples/graphbolt/pyg/multigpu/node_classification.py index 30e076ce50f8..306a5cd2b618 100644 --- a/examples/graphbolt/pyg/multigpu/node_classification.py +++ b/examples/graphbolt/pyg/multigpu/node_classification.py @@ -391,6 +391,7 @@ def run(rank, world_size, args, dataset): # Set up multiprocessing environment. torch.cuda.set_device(rank) dist.init_process_group( + backend='cpu:gloo,cuda:nccl', init_method="tcp://127.0.0.1:12345", rank=rank, world_size=world_size, diff --git a/tests/python/pytorch/graphbolt/test_dataloader.py b/tests/python/pytorch/graphbolt/test_dataloader.py index 5843264516fc..6f56a3bb54b0 100644 --- a/tests/python/pytorch/graphbolt/test_dataloader.py +++ b/tests/python/pytorch/graphbolt/test_dataloader.py @@ -89,6 +89,7 @@ def test_gpu_sampling_DataLoader( else "tcp://127.0.0.1:12345" ) thd.init_process_group( + backend='cpu:gloo,cuda:nccl', init_method=init_method, world_size=1, rank=0,