Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

a proof of concept for Distributed Muon #1428

Open
wants to merge 2 commits into
base: main
Choose a base branch
from

Conversation

toothacher17
Copy link

An proof of concept for implementing the Distributed Muon as described in: https://github.com/MoonshotAI/Moonlight

  • Example script: see examples/muon/training.sh

  • Tested with TP=2, PP=2, DP=2 and compared with AdamW, and no TP/PP

  • Used the data from bigscience and the provided example script

img_v3_02jq_52105121-679a-4744-9b77-02645613951g

Copy link

@mactavish91 mactavish91 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

image
@toothacher17 Hello, I am comparing the performance between AdamW and Muon. The experiment involves training a 1B parameter MoE model on the H800 cluster, with a maximum learning rate of 1e-3 that decays to 1e-4 using cosine decay. Muon uses default parameters. I observed that Muon has a significant advantage over AdamW in the early stages of training, but after 20k steps, their performance becomes similar, with AdamW sometimes even outperforming Muon. Is this phenomenon normal?

@toothacher17
Copy link
Author

toothacher17 commented Feb 25, 2025

image @toothacher17 Hello, I am comparing the performance between AdamW and Muon. The experiment involves training a 1B parameter MoE model on the H800 cluster, with a maximum learning rate of 1e-3 that decays to 1e-4 using cosine decay. Muon uses default parameters. I observed that Muon has a significant advantage over AdamW in the early stages of training, but after 20k steps, their performance becomes similar, with AdamW sometimes even outperforming Muon. Is this phenomenon normal?

hi, @mactavish91

