Skip to content

Commit

Permalink
Support ZBVZeroBubbleSchedule
Browse files Browse the repository at this point in the history
  • Loading branch information
H-Huang committed Feb 4, 2025
1 parent d4c86e3 commit 60ea9f2
Show file tree
Hide file tree
Showing 3 changed files with 44 additions and 12 deletions.
12 changes: 12 additions & 0 deletions tests/integration_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -149,6 +149,18 @@ def build_test_list():
"pp_looped_zero_bubble",
ngpu=4,
),
OverrideDefinitions(
[
[
"--experimental.pipeline_parallel_degree 2",
"--experimental.pipeline_parallel_schedule ZBVZeroBubble",
"--experimental.pipeline_parallel_microbatches 8",
],
],
"PP zero bubble test (v shaped)",
"pp_zbv",
ngpu=2,
),
OverrideDefinitions(
[
[
Expand Down
24 changes: 21 additions & 3 deletions torchtitan/parallelisms/pipeline_llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,10 @@
import torch.nn as nn
from torch.distributed import DeviceMesh
from torch.distributed.pipelining import PipelineStage

from torch.distributed.pipelining.schedules import (
get_schedule_class,
ScheduleZBVZeroBubble,
)
from torchtitan.config_manager import JobConfig
from torchtitan.logging import logger
from torchtitan.models.llama.model import ModelArgs
Expand Down Expand Up @@ -43,7 +46,16 @@ def pipeline_llama(

pp_schedule = build_pipeline_schedule(job_config, stages, loss_fn)

return pp_schedule, models
# This is used in the train loop to determine whether to pass in the input_ids and labels
has_first_stage = False
has_last_stage = False
for stage in stages:
if stage.is_first:
has_first_stage = True
if stage.is_last:
has_last_stage = True

return pp_schedule, models, has_first_stage, has_last_stage


def pipeline_llama_manual_split(
Expand Down Expand Up @@ -103,7 +115,13 @@ def _build_stage(stage_idx, start_layer, stop_layer, is_first=False, is_last=Fal

stages = []
models = []
for stage_idx in stage_ids_this_rank(pp_rank, pp_size, num_stages, style="loop"):

schedule_class = get_schedule_class(
job_config.experimental.pipeline_parallel_schedule
)
style = "v" if schedule_class == ScheduleZBVZeroBubble else "loop"

for stage_idx in stage_ids_this_rank(pp_rank, pp_size, num_stages, style=style):
start_layer = splits[stage_idx - 1] if stage_idx > 0 else None
stop_layer = splits[stage_idx] if stage_idx < num_stages - 1 else None
stage, model_chunk = _build_stage(
Expand Down
20 changes: 11 additions & 9 deletions train.py
Original file line number Diff line number Diff line change
Expand Up @@ -151,7 +151,12 @@ def loss_fn(pred, labels):
# apply parallelisms and initialization
if parallel_dims.pp_enabled:
# apply PT-D Pipeline Parallel
pp_schedule, model_parts = models_pipelining_fns[model_name](
(
pp_schedule,
model_parts,
has_first_stage,
has_last_stage,
) = models_pipelining_fns[model_name](
model, pp_mesh, parallel_dims, job_config, device, model_config, loss_fn
)
# when PP is enabled, `model` obj is no longer used after this point, model_parts is used instead
Expand Down Expand Up @@ -285,22 +290,19 @@ def loss_fn(pred, labels):

if parallel_dims.pp_enabled:
# Pipeline Parallel forward / backward inside step() call
is_last_stage = pp_mesh.get_local_rank() == pp_mesh.size() - 1

with train_context(optional_context_parallel_ctx):
if pp_mesh.get_local_rank() == 0:
pp_schedule.step(input_ids)
elif is_last_stage:
losses = []
pp_schedule.step(target=labels, losses=losses)
targets = labels if has_last_stage else None
losses = [] if has_last_stage else None
if has_first_stage:
pp_schedule.step(input_ids, target=targets, losses=losses)
else:
pp_schedule.step()

# accumulate losses across pipeline microbatches
# TODO: PP+FSDP unexpectedly puts the loss back to the CPU
loss = (
torch.mean(torch.stack(losses)).to(device)
if is_last_stage
if has_last_stage
else torch.tensor([-1.0], device=device)
)
else:
Expand Down

0 comments on commit 60ea9f2

Please sign in to comment.