Skip to content

Commit ff3864f

Browse files
committed
Integrate TorchFT
**Summary** This is a WIP TorchFT integration PR. **Current Issues** This doesn't work at this moment as there are hanged groups when a new group joins. **Issue 1:** ~Group 0 and group 1 will hang during the first `should_commit` after group 1 applying the pending state_dict from group 0.~ Fixed with: pytorch/torchft#83 **Issue 2:** ~Group 0 and group 1 will pass the `should_commit` but group 0 needs healing which is wrong and the healing process will cause another hang.~ Fixed with: pytorch/torchft#83 **Issue 3:** ~The byproduct of issue 1 and issue 2: group 1 will continue to print out~ ``` [rank0]:devgpu051:76838:80357 [0] misc/socket.cc:50 NCCL WARN socketProgress: Connection closed by remote peer devgpu051.cln3.svc.fbinfra.net<33618> ``` Fixed with pytorch/torchft#91 and several other fixes. **Issue 4:** When there are 3 groups, everyone requests the state dict every step. ***How to reproduce?*** Using the `Reproduce steps` to run 2 groups, then add another group by modifying the command. Seems to be fixed, will need more tests. **Issue 5:** Hang will happen if using functional collective. ***How to reproduce?*** Pull the latest version of this PR and comment out line 41 and uncomment line 42 in `torchtitan/utils.py` **Reproduce steps:** 1. Patch TorchFT with pytorch/torchft#82 2. Execute lighthouse 3. Execute the following command in one terminal: ``` TORCHFT_MANAGER_PORT=29520 REPLICA_GROUP_ID=0 CUDA_VISIBLE_DEVICES=0,1 NGPU=2 ./run_llama_train.sh --training.data_parallel_shard_degree=2 --experimental.enable_torchft --experimental.ft_replica_group_id=0 ``` 4. Wait 10 seconds, execute following command in another terminal: ``` TORCHFT_MANAGER_PORT=29522 REPLICA_GROUP_ID=1 CUDA_VISIBLE_DEVICES=2,3 NGPU=2 ./run_llama_train.sh --training.data_parallel_shard_degree=2 --experimental.enable_torchft --experimental.ft_replica_group_id=1 ``` ghstack-source-id: 45c6bee925da6118756135c2447e337c6eaedfe5 Pull Request resolved: #834
1 parent ec82573 commit ff3864f

File tree

10 files changed

+358
-50
lines changed

10 files changed

+358
-50
lines changed

run_train.sh

+3
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,10 @@ if [ $# -ne 0 ]; then
1919
overrides="$*"
2020
fi
2121

22+
TORCHFT_LIGHTHOUSE=${TORCHFT_LIGHTHOUSE:-"http://localhost:29510"}
23+
2224
PYTORCH_CUDA_ALLOC_CONF="expandable_segments:True" \
25+
TORCHFT_LIGHTHOUSE=${TORCHFT_LIGHTHOUSE} \
2326
torchrun --nproc_per_node=${NGPU} --rdzv_backend c10d --rdzv_endpoint="localhost:0" \
2427
--local-ranks-filter ${LOG_RANK} --role rank --tee 3 \
2528
torchtitan/train.py --job.config_file ${CONFIG_FILE} $overrides

tests/unit_tests/test_checkpoint.py

+8
Original file line numberDiff line numberDiff line change
@@ -61,10 +61,18 @@ class DummyJob:
6161
dump_folder: str = "dummy_folder"
6262

6363

64+
@dataclass
65+
class DummyExperimental:
66+
ft_replica_id = 0
67+
ft_group_size = 1
68+
69+
6470
@dataclass
6571
class DummyJobConfig:
6672
checkpoint: DummyCheckpointConfig = field(default_factory=DummyCheckpointConfig)
6773
job: DummyJob = field(default_factory=DummyJob)
74+
experimental: DummyExperimental = field(default_factory=DummyExperimental)
75+
ft_manager = None
6876

6977

7078
# Dummy instances to supply as constructor arguments.

tests/unit_tests/test_model_converter.py

+1
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@ def build_parallel_dims(job_config, world_size):
2222
pp=job_config.experimental.pipeline_parallel_degree,
2323
world_size=world_size,
2424
enable_loss_parallel=not job_config.training.disable_loss_parallel,
25+
ft_manager=None,
2526
)
2627
return parallel_dims
2728

