From 62d0d951666f71fc547c883232740d066f91550f Mon Sep 17 00:00:00 2001 From: 152334H <54623771+152334H@users.noreply.github.com> Date: Sun, 29 Sep 2024 14:58:11 +0800 Subject: [PATCH 1/4] use autocast for replicate --- train.py | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/train.py b/train.py index d1973b6d..4173b3f1 100644 --- a/train.py +++ b/train.py @@ -14,7 +14,7 @@ from torchtitan import utils from torchtitan.checkpoint import CheckpointManager, TrainState -from torchtitan.config_manager import JobConfig +from torchtitan.config_manager import JobConfig, TORCH_DTYPE_MAP from torchtitan.datasets import build_hf_data_loader, build_tokenizer from torchtitan.float8 import Float8Handler from torchtitan.logging import init_logger, logger @@ -288,8 +288,11 @@ def loss_fn(pred, labels): else: # Non-PP forward / backward with train_context(): - pred = model(input_ids) - loss = loss_fn(pred, labels) + with contextlib.nullcontext() if parallel_dims.dp_shard_enabled else torch.autocast( + "cuda", dtype=TORCH_DTYPE_MAP[job_config.training.mixed_precision_param], + ): + pred = model(input_ids) + loss = loss_fn(pred, labels) # pred.shape=(bs, seq_len, vocab_size) # need to free to before bwd to avoid peaking memory del pred From dae0527e129c52a5c4e6f38923cb8c72f4d26b04 Mon Sep 17 00:00:00 2001 From: 152334H <54623771+152334H@users.noreply.github.com> Date: Sun, 29 Sep 2024 15:02:33 +0800 Subject: [PATCH 2/4] update mixed_precision_param desc --- torchtitan/config_manager.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/torchtitan/config_manager.py b/torchtitan/config_manager.py index 67c82d53..70507dcc 100644 --- a/torchtitan/config_manager.py +++ b/torchtitan/config_manager.py @@ -332,8 +332,9 @@ def __init__(self): default="bfloat16", choices=["bfloat16", "float32"], help=""" - torch dtype to use for parameters when applying mixed precision via FSDP. - This feature only takes effect when data_parallel_degree > 1 + torch dtype to use for parameters when applying mixed precision. + When data_parallel_degree > 1, this changes FSDP's `param_dtype`. + When data_parallel_degree == 1, this enables AMP autocast. """, ) self.parser.add_argument( From 4da73765d6cf345234b2b4aba48eacb96f0b4e6d Mon Sep 17 00:00:00 2001 From: 152334H <54623771+152334H@users.noreply.github.com> Date: Tue, 1 Oct 2024 03:57:37 +0800 Subject: [PATCH 3/4] possibly fix lint error --- train.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/train.py b/train.py index 4173b3f1..93beb690 100644 --- a/train.py +++ b/train.py @@ -289,7 +289,10 @@ def loss_fn(pred, labels): # Non-PP forward / backward with train_context(): with contextlib.nullcontext() if parallel_dims.dp_shard_enabled else torch.autocast( - "cuda", dtype=TORCH_DTYPE_MAP[job_config.training.mixed_precision_param], + "cuda", + dtype=TORCH_DTYPE_MAP[ + job_config.training.mixed_precision_param + ], ): pred = model(input_ids) loss = loss_fn(pred, labels) From 747c3b59b405b1719cbb64d4f8a6e9048d5e7246 Mon Sep 17 00:00:00 2001 From: 152334H <54623771+152334H@users.noreply.github.com> Date: Fri, 4 Oct 2024 08:31:15 +0800 Subject: [PATCH 4/4] Update config_manager.py --- torchtitan/config_manager.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/torchtitan/config_manager.py b/torchtitan/config_manager.py index 70507dcc..b0dbdb55 100644 --- a/torchtitan/config_manager.py +++ b/torchtitan/config_manager.py @@ -333,8 +333,8 @@ def __init__(self): choices=["bfloat16", "float32"], help=""" torch dtype to use for parameters when applying mixed precision. - When data_parallel_degree > 1, this changes FSDP's `param_dtype`. - When data_parallel_degree == 1, this enables AMP autocast. + When data_parallel_shard_degree > 1, this changes FSDP's `param_dtype`. + When data_parallel_shard_degree == 1, this enables AMP autocast. """, ) self.parser.add_argument(