-
Notifications
You must be signed in to change notification settings - Fork 20
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Refactors process_group_tests.py
#103
base: main
Are you sure you want to change the base?
Conversation
torchft/process_group_test.py
Outdated
@@ -391,18 +394,19 @@ def test_error_swallowing_process_group_wrapper(self) -> None: | |||
wrapper = ErrorSwallowingProcessGroupWrapper(pg) | |||
self.assertIs(wrapper.parent, pg) | |||
|
|||
works = _test_pg(wrapper) | |||
works = run_collective(pg=wrapper, collective="allreduce") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this seems like a pretty big decrease in coverage?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Added functionality back
torchft/process_group_test.py
Outdated
shape: torch.Size = example_tensor.shape | ||
dtype: torch.dtype = example_tensor.dtype | ||
coll = getattr(pg, collective) | ||
args_list = _build_args(pg=pg, collective=collective, example_tensor=example_tensor) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What's the intention behind pulling this out? I'm not really convinced that this makes it all that much cleaner
In some ways I think I'd prefer if we got rid of the arg generation and instead flatten this out with direct calls i.e.
if collective == "allreduce":
work = pg.allreduce(...)
work.wait()
...
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I agree, I've removed the arg generation and included it in place for run_collective
pg = ProcessGroupBabyNCCL(timeout=timedelta(seconds=10)) | ||
try: | ||
pg.configure(self.store_addr, 0, 1) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This seems really slow -- how fast does this run? Launching the subprocess is pretty slow so would actually prefer to run these all on the same PG
If you want prettier printing we can use subtests?
i.e.
for collective in collectives:
with self.subTest(collective=collective):
...
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Good callout, with parameterized it took ~36s, without it took ~16s. I've removed parameterized.
What does this PR do?
As part of #97, this PR refactors
process_group_test
:_test_pg
torun_collectives
and extending it to accept a given list of collectives by name.ProcessGroupTest
into three tests:GlooTest
,NCCLTests
andDummyTests
:GlooTest
logically tests every test usinggloo
,NCCLTest
with NCCL, etc.shutdown()
and garbage collection etc. to avoid extraneous messages & warnings likeWhy is this needed?
As part of #102, I noticed that there were some mismatches between which collectives ran on which backends (matrix is here). Therefore this logical grouping of tests by backend allows us to define which collectives should be tested explicitly