Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adds reduce_scatter into torchft #102

Open
wants to merge 3 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
75 changes: 69 additions & 6 deletions torchft/process_group.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,7 @@
AllreduceOptions,
BroadcastOptions,
ReduceOp,
ReduceScatterOptions,
Work,
)
from torch.futures import Future
Expand Down Expand Up @@ -180,6 +181,20 @@ def broadcast_one(self, tensor: torch.Tensor, root: int) -> Work:
opts.rootRank = root
return self.broadcast([tensor], opts)

# pyre-fixme[14]: inconsistent override
def reduce_scatter(
self,
output_tensors: List[torch.Tensor],
input_tensors: List[List[torch.Tensor]],
opts: ReduceScatterOptions,
) -> Work:
"""
Reduces, then scatters a list of tensors to all processes in a group.

See torch.distributed.reduce_scatter for more details.
"""
raise NotImplementedError("not implemented")

def size(self) -> int:
raise NotImplementedError("not implemented")

Expand Down Expand Up @@ -288,6 +303,14 @@ def allgather(
def broadcast(self, tensor_list: List[torch.Tensor], opts: object) -> Work:
return self.parent.broadcast(tensor_list, opts)

def reduce_scatter(
self,
output_tensors: List[torch.Tensor],
input_tensors: List[List[torch.Tensor]],
opts: object,
) -> Work:
return self.parent.reduce_scatter(output_tensors, input_tensors, opts)

def size(self) -> int:
return self.parent.size()

Expand Down Expand Up @@ -375,11 +398,6 @@ def __init__(self, rank: int, world: int) -> None:
def configure(self, store_addr: str, rank: int, world_size: int) -> None:
self.configure_count += 1

def broadcast(self, tensor_list: List[torch.Tensor], opts: object) -> Work:
res = _DummyWork(tensor_list)
self._work.append(res)
return res

def allgather(
self,
output_tensors: List[List[torch.Tensor]],
Expand All @@ -398,6 +416,24 @@ def allreduce(self, tensors: List[torch.Tensor], opts: object) -> Work:
self._work.append(res)
return res

def broadcast(self, tensor_list: List[torch.Tensor], opts: object) -> Work:
res = _DummyWork(tensor_list)
self._work.append(res)
return res

def reduce_scatter(
self,
output_tensors: List[torch.Tensor],
input_tensors: List[List[torch.Tensor]],
opts: object,
) -> Work:
for o, i in zip(output_tensors, input_tensors[0]):
o.copy_(i)

res = _DummyWork(output_tensors)
self._work.append(res)
return res

def size(self) -> int:
return self._world

Expand Down Expand Up @@ -960,6 +996,25 @@ def broadcast(

return self._run_func("broadcast", tensor_list, opts)

def reduce_scatter(
self,
output_tensors: List[torch.Tensor],
input_tensors: List[List[torch.Tensor]],
opts: ReduceScatterOptions,
) -> Work:
assert isinstance(output_tensors, list), "input must be list"
assert isinstance(input_tensors, list), "input must be list"

for tensor in output_tensors:
if not tensor.is_shared():
tensor.share_memory_()

for tensor_list in input_tensors:
for tensor in tensor_list:
if not tensor.is_shared():
tensor.share_memory_()
return self._run_func("reduce_scatter", output_tensors, input_tensors, opts)

def size(self) -> int:
return self._world_size

Expand All @@ -982,7 +1037,15 @@ def safe_args(cls, args: T) -> T:
return tuple(cls.safe_args(arg) for arg in args)
elif isinstance(args, list):
return [cls.safe_args(arg) for arg in args]
elif isinstance(args, (AllreduceOptions, AllgatherOptions, BroadcastOptions)):
elif isinstance(
args,
(
AllreduceOptions,
AllgatherOptions,
BroadcastOptions,
ReduceScatterOptions,
),
):
return cls.from_torch(args)
else:
return args
Expand Down
42 changes: 42 additions & 0 deletions torchft/process_group_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
AllreduceOptions,
BroadcastOptions,
ReduceOp,
ReduceScatterOptions,
_resolve_process_group,
)
from torch.distributed import (
Expand Down Expand Up @@ -60,6 +61,31 @@ def dummy_init_pg() -> None:
)


def _should_run_collective(collective_str: str, backend_str: str, device: str) -> bool:
"""Verify if the collective is supported by the backend and device.

See https://pytorch.org/docs/stable/distributed.html#backends for the
supported collectives / backends / devices matrix.

"""
if "nccl" in backend_str.lower():
# all collectives are supported for NCCL/CUDA but none on CPU.
return device == "cuda"
elif "gloo" in backend_str.lower():
if device == "cuda":
# GLOO/GPU only supports broadcast and all_reduce.
if collective_str in ["broadcast", "all_reduce"]:
return True
return False
else: # cpu
if collective_str in ["reduce_scatter", "all_to_all"]:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

oh wow -- didn't realize we don't support these on Gloo, good to know! cc @c-p-i-o

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ye, we miss many APIs on Gloo.

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this approach seems nice and explicit. but is it possible to instead just try: the test, and except: some specific NYI error? (i'm not sure if we raise a consistent type of NYI exception from backends?)

return False
return True
else:
# Non defined backends (e.g. ErrorSwallowing) should continue to work.
return True


def _test_pg(
pg: ProcessGroup,
example_tensor: torch.Tensor = torch.randn((2, 3), dtype=torch.float32),
Expand Down Expand Up @@ -94,9 +120,25 @@ def check_tensors(arg: Any) -> None: # pyre-ignore[2]
("allgather", (output_tensors, [input_tensor], AllgatherOptions())),
("broadcast", (tensor_list, BroadcastOptions())),
("broadcast_one", (input_tensor, 0)),
(
"reduce_scatter",
(output_tensors[0], [[input_tensor]], ReduceScatterOptions()),
),
]
works: Dict[str, dist._Work] = {}

try:
backend_str = pg.getBackendName()
device = example_tensor.device
if type(device) is torch.device:
device = device.type
except NotImplementedError as e:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why are we getting a NotImplementedError? which backend is this?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is just for the ErrorSwallowingProcessGroupWrapper. I have a follow up PR to refactor the tests to get rid of this entire function though!

backend_str = ""
device = ""

for coll_str, args in collectives:
if not _should_run_collective(coll_str, backend_str=backend_str, device=device):
continue
coll = getattr(pg, coll_str)
work = coll(*args)
works[coll_str] = work
Expand Down