torchtitan/components/checkpoint.py

+118-28
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616
from dataclasses import dataclass, field
1717
from io import BytesIO
1818
from multiprocessing import get_context
19-
from typing import Any, Dict, List, Optional, Union
19+
from typing import Any, Dict, List, Optional, TYPE_CHECKING, Union
2020

2121
import torch
2222
import torch.distributed as dist
@@ -36,6 +36,9 @@
3636
from torchtitan.tools.logging import init_logger, logger
3737
from torchtitan.tools.utils import GarbageCollection
3838

39+
if TYPE_CHECKING:
40+
import torchft as ft
41+
3942

4043
MODEL = "model"
4144
OPTIMIZER = "optimizer"
@@ -214,6 +217,19 @@ class CheckpointManager:
214217
3. LR schedulers also index model states like optimizers. Here we flatten the lr_schedulers
215218
with the assumption that all lr_schedulers have the same state_dict.
216219
220+
Note: TorchFT checkpointing flow
221+
222+
There are two types of checkpoints: when TorchFT is enabled: 1) the full perisistent
223+
checkpoint, 2) the per-replica checkpoint.
224+
225+
The full perisistent checkpoint is saved by the replica with
226+
``ft_manager.participating_rank() == 0``. It contains everything including the model,
227+
optimizer, lr_scheduler, dataloader, and train_state. Right now the full perisistent
228+
checkpoint is loaded by all replicas. However, we can optimize it to only load if
229+
there are no other alive replicas.
230+
231+
The per-replica checkpoint contains only the dataloader and is saved/loaded by all
232+
replicas to/from the its own folder. The folder name is prefixed with the ft_replica_id.
217233
218234
Args:
219235
dataloader (DataLoader): The dataloader used to load the data.
@@ -223,6 +239,7 @@ class CheckpointManager:
223239
states (Dict[str, Any]): The states that need to be saved, other than the
224240
previous 4 components.
225241
job_config (JobConfig): The job config used to configure the checkpointing.
242+
ft_manager (Optional[ft.Manager]): The FTManager from TorchFT.
226243
"""
227244

228245
def __init__(
@@ -233,16 +250,41 @@ def __init__(
233250
lr_schedulers: LRSchedulersContainer,
234251
states: Dict[str, Any],
235252
job_config: JobConfig,
253+
ft_manager: Optional["ft.Manager"] = None,
236254
) -> None:
237255
ckpt_config = job_config.checkpoint
238256
self.enable_checkpoint = ckpt_config.enable_checkpoint
257+
self.ft_manager = ft_manager
258+
259+
if self.ft_manager:
260+
optimizers.init_cache_state_dict()
261+
262+
def state_dict():
263+
ret = {}
264+
for k, v in self.states.items():
265+
if k in {
266+
MODEL,
267+
OPTIMIZER,
268+
LR_SCHEDULER,
269+
TRAIN_STATE,
270+
}:
271+
ret[k] = v.state_dict()
272+
return ret
273+
274+
def load_state_dict(state_dict):
275+
assert state_dict is not None
276+
for k, v in state_dict.items():
277+
self.states[k].load_state_dict(v)
278+
279+
ft_manager.set_state_dict_fns(load_state_dict, state_dict)
280+
self.ft_replica_id = job_config.experimental.ft_replica_id
239281

240282
async_mode = ckpt_config.async_mode.lower()
241283
self.enable_staging = (
242284
self.enable_checkpoint and async_mode == AsyncMode.ASYNC_WITH_PINNED_MEM
243-
)
285+
) or self.ft_manager
244286

245-
if not self.enable_checkpoint:
287+
if not self.enable_checkpoint and self.ft_manager is None:
246288
return
247289

248290
self.states = states
@@ -254,6 +296,13 @@ def __init__(
254296
LR_SCHEDULER: lr_schedulers,
255297
}
256298
)
299+
self.ft_states = {DATALOADER: dataloader}
300+
301+
self.staging = False
302+
self.sending_to_checkpoint_mp = False
303+
self.staging_id = None
304+
self.cpu_offload_state_dict = None
305+
self.staging_stream = torch.cuda.Stream() if self.enable_staging else None
257306

258307
self.staging = False
259308
self.sending_to_checkpoint_mp = False
@@ -264,7 +313,7 @@ def __init__(
264313
self.folder = os.path.join(job_config.job.dump_folder, ckpt_config.folder)
265314
self.interval = ckpt_config.interval
266315
async_mode = ckpt_config.async_mode.lower()
267-
if async_mode == AsyncMode.ASYNC:
316+
if async_mode == AsyncMode.ASYNC or self.ft_manager:
268317
self.pg = dist.new_group(backend="gloo")
269318

270319
self.keep_latest_k = ckpt_config.keep_latest_k
@@ -339,35 +388,44 @@ def save(self, curr_step: int, force: bool = False) -> None:
339388
None
340389
"""
341390

391+
if self.ft_manager:
392+
self._ft_save(curr_step)
393+
342394
if not self._should_save(curr_step, force):
343395
return
344396

345397
begin = time.monotonic()
346-
logger.info("Saving the checkpoint (or staging if async is enabled).")
347-
checkpoint_id = self._create_checkpoint_id(curr_step)
348-
self._async_wait()
349-
# This GC is called for async checkpoint as it is useless to do
350-
# GC right after async_save -- the CPU memory is not able to be
351-
# freed until _async_wait()
352-
if force:
353-
self._save_last_step(curr_step)
354-
elif self.async_mode == AsyncMode.ASYNC_WITH_PINNED_MEM:
355-
GarbageCollection.collect("GC collection invoked by checkpointer.")
356-
self._async_with_pinned_memory(checkpoint_id)
357-
elif self.async_mode == AsyncMode.ASYNC:
358-
GarbageCollection.collect("GC collection invoked by checkpointer.")
359-
self.async_future = dcp.async_save(
360-
self.states, checkpoint_id=checkpoint_id, process_group=self.pg
361-
)
362-
GarbageCollection.collect("GC collection invoked by checkpointer.")
363-
else:
364-
save_with_gc(self.states, checkpoint_id=checkpoint_id)
365-
self._purge_stale_checkpoints()
398+
if not self.ft_manager or self.ft_manager.participating_rank() == 0:
399+
logger.info("Saving the checkpoint (or staging if async is enabled).")
400+
checkpoint_id = self._create_checkpoint_id(curr_step)
401+
self._async_wait()
402+
# This GC is called for async checkpoint as it is useless to do
403+
# GC right after async_save -- the CPU memory is not able to be
404+
# freed until _async_wait()
405+
if force:
406+
self._save_last_step(curr_step)
407+
elif self.async_mode == AsyncMode.ASYNC_WITH_PINNED_MEM:
408+
GarbageCollection.collect("GC collection invoked by checkpointer.")
409+
self._async_with_pinned_memory(checkpoint_id)
410+
elif self.async_mode == AsyncMode.ASYNC:
411+
GarbageCollection.collect("GC collection invoked by checkpointer.")
412+
self.async_future = dcp.async_save(
413+
self.states, checkpoint_id=checkpoint_id, process_group=self.pg
414+
)
415+
GarbageCollection.collect("GC collection invoked by checkpointer.")
416+
else:
417+
save_with_gc(self.states, checkpoint_id=checkpoint_id)
418+
self._purge_stale_checkpoints()
366419

367-
logger.info(
368-
"Finished saving the checkpoint (or staging if async is enabled)"
369-
f"in {time.monotonic() - begin:.2f} seconds."
370-
)
420+
logger.info(
421+
"Finished saving the checkpoint (or staging if async is enabled)"
422+
f"in {time.monotonic() - begin:.2f} seconds."
423+
)
424+
elif self.ft_manager:
425+
logger.info(
426+
"Replica %d doesn't save checkpoint.",
427+
self.ft_manager.participating_rank(),
428+
)
371429

