Skip to content

Commit 1b1a41d

Browse files
committed
Integrate TorchFT
**Summary** This is a WIP TorchFT integration PR. **Current Issues** This doesn't work at this moment as there are hanged groups when a new group joins. **Issue 1:** ~Group 0 and group 1 will hang during the first `should_commit` after group 1 applying the pending state_dict from group 0.~ Fixed with: pytorch/torchft#83 **Issue 2:** ~Group 0 and group 1 will pass the `should_commit` but group 0 needs healing which is wrong and the healing process will cause another hang.~ Fixed with: pytorch/torchft#83 **Issue 3:** ~The byproduct of issue 1 and issue 2: group 1 will continue to print out~ ``` [rank0]:devgpu051:76838:80357 [0] misc/socket.cc:50 NCCL WARN socketProgress: Connection closed by remote peer devgpu051.cln3.svc.fbinfra.net<33618> ``` Fixed with pytorch/torchft#91 and several other fixes. **Issue 4:** When there are 3 groups, everyone requests the state dict every step. ***How to reproduce?*** Using the `Reproduce steps` to run 2 groups, then add another group by modifying the command. Seems to be fixed, will need more tests. **Issue 5:** Hang will happen if using functional collective. ***How to reproduce?*** Pull the latest version of this PR and comment out line 41 and uncomment line 42 in `torchtitan/utils.py` **Reproduce steps:** 1. Patch TorchFT with pytorch/torchft#82 2. Execute lighthouse 3. Execute the following command in one terminal: ``` TORCHFT_MANAGER_PORT=29520 REPLICA_GROUP_ID=0 CUDA_VISIBLE_DEVICES=0,1 NGPU=2 ./run_llama_train.sh --training.data_parallel_shard_degree=2 --experimental.enable_torchft --experimental.ft_replica_group_id=0 ``` 4. Wait 10 seconds, execute following command in another terminal: ``` TORCHFT_MANAGER_PORT=29522 REPLICA_GROUP_ID=1 CUDA_VISIBLE_DEVICES=2,3 NGPU=2 ./run_llama_train.sh --training.data_parallel_shard_degree=2 --experimental.enable_torchft --experimental.ft_replica_group_id=1 ``` ghstack-source-id: 40d49640772abac268fa97147d889aa652559116 Pull Request resolved: #834
1 parent 91a494b commit 1b1a41d

File tree

8 files changed

+270
-97
lines changed

8 files changed

+270
-97
lines changed

run_llama_train.sh

+4
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,11 @@ if [ $# -ne 0 ]; then
1919
overrides="$*"
2020
fi
2121

22+
TORCHFT_MANAGER_PORT=${TORCHFT_MANAGER_PORT:-"29512"}
23+
2224
PYTORCH_CUDA_ALLOC_CONF="expandable_segments:True" \
25+
TORCHFT_LIGHTHOUSE=http://localhost:29510 \
26+
TORCHFT_MANAGER_PORT=${TORCHFT_MANAGER_PORT} \
2327
torchrun --nproc_per_node=${NGPU} --rdzv_backend c10d --rdzv_endpoint="localhost:0" \
2428
--local-ranks-filter ${LOG_RANK} --role rank --tee 3 \
2529
train.py --job.config_file ${CONFIG_FILE} $overrides

torchtitan/checkpoint.py

+118-84
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
import torch.distributed as dist
2020
import torch.distributed.checkpoint as dcp
2121
import torch.nn as nn
22+
from torch.distributed._state_dict_utils import _copy_state_dict, _create_cpu_state_dict
2223
from torch.distributed.checkpoint.state_dict import (
2324
get_model_state_dict,
2425
set_model_state_dict,
@@ -144,50 +145,29 @@ def __init__(
144145
lr_schedulers: LRSchedulersContainer,
145146
states: Dict[str, Any],
146147
job_config: JobConfig,
148+
ft_manager: Optional[Any] = None,
147149
) -> None:
148150
ckpt_config = job_config.checkpoint
149151
self.enable_checkpoint = ckpt_config.enable_checkpoint
150-
self.keep_latest_k = ckpt_config.keep_latest_k
152+
self.ft_manager = ft_manager
153+
self.enable_staging = (
154+
self.enable_checkpoint and async_mode == AsyncMode.ASYNC_WITH_PINNED_MEM
155+
) or self.ft_manager
151156

152-
if not self.enable_checkpoint:
157+
if not self.enable_checkpoint and self.ft_manager is None:
153158
return
154-
"""
155-
Note: Pipeline Parallelism and Virtual Stages
156-
157-
1. even for simple PP schedules, there is a separate optimizer each PP rank.
158-
rank0's optimizer would have a param_group[0] which refers to layers.0 in the original model.
159-
rank1's would _also_ have a param_group[0], since it's index based, but referring to layers.1.
160-
When saving, these collide and one of them is lost. Then when reloading, only one stage can
161-
restore its optimizer states, others will error.
162-
163-
The solution to this problem is optimizer flattening: it landed in #127071 and is enabled in TorchTitan
164-
by passing the 'flatten_optimizer_state_dict' kwarg to DCP functions called in the OptimizerContainer.
165-
166-
2. With complex PP schedules, we have multiple model chunks per pp rank. This compounds challenge (1) by also
167-
requiring us to reason about multiple 'optim' objects locally.
168-
169-
We solve this in the Model and Optimizer wrapper classes by flattening the state dicts from each object
170-
into one state dict before saving/loading. We rely on the individual state_dicts to not collide,
171-
which is gauranteed for the model by correct pipeline splitting and for the optimizer by the flattening
172-
support described in (1).
173-
174-
3. LR schedulers also index model states like optimizers and would need to be flattened properly to support
175-
resharding. Unfortunately, the implementations of different lr_schedulers do not follow a clear pattern like
176-
optimizers do, so it's hard to write a generic 'flattener' utility.
177-
178-
TODO: This is currently unsolved and needs a fix.
179-
"""
180-
self.states = states
181159

182-
self.states.update(
183-
{
184-
"model": ModelWrapper(model_parts),
185-
"optimizer": optimizers,
186-
"dataloader": dataloader,
187-
"lr_scheduler": lr_schedulers,
188-
}
160+
self._initialize_states(
161+
states, dataloader, model_parts, optimizers, lr_schedulers
189162
)
190163

164+
async_mode = ckpt_config.async_mode.lower()
165+
self.staging = False
166+
self.sending_to_checkpoint_mp = False
167+
self.staging_id = None
168+
self.cpu_offload_state_dict = None
169+
self.staging_stream = torch.cuda.Stream() if self.enable_staging else None
170+
191171
self.folder = os.path.join(job_config.job.dump_folder, ckpt_config.folder)
192172
self.interval_type = (
193173
IntervalType.SECONDS
@@ -202,6 +182,7 @@ def __init__(
202182
if async_mode == AsyncMode.ASYNC or self.interval_type == IntervalType.SECONDS:
203183
self.pg = dist.new_group(backend="gloo")
204184

185+
self.keep_latest_k = ckpt_config.keep_latest_k
205186
self.model_weights_only = ckpt_config.model_weights_only
206187
self.export_dtype = TORCH_DTYPE_MAP[ckpt_config.export_dtype]
207188

@@ -225,10 +206,6 @@ def __init__(
225206
daemon=True,
226207
)
227208
self.mp.start()
228-
self.cpu_offload_state_dict = None
229-
self.staging = False
230-
self.staging_id = None
231-
self.staging_stream = torch.cuda.Stream()
232209
else:
233210
raise ValueError(f"Unkown checkpoint async_mode {ckpt_config.async_mode}")
234211

@@ -242,8 +219,61 @@ def __del__(self):
242219
self.mp.join()
243220

244221
def reset(self) -> None:
222+
# We need to stage the local state if another replicate joins during the
223+
# first step.
224+
if self.ft_manager:
225+
self.cpu_staging(None)
245226
self.begin_time = time.monotonic()
246227

228+
def _initialize_states(
229+
self,
230+
states: Dict[str, Any],
231+
dataloader: DataLoader,
232+
model_parts: List[nn.Module],
233+
optimizers: OptimizersContainer,
234+
lr_schedulers: LRSchedulersContainer,
235+
) -> None:
236+
"""
237+
Note: Pipeline Parallelism and Virtual Stages
238+
239+
1. Even for simple PP schedules, there is a separate optimizer each PP rank.
240+
rank0's optimizer would have a param_group[0] which refers to layers.0 in the
241+
original model. rank1's would _also_ have a param_group[0], since it's index based,
242+
but referring to layers.1.
243+
When saving, these collide and one of them is lost. Then when reloading, only one
244+
stage can restore its optimizer states, others will error.
245+
246+
The solution to this problem is optimizer flattening: it landed in #127071
247+
and is enabled in TorchTitan by passing the 'flatten_optimizer_state_dict'
248+
kwarg to DCP functions called in the OptimizerContainer.
249+
250+
2. With complex PP schedules, we have multiple model chunks per pp rank. This
251+
compounds challenge (1) by also requiring us to reason about multiple 'optim'
252+
objects locally.
253+
254+
We solve this in the Model and Optimizer wrapper classes by flattening the
255+
state dicts from each object into one state dict before saving/loading.
256+
We rely on the individual state_dicts to not collide, which is gauranteed for
257+
the model by correct pipeline splitting and for the optimizer by the flattening
258+
support described in (1).
259+
260+
3. LR schedulers also index model states like optimizers and would need to be
261+
flattened properly to support resharding. Unfortunately, the implementations of
262+
different lr_schedulers do not follow a clear pattern like optimizers do, so it's
263+
hard to write a generic 'flattener' utility.
264+
265+
TODO: This is currently unsolved and needs a fix.
266+
"""
267+
self.states = states
268+
self.states.update(
269+
{
270+
"model": ModelWrapper(model_parts),
271+
"optimizer": optimizers,
272+
"dataloader": dataloader,
273+
"lr_scheduler": lr_schedulers,
274+
}
275+
)
276+
247277
def _create_checkpoint_id(self, step: int) -> str:
248278
return os.path.join(self.folder, f"step-{step}")
249279

@@ -326,31 +356,8 @@ def _async_wait(self) -> None:
326356
self.async_future.result()
327357

328358
def _async_with_pinned_memory(self, checkpoint_id: str) -> None:
329-
try:
330-
from torch.distributed._state_dict_utils import (
331-
_copy_state_dict,
332-
_create_cpu_state_dict,
333-
)
334-
except ImportError as e:
335-
raise ImportError(
336-
"Please install the latest PyTorch nightly to use async checkpointing with pinned memory."
337-
) from e
338-
state_dict = dcp.state_dict_saver._stateful_to_state_dict(self.states)
339-
if self.cpu_offload_state_dict is None:
340-
logger.debug(f"Preparing the CPU memory, {time.monotonic()=}.:.2f")
341-
self.cpu_offload_state_dict = _create_cpu_state_dict(
342-
state_dict, pin_memory=True, share_memory=True
343-
)
344-
345-
logger.debug(f"Staging the state_dict, {time.monotonic()=}.:.2f")
346-
with torch.cuda.stream(self.staging_stream):
347-
self.cpu_offload_state_dict = _copy_state_dict(
348-
state_dict,
349-
self.cpu_offload_state_dict,
350-
non_blocking=True,
351-
)
352-
self.staging = True
353-
self.staging_id = checkpoint_id
359+
self.cpu_staging(checkpoint_id)
360+
self.sending_to_checkpoint_mp = True
354361

355362
def save(self, curr_step: int, force: bool = False) -> None:
356363
"""
@@ -360,6 +367,8 @@ def save(self, curr_step: int, force: bool = False) -> None:
360367
for initial seed checkpoint.
361368
"""
362369
if not self._should_save(curr_step, force):
370+
if self.ft_manager:
371+
self.cpu_staging(None)
363372
return
364373

365374
begin = time.monotonic()
@@ -383,26 +392,51 @@ def save(self, curr_step: int, force: bool = False) -> None:
383392
f"in {time.monotonic() - begin:.2f} seconds."
384393
)
385394

395+
def cpu_staging(self, checkpoint_id: Optional[str]) -> None:
396+
"""Offload state_dict to CPU memory"""
397+
state_dict = dcp.state_dict_saver._stateful_to_state_dict(self.states)
398+
if self.cpu_offload_state_dict is None:
399+
logger.debug(f"Preparing the CPU memory, {time.monotonic()=}.:.2f")
400+
self.cpu_offload_state_dict = _create_cpu_state_dict(
401+
state_dict, pin_memory=True, share_memory=True
402+
)
403+
404+
logger.debug(f"Staging the state_dict, {time.monotonic()=}.:.2f")
405+
with torch.cuda.stream(self.staging_stream):
406+
self.cpu_offload_state_dict = _copy_state_dict(
407+
state_dict,
408+
self.cpu_offload_state_dict,
409+
non_blocking=True,
410+
)
411+
self.staging = True
412+
self.staging_id = checkpoint_id
413+
414+
def wait_for_staging(self) -> None:
415+
if not self.staging_stream.query():
416+
self.staging_stream.synchronize()
417+
self.staging = False
418+
419+
def staging_results(self) -> Dict[str, Any]:
420+
self.maybe_wait_for_staging()
421+
return self.cpu_offload_state_dict
422+
386423
def maybe_wait_for_staging(self) -> None:
387-
if (
388-
self.enable_checkpoint
389-
and self.async_mode == AsyncMode.ASYNC_WITH_PINNED_MEM
390-
and self.staging
391-
):
392-
if not self.staging_stream.query():
393-
self.staging_stream.synchronize()
394-
395-
def sync_func():
396-
self.mp_queue_send.put_nowait(
397-
(self.cpu_offload_state_dict, self.staging_id)
398-
)
399-
400-
# This may be a faster way to do zero-overhead checkpointing staging
401-
# checkpointing but we need more thorough investigation before
402-
# swithing to this method.
403-
# self.my_thread = threading.Thread(target=func).start()
404-
sync_func()
405-
self.staging = False
424+
if self.enable_staging and self.staging:
425+
self.wait_for_staging()
426+
427+
if self.sending_to_checkpoint_mp:
428+
# Copy the sync staging result to another process.
429+
def sync_func():
430+
self.mp_queue_send.put_nowait(
431+
(self.cpu_offload_state_dict, self.staging_id)
432+
)
433+
434+
# This may be a faster way to do zero-overhead checkpointing staging
435+
# checkpointing but we need more thorough investigation before
436+
# swithing to this method.
437+
# self.my_thread = threading.Thread(target=func).start()
438+
sync_func()
439+
self.sending_to_checkpoint_mp = False
406440

407441
def load(self, step: int = -1) -> bool:
408442
if not self.enable_checkpoint:

torchtitan/config_manager.py

+13
Original file line numberDiff line numberDiff line change
@@ -620,6 +620,19 @@ def __init__(self):
620620
action="store_true",
621621
)
622622

623+
self.parser.add_argument(
624+
"--experimental.enable_torchft",
625+
action="store_true",
626+
help="Enable TorchFT integration.",
627+
)
628+
629+
self.parser.add_argument(
630+
"--experimental.ft_replica_group_id",
631+
type=int,
632+
default=-1,
633+
help="The FT replicate group of this run.",
634+
)
635+
623636
def to_dict(self):
624637
return self.args_dict
625638

torchtitan/ft.py

+58
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,58 @@
1+
import importlib
2+
from typing import Any, Callable, Optional
3+
4+
from torch.distributed._state_dict_utils import _copy_state_dict, _create_cpu_state_dict
5+
6+
from torchtitan.config_manager import JobConfig
7+
8+
if importlib.util.find_spec("torchft") is not None:
9+
import torchft as ft
10+
11+
has_torchft = True
12+
else:
13+
has_torchft = False
14+
15+
16+
def init_ft_manager(job: JobConfig) -> Optional["ft.Manager"]:
17+
"""
18+
Initialize the FT manager for the given job.
19+
"""
20+
if not job.experimental.enable_torchft:
21+
return None
22+
23+
if not has_torchft:
24+
raise ImportError("torchft is not installed. Please install it.")
25+
26+
pg = ft.ProcessGroupBabyNCCL()
27+
manager = ft.Manager(
28+
pg=pg,
29+
min_replica_size=1,
30+
load_state_dict=None,
31+
state_dict=None,
32+
use_async_quorum=True,
33+
replica_id=f"torchtitan_ft_{job.experimental.ft_replica_group_id}",
34+
)
35+
36+
return manager
37+
38+
39+
def set_ft_state_dict_fns(manager: Optional["ft.Manager"], ckpt_manager) -> None:
40+
"""
41+
Set the state dict for the given manager.
42+
"""
43+
if manager is None:
44+
return
45+
46+
def state_dict():
47+
ret = {}
48+
for k, v in ckpt_manager.staging_results().items():
49+
if k in {"model", "optimizer", "lr_schedulers"}:
50+
ret[k] = v
51+
return ret
52+
53+
def load_state_dict(state_dict):
54+
assert state_dict is not None
55+
for k, v in state_dict.items():
56+
ckpt_manager.states[k].load_state_dict(v)
57+
58+
manager.set_state_dict_fns(load_state_dict, state_dict)

0 commit comments

Comments
 (0)