-
Notifications
You must be signed in to change notification settings - Fork 2.6k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
a proof of concept for Distributed Muon #1428
base: main
Are you sure you want to change the base?
a proof of concept for Distributed Muon #1428
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@toothacher17 Hello, I am comparing the performance between AdamW and Muon. The experiment involves training a 1B parameter MoE model on the H800 cluster, with a maximum learning rate of 1e-3 that decays to 1e-4 using cosine decay. Muon uses default parameters. I observed that Muon has a significant advantage over AdamW in the early stages of training, but after 20k steps, their performance becomes similar, with AdamW sometimes even outperforming Muon. Is this phenomenon normal?
hi, @mactavish91 Thanks a lot for trying out! I actually probably know the reason:
A simple way to hack is to add one line to force lr_mult = 1.0 and wd_mult = 1.0 for all parameters after line 114 |
Your exp looks very much like this: ![]() Let us know if adding weight decay to all params helps! |
Thank you for your kindly help. I used train loss, since the dataset is of a pretrain-level size, it can be approximated as val loss. After 20k steps, approximately 50B tokens have been trained. I tried applying weight decay to all parameters, but it doesn't seem to help much. |
Hi, thanks for sharing. Yeah all your settings looked fine and reasonable. Since it is pretrain level size, reporting pretrain loss is also reasonable. I am not sure what is the root cause that Muon's advantages diminishes, we found muon performed well as long as we adding correct weight decay and adjusting the update rms for matrix's shape. If it is ok, do you mind sharing your model arch details, and we can try in our code base and data, and see what's going on |
Another thing to debug is to observe your weight rms, max logit, output rms per layer, and update rms during the training and see if there is anything weird that is happening |
Hi, thank you for open-sourcing great job! I have some questions:
|
I think the reason megatron uses fused QKV is to reduce memory access and allow the GPU to perform larger matrix multiplications. This seems more like megatron's built-in optimization rather than the moonlight author's intention. well, but... idk why using separate weights is better. |
Very good questions!
In general, the concept of 'matrix' might not be well defined in Muon, and for now we relied on empirical results to decide the matrix split |
Yeah splitting them into three matrices performed better empirically so we followed. For moonlight, it uses MLA so it is naturally split. |
Besides the larger matrix multiplications, another advantage of using QKV fused is that you only need to gather the input between TP group once (if TP and SP are enabled) and used them for projection |
The following are the settings I used in the experiment
|
Hi, @mactavish91 your model arch looks reasonable. For the purpose of debugging, I'll need more monitoring that current open source megatron-lm does not have. So I'll run in our internal infra, with some slight changes:
Other settings will remain the same as you posted. We'll keep you posted about our findings |
Yeah, that would be better to get rid of the impacts of large embeddings. I am still running the two comparing jobs based on your previous smaller model setting in progress. Besides increasing the, another thing worth mentioning is to use/report the OOD validation data rather than in domain validation data for a more accurate eval of the model. |
hi, @mactavish91 We ran your settings for about ~17K steps by now and for about ~40+B tokens (You mentioned before that ~20K steps, the advantages diminish. Even though with the big embedding issue, I actually think the result is promising. We plot the figure as shown below:
|
For the purpose of reproducing, we provide the script to generate these figures. @mactavish91 Can you help to try on such figures based on your previous small run data as well?
Here adam_data or muon_data is the run we fetched from the TB, tag is simply the 'lm-loss-training/lm loss' ![]() |
@mactavish91 Besides, we also evaluated on OOD lm validation loss data, and it showed pretty good results ![]() |
@mactavish91 We'll wait for you visualization results and see how it goes! Thanks! |
@SeunghyunSEO Thanks for sharing! I have some comments regarding your runs:
However, the current impl of distributed optimizer is first to group in several params into a bucket. And every params in that bucket will be split into DP parts and needs a gather! Thus, bringing the extra needed all gather to its upperbound, which means every params in every rank needs to a gather. For distributed muon to work efficiently as described in our paper. We need the original way of DP sharding optimizer states, which only requires very limited params to do the extra gathering |
@toothacher17 wow, your response is as fast as the speed of light, lol. I didn’t even know that megatron changed its sharding logic. (I’m also familiar with the sharding strategy you mentioned in point 2.) I’ll dig into the codebase and come back if I find any clues to improve performance. edited) can you share related PR for refactoring param and grad bucketing? I'm not sure this one is right. |
@SeunghyunSEO Thanks for your kind words! However, I am not sure when they changed the logic since we do not merge the upstream for a while. We just noticed when we are preparing this PR. The commit you find looks very related but I am not sure if that's the only related one. BTW, do you mind visualizing your results using the script I mentioned above? It is exciting to see some other people reproducing similar results of Muon! |
Thank you for your detailed guidance! Here are my findings from today:
|
Glad to see it works! Regarding your questions:
![]() Besides, your leading step ratio figure would be better if you can set 'ylim(0,1)'. Would you mind re-plotting the figure and do you mind me sharing it on X as it reproduces our results! |
@mactavish91 @SeunghyunSEO Discussed with my colleagues and we might need more investigation on this performance dropping issue. If more gatherings happened, more NS iterative steps calculation will also happen for more params |
@mactavish91 btw, can i ask how you log per param activation norm?? it's too heavy to reduce or gather all stats in forward when naively register forward hook, so I'm curious how you do efficiently logging ! |
I think @mactavish91 showed the output l2norm of the model rather than the per param activation norm? For output l2norm, you can open a log buffer to log the detached l2norm or rms of the output, (do not do the sqrt, but only the square sum), accumulate it in the fwd-bwd, and only reduce once after all the fwdbwd is done. This is cheap as you only do it once per global step |
Some other issues asked why distributed muon is efficient, and tried to explain it with details: MoonshotAI/Moonlight#16 |
@toothacher17 ty for sharing! |
An proof of concept for implementing the Distributed Muon as described in: https://github.com/MoonshotAI/Moonlight
Example script: see examples/muon/training.sh
Tested with TP=2, PP=2, DP=2 and compared with AdamW, and no TP/PP
Used the data from bigscience and the provided example script