19
19
import torch .distributed as dist
20
20
import torch .distributed .checkpoint as dcp
21
21
import torch .nn as nn
22
+ from torch .distributed ._state_dict_utils import _copy_state_dict , _create_cpu_state_dict
22
23
from torch .distributed .checkpoint .state_dict import (
23
24
get_model_state_dict ,
24
25
set_model_state_dict ,
@@ -151,13 +152,18 @@ def __init__(
151
152
lr_schedulers : LRSchedulersContainer ,
152
153
states : Dict [str , Any ],
153
154
job_config : JobConfig ,
155
+ ft_manager : Optional [Any ] = None ,
154
156
) -> None :
155
157
ckpt_config = job_config .checkpoint
156
158
self .enable_checkpoint = ckpt_config .enable_checkpoint
157
- self .keep_latest_k = ckpt_config .keep_latest_k
159
+ self .ft_manager = ft_manager
160
+ self .enable_staging = (
161
+ self .enable_checkpoint and async_mode == AsyncMode .ASYNC_WITH_PINNED_MEM
162
+ ) or self .ft_manager
158
163
159
- if not self .enable_checkpoint :
164
+ if not self .enable_checkpoint and self . ft_manager is None :
160
165
return
166
+
161
167
"""
162
168
Note: Pipeline Parallelism and Virtual Stages
163
169
@@ -192,6 +198,13 @@ def __init__(
192
198
}
193
199
)
194
200
201
+ async_mode = ckpt_config .async_mode .lower ()
202
+ self .staging = False
203
+ self .sending_to_checkpoint_mp = False
204
+ self .staging_id = None
205
+ self .cpu_offload_state_dict = None
206
+ self .staging_stream = torch .cuda .Stream () if self .enable_staging else None
207
+
195
208
self .folder = os .path .join (job_config .job .dump_folder , ckpt_config .folder )
196
209
self .interval_type = (
197
210
IntervalType .SECONDS
@@ -206,6 +219,7 @@ def __init__(
206
219
if async_mode == AsyncMode .ASYNC or self .interval_type == IntervalType .SECONDS :
207
220
self .pg = dist .new_group (backend = "gloo" )
208
221
222
+ self .keep_latest_k = ckpt_config .keep_latest_k
209
223
self .model_weights_only = ckpt_config .model_weights_only
210
224
self .export_dtype = TORCH_DTYPE_MAP [ckpt_config .export_dtype ]
211
225
self .exclude_from_loading = ckpt_config .exclude_from_loading
@@ -230,10 +244,6 @@ def __init__(
230
244
daemon = True ,
231
245
)
232
246
self .mp .start ()
233
- self .cpu_offload_state_dict = None
234
- self .staging = False
235
- self .staging_id = None
236
- self .staging_stream = torch .cuda .Stream ()
237
247
else :
238
248
raise ValueError (f"Unkown checkpoint async_mode { ckpt_config .async_mode } " )
239
249
@@ -247,8 +257,61 @@ def __del__(self):
247
257
self .mp .join ()
248
258
249
259
def reset (self ) -> None :
260
+ # We need to stage the local state if another replicate joins during the
261
+ # first step.
262
+ if self .ft_manager :
263
+ self .cpu_staging (None )
250
264
self .begin_time = time .monotonic ()
251
265
266
+ def _initialize_states (
267
+ self ,
268
+ states : Dict [str , Any ],
269
+ dataloader : DataLoader ,
270
+ model_parts : List [nn .Module ],
271
+ optimizers : OptimizersContainer ,
272
+ lr_schedulers : LRSchedulersContainer ,
273
+ ) -> None :
274
+ """
275
+ Note: Pipeline Parallelism and Virtual Stages
276
+
277
+ 1. Even for simple PP schedules, there is a separate optimizer each PP rank.
278
+ rank0's optimizer would have a param_group[0] which refers to layers.0 in the
279
+ original model. rank1's would _also_ have a param_group[0], since it's index based,
280
+ but referring to layers.1.
281
+ When saving, these collide and one of them is lost. Then when reloading, only one
282
+ stage can restore its optimizer states, others will error.
283
+
284
+ The solution to this problem is optimizer flattening: it landed in #127071
285
+ and is enabled in TorchTitan by passing the 'flatten_optimizer_state_dict'
286
+ kwarg to DCP functions called in the OptimizerContainer.
287
+
288
+ 2. With complex PP schedules, we have multiple model chunks per pp rank. This
289
+ compounds challenge (1) by also requiring us to reason about multiple 'optim'
290
+ objects locally.
291
+
292
+ We solve this in the Model and Optimizer wrapper classes by flattening the
293
+ state dicts from each object into one state dict before saving/loading.
294
+ We rely on the individual state_dicts to not collide, which is gauranteed for
295
+ the model by correct pipeline splitting and for the optimizer by the flattening
296
+ support described in (1).
297
+
298
+ 3. LR schedulers also index model states like optimizers and would need to be
299
+ flattened properly to support resharding. Unfortunately, the implementations of
300
+ different lr_schedulers do not follow a clear pattern like optimizers do, so it's
301
+ hard to write a generic 'flattener' utility.
302
+
303
+ TODO: This is currently unsolved and needs a fix.
304
+ """
305
+ self .states = states
306
+ self .states .update (
307
+ {
308
+ "model" : ModelWrapper (model_parts ),
309
+ "optimizer" : optimizers ,
310
+ "dataloader" : dataloader ,
311
+ "lr_scheduler" : lr_schedulers ,
312
+ }
313
+ )
314
+
252
315
def _create_checkpoint_id (self , step : int ) -> str :
253
316
return os .path .join (self .folder , f"step-{ step } " )
254
317
@@ -331,31 +394,8 @@ def _async_wait(self) -> None:
331
394
self .async_future .result ()
332
395
333
396
def _async_with_pinned_memory (self , checkpoint_id : str ) -> None :
334
- try :
335
- from torch .distributed ._state_dict_utils import (
336
- _copy_state_dict ,
337
- _create_cpu_state_dict ,
338
- )
339
- except ImportError as e :
340
- raise ImportError (
341
- "Please install the latest PyTorch nightly to use async checkpointing with pinned memory."
342
- ) from e
343
- state_dict = dcp .state_dict_saver ._stateful_to_state_dict (self .states )
344
- if self .cpu_offload_state_dict is None :
345
- logger .debug (f"Preparing the CPU memory, { time .monotonic ()= } .:.2f" )
346
- self .cpu_offload_state_dict = _create_cpu_state_dict (
347
- state_dict , pin_memory = True , share_memory = True
348
- )
349
-
350
- logger .debug (f"Staging the state_dict, { time .monotonic ()= } .:.2f" )
351
- with torch .cuda .stream (self .staging_stream ):
352
- self .cpu_offload_state_dict = _copy_state_dict (
353
- state_dict ,
354
- self .cpu_offload_state_dict ,
355
- non_blocking = True ,
356
- )
357
- self .staging = True
358
- self .staging_id = checkpoint_id
397
+ self .cpu_staging (checkpoint_id )
398
+ self .sending_to_checkpoint_mp = True
359
399
360
400
def save (self , curr_step : int , force : bool = False ) -> None :
361
401
"""
@@ -365,6 +405,8 @@ def save(self, curr_step: int, force: bool = False) -> None:
365
405
for initial seed checkpoint.
366
406
"""
367
407
if not self ._should_save (curr_step , force ):
408
+ if self .ft_manager :
409
+ self .cpu_staging (None )
368
410
return
369
411
370
412
begin = time .monotonic ()
@@ -393,26 +435,51 @@ def save(self, curr_step: int, force: bool = False) -> None:
393
435
f"in { time .monotonic () - begin :.2f} seconds."
394
436
)
395
437
438
+ def cpu_staging (self , checkpoint_id : Optional [str ]) -> None :
439
+ """Offload state_dict to CPU memory"""
440
+ state_dict = dcp .state_dict_saver ._stateful_to_state_dict (self .states )
441
+ if self .cpu_offload_state_dict is None :
442
+ logger .debug (f"Preparing the CPU memory, { time .monotonic ()= } .:.2f" )
443
+ self .cpu_offload_state_dict = _create_cpu_state_dict (
444
+ state_dict , pin_memory = True , share_memory = True
445
+ )
446
+
447
+ logger .debug (f"Staging the state_dict, { time .monotonic ()= } .:.2f" )
448
+ with torch .cuda .stream (self .staging_stream ):
449
+ self .cpu_offload_state_dict = _copy_state_dict (
450
+ state_dict ,
451
+ self .cpu_offload_state_dict ,
452
+ non_blocking = True ,
453
+ )
454
+ self .staging = True
455
+ self .staging_id = checkpoint_id
456
+
457
+ def wait_for_staging (self ) -> None :
458
+ if not self .staging_stream .query ():
459
+ self .staging_stream .synchronize ()
460
+ self .staging = False
461
+
462
+ def staging_results (self ) -> Dict [str , Any ]:
463
+ self .maybe_wait_for_staging ()
464
+ return self .cpu_offload_state_dict
465
+
396
466
def maybe_wait_for_staging (self ) -> None :
397
- if (
398
- self .enable_checkpoint
399
- and self .async_mode == AsyncMode .ASYNC_WITH_PINNED_MEM
400
- and self .staging
401
- ):
402
- if not self .staging_stream .query ():
403
- self .staging_stream .synchronize ()
404
-
405
- def sync_func ():
406
- self .mp_queue_send .put_nowait (
407
- (self .cpu_offload_state_dict , self .staging_id )
408
- )
409
-
410
- # This may be a faster way to do zero-overhead checkpointing staging
411
- # checkpointing but we need more thorough investigation before
412
- # swithing to this method.
413
- # self.my_thread = threading.Thread(target=func).start()
414
- sync_func ()
415
- self .staging = False
467
+ if self .enable_staging and self .staging :
468
+ self .wait_for_staging ()
469
+
470
+ if self .sending_to_checkpoint_mp :
471
+ # Copy the sync staging result to another process.
472
+ def sync_func ():
473
+ self .mp_queue_send .put_nowait (
474
+ (self .cpu_offload_state_dict , self .staging_id )
475
+ )
476
+
477
+ # This may be a faster way to do zero-overhead checkpointing staging
478
+ # checkpointing but we need more thorough investigation before
479
+ # swithing to this method.
480
+ # self.my_thread = threading.Thread(target=func).start()
481
+ sync_func ()
482
+ self .sending_to_checkpoint_mp = False
416
483
417
484
def load (self , step : int = - 1 ) -> bool :
418
485
if not self .enable_checkpoint :
0 commit comments