Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

a proof of concept for Distributed Muon #1428

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
88 changes: 88 additions & 0 deletions examples/muon/training.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,88 @@
#!/bin/bash

# Runs the "340M" parameter model with Distributed Muon
# See more details at: https://github.com/MoonshotAI/Moonlight/blob/master/Moonlight.pdf

export CUDA_DEVICE_MAX_CONNECTIONS=1

GPUS_PER_NODE=8
# Change for multinode config
MASTER_ADDR=localhost
MASTER_PORT=6000
NUM_NODES=1
NODE_RANK=0
WORLD_SIZE=$(($GPUS_PER_NODE*$NUM_NODES))

CHECKPOINT_PATH=/PATH/TO/CKPT
TENSORBOARD_LOGS_PATH=/PATH/TO/TB
VOCAB_FILE=/PATH/TO/VOCAB
MERGE_FILE=/PATH/TO/MERGE
# data is preprocessed as described in Megatron-LM' readme
DATA_PATH=/PATH/TO/DATA

DISTRIBUTED_ARGS=(
--nproc_per_node $GPUS_PER_NODE
--nnodes $NUM_NODES
--master_addr $MASTER_ADDR
--master_port $MASTER_PORT
)

GPT_MODEL_ARGS=(
--num-layers 12
--hidden-size 1536
--num-attention-heads 12
--num-query-groups 12
--seq-length 1024
--max-position-embeddings 1024
--transformer-impl local
)

TRAINING_ARGS=(
--optimizer muon
--micro-batch-size 1
--global-batch-size 64
--train-iters 5000
--weight-decay 0.1
--adam-beta1 0.9
--adam-beta2 0.95
--init-method-std 0.006
--clip-grad 1.0
--bf16
--lr 1e-3
--lr-decay-style cosine
--min-lr 1e-4
--muon-matched-adamw-rms 0.2
--lr-warmup-fraction 0.02
--lr-decay-iters 5000
--use-distributed-optimizer
--ckpt-format torch

)

MODEL_PARALLEL_ARGS=(
--tensor-model-parallel-size 2
--pipeline-model-parallel-size 2
)

DATA_ARGS=(
--data-path $DATA_PATH
--vocab-file $VOCAB_FILE
--merge-file $MERGE_FILE
--split 949,50,1
)

EVAL_AND_LOGGING_ARGS=(
--log-interval 100
--save-interval 10000
--eval-interval 1000
--save $CHECKPOINT_PATH
--eval-iters 10
--tensorboard-dir $TENSORBOARD_LOGS_PATH
)

torchrun ${DISTRIBUTED_ARGS[@]} pretrain_gpt.py \
${GPT_MODEL_ARGS[@]} \
${TRAINING_ARGS[@]} \
${MODEL_PARALLEL_ARGS[@]} \
${DATA_ARGS[@]} \
${EVAL_AND_LOGGING_ARGS[@]}
59 changes: 55 additions & 4 deletions megatron/core/optimizer/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import torch
from torch.optim import SGD as CPUSGD
from torch.optim import AdamW as CPUAdam
from .muon import Muon

try:
from transformer_engine.pytorch.optimizers import FusedAdam as Adam
Expand Down Expand Up @@ -52,6 +53,8 @@ def _get_param_groups(
min_lr: float,
decoupled_lr: Optional[float],
decoupled_min_lr: Optional[float],
muon_matched_adamw_rms: Optional[float],
use_muon: bool = False,
) -> List[Dict]:
"""Create parameter groups for optimizer.
Expand Down Expand Up @@ -82,6 +85,7 @@ def _get_param_groups(

# Map (wd_mult, lr_mult, is_expert_parallel, is_decoupled_lr) to params.
params_map = {}
muon_params_map = {}
for model_chunk in model_chunks:
for name, param in model_chunk.named_parameters():
if not param.requires_grad:
Expand Down Expand Up @@ -117,10 +121,22 @@ def _get_param_groups(
):
is_decoupled_lr = True

key = (wd_mult, _lr_mult, is_expert_parallel, is_decoupled_lr)
if key not in params_map:
params_map[key] = []
params_map[key].append(param)
# check if linear params
bias_flag = name.endswith(".bias")
shape_flag = param.dim() == 2
embedding_flag = "embedding" in name or "output_layer" in name
muon_flag = use_muon and shape_flag \
and (not bias_flag) and (not embedding_flag)
if muon_flag:
key = (wd_mult, _lr_mult, is_expert_parallel)
if key not in muon_params_map:
muon_params_map[key] = []
muon_params_map[key].append(param)
else:
key = (wd_mult, _lr_mult, is_expert_parallel, is_decoupled_lr)
if key not in params_map:
params_map[key] = []
params_map[key].append(param)

param_groups = []
for (wd_mult, _lr_mult, is_expert_parallel, is_decoupled_lr), params in params_map.items():
Expand All @@ -142,6 +158,20 @@ def _get_param_groups(
decoupled_min_lr=decoupled_min_lr,
)

for (wd_mult, _lr_mult, is_expert_parallel), params in muon_params_map.items():
if len(params) == 0:
continue
param_groups.append(
{
'params': params,
'wd_mult': wd_mult,
'lr_mult': _lr_mult,
'is_expert_parallel': is_expert_parallel,
'use_muon': True,
'is_decoupled_lr': False,
}
)

return param_groups


Expand Down Expand Up @@ -224,6 +254,8 @@ def _get_param_groups_and_buffers(
min_lr=config.min_lr,
decoupled_lr=config.decoupled_lr,
decoupled_min_lr=config.decoupled_min_lr,
muon_matched_adamw_rms=config.muon_matched_adamw_rms,
use_muon = config.optimizer == 'muon',
)
param_groups = list(filter(filter_fn, param_groups))
buffers = {}
Expand Down Expand Up @@ -350,6 +382,25 @@ def init_state_fn(opt, config=None):
momentum=config.sgd_momentum,
)
init_state_fn = None
elif config.optimizer == 'muon':
optimizer = Muon(param_groups,
lr=config.lr, weight_decay=config.weight_decay,
matched_adamw_rms=config.muon_matched_adamw_rms,
momentum=config.muon_momentum,
nesterov=config.muon_nesterov,
ns_steps=config.muon_ns_steps,
adamw_betas=(config.adam_beta1, config.adam_beta2),
adamw_eps=config.adam_eps)

def init_state_fn(opt, config=None):
for group in opt.param_groups:
for p in group['params']:
if len(opt.state[p]) == 0:
if config is None or not config.use_precision_aware_optimizer:
opt.state[p]['exp_avg'] = torch.zeros_like(p.data)
opt.state[p]['exp_avg_sq'] = torch.zeros_like(p.data)
else:
opt.initialize_state(p)
else:
raise Exception('{} optimizer is not supported.'.format(config.optimizer))
else:
Expand Down
56 changes: 50 additions & 6 deletions megatron/core/optimizer/distrib_optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,8 @@
_zero_grad_group_helper,
)
from .optimizer_config import OptimizerConfig
from .muon import Muon, MuonDistMeta
from megatron.core.parallel_state import get_tensor_model_parallel_group

try:
# This will be used when "--fp8-param-gather" is enabled.
Expand Down Expand Up @@ -148,6 +150,7 @@ def _build_model_gbuf_param_range_map(
sub_param_start = max(0, gbuf_world_range.start - param_world_start)
sub_param_range = param_local_range.normalize(sub_param_start)
param_range_map[param] = {
"world_indexes": (param_world_start, param_world_end),
"gbuf_world": param_world_range,
"gbuf_world_in_bucket": param_world_range_in_bucket,
"gbuf_local": param_local_range,
Expand Down Expand Up @@ -335,13 +338,22 @@ def _build_model_and_main_param_groups(
shard_fp32_groups.append(shard_fp32_params_this_group)
shard_fp32_from_float16_groups.append(shard_fp32_from_float16_params_this_group)

dist_metas = {}

for model_param in group_range["params"]:

assert model_param.requires_grad

gbuf_index, dtype, bucket_index = param_gbuf_map[model_param]
gbuf_range = gbuf_ranges[gbuf_index][dtype][bucket_index]
param_range = gbuf_range["param_map"][model_param]["param"]
param_gbuf_ranges = gbuf_range["param_map"][model_param]
param_range = param_gbuf_ranges["param"]

# gen dist meta
param_world_indexes = param_gbuf_ranges["world_indexes"]
tp_split_dim = -1 if getattr(model_param, 'tensor_model_parallel', False) else \
getattr(model_param, 'partition_dim')
dist_meta = MuonDistMeta(gbuf_index, bucket_index, model_param.shape, param_world_indexes, tp_split_dim)

# fp16, bf16 params.
if model_param.type() in ['torch.cuda.HalfTensor', 'torch.cuda.BFloat16Tensor']:
Expand Down Expand Up @@ -395,6 +407,9 @@ def _build_model_and_main_param_groups(
shard_float16_params_this_group.append(shard_model_param)
shard_fp32_from_float16_params_this_group.append(shard_main_param)

# add to dist metas
dist_metas[shard_main_param] = dist_meta

# fp32 params.
elif model_param.type() == 'torch.cuda.FloatTensor':
shard_model_param = model_param.view(-1)[param_range.start : param_range.end]
Expand Down Expand Up @@ -433,7 +448,7 @@ def _build_model_and_main_param_groups(
shard_float16_groups,
shard_fp32_groups,
shard_fp32_from_float16_groups,
)
), dist_metas

def __init__(
self,
Expand Down Expand Up @@ -489,8 +504,8 @@ def __init__(
for model_chunk in self.model_chunks:
assert self.ddp_config == model_chunk.ddp_config

assert isinstance(optimizer, (Adam, HybridDeviceOptimizer)) or optimizer is None, (
"Only Adam and HybridDeviceOptimizer currently supported, "
assert isinstance(optimizer, (Adam, HybridDeviceOptimizer, Muon)) or optimizer is None, (
"Only Adam / HybridDeviceOptimizer / Muon currently supported, "
"due to checkpointing requirements."
)

Expand Down Expand Up @@ -567,8 +582,8 @@ def __init__(
self.shard_float16_groups,
self.shard_fp32_groups,
self.shard_fp32_from_float16_groups,
) = self._build_model_and_main_param_groups(
self.gbuf_ranges, self.model_param_gbuf_map, self.opt_group_ranges, config
), dist_metas = self._build_model_and_main_param_groups(
self.gbuf_ranges, self.model_param_gbuf_map, self.opt_group_ranges, config
)

if isinstance(self.optimizer, HybridDeviceOptimizer):
Expand All @@ -579,6 +594,18 @@ def __init__(
self.optimizer.param_groups = [g["orig_group"] for g in self.opt_group_ranges]
self.optimizer.load_state_dict(self.optimizer.state_dict())

# for muon optimizer, enable distributed mode
if isinstance(self.optimizer, Muon):
assert all(grad_buffer.grad_dtype == torch.float32 for grad_buffer in self.buffers), \
"all grad buffer should only contains float32 type for muon optimizer"
gbuf_sizes = [ [(bucket.grad_data.numel(), bucket.offset) for bucket in buffer.buckets ]
for buffer in self.buffers ]
self.optimizer.enable_distributed_mode(
gbuf_sizes, self.data_parallel_group,
get_tensor_model_parallel_group(),
dist_metas,
)

self.is_stub_optimizer = False

def _get_model_param_range_map(self, param: torch.nn.Parameter):
Expand Down Expand Up @@ -714,6 +741,11 @@ def load_state_dict(self, state_dict):
)

tensors = {"exp_avg": init_shard(), "exp_avg_sq": init_shard()}
# for muon optimizer link state ( for load state dict )
if isinstance(self.optimizer, Muon):
tensors["muon_buffer"] = tensors["exp_avg"]
tensors["adamw_exp_avg"] = tensors["exp_avg"]
tensors["adamw_exp_avg_sq"] = tensors["exp_avg_sq"]
if self.config.use_precision_aware_optimizer:
tensors["master_param"] = init_shard()
state_dict_state.append((state_order, tensors))
Expand Down Expand Up @@ -808,6 +840,16 @@ def _get_main_param_and_optimizer_states(self, model_param):
main_param = self.optimizer.param_groups[group_index]["params"][group_order]
optim_state = self.optimizer.state[main_param]
tensors = {"param": main_param, **optim_state}

# process muon to be compatiable with adam ( always save to exp_avg / exp_avg_sq )
if isinstance(self.optimizer, Muon):
use_muon = self.optimizer.param_groups[group_index].get("use_muon", False)
if use_muon:
tensors["exp_avg"] = tensors["muon_buffer"]
tensors["exp_avg_sq"] = torch.zeros_like(tensors["param"])
else:
tensors["exp_avg"] = tensors["adamw_exp_avg"]
tensors["exp_avg_sq"] = tensors["adamw_exp_avg_sq"]
return tensors

def _set_main_param_and_optimizer_states(self, model_param, tensors):
Expand Down Expand Up @@ -839,6 +881,8 @@ def _set_main_param_and_optimizer_states(self, model_param, tensors):
optim_state = self.optimizer.state[main_param]
dst_tensors = {"param": main_param, **optim_state}
for key in dst_tensors:
if not key in tensors:
continue
dst_tensors[key].copy_(tensors[key])

def get_parameter_state_fs_bucket_space(self):
Expand Down
Loading