Skip to content

Commit

Permalink
deepspeed.init_distributed() support for TCP protocols (#2905)
Browse files Browse the repository at this point in the history
Co-authored-by: Jeff Rasley <[email protected]>
  • Loading branch information
noabauma and jeffra authored Mar 7, 2023
1 parent 0acf7e9 commit db15ef5
Show file tree
Hide file tree
Showing 2 changed files with 18 additions and 6 deletions.
8 changes: 6 additions & 2 deletions deepspeed/comm/comm.py
Original file line number Diff line number Diff line change
Expand Up @@ -603,7 +603,9 @@ def init_distributed(dist_backend=None,
timeout=default_pg_timeout,
init_method=None,
dist_init_required=None,
config=None):
config=None,
rank=-1,
world_size=-1):
''' Initialize dist backend, potentially performing MPI discovery if needed
Arguments:
Expand All @@ -614,6 +616,8 @@ def init_distributed(dist_backend=None,
timeout: Optional (timedelta). Timeout for operations executed against the process group. Default value equals 30 minutes.
init_method: Optional (string). Torch distributed, URL specifying how to initialize the process group. Default is “env://” if no init_method or store is specified.
config: Optional (dict). DeepSpeed configuration for setting up comms options (e.g. Comms profiling)
rank: Optional (int). The current manually specified rank. Some init_method like “tcp://” need the rank and world_size as well (see: https://pytorch.org/docs/stable/distributed.html#tcp-initialization)
world_size: Optional (int). Desired world_size for the TCP or Shared file-system initialization.
'''
global cdb

Expand Down Expand Up @@ -658,7 +662,7 @@ def init_distributed(dist_backend=None,
'Initializing TorchBackend in DeepSpeed with backend {}'.format(
dist_backend))
# Create a torch backend object, initialize torch distributed, and assign to cdb
cdb = TorchBackend(dist_backend, timeout, init_method)
cdb = TorchBackend(dist_backend, timeout, init_method, rank, world_size)


def mpi_discovery(distributed_port=TORCH_DISTRIBUTED_DEFAULT_PORT, verbose=True):
Expand Down
16 changes: 12 additions & 4 deletions deepspeed/comm/torch.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,13 @@ class TorchBackend(Backend):
so no need to wrap all the functions. We can keep adding wrappers as
needed.
"""
def __init__(self, backend, timeout, init_method, name='torch'):
def __init__(self,
backend,
timeout,
init_method,
rank=-1,
world_size=-1,
name='torch'):
super(TorchBackend, self).__init__()
self.torch_version_before_18 = older_torch()
self.has_allgather_base = has_allgather_base()
Expand All @@ -27,13 +33,15 @@ def __init__(self, backend, timeout, init_method, name='torch'):
# The idea is to fake that dist backend is initialized even when
# it is not so we can run on a single GPU without doing any init_process_group
self.single_gpu_mode = True
self.init_process_group(backend, timeout, init_method)
self.init_process_group(backend, timeout, init_method, rank, world_size)

def init_process_group(self, backend, timeout, init_method):
def init_process_group(self, backend, timeout, init_method, rank, world_size):
if not torch.distributed.is_initialized():
torch.distributed.init_process_group(backend,
timeout=timeout,
init_method=init_method)
init_method=init_method,
rank=rank,
world_size=world_size)
self.using_mpi = torch.distributed.get_backend() == 'mpi'

def all_reduce(self,
Expand Down

0 comments on commit db15ef5

Please sign in to comment.