Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[MoE][PoC] Expert Parallel: dp2ep #732

Draft
wants to merge 2 commits into
base: gh/tianyu-l/26/base
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 13 additions & 5 deletions torchtitan/config_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -375,15 +375,23 @@ def __init__(self):
The default value is 'allgather'.
""",
)
self.parser.add_argument(
"--experimental.expert_parallel_degree",
type=int,
default=1,
help="""
Expert parallelism degree. 1 means disabled.
When expert_parallel_mode is 'tp' or 'tp2ep', it has to be equal to tensor_parallel_degree.
When expert_parallel_mode is 'dp2ep', it has to be k * context_parallel_degree,
where k >= 1 and k | data_parallel_shard_degree.
""",
)
self.parser.add_argument(
"--experimental.expert_parallel_mode",
type=str,
default="none",
choices=["none", "tp", "tp2ep"],
help="""
Expert Parallel mode.
'tp2ep' would use the entire TP mesh to shard non-shared experts on the num_experts dimension.
""",
choices=["none", "tp", "tp2ep", "dp2ep"],
help="Expert Parallel mode",
)
self.parser.add_argument(
"--training.mixed_precision_param",
Expand Down
2 changes: 1 addition & 1 deletion torchtitan/optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -146,7 +146,7 @@ def build_optimizers(
"betas": (0.9, 0.95),
"weight_decay": 0.1,
"fused": fused,
"foreach": not fused,
"foreach": False,
}

return (
Expand Down
119 changes: 119 additions & 0 deletions torchtitan/parallelisms/expert_parallel.py
Original file line number Diff line number Diff line change
Expand Up @@ -325,3 +325,122 @@ def _apply(self, module: nn.Module, device_mesh: DeviceMesh) -> nn.Module:
self._prepare_output_fn, self.output_layouts, self.use_local_output
),
)


# This class is for dp2ep with TP (without TP we can just use ExpertParallel)
class ExpertTensorParallel(ParallelStyle):
def __init__(
self,
*,
tp_mesh: DeviceMesh,
ep_mesh: DeviceMesh,
):
super().__init__()
# TODO: has to pass in the meshes in addition to device_mesh,
# as there's an issue from DeviceMesh that
# "Cannot create a submesh from a submesh."
self.tp_mesh = tp_mesh
self.ep_mesh = ep_mesh

@staticmethod
def _prepare_input_fn(tp_mesh, ep_mesh, mod, inputs, device_mesh):
input_tensor = inputs[0]
# input_tensor of placements Shard(1) on the tp mesh
assert not isinstance(input_tensor, DTensor)

# a2a(ep)
input_tensor = DTensor.from_local(input_tensor, ep_mesh, (Shard(1),))
input_tensor = input_tensor.redistribute(placements=(Shard(0),)).to_local()
# ag(tp)
input_tensor = DTensor.from_local(input_tensor, tp_mesh, (Shard(1),))
input_tensor = input_tensor.redistribute(placements=(Replicate(),))

return input_tensor

@staticmethod
def _partition_fn(tp_mesh, ep_mesh, name, module, device_mesh):
# TODO: FSDP doesn't support sharding a 2D Tensor into a 3D one yet
# module.register_parameter(
# "gate_proj",
# nn.Parameter(
# distribute_tensor(module.gate_proj, device_mesh, [Shard(0), Shard(2)])
# ),
# ) # Column-wise sharding
# module.register_parameter(
# "down_proj",
# nn.Parameter(
# distribute_tensor(module.down_proj, device_mesh, [Shard(0), Shard(1)])
# ),
# ) # Row-wise sharding
# module.register_parameter(
# "up_proj",
# nn.Parameter(
# distribute_tensor(module.up_proj, device_mesh, [Shard(0), Shard(2)])
# ),
# ) # Column-wise sharding

# TODO: Instead, for MoE experts, we shard on the EP mesh and then "forget" it.
# This would become an issue from DCP resharding perspective.
module.register_parameter(
"gate_proj",
nn.Parameter(
DTensor.from_local(
(
distribute_tensor(
module.gate_proj, device_mesh, [Shard(0), Shard(2)]
).to_local()
),
tp_mesh,
(Shard(2),),
)
),
) # Column-wise sharding
module.register_parameter(
"down_proj",
nn.Parameter(
DTensor.from_local(
(
distribute_tensor(
module.down_proj, device_mesh, [Shard(0), Shard(1)]
).to_local()
),
tp_mesh,
(Shard(1),),
)
),
) # Row-wise sharding
module.register_parameter(
"up_proj",
nn.Parameter(
DTensor.from_local(
(
distribute_tensor(
module.up_proj, device_mesh, [Shard(0), Shard(2)]
).to_local()
),
tp_mesh,
(Shard(2),),
)
),
) # Column-wise sharding

@staticmethod
def _prepare_output_fn(tp_mesh, ep_mesh, mod, outputs, device_mesh):
# outputs of placements Partial() on the tp mesh

# rs(tp)
outputs = outputs.redistribute(placements=(Shard(1),)).to_local()
# a2a(ep)
outputs = DTensor.from_local(outputs, ep_mesh, (Shard(0),))
outputs = outputs.redistribute(placements=(Shard(1),)).to_local()

return outputs

def _apply(self, module: nn.Module, device_mesh: DeviceMesh) -> nn.Module:
return distribute_module(
module,
device_mesh,
partial(self._partition_fn, self.tp_mesh, self.ep_mesh),
partial(self._prepare_input_fn, self.tp_mesh, self.ep_mesh),
partial(self._prepare_output_fn, self.tp_mesh, self.ep_mesh),
)
80 changes: 78 additions & 2 deletions torchtitan/parallelisms/parallel_dims.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,21 +18,24 @@ class ParallelDims:
cp: int
tp: int
pp: int
ep: int
ep_mode: str
world_size: int
enable_loss_parallel: bool

def __post_init__(self):
self._validate()

def _validate(self):
dp_replicate, dp_shard, cp, tp, pp = (
dp_replicate, dp_shard, cp, tp, pp, ep = (
self.dp_replicate,
self.dp_shard,
self.cp,
self.tp,
self.pp,
self.ep,
)
for d in (dp_replicate, cp, tp, pp):
for d in (dp_replicate, cp, tp, pp, ep):
assert d >= 1, "Parallelism degree should be >= 1, except for dp_shard"

assert dp_shard == -1 or dp_shard >= 1, " dp_shard must -1 or >=1."
Expand All @@ -45,7 +48,80 @@ def _validate(self):
f"cp({cp}) * tp({tp}) * pp({pp}) != WORLD_SIZE({self.world_size})"
)

if ep > 1:
assert self.ep_mode in ["tp", "tp2ep", "dp2ep"]
if self.ep_mode == "tp" or self.ep_mode == "tp2ep":
assert ep == tp
elif self.ep_mode == "dp2ep":
# EP would borrow all cp and some dp_shard degree
assert ep % cp == 0 and (dp_shard * cp) % ep == 0
else:
self.ep_mode = "none"

def _build_mesh_with_dp2ep(self, device_type):
# In dp2ep, dp_shard and ep are derived submeshes:
# dp_shard = dp_shard_1 * dp_shard_2
# ep = dp_shard_2 * cp
dp_shard_1 = self.dp_shard * self.cp // self.ep
dp_shard_2 = self.ep // self.cp

dims = []
names = []
for d, name in zip(
[self.pp, self.dp_replicate, dp_shard_1, dp_shard_2, self.cp, self.tp],
["pp", "dp_replicate", "dp_shard_1", "dp_shard_2", "cp", "tp"],
):
# dp_shard_1 is needed even if it's 1, whose FSDP wrapping
# helps the MoE layers do mixed precision training
if d > 1 or name == "dp_shard_1":
dims.append(d)
names.append(name)

logger.info(f"Building {len(dims)}-D device mesh with {names}, {dims}")
names = tuple(names)
mesh = init_device_mesh(device_type, dims, mesh_dim_names=names)

# Create all the submesh here to ensure all required process groups are
# initialized:
# Mesh for data loading (no communication on this mesh)
dp_mesh_dim_names = []
# Mesh for param sharding
dp_shard_cp_mesh_dim_names = []
# Mesh for loss all-reduce
dp_cp_mesh_dim_names = []
# Mesh for ep
ep_mesh_dim_names = []

if self.dp_replicate_enabled:
dp_mesh_dim_names.append("dp_replicate")
dp_cp_mesh_dim_names.append("dp_replicate")
# dp_shard_1 is always needed, even if it's 1
dp_mesh_dim_names.append("dp_shard_1")
dp_shard_cp_mesh_dim_names.append("dp_shard_1")
dp_cp_mesh_dim_names.append("dp_shard_1")
if "dp_shard_2" in names:
dp_mesh_dim_names.append("dp_shard_2")
dp_shard_cp_mesh_dim_names.append("dp_shard_2")
dp_cp_mesh_dim_names.append("dp_shard_2")
ep_mesh_dim_names.append("dp_shard_2")
if self.cp_enabled:
dp_shard_cp_mesh_dim_names.append("cp")
dp_cp_mesh_dim_names.append("cp")
ep_mesh_dim_names.append("cp")

mesh[tuple(dp_mesh_dim_names)]._flatten(mesh_dim_name="dp")
mesh[tuple(dp_shard_cp_mesh_dim_names)]._flatten(
mesh_dim_name="dp_shard_cp"
)
mesh[tuple(dp_cp_mesh_dim_names)]._flatten(mesh_dim_name="dp_cp")
mesh[tuple(ep_mesh_dim_names)]._flatten(mesh_dim_name="ep")

return mesh

def build_mesh(self, device_type):
if self.ep_mode == "dp2ep":
return self._build_mesh_with_dp2ep(device_type)

dims = []
names = []
for d, name in zip(
Expand Down
Loading
Loading