372430
@torch.no_grad()
373431
def load(self, step: int = -1) -> bool:
@@ -384,6 +442,9 @@ def load(self, step: int = -1) -> bool:
384442
bool: Whether the checkpoint was loaded successfully.
385443
"""
386444

445+
if self.ft_manager:
446+
self._ft_load()
447+
387448
if not self.enable_checkpoint or not os.path.isdir(self.folder):
388449
return False
389450

@@ -467,10 +528,36 @@ def _find_load_step(self, folder: str = "") -> int:
467528
return -1
468529
return max(step_counts)
469530

531+
def _ft_folder(self) -> str:
532+
return os.path.join(self.folder, f"ft-replicat-{self.ft_replica_id}")
533+
470534
def _create_checkpoint_id(self, step: int, folder: str = "") -> str:
471535
folder = folder if folder else self.folder
472536
return os.path.join(folder, f"step-{step}")
473537

538+
def _ft_save(self, step: int) -> None:
539+
begin = time.monotonic()
540+
self._async_wait()
541+
checkpoint_id = self._create_checkpoint_id(step, folder=self._ft_folder())
542+
self.async_future = dcp.async_save(
543+
self.ft_states, checkpoint_id=checkpoint_id, process_group=self.pg
544+
)
545+
logger.info(f"Staging ft checkpoint took {time.monotonic() - begin} secs.")
546+
547+
def _ft_load(self) -> None:
548+
step = self._find_load_step(folder=self._ft_folder())
549+
if step == -1:
550+
return
551+
552+
begin = time.monotonic()
553+
logger.info(f"Loading the FT checkpoint at step {step}.")
554+
checkpoint_id = self._create_checkpoint_id(step, folder=self._ft_folder())
555+
dcp.load(self.ft_states, checkpoint_id=checkpoint_id)
556+
GarbageCollection.collect("GC collection for checkpoint loading.")
557+
logger.info(
558+
f"Finished loading the ft checkpoint in {time.monotonic() - begin:.2f} seconds."
559+
)
560+
474561
def _states_to_load(self, step: int) -> Dict[str, Any]:
475562
"""Determines which states to load for the given step.
476563
@@ -491,6 +578,8 @@ def _states_to_load(self, step: int) -> Dict[str, Any]:
491578
for exclude_key in self.exclude_from_loading:
492579
if exclude_key not in states:
493580
raise ValueError(f"{exclude_key} not found in state_dict.")
581+
if self.ft_manager:
582+
states_to_load.pop(DATALOADER)
494583
return states_to_load
495584

496585
def _save_last_step(self, curr_step: int) -> None:
@@ -577,6 +666,7 @@ def _purge_stale_checkpoints(self):
577666
self.keep_latest_k > 0
578667
and dist.get_rank() == 0
579668
and os.path.isdir(self.folder)
669+
and (not self.ft_manager or self.ft_manager.participating_rank() == 0)
580670
):
581671
discovered_checkpoints = []
582672
for filename in os.listdir(self.folder):

torchtitan/components/ft.py

+47
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,47 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
import importlib
8+
from typing import Optional
9+
10+
from torchtitan.config_manager import JobConfig
11+
12+
if importlib.util.find_spec("torchft") is not None:
13+
import torchft as ft
14+
15+
has_torchft = True
16+
else:
17+
has_torchft = False
18+
19+
20+
def init_ft_manager(job: JobConfig) -> Optional["ft.Manager"]:
21+
"""Initialize the FT manager if TorchFT is enabled.
22+
23+
Args:
24+
job (JobConfig): The job configuration.
25+
26+
Returns:
27+
Optional[ft.Manager]: The FT manager if TorchFT is enabled, otherwise None.
28+
"""
29+
if not job.experimental.enable_torchft:
30+
return None
31+
32+
if not has_torchft:
33+
raise ImportError("torchft is not installed. Please install it.")
34+
35+
if job.experimental.ft_min_replica_size < 1:
36+
raise ValueError("At least one FT replica is required.")
37+
38+
pg = ft.ProcessGroupBabyNCCL()
39+
40+
return ft.Manager(
41+
pg=pg,
42+
min_replica_size=job.experimental.ft_min_replica_size,
43+
load_state_dict=None,
44+
state_dict=None,
45+
use_async_quorum=True,
46+
replica_id=f"torchtitan_ft_{job.experimental.ft_replica_id}",
47+
)

0 commit comments

Comments
 (0)