Skip to content

Commit 58510b5

Browse files
committed
Integrate TorchFT
**Summary** This is a WIP TorchFT integration PR. **Current Issues** This doesn't work at this moment as there are hanged groups when a new group joins. **Issue 1:** ~Group 0 and group 1 will hang during the first `should_commit` after group 1 applying the pending state_dict from group 0.~ Fixed with: pytorch/torchft#83 **Issue 2:** ~Group 0 and group 1 will pass the `should_commit` but group 0 needs healing which is wrong and the healing process will cause another hang.~ Fixed with: pytorch/torchft#83 **Issue 3:** ~The byproduct of issue 1 and issue 2: group 1 will continue to print out~ ``` [rank0]:devgpu051:76838:80357 [0] misc/socket.cc:50 NCCL WARN socketProgress: Connection closed by remote peer devgpu051.cln3.svc.fbinfra.net<33618> ``` Fixed with pytorch/torchft#91 and several other fixes. **Issue 4:** When there are 3 groups, everyone requests the state dict every step. ***How to reproduce?*** Using the `Reproduce steps` to run 2 groups, then add another group by modifying the command. Seems to be fixed, will need more tests. **Issue 5:** Hang will happen if using functional collective. ***How to reproduce?*** Pull the latest version of this PR and comment out line 41 and uncomment line 42 in `torchtitan/utils.py` **Reproduce steps:** 1. Patch TorchFT with pytorch/torchft#82 2. Execute lighthouse 3. Execute the following command in one terminal: ``` TORCHFT_MANAGER_PORT=29520 REPLICA_GROUP_ID=0 CUDA_VISIBLE_DEVICES=0,1 NGPU=2 ./run_llama_train.sh --training.data_parallel_shard_degree=2 --experimental.enable_torchft --experimental.ft_replica_group_id=0 ``` 4. Wait 10 seconds, execute following command in another terminal: ``` TORCHFT_MANAGER_PORT=29522 REPLICA_GROUP_ID=1 CUDA_VISIBLE_DEVICES=2,3 NGPU=2 ./run_llama_train.sh --training.data_parallel_shard_degree=2 --experimental.enable_torchft --experimental.ft_replica_group_id=1 ``` ghstack-source-id: ec7fd5c Pull Request resolved: #834
1 parent 37b92a0 commit 58510b5

File tree

8 files changed

+275
-41
lines changed

8 files changed

+275
-41
lines changed

run_llama_train.sh