Thanks a lot for trying out! I actually probably know the reason:

  1. The first question is that you are reporting training loss or validation loss? It's better to observe validation loss rather than training loss.

  2. My next questions is that how many tokens did you train with Muon for your 20k steps? It is likely your trained tokens is already in the over-train setting. Your curve looks like a typical case that Muon trained model is not well weight decayed.

  3. For an over-train setting, as our paper mentioned (https://arxiv.org/pdf/2502.16982), it is important to do weight decay on all parameters, even including the RMSNorm gamma, see Part 2.1 and Appendix D.

  4. However, the default setting for megatron is to set no weight decay for RMSNorm Gamma
    https://github.com/NVIDIA/Megatron-LM/pull/1428/files#diff-b5fac51ecd0148c2f4f8f2f1e64535089e90be87606c1f9357778d05af823220R100

A simple way to hack is to add one line to force lr_mult = 1.0 and wd_mult = 1.0 for all parameters after line 114
image

@toothacher17
Copy link
Author

image @toothacher17 Hello, I am comparing the performance between AdamW and Muon. The experiment involves training a 1B parameter MoE model on the H800 cluster, with a maximum learning rate of 1e-3 that decays to 1e-4 using cosine decay. Muon uses default parameters. I observed that Muon has a significant advantage over AdamW in the early stages of training, but after 20k steps, their performance becomes similar, with AdamW sometimes even outperforming Muon. Is this phenomenon normal?

Your exp looks very much like this:

image

Let us know if adding weight decay to all params helps!

@mactavish91
Copy link

image @toothacher17 Hello, I am comparing the performance between AdamW and Muon. The experiment involves training a 1B parameter MoE model on the H800 cluster, with a maximum learning rate of 1e-3 that decays to 1e-4 using cosine decay. Muon uses default parameters. I observed that Muon has a significant advantage over AdamW in the early stages of training, but after 20k steps, their performance becomes similar, with AdamW sometimes even outperforming Muon. Is this phenomenon normal?

Your exp looks very much like this:

image Let us know if adding weight decay to all params helps!

Thank you for your kindly help. I used train loss, since the dataset is of a pretrain-level size, it can be approximated as val loss. After 20k steps, approximately 50B tokens have been trained. I tried applying weight decay to all parameters, but it doesn't seem to help much.

@toothacher17
Copy link
Author

image @toothacher17 Hello, I am comparing the performance between AdamW and Muon. The experiment involves training a 1B parameter MoE model on the H800 cluster, with a maximum learning rate of 1e-3 that decays to 1e-4 using cosine decay. Muon uses default parameters. I observed that Muon has a significant advantage over AdamW in the early stages of training, but after 20k steps, their performance becomes similar, with AdamW sometimes even outperforming Muon. Is this phenomenon normal?

Your exp looks very much like this:
image
Let us know if adding weight decay to all params helps!

Thank you for your kindly help. I used train loss, since the dataset is of a pretrain-level size, it can be approximated as val loss. After 20k steps, approximately 50B tokens have been trained. I tried applying weight decay to all parameters, but it doesn't seem to help much.

Hi, thanks for sharing. Yeah all your settings looked fine and reasonable. Since it is pretrain level size, reporting pretrain loss is also reasonable. I am not sure what is the root cause that Muon's advantages diminishes, we found muon performed well as long as we adding correct weight decay and adjusting the update rms for matrix's shape.

If it is ok, do you mind sharing your model arch details, and we can try in our code base and data, and see what's going on

@toothacher17
Copy link
Author

image @toothacher17 Hello, I am comparing the performance between AdamW and Muon. The experiment involves training a 1B parameter MoE model on the H800 cluster, with a maximum learning rate of 1e-3 that decays to 1e-4 using cosine decay. Muon uses default parameters. I observed that Muon has a significant advantage over AdamW in the early stages of training, but after 20k steps, their performance becomes similar, with AdamW sometimes even outperforming Muon. Is this phenomenon normal?

Your exp looks very much like this:
image
Let us know if adding weight decay to all params helps!

Thank you for your kindly help. I used train loss, since the dataset is of a pretrain-level size, it can be approximated as val loss. After 20k steps, approximately 50B tokens have been trained. I tried applying weight decay to all parameters, but it doesn't seem to help much.

Another thing to debug is to observe your weight rms, max logit, output rms per layer, and update rms during the training and see if there is anything weird that is happening

@hjlee1371
Copy link

Hi, thank you for open-sourcing great job! I have some questions:

  1. It seems that this implementation treats weight matrices for QKV projection as one parameter (linear_qkv of attention), rather than splitting them as in here. Is this intended?
  2. Can you share some intuitions about the effect of splitting WQ, WK, and WV matrices (e.g., treating each head as a separate parameter, as mentioned in the paper)?

@SeunghyunSEO
Copy link

SeunghyunSEO commented Feb 26, 2025

Hi, thank you for open-sourcing great job! I have some questions:

  1. It seems that this implementation treats weight matrices for QKV projection as one parameter (linear_qkv of attention), rather than splitting them as in here. Is this intended?
  2. Can you share some intuitions about the effect of splitting WQ, WK, and WV matrices (e.g., treating each head as a separate parameter, as mentioned in the paper)?

I think the reason megatron uses fused QKV is to reduce memory access and allow the GPU to perform larger matrix multiplications. This seems more like megatron's built-in optimization rather than the moonlight author's intention.
In the original codebase, it looks like they fused the operation as well, but in a slightly different way. but I'm not sure how muon processes this 3D qkv_w tensor.

well, but... idk why using separate weights is better.
If muon is just an approximation of a second-order optimizer like shampoo, shouldn't it perform better when it considers more correlations between the matrices?

@toothacher17
Copy link
Author

toothacher17 commented Feb 26, 2025

Hi, thank you for open-sourcing great job! I have some questions:

  1. It seems that this implementation treats weight matrices for QKV projection as one parameter (linear_qkv of attention), rather than splitting them as in here. Is this intended?
  2. Can you share some intuitions about the effect of splitting WQ, WK, and WV matrices (e.g., treating each head as a separate parameter, as mentioned in the paper)?

Very good questions!

  1. For the moonlight and moonlight-a model, we used MLA, so Q K V are naturally split. Besides, following Keller's blog, it seems that splitting performs better. We recommending split Q K V into three matrices and update them respectively for now;

  2. For splitting Q K V into multiple heads and updating them separately, I think https://leloykun.github.io/ has some experiments. For now we do not split them and update the Q K V all heads together. But in the sec 3.1 of the paper, you can see that Query projection matrix performed very differently comparing to the MLP matrix. While the Update Norm method is strictly controlled RMS to match AdamW, the Adjusted LR method we used here is not. I think there are some room here to further improve it

In general, the concept of 'matrix' might not be well defined in Muon, and for now we relied on empirical results to decide the matrix split

@toothacher17
Copy link
Author

Hi, thank you for open-sourcing great job! I have some questions:

  1. It seems that this implementation treats weight matrices for QKV projection as one parameter (linear_qkv of attention), rather than splitting them as in here. Is this intended?
  2. Can you share some intuitions about the effect of splitting WQ, WK, and WV matrices (e.g., treating each head as a separate parameter, as mentioned in the paper)?

I think the reason megatron uses fused QKV is to reduce memory access and allow the GPU to perform larger matrix multiplications. This seems more like megatron's built-in optimization rather than the moonlight author's intention. In the original codebase, it looks like they fused the operation as well, but in a slightly different way. but I'm not sure how muon processes this 3D qkv_w tensor.

well, but... idk why using separate weights is better. If muon is just an approximation of a second-order optimizer like shampoo, shouldn't it perform better when it considers more correlations between the matrices?

Yeah splitting them into three matrices performed better empirically so we followed. For moonlight, it uses MLA so it is naturally split.

@toothacher17
Copy link
Author

Hi, thank you for open-sourcing great job! I have some questions:

  1. It seems that this implementation treats weight matrices for QKV projection as one parameter (linear_qkv of attention), rather than splitting them as in here. Is this intended?
  2. Can you share some intuitions about the effect of splitting WQ, WK, and WV matrices (e.g., treating each head as a separate parameter, as mentioned in the paper)?

I think the reason megatron uses fused QKV is to reduce memory access and allow the GPU to perform larger matrix multiplications. This seems more like megatron's built-in optimization rather than the moonlight author's intention. In the original codebase, it looks like they fused the operation as well, but in a slightly different way. but I'm not sure how muon processes this 3D qkv_w tensor.

well, but... idk why using separate weights is better. If muon is just an approximation of a second-order optimizer like shampoo, shouldn't it perform better when it considers more correlations between the matrices?

Besides the larger matrix multiplications, another advantage of using QKV fused is that you only need to gather the input between TP group once (if TP and SP are enabled) and used them for projection

@mactavish91
Copy link

image @toothacher17 Hello, I am comparing the performance between AdamW and Muon. The experiment involves training a 1B parameter MoE model on the H800 cluster, with a maximum learning rate of 1e-3 that decays to 1e-4 using cosine decay. Muon uses default parameters. I observed that Muon has a significant advantage over AdamW in the early stages of training, but after 20k steps, their performance becomes similar, with AdamW sometimes even outperforming Muon. Is this phenomenon normal?

Your exp looks very much like this:
image
Let us know if adding weight decay to all params helps!

Thank you for your kindly help. I used train loss, since the dataset is of a pretrain-level size, it can be approximated as val loss. After 20k steps, approximately 50B tokens have been trained. I tried applying weight decay to all parameters, but it doesn't seem to help much.

Hi, thanks for sharing. Yeah all your settings looked fine and reasonable. Since it is pretrain level size, reporting pretrain loss is also reasonable. I am not sure what is the root cause that Muon's advantages diminishes, we found muon performed well as long as we adding correct weight decay and adjusting the update rms for matrix's shape.

If it is ok, do you mind sharing your model arch details, and we can try in our code base and data, and see what's going on

The following are the settings I used in the experiment

#!/bin/bash

TEXT_DATA_PATH=""

NAME="1b-1e-3-qknorm-factor2.45-muon"
CHECKPOINT_PATH="checkpoints/${NAME}"
TENSORBOARD_PATH="runs/research/${NAME}"
KEEP_LATEST_CKPT=3  

MICRO_BATCH_SIZE=1
GLOBAL_BATCH_SIZE=1152

TP_SIZE=1
PP_SIZE=1
EP_SIZE=1

MOE_ROUTED_EXPERTS=64
MOE_ACTIVE_ROUTED_EXPERTS=6
MOE_SHARED_EXPERTS=2

NHIDDEN=728
MOE_FFN_HIDDEN=408
MOE_SHARED_EXPERT_INTERMEDIATE_SIZE=$(($MOE_FFN_HIDDEN * $MOE_SHARED_EXPERTS))
FFN_HIDDEN=2176
NLAYERS=18
NHEADS=8

SEQ_LEN=2048

SAVE_INTERVAL=50000

TRAIN_TOKENS=100000000000 # 100B tokens
TRAIN_SAMPLES=$((TRAIN_TOKENS / SEQ_LEN))
LR_DECAY_SAMPLES=$((TRAIN_SAMPLES * 98 / 100))
LR_WARMUP_SAMPLES=$((TRAIN_SAMPLES * 1 / 100))

NCCL_IB_QPS_PER_CONNECTION=2

script_path="pretrain_gpt.py"

OPTIMIZER_ARGS="
    --optimizer muon
    --muon-matched-adamw-rms 0.2
    --adam-beta1 0.9
    --adam-beta2 0.95
    --adam-eps 1e-8
    --lr 1e-3
    --min-lr 1e-4
    --lr-decay-style cosine
    --lr-decay-samples $LR_DECAY_SAMPLES
    --lr-warmup-samples $LR_WARMUP_SAMPLES
    --clip-grad 1.0
    --weight-decay 1e-1
    --hidden-dropout 0.0
    --attention-dropout 0.0
    --initial-loss-scale 65536
"

MOE_ARGS="
    --num-experts $MOE_ROUTED_EXPERTS
    --moe-shared-expert-intermediate-size $MOE_SHARED_EXPERT_INTERMEDIATE_SIZE
    --moe-shared-expert-overlap
    --moe-router-topk $MOE_ACTIVE_ROUTED_EXPERTS
    --moe-grouped-gemm
    --moe-num-first-dense-layers 2
    --moe-ffn-hidden-size $MOE_FFN_HIDDEN
    --expert-model-parallel-size $EP_SIZE
    --moe-permute-fusion
    --moe-router-enable-expert-bias
    --moe-router-bias-update-rate 1e-3
    --expert-balance-factor 0
    --device-balance-factor 0
    --moe-global-batch-balance
    --moe-router-activation-type softmax
    --moe-routed-scaling-factor 2.45
"

MODEL_ARGS="
    --bf16
    --num-layers $NLAYERS
    --hidden-size $NHIDDEN
    --ffn-hidden-size $FFN_HIDDEN
    --seq-length $SEQ_LEN
    --no-interleaved-qkv
    --max-position-embeddings $SEQ_LEN
    --num-attention-heads $NHEADS
    --disable-bias-linear
    --add-qkv-bias
    --rotary-percent 0.5
    --swiglu
    --use-flash-attn
    --transformer-impl transformer_engine
    --untie-embeddings-and-output-weights
    --position-embedding-type rope
    --no-position-embedding
    --normalization RMSNorm
    --use-mcore-models
    --manual-gc
    --kv-channels 128
    --qk-layernorm
"

TRAINING_ARGS="
    --micro-batch-size $MICRO_BATCH_SIZE
    --global-batch-size $GLOBAL_BATCH_SIZE
    --train-samples $TRAIN_SAMPLES
    --tensor-model-parallel-size $TP_SIZE
    --pipeline-model-parallel-size $PP_SIZE
    --use-distributed-optimizer
    --overlap-grad-reduce
"

DATA_ARGS="
    --num-workers 1
    --train-data-path $TEXT_DATA_PATH
"

OUTPUT_ARGS="
    --log-throughput \
    --log-interval 1 \
    --eval-interval 0 \
    --timing-log-level 0 \
    --save-interval $SAVE_INTERVAL \
    --tensorboard-dir $TENSORBOARD_PATH/tensorboard \
    --wandb-save-dir $CHECKPOINT_PATH \
    --wandb-exp-name $NAME \
"

gpt_options="
    $MODEL_ARGS
    $MOE_ARGS
    $TRAINING_ARGS
    $OPTIMIZER_ARGS
    $DATA_ARGS
    $OUTPUT_ARGS
    --distributed-timeout-minutes 20
    --init-method-std 0.006
    --save $CHECKPOINT_PATH
    --load $CHECKPOINT_PATH
    --save-async-fast-checkpoint
"

@toothacher17
Copy link
Author

toothacher17 commented Feb 26, 2025

image @toothacher17 Hello, I am comparing the performance between AdamW and Muon. The experiment involves training a 1B parameter MoE model on the H800 cluster, with a maximum learning rate of 1e-3 that decays to 1e-4 using cosine decay. Muon uses default parameters. I observed that Muon has a significant advantage over AdamW in the early stages of training, but after 20k steps, their performance becomes similar, with AdamW sometimes even outperforming Muon. Is this phenomenon normal?

Your exp looks very much like this:
image
Let us know if adding weight decay to all params helps!

Thank you for your kindly help. I used train loss, since the dataset is of a pretrain-level size, it can be approximated as val loss. After 20k steps, approximately 50B tokens have been trained. I tried applying weight decay to all parameters, but it doesn't seem to help much.

Hi, thanks for sharing. Yeah all your settings looked fine and reasonable. Since it is pretrain level size, reporting pretrain loss is also reasonable. I am not sure what is the root cause that Muon's advantages diminishes, we found muon performed well as long as we adding correct weight decay and adjusting the update rms for matrix's shape.
If it is ok, do you mind sharing your model arch details, and we can try in our code base and data, and see what's going on

The following are the settings I used in the experiment

#!/bin/bash

TEXT_DATA_PATH=""

NAME="1b-1e-3-qknorm-factor2.45-muon"
CHECKPOINT_PATH="checkpoints/${NAME}"
TENSORBOARD_PATH="runs/research/${NAME}"
KEEP_LATEST_CKPT=3  

MICRO_BATCH_SIZE=1
GLOBAL_BATCH_SIZE=1152

TP_SIZE=1
PP_SIZE=1
EP_SIZE=1

MOE_ROUTED_EXPERTS=64
MOE_ACTIVE_ROUTED_EXPERTS=6
MOE_SHARED_EXPERTS=2

NHIDDEN=728
MOE_FFN_HIDDEN=408
MOE_SHARED_EXPERT_INTERMEDIATE_SIZE=$(($MOE_FFN_HIDDEN * $MOE_SHARED_EXPERTS))
FFN_HIDDEN=2176
NLAYERS=18
NHEADS=8

SEQ_LEN=2048

SAVE_INTERVAL=50000

TRAIN_TOKENS=100000000000 # 100B tokens
TRAIN_SAMPLES=$((TRAIN_TOKENS / SEQ_LEN))
LR_DECAY_SAMPLES=$((TRAIN_SAMPLES * 98 / 100))
LR_WARMUP_SAMPLES=$((TRAIN_SAMPLES * 1 / 100))

NCCL_IB_QPS_PER_CONNECTION=2

script_path="pretrain_gpt.py"

OPTIMIZER_ARGS="
    --optimizer muon
    --muon-matched-adamw-rms 0.2
    --adam-beta1 0.9
    --adam-beta2 0.95
    --adam-eps 1e-8
    --lr 1e-3
    --min-lr 1e-4
    --lr-decay-style cosine
    --lr-decay-samples $LR_DECAY_SAMPLES
    --lr-warmup-samples $LR_WARMUP_SAMPLES
    --clip-grad 1.0
    --weight-decay 1e-1
    --hidden-dropout 0.0
    --attention-dropout 0.0
    --initial-loss-scale 65536
"

MOE_ARGS="
    --num-experts $MOE_ROUTED_EXPERTS
    --moe-shared-expert-intermediate-size $MOE_SHARED_EXPERT_INTERMEDIATE_SIZE
    --moe-shared-expert-overlap
    --moe-router-topk $MOE_ACTIVE_ROUTED_EXPERTS
    --moe-grouped-gemm
    --moe-num-first-dense-layers 2
    --moe-ffn-hidden-size $MOE_FFN_HIDDEN
    --expert-model-parallel-size $EP_SIZE
    --moe-permute-fusion
    --moe-router-enable-expert-bias
    --moe-router-bias-update-rate 1e-3
    --expert-balance-factor 0
    --device-balance-factor 0
    --moe-global-batch-balance
    --moe-router-activation-type softmax
    --moe-routed-scaling-factor 2.45
"

MODEL_ARGS="
    --bf16
    --num-layers $NLAYERS
    --hidden-size $NHIDDEN
    --ffn-hidden-size $FFN_HIDDEN
    --seq-length $SEQ_LEN
    --no-interleaved-qkv
    --max-position-embeddings $SEQ_LEN
    --num-attention-heads $NHEADS
    --disable-bias-linear
    --add-qkv-bias
    --rotary-percent 0.5
    --swiglu
    --use-flash-attn
    --transformer-impl transformer_engine
    --untie-embeddings-and-output-weights
    --position-embedding-type rope
    --no-position-embedding
    --normalization RMSNorm
    --use-mcore-models
    --manual-gc
    --kv-channels 128
    --qk-layernorm
"

TRAINING_ARGS="
    --micro-batch-size $MICRO_BATCH_SIZE
    --global-batch-size $GLOBAL_BATCH_SIZE
    --train-samples $TRAIN_SAMPLES
    --tensor-model-parallel-size $TP_SIZE
    --pipeline-model-parallel-size $PP_SIZE
    --use-distributed-optimizer
    --overlap-grad-reduce
"

DATA_ARGS="
    --num-workers 1
    --train-data-path $TEXT_DATA_PATH
"

OUTPUT_ARGS="
    --log-throughput \
    --log-interval 1 \
    --eval-interval 0 \
    --timing-log-level 0 \
    --save-interval $SAVE_INTERVAL \
    --tensorboard-dir $TENSORBOARD_PATH/tensorboard \
    --wandb-save-dir $CHECKPOINT_PATH \
    --wandb-exp-name $NAME \
"

gpt_options="
    $MODEL_ARGS
    $MOE_ARGS
    $TRAINING_ARGS
    $OPTIMIZER_ARGS
    $DATA_ARGS
    $OUTPUT_ARGS
    --distributed-timeout-minutes 20
    --init-method-std 0.006
    --save $CHECKPOINT_PATH
    --load $CHECKPOINT_PATH
    --save-async-fast-checkpoint
"

Hi, @mactavish91 your model arch looks reasonable.

For the purpose of debugging, I'll need more monitoring that current open source megatron-lm does not have. So I'll run in our internal infra, with some slight changes:

  1. We will use our own data so the seq-len will be changed from 2048 to 8192. Correspondingly, the bsz will be changed from 1152 to 288
  2. we will not add the attention qk bias
  3. we will update q k v three matrices separately
  4. since we are using moe with auxfree bias and a scaling factor of 2.45, I'll use the sigmoid gate, rather than the softmax gate

Other settings will remain the same as you posted. We'll keep you posted about our findings

@toothacher17
Copy link
Author

toothacher17 commented Feb 26, 2025

@mactavish91

I am running on your two configs right now and not sure about the results yet. But I did some math and probably found out the problem: the model might be too small comparing to its embedding. We have a 160K vocab size, (I am not sure about yours, do you mind sharing it?), so the parameters became:

Total Params:
total = 1,351,721,016
embedding = 768 X 163840 X 2 = 251,658,240
total excluding embedding = 1,351,721,016 - 251,658,240 = 1,100,062,776

Activated Params:
not activated = 18 X 408 X 728 X 3 X 58 = 930,279,168
total activated = 1,351,721,016 - 930,279,168 = 421,441,848
total activated excluding embedding: 1,100,062,776 - 930,279,168 = 169,783,608

So you can see, the model has ~170M non-embedding activated params, about 1.1B non-embedding total and ~252M word embeddings or LM heads. Because the word embeddings and LM heads are updated by the AdamW, so maybe in the long run, there are not too much differences.

I would recommend to try on a larger model as well, for example the 822M one as listed below. We ran on this model with AdamW and Muon for 100B tokens and still see big differences:

image

@mactavish91
Copy link

@mactavish91

I am running on your two configs right now and not sure about the results yet. But I did some math and probably found out the problem: the model might be too small comparing to its embedding. We have a 160K vocab size, (I am not sure about yours, do you mind sharing it?), so the parameters became:

Total Params: total = 1,351,721,016 embedding = 768 X 163840 X 2 = 251,658,240 total excluding embedding = 1,351,721,016 - 251,658,240 = 1,100,062,776

Activated Params: not activated = 18 X 408 X 728 X 3 X 58 = 930,279,168 total activated = 1,351,721,016 - 930,279,168 = 421,441,848 total activated excluding embedding: 1,100,062,776 - 930,279,168 = 169,783,608

So you can see, the model has ~170M non-embedding activated params, about 1.1B non-embedding total and ~252M word embeddings or LM heads. Because the word embeddings and LM heads are updated by the AdamW, so maybe in the long run, there are not too much differences.

I would recommend to try on a larger model as well, for example the 822M one as listed below. We ran on this model with AdamW and Muon for 100B tokens and still see big differences:

image

Our tokenizer size is 150k, and it is very likely the reason behind the issue. I will switch to a 60k tokenizer and increase the hidden size and the number of layers for a new experiment.

@toothacher17
Copy link
Author

toothacher17 commented Feb 26, 2025

@mactavish91
I am running on your two configs right now and not sure about the results yet. But I did some math and probably found out the problem: the model might be too small comparing to its embedding. We have a 160K vocab size, (I am not sure about yours, do you mind sharing it?), so the parameters became:
Total Params: total = 1,351,721,016 embedding = 768 X 163840 X 2 = 251,658,240 total excluding embedding = 1,351,721,016 - 251,658,240 = 1,100,062,776
Activated Params: not activated = 18 X 408 X 728 X 3 X 58 = 930,279,168 total activated = 1,351,721,016 - 930,279,168 = 421,441,848 total activated excluding embedding: 1,100,062,776 - 930,279,168 = 169,783,608
So you can see, the model has ~170M non-embedding activated params, about 1.1B non-embedding total and ~252M word embeddings or LM heads. Because the word embeddings and LM heads are updated by the AdamW, so maybe in the long run, there are not too much differences.
I would recommend to try on a larger model as well, for example the 822M one as listed below. We ran on this model with AdamW and Muon for 100B tokens and still see big differences:
image

Our tokenizer size is 150k, and it is very likely the reason behind the issue. I will switch to a 60k tokenizer and increase the hidden size and the number of layers for a new experiment.

Yeah, that would be better to get rid of the impacts of large embeddings. I am still running the two comparing jobs based on your previous smaller model setting in progress.

Besides increasing the, another thing worth mentioning is to use/report the OOD validation data rather than in domain validation data for a more accurate eval of the model.

@toothacher17
Copy link
Author

toothacher17 commented Feb 27, 2025

hi, @mactavish91 We ran your settings for about ~17K steps by now and for about ~40+B tokens (You mentioned before that ~20K steps, the advantages diminish. Even though with the big embedding issue, I actually think the result is promising. We plot the figure as shown below:

  1. With proper smoothing, we can see the training loss gap of muon is not diminishing
  2. We define a new metric, Muon Leading Steps, to understand how many extra steps that AdamW needs to match Muon's performances
  3. Besides, we can use a simple ratio metric Muon_Leading_Steps/Muon_Trained_Steps to help understand that if Muon is consistently leading

image

@toothacher17
Copy link
Author

toothacher17 commented Feb 27, 2025

For the purpose of reproducing, we provide the script to generate these figures. @mactavish91 Can you help to try on such figures based on your previous small run data as well?

if "validation" or 'training' in tag:
    # smooth the data by emw
    ewm_alpha = 0.005
    muon_data = muon_data.ewm(alpha=ewm_alpha).mean()
    adam_data = adam_data.ewm(alpha=ewm_alpha).mean()

    # Subsample both dataframes to ~1000 rows for cleaner plotting
    num_samples = 1000
    stride = max(1, len(muon_data) // num_samples)
    muon_data = muon_data.iloc[::stride]
    adam_data = adam_data.iloc[::stride]

# columns = wall_time, step, value

# assuming steps are sorted
# while losses are not sorted, but it is generally decreasing, and the lower the loss the better
# for each step of muon loss, find the smallest step of adam loss that is smaller than the muon loss
# plot the step difference between the two steps, that is how much steps does adam need to take to match the muon loss
import matplotlib.pyplot as plt


def find_matching_step(target_value: float, reference_df: pd.DataFrame) -> int:
    """Find the first step where reference loss is lower than target value.

    Args:
        target_value: Loss value to match or beat
        reference_df: DataFrame containing step and value columns

    Returns:
        Step number where reference loss first beats target, or max step if never beats
    """
    mask = reference_df["value"] <= target_value
    if not mask.any():
        return None
    return reference_df.loc[mask, "step"].iloc[0]


step_differences = []
for _, muon_row in muon_data.iterrows():
    muon_step = muon_row["step"]
    muon_loss = muon_row["value"]
    adam_matching_step = find_matching_step(muon_loss, adam_data)
    if adam_matching_step is None:
        break
    step_diff = adam_matching_step - muon_step
    step_differences.append((muon_step, step_diff))

step_diff_df = pd.DataFrame(step_differences, columns=["muon_step", "step_difference"])
num_plot_rows = 3
fig, (ax1, ax2, ax3) = plt.subplots(num_plot_rows, 1, figsize=(10, 6 * num_plot_rows))

# Plot losse
muon_run_name = muon_run.split("/")[0]
adam_run_name = adam_run.split("/")[0]
ax1.plot(muon_data["step"], muon_data["value"], label=f"muon: {muon_run_name}")
ax1.plot(adam_data["step"], adam_data["value"], label=f"adam: {adam_run_name}")
ax1.set_xlabel("Training Steps")
#ax1.set_ylabel("Validation Loss Value")
ax1.set_ylabel("Training Loss Value")
ax1.set_title(
    f"Training Loss Comparison\ncollection={collection_id}\ntag={tag}\nsmoothing: {ewm_alpha}, subsampling: {num_samples}"
)

ylim_min = min(muon_data["value"].min(), adam_data["value"].min()) - 0.02
ylim_max = ylim_min + 0.5
ax1.set_ylim(ylim_min, ylim_max)

ax1.grid(True)
ax1.legend()

# Plot step differences
ax2.plot(step_diff_df["muon_step"], step_diff_df["step_difference"])
ax2.set_xlabel("Muon Training Steps")
ax2.set_ylabel("Additional Steps Needed by AdamW")
ax2.set_title("Steps AdamW Needs to Match Muon Performance")
ax2.grid(True)

# plot step diff/step
leading_step_ratio = step_diff_df["step_difference"] / step_diff_df["muon_step"]
ax3.plot(step_diff_df["muon_step"], leading_step_ratio)
ax3.set_xlabel("Muon Training Steps")
ax3.set_ylabel("Muon leading step ratio")
ax3.set_title("Muon leading step ratio")
ax3.set_ylim(0, leading_step_ratio.max() * 1.1)
ax3.grid(True)

plt.tight_layout()

Here adam_data or muon_data is the run we fetched from the TB, tag is simply the 'lm-loss-training/lm loss'

image

@toothacher17
Copy link
Author

@mactavish91 Besides, we also evaluated on OOD lm validation loss data, and it showed pretty good results

image

@toothacher17
Copy link
Author

@mactavish91 We'll wait for you visualization results and see how it goes! Thanks!

@SeunghyunSEO
Copy link

hi guys, let me share my vibe check results.
i tested small scale proxy model with 64*4096=262k batch tokens and 40k horizon, so they consumed 10.5B tokens.
my model config is like standard parameterization (SP) with 0.2 std, GQA, not separated QKV, lr 0.00195, weight decay 0.1 (didn't decay rmsnorm gamma) for 12 layers with 2 width (hidden size), 1024 and 4096.
the larger one (4096) is approximately 3.5B.
and here are my results.

for smaller one, muon outperforms adamw.

Screenshot 2025-02-27 at 2 44 29 PM

and for larger model, muon looks promising too but there is some issue that adamw diverges.

Screenshot 2025-02-27 at 2 44 42 PM

however the real problem is that the throughput of muon is bad at multi node setup.
i used 4 node A100 for larger model, and it's throughput seems it needs to be optimized.
i think gradient all-gather for NS should be overlapped or grad bucketing should be carefully designed.

Screenshot 2025-02-27 at 2 44 54 PM

@toothacher17
Copy link
Author

@SeunghyunSEO Thanks for sharing! I have some comments regarding your runs:

  1. For muon's performances, this is what we actually expected to see! Thanks for sharing! It would be better stability if you norm your weight decay gamma (as mentioned in the appendix) and maybe your adamw lr is too big so it actually does not converge

  2. For the throughput issue, this is probably because the distributed optimizer implementation changes. We noticed this gap when porting our internal impl to the open sourced one. Previously in Megatron-LM, the distributed optimizer states are concat together and flatten into a list (the concatting order is defined by the params init order). Then the list is split into DP parts. So only those params (very few actually if you think about it) that are split in the DP boundary will need the extra gather.

However, the current impl of distributed optimizer is first to group in several params into a bucket. And every params in that bucket will be split into DP parts and needs a gather! Thus, bringing the extra needed all gather to its upperbound, which means every params in every rank needs to a gather. For distributed muon to work efficiently as described in our paper. We need the original way of DP sharding optimizer states, which only requires very limited params to do the extra gathering

@SeunghyunSEO
Copy link

SeunghyunSEO commented Feb 27, 2025

@toothacher17 wow, your response is as fast as the speed of light, lol. I didn’t even know that megatron changed its sharding logic. (I’m also familiar with the sharding strategy you mentioned in point 2.) I’ll dig into the codebase and come back if I find any clues to improve performance.

edited) can you share related PR for refactoring param and grad bucketing? I'm not sure this one is right.

@toothacher17
Copy link
Author

@SeunghyunSEO Thanks for your kind words! However, I am not sure when they changed the logic since we do not merge the upstream for a while. We just noticed when we are preparing this PR. The commit you find looks very related but I am not sure if that's the only related one.

BTW, do you mind visualizing your results using the script I mentioned above? It is exciting to see some other people reproducing similar results of Muon!

@mactavish91
Copy link

2. For the throughput issue, this is probably because the distributed optimizer implementation changes. We noticed this gap when porting our internal impl to the open sourced one. Previously in Megatron-LM, the distributed optimizer states are concat together and flatten into a list (the concatting order is defined by the params init order). Then the list is split into DP parts. So only those params (very few actually if you think about it) that are split in the DP boundary will need the extra gather.

Thank you for your detailed guidance! Here are my findings from today:

  1. In my initial setup (1B MoE model + 150k tokenizer), Muon didn't show an advantage over Adam. However, after using your plotting code, I did observe that Muon's loss appeared lower. I believe this was due to the ewm_alpha in the plotting code being set too low, causing excessive smoothing. After adjusting this value to 0.05, the results, as shown in the attached image, align with the original experiment logs on Wandb.
    comparison_plot2

  2. Using the new setup (4B MoE model + 60k tokenizer) has yielded excellent results! Muon's performance consistently outperforms Adam. This is very encouraging, and I will be experimenting with larger models in the future.
    comparison_plot

  3. Unfortunately, I also observed results similar to @SeunghyunSEO 's findings. When using Muon, cluster throughput significantly decreases. I look forward to further community optimizations in this area.
    image

  4. I also noticed that when using Muon, the L2 norm of the model output is several times higher compared to Adam. Do you have any tuning recommendations for this?
    image
    image

@toothacher17
Copy link
Author

toothacher17 commented Feb 27, 2025

  1. For the throughput issue, this is probably because the distributed optimizer implementation changes. We noticed this gap when porting our internal impl to the open sourced one. Previously in Megatron-LM, the distributed optimizer states are concat together and flatten into a list (the concatting order is defined by the params init order). Then the list is split into DP parts. So only those params (very few actually if you think about it) that are split in the DP boundary will need the extra gather.

Thank you for your detailed guidance! Here are my findings from today:

  1. In my initial setup (1B MoE model + 150k tokenizer), Muon didn't show an advantage over Adam. However, after using your plotting code, I did observe that Muon's loss appeared lower. I believe this was due to the ewm_alpha in the plotting code being set too low, causing excessive smoothing. After adjusting this value to 0.05, the results, as shown in the attached image, align with the original experiment logs on Wandb.
    comparison_plot2
  2. Using the new setup (4B MoE model + 60k tokenizer) has yielded excellent results! Muon's performance consistently outperforms Adam. This is very encouraging, and I will be experimenting with larger models in the future.
    comparison_plot
  3. Unfortunately, I also observed results similar to @SeunghyunSEO 's findings. When using Muon, cluster throughput significantly decreases. I look forward to further community optimizations in this area.
    image
  4. I also noticed that when using Muon, the L2 norm of the model output is several times higher compared to Adam. Do you have any tuning recommendations for this?
    image
    image

Glad to see it works! Regarding your questions:

  1. Thanks for your visualization. It clears showed that in this setting, AdamW is close Muon. I still do not know why, as my experiments in a similar setting like your still worked fine for me, with Muon having a reasonable leading ratio;

  2. Glad to see the figure of your setting 2!

  3. For the throughput issue, I think it is caused by the same reason as what @SeunghyunSEO met. However, changing the distributed optimizer's impl might be non-trivial.. I'll discuss with the community and see if we can find some easy way to hack it. The key is to avoid too many parameters being shared, so the extra communication will be minimal. This will also naturally reduce the number of computations needed for NS

  4. For the large output RMS issue, we also monitored it, but we do not see any problem. Can you check if your RMSNorm Gamma is properly weight decayed?

image

Besides, your leading step ratio figure would be better if you can set 'ylim(0,1)'. Would you mind re-plotting the figure and do you mind me sharing it on X as it reproduces our results!

@toothacher17
Copy link
Author

toothacher17 commented Feb 27, 2025

@mactavish91 @SeunghyunSEO Discussed with my colleagues and we might need more investigation on this performance dropping issue. If more gatherings happened, more NS iterative steps calculation will also happen for more params

@mactavish91
Copy link

Glad to share our replication results.
comparison_plot

@SeunghyunSEO
Copy link

@mactavish91 btw, can i ask how you log per param activation norm?? it's too heavy to reduce or gather all stats in forward when naively register forward hook, so I'm curious how you do efficiently logging !

@toothacher17
Copy link
Author

@mactavish91 btw, can i ask how you log per param activation norm?? it's too heavy to reduce or gather all stats in forward when naively register forward hook, so I'm curious how you do efficiently logging !

I think @mactavish91 showed the output l2norm of the model rather than the per param activation norm? For output l2norm, you can open a log buffer to log the detached l2norm or rms of the output, (do not do the sqrt, but only the square sum), accumulate it in the fwd-bwd, and only reduce once after all the fwdbwd is done. This is cheap as you only do it once per global step

@toothacher17
Copy link
Author

Some other issues asked why distributed muon is efficient, and tried to explain it with details: MoonshotAI/Moonlight#16

@SeunghyunSEO
Copy link

SeunghyunSEO commented Feb 28, 2025

@toothacher17 ty for sharing!
oh it's just layerwise output norm.
i mean i want to log both per param activation and grad norm, but it's kinda messy.
i know there is more clean way like lingua's probing module, so just want to ask if he use clean and efficient logging module for megatron :)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants