From 82bcdd465a4d5c71fdf4a025c5316eec10c160a6 Mon Sep 17 00:00:00 2001 From: Tianyu Liu Date: Wed, 11 Dec 2024 19:55:11 -0800 Subject: [PATCH] [MoE][PoC] Expert Parallel: dp2ep [ghstack-poisoned] --- torchtitan/config_manager.py | 18 ++- torchtitan/optimizer.py | 2 +- torchtitan/parallelisms/expert_parallel.py | 119 +++++++++++++++++++ torchtitan/parallelisms/parallel_dims.py | 74 +++++++++++- torchtitan/parallelisms/parallelize_llama.py | 98 +++++++++++++-- train.py | 14 ++- train_configs/debug_model.toml | 3 +- 7 files changed, 306 insertions(+), 22 deletions(-) diff --git a/torchtitan/config_manager.py b/torchtitan/config_manager.py index b8fd4f21..55fd6392 100644 --- a/torchtitan/config_manager.py +++ b/torchtitan/config_manager.py @@ -364,15 +364,23 @@ def __init__(self): default=1, help="Context parallelism degree. 1 means disabled.", ) + self.parser.add_argument( + "--experimental.expert_parallel_degree", + type=int, + default=1, + help=""" + Expert parallelism degree. 1 means disabled. + When expert_parallel_mode is 'tp' or 'tp2ep', it has to be equal to tensor_parallel_degree. + When expert_parallel_mode is 'dp2ep', it has to be k * context_parallel_degree, + where k >= 1 and k | data_parallel_shard_degree. + """, + ) self.parser.add_argument( "--experimental.expert_parallel_mode", type=str, default="none", - choices=["none", "tp", "tp2ep"], - help=""" - Expert Parallel mode. - 'tp2ep' would use the entire TP mesh to shard non-shared experts on the num_experts dimension. - """, + choices=["none", "tp", "tp2ep", "dp2ep"], + help="Expert Parallel mode", ) self.parser.add_argument( "--training.mixed_precision_param", diff --git a/torchtitan/optimizer.py b/torchtitan/optimizer.py index d88d431f..8e44f1c5 100644 --- a/torchtitan/optimizer.py +++ b/torchtitan/optimizer.py @@ -86,7 +86,7 @@ def build_optimizers(model_parts, job_config: JobConfig): "betas": (0.9, 0.95), "weight_decay": 0.1, "fused": fused, - "foreach": not fused, + "foreach": False, } return ( diff --git a/torchtitan/parallelisms/expert_parallel.py b/torchtitan/parallelisms/expert_parallel.py index a42a9a15..9d451b98 100644 --- a/torchtitan/parallelisms/expert_parallel.py +++ b/torchtitan/parallelisms/expert_parallel.py @@ -325,3 +325,122 @@ def _apply(self, module: nn.Module, device_mesh: DeviceMesh) -> nn.Module: self._prepare_output_fn, self.output_layouts, self.use_local_output ), ) + + +# This class is for dp2ep with TP (without TP we can just use ExpertParallel) +class ExpertTensorParallel(ParallelStyle): + def __init__( + self, + *, + tp_mesh: DeviceMesh, + ep_mesh: DeviceMesh, + ): + super().__init__() + # TODO: has to pass in the meshes in addition to device_mesh, + # as there's an issue from DeviceMesh that + # "Cannot create a submesh from a submesh." + self.tp_mesh = tp_mesh + self.ep_mesh = ep_mesh + + @staticmethod + def _prepare_input_fn(tp_mesh, ep_mesh, mod, inputs, device_mesh): + input_tensor = inputs[0] + # input_tensor of placements Shard(1) on the tp mesh + assert not isinstance(input_tensor, DTensor) + + # a2a(ep) + input_tensor = DTensor.from_local(input_tensor, ep_mesh, (Shard(1),)) + input_tensor = input_tensor.redistribute(placements=(Shard(0),)).to_local() + # ag(tp) + input_tensor = DTensor.from_local(input_tensor, tp_mesh, (Shard(1),)) + input_tensor = input_tensor.redistribute(placements=(Replicate(),)) + + return input_tensor + + @staticmethod + def _partition_fn(tp_mesh, ep_mesh, name, module, device_mesh): + # TODO: FSDP doesn't support sharding a 2D Tensor into a 3D one yet + # module.register_parameter( + # "gate_proj", + # nn.Parameter( + # distribute_tensor(module.gate_proj, device_mesh, [Shard(0), Shard(2)]) + # ), + # ) # Column-wise sharding + # module.register_parameter( + # "down_proj", + # nn.Parameter( + # distribute_tensor(module.down_proj, device_mesh, [Shard(0), Shard(1)]) + # ), + # ) # Row-wise sharding + # module.register_parameter( + # "up_proj", + # nn.Parameter( + # distribute_tensor(module.up_proj, device_mesh, [Shard(0), Shard(2)]) + # ), + # ) # Column-wise sharding + + # TODO: Instead, for MoE experts, we shard on the EP mesh and then "forget" it. + # This would become an issue from DCP resharding perspective. + module.register_parameter( + "gate_proj", + nn.Parameter( + DTensor.from_local( + ( + distribute_tensor( + module.gate_proj, device_mesh, [Shard(0), Shard(2)] + ).to_local() + ), + tp_mesh, + (Shard(2),), + ) + ), + ) # Column-wise sharding + module.register_parameter( + "down_proj", + nn.Parameter( + DTensor.from_local( + ( + distribute_tensor( + module.down_proj, device_mesh, [Shard(0), Shard(1)] + ).to_local() + ), + tp_mesh, + (Shard(1),), + ) + ), + ) # Row-wise sharding + module.register_parameter( + "up_proj", + nn.Parameter( + DTensor.from_local( + ( + distribute_tensor( + module.up_proj, device_mesh, [Shard(0), Shard(2)] + ).to_local() + ), + tp_mesh, + (Shard(2),), + ) + ), + ) # Column-wise sharding + + @staticmethod + def _prepare_output_fn(tp_mesh, ep_mesh, mod, outputs, device_mesh): + # outputs of placements Partial() on the tp mesh + + # rs(tp) + outputs = outputs.redistribute(placements=(Shard(1),)).to_local() + # a2a(ep) + outputs = DTensor.from_local(outputs, ep_mesh, (Shard(0),)) + outputs = outputs.redistribute(placements=(Shard(1),)).to_local() + + return outputs + + def _apply(self, module: nn.Module, device_mesh: DeviceMesh) -> nn.Module: + return distribute_module( + module, + device_mesh, + partial(self._partition_fn, self.tp_mesh, self.ep_mesh), + partial(self._prepare_input_fn, self.tp_mesh, self.ep_mesh), + partial(self._prepare_output_fn, self.tp_mesh, self.ep_mesh), + ) diff --git a/torchtitan/parallelisms/parallel_dims.py b/torchtitan/parallelisms/parallel_dims.py index 9af771a2..736543c6 100644 --- a/torchtitan/parallelisms/parallel_dims.py +++ b/torchtitan/parallelisms/parallel_dims.py @@ -18,6 +18,8 @@ class ParallelDims: cp: int tp: int pp: int + ep: int + ep_mode: str world_size: int enable_loss_parallel: bool @@ -25,14 +27,15 @@ def __post_init__(self): self._validate() def _validate(self): - dp_replicate, dp_shard, cp, tp, pp = ( + dp_replicate, dp_shard, cp, tp, pp, ep = ( self.dp_replicate, self.dp_shard, self.cp, self.tp, self.pp, + self.ep, ) - for d in (dp_replicate, cp, tp, pp): + for d in (dp_replicate, cp, tp, pp, ep): assert d >= 1, "Parallelism degree should be >= 1, except for dp_shard" assert dp_shard == -1 or dp_shard >= 1, " dp_shard must -1 or >=1." @@ -45,7 +48,74 @@ def _validate(self): f"cp({cp}) * tp({tp}) * pp({pp}) != WORLD_SIZE({self.world_size})" ) + if ep > 1: + assert self.ep_mode in ["tp", "tp2ep", "dp2ep"] + if self.ep_mode == "tp" or self.ep_mode == "tp2ep": + assert ep == tp + elif self.ep_mode == "dp2ep": + # EP would borrow all cp and some dp_shard degree + assert ep % cp == 0 and (dp_shard * cp) % ep == 0 + else: + self.ep_mode = "none" + + def _build_mesh_with_dp2ep(self, device_type): + # In dp2ep, dp_shard and ep are derived submeshes: + # dp_shard = dp_shard_1 * dp_shard_2 + # ep = dp_shard_2 * cp + dp_shard_1 = self.dp_shard * self.cp // self.ep + dp_shard_2 = self.ep // self.cp + + dims = [] + names = [] + for d, name in zip( + [self.pp, self.dp_replicate, dp_shard_1, dp_shard_2, self.cp, self.tp], + ["pp", "dp_replicate", "dp_shard_1", "dp_shard_2", "cp", "tp"], + ): + # dp_shard_1 is needed even if it's 1, whose FSDP wrapping + # helps the MoE layers do mixed precision training + if d > 1 or name == "dp_shard_1": + dims.append(d) + names.append(name) + + logger.info(f"Building {len(dims)}-D device mesh with {names}, {dims}") + names = tuple(names) + mesh = init_device_mesh(device_type, dims, mesh_dim_names=names) + + # Create all the submesh here to ensure all required process groups are + # initialized: + # Mesh for data loading + dp_mesh_dim_names = [] + if self.dp_replicate_enabled: + dp_mesh_dim_names.append("dp_replicate") + dp_mesh_dim_names.append("dp_shard_1") + if "dp_shard_2" in names: + dp_mesh_dim_names.append("dp_shard_2") + mesh[tuple(dp_mesh_dim_names)]._flatten(mesh_dim_name="dp") + + # Mesh for param sharding + dp_shard_cp_mesh_dim_name = [] + dp_shard_cp_mesh_dim_name.append("dp_shard_1") + if "dp_shard_2" in names: + dp_shard_cp_mesh_dim_name.append("dp_shard_2") + if self.cp_enabled: + dp_shard_cp_mesh_dim_name.append("cp") + mesh[tuple(dp_shard_cp_mesh_dim_name)]._flatten(mesh_dim_name="dp_shard_cp") + + # Mesh for ep + ep_mesh_dim_names = [] + if "dp_shard_2" in names: + ep_mesh_dim_names.append("dp_shard_2") + if self.cp_enabled: + ep_mesh_dim_names.append("cp") + assert len(ep_mesh_dim_names) > 0 + mesh[tuple(ep_mesh_dim_names)]._flatten(mesh_dim_name="ep") + + return mesh + def build_mesh(self, device_type): + if self.ep_mode == "dp2ep": + return self._build_mesh_with_dp2ep(device_type) + dims = [] names = [] for d, name in zip( diff --git a/torchtitan/parallelisms/parallelize_llama.py b/torchtitan/parallelisms/parallelize_llama.py index 56df423d..f097e6ba 100644 --- a/torchtitan/parallelisms/parallelize_llama.py +++ b/torchtitan/parallelisms/parallelize_llama.py @@ -65,12 +65,17 @@ def parallelize_llama( enable_async_tp=job_config.experimental.enable_async_tensor_parallel, ) - ep_mode = job_config.experimental.expert_parallel_mode - if ep_mode != "none": + if parallel_dims.ep_mode != "none": apply_ep( model, - ep_mode=ep_mode, + ep_mode=parallel_dims.ep_mode, + ep_mesh=world_mesh["ep"] if parallel_dims.ep_mode == "dp2ep" else None, tp_mesh=world_mesh["tp"] if parallel_dims.tp_enabled else None, + ep_tp_mesh=( + world_mesh["ep", "tp"] + if parallel_dims.ep_mode == "dp2ep" and parallel_dims.tp_enabled + else None + ), ) if job_config.activation_checkpoint.mode != "none": @@ -86,20 +91,30 @@ def parallelize_llama( apply_compile(model) if ( - parallel_dims.dp_shard_enabled or parallel_dims.cp_enabled + parallel_dims.dp_shard_enabled + or parallel_dims.cp_enabled + or parallel_dims.ep_mode == "dp2ep" ): # apply FSDP or HSDP, potentially with Context Parallel if not parallel_dims.dp_shard_enabled and parallel_dims.dp_replicate_enabled: # Composability of DDP + CP is not supported. - raise RuntimeError("Composability of DDP + CP is not supported.") + raise RuntimeError( + "Composability of DDP + CP or DDP + EP is not supported." + ) # the mesh dim names of which the model params are sharded on dp_mesh_dim_names = [] if parallel_dims.dp_replicate_enabled: dp_mesh_dim_names.append("dp_replicate") - dp_mesh_dim_names.append("dp_shard_cp") + # the mesh dim names of which the MoE params are sharded on + dp_mod_ep_mesh_dim_names = [] + if parallel_dims.ep_mode == "dp2ep": + if parallel_dims.dp_replicate_enabled: + dp_mod_ep_mesh_dim_names.append("dp_replicate") + dp_mod_ep_mesh_dim_names.append("dp_shard_1") + apply_fsdp( model, world_mesh[tuple(dp_mesh_dim_names)], @@ -107,6 +122,8 @@ def parallelize_llama( reduce_dtype=TORCH_DTYPE_MAP[job_config.training.mixed_precision_reduce], pp_enabled=parallel_dims.pp_enabled, cpu_offload=job_config.training.enable_cpu_offload, + ep_enabled=(parallel_dims.ep_mode == "dp2ep"), + dp_mod_ep_mesh=world_mesh[tuple(dp_mod_ep_mesh_dim_names)], ) if parallel_dims.dp_replicate_enabled: @@ -226,11 +243,15 @@ def apply_tp( def apply_ep( model: nn.Module, ep_mode: str, + ep_mesh: Optional[DeviceMesh] = None, tp_mesh: Optional[DeviceMesh] = None, + ep_tp_mesh: Optional[DeviceMesh] = None, ): from torch.distributed.tensor import Partial + from torch.distributed.tensor.parallel import PrepareModuleOutput from torchtitan.parallelisms.expert_parallel import ( ExpertParallel, + ExpertTensorParallel, PrepareModuleInputOutput, TensorParallel, ) @@ -286,6 +307,57 @@ def apply_ep( parallelize_plan=moe_plan, ) + elif ep_mode == "dp2ep": + if not tp_mesh: + assert ep_mesh is not None + + for _, transformer_block in model.layers.items(): + parallelize_module( + module=transformer_block.moe.experts, + device_mesh=ep_mesh, + # input / output sharding on the tokens dim + parallelize_plan=ExpertParallel( + input_layouts=Shard(1), + output_layouts=Shard(1), + ), + ) + + else: # dp2ep with TP (no Router Parallel) + assert ep_tp_mesh is not None + + for _, transformer_block in model.layers.items(): + moe_plan = { + # input / output sharding on the seqlen dim + "moe": PrepareModuleInputOutput( + input_layouts=(Shard(1),), + desired_input_layouts=(Replicate(),), + output_layouts=(Partial(),), + desired_output_layouts=(Shard(1),), + ), + # no Router Parallel + # NOTE: still need to explicitly or implicitly turn the router into DTensor + # for gradient clippint and optimizer to use DTensor foreach + # top_scores, selected_token_indices shareded on the seqlen dim + "moe.router": PrepareModuleOutput( + output_layouts=(Replicate(), Replicate()), + desired_output_layouts=(Shard(1), Shard(1)), + ), + "moe.shared_expert": TensorParallel(), + } + parallelize_module( + module=transformer_block, + device_mesh=tp_mesh, + parallelize_plan=moe_plan, + ) + + parallelize_module( + module=transformer_block.moe.experts, + device_mesh=ep_tp_mesh, + parallelize_plan=ExpertTensorParallel( + tp_mesh=tp_mesh, ep_mesh=ep_tp_mesh + ), + ) + logger.info(f"Applied {ep_mode} Expert Parallelism to the model") @@ -379,7 +451,7 @@ def apply_compile(model: nn.Module): repeated structure. Alternatively one can compile the whole model (after applying DP). """ for layer_id, transformer_block in model.layers.named_children(): - transformer_block = torch.compile(transformer_block, fullgraph=True) + transformer_block = torch.compile(transformer_block, fullgraph=False) model.layers.register_module(layer_id, transformer_block) logger.info("Compiling each TransformerBlock with torch.compile") @@ -392,6 +464,8 @@ def apply_fsdp( reduce_dtype: torch.dtype, pp_enabled: bool, cpu_offload: bool = False, + ep_enabled: bool = False, + dp_mod_ep_mesh: Optional[DeviceMesh] = None, ): """ Apply data parallelism to the model. FSDP2 is used here. @@ -410,6 +484,16 @@ def apply_fsdp( # As an optimization, do not reshard after forward for the last # transformer block since FSDP would prefetch it immediately reshard_after_forward = int(layer_id) < len(model.layers) - 1 + + fsdp_mod_ep_config = fsdp_config.copy() + fsdp_mod_ep_config["mesh"] = dp_mod_ep_mesh + if ep_enabled: + fully_shard( + transformer_block.moe.experts, + **fsdp_mod_ep_config, + reshard_after_forward=reshard_after_forward, + ) + fully_shard( transformer_block, **fsdp_config, diff --git a/train.py b/train.py index 84417b5b..2c71b0e2 100644 --- a/train.py +++ b/train.py @@ -50,6 +50,8 @@ def main(job_config: JobConfig): cp=job_config.experimental.context_parallel_degree, tp=job_config.training.tensor_parallel_degree, pp=job_config.experimental.pipeline_parallel_degree, + ep=job_config.experimental.expert_parallel_degree, + ep_mode=job_config.experimental.expert_parallel_mode, world_size=world_size, enable_loss_parallel=job_config.training.enable_loss_parallel, ) @@ -299,12 +301,12 @@ def loss_fn(pred, labels): loss.backward() # clip gradients - utils.clip_grad_norm_( - [p for m in model_parts for p in m.parameters()], - job_config.training.max_norm, - foreach=True, - pp_mesh=pp_mesh if parallel_dims.pp_enabled else None, - ) + # utils.clip_grad_norm_( + # [p for m in model_parts for p in m.parameters()], + # job_config.training.max_norm, + # foreach=True, + # pp_mesh=pp_mesh if parallel_dims.pp_enabled else None, + # ) # sync float8 amaxes and scales float8_handler.sync_float8_amax_and_scale_history(model_parts) diff --git a/train_configs/debug_model.toml b/train_configs/debug_model.toml index 276c8e71..318c9655 100644 --- a/train_configs/debug_model.toml +++ b/train_configs/debug_model.toml @@ -46,7 +46,8 @@ dataset = "c4_test" # supported datasets: c4_test (2K), c4 (177M) context_parallel_degree = 1 pipeline_parallel_degree = 1 enable_async_tensor_parallel = false -expert_parallel_mode = "tp2ep" +expert_parallel_degree = 8 +expert_parallel_mode = "dp2ep" [checkpoint] enable_checkpoint = false