-
Notifications
You must be signed in to change notification settings - Fork 20
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Adds reduce_scatter
into torchft
#102
base: main
Are you sure you want to change the base?
Conversation
return True | ||
return False | ||
else: # cpu | ||
if collective_str in ["reduce_scatter", "all_to_all"]: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
oh wow -- didn't realize we don't support these on Gloo, good to know! cc @c-p-i-o
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ye, we miss many APIs on Gloo.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this approach seems nice and explicit. but is it possible to instead just try: the test, and except: some specific NYI error? (i'm not sure if we raise a consistent type of NYI exception from backends?)
device = example_tensor.device | ||
if type(device) is torch.device: | ||
device = device.type | ||
except NotImplementedError as e: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
why are we getting a NotImplementedError? which backend is this?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is just for the ErrorSwallowingProcessGroupWrapper
. I have a follow up PR to refactor the tests to get rid of this entire function though!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM, thanks for adding this!
What does this PR do?
Partially addresses #97 by adding
reduce_scatter
intotorchft
.Concretely, this consists of a few pieces:
reduce_scatter
into theProcessGroup
following the signature [here](https://github.com/pytorch/pytorch/blob/11f69808c64a65c68a4452250ba7719dcff27c78/torch/csrc/distributed/c10d/PyProcessGroup.hpp#L203ProcessGroup*
we essentially follow the behavior of other collectives:ProcessGroupWrapper
, it depends on the parent implementationProcessGroupDummy
, it writes from the first input into outputProcessGroupBaby
, it asserts inputs and moves underlying storage into shared memoryReduceScatterOptions
in_PickleSafeOptions
reduce_scatter
as an option in_test_pg
, however this necessitated a new function (named_should_run_collective
) which was needed as e.g. GLOO does not supportreduce_scatter
. This function essentially takes the collective, backend and device and copies the logic of the published supported collective matrix.Tests
Presubmits, and:
Next steps
The logic of
_should_run_collective
is a bit confusing, as it allows "non defined backends" likeErrorSwallowing*
through, to mimic the old behavior before this change. Testing here could become a bit unwieldy as we add more collectives and so a future step could be to refactor the testing.One nice change could be to parameterize tests by the collective. This will make potentially failing collectives more explicit and will reduce the time it takes to run individual tests. Likely can do this in the next PR.