diff --git a/torchtitan/config_manager.py b/torchtitan/config_manager.py index 67c82d53..b0dbdb55 100644 --- a/torchtitan/config_manager.py +++ b/torchtitan/config_manager.py @@ -332,8 +332,9 @@ def __init__(self): default="bfloat16", choices=["bfloat16", "float32"], help=""" - torch dtype to use for parameters when applying mixed precision via FSDP. - This feature only takes effect when data_parallel_degree > 1 + torch dtype to use for parameters when applying mixed precision. + When data_parallel_shard_degree > 1, this changes FSDP's `param_dtype`. + When data_parallel_shard_degree == 1, this enables AMP autocast. """, ) self.parser.add_argument( diff --git a/train.py b/train.py index d1973b6d..93beb690 100644 --- a/train.py +++ b/train.py @@ -14,7 +14,7 @@ from torchtitan import utils from torchtitan.checkpoint import CheckpointManager, TrainState -from torchtitan.config_manager import JobConfig +from torchtitan.config_manager import JobConfig, TORCH_DTYPE_MAP from torchtitan.datasets import build_hf_data_loader, build_tokenizer from torchtitan.float8 import Float8Handler from torchtitan.logging import init_logger, logger @@ -288,8 +288,14 @@ def loss_fn(pred, labels): else: # Non-PP forward / backward with train_context(): - pred = model(input_ids) - loss = loss_fn(pred, labels) + with contextlib.nullcontext() if parallel_dims.dp_shard_enabled else torch.autocast( + "cuda", + dtype=TORCH_DTYPE_MAP[ + job_config.training.mixed_precision_param + ], + ): + pred = model(input_ids) + loss = loss_fn(pred, labels) # pred.shape=(bs, seq_len, vocab_size) # need to free to before bwd to avoid peaking memory del pred