+4
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,11 @@ if [ $# -ne 0 ]; then
1919
overrides="$*"
2020
fi
2121

22+
TORCHFT_MANAGER_PORT=${TORCHFT_MANAGER_PORT:-"29512"}
23+
2224
PYTORCH_CUDA_ALLOC_CONF="expandable_segments:True" \
25+
TORCHFT_LIGHTHOUSE=http://localhost:29510 \
26+
TORCHFT_MANAGER_PORT=${TORCHFT_MANAGER_PORT} \
2327
torchrun --nproc_per_node=${NGPU} --rdzv_backend c10d --rdzv_endpoint="localhost:0" \
2428
--local-ranks-filter ${LOG_RANK} --role rank --tee 3 \
2529
train.py --job.config_file ${CONFIG_FILE} $overrides

torchtitan/checkpoint.py

+101-27
Original file line numberDiff line numberDiff line change
@@ -228,6 +228,7 @@ class CheckpointManager:
228228
states (Dict[str, Any]): The states that need to be saved, other than the
229229
previous 4 components.
230230
job_config (JobConfig): The job config used to configure the checkpointing.
231+
ft_manager (Optional[ft.Manager]): The FTManager from TorchFT.
231232
"""
232233

233234
def __init__(
@@ -238,16 +239,41 @@ def __init__(
238239
lr_schedulers: LRSchedulersContainer,
239240
states: Dict[str, Any],
240241
job_config: JobConfig,
242+
ft_manager: Optional["ft.Manager"] = None,
241243
) -> None:
242244
ckpt_config = job_config.checkpoint
243245
self.enable_checkpoint = ckpt_config.enable_checkpoint
246+
self.ft_manager = ft_manager
247+
248+
if self.ft_manager:
249+
optimizers.init_cache_state_dict()
250+
251+
def state_dict():
252+
ret = {}
253+
for k, v in self.states.items():
254+
if k in {
255+
MODEL,
256+
OPTIMIZER,
257+
LR_SCHEDULER,
258+
TRAIN_STATE,
259+
}:
260+
ret[k] = v.state_dict()
261+
return ret
262+
263+
def load_state_dict(state_dict):
264+
assert state_dict is not None
265+
for k, v in state_dict.items():
266+
self.states[k].load_state_dict(v)
267+
268+
ft_manager.set_state_dict_fns(load_state_dict, state_dict)
269+
self.ft_replica_id = job_config.experimental.ft_replica_id
244270

245271
async_mode = ckpt_config.async_mode.lower()
246272
self.enable_staging = (
247273
self.enable_checkpoint and async_mode == AsyncMode.ASYNC_WITH_PINNED_MEM
248-
)
274+
) or self.ft_manager
249275

250-
if not self.enable_checkpoint:
276+
if not self.enable_checkpoint and self.ft_manager is None:
251277
return
252278

253279
self.states = states
@@ -259,6 +285,13 @@ def __init__(
259285
LR_SCHEDULER: lr_schedulers,
260286
}
261287
)
288+
self.ft_states = {DATALOADER: dataloader}
289+
290+
self.staging = False
291+
self.sending_to_checkpoint_mp = False
292+
self.staging_id = None
293+
self.cpu_offload_state_dict = None
294+
self.staging_stream = torch.cuda.Stream() if self.enable_staging else None
262295

263296
self.staging = False
264297
self.sending_to_checkpoint_mp = False
@@ -269,7 +302,7 @@ def __init__(
269302
self.folder = os.path.join(job_config.job.dump_folder, ckpt_config.folder)
270303
self.interval = ckpt_config.interval
271304
async_mode = ckpt_config.async_mode.lower()
272-
if async_mode == AsyncMode.ASYNC:
305+
if async_mode == AsyncMode.ASYNC or self.ft_manager:
273306
self.pg = dist.new_group(backend="gloo")
274307

275308
self.keep_latest_k = ckpt_config.keep_latest_k
@@ -343,35 +376,44 @@ def save(self, curr_step: int, force: bool = False) -> None:
343376
None
344377
"""
345378

379+
if self.ft_manager:
380+
self._ft_save(curr_step)
381+
346382
if not self._should_save(curr_step, force):
347383
return
348384

349385
begin = time.monotonic()
350-
logger.info("Saving the checkpoint (or staging if async is enabled).")
351-
checkpoint_id = self._create_checkpoint_id(curr_step)
352-
self._async_wait()
353-
# This GC is called for async checkpoint as it is useless to do
354-
# GC right after async_save -- the CPU memory is not able to be
355-
# freed until _async_wait()
356-
if force:
357-
self._save_last_step(curr_step)
358-
elif self.async_mode == AsyncMode.ASYNC_WITH_PINNED_MEM:
359-
GarbageCollection.collect("GC collection invoked by checkpointer.")
360-
self._async_with_pinned_memory(checkpoint_id)
361-
elif self.async_mode == AsyncMode.ASYNC:
362-
GarbageCollection.collect("GC collection invoked by checkpointer.")
363-
self.async_future = dcp.async_save(
364-
self.states, checkpoint_id=checkpoint_id, process_group=self.pg
365-
)
366-
GarbageCollection.collect("GC collection invoked by checkpointer.")
367-
else:
368-
save_with_gc(self.states, checkpoint_id=checkpoint_id)
369-
self._purge_stale_checkpoints()
386+
if not self.ft_manager or self.ft_manager.participating_rank() == 0:
387+
logger.info("Saving the checkpoint (or staging if async is enabled).")
388+
checkpoint_id = self._create_checkpoint_id(curr_step)
389+
self._async_wait()
390+
# This GC is called for async checkpoint as it is useless to do
391+
# GC right after async_save -- the CPU memory is not able to be
392+
# freed until _async_wait()
393+
if force:
394+
self._save_last_step(curr_step)
395+
elif self.async_mode == AsyncMode.ASYNC_WITH_PINNED_MEM:
396+
GarbageCollection.collect("GC collection invoked by checkpointer.")
397+
self._async_with_pinned_memory(checkpoint_id)
398+
elif self.async_mode == AsyncMode.ASYNC:
399+
GarbageCollection.collect("GC collection invoked by checkpointer.")
400+
self.async_future = dcp.async_save(
401+
self.states, checkpoint_id=checkpoint_id, process_group=self.pg
402+
)
403+
GarbageCollection.collect("GC collection invoked by checkpointer.")
404+
else:
405+
save_with_gc(self.states, checkpoint_id=checkpoint_id)
406+
self._purge_stale_checkpoints()
370407

371-
logger.info(
372-
"Finished saving the checkpoint (or staging if async is enabled)"
373-
f"in {time.monotonic() - begin:.2f} seconds."
374-
)
408+
logger.info(
409+
"Finished saving the checkpoint (or staging if async is enabled)"
410+
f"in {time.monotonic() - begin:.2f} seconds."
411+
)
412+
elif self.ft_manager:
413+
logger.info(
414+
"Replica %d doesn't save checkpoint.",
415+
self.ft_manager.participating_rank(),
416+
)
375417

376418
@torch.no_grad()
377419
def load(self, step: int = -1) -> bool:
@@ -388,6 +430,9 @@ def load(self, step: int = -1) -> bool:
388430
bool: Whether the checkpoint was loaded successfully.
389431
"""
390432

433+
if self.ft_manager:
434+
self._ft_load()
435+
391436
if not self.enable_checkpoint or not os.path.isdir(self.folder):
392437
return False
393438

@@ -471,10 +516,36 @@ def _find_load_step(self, folder: str = "") -> int:
471516
return -1
472517
return max(step_counts)
473518

519+
def _ft_folder(self) -> str:
520+
return os.path.join(self.folder, f"ft-replicat-{self.ft_replica_id}")
521+
474522
def _create_checkpoint_id(self, step: int, folder: str = "") -> str:
475523
folder = folder if folder else self.folder
476524
return os.path.join(folder, f"step-{step}")
477525

526+
def _ft_save(self, step: int) -> None:
527+
begin = time.monotonic()
528+
self._async_wait()
529+
checkpoint_id = self._create_checkpoint_id(step, folder=self._ft_folder())
530+
self.async_future = dcp.async_save(
531+
self.ft_states, checkpoint_id=checkpoint_id, process_group=self.pg
532+
)
533+
logger.info(f"Staging ft checkpoint took {time.monotonic() - begin} secs.")
534+
535+
def _ft_load(self) -> None:
536+
step = self._find_load_step(folder=self._ft_folder())
537+
if step == -1:
538+
return
539+
540+
begin = time.monotonic()
541+
logger.info(f"Loading the FT checkpoint at step {step}.")
542+
checkpoint_id = self._create_checkpoint_id(step, folder=self._ft_folder())
543+
dcp.load(self.ft_states, checkpoint_id=checkpoint_id)
544+
GarbageCollection.collect("GC collection for checkpoint loading.")
545+
logger.info(
546+
f"Finished loading the ft checkpoint in {time.monotonic() - begin:.2f} seconds."
547+
)
548+
478549
def _states_to_load(self, step: int) -> Dict[str, Any]:
479550
"""Determines which states to load for the given step.
480551
@@ -495,6 +566,8 @@ def _states_to_load(self, step: int) -> Dict[str, Any]:
495566
for exclude_key in self.exclude_from_loading:
496567
if exclude_key not in states:
497568
raise ValueError(f"{exclude_key} not found in state_dict.")
569+
if self.ft_manager:
570+
states_to_load.pop(DATALOADER)
498571
return states_to_load
499572

500573
def _save_last_step(self, curr_step: int) -> None:
@@ -579,6 +652,7 @@ def _cpu_staging(self, checkpoint_id: Optional[str]) -> None:
579652
def _purge_stale_checkpoints(self):
580653
if (
581654
self.keep_latest_k > 0
655+
and self.ft_manager.participating_rank() == 0
582656
and dist.get_rank() == 0
583657
and os.path.isdir(self.folder)
584658
):

torchtitan/config_manager.py

+24
Original file line numberDiff line numberDiff line change
@@ -652,6 +652,30 @@ def __init__(self):
652652
action="store_true",
653653
)
654654

