16
16
from dataclasses import dataclass , field
17
17
from io import BytesIO
18
18
from multiprocessing import get_context
19
- from typing import Any , Dict , List , Optional , Union
19
+ from typing import Any , Dict , List , Optional , TYPE_CHECKING , Union
20
20
21
21
import torch
22
22
import torch .distributed as dist
36
36
from torchtitan .tools .logging import init_logger , logger
37
37
from torchtitan .tools .utils import GarbageCollection
38
38
39
+ if TYPE_CHECKING :
40
+ import torchft as ft
41
+
39
42
40
43
MODEL = "model"
41
44
OPTIMIZER = "optimizer"
@@ -214,6 +217,19 @@ class CheckpointManager:
214
217
3. LR schedulers also index model states like optimizers. Here we flatten the lr_schedulers
215
218
with the assumption that all lr_schedulers have the same state_dict.
216
219
220
+ Note: TorchFT checkpointing flow
221
+
222
+ There are two types of checkpoints: when TorchFT is enabled: 1) the full perisistent
223
+ checkpoint, 2) the per-replica checkpoint.
224
+
225
+ The full perisistent checkpoint is saved by the replica with
226
+ ``ft_manager.participating_rank() == 0``. It contains everything including the model,
227
+ optimizer, lr_scheduler, dataloader, and train_state. Right now the full perisistent
228
+ checkpoint is loaded by all replicas. However, we can optimize it to only load if
229
+ there are no other alive replicas.
230
+
231
+ The per-replica checkpoint contains only the dataloader and is saved/loaded by all
232
+ replicas to/from the its own folder. The folder name is prefixed with the ft_replica_id.
217
233
218
234
Args:
219
235
dataloader (DataLoader): The dataloader used to load the data.
@@ -223,6 +239,7 @@ class CheckpointManager:
223
239
states (Dict[str, Any]): The states that need to be saved, other than the
224
240
previous 4 components.
225
241
job_config (JobConfig): The job config used to configure the checkpointing.
242
+ ft_manager (Optional[ft.Manager]): The FTManager from TorchFT.
226
243
"""
227
244
228
245
def __init__ (
@@ -233,16 +250,41 @@ def __init__(
233
250
lr_schedulers : LRSchedulersContainer ,
234
251
states : Dict [str , Any ],
235
252
job_config : JobConfig ,
253
+ ft_manager : Optional ["ft.Manager" ] = None ,
236
254
) -> None :
237
255
ckpt_config = job_config .checkpoint
238
256
self .enable_checkpoint = ckpt_config .enable_checkpoint
257
+ self .ft_manager = ft_manager
258
+
259
+ if self .ft_manager :
260
+ optimizers .init_cache_state_dict ()
261
+
262
+ def state_dict ():
263
+ ret = {}
264
+ for k , v in self .states .items ():
265
+ if k in {
266
+ MODEL ,
267
+ OPTIMIZER ,
268
+ LR_SCHEDULER ,
269
+ TRAIN_STATE ,
270
+ }:
271
+ ret [k ] = v .state_dict ()
272
+ return ret
273
+
274
+ def load_state_dict (state_dict ):
275
+ assert state_dict is not None
276
+ for k , v in state_dict .items ():
277
+ self .states [k ].load_state_dict (v )
278
+
279
+ ft_manager .set_state_dict_fns (load_state_dict , state_dict )
280
+ self .ft_replica_id = job_config .experimental .ft_replica_id
239
281
240
282
async_mode = ckpt_config .async_mode .lower ()
241
283
self .enable_staging = (
242
284
self .enable_checkpoint and async_mode == AsyncMode .ASYNC_WITH_PINNED_MEM
243
- )
285
+ ) or self . ft_manager
244
286
245
- if not self .enable_checkpoint :
287
+ if not self .enable_checkpoint and self . ft_manager is None :
246
288
return
247
289
248
290
self .states = states
@@ -254,6 +296,13 @@ def __init__(
254
296
LR_SCHEDULER : lr_schedulers ,
255
297
}
256
298
)
299
+ self .ft_states = {DATALOADER : dataloader }
300
+
301
+ self .staging = False
302
+ self .sending_to_checkpoint_mp = False
303
+ self .staging_id = None
304
+ self .cpu_offload_state_dict = None
305
+ self .staging_stream = torch .cuda .Stream () if self .enable_staging else None
257
306
258
307
self .staging = False
259
308
self .sending_to_checkpoint_mp = False
@@ -264,7 +313,7 @@ def __init__(
264
313
self .folder = os .path .join (job_config .job .dump_folder , ckpt_config .folder )
265
314
self .interval = ckpt_config .interval
266
315
async_mode = ckpt_config .async_mode .lower ()
267
- if async_mode == AsyncMode .ASYNC :
316
+ if async_mode == AsyncMode .ASYNC or self . ft_manager :
268
317
self .pg = dist .new_group (backend = "gloo" )
269
318
270
319
self .keep_latest_k = ckpt_config .keep_latest_k
@@ -339,35 +388,44 @@ def save(self, curr_step: int, force: bool = False) -> None:
339
388
None
340
389
"""
341
390
391
+ if self .ft_manager :
392
+ self ._ft_save (curr_step )
393
+
342
394
if not self ._should_save (curr_step , force ):
343
395
return
344
396
345
397
begin = time .monotonic ()
346
- logger .info ("Saving the checkpoint (or staging if async is enabled)." )
347
- checkpoint_id = self ._create_checkpoint_id (curr_step )
348
- self ._async_wait ()
349
- # This GC is called for async checkpoint as it is useless to do
350
- # GC right after async_save -- the CPU memory is not able to be
351
- # freed until _async_wait()
352
- if force :
353
- self ._save_last_step (curr_step )
354
- elif self .async_mode == AsyncMode .ASYNC_WITH_PINNED_MEM :
355
- GarbageCollection .collect ("GC collection invoked by checkpointer." )
356
- self ._async_with_pinned_memory (checkpoint_id )
357
- elif self .async_mode == AsyncMode .ASYNC :
358
- GarbageCollection .collect ("GC collection invoked by checkpointer." )
359
- self .async_future = dcp .async_save (
360
- self .states , checkpoint_id = checkpoint_id , process_group = self .pg
361
- )
362
- GarbageCollection .collect ("GC collection invoked by checkpointer." )
363
- else :
364
- save_with_gc (self .states , checkpoint_id = checkpoint_id )
365
- self ._purge_stale_checkpoints ()
398
+ if not self .ft_manager or self .ft_manager .participating_rank () == 0 :
399
+ logger .info ("Saving the checkpoint (or staging if async is enabled)." )
400
+ checkpoint_id = self ._create_checkpoint_id (curr_step )
401
+ self ._async_wait ()
402
+ # This GC is called for async checkpoint as it is useless to do
403
+ # GC right after async_save -- the CPU memory is not able to be
404
+ # freed until _async_wait()
405
+ if force :
406
+ self ._save_last_step (curr_step )
407
+ elif self .async_mode == AsyncMode .ASYNC_WITH_PINNED_MEM :
408
+ GarbageCollection .collect ("GC collection invoked by checkpointer." )
409
+ self ._async_with_pinned_memory (checkpoint_id )
410
+ elif self .async_mode == AsyncMode .ASYNC :
411
+ GarbageCollection .collect ("GC collection invoked by checkpointer." )
412
+ self .async_future = dcp .async_save (
413
+ self .states , checkpoint_id = checkpoint_id , process_group = self .pg
414
+ )
415
+ GarbageCollection .collect ("GC collection invoked by checkpointer." )
416
+ else :
417
+ save_with_gc (self .states , checkpoint_id = checkpoint_id )
418
+ self ._purge_stale_checkpoints ()
366
419
367
- logger .info (
368
- "Finished saving the checkpoint (or staging if async is enabled)"
369
- f"in { time .monotonic () - begin :.2f} seconds."
370
- )
420
+ logger .info (
421
+ "Finished saving the checkpoint (or staging if async is enabled)"
422
+ f"in { time .monotonic () - begin :.2f} seconds."
423
+ )
424
+ elif self .ft_manager :
425
+ logger .info (
426
+ "Replica %d doesn't save checkpoint." ,
427
+ self .ft_manager .participating_rank (),
428
+ )
371
429
372
430
@torch .no_grad ()
373
431
def load (self , step : int = - 1 ) -> bool :
@@ -384,6 +442,9 @@ def load(self, step: int = -1) -> bool:
384
442
bool: Whether the checkpoint was loaded successfully.
385
443
"""
386
444
445
+ if self .ft_manager :
446
+ self ._ft_load ()
447
+
387
448
if not self .enable_checkpoint or not os .path .isdir (self .folder ):
388
449
return False
389
450
@@ -467,10 +528,36 @@ def _find_load_step(self, folder: str = "") -> int:
467
528
return - 1
468
529
return max (step_counts )
469
530
531
+ def _ft_folder (self ) -> str :
532
+ return os .path .join (self .folder , f"ft-replicat-{ self .ft_replica_id } " )
533
+
470
534
def _create_checkpoint_id (self , step : int , folder : str = "" ) -> str :
471
535
folder = folder if folder else self .folder
472
536
return os .path .join (folder , f"step-{ step } " )
473
537
538
+ def _ft_save (self , step : int ) -> None :
539
+ begin = time .monotonic ()
540
+ self ._async_wait ()
541
+ checkpoint_id = self ._create_checkpoint_id (step , folder = self ._ft_folder ())
542
+ self .async_future = dcp .async_save (
543
+ self .ft_states , checkpoint_id = checkpoint_id , process_group = self .pg
544
+ )
545
+ logger .info (f"Staging ft checkpoint took { time .monotonic () - begin } secs." )
546
+
547
+ def _ft_load (self ) -> None :
548
+ step = self ._find_load_step (folder = self ._ft_folder ())
549
+ if step == - 1 :
550
+ return
551
+
552
+ begin = time .monotonic ()
553
+ logger .info (f"Loading the FT checkpoint at step { step } ." )
554
+ checkpoint_id = self ._create_checkpoint_id (step , folder = self ._ft_folder ())
555
+ dcp .load (self .ft_states , checkpoint_id = checkpoint_id )
556
+ GarbageCollection .collect ("GC collection for checkpoint loading." )
557
+ logger .info (
558
+ f"Finished loading the ft checkpoint in { time .monotonic () - begin :.2f} seconds."
559
+ )
560
+
474
561
def _states_to_load (self , step : int ) -> Dict [str , Any ]:
475
562
"""Determines which states to load for the given step.
476
563
@@ -491,6 +578,8 @@ def _states_to_load(self, step: int) -> Dict[str, Any]:
491
578
for exclude_key in self .exclude_from_loading :
492
579
if exclude_key not in states :
493
580
raise ValueError (f"{ exclude_key } not found in state_dict." )
581
+ if self .ft_manager :
582
+ states_to_load .pop (DATALOADER )
494
583
return states_to_load
495
584
496
585
def _save_last_step (self , curr_step : int ) -> None :
@@ -577,6 +666,7 @@ def _purge_stale_checkpoints(self):
577
666
self .keep_latest_k > 0
578
667
and dist .get_rank () == 0
579
668
and os .path .isdir (self .folder )
669
+ and (not self .ft_manager or self .ft_manager .participating_rank () == 0 )
580
670
):
581
671
discovered_checkpoints = []
582
672
for filename in os .listdir (self .folder ):
0 commit comments