19
19
import torch .distributed as dist
20
20
import torch .distributed .checkpoint as dcp
21
21
import torch .nn as nn
22
+ from torch .distributed ._state_dict_utils import _copy_state_dict , _create_cpu_state_dict
22
23
from torch .distributed .checkpoint .state_dict import (
23
24
get_model_state_dict ,
24
25
set_model_state_dict ,
@@ -144,13 +145,18 @@ def __init__(
144
145
lr_schedulers : LRSchedulersContainer ,
145
146
states : Dict [str , Any ],
146
147
job_config : JobConfig ,
148
+ ft_manager : Optional [Any ] = None ,
147
149
) -> None :
148
150
ckpt_config = job_config .checkpoint
149
151
self .enable_checkpoint = ckpt_config .enable_checkpoint
150
- self .keep_latest_k = ckpt_config .keep_latest_k
152
+ self .ft_manager = ft_manager
153
+ self .enable_staging = (
154
+ self .enable_checkpoint and async_mode == AsyncMode .ASYNC_WITH_PINNED_MEM
155
+ ) or self .ft_manager
151
156
152
- if not self .enable_checkpoint :
157
+ if not self .enable_checkpoint and self . ft_manager is None :
153
158
return
159
+
154
160
"""
155
161
Note: Pipeline Parallelism and Virtual Stages
156
162
@@ -185,6 +191,13 @@ def __init__(
185
191
}
186
192
)
187
193
194
+ async_mode = ckpt_config .async_mode .lower ()
195
+ self .staging = False
196
+ self .sending_to_checkpoint_mp = False
197
+ self .staging_id = None
198
+ self .cpu_offload_state_dict = None
199
+ self .staging_stream = torch .cuda .Stream () if self .enable_staging else None
200
+
188
201
self .folder = os .path .join (job_config .job .dump_folder , ckpt_config .folder )
189
202
self .interval_type = (
190
203
IntervalType .SECONDS
@@ -199,6 +212,7 @@ def __init__(
199
212
if async_mode == AsyncMode .ASYNC or self .interval_type == IntervalType .SECONDS :
200
213
self .pg = dist .new_group (backend = "gloo" )
201
214
215
+ self .keep_latest_k = ckpt_config .keep_latest_k
202
216
self .model_weights_only = ckpt_config .model_weights_only
203
217
self .export_dtype = TORCH_DTYPE_MAP [ckpt_config .export_dtype ]
204
218
self .exclude_from_loading = ckpt_config .exclude_from_loading
@@ -223,10 +237,6 @@ def __init__(
223
237
daemon = True ,
224
238
)
225
239
self .mp .start ()
226
- self .cpu_offload_state_dict = None
227
- self .staging = False
228
- self .staging_id = None
229
- self .staging_stream = torch .cuda .Stream ()
230
240
else :
231
241
raise ValueError (f"Unkown checkpoint async_mode { ckpt_config .async_mode } " )
232
242
@@ -240,8 +250,61 @@ def __del__(self):
240
250
self .mp .join ()
241
251
242
252
def reset (self ) -> None :
253
+ # We need to stage the local state if another replicate joins during the
254
+ # first step.
255
+ if self .ft_manager :
256
+ self .cpu_staging (None )
243
257
self .begin_time = time .monotonic ()
244
258
259
+ def _initialize_states (
260
+ self ,
261
+ states : Dict [str , Any ],
262
+ dataloader : DataLoader ,
263
+ model_parts : List [nn .Module ],
264
+ optimizers : OptimizersContainer ,
265
+ lr_schedulers : LRSchedulersContainer ,
266
+ ) -> None :
267
+ """
268
+ Note: Pipeline Parallelism and Virtual Stages
269
+
270
+ 1. Even for simple PP schedules, there is a separate optimizer each PP rank.
271
+ rank0's optimizer would have a param_group[0] which refers to layers.0 in the
272
+ original model. rank1's would _also_ have a param_group[0], since it's index based,
273
+ but referring to layers.1.
274
+ When saving, these collide and one of them is lost. Then when reloading, only one
275
+ stage can restore its optimizer states, others will error.
276
+
277
+ The solution to this problem is optimizer flattening: it landed in #127071
278
+ and is enabled in TorchTitan by passing the 'flatten_optimizer_state_dict'
279
+ kwarg to DCP functions called in the OptimizerContainer.
280
+
281
+ 2. With complex PP schedules, we have multiple model chunks per pp rank. This
282
+ compounds challenge (1) by also requiring us to reason about multiple 'optim'
283
+ objects locally.
284
+
285
+ We solve this in the Model and Optimizer wrapper classes by flattening the
286
+ state dicts from each object into one state dict before saving/loading.
287
+ We rely on the individual state_dicts to not collide, which is gauranteed for
288
+ the model by correct pipeline splitting and for the optimizer by the flattening
289
+ support described in (1).
290
+
291
+ 3. LR schedulers also index model states like optimizers and would need to be
292
+ flattened properly to support resharding. Unfortunately, the implementations of
293
+ different lr_schedulers do not follow a clear pattern like optimizers do, so it's
294
+ hard to write a generic 'flattener' utility.
295
+
296
+ TODO: This is currently unsolved and needs a fix.
297
+ """
298
+ self .states = states
299
+ self .states .update (
300
+ {
301
+ "model" : ModelWrapper (model_parts ),
302
+ "optimizer" : optimizers ,
303
+ "dataloader" : dataloader ,
304
+ "lr_scheduler" : lr_schedulers ,
305
+ }
306
+ )
307
+
245
308
def _create_checkpoint_id (self , step : int ) -> str :
246
309
return os .path .join (self .folder , f"step-{ step } " )
247
310
@@ -324,31 +387,8 @@ def _async_wait(self) -> None:
324
387
self .async_future .result ()
325
388
326
389
def _async_with_pinned_memory (self , checkpoint_id : str ) -> None :
327
- try :
328
- from torch .distributed ._state_dict_utils import (
329
- _copy_state_dict ,
330
- _create_cpu_state_dict ,
331
- )
332
- except ImportError as e :
333
- raise ImportError (
334
- "Please install the latest PyTorch nightly to use async checkpointing with pinned memory."
335
- ) from e
336
- state_dict = dcp .state_dict_saver ._stateful_to_state_dict (self .states )
337
- if self .cpu_offload_state_dict is None :
338
- logger .debug (f"Preparing the CPU memory, { time .monotonic ()= } .:.2f" )
339
- self .cpu_offload_state_dict = _create_cpu_state_dict (
340
- state_dict , pin_memory = True , share_memory = True
341
- )
342
-
343
- logger .debug (f"Staging the state_dict, { time .monotonic ()= } .:.2f" )
344
- with torch .cuda .stream (self .staging_stream ):
345
- self .cpu_offload_state_dict = _copy_state_dict (
346
- state_dict ,
347
- self .cpu_offload_state_dict ,
348
- non_blocking = True ,
349
- )
350
- self .staging = True
351
- self .staging_id = checkpoint_id
390
+ self .cpu_staging (checkpoint_id )
391
+ self .sending_to_checkpoint_mp = True
352
392
353
393
def save (self , curr_step : int , force : bool = False ) -> None :
354
394
"""
@@ -358,6 +398,8 @@ def save(self, curr_step: int, force: bool = False) -> None:
358
398
for initial seed checkpoint.
359
399
"""
360
400
if not self ._should_save (curr_step , force ):
401
+ if self .ft_manager :
402
+ self .cpu_staging (None )
361
403
return
362
404
363
405
begin = time .monotonic ()
@@ -381,26 +423,51 @@ def save(self, curr_step: int, force: bool = False) -> None:
381
423
f"in { time .monotonic () - begin :.2f} seconds."
382
424
)
383
425
426
+ def cpu_staging (self , checkpoint_id : Optional [str ]) -> None :
427
+ """Offload state_dict to CPU memory"""
428
+ state_dict = dcp .state_dict_saver ._stateful_to_state_dict (self .states )
429
+ if self .cpu_offload_state_dict is None :
430
+ logger .debug (f"Preparing the CPU memory, { time .monotonic ()= } .:.2f" )
431
+ self .cpu_offload_state_dict = _create_cpu_state_dict (
432
+ state_dict , pin_memory = True , share_memory = True
433
+ )
434
+
435
+ logger .debug (f"Staging the state_dict, { time .monotonic ()= } .:.2f" )
436
+ with torch .cuda .stream (self .staging_stream ):
437
+ self .cpu_offload_state_dict = _copy_state_dict (
438
+ state_dict ,
439
+ self .cpu_offload_state_dict ,
440
+ non_blocking = True ,
441
+ )
442
+ self .staging = True
443
+ self .staging_id = checkpoint_id
444
+
445
+ def wait_for_staging (self ) -> None :
446
+ if not self .staging_stream .query ():
447
+ self .staging_stream .synchronize ()
448
+ self .staging = False
449
+
450
+ def staging_results (self ) -> Dict [str , Any ]:
451
+ self .maybe_wait_for_staging ()
452
+ return self .cpu_offload_state_dict
453
+
384
454
def maybe_wait_for_staging (self ) -> None :
385
- if (
386
- self .enable_checkpoint
387
- and self .async_mode == AsyncMode .ASYNC_WITH_PINNED_MEM
388
- and self .staging
389
- ):
390
- if not self .staging_stream .query ():
391
- self .staging_stream .synchronize ()
392
-
393
- def sync_func ():
394
- self .mp_queue_send .put_nowait (
395
- (self .cpu_offload_state_dict , self .staging_id )
396
- )
397
-
398
- # This may be a faster way to do zero-overhead checkpointing staging
399
- # checkpointing but we need more thorough investigation before
400
- # swithing to this method.
401
- # self.my_thread = threading.Thread(target=func).start()
402
- sync_func ()
403
- self .staging = False
455
+ if self .enable_staging and self .staging :
456
+ self .wait_for_staging ()
457
+
458
+ if self .sending_to_checkpoint_mp :
459
+ # Copy the sync staging result to another process.
460
+ def sync_func ():
461
+ self .mp_queue_send .put_nowait (
462
+ (self .cpu_offload_state_dict , self .staging_id )
463
+ )
464
+
465
+ # This may be a faster way to do zero-overhead checkpointing staging
466
+ # checkpointing but we need more thorough investigation before
467
+ # swithing to this method.
468
+ # self.my_thread = threading.Thread(target=func).start()
469
+ sync_func ()
470
+ self .sending_to_checkpoint_mp = False
404
471
405
472
def load (self , step : int = - 1 ) -> bool :
406
473
if not self .enable_checkpoint :
0 commit comments