655+
self.parser.add_argument(
656+
"--experimental.enable_torchft",
657+
action="store_true",
658+
help="Enable TorchFT integration.",
659+
)
660+
661+
self.parser.add_argument(
662+
"--experimental.ft_replica_id",
663+
type=int,
664+
default=0,
665+
help="The TorchFT replica ID of this run.",
666+
)
667+
668+
self.parser.add_argument(
669+
"--experimental.ft_group_size",
670+
type=int,
671+
default=1,
672+
help="""
673+
The number of TorchFT replicate groups. This number will be used for
674+
dataloader to split the dataset across the replicate groups and FSDP
675+
dimension.
676+
""",
677+
)
678+
655679
def to_dict(self):
656680
return self.args_dict
657681

torchtitan/ft.py

+44
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,44 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
import importlib
8+
from dataclasses import dataclass
9+
from typing import Optional
10+
11+
from torchtitan.config_manager import JobConfig
12+
13+
if importlib.util.find_spec("torchft") is not None:
14+
import torchft as ft
15+
16+
has_torchft = True
17+
else:
18+
has_torchft = False
19+
20+
21+
def init_ft_manager(job: JobConfig) -> Optional["ft.Manager"]:
22+
"""Initialize the FT manager if TorchFT is enabled.
23+
24+
Args:
25+
job (JobConfig): The job configuration.
26+
27+
Returns:
28+
Optional[ft.Manager]: The FT manager if TorchFT is enabled, otherwise None.
29+
"""
30+
if not job.experimental.enable_torchft:
31+
return None
32+
33+
if not has_torchft:
34+
raise ImportError("torchft is not installed. Please install it.")
35+
36+
pg = ft.ProcessGroupBabyNCCL()
37+
return ft.Manager(
38+
pg=pg,
39+
min_replica_size=1,
40+
load_state_dict=None,
41+
state_dict=None,
42+
use_async_quorum=True,
43+
replica_id=f"torchtitan_ft_{job.experimental.ft_replica_id}",
44+
)

torchtitan/optimizer.py

+51-7
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66

77
import copy
88
import functools
9-
from typing import Any, Callable, Dict, Iterable, List
9+
from typing import Any, Callable, Dict, Iterable, List, Optional
1010

1111
import torch
1212
import torch.nn as nn
@@ -177,8 +177,49 @@ def zero_grad(self) -> None:
177177
pass
178178

179179

180+
class FTOptimizersContainer(OptimizersContainer):
181+
def __init__(
182+
self,
183+
model_parts: List[nn.Module],
184+
optimizer_kwargs: Dict[str, Any],
185+
name: str,
186+
ft_manager: Optional["ft.Manager"],
187+
) -> None:
188+
import torchft as ft
189+
190+
super().__init__(model_parts, optimizer_kwargs, name)
191+
192+
# Force to initialize the optimizer state so that `optim.step()`
193+
# won't be called by state_dict() and load_state_dict().
194+
_ = {
195+
k: v
196+
for sd in map(get_optimizer_state_dict, model_parts, self.optimizers)
197+
for k, v in sd.items()
198+
}
199+
self.optimizers = [
200+
ft.Optimizer(ft_manager, optim) for optim in self.optimizers
201+
]
202+
self.cache_state_dict: Dict[str, Any] = {}
203+
204+
def init_cache_state_dict(self) -> None:
205+
self.cache_state_dict = super().state_dict()
206+
207+
def state_dict(self) -> Dict[str, Any]:
208+
return self.cache_state_dict
209+
210+
def load_state_dict(self, state_dict: Dict[str, Any]) -> None:
211+
# We have to invalidate the `cache_state_dict` because optimizer uses
212+
# assign instead of copy when doing `load_state_dict()`. Without
213+
# invalidating the `cache_state_dict`, there will be memory leakage.
214+
self.cache_state_dict = {}
215+
super().load_state_dict(state_dict)
216+
self.init_cache_state_dict()
217+
218+
180219
def build_optimizers(
181-
model_parts: List[nn.Module], job_config: JobConfig
220+
model_parts: List[nn.Module],
221+
job_config: JobConfig,
222+
ft_manager: Optional["ft.Manager"] = None,
182223
) -> OptimizersContainer:
183224
"""Create a OptimizersContainer for the given model parts and job config.
184225
@@ -213,11 +254,14 @@ def build_optimizers(
213254
"foreach": not fused,
214255
}
215256

216-
return (
217-
OptimizersContainer(model_parts, optimizer_kwargs, name)
218-
if not optim_in_bwd
219-
else OptimizersInBackwardContainer(model_parts, optimizer_kwargs, name)
220-
)
257+
if optim_in_bwd and ft_manager:
258+
raise ValueError("TorchFT is not supported with optimizers in backward.")
259+
elif optim_in_bwd:
260+
return OptimizersInBackwardContainer(model_parts, optimizer_kwargs, name)
261+
elif ft_manager:
262+
return FTOptimizersContainer(model_parts, optimizer_kwargs, name, ft_manager)
263+
else:
264+
return OptimizersContainer(model_parts, optimizer_kwargs, name)
221265

222266

223267
class LRSchedulersContainer(Stateful):

0 commit comments

Comments
 (0)