diff --git a/.ci/docker/requirements.txt b/.ci/docker/requirements.txt index 2321627e1..9bea6f9e0 100644 --- a/.ci/docker/requirements.txt +++ b/.ci/docker/requirements.txt @@ -6,3 +6,5 @@ sentencepiece tiktoken blobfile tabulate +pwlf +pulp diff --git a/scripts/estimate/estimation.py b/scripts/estimate/estimation.py index 09b1ce4f6..a05fde623 100644 --- a/scripts/estimate/estimation.py +++ b/scripts/estimate/estimation.py @@ -7,10 +7,18 @@ import contextlib import gc import os +from typing import Any, Set, Union import torch +from torch import nn, optim from torch._guards import active_fake_mode from torch._subclasses.fake_tensor import FakeTensorMode +from torch.distributed._tools.auto_sac import ( + AutoSACResult, + get_auto_sac_policies, + get_module_name_dict, + SACAlgorithm, +) from torch.distributed._tools.fsdp2_mem_tracker import FSDPMemTracker from torch.testing._internal.distributed.fake_pg import FakeStore @@ -18,14 +26,21 @@ from torchtitan.config_manager import JobConfig from torchtitan.datasets import build_tokenizer from torchtitan.float8 import Float8Handler -from torchtitan.logging import init_logger, logger +from torchtitan.logging import init_logger, logger, logging from torchtitan.models import model_name_to_cls, model_name_to_tokenizer, models_config from torchtitan.optimizer import build_lr_schedulers, build_optimizers from torchtitan.parallelisms import models_parallelize_fns, ParallelDims -def estimate_memory(job_config: JobConfig): - init_logger() +def estimate(job_config: JobConfig) -> Union[AutoSACResult, None]: + assert not ( + job_config.memory_estimation.enabled and job_config.sac_estimation.enabled + ), "Enabling SAC estimation and FSDP memory estimation together is not permitted." + if job_config.memory_estimation.enabled: + init_logger() + else: + logging.disable() + logger.info("Estimating memory usage...") gc.disable() gc.collect(1) @@ -37,10 +52,9 @@ def estimate_memory(job_config: JobConfig): if ( job_config.training.tensor_parallel_degree > 1 or job_config.experimental.pipeline_parallel_degree > 1 + or job_config.experimental.context_parallel_degree > 1 ): - logger.info( - "Tensor parallelism and pipeline parallelism are not supported yet." - ) + logger.info("Tensor, Context and Pipeline parallelism are not supported yet.") return # fake tensor doesn't work with fused rmsnorm @@ -76,25 +90,12 @@ def estimate_memory(job_config: JobConfig): device = torch.device(f"cuda:{int(os.environ['LOCAL_RANK'])}") torch.cuda.set_device(device) - # init fake pg - store = FakeStore() - torch.distributed.init_process_group( - "fake", rank=int(os.environ["LOCAL_RANK"]), world_size=world_size, store=store - ) - - # build meshes - world_mesh = parallel_dims.build_mesh(device_type="cuda") - if not parallel_dims.dp_enabled: logger.info("Data parallelism is not enabled. Skipping memory estimation.") return model_name = job_config.model.name - # build tokenizer - tokenizer_type = model_name_to_tokenizer[model_name] - tokenizer = build_tokenizer(tokenizer_type, job_config.model.tokenizer_path) - train_context = utils.get_train_context( parallel_dims.loss_parallel_enabled, job_config.experimental.enable_compiled_autograd, @@ -111,11 +112,28 @@ def loss_fn(pred, labels): model_config = models_config[model_name][job_config.model.flavor] # set the model configs from training inputs: # 1. norm type to decide which norm layer to use - # 2. vocab size from tokenizer - # 3. max_seq_len base on inputs + # 2. max_seq_len base on inputs + # 3. vocab size from tokenizer + model_config.norm_type = job_config.model.norm_type - model_config.vocab_size = tokenizer.n_words model_config.max_seq_len = job_config.training.seq_len + model_config.vocab_size = 128256 + if not job_config.sac_estimation.enabled: + # build tokenizer + tokenizer_type = model_name_to_tokenizer[model_name] + tokenizer = build_tokenizer(tokenizer_type, job_config.model.tokenizer_path) + model_config.vocab_size = tokenizer.n_words + # init fake pg + store = FakeStore() + torch.distributed.init_process_group( + "fake", + rank=int(os.environ["LOCAL_RANK"]), + world_size=world_size, + store=store, + ) + + # build meshes + world_mesh = parallel_dims.build_mesh(device_type="cuda") with FakeTensorMode() if not job_config.memory_estimation.disable_fake_mode else contextlib.nullcontext(): @@ -130,8 +148,11 @@ def loss_fn(pred, labels): # swap to Float8Linear based on float8 configs float8_handler.convert_to_float8_training(model) - # apply PT-D DP/TP parallelisms and activation checkpointing - models_parallelize_fns[model_name](model, world_mesh, parallel_dims, job_config) + if not job_config.sac_estimation.enabled: + # apply PT-D DP/TP parallelisms and activation checkpointing + models_parallelize_fns[model_name]( + model, world_mesh, parallel_dims, job_config + ) model.to_empty(device="cuda") if not active_fake_mode(): @@ -158,32 +179,74 @@ def loss_fn(pred, labels): device="cuda", ), ) + + def train_step(models: nn.Module, optims: optim.Optimizer, batch: Any): + # train step + input_ids, labels = batch + with train_context(): + pred = model(input_ids) + loss = loss_fn(pred, labels) + del pred + loss.backward() + + # clip gradients + torch.nn.utils.clip_grad_norm_( + model.parameters(), job_config.training.max_norm, foreach=True + ) + # sync float8 amaxes and scales + float8_handler.sync_float8_amax_and_scale_history(model) + # optimizer step + optimizers.step() + lr_schedulers.step() + # calculate float8 dynamic amax/scale for all-parameter for FSDP2 + # it issues a single all-reduce for all parameters at once for better performance + float8_handler.precompute_float8_dynamic_scale_for_fsdp(model) + optimizers.zero_grad() + + if job_config.sac_estimation.enabled: + logging.disable(logging.NOTSET) + gib = 1024**3 + budget = float(job_config.activation_checkpoint.auto_sac_budget) + recommended_budget = ( + 0.85 * torch.cuda.get_device_properties(device).total_memory / gib + ) + if budget > recommended_budget: + logger.warning( + "It is recommended to set Auto-SAC memory budget to 85 percent of device memory.\n" + "Current budget is %.2f GiB, reducing it to %.2f GiB.", + budget, + recommended_budget, + ) + budget = recommended_budget + + mod_fqns = get_module_name_dict(model) + fsdp_unit_fqns: Set[str] = set() + for transformer_block in model.layers.values(): + fsdp_unit_fqns.add(mod_fqns[transformer_block]) + fsdp_unit_fqns.add(mod_fqns[model]) + + sac_algorithm = SACAlgorithm( + job_config.activation_checkpoint.auto_sac_algorithm + ) + auto_sac_result = get_auto_sac_policies( + train_step=train_step, + models=[model], + optimizers=optimizers.optimizers, + inputs=batch, + dev=device, + memory_budget=budget, + sac_algo=sac_algorithm, + shard_degree=parallel_dims.dp_shard, + fsdp_units=fsdp_unit_fqns, + ) + return auto_sac_result + fsdp_memtracker = FSDPMemTracker(mod=model, optm=optimizers.optimizers[0]) fsdp_memtracker.track_inputs(batch) with fsdp_memtracker: for iter_idx in range(2): - input_ids, labels = batch - # train step - with train_context(): - pred = model(input_ids) - loss = loss_fn(pred, labels) - del pred - loss.backward() - - # clip gradients - torch.nn.utils.clip_grad_norm_( - model.parameters(), job_config.training.max_norm, foreach=True - ) - # sync float8 amaxes and scales - float8_handler.sync_float8_amax_and_scale_history(model) - # optimizer step - optimizers.step() - lr_schedulers.step() - # calculate float8 dynamic amax/scale for all-parameter for FSDP2 - # it issues a single all-reduce for all parameters at once for better performance - float8_handler.precompute_float8_dynamic_scale_for_fsdp(model) - optimizers.zero_grad() + train_step([model], optimizers.optimizers, batch) print(f"Peak Memory at iter: {iter_idx}") fsdp_memtracker.display_snapshot("peak", units="MiB", tabulate=True) if iter_idx == 0: @@ -214,7 +277,7 @@ def loss_fn(pred, labels): config = JobConfig() config.parse_args() try: - estimate_memory(config) + estimate(config) finally: if torch.distributed.is_initialized(): torch.distributed.destroy_process_group() diff --git a/scripts/generate/test_generate.py b/scripts/generate/test_generate.py index b069e8bdd..54e01b429 100644 --- a/scripts/generate/test_generate.py +++ b/scripts/generate/test_generate.py @@ -26,14 +26,14 @@ ) from torchtitan import utils -from torchtitan.utils import device_module, device_type from torchtitan.config_manager import JobConfig from torchtitan.datasets import build_tokenizer from torchtitan.logging import init_logger, logger -from torchtitan.metrics import build_device_memory_monitor, build_metric_logger +from torchtitan.metrics import build_device_memory_monitor from torchtitan.models import model_name_to_cls, model_name_to_tokenizer, models_config from torchtitan.parallelisms import ParallelDims +from torchtitan.utils import device_module, device_type # support running w/o installing as package wd = Path(__file__).parent.parent.resolve() @@ -143,7 +143,8 @@ def test_generate( # Build world mesh for parallelism world_mesh = parallel_dims.build_mesh(device_type=device_type) - # apply_tp (with Sequence Parallel) on unevenly sharded sequences would require https://github.com/pytorch/torchtitan/pull/686 + # apply_tp (with Sequence Parallel) on unevenly sharded sequences would require + # https://github.com/pytorch/torchtitan/pull/686 apply_tp_minus_sp(model, world_mesh["tp"]) # materalize model diff --git a/torchtitan/config_manager.py b/torchtitan/config_manager.py index 814bd80fc..01045626f 100644 --- a/torchtitan/config_manager.py +++ b/torchtitan/config_manager.py @@ -479,7 +479,7 @@ def __init__(self): "--activation_checkpoint.mode", type=str, default="selective", - help="Type of activation checkpointing to use ['none', 'full', 'selective']", + help="Type of activation checkpointing to use ['none', 'full', 'selective', 'auto']", ) self.parser.add_argument( "--activation_checkpoint.selective_ac_option", @@ -490,6 +490,25 @@ def __init__(self): 'int' (e.g., 2) for every nth layer, or 'op' for op level ac. """, ) + self.parser.add_argument( + "--activation_checkpoint.auto_sac_budget", + type=str, + default="65.0", + help=""" + Auto-SAC Memory Budget in GiB. + Recommended to set 85 percent of total device memory. + """, + ) + self.parser.add_argument( + "--activation_checkpoint.auto_sac_algorithm", + type=str, + default="optimal", + choices=["greedy", "optimal"], + help=""" + Algorithm to use for determining SAC policies. + `greedy` runs in linear time, while `optimal` solves an ILP. + """, + ) # float8 configs self.parser.add_argument( @@ -570,6 +589,13 @@ def __init__(self): action="store_true", ) + self.parser.add_argument( + "--sac_estimation.enabled", + help="Whether to calculate SAC (Selective Activation Checkpointing) policies", + default=False, + action="store_true", + ) + def parse_args(self, args_list: list = sys.argv[1:]): args, cmd_args = self.parse_args_from_command_line(args_list) config_file = getattr(args, "job.config_file", None) diff --git a/torchtitan/parallelisms/parallelize_llama.py b/torchtitan/parallelisms/parallelize_llama.py index 4d4c60bc0..c0945c286 100644 --- a/torchtitan/parallelisms/parallelize_llama.py +++ b/torchtitan/parallelisms/parallelize_llama.py @@ -8,6 +8,7 @@ # training techniques (e.g. activation checkpointing and compile) to the Llama model. from collections import defaultdict +from copy import deepcopy import torch import torch.nn as nn @@ -20,6 +21,7 @@ ) from torch.distributed._composable.replicate import replicate from torch.distributed._tensor import Replicate, Shard +from torch.distributed._tools.auto_sac import apply_auto_sac_policies from torch.distributed.algorithms._checkpoint.checkpoint_wrapper import ( checkpoint_wrapper as ptd_checkpoint_wrapper, ) @@ -66,7 +68,13 @@ def parallelize_llama( ) if job_config.activation_checkpoint.mode != "none": - apply_ac(model, job_config.activation_checkpoint) + if job_config.activation_checkpoint.mode == "auto": + if not apply_auto_sac(model, job_config): + logger.info("Auto-SAC failed, falling back to full AC mode.") + job_config.activation_checkpoint.mode = "full" + apply_ac(model, job_config.activation_checkpoint) + else: + apply_ac(model, job_config.activation_checkpoint) # turn on per-TransformerBlock compile after AC wrapping and before FSDP if job_config.training.compile: @@ -314,6 +322,41 @@ def apply_ac(model: nn.Module, ac_config): logger.info(f"Applied {ac_config.mode} activation checkpointing to the model") +def apply_auto_sac(model: nn.Module, job_config: JobConfig) -> bool: + if ( + job_config.training.tensor_parallel_degree > 1 + or job_config.experimental.pipeline_parallel_degree > 1 + or job_config.experimental.context_parallel_degree > 1 + or job_config.training.enable_cpu_offload + ): + logger.info( + "Tensor, Context and Pipeline parallelism or FSDP with CPU Offload option" + " are not supported yet with Auto-SAC." + ) + return False + est_job_config = deepcopy(job_config) + est_job_config.memory_estimation.disable_fake_mode = False + est_job_config.memory_estimation.enabled = False + est_job_config.sac_estimation.enabled = True + est_job_config.training.compile = False + est_job_config.experimental.enable_compiled_autograd = False + if ( + est_job_config.model.norm_type == "compiled_rmsnorm" + or est_job_config.model.norm_type == "fused_rmsnorm" + ): + est_job_config.model.norm_type = "rmsnorm" + from scripts.estimate.estimation import estimate + + auto_sac_result = estimate(est_job_config) + assert auto_sac_result is not None + if auto_sac_result.peak_mem == -1: + return False + apply_auto_sac_policies( + model, auto_sac_result.sac_policies, preserve_rng_state=False + ) + return True + + def apply_compile(model: nn.Module): """ Apply torch.compile to each TransformerBlock, which makes compilation efficient due to