Skip to content

Commit 84e7afb

Browse files
committed
[WIP][RFC] TorchFT integration
Summary: This is a WIP TorchFT integration PR. Test Plan: ``` TORCHFT_MANAGER_PORT=29520 REPLICA_GROUP_ID=0 CUDA_VISIBLE_DEVICES=0,1 NGPU=2 ./run_llama_train.sh --training.data_parallel_shard_degree=2 --experimental.enable_torchft --experimental.ft_replica_group_id=0 ``` ``` TORCHFT_MANAGER_PORT=29522 REPLICA_GROUP_ID=1 CUDA_VISIBLE_DEVICES=2,3 NGPU=2 ./run_llama_train.sh --training.data_parallel_shard_degree=2 --experimental.enable_torchft --experimental.ft_replica_group_id=1 ``` Reviewers: Subscribers: Tasks: Tags: ghstack-source-id: 91788fc Pull Request resolved: #806
1 parent d989842 commit 84e7afb

File tree

9 files changed

+264
-106
lines changed

9 files changed

+264
-106
lines changed

run_llama_train.sh

+4
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,11 @@ if [ $# -ne 0 ]; then
1919
overrides="$*"
2020
fi
2121

22+
TORCHFT_MANAGER_PORT=${TORCHFT_MANAGER_PORT:-"29512"}
23+
2224
PYTORCH_CUDA_ALLOC_CONF="expandable_segments:True" \
25+
TORCHFT_LIGHTHOUSE=http://localhost:29510 \
26+
TORCHFT_MANAGER_PORT=${TORCHFT_MANAGER_PORT} \
2327
torchrun --nproc_per_node=${NGPU} --rdzv_backend c10d --rdzv_endpoint="localhost:0" \
2428
--local-ranks-filter ${LOG_RANK} --role rank --tee 3 \
2529
train.py --job.config_file ${CONFIG_FILE} $overrides

torchtitan/checkpoint.py

+119-86
Original file line numberDiff line numberDiff line change
@@ -13,12 +13,13 @@
1313
from dataclasses import dataclass, field
1414
from io import BytesIO
1515
from multiprocessing import get_context
16-
from typing import Any, Dict, List, Union
16+
from typing import Any, Dict, List, Optional, Union
1717

1818
import torch
1919
import torch.distributed as dist
2020
import torch.distributed.checkpoint as dcp
2121
import torch.nn as nn
22+
from torch.distributed._state_dict_utils import _copy_state_dict, _create_cpu_state_dict
2223
from torch.distributed.checkpoint.state_dict import (
2324
get_model_state_dict,
2425
set_model_state_dict,
@@ -143,49 +144,28 @@ def __init__(
143144
lr_schedulers: SchedulersContainer,
144145
states: Dict[str, Any],
145146
job_config: JobConfig,
147+
ft_manager: Optional[Any] = None,
146148
) -> None:
147149
ckpt_config = job_config.checkpoint
148150
self.enable_checkpoint = ckpt_config.enable_checkpoint
149-
self.keep_latest_k = ckpt_config.keep_latest_k
151+
self.ft_manager = ft_manager
150152

151-
if not self.enable_checkpoint:
153+
if not self.enable_checkpoint and self.ft_manager is None:
152154
return
153-
"""
154-
Note: Pipeline Parallelism and Virtual Stages
155-
156-
1. even for simple PP schedules, there is a separate optimizer each PP rank.
157-
rank0's optimizer would have a param_group[0] which refers to layers.0 in the original model.
158-
rank1's would _also_ have a param_group[0], since it's index based, but referring to layers.1.
159-
When saving, these collide and one of them is lost. Then when reloading, only one stage can
160-
restore its optimizer states, others will error.
161-
162-
The solution to this problem is optimizer flattening: it landed in #127071 and is enabled in TorchTitan
163-
by passing the 'flatten_optimizer_state_dict' kwarg to DCP functions called in the OptimizerContainer.
164-
165-
2. With complex PP schedules, we have multiple model chunks per pp rank. This compounds challenge (1) by also
166-
requiring us to reason about multiple 'optim' objects locally.
167-
168-
We solve this in the Model and Optimizer wrapper classes by flattening the state dicts from each object
169-
into one state dict before saving/loading. We rely on the individual state_dicts to not collide,
170-
which is gauranteed for the model by correct pipeline splitting and for the optimizer by the flattening
171-
support described in (1).
172-
173-
3. LR schedulers also index model states like optimizers and would need to be flattened properly to support
174-
resharding. Unfortunately, the implementations of different lr_schedulers do not follow a clear pattern like
175-
optimizers do, so it's hard to write a generic 'flattener' utility.
176-
177-
TODO: This is currently unsolved and needs a fix.
178-
"""
179-
self.states = states
180155

181-
self.states.update(
182-
{
183-
"model": ModelWrapper(model_parts),
184-
"optimizer": optimizers,
185-
"dataloader": dataloader,
186-
}
156+
self._initialize_states(
157+
states, dataloader, model_parts, optimizers, lr_schedulers
187158
)
188-
self.states.update(lr_schedulers.get_lr_scheduler_state())
159+
160+
async_mode = ckpt_config.async_mode.lower()
161+
self.enable_staging = (
162+
self.enable_checkpoint and async_mode == AsyncMode.ASYNC_WITH_PINNED_MEM
163+
) or self.ft_manager
164+
self.staging = False
165+
self.sending_to_checkpoint_mp = False
166+
self.staging_id = None
167+
self.cpu_offload_state_dict = None
168+
self.staging_stream = torch.cuda.Stream() if self.enable_staging else None
189169

190170
self.folder = os.path.join(job_config.job.dump_folder, ckpt_config.folder)
191171
self.interval_type = (
@@ -199,11 +179,11 @@ def __init__(
199179
self.time_sync_result = None
200180
self.pg = dist.new_group(backend="gloo")
201181

182+
self.keep_latest_k = ckpt_config.keep_latest_k
202183
self.model_weights_only = ckpt_config.model_weights_only
203184
self.export_dtype = TORCH_DTYPE_MAP[ckpt_config.export_dtype]
204185

205186
self.mp = None
206-
async_mode = ckpt_config.async_mode.lower()
207187
if async_mode == AsyncMode.DISABLED:
208188
self.async_mode = AsyncMode.DISABLED
209189
elif async_mode == AsyncMode.ASYNC:
@@ -223,10 +203,6 @@ def __init__(
223203
daemon=True,
224204
)
225205
self.mp.start()
226-
self.cpu_offload_state_dict = None
227-
self.staging = False
228-
self.staging_id = None
229-
self.staging_stream = torch.cuda.Stream()
230206
else:
231207
raise ValueError(f"Unkown checkpoint async_mode {ckpt_config.async_mode}")
232208

@@ -240,8 +216,61 @@ def __del__(self):
240216
self.mp.join()
241217

242218
def reset(self) -> None:
219+
# We need to stage the local state if another replicate joins during the
220+
# first step.
221+
if self.ft_manager:
222+
self.cpu_staging(None)
243223
self.begin_time = time.monotonic()
244224

225+
def _initialize_states(
226+
self,
227+
states: Dict[str, Any],
228+
dataloader: DataLoader,
229+
model_parts: List[nn.Module],
230+
optimizers: OptimizersContainer,
231+
lr_schedulers: SchedulersContainer,
232+
) -> None:
233+
"""
234+
Note: Pipeline Parallelism and Virtual Stages
235+
236+
1. Even for simple PP schedules, there is a separate optimizer each PP rank.
237+
rank0's optimizer would have a param_group[0] which refers to layers.0 in the
238+
original model. rank1's would _also_ have a param_group[0], since it's index based,
239+
but referring to layers.1.
240+
When saving, these collide and one of them is lost. Then when reloading, only one
241+
stage can restore its optimizer states, others will error.
242+
243+
The solution to this problem is optimizer flattening: it landed in #127071
244+
and is enabled in TorchTitan by passing the 'flatten_optimizer_state_dict'
245+
kwarg to DCP functions called in the OptimizerContainer.
246+
247+
2. With complex PP schedules, we have multiple model chunks per pp rank. This
248+
compounds challenge (1) by also requiring us to reason about multiple 'optim'
249+
objects locally.
250+
251+
We solve this in the Model and Optimizer wrapper classes by flattening the
252+
state dicts from each object into one state dict before saving/loading.
253+
We rely on the individual state_dicts to not collide, which is gauranteed for
254+
the model by correct pipeline splitting and for the optimizer by the flattening
255+
support described in (1).
256+
257+
3. LR schedulers also index model states like optimizers and would need to be
258+
flattened properly to support resharding. Unfortunately, the implementations of
259+
different lr_schedulers do not follow a clear pattern like optimizers do, so it's
260+
hard to write a generic 'flattener' utility.
261+
262+
TODO: This is currently unsolved and needs a fix.
263+
"""
264+
self.states = states
265+
self.states.update(
266+
{
267+
"model": ModelWrapper(model_parts),
268+
"optimizer": optimizers,
269+
"dataloader": dataloader,
270+
}
271+
)
272+
self.states.update(lr_schedulers.get_lr_scheduler_state())
273+
245274
def _create_checkpoint_id(self, step: int) -> str:
246275
return os.path.join(self.folder, f"step-{step}")
247276

@@ -324,31 +353,8 @@ def _async_wait(self) -> None:
324353
self.async_future.result()
325354

326355
def _async_with_pinned_memory(self, checkpoint_id: str) -> None:
327-
try:
328-
from torch.distributed._state_dict_utils import (
329-
_copy_state_dict,
330-
_create_cpu_state_dict,
331-
)
332-
except ImportError as e:
333-
raise ImportError(
334-
"Please install the latest PyTorch nightly to use async checkpointing with pinned memory."
335-
) from e
336-
state_dict = dcp.state_dict_saver._stateful_to_state_dict(self.states)
337-
if self.cpu_offload_state_dict is None:
338-
logger.debug(f"Preparing the CPU memory, {time.monotonic()=}.:.2f")
339-
self.cpu_offload_state_dict = _create_cpu_state_dict(
340-
state_dict, pin_memory=True, share_memory=True
341-
)
342-
343-
logger.debug(f"Staging the state_dict, {time.monotonic()=}.:.2f")
344-
with torch.cuda.stream(self.staging_stream):
345-
self.cpu_offload_state_dict = _copy_state_dict(
346-
state_dict,
347-
self.cpu_offload_state_dict,
348-
non_blocking=True,
349-
)
350-
self.staging = True
351-
self.staging_id = checkpoint_id
356+
self.cpu_staging(checkpoint_id)
357+
self.sending_to_checkpoint_mp = True
352358

353359
def save(self, curr_step: int, force: bool = False) -> None:
354360
"""
@@ -358,6 +364,8 @@ def save(self, curr_step: int, force: bool = False) -> None:
358364
for initial seed checkpoint.
359365
"""
360366
if not self._should_save(curr_step, force):
367+
if self.ft_manager:
368+
self.cpu_staging(None)
361369
return
362370

363371
begin = time.monotonic()
@@ -381,26 +389,51 @@ def save(self, curr_step: int, force: bool = False) -> None:
381389
f"in {time.monotonic() - begin:.2f} seconds."
382390
)
383391

392+
def cpu_staging(self, checkpoint_id: Optional[str]) -> None:
393+
"""Offload state_dict to CPU memory"""
394+
state_dict = dcp.state_dict_saver._stateful_to_state_dict(self.states)
395+
if self.cpu_offload_state_dict is None:
396+
logger.debug(f"Preparing the CPU memory, {time.monotonic()=}.:.2f")
397+
self.cpu_offload_state_dict = _create_cpu_state_dict(
398+
state_dict, pin_memory=True, share_memory=True
399+
)
400+
401+
logger.debug(f"Staging the state_dict, {time.monotonic()=}.:.2f")
402+
with torch.cuda.stream(self.staging_stream):
403+
self.cpu_offload_state_dict = _copy_state_dict(
404+
state_dict,
405+
self.cpu_offload_state_dict,
406+
non_blocking=True,
407+
)
408+
self.staging = True
409+
self.staging_id = checkpoint_id
410+
411+
def wait_for_staging(self) -> None:
412+
if not self.staging_stream.query():
413+
self.staging_stream.synchronize()
414+
self.staging = False
415+
416+
def staging_results(self) -> Dict[str, Any]:
417+
self.maybe_wait_for_staging()
418+
return self.cpu_offload_state_dict
419+
384420
def maybe_wait_for_staging(self) -> None:
385-
if (
386-
self.enable_checkpoint
387-
and self.async_mode == AsyncMode.ASYNC_WITH_PINNED_MEM
388-
and self.staging
389-
):
390-
if not self.staging_stream.query():
391-
self.staging_stream.synchronize()
392-
393-
def sync_func():
394-
self.mp_queue_send.put_nowait(
395-
(self.cpu_offload_state_dict, self.staging_id)
396-
)
397-
398-
# This may be a faster way to do zero-overhead checkpointing staging
399-
# checkpointing but we need more thorough investigation before
400-
# swithing to this method.
401-
# self.my_thread = threading.Thread(target=func).start()
402-
sync_func()
403-
self.staging = False
421+
if self.enable_staging and self.staging:
422+
self.wait_for_staging()
423+
424+
if self.sending_to_checkpoint_mp:
425+
# Copy the sync staging result to another process.
426+
def sync_func():
427+
self.mp_queue_send.put_nowait(
428+
(self.cpu_offload_state_dict, self.staging_id)
429+
)
430+
431+
# This may be a faster way to do zero-overhead checkpointing staging
432+
# checkpointing but we need more thorough investigation before
433+
# swithing to this method.
434+
# self.my_thread = threading.Thread(target=func).start()
435+
sync_func()
436+
self.sending_to_checkpoint_mp = False
404437

405438
def load(self, step: int = -1) -> bool:
406439
if not self.enable_checkpoint:

torchtitan/config_manager.py

+13
Original file line numberDiff line numberDiff line change
@@ -604,6 +604,19 @@ def __init__(self):
604604
action="store_true",
605605
)
606606

607+
self.parser.add_argument(
608+
"--experimental.enable_torchft",
609+
action="store_true",
610+
help="Enable TorchFT integration.",
611+
)
612+
613+
self.parser.add_argument(
614+
"--experimental.ft_replica_group_id",
615+
type=int,
616+
default=-1,
617+
help="The FT replicate group of this run.",
618+
)
619+
607620
def to_dict(self):
608621
return self.args_dict
609622

torchtitan/ft.py

+59
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,59 @@
1+
from typing import Any, Callable, Optional
2+
import importlib
3+
4+
from torchtitan.config_manager import JobConfig
5+
from torch.distributed._state_dict_utils import (
6+
_copy_state_dict,
7+
_create_cpu_state_dict,
8+
)
9+
10+
if importlib.util.find_spec("torchft") is not None:
11+
import torchft as ft
12+
has_torchft = True
13+
else:
14+
has_torchft = False
15+
16+
17+
def init_ft_manager(job: JobConfig) -> Optional["ft.Manager"]:
18+
"""
19+
Initialize the FT manager for the given job.
20+
"""
21+
if not job.experimental.enable_torchft:
22+
return None
23+
24+
if not has_torchft:
25+
raise ImportError("torchft is not installed. Please install it.")
26+
27+
pg = ft.ProcessGroupBabyNCCL()
28+
manager = ft.Manager(
29+
pg=pg,
30+
min_replica_size=1,
31+
load_state_dict=None,
32+
state_dict=None,
33+
use_async_quorum=True,
34+
replica_id=f"torchtitan_ft_{job.experimental.ft_replica_group_id}",
35+
)
36+
37+
return manager
38+
39+
40+
def set_ft_state_dict_fns(manager: Optional["ft.Manager"], ckpt_manager) -> None:
41+
"""
42+
Set the state dict for the given manager.
43+
"""
44+
if manager is None:
45+
return
46+
47+
def state_dict():
48+
ret = {}
49+
for k, v in ckpt_manager.staging_results().items():
50+
if k in {"model", "optimizer", "lr_schedulers"}:
51+
ret[k] = v
52+
return ret
53+
54+
def load_state_dict(state_dict):
55+
assert state_dict is not None
56+
for k, v in state_dict.items():
57+
ckpt_manager.states[k].load_state_dict(v)
58+
59+
manager.set_state_dict_fns(load_state_dict, state_dict)

0 commit comments

Comments
 (0)