19
19
import torch .distributed as dist
20
20
import torch .distributed .checkpoint as dcp
21
21
import torch .nn as nn
22
+ from torch .distributed ._state_dict_utils import _copy_state_dict , _create_cpu_state_dict
22
23
from torch .distributed .checkpoint .state_dict import (
23
24
get_model_state_dict ,
24
25
set_model_state_dict ,
@@ -144,50 +145,29 @@ def __init__(
144
145
lr_schedulers : LRSchedulersContainer ,
145
146
states : Dict [str , Any ],
146
147
job_config : JobConfig ,
148
+ ft_manager : Optional [Any ] = None ,
147
149
) -> None :
148
150
ckpt_config = job_config .checkpoint
149
151
self .enable_checkpoint = ckpt_config .enable_checkpoint
150
- self .keep_latest_k = ckpt_config .keep_latest_k
152
+ self .ft_manager = ft_manager
153
+ self .enable_staging = (
154
+ self .enable_checkpoint and async_mode == AsyncMode .ASYNC_WITH_PINNED_MEM
155
+ ) or self .ft_manager
151
156
152
- if not self .enable_checkpoint :
157
+ if not self .enable_checkpoint and self . ft_manager is None :
153
158
return
154
- """
155
- Note: Pipeline Parallelism and Virtual Stages
156
-
157
- 1. even for simple PP schedules, there is a separate optimizer each PP rank.
158
- rank0's optimizer would have a param_group[0] which refers to layers.0 in the original model.
159
- rank1's would _also_ have a param_group[0], since it's index based, but referring to layers.1.
160
- When saving, these collide and one of them is lost. Then when reloading, only one stage can
161
- restore its optimizer states, others will error.
162
-
163
- The solution to this problem is optimizer flattening: it landed in #127071 and is enabled in TorchTitan
164
- by passing the 'flatten_optimizer_state_dict' kwarg to DCP functions called in the OptimizerContainer.
165
-
166
- 2. With complex PP schedules, we have multiple model chunks per pp rank. This compounds challenge (1) by also
167
- requiring us to reason about multiple 'optim' objects locally.
168
-
169
- We solve this in the Model and Optimizer wrapper classes by flattening the state dicts from each object
170
- into one state dict before saving/loading. We rely on the individual state_dicts to not collide,
171
- which is gauranteed for the model by correct pipeline splitting and for the optimizer by the flattening
172
- support described in (1).
173
-
174
- 3. LR schedulers also index model states like optimizers and would need to be flattened properly to support
175
- resharding. Unfortunately, the implementations of different lr_schedulers do not follow a clear pattern like
176
- optimizers do, so it's hard to write a generic 'flattener' utility.
177
-
178
- TODO: This is currently unsolved and needs a fix.
179
- """
180
- self .states = states
181
159
182
- self .states .update (
183
- {
184
- "model" : ModelWrapper (model_parts ),
185
- "optimizer" : optimizers ,
186
- "dataloader" : dataloader ,
187
- "lr_scheduler" : lr_schedulers ,
188
- }
160
+ self ._initialize_states (
161
+ states , dataloader , model_parts , optimizers , lr_schedulers
189
162
)
190
163
164
+ async_mode = ckpt_config .async_mode .lower ()
165
+ self .staging = False
166
+ self .sending_to_checkpoint_mp = False
167
+ self .staging_id = None
168
+ self .cpu_offload_state_dict = None
169
+ self .staging_stream = torch .cuda .Stream () if self .enable_staging else None
170
+
191
171
self .folder = os .path .join (job_config .job .dump_folder , ckpt_config .folder )
192
172
self .interval_type = (
193
173
IntervalType .SECONDS
@@ -202,6 +182,7 @@ def __init__(
202
182
if async_mode == AsyncMode .ASYNC or self .interval_type == IntervalType .SECONDS :
203
183
self .pg = dist .new_group (backend = "gloo" )
204
184
185
+ self .keep_latest_k = ckpt_config .keep_latest_k
205
186
self .model_weights_only = ckpt_config .model_weights_only
206
187
self .export_dtype = TORCH_DTYPE_MAP [ckpt_config .export_dtype ]
207
188
@@ -225,10 +206,6 @@ def __init__(
225
206
daemon = True ,
226
207
)
227
208
self .mp .start ()
228
- self .cpu_offload_state_dict = None
229
- self .staging = False
230
- self .staging_id = None
231
- self .staging_stream = torch .cuda .Stream ()
232
209
else :
233
210
raise ValueError (f"Unkown checkpoint async_mode { ckpt_config .async_mode } " )
234
211
@@ -242,8 +219,61 @@ def __del__(self):
242
219
self .mp .join ()
243
220
244
221
def reset (self ) -> None :
222
+ # We need to stage the local state if another replicate joins during the
223
+ # first step.
224
+ if self .ft_manager :
225
+ self .cpu_staging (None )
245
226
self .begin_time = time .monotonic ()
246
227
228
+ def _initialize_states (
229
+ self ,
230
+ states : Dict [str , Any ],
231
+ dataloader : DataLoader ,
232
+ model_parts : List [nn .Module ],
233
+ optimizers : OptimizersContainer ,
234
+ lr_schedulers : LRSchedulersContainer ,
235
+ ) -> None :
236
+ """
237
+ Note: Pipeline Parallelism and Virtual Stages
238
+
239
+ 1. Even for simple PP schedules, there is a separate optimizer each PP rank.
240
+ rank0's optimizer would have a param_group[0] which refers to layers.0 in the
241
+ original model. rank1's would _also_ have a param_group[0], since it's index based,
242
+ but referring to layers.1.
243
+ When saving, these collide and one of them is lost. Then when reloading, only one
244
+ stage can restore its optimizer states, others will error.
245
+
246
+ The solution to this problem is optimizer flattening: it landed in #127071
247
+ and is enabled in TorchTitan by passing the 'flatten_optimizer_state_dict'
248
+ kwarg to DCP functions called in the OptimizerContainer.
249
+
250
+ 2. With complex PP schedules, we have multiple model chunks per pp rank. This
251
+ compounds challenge (1) by also requiring us to reason about multiple 'optim'
252
+ objects locally.
253
+
254
+ We solve this in the Model and Optimizer wrapper classes by flattening the
255
+ state dicts from each object into one state dict before saving/loading.
256
+ We rely on the individual state_dicts to not collide, which is gauranteed for
257
+ the model by correct pipeline splitting and for the optimizer by the flattening
258
+ support described in (1).
259
+
260
+ 3. LR schedulers also index model states like optimizers and would need to be
261
+ flattened properly to support resharding. Unfortunately, the implementations of
262
+ different lr_schedulers do not follow a clear pattern like optimizers do, so it's
263
+ hard to write a generic 'flattener' utility.
264
+
265
+ TODO: This is currently unsolved and needs a fix.
266
+ """
267
+ self .states = states
268
+ self .states .update (
269
+ {
270
+ "model" : ModelWrapper (model_parts ),
271
+ "optimizer" : optimizers ,
272
+ "dataloader" : dataloader ,
273
+ "lr_scheduler" : lr_schedulers ,
274
+ }
275
+ )
276
+
247
277
def _create_checkpoint_id (self , step : int ) -> str :
248
278
return os .path .join (self .folder , f"step-{ step } " )
249
279
@@ -326,31 +356,8 @@ def _async_wait(self) -> None:
326
356
self .async_future .result ()
327
357
328
358
def _async_with_pinned_memory (self , checkpoint_id : str ) -> None :
329
- try :
330
- from torch .distributed ._state_dict_utils import (
331
- _copy_state_dict ,
332
- _create_cpu_state_dict ,
333
- )
334
- except ImportError as e :
335
- raise ImportError (
336
- "Please install the latest PyTorch nightly to use async checkpointing with pinned memory."
337
- ) from e
338
- state_dict = dcp .state_dict_saver ._stateful_to_state_dict (self .states )
339
- if self .cpu_offload_state_dict is None :
340
- logger .debug (f"Preparing the CPU memory, { time .monotonic ()= } .:.2f" )
341
- self .cpu_offload_state_dict = _create_cpu_state_dict (
342
- state_dict , pin_memory = True , share_memory = True
343
- )
344
-
345
- logger .debug (f"Staging the state_dict, { time .monotonic ()= } .:.2f" )
346
- with torch .cuda .stream (self .staging_stream ):
347
- self .cpu_offload_state_dict = _copy_state_dict (
348
- state_dict ,
349
- self .cpu_offload_state_dict ,
350
- non_blocking = True ,
351
- )
352
- self .staging = True
353
- self .staging_id = checkpoint_id
359
+ self .cpu_staging (checkpoint_id )
360
+ self .sending_to_checkpoint_mp = True
354
361
355
362
def save (self , curr_step : int , force : bool = False ) -> None :
356
363
"""
@@ -360,6 +367,8 @@ def save(self, curr_step: int, force: bool = False) -> None:
360
367
for initial seed checkpoint.
361
368
"""
362
369
if not self ._should_save (curr_step , force ):
370
+ if self .ft_manager :
371
+ self .cpu_staging (None )
363
372
return
364
373
365
374
begin = time .monotonic ()
@@ -383,26 +392,51 @@ def save(self, curr_step: int, force: bool = False) -> None:
383
392
f"in { time .monotonic () - begin :.2f} seconds."
384
393
)
385
394
395
+ def cpu_staging (self , checkpoint_id : Optional [str ]) -> None :
396
+ """Offload state_dict to CPU memory"""
397
+ state_dict = dcp .state_dict_saver ._stateful_to_state_dict (self .states )
398
+ if self .cpu_offload_state_dict is None :
399
+ logger .debug (f"Preparing the CPU memory, { time .monotonic ()= } .:.2f" )
400
+ self .cpu_offload_state_dict = _create_cpu_state_dict (
401
+ state_dict , pin_memory = True , share_memory = True
402
+ )
403
+
404
+ logger .debug (f"Staging the state_dict, { time .monotonic ()= } .:.2f" )
405
+ with torch .cuda .stream (self .staging_stream ):
406
+ self .cpu_offload_state_dict = _copy_state_dict (
407
+ state_dict ,
408
+ self .cpu_offload_state_dict ,
409
+ non_blocking = True ,
410
+ )
411
+ self .staging = True
412
+ self .staging_id = checkpoint_id
413
+
414
+ def wait_for_staging (self ) -> None :
415
+ if not self .staging_stream .query ():
416
+ self .staging_stream .synchronize ()
417
+ self .staging = False
418
+
419
+ def staging_results (self ) -> Dict [str , Any ]:
420
+ self .maybe_wait_for_staging ()
421
+ return self .cpu_offload_state_dict
422
+
386
423
def maybe_wait_for_staging (self ) -> None :
387
- if (
388
- self .enable_checkpoint
389
- and self .async_mode == AsyncMode .ASYNC_WITH_PINNED_MEM
390
- and self .staging
391
- ):
392
- if not self .staging_stream .query ():
393
- self .staging_stream .synchronize ()
394
-
395
- def sync_func ():
396
- self .mp_queue_send .put_nowait (
397
- (self .cpu_offload_state_dict , self .staging_id )
398
- )
399
-
400
- # This may be a faster way to do zero-overhead checkpointing staging
401
- # checkpointing but we need more thorough investigation before
402
- # swithing to this method.
403
- # self.my_thread = threading.Thread(target=func).start()
404
- sync_func ()
405
- self .staging = False
424
+ if self .enable_staging and self .staging :
425
+ self .wait_for_staging ()
426
+
427
+ if self .sending_to_checkpoint_mp :
428
+ # Copy the sync staging result to another process.
429
+ def sync_func ():
430
+ self .mp_queue_send .put_nowait (
431
+ (self .cpu_offload_state_dict , self .staging_id )
432
+ )
433
+
434
+ # This may be a faster way to do zero-overhead checkpointing staging
435
+ # checkpointing but we need more thorough investigation before
436
+ # swithing to this method.
437
+ # self.my_thread = threading.Thread(target=func).start()
438
+ sync_func ()
439
+ self .sending_to_checkpoint_mp = False
406
440
407
441
def load (self , step : int = - 1 ) -> bool :
408
442
if not self .enable_checkpoint :
0 commit comments