Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adds reduce_scatter into torchft #102

Open
wants to merge 3 commits into
base: main
Choose a base branch
from

Conversation

allenwang28
Copy link

@allenwang28 allenwang28 commented Feb 6, 2025

What does this PR do?

Partially addresses #97 by adding reduce_scatter into torchft.

Concretely, this consists of a few pieces:

  • Introducing reduce_scatter into the ProcessGroup following the signature [here](https://github.com/pytorch/pytorch/blob/11f69808c64a65c68a4452250ba7719dcff27c78/torch/csrc/distributed/c10d/PyProcessGroup.hpp#L203
  • In ProcessGroup* we essentially follow the behavior of other collectives:
    • In ProcessGroupWrapper, it depends on the parent implementation
    • In ProcessGroupDummy, it writes from the first input into output
    • In ProcessGroupBaby, it asserts inputs and moves underlying storage into shared memory
  • Add ReduceScatterOptions in _PickleSafeOptions
  • Introduces reduce_scatter as an option in _test_pg, however this necessitated a new function (named _should_run_collective) which was needed as e.g. GLOO does not support reduce_scatter. This function essentially takes the collective, backend and device and copies the logic of the published supported collective matrix.

Tests

Presubmits, and:

$ pytest torchft/process_group_test.py 
============================================= test session starts =============================================
platform linux -- Python 3.10.16, pytest-8.3.4, pluggy-1.5.0
rootdir: /home/allencwang/workspace/torchft
configfile: pytest.ini
plugins: typeguard-2.13.3
collected 16 items                                                                                            

torchft/process_group_test.py ................                                                          [100%]

============================================= 16 passed in 31.44s =============================================
[rank0]:[W206 14:54:24.777939032 CudaIPCTypes.cpp:16] Producer process has been terminated before all shared CUDA tensors released. See Note [Sharing CUDA tensors]

Next steps

The logic of _should_run_collective is a bit confusing, as it allows "non defined backends" like ErrorSwallowing* through, to mimic the old behavior before this change. Testing here could become a bit unwieldy as we add more collectives and so a future step could be to refactor the testing.

One nice change could be to parameterize tests by the collective. This will make potentially failing collectives more explicit and will reduce the time it takes to run individual tests. Likely can do this in the next PR.

@facebook-github-bot facebook-github-bot added the CLA Signed This label is managed by the Meta Open Source bot. label Feb 6, 2025
return True
return False
else: # cpu
if collective_str in ["reduce_scatter", "all_to_all"]:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

oh wow -- didn't realize we don't support these on Gloo, good to know! cc @c-p-i-o

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ye, we miss many APIs on Gloo.

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this approach seems nice and explicit. but is it possible to instead just try: the test, and except: some specific NYI error? (i'm not sure if we raise a consistent type of NYI exception from backends?)

device = example_tensor.device
if type(device) is torch.device:
device = device.type
except NotImplementedError as e:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why are we getting a NotImplementedError? which backend is this?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is just for the ErrorSwallowingProcessGroupWrapper. I have a follow up PR to refactor the tests to get rid of this entire function though!

@allenwang28 allenwang28 marked this pull request as ready for review February 7, 2025 16:55
Copy link
Member

@d4l3k d4l3k left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, thanks for adding this!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
CLA Signed This label is managed by the Meta Open Source bot.
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants