From 60ea9f2f47868dac4d2c42ca05b1cfe9a682ea9c Mon Sep 17 00:00:00 2001 From: Howard Huang Date: Tue, 4 Feb 2025 10:58:12 -0800 Subject: [PATCH] Support ZBVZeroBubbleSchedule --- tests/integration_tests.py | 12 ++++++++++++ torchtitan/parallelisms/pipeline_llama.py | 24 ++++++++++++++++++++--- train.py | 20 ++++++++++--------- 3 files changed, 44 insertions(+), 12 deletions(-) diff --git a/tests/integration_tests.py b/tests/integration_tests.py index 7048439c..b7f27c3a 100755 --- a/tests/integration_tests.py +++ b/tests/integration_tests.py @@ -149,6 +149,18 @@ def build_test_list(): "pp_looped_zero_bubble", ngpu=4, ), + OverrideDefinitions( + [ + [ + "--experimental.pipeline_parallel_degree 2", + "--experimental.pipeline_parallel_schedule ZBVZeroBubble", + "--experimental.pipeline_parallel_microbatches 8", + ], + ], + "PP zero bubble test (v shaped)", + "pp_zbv", + ngpu=2, + ), OverrideDefinitions( [ [ diff --git a/torchtitan/parallelisms/pipeline_llama.py b/torchtitan/parallelisms/pipeline_llama.py index 6605a57d..8fe892ab 100644 --- a/torchtitan/parallelisms/pipeline_llama.py +++ b/torchtitan/parallelisms/pipeline_llama.py @@ -13,7 +13,10 @@ import torch.nn as nn from torch.distributed import DeviceMesh from torch.distributed.pipelining import PipelineStage - +from torch.distributed.pipelining.schedules import ( + get_schedule_class, + ScheduleZBVZeroBubble, +) from torchtitan.config_manager import JobConfig from torchtitan.logging import logger from torchtitan.models.llama.model import ModelArgs @@ -43,7 +46,16 @@ def pipeline_llama( pp_schedule = build_pipeline_schedule(job_config, stages, loss_fn) - return pp_schedule, models + # This is used in the train loop to determine whether to pass in the input_ids and labels + has_first_stage = False + has_last_stage = False + for stage in stages: + if stage.is_first: + has_first_stage = True + if stage.is_last: + has_last_stage = True + + return pp_schedule, models, has_first_stage, has_last_stage def pipeline_llama_manual_split( @@ -103,7 +115,13 @@ def _build_stage(stage_idx, start_layer, stop_layer, is_first=False, is_last=Fal stages = [] models = [] - for stage_idx in stage_ids_this_rank(pp_rank, pp_size, num_stages, style="loop"): + + schedule_class = get_schedule_class( + job_config.experimental.pipeline_parallel_schedule + ) + style = "v" if schedule_class == ScheduleZBVZeroBubble else "loop" + + for stage_idx in stage_ids_this_rank(pp_rank, pp_size, num_stages, style=style): start_layer = splits[stage_idx - 1] if stage_idx > 0 else None stop_layer = splits[stage_idx] if stage_idx < num_stages - 1 else None stage, model_chunk = _build_stage( diff --git a/train.py b/train.py index bac22772..1033675e 100644 --- a/train.py +++ b/train.py @@ -151,7 +151,12 @@ def loss_fn(pred, labels): # apply parallelisms and initialization if parallel_dims.pp_enabled: # apply PT-D Pipeline Parallel - pp_schedule, model_parts = models_pipelining_fns[model_name]( + ( + pp_schedule, + model_parts, + has_first_stage, + has_last_stage, + ) = models_pipelining_fns[model_name]( model, pp_mesh, parallel_dims, job_config, device, model_config, loss_fn ) # when PP is enabled, `model` obj is no longer used after this point, model_parts is used instead @@ -285,14 +290,11 @@ def loss_fn(pred, labels): if parallel_dims.pp_enabled: # Pipeline Parallel forward / backward inside step() call - is_last_stage = pp_mesh.get_local_rank() == pp_mesh.size() - 1 - with train_context(optional_context_parallel_ctx): - if pp_mesh.get_local_rank() == 0: - pp_schedule.step(input_ids) - elif is_last_stage: - losses = [] - pp_schedule.step(target=labels, losses=losses) + targets = labels if has_last_stage else None + losses = [] if has_last_stage else None + if has_first_stage: + pp_schedule.step(input_ids, target=targets, losses=losses) else: pp_schedule.step() @@ -300,7 +302,7 @@ def loss_fn(pred, labels): # TODO: PP+FSDP unexpectedly puts the loss back to the CPU loss = ( torch.mean(torch.stack(losses)).to(device) - if is_last_stage + if has_last_stage else torch.tensor([-1.0], device=device) ) else: