@@ -214,6 +214,19 @@ class CheckpointManager:
214
214
3. LR schedulers also index model states like optimizers. Here we flatten the lr_schedulers
215
215
with the assumption that all lr_schedulers have the same state_dict.
216
216
217
+ Note: TorchFT checkpointing flow
218
+
219
+ There are two types of checkpoints: when TorchFT is enabled: 1) the full perisistent
220
+ checkpoint, 2) the per-replica checkpoint.
221
+
222
+ The full perisistent checkpoint is saved by the replica with
223
+ ``ft_manager.participating_rank() == 0``. It contains everything including the model,
224
+ optimizer, lr_scheduler, dataloader, and train_state. Right now the full perisistent
225
+ checkpoint is loaded by all replicas. However, we can optimize it to only load if
226
+ there are no other alive replicas.
227
+
228
+ The per-replica checkpoint contains only the dataloader and is saved/loaded by all
229
+ replicas to/from the its own folder. The folder name is prefixed with the ft_replica_id.
217
230
218
231
Args:
219
232
dataloader (DataLoader): The dataloader used to load the data.
@@ -223,6 +236,7 @@ class CheckpointManager:
223
236
states (Dict[str, Any]): The states that need to be saved, other than the
224
237
previous 4 components.
225
238
job_config (JobConfig): The job config used to configure the checkpointing.
239
+ ft_manager (Optional[ft.Manager]): The FTManager from TorchFT.
226
240
"""
227
241
228
242
def __init__ (
@@ -233,16 +247,41 @@ def __init__(
233
247
lr_schedulers : LRSchedulersContainer ,
234
248
states : Dict [str , Any ],
235
249
job_config : JobConfig ,
250
+ ft_manager : FTManager ,
236
251
) -> None :
237
252
ckpt_config = job_config .checkpoint
238
253
self .enable_checkpoint = ckpt_config .enable_checkpoint
254
+ self .ft_manager = ft_manager .manager if ft_manager .enabled else None
255
+
256
+ if self .ft_manager :
257
+ optimizers .init_cache_state_dict ()
258
+
259
+ def state_dict ():
260
+ ret = {}
261
+ for k , v in self .states .items ():
262
+ if k in {
263
+ MODEL ,
264
+ OPTIMIZER ,
265
+ LR_SCHEDULER ,
266
+ TRAIN_STATE ,
267
+ }:
268
+ ret [k ] = v .state_dict ()
269
+ return ret
270
+
271
+ def load_state_dict (state_dict ):
272
+ assert state_dict is not None
273
+ for k , v in state_dict .items ():
274
+ self .states [k ].load_state_dict (v )
275
+
276
+ self .ft_manager .set_state_dict_fns (load_state_dict , state_dict )
277
+ self .ft_replica_id = job_config .experimental .ft_replica_id
239
278
240
279
async_mode = ckpt_config .async_mode .lower ()
241
280
self .enable_staging = (
242
281
self .enable_checkpoint and async_mode == AsyncMode .ASYNC_WITH_PINNED_MEM
243
- )
282
+ ) or self . ft_manager
244
283
245
- if not self .enable_checkpoint :
284
+ if not self .enable_checkpoint and self . ft_manager is None :
246
285
return
247
286
248
287
self .states = states
@@ -254,6 +293,13 @@ def __init__(
254
293
LR_SCHEDULER : lr_schedulers ,
255
294
}
256
295
)
296
+ self .ft_states = {DATALOADER : dataloader }
297
+
298
+ self .staging = False
299
+ self .sending_to_checkpoint_mp = False
300
+ self .staging_id = None
301
+ self .cpu_offload_state_dict = None
302
+ self .staging_stream = torch .cuda .Stream () if self .enable_staging else None
257
303
258
304
self .staging = False
259
305
self .sending_to_checkpoint_mp = False
@@ -264,7 +310,7 @@ def __init__(
264
310
self .folder = os .path .join (job_config .job .dump_folder , ckpt_config .folder )
265
311
self .interval = ckpt_config .interval
266
312
async_mode = ckpt_config .async_mode .lower ()
267
- if async_mode == AsyncMode .ASYNC :
313
+ if async_mode == AsyncMode .ASYNC or self . ft_manager :
268
314
self .pg = dist .new_group (backend = "gloo" )
269
315
270
316
self .keep_latest_k = ckpt_config .keep_latest_k
@@ -339,35 +385,44 @@ def save(self, curr_step: int, force: bool = False) -> None:
339
385
None
340
386
"""
341
387
388
+ if self .ft_manager :
389
+ self ._ft_save (curr_step )
390
+
342
391
if not self ._should_save (curr_step , force ):
343
392
return
344
393
345
394
begin = time .monotonic ()
346
- logger .info ("Saving the checkpoint (or staging if async is enabled)." )
347
- checkpoint_id = self ._create_checkpoint_id (curr_step )
348
- self ._async_wait ()
349
- # This GC is called for async checkpoint as it is useless to do
350
- # GC right after async_save -- the CPU memory is not able to be
351
- # freed until _async_wait()
352
- if force :
353
- self ._save_last_step (curr_step )
354
- elif self .async_mode == AsyncMode .ASYNC_WITH_PINNED_MEM :
355
- GarbageCollection .collect ("GC collection invoked by checkpointer." )
356
- self ._async_with_pinned_memory (checkpoint_id )
357
- elif self .async_mode == AsyncMode .ASYNC :
358
- GarbageCollection .collect ("GC collection invoked by checkpointer." )
359
- self .async_future = dcp .async_save (
360
- self .states , checkpoint_id = checkpoint_id , process_group = self .pg
361
- )
362
- GarbageCollection .collect ("GC collection invoked by checkpointer." )
363
- else :
364
- save_with_gc (self .states , checkpoint_id = checkpoint_id )
365
- self ._purge_stale_checkpoints ()
395
+ if not self .ft_manager or self .ft_manager .participating_rank () == 0 :
396
+ logger .info ("Saving the checkpoint (or staging if async is enabled)." )
397
+ checkpoint_id = self ._create_checkpoint_id (curr_step )
398
+ self ._async_wait ()
399
+ # This GC is called for async checkpoint as it is useless to do
400
+ # GC right after async_save -- the CPU memory is not able to be
401
+ # freed until _async_wait()
402
+ if force :
403
+ self ._save_last_step (curr_step )
404
+ elif self .async_mode == AsyncMode .ASYNC_WITH_PINNED_MEM :
405
+ GarbageCollection .collect ("GC collection invoked by checkpointer." )
406
+ self ._async_with_pinned_memory (checkpoint_id )
407
+ elif self .async_mode == AsyncMode .ASYNC :
408
+ GarbageCollection .collect ("GC collection invoked by checkpointer." )
409
+ self .async_future = dcp .async_save (
410
+ self .states , checkpoint_id = checkpoint_id , process_group = self .pg
411
+ )
412
+ GarbageCollection .collect ("GC collection invoked by checkpointer." )
413
+ else :
414
+ save_with_gc (self .states , checkpoint_id = checkpoint_id )
415
+ self ._purge_stale_checkpoints ()
366
416
367
- logger .info (
368
- "Finished saving the checkpoint (or staging if async is enabled)"
369
- f"in { time .monotonic () - begin :.2f} seconds."
370
- )
417
+ logger .info (
418
+ "Finished saving the checkpoint (or staging if async is enabled)"
419
+ f"in { time .monotonic () - begin :.2f} seconds."
420
+ )
421
+ elif self .ft_manager :
422
+ logger .info (
423
+ "Replica %d doesn't save checkpoint." ,
424
+ self .ft_manager .participating_rank (),
425
+ )
371
426
372
427
@torch .no_grad ()
373
428
def load (self , step : int = - 1 ) -> bool :
@@ -384,6 +439,9 @@ def load(self, step: int = -1) -> bool:
384
439
bool: Whether the checkpoint was loaded successfully.
385
440
"""
386
441
442
+ if self .ft_manager :
443
+ self ._ft_load ()
444
+
387
445
if not self .enable_checkpoint or not os .path .isdir (self .folder ):
388
446
return False
389
447
@@ -467,10 +525,36 @@ def _find_load_step(self, folder: str = "") -> int:
467
525
return - 1
468
526
return max (step_counts )
469
527
528
+ def _ft_folder (self ) -> str :
529
+ return os .path .join (self .folder , f"ft-replicat-{ self .ft_replica_id } " )
530
+
470
531
def _create_checkpoint_id (self , step : int , folder : str = "" ) -> str :
471
532
folder = folder if folder else self .folder
472
533
return os .path .join (folder , f"step-{ step } " )
473
534
535
+ def _ft_save (self , step : int ) -> None :
536
+ begin = time .monotonic ()
537
+ self ._async_wait ()
538
+ checkpoint_id = self ._create_checkpoint_id (step , folder = self ._ft_folder ())
539
+ self .async_future = dcp .async_save (
540
+ self .ft_states , checkpoint_id = checkpoint_id , process_group = self .pg
541
+ )
542
+ logger .info (f"Staging ft checkpoint took { time .monotonic () - begin } secs." )
543
+
544
+ def _ft_load (self ) -> None :
545
+ step = self ._find_load_step (folder = self ._ft_folder ())
546
+ if step == - 1 :
547
+ return
548
+
549
+ begin = time .monotonic ()
550
+ logger .info (f"Loading the FT checkpoint at step { step } ." )
551
+ checkpoint_id = self ._create_checkpoint_id (step , folder = self ._ft_folder ())
552
+ dcp .load (self .ft_states , checkpoint_id = checkpoint_id )
553
+ GarbageCollection .collect ("GC collection for checkpoint loading." )
554
+ logger .info (
555
+ f"Finished loading the ft checkpoint in { time .monotonic () - begin :.2f} seconds."
556
+ )
557
+
474
558
def _states_to_load (self , step : int ) -> Dict [str , Any ]:
475
559
"""Determines which states to load for the given step.
476
560
@@ -491,6 +575,8 @@ def _states_to_load(self, step: int) -> Dict[str, Any]:
491
575
for exclude_key in self .exclude_from_loading :
492
576
if exclude_key not in states :
493
577
raise ValueError (f"{ exclude_key } not found in state_dict." )
578
+ if self .ft_manager :
579
+ states_to_load .pop (DATALOADER )
494
580
return states_to_load
495
581
496
582
def _save_last_step (self , curr_step : int ) -> None :
@@ -577,6 +663,7 @@ def _purge_stale_checkpoints(self):
577
663
self .keep_latest_k > 0
578
664
and dist .get_rank () == 0
579
665
and os .path .isdir (self .folder )
666
+ and (not self .ft_manager or self .ft_manager .participating_rank () == 0 )
580
667
):
581
668
discovered_checkpoints = []
582
669
for filename in os .listdir (self .folder ):
0 commit comments