Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Question about Distributed Muon #16

Closed
Niccolo-Ajroldi opened this issue Feb 27, 2025 · 5 comments
Closed

Question about Distributed Muon #16

Niccolo-Ajroldi opened this issue Feb 27, 2025 · 5 comments

Comments

@Niccolo-Ajroldi
Copy link

Niccolo-Ajroldi commented Feb 27, 2025

This is more of a question than an issue.

Question

I can't really wrap my head around how DP Gather is incurring in a communication cost lower than a classical AllGather op.

My understanding is that if each parameter weight matrix is sharded across all the devices (in ZeRO-1 style), then each device has to collect all the remaining shards across other devices, hence having to resort to an AllGather call.

Here is a small example:

Consider a model with 2 parameter weight matrices: p1, p2 with corresponding moments m1, m2 and grads g1, g2
We shard the optimizer state across 2 devices. After ReduceScatter on the gradients and after applying momentum, each device holds:
DP1: p11, p21, m11, m21, g'11, g'21
DP1: p12, p22, m12, m22, g'12, g'22

where pij is the j-th shard of parameter i, optimizer state is p and m, g' indicates the gradient after application of momentum (following algorithm 1 notation)

In order to perform Newton-Schulz and compute the update for p11 and p21, DP1 needs to collect all the remaining gradient shards. Since it already has g11 and g21, it will send those to DP2 and receive g12 and g22 from it. How is this different from a normal AllGather op?

Any help in understanding is greatly appreciated!

Image
@toothacher17
Copy link
Collaborator

toothacher17 commented Feb 28, 2025

Hi, @Niccolo-Ajroldi

This is a brilliant question and thanks for pointing it out! Let me try to explain it in here:

1. DP based ZeRO1 Distributed Optimizer.

a. Non-Distributed Optimizer
Say you have two DP and 5 Params. Then in Non-Distributed Optimizer, you'll get (capital indicating full matrix without sharding):

DP0: Params: P0-P5, momentum M0-M5, master weight MP0-MP5, gradient G0_dp0-G5_dp0
DP1: Params: P0-P5, momentum M0-M5, master weight MP0-MP5, gradient G0_dp1-G5_dp1

Since it is non-distributed optimizer, P, M, MP are identical and kept in sync. Gradients are different because different DP used different data, so they need communication (e.g. all-reduce) to sync. Since all matrices are full, they can just directly perform updates. M, MP are all full and cost a lot of memory

b. ZeRO1
Now we move to ZeRO1. You still need full Params and Gradient to do fwdbwd, but instead of the full momentum and master weight, you can split among DP, like this:

DP0: Params: P0-P5, gradient G0_dp0-G5_dp0, momentum M0, M1, m2(1st half), master weight MP0, MP1, mp2(1st half)
DP1: Params: P0-P5, gradient G0_dp1-G5_dp1, momentum m2(2nd half), M3, M4, master weight mp2(2nd half), MP3, MP4

You might think that every optimizer state (M and MP) is sharded across DP. It is not. ZeRO1 actually first concats all of them into a list, flattens, then shards across DP.

2. Adapting ZeRO1 to Distributed Muon

a. Operations
Since we need full Matrix to perform action the NS steps to get the update, so what you'll need is to follow previous steps:

DP0:
step1. communication: Gather m2(2nd half) from DP1 (actually g'2(2nd half), which is the same as m2(2nd half))
step2. calculation: NS steps on M0, M1, M2
step3. Update MP1, MP2, mp2(1st half). We only update those optimizer states stored in my local

DP1:
step1. communication: Gather mp2(1st half) from DP0
step2. calculation: NS steps on M2, M3, M4
step3. Update MP3, MP4, mp2(2nd half). We only update those optimizer states stored in my local

b. Analysis
As you can see above, the gathering only happened for the params that is split into multiple DP. Besides, those split matrices will be duplicate computed. This extra gathering only happens for some parameters and is in BF16, while the first grad reduce scatter and last params all gather happens on all parameters in FP32, so the extra workload is (<=2+4+4) / (4+4), so that's why it is <=1.25.

So some friends suggested even smarter partitioning to avoid split a single matrix. If you are carefully handle this, it will be close to 1.0

@toothacher17
Copy link
Collaborator

@Niccolo-Ajroldi

See a proof of concept implementation and more discussions in here: NVIDIA/Megatron-LM#1428

@Niccolo-Ajroldi
Copy link
Author

Niccolo-Ajroldi commented Feb 28, 2025

Amazing answer, love it!

My main misunderstanding was that I thought that under ZeRO-1 every optimizer state is sharded across all DP, but as u pointed out it's not.

Your comment makes it is very clear, thank you!

@toothacher17
Copy link
Collaborator

@Niccolo-Ajroldi

Thanks a lot for your interest and this is a great question! This has been discussed a bit on X as well. Your understanding is not 'mis', as it might be right under some parallel settings. It's just we found our Megatron-LM's ZeRO-1 design is perfectly suitable for Distributed Muon!

@toothacher17
Copy link
Collaborator

Image

Found by SeunghyunSEO from https://github.com/NVIDIA/Megatron-LM/tree/main/docs/source/images/distrib_optimizer

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants