-
Notifications
You must be signed in to change notification settings - Fork 49
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
ring-based reduce-scatter pipelining, ATen implementation #2950
Open
samnordmann
wants to merge
27
commits into
NVIDIA:main
Choose a base branch
from
samnordmann:host_irs/ring_reducescatter_ATen
base: main
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
Open
ring-based reduce-scatter pipelining, ATen implementation #2950
samnordmann
wants to merge
27
commits into
NVIDIA:main
from
samnordmann:host_irs/ring_reducescatter_ATen
Conversation
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
…s blocking) with 3 ranks which all first post send and then recv
…da_gdrcopy.so.0: undefined symbol: gdr_get_info_v2
!build |
!build |
I am seeing issues with UCC, both locally and in the CI. I need to investigate. In the meantime, I disable UCC from being tested |
!build |
Not needed for the PR, but can we run this with larger sizes and compare it against matmul+reduce scatter baseline? I'm curious if we can see any time savings |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
ATen implementation of a GEMM+Reduce-scatter operation, with a decomposition of the Reduce-scatter into a ring algorithm. This reproduces a technique present in TransformerEngine for achieving comm/compute overlap in Megatron.
Experiment
The following nSight profile clearly shows that we achieve overlap:
setup: DGX 8*V100 32GB
params: backend=NCCL with coalescence, M=K=N=2048, S=8, number_of_streams=3
Arbitrary number of steps
The algorithm we provide is slightly more general than the classical case, as it allows decomposing reduce-scatter into an arbitrary large number of steps, not only
num_devices_
steps as in the classical algorithm. More precisely, the parameterS
(which stands for the number of steps, meaning, the number of interleaved comms and compute), which is classically equal tonum_devices_
for the ring algorithm, is only assumed to be a multiple ofnum_devices_
in our version. This is an important parameter as it gives more flexibility for the size of interleaved chunks and therefore could lead to better overlap -- but thorough perf analysis remains to be done. IfS>num_devices_
, only a fractionS/num_devices_
of the buffer is computed and communicated to the peers at each iteration.Coalescing
NCCL
ProcessGroupNCCL provides two methods
startCoalescing
andendCoalescing
, which internally correspond toncclGroupStart
andncclGroupEnd
, see doc here. Those calls group p2p calls that need to be progressed together -- one global work handle returned byendCoalescing
needs to be progressed. This has the following main advantagesrank0:
rank1:
This situation created a deadlock because no rank can receive before it has sent.
It is in general preferable to coalesce send/recv calls. The only drawback is that we don't have a fine-grain control on synchronicity, in other words, we can only synchronize with the bulked communication as a whole.
Remark:
note that NCCL doesn't support tag in send/recv.
UCC
ProcessGroupUCC does not implement coalesced groups for now. It should not be a problem to achieve full bidirectional BW with two send/recv though. However, having more than two ops in a batch will be suboptimal. Adding UCC coalescing was discussed but not added to POR for now because lacking a good use case.
Reducing memory foot print
Further optimizations, that we leave as TODO, are possible to reduce the footprint of the work buffers
src_buffer_
anddst_buffer_
: