-
Notifications
You must be signed in to change notification settings - Fork 5.8k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[core][compiled graphs] Overlap computation and communication #47586
Conversation
9e13938
to
ccb561c
Compare
6b66049
to
4e752b2
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Main comment is to try to reuse existing codepaths for executing tasks and reading/writing local args. I think the code will be much more robust this way, plus we need to do it anyway to support enabling overlapping per-task or object.
Seems possible to do if we wrap all inputs/outputs with a wrapper class like this, maybe we need to update the channel reading/writing:
@dataclass
class Value:
value: Any
# If set, then reader needs to synchronize before reading value.
cuda_event: Optional[cp.cuda.Event]
Also should think about how we can unit-test this. Ideally we should try to write a mocked test, maybe something like this one.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I really think we need to unify the execution loop. The reason is the test space becomes much larger (we need to also make sure existing case works correctly when overlap is used).
What about we always assume send/recv is unblocking and return Future? And if send/recv is a blocking, the future is returned after wait is finished. Otherwise, it just returns future. It is same as how gloo apis also work
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
2 additional comments;
- Before proceeding refactoring, is it possible to unblock @woshiyyya? Maybe one easy solution is to build a custom wheel (you can ping reef team to learn how to do it) so that he can just use it for multi node tests.
- right now, we have 1 stream per recv/send for 1 process. I wonder if we need a stream per channel or should manage a queue of streams to have multiple nccl channels?
Yes that's my plan.
I think initially we can start with 1 stream. Whether multiple streams can provide better performance would depend on a few factors, and may need a bit of design, although we can perhaps try it. |
Yeah sounds good. Pp anyway has 1:1 only. We can revisit and benchmark when we have more than 1 input output use cases |
Newer commits are for graph visualization prototype, just a draft for now and no need to review new code. |
a419483
to
fc5bbb7
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think this is better but I believe we can still clean up the GPUFuture interface to make the code more robust. Ideally we want to use .wait()
everywhere, not just if isinstance(GPUFuture)
.
Will take a look at the scheduling and testing code separately.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I am still reviewing the PR. All comments with "nit" are not necessary to be addressed in this PR.
ff439f7
to
00930d2
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hmm there might be some misunderstanding about the suggestion to use futures. I wasn't imagining that it would need this much code change, but please let me know if my comments don't make sense.
Yeah I think the main challenge is we need to examine write values (exception check, multi-return value extraction) earlier than they are written. Now that these values are futures, we need additional handling: the problematic case is Previously I came up with two approaches:
I think approach2 leaves out changes to the channel interface but leaks the NcclGroup send_stream implementation details, and approach1 touches a lot of interfaces but things are consistent. Originally I thought about implementing approach2 but found approach1 may be a bit cleaner, although it changes the interfaces quite a bit. Let me know how you think about it. |
This can be fixed by passing
I don't understand why we need to do the wait earlier. It should be okay to just do the wait right before starting to execute the |
Yeah what I meant by "earlier" in
is exactly this line, and it needs the
So I think approach2 is what you preferred? If so, I can change to implement it. |
Ah okay, got it, thanks for the explanation. Yes just to be clear, I prefer approach2 for the following reasons:
|
c8cc45e
to
041fdd2
Compare
d181804
to
98f5c38
Compare
Signed-off-by: Rui Qiao <[email protected]>
Signed-off-by: Rui Qiao <[email protected]>
Signed-off-by: Rui Qiao <[email protected]>
98f5c38
to
2497dd5
Compare
Signed-off-by: Rui Qiao <[email protected]>
Signed-off-by: Rui Qiao <[email protected]>
Signed-off-by: Rui Qiao <[email protected]>
Signed-off-by: Rui Qiao <[email protected]>
Signed-off-by: Rui Qiao <[email protected]>
Signed-off-by: Rui Qiao <[email protected]>
Signed-off-by: Rui Qiao <[email protected]>
Signed-off-by: Rui Qiao <[email protected]>
Signed-off-by: Rui Qiao <[email protected]>
) | ||
|
||
|
||
class AbstractNcclGroup(GPUCommunicator): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Need to move this from test_collective_dag.py, otherwise there is the following error:
> raise value.as_instanceof_cause()
E ray.exceptions.RayTaskError(RaySystemError): ray::CPUTorchTensorWorker.__ray_call__() (pid=2553478, ip=172.31.15.128, actor_id=a378363d53bded77aa33e28301000000, repr=<test_collective_dag.CPUTorchTensorWorker object at 0x7f02403f7c40>)
E At least one of the input arguments for this task could not be computed:
E ray.exceptions.RaySystemError: System error: No module named 'test_collective_dag'
E traceback: Traceback (most recent call last):
E ModuleNotFoundError: No module named 'test_collective_dag'
Signed-off-by: Rui Qiao <[email protected]>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM. One last request for unit tests
Signed-off-by: Rui Qiao <[email protected]>
Signed-off-by: Rui Qiao <[email protected]>
Signed-off-by: Rui Qiao <[email protected]>
lmk when ti is ready to merge! super excited for this feature. After this can you follow up
|
…oject#47586) This PR supports overlapping computation and communication for GPU tasks, as described in https://docs.google.com/document/d/1AkAqrMPadk1rMyjKE4VN4bq058z36fgBcx0i4dHIW20/edit#heading=h.8jw8z0hmgva0 The scope is send/recv but does not include collectives. Checked perf improvement with test_torch_tensor_dag::test_torch_tensor_nccl_overlap, result is consistent and roughly: overlap_gpu_communication=False, duration=1.0124679207801819 overlap_gpu_communication=True, duration=0.8186687417328358
Why are these changes needed?
This PR supports overlapping computation and communication for GPU tasks, as described in https://docs.google.com/document/d/1AkAqrMPadk1rMyjKE4VN4bq058z36fgBcx0i4dHIW20/edit#heading=h.8jw8z0hmgva0
The scope is send/recv but does not include collectives.
Checked perf improvement with test_torch_tensor_dag::test_torch_tensor_nccl_overlap, result is consistent and roughly:
overlap_gpu_communication=False, duration=1.0124679207801819
overlap_gpu_communication=True, duration=0.8186687417328358
Related issue number
Closes #47016
Closes #48191
Visualization
test_torch_tensor_nccl_overlap:
zero bubble:
Checks
git commit -s
) in this PR.scripts/format.sh
to lint the changes in this PR.method in Tune, I've added it in
doc/source/tune/api/
under thecorresponding
.rst
file.