13
13
from dataclasses import dataclass , field
14
14
from io import BytesIO
15
15
from multiprocessing import get_context
16
- from typing import Any , Dict , List , Union
16
+ from typing import Any , Dict , List , Optional , Union
17
17
18
18
import torch
19
19
import torch .distributed as dist
20
20
import torch .distributed .checkpoint as dcp
21
21
import torch .nn as nn
22
+ from torch .distributed ._state_dict_utils import _copy_state_dict , _create_cpu_state_dict
22
23
from torch .distributed .checkpoint .state_dict import (
23
24
get_model_state_dict ,
24
25
set_model_state_dict ,
@@ -143,49 +144,28 @@ def __init__(
143
144
lr_schedulers : SchedulersContainer ,
144
145
states : Dict [str , Any ],
145
146
job_config : JobConfig ,
147
+ ft_manager : Optional [Any ] = None ,
146
148
) -> None :
147
149
ckpt_config = job_config .checkpoint
148
150
self .enable_checkpoint = ckpt_config .enable_checkpoint
149
- self .keep_latest_k = ckpt_config . keep_latest_k
151
+ self .ft_manager = ft_manager
150
152
151
- if not self .enable_checkpoint :
153
+ if not self .enable_checkpoint and self . ft_manager is None :
152
154
return
153
- """
154
- Note: Pipeline Parallelism and Virtual Stages
155
-
156
- 1. even for simple PP schedules, there is a separate optimizer each PP rank.
157
- rank0's optimizer would have a param_group[0] which refers to layers.0 in the original model.
158
- rank1's would _also_ have a param_group[0], since it's index based, but referring to layers.1.
159
- When saving, these collide and one of them is lost. Then when reloading, only one stage can
160
- restore its optimizer states, others will error.
161
-
162
- The solution to this problem is optimizer flattening: it landed in #127071 and is enabled in TorchTitan
163
- by passing the 'flatten_optimizer_state_dict' kwarg to DCP functions called in the OptimizerContainer.
164
-
165
- 2. With complex PP schedules, we have multiple model chunks per pp rank. This compounds challenge (1) by also
166
- requiring us to reason about multiple 'optim' objects locally.
167
-
168
- We solve this in the Model and Optimizer wrapper classes by flattening the state dicts from each object
169
- into one state dict before saving/loading. We rely on the individual state_dicts to not collide,
170
- which is gauranteed for the model by correct pipeline splitting and for the optimizer by the flattening
171
- support described in (1).
172
-
173
- 3. LR schedulers also index model states like optimizers and would need to be flattened properly to support
174
- resharding. Unfortunately, the implementations of different lr_schedulers do not follow a clear pattern like
175
- optimizers do, so it's hard to write a generic 'flattener' utility.
176
-
177
- TODO: This is currently unsolved and needs a fix.
178
- """
179
- self .states = states
180
155
181
- self .states .update (
182
- {
183
- "model" : ModelWrapper (model_parts ),
184
- "optimizer" : optimizers ,
185
- "dataloader" : dataloader ,
186
- }
156
+ self ._initialize_states (
157
+ states , dataloader , model_parts , optimizers , lr_schedulers
187
158
)
188
- self .states .update (lr_schedulers .get_lr_scheduler_state ())
159
+
160
+ async_mode = ckpt_config .async_mode .lower ()
161
+ self .enable_staging = (
162
+ self .enable_checkpoint and async_mode == AsyncMode .ASYNC_WITH_PINNED_MEM
163
+ ) or self .ft_manager
164
+ self .staging = False
165
+ self .sending_to_checkpoint_mp = False
166
+ self .staging_id = None
167
+ self .cpu_offload_state_dict = None
168
+ self .staging_stream = torch .cuda .Stream () if self .enable_staging else None
189
169
190
170
self .folder = os .path .join (job_config .job .dump_folder , ckpt_config .folder )
191
171
self .interval_type = (
@@ -199,11 +179,11 @@ def __init__(
199
179
self .time_sync_result = None
200
180
self .pg = dist .new_group (backend = "gloo" )
201
181
182
+ self .keep_latest_k = ckpt_config .keep_latest_k
202
183
self .model_weights_only = ckpt_config .model_weights_only
203
184
self .export_dtype = TORCH_DTYPE_MAP [ckpt_config .export_dtype ]
204
185
205
186
self .mp = None
206
- async_mode = ckpt_config .async_mode .lower ()
207
187
if async_mode == AsyncMode .DISABLED :
208
188
self .async_mode = AsyncMode .DISABLED
209
189
elif async_mode == AsyncMode .ASYNC :
@@ -223,10 +203,6 @@ def __init__(
223
203
daemon = True ,
224
204
)
225
205
self .mp .start ()
226
- self .cpu_offload_state_dict = None
227
- self .staging = False
228
- self .staging_id = None
229
- self .staging_stream = torch .cuda .Stream ()
230
206
else :
231
207
raise ValueError (f"Unkown checkpoint async_mode { ckpt_config .async_mode } " )
232
208
@@ -240,8 +216,61 @@ def __del__(self):
240
216
self .mp .join ()
241
217
242
218
def reset (self ) -> None :
219
+ # We need to stage the local state if another replicate joins during the
220
+ # first step.
221
+ if self .ft_manager :
222
+ self .cpu_staging (None )
243
223
self .begin_time = time .monotonic ()
244
224
225
+ def _initialize_states (
226
+ self ,
227
+ states : Dict [str , Any ],
228
+ dataloader : DataLoader ,
229
+ model_parts : List [nn .Module ],
230
+ optimizers : OptimizersContainer ,
231
+ lr_schedulers : SchedulersContainer ,
232
+ ) -> None :
233
+ """
234
+ Note: Pipeline Parallelism and Virtual Stages
235
+
236
+ 1. Even for simple PP schedules, there is a separate optimizer each PP rank.
237
+ rank0's optimizer would have a param_group[0] which refers to layers.0 in the
238
+ original model. rank1's would _also_ have a param_group[0], since it's index based,
239
+ but referring to layers.1.
240
+ When saving, these collide and one of them is lost. Then when reloading, only one
241
+ stage can restore its optimizer states, others will error.
242
+
243
+ The solution to this problem is optimizer flattening: it landed in #127071
244
+ and is enabled in TorchTitan by passing the 'flatten_optimizer_state_dict'
245
+ kwarg to DCP functions called in the OptimizerContainer.
246
+
247
+ 2. With complex PP schedules, we have multiple model chunks per pp rank. This
248
+ compounds challenge (1) by also requiring us to reason about multiple 'optim'
249
+ objects locally.
250
+
251
+ We solve this in the Model and Optimizer wrapper classes by flattening the
252
+ state dicts from each object into one state dict before saving/loading.
253
+ We rely on the individual state_dicts to not collide, which is gauranteed for
254
+ the model by correct pipeline splitting and for the optimizer by the flattening
255
+ support described in (1).
256
+
257
+ 3. LR schedulers also index model states like optimizers and would need to be
258
+ flattened properly to support resharding. Unfortunately, the implementations of
259
+ different lr_schedulers do not follow a clear pattern like optimizers do, so it's
260
+ hard to write a generic 'flattener' utility.
261
+
262
+ TODO: This is currently unsolved and needs a fix.
263
+ """
264
+ self .states = states
265
+ self .states .update (
266
+ {
267
+ "model" : ModelWrapper (model_parts ),
268
+ "optimizer" : optimizers ,
269
+ "dataloader" : dataloader ,
270
+ }
271
+ )
272
+ self .states .update (lr_schedulers .get_lr_scheduler_state ())
273
+
245
274
def _create_checkpoint_id (self , step : int ) -> str :
246
275
return os .path .join (self .folder , f"step-{ step } " )
247
276
@@ -324,31 +353,8 @@ def _async_wait(self) -> None:
324
353
self .async_future .result ()
325
354
326
355
def _async_with_pinned_memory (self , checkpoint_id : str ) -> None :
327
- try :
328
- from torch .distributed ._state_dict_utils import (
329
- _copy_state_dict ,
330
- _create_cpu_state_dict ,
331
- )
332
- except ImportError as e :
333
- raise ImportError (
334
- "Please install the latest PyTorch nightly to use async checkpointing with pinned memory."
335
- ) from e
336
- state_dict = dcp .state_dict_saver ._stateful_to_state_dict (self .states )
337
- if self .cpu_offload_state_dict is None :
338
- logger .debug (f"Preparing the CPU memory, { time .monotonic ()= } .:.2f" )
339
- self .cpu_offload_state_dict = _create_cpu_state_dict (
340
- state_dict , pin_memory = True , share_memory = True
341
- )
342
-
343
- logger .debug (f"Staging the state_dict, { time .monotonic ()= } .:.2f" )
344
- with torch .cuda .stream (self .staging_stream ):
345
- self .cpu_offload_state_dict = _copy_state_dict (
346
- state_dict ,
347
- self .cpu_offload_state_dict ,
348
- non_blocking = True ,
349
- )
350
- self .staging = True
351
- self .staging_id = checkpoint_id
356
+ self .cpu_staging (checkpoint_id )
357
+ self .sending_to_checkpoint_mp = True
352
358
353
359
def save (self , curr_step : int , force : bool = False ) -> None :
354
360
"""
@@ -358,6 +364,8 @@ def save(self, curr_step: int, force: bool = False) -> None:
358
364
for initial seed checkpoint.
359
365
"""
360
366
if not self ._should_save (curr_step , force ):
367
+ if self .ft_manager :
368
+ self .cpu_staging (None )
361
369
return
362
370
363
371
begin = time .monotonic ()
@@ -381,26 +389,51 @@ def save(self, curr_step: int, force: bool = False) -> None:
381
389
f"in { time .monotonic () - begin :.2f} seconds."
382
390
)
383
391
392
+ def cpu_staging (self , checkpoint_id : Optional [str ]) -> None :
393
+ """Offload state_dict to CPU memory"""
394
+ state_dict = dcp .state_dict_saver ._stateful_to_state_dict (self .states )
395
+ if self .cpu_offload_state_dict is None :
396
+ logger .debug (f"Preparing the CPU memory, { time .monotonic ()= } .:.2f" )
397
+ self .cpu_offload_state_dict = _create_cpu_state_dict (
398
+ state_dict , pin_memory = True , share_memory = True
399
+ )
400
+
401
+ logger .debug (f"Staging the state_dict, { time .monotonic ()= } .:.2f" )
402
+ with torch .cuda .stream (self .staging_stream ):
403
+ self .cpu_offload_state_dict = _copy_state_dict (
404
+ state_dict ,
405
+ self .cpu_offload_state_dict ,
406
+ non_blocking = True ,
407
+ )
408
+ self .staging = True
409
+ self .staging_id = checkpoint_id
410
+
411
+ def wait_for_staging (self ) -> None :
412
+ if not self .staging_stream .query ():
413
+ self .staging_stream .synchronize ()
414
+ self .staging = False
415
+
416
+ def staging_results (self ) -> Dict [str , Any ]:
417
+ self .maybe_wait_for_staging ()
418
+ return self .cpu_offload_state_dict
419
+
384
420
def maybe_wait_for_staging (self ) -> None :
385
- if (
386
- self .enable_checkpoint
387
- and self .async_mode == AsyncMode .ASYNC_WITH_PINNED_MEM
388
- and self .staging
389
- ):
390
- if not self .staging_stream .query ():
391
- self .staging_stream .synchronize ()
392
-
393
- def sync_func ():
394
- self .mp_queue_send .put_nowait (
395
- (self .cpu_offload_state_dict , self .staging_id )
396
- )
397
-
398
- # This may be a faster way to do zero-overhead checkpointing staging
399
- # checkpointing but we need more thorough investigation before
400
- # swithing to this method.
401
- # self.my_thread = threading.Thread(target=func).start()
402
- sync_func ()
403
- self .staging = False
421
+ if self .enable_staging and self .staging :
422
+ self .wait_for_staging ()
423
+
424
+ if self .sending_to_checkpoint_mp :
425
+ # Copy the sync staging result to another process.
426
+ def sync_func ():
427
+ self .mp_queue_send .put_nowait (
428
+ (self .cpu_offload_state_dict , self .staging_id )
429
+ )
430
+
431
+ # This may be a faster way to do zero-overhead checkpointing staging
432
+ # checkpointing but we need more thorough investigation before
433
+ # swithing to this method.
434
+ # self.my_thread = threading.Thread(target=func).start()
435
+ sync_func ()
436
+ self .sending_to_checkpoint_mp = False
404
437
405
438
def load (self , step : int = - 1 ) -> bool :
406
439
if not self .enable_checkpoint :
0 commit comments