Skip to content

Commit 307ec5b

Browse files
committed
Integrate TorchFT
**Summary** This is a WIP TorchFT integration PR. **Current Issues** This doesn't work at this moment as there are hanged groups when a new group joins. **Issue 1:** ~Group 0 and group 1 will hang during the first `should_commit` after group 1 applying the pending state_dict from group 0.~ Fixed with: pytorch/torchft#83 **Issue 2:** ~Group 0 and group 1 will pass the `should_commit` but group 0 needs healing which is wrong and the healing process will cause another hang.~ Fixed with: pytorch/torchft#83 **Issue 3:** ~The byproduct of issue 1 and issue 2: group 1 will continue to print out~ ``` [rank0]:devgpu051:76838:80357 [0] misc/socket.cc:50 NCCL WARN socketProgress: Connection closed by remote peer devgpu051.cln3.svc.fbinfra.net<33618> ``` Fixed with pytorch/torchft#91 and several other fixes. **Issue 4:** When there are 3 groups, everyone requests the state dict every step. ***How to reproduce?*** Using the `Reproduce steps` to run 2 groups, then add another group by modifying the command. Seems to be fixed, will need more tests. **Issue 5:** Hang will happen if using functional collective. ***How to reproduce?*** Pull the latest version of this PR and comment out line 41 and uncomment line 42 in `torchtitan/utils.py` **Reproduce steps:** 1. Patch TorchFT with pytorch/torchft#82 2. Execute lighthouse 3. Execute the following command in one terminal: ``` TORCHFT_MANAGER_PORT=29520 REPLICA_GROUP_ID=0 CUDA_VISIBLE_DEVICES=0,1 NGPU=2 ./run_llama_train.sh --training.data_parallel_shard_degree=2 --experimental.enable_torchft --experimental.ft_replica_group_id=0 ``` 4. Wait 10 seconds, execute following command in another terminal: ``` TORCHFT_MANAGER_PORT=29522 REPLICA_GROUP_ID=1 CUDA_VISIBLE_DEVICES=2,3 NGPU=2 ./run_llama_train.sh --training.data_parallel_shard_degree=2 --experimental.enable_torchft --experimental.ft_replica_group_id=1 ``` ghstack-source-id: 0a019154208455ddcd6571f0bb4f4bb4cf36f9fe Pull Request resolved: #834
1 parent ec82573 commit 307ec5b

File tree

10 files changed

+384
-48
lines changed

10 files changed

+384
-48
lines changed

run_train.sh

+3
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,10 @@ if [ $# -ne 0 ]; then
1919
overrides="$*"
2020
fi
2121

22+
TORCHFT_LIGHTHOUSE=${TORCHFT_LIGHTHOUSE:-"http://localhost:29510"}
23+
2224
PYTORCH_CUDA_ALLOC_CONF="expandable_segments:True" \
25+
TORCHFT_LIGHTHOUSE=${TORCHFT_LIGHTHOUSE} \
2326
torchrun --nproc_per_node=${NGPU} --rdzv_backend c10d --rdzv_endpoint="localhost:0" \
2427
--local-ranks-filter ${LOG_RANK} --role rank --tee 3 \
2528
torchtitan/train.py --job.config_file ${CONFIG_FILE} $overrides

tests/unit_tests/test_checkpoint.py

+8
Original file line numberDiff line numberDiff line change
@@ -61,10 +61,18 @@ class DummyJob:
6161
dump_folder: str = "dummy_folder"
6262

6363

64+
@dataclass
65+
class DummyExperimental:
66+
ft_replica_id = 0
67+
ft_group_size = 1
68+
69+
6470
@dataclass
6571
class DummyJobConfig:
6672
checkpoint: DummyCheckpointConfig = field(default_factory=DummyCheckpointConfig)
6773
job: DummyJob = field(default_factory=DummyJob)
74+
experimental: DummyExperimental = field(default_factory=DummyExperimental)
75+
ft_manager = None
6876

6977

7078
# Dummy instances to supply as constructor arguments.

tests/unit_tests/test_model_converter.py

+1
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@ def build_parallel_dims(job_config, world_size):
2222
pp=job_config.experimental.pipeline_parallel_degree,
2323
world_size=world_size,
2424
enable_loss_parallel=not job_config.training.disable_loss_parallel,
25+
ft_manager=None,
2526
)
2627
return parallel_dims
2728

torchtitan/components/checkpoint.py

+114-27
Original file line numberDiff line numberDiff line change
@@ -214,6 +214,19 @@ class CheckpointManager:
214214
3. LR schedulers also index model states like optimizers. Here we flatten the lr_schedulers
215215
with the assumption that all lr_schedulers have the same state_dict.
216216
217+
Note: TorchFT checkpointing flow
218+
219+
There are two types of checkpoints: when TorchFT is enabled: 1) the full perisistent
220+
checkpoint, 2) the per-replica checkpoint.
221+
222+
The full perisistent checkpoint is saved by the replica with
223+
``ft_manager.participating_rank() == 0``. It contains everything including the model,
224+
optimizer, lr_scheduler, dataloader, and train_state. Right now the full perisistent
225+
checkpoint is loaded by all replicas. However, we can optimize it to only load if
226+
there are no other alive replicas.
227+
228+
The per-replica checkpoint contains only the dataloader and is saved/loaded by all
229+
replicas to/from the its own folder. The folder name is prefixed with the ft_replica_id.
217230
218231
Args:
219232
dataloader (DataLoader): The dataloader used to load the data.
@@ -223,6 +236,7 @@ class CheckpointManager:
223236
states (Dict[str, Any]): The states that need to be saved, other than the
224237
previous 4 components.
225238
job_config (JobConfig): The job config used to configure the checkpointing.
239+
ft_manager (Optional[ft.Manager]): The FTManager from TorchFT.
226240
"""
227241

228242
def __init__(
@@ -233,16 +247,41 @@ def __init__(
233247
lr_schedulers: LRSchedulersContainer,
234248
states: Dict[str, Any],
235249
job_config: JobConfig,
250+
ft_manager: FTManager,
236251
) -> None:
237252
ckpt_config = job_config.checkpoint
238253
self.enable_checkpoint = ckpt_config.enable_checkpoint
254+
self.ft_manager = ft_manager.manager if ft_manager.enabled else None
255+
256+
if self.ft_manager:
257+
optimizers.init_cache_state_dict()
258+
259+
def state_dict():
260+
ret = {}
261+
for k, v in self.states.items():
262+
if k in {
263+
MODEL,
264+
OPTIMIZER,
265+
LR_SCHEDULER,
266+
TRAIN_STATE,
267+
}:
268+
ret[k] = v.state_dict()
269+
return ret
270+
271+
def load_state_dict(state_dict):
272+
assert state_dict is not None
273+
for k, v in state_dict.items():
274+
self.states[k].load_state_dict(v)
275+
276+
self.ft_manager.set_state_dict_fns(load_state_dict, state_dict)
277+
self.ft_replica_id = job_config.experimental.ft_replica_id
239278

240279
async_mode = ckpt_config.async_mode.lower()
241280
self.enable_staging = (
242281
self.enable_checkpoint and async_mode == AsyncMode.ASYNC_WITH_PINNED_MEM
243-
)
282+
) or self.ft_manager
244283

245-
if not self.enable_checkpoint:
284+
if not self.enable_checkpoint and self.ft_manager is None:
246285
return
247286

248287
self.states = states
@@ -254,6 +293,13 @@ def __init__(
254293
LR_SCHEDULER: lr_schedulers,
255294
}
256295
)
296+
self.ft_states = {DATALOADER: dataloader}
297+
298+
self.staging = False
299+
self.sending_to_checkpoint_mp = False
300+
self.staging_id = None
301+
self.cpu_offload_state_dict = None
302+
self.staging_stream = torch.cuda.Stream() if self.enable_staging else None
257303

258304
self.staging = False
259305
self.sending_to_checkpoint_mp = False
@@ -264,7 +310,7 @@ def __init__(
264310
self.folder = os.path.join(job_config.job.dump_folder, ckpt_config.folder)
265311
self.interval = ckpt_config.interval
266312
async_mode = ckpt_config.async_mode.lower()
267-
if async_mode == AsyncMode.ASYNC:
313+
if async_mode == AsyncMode.ASYNC or self.ft_manager:
268314
self.pg = dist.new_group(backend="gloo")
269315

270316
self.keep_latest_k = ckpt_config.keep_latest_k
@@ -339,35 +385,44 @@ def save(self, curr_step: int, force: bool = False) -> None:
339385
None
340386
"""
341387

388+
if self.ft_manager:
389+
self._ft_save(curr_step)
390+
342391
if not self._should_save(curr_step, force):
343392
return
344393

345394
begin = time.monotonic()
346-
logger.info("Saving the checkpoint (or staging if async is enabled).")
347-
checkpoint_id = self._create_checkpoint_id(curr_step)
348-
self._async_wait()
349-
# This GC is called for async checkpoint as it is useless to do
350-
# GC right after async_save -- the CPU memory is not able to be
351-
# freed until _async_wait()
352-
if force:
353-
self._save_last_step(curr_step)
354-
elif self.async_mode == AsyncMode.ASYNC_WITH_PINNED_MEM:
355-
GarbageCollection.collect("GC collection invoked by checkpointer.")
356-
self._async_with_pinned_memory(checkpoint_id)
357-
elif self.async_mode == AsyncMode.ASYNC:
358-
GarbageCollection.collect("GC collection invoked by checkpointer.")
359-
self.async_future = dcp.async_save(
360-
self.states, checkpoint_id=checkpoint_id, process_group=self.pg
361-
)
362-
GarbageCollection.collect("GC collection invoked by checkpointer.")
363-
else:
364-
save_with_gc(self.states, checkpoint_id=checkpoint_id)
365-
self._purge_stale_checkpoints()
395+
if not self.ft_manager or self.ft_manager.participating_rank() == 0:
396+
logger.info("Saving the checkpoint (or staging if async is enabled).")
397+
checkpoint_id = self._create_checkpoint_id(curr_step)
398+
self._async_wait()
399+
# This GC is called for async checkpoint as it is useless to do
400+
# GC right after async_save -- the CPU memory is not able to be
401+
# freed until _async_wait()
402+
if force:
403+
self._save_last_step(curr_step)
404+
elif self.async_mode == AsyncMode.ASYNC_WITH_PINNED_MEM:
405+
GarbageCollection.collect("GC collection invoked by checkpointer.")
406+
self._async_with_pinned_memory(checkpoint_id)
407+
elif self.async_mode == AsyncMode.ASYNC:
408+
GarbageCollection.collect("GC collection invoked by checkpointer.")
409+
self.async_future = dcp.async_save(
410+
self.states, checkpoint_id=checkpoint_id, process_group=self.pg
411+
)
412+
GarbageCollection.collect("GC collection invoked by checkpointer.")
413+
else:
414+
save_with_gc(self.states, checkpoint_id=checkpoint_id)
415+
self._purge_stale_checkpoints()
366416

367-
logger.info(
368-
"Finished saving the checkpoint (or staging if async is enabled)"
369-
f"in {time.monotonic() - begin:.2f} seconds."
370-
)
417+
logger.info(
418+
"Finished saving the checkpoint (or staging if async is enabled)"
419+
f"in {time.monotonic() - begin:.2f} seconds."
420+
)
421+
elif self.ft_manager:
422+
logger.info(
423+
"Replica %d doesn't save checkpoint.",
424+
self.ft_manager.participating_rank(),
425+
)
371426

372427
@torch.no_grad()
373428
def load(self, step: int = -1) -> bool:
@@ -384,6 +439,9 @@ def load(self, step: int = -1) -> bool:
384439
bool: Whether the checkpoint was loaded successfully.
385440
"""
386441

442+
if self.ft_manager:
443+
self._ft_load()
444+
387445
if not self.enable_checkpoint or not os.path.isdir(self.folder):
388446
return False
389447

@@ -467,10 +525,36 @@ def _find_load_step(self, folder: str = "") -> int:
467525
return -1
468526
return max(step_counts)
469527

528+
def _ft_folder(self) -> str:
529+
return os.path.join(self.folder, f"ft-replicat-{self.ft_replica_id}")
530+
470531
def _create_checkpoint_id(self, step: int, folder: str = "") -> str:
471532
folder = folder if folder else self.folder
472533
return os.path.join(folder, f"step-{step}")
473534

535+
def _ft_save(self, step: int) -> None:
536+
begin = time.monotonic()
537+
self._async_wait()
538+
checkpoint_id = self._create_checkpoint_id(step, folder=self._ft_folder())
539+
self.async_future = dcp.async_save(
540+
self.ft_states, checkpoint_id=checkpoint_id, process_group=self.pg
541+
)
542+
logger.info(f"Staging ft checkpoint took {time.monotonic() - begin} secs.")
543+
544+
def _ft_load(self) -> None:
545+
step = self._find_load_step(folder=self._ft_folder())
546+
if step == -1:
547+
return
548+
549+
begin = time.monotonic()
550+
logger.info(f"Loading the FT checkpoint at step {step}.")
551+
checkpoint_id = self._create_checkpoint_id(step, folder=self._ft_folder())
552+
dcp.load(self.ft_states, checkpoint_id=checkpoint_id)
553+
GarbageCollection.collect("GC collection for checkpoint loading.")
554+
logger.info(
555+
f"Finished loading the ft checkpoint in {time.monotonic() - begin:.2f} seconds."
556+
)
557+
474558
def _states_to_load(self, step: int) -> Dict[str, Any]:
475559
"""Determines which states to load for the given step.
476560
@@ -491,6 +575,8 @@ def _states_to_load(self, step: int) -> Dict[str, Any]:
491575
for exclude_key in self.exclude_from_loading:
492576
if exclude_key not in states:
493577
raise ValueError(f"{exclude_key} not found in state_dict.")
578+
if self.ft_manager:
579+
states_to_load.pop(DATALOADER)
494580
return states_to_load
495581

496582
def _save_last_step(self, curr_step: int) -> None:
@@ -577,6 +663,7 @@ def _purge_stale_checkpoints(self):
577663
self.keep_latest_k > 0
578664
and dist.get_rank() == 0
579665
and os.path.isdir(self.folder)
666+
and (not self.ft_manager or self.ft_manager.participating_rank() == 0)
580667
):
581668
discovered_checkpoints = []
582669
for filename in os.listdir(self.folder):

torchtitan/components/ft.py

+78
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,78 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
import importlib
8+
from typing import Optional
9+
10+
from torchtitan.config_manager import JobConfig
11+
12+
if importlib.util.find_spec("torchft") is not None:
13+
import torchft as ft
14+
15+
has_torchft = True
16+
else:
17+
has_torchft = False
18+
19+
20+
class FTManager:
21+
def __init__(
22+
self,
23+
manager: Optional["ft.Manager"],
24+
group_size: int = 1,
25+
replica_id: int = 0,
26+
) -> None:
27+
self._manager = manager
28+
self.group_size = group_size
29+
self.replica_id = replica_id
30+
31+
@property
32+
def enabled(self) -> bool:
33+
return self._manager is not None
34+
35+
@property
36+
def manager(self) -> "ft.Manager":
37+
assert self._manager is not None
38+
return self._manager
39+
40+
def get_dp_rank(self, dp_degree: int, dp_rank: int) -> int:
41+
return dp_degree * self.replica_id + dp_rank
42+
43+
def get_dp_degree(self, dp_degree: int) -> int:
44+
return dp_degree * self.group_size
45+
46+
47+
def init_ft_manager(job: JobConfig) -> FTManager:
48+
"""Initialize the FT manager if TorchFT is enabled.
49+
50+
Args:
51+
job (JobConfig): The job configuration.
52+
53+
Returns:
54+
Optional[ft.Manager]: The FT manager if TorchFT is enabled, otherwise None.
55+
"""
56+
if not job.experimental.enable_torchft:
57+
return FTManager(None)
58+
59+
if not has_torchft:
60+
raise ImportError("torchft is not installed. Please install it.")
61+
62+
if job.experimental.ft_min_replica_size < 1:
63+
raise ValueError("At least one FT replica is required.")
64+
65+
pg = ft.ProcessGroupBabyNCCL()
66+
67+
return FTManager(
68+
ft.Manager(
69+
pg=pg,
70+
min_replica_size=job.experimental.ft_min_replica_size,
71+
load_state_dict=None,
72+
state_dict=None,
73+
use_async_quorum=True,
74+
replica_id=f"torchtitan_ft_{job.experimental.ft_replica_id}",
75+
),
76+
group_size=job.experimental.ft_group_size,
77+
replica_id=job.experimental.ft_replica_id,
78+
)

0 commit comments

Comments
 (0)