Skip to content
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.

Commit 11ace68

Browse files
committedFeb 12, 2025
Integrate TorchFT
**Summary** This is a WIP TorchFT integration PR. **Current Issues** This doesn't work at this moment as there are hanged groups when a new group joins. **Issue 1:** ~Group 0 and group 1 will hang during the first `should_commit` after group 1 applying the pending state_dict from group 0.~ Fixed with: pytorch/torchft#83 **Issue 2:** ~Group 0 and group 1 will pass the `should_commit` but group 0 needs healing which is wrong and the healing process will cause another hang.~ Fixed with: pytorch/torchft#83 **Issue 3:** ~The byproduct of issue 1 and issue 2: group 1 will continue to print out~ ``` [rank0]:devgpu051:76838:80357 [0] misc/socket.cc:50 NCCL WARN socketProgress: Connection closed by remote peer devgpu051.cln3.svc.fbinfra.net<33618> ``` Fixed with pytorch/torchft#91 and several other fixes. **Issue 4:** When there are 3 groups, everyone requests the state dict every step. ***How to reproduce?*** Using the `Reproduce steps` to run 2 groups, then add another group by modifying the command. Seems to be fixed, will need more tests. **Issue 5:** Hang will happen if using functional collective. ***How to reproduce?*** Pull the latest version of this PR and comment out line 41 and uncomment line 42 in `torchtitan/utils.py` **Reproduce steps:** 1. Patch TorchFT with pytorch/torchft#82 2. Execute lighthouse 3. Execute the following command in one terminal: ``` TORCHFT_MANAGER_PORT=29520 REPLICA_GROUP_ID=0 CUDA_VISIBLE_DEVICES=0,1 NGPU=2 ./run_llama_train.sh --training.data_parallel_shard_degree=2 --experimental.enable_torchft --experimental.ft_replica_group_id=0 ``` 4. Wait 10 seconds, execute following command in another terminal: ``` TORCHFT_MANAGER_PORT=29522 REPLICA_GROUP_ID=1 CUDA_VISIBLE_DEVICES=2,3 NGPU=2 ./run_llama_train.sh --training.data_parallel_shard_degree=2 --experimental.enable_torchft --experimental.ft_replica_group_id=1 ``` ghstack-source-id: c90068ba4f5d937f31596de91c5b08416a48b7d3 Pull Request resolved: #834
1 parent e7305a5 commit 11ace68

File tree

8 files changed

+269
-63
lines changed

8 files changed

+269
-63
lines changed
 

‎run_llama_train.sh

+4
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,11 @@ if [ $# -ne 0 ]; then
1919
overrides="$*"
2020
fi
2121

22+
TORCHFT_MANAGER_PORT=${TORCHFT_MANAGER_PORT:-"29512"}
23+
2224
PYTORCH_CUDA_ALLOC_CONF="expandable_segments:True" \
25+
TORCHFT_LIGHTHOUSE=http://localhost:29510 \
26+
TORCHFT_MANAGER_PORT=${TORCHFT_MANAGER_PORT} \
2327
torchrun --nproc_per_node=${NGPU} --rdzv_backend c10d --rdzv_endpoint="localhost:0" \
2428
--local-ranks-filter ${LOG_RANK} --role rank --tee 3 \
2529
train.py --job.config_file ${CONFIG_FILE} $overrides

‎torchtitan/checkpoint.py

+117-50
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
import torch.distributed as dist
2020
import torch.distributed.checkpoint as dcp
2121
import torch.nn as nn
22+
from torch.distributed._state_dict_utils import _copy_state_dict, _create_cpu_state_dict
2223
from torch.distributed.checkpoint.state_dict import (
2324
get_model_state_dict,
2425
set_model_state_dict,
@@ -144,13 +145,18 @@ def __init__(
144145
lr_schedulers: LRSchedulersContainer,
145146
states: Dict[str, Any],
146147
job_config: JobConfig,
148+
ft_manager: Optional[Any] = None,
147149
) -> None:
148150
ckpt_config = job_config.checkpoint
149151
self.enable_checkpoint = ckpt_config.enable_checkpoint
150-
self.keep_latest_k = ckpt_config.keep_latest_k
152+
self.ft_manager = ft_manager
153+
self.enable_staging = (
154+
self.enable_checkpoint and async_mode == AsyncMode.ASYNC_WITH_PINNED_MEM
155+
) or self.ft_manager
151156

152-
if not self.enable_checkpoint:
157+
if not self.enable_checkpoint and self.ft_manager is None:
153158
return
159+
154160
"""
155161
Note: Pipeline Parallelism and Virtual Stages
156162
@@ -185,6 +191,13 @@ def __init__(
185191
}
186192
)
187193

194+
async_mode = ckpt_config.async_mode.lower()
195+
self.staging = False
196+
self.sending_to_checkpoint_mp = False
197+
self.staging_id = None
198+
self.cpu_offload_state_dict = None
199+
self.staging_stream = torch.cuda.Stream() if self.enable_staging else None
200+
188201
self.folder = os.path.join(job_config.job.dump_folder, ckpt_config.folder)
189202
self.interval_type = (
190203
IntervalType.SECONDS
@@ -199,6 +212,7 @@ def __init__(
199212
if async_mode == AsyncMode.ASYNC or self.interval_type == IntervalType.SECONDS:
200213
self.pg = dist.new_group(backend="gloo")
201214

215+
self.keep_latest_k = ckpt_config.keep_latest_k
202216
self.model_weights_only = ckpt_config.model_weights_only
203217
self.export_dtype = TORCH_DTYPE_MAP[ckpt_config.export_dtype]
204218
self.exclude_from_loading = ckpt_config.exclude_from_loading
@@ -223,10 +237,6 @@ def __init__(
223237
daemon=True,
224238
)
225239
self.mp.start()
226-
self.cpu_offload_state_dict = None
227-
self.staging = False
228-
self.staging_id = None
229-
self.staging_stream = torch.cuda.Stream()
230240
else:
231241
raise ValueError(f"Unkown checkpoint async_mode {ckpt_config.async_mode}")
232242

@@ -240,8 +250,61 @@ def __del__(self):
240250
self.mp.join()
241251

242252
def reset(self) -> None:
253+
# We need to stage the local state if another replicate joins during the
254+
# first step.
255+
if self.ft_manager:
256+
self.cpu_staging(None)
243257
self.begin_time = time.monotonic()
244258

259+
def _initialize_states(
260+
self,
261+
states: Dict[str, Any],
262+
dataloader: DataLoader,
263+
model_parts: List[nn.Module],
264+
optimizers: OptimizersContainer,
265+
lr_schedulers: LRSchedulersContainer,
266+
) -> None:
267+
"""
268+
Note: Pipeline Parallelism and Virtual Stages
269+
270+
1. Even for simple PP schedules, there is a separate optimizer each PP rank.
271+
rank0's optimizer would have a param_group[0] which refers to layers.0 in the
272+
original model. rank1's would _also_ have a param_group[0], since it's index based,
273+
but referring to layers.1.
274+
When saving, these collide and one of them is lost. Then when reloading, only one
275+
stage can restore its optimizer states, others will error.
276+
277+
The solution to this problem is optimizer flattening: it landed in #127071
278+
and is enabled in TorchTitan by passing the 'flatten_optimizer_state_dict'
279+
kwarg to DCP functions called in the OptimizerContainer.
280+
281+
2. With complex PP schedules, we have multiple model chunks per pp rank. This
282+
compounds challenge (1) by also requiring us to reason about multiple 'optim'
283+
objects locally.
284+
285+
We solve this in the Model and Optimizer wrapper classes by flattening the
286+
state dicts from each object into one state dict before saving/loading.
287+
We rely on the individual state_dicts to not collide, which is gauranteed for
288+
the model by correct pipeline splitting and for the optimizer by the flattening
289+
support described in (1).
290+
291+
3. LR schedulers also index model states like optimizers and would need to be
292+
flattened properly to support resharding. Unfortunately, the implementations of
293+
different lr_schedulers do not follow a clear pattern like optimizers do, so it's
294+
hard to write a generic 'flattener' utility.
295+
296+
TODO: This is currently unsolved and needs a fix.
297+
"""
298+
self.states = states
299+
self.states.update(
300+
{
301+
"model": ModelWrapper(model_parts),
302+
"optimizer": optimizers,
303+
"dataloader": dataloader,
304+
"lr_scheduler": lr_schedulers,
305+
}
306+
)
307+
245308
def _create_checkpoint_id(self, step: int) -> str:
246309
return os.path.join(self.folder, f"step-{step}")
247310

@@ -324,31 +387,8 @@ def _async_wait(self) -> None:
324387
self.async_future.result()
325388

326389
def _async_with_pinned_memory(self, checkpoint_id: str) -> None:
327-
try:
328-
from torch.distributed._state_dict_utils import (
329-
_copy_state_dict,
330-
_create_cpu_state_dict,
331-
)
332-
except ImportError as e:
333-
raise ImportError(
334-
"Please install the latest PyTorch nightly to use async checkpointing with pinned memory."
335-
) from e
336-
state_dict = dcp.state_dict_saver._stateful_to_state_dict(self.states)
337-
if self.cpu_offload_state_dict is None:
338-
logger.debug(f"Preparing the CPU memory, {time.monotonic()=}.:.2f")
339-
self.cpu_offload_state_dict = _create_cpu_state_dict(
340-
state_dict, pin_memory=True, share_memory=True
341-
)
342-
343-
logger.debug(f"Staging the state_dict, {time.monotonic()=}.:.2f")
344-
with torch.cuda.stream(self.staging_stream):
345-
self.cpu_offload_state_dict = _copy_state_dict(
346-
state_dict,
347-
self.cpu_offload_state_dict,
348-
non_blocking=True,
349-
)
350-
self.staging = True
351-
self.staging_id = checkpoint_id
390+
self.cpu_staging(checkpoint_id)
391+
self.sending_to_checkpoint_mp = True
352392

353393
def save(self, curr_step: int, force: bool = False) -> None:
354394
"""
@@ -358,6 +398,8 @@ def save(self, curr_step: int, force: bool = False) -> None:
358398
for initial seed checkpoint.
359399
"""
360400
if not self._should_save(curr_step, force):
401+
if self.ft_manager:
402+
self.cpu_staging(None)
361403
return
362404

363405
begin = time.monotonic()
@@ -381,26 +423,51 @@ def save(self, curr_step: int, force: bool = False) -> None:
381423
f"in {time.monotonic() - begin:.2f} seconds."
382424
)
383425

426+
def cpu_staging(self, checkpoint_id: Optional[str]) -> None:
427+
"""Offload state_dict to CPU memory"""
428+
state_dict = dcp.state_dict_saver._stateful_to_state_dict(self.states)
429+
if self.cpu_offload_state_dict is None:
430+
logger.debug(f"Preparing the CPU memory, {time.monotonic()=}.:.2f")
431+
self.cpu_offload_state_dict = _create_cpu_state_dict(
432+
state_dict, pin_memory=True, share_memory=True
433+
)
434+
435+
logger.debug(f"Staging the state_dict, {time.monotonic()=}.:.2f")
436+
with torch.cuda.stream(self.staging_stream):
437+
self.cpu_offload_state_dict = _copy_state_dict(
438+
state_dict,
439+
self.cpu_offload_state_dict,
440+
non_blocking=True,
441+
)
442+
self.staging = True
443+
self.staging_id = checkpoint_id
444+
445+
def wait_for_staging(self) -> None:
446+
if not self.staging_stream.query():
447+
self.staging_stream.synchronize()
448+
self.staging = False
449+
450+
def staging_results(self) -> Dict[str, Any]:
451+
self.maybe_wait_for_staging()
452+
return self.cpu_offload_state_dict
453+
384454
def maybe_wait_for_staging(self) -> None:
385-
if (
386-
self.enable_checkpoint
387-
and self.async_mode == AsyncMode.ASYNC_WITH_PINNED_MEM
388-
and self.staging
389-
):
390-
if not self.staging_stream.query():
391-
self.staging_stream.synchronize()
392-
393-
def sync_func():
394-
self.mp_queue_send.put_nowait(
395-
(self.cpu_offload_state_dict, self.staging_id)
396-
)
397-
398-
# This may be a faster way to do zero-overhead checkpointing staging
399-
# checkpointing but we need more thorough investigation before
400-
# swithing to this method.
401-
# self.my_thread = threading.Thread(target=func).start()
402-
sync_func()
403-
self.staging = False
455+
if self.enable_staging and self.staging:
456+
self.wait_for_staging()
457+
458+
if self.sending_to_checkpoint_mp:
459+
# Copy the sync staging result to another process.
460+
def sync_func():
461+
self.mp_queue_send.put_nowait(
462+
(self.cpu_offload_state_dict, self.staging_id)
463+
)
464+
465+
# This may be a faster way to do zero-overhead checkpointing staging
466+
# checkpointing but we need more thorough investigation before
467+
# swithing to this method.
468+
# self.my_thread = threading.Thread(target=func).start()
469+
sync_func()
470+
self.sending_to_checkpoint_mp = False
404471

405472
def load(self, step: int = -1) -> bool:
406473
if not self.enable_checkpoint:

‎torchtitan/config_manager.py

+13
Original file line numberDiff line numberDiff line change
@@ -631,6 +631,19 @@ def __init__(self):
631631
action="store_true",
632632
)
633633

634+
self.parser.add_argument(
635+
"--experimental.enable_torchft",
636+
action="store_true",
637+
help="Enable TorchFT integration.",
638+
)
639+
640+
self.parser.add_argument(
641+
"--experimental.ft_replica_group_id",
642+
type=int,
643+
default=-1,
644+
help="The FT replicate group of this run.",
645+
)
646+
634647
def to_dict(self):
635648
return self.args_dict
636649

‎torchtitan/ft.py

+58
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,58 @@
1+
import importlib
2+
from typing import Any, Callable, Optional
3+
4+
from torch.distributed._state_dict_utils import _copy_state_dict, _create_cpu_state_dict
5+
6+
from torchtitan.config_manager import JobConfig
7+
8+
if importlib.util.find_spec("torchft") is not None:
9+
import torchft as ft
10+
11+
has_torchft = True
12+
else:
13+
has_torchft = False
14+
15+
16+
def init_ft_manager(job: JobConfig) -> Optional["ft.Manager"]:
17+
"""
18+
Initialize the FT manager for the given job.
19+
"""
20+
if not job.experimental.enable_torchft:
21+
return None
22+
23+
if not has_torchft:
24+
raise ImportError("torchft is not installed. Please install it.")
25+
26+
pg = ft.ProcessGroupBabyNCCL()
27+
manager = ft.Manager(
28+
pg=pg,
29+
min_replica_size=1,
30+
load_state_dict=None,
31+
state_dict=None,
32+
use_async_quorum=True,
33+
replica_id=f"torchtitan_ft_{job.experimental.ft_replica_group_id}",
34+
)
35+
36+
return manager
37+
38+
39+
def set_ft_state_dict_fns(manager: Optional["ft.Manager"], ckpt_manager) -> None:
40+
"""
41+
Set the state dict for the given manager.
42+
"""
43+
if manager is None:
44+
return
45+
46+
def state_dict():
47+
ret = {}
48+
for k, v in ckpt_manager.staging_results().items():
49+
if k in {"model", "optimizer", "lr_schedulers"}:
50+
ret[k] = v
51+
return ret
52+
53+
def load_state_dict(state_dict):
54+
assert state_dict is not None
55+
for k, v in state_dict.items():
56+
ckpt_manager.states[k].load_state_dict(v)
57+
58+
manager.set_state_dict_fns(load_state_dict, state_dict)

‎torchtitan/optimizer.py

+33-6
Original file line numberDiff line numberDiff line change
@@ -177,8 +177,32 @@ def zero_grad(self) -> None:
177177
pass
178178

179179

180+
class FTOptimizersContainer(Optimizer):
181+
def __init__(
182+
self,
183+
model_parts: List[nn.Module],
184+
optimizer_kwargs: Dict[str, Any],
185+
name: str,
186+
ft_manager: Any,
187+
) -> None:
188+
import torchft as ft
189+
190+
super().__init__()
191+
192+
# Force to initialize the optimizer state so that `optim.step()`
193+
# won't be called by state_dict() and load_state_dict().
194+
_ = {
195+
k: v
196+
for sd in map(get_optimizer_state_dict, model_parts, self.optimizers)
197+
for k, v in sd.items()
198+
}
199+
self.optimizers = [ft.Optimizer(ft_manager, optim) for optim in self.optimizers]
200+
201+
180202
def build_optimizers(
181-
model_parts: List[nn.Module], job_config: JobConfig
203+
model_parts: List[nn.Module],
204+
job_config: JobConfig,
205+
ft_manager: Optional[Any] = None,
182206
) -> OptimizersContainer:
183207
"""Create a OptimizersContainer for the given model parts and job config.
184208
@@ -213,11 +237,14 @@ def build_optimizers(
213237
"foreach": not fused,
214238
}
215239

216-
return (
217-
OptimizersContainer(model_parts, optimizer_kwargs, name)
218-
if not optim_in_bwd
219-
else OptimizersInBackwardContainer(model_parts, optimizer_kwargs, name)
220-
)
240+
if optim_in_bwd and ft_manager:
241+
raise ValueError("TorchFT is not supported with optimizers in backward.")
242+
elif optim_in_bwd:
243+
return OptimizersInBackwardContainer(model_parts, optimizer_kwargs, name)
244+
elif ft_manager:
245+
return FTOptimizersContainer(model_parts, optimizer_kwargs, name, ft_manager)
246+
else:
247+
return OptimizersContainer(model_parts, optimizer_kwargs, name)
221248

222249

223250
class LRSchedulersContainer(Stateful):

‎torchtitan/parallelisms/parallel_dims.py

+17-4
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66

77
from dataclasses import dataclass
88
from functools import cached_property
9+
from typing import Any, Optional
910

1011
from torch.distributed.device_mesh import init_device_mesh
1112

@@ -24,6 +25,7 @@ class ParallelDims:
2425
pp: int
2526
world_size: int
2627
enable_loss_parallel: bool
28+
ft_manager: Optional[Any]
2729

2830
def __post_init__(self):
2931
self._validate()
@@ -56,13 +58,24 @@ def build_mesh(self, device_type):
5658
[self.pp, self.dp_replicate, self.dp_shard, self.cp, self.tp],
5759
["pp", "dp_replicate", "dp_shard", "cp", "tp"],
5860
):
59-
if d > 1:
61+
if d > 1 or (name == "dp_replicate" and self.ft_manager is not None):
6062
dims.append(d)
6163
names.append(name)
6264

6365
logger.info(f"Building {len(dims)}-D device mesh with {names}, {dims}")
6466
names = tuple(names)
65-
mesh = init_device_mesh(device_type, dims, mesh_dim_names=names)
67+
if self.ft_manager is None:
68+
mesh = init_device_mesh(device_type, dims, mesh_dim_names=names)
69+
else:
70+
from torchft.process_group import ft_init_device_mesh
71+
72+
mesh = ft_init_device_mesh(
73+
device_type=device_type,
74+
mesh_shape=dims,
75+
mesh_dim_names=names,
76+
replicate_dim=names.index("dp_replicate"),
77+
manager=self.ft_manager,
78+
)
6679

6780
# Create all the submesh here to ensure all required process groups are
6881
# initialized:
@@ -73,7 +86,7 @@ def build_mesh(self, device_type):
7386
# Mesh for loss all-reduce
7487
dp_cp_mesh_dim_names = []
7588

76-
if self.dp_replicate_enabled:
89+
if self.dp_replicate_enabled or ft_manager is not None:
7790
dp_mesh_dim_names.append("dp_replicate")
7891
dp_cp_mesh_dim_names.append("dp_replicate")
7992
if self.dp_shard_enabled:
@@ -101,7 +114,7 @@ def dp_enabled(self):
101114

102115
@property
103116
def dp_replicate_enabled(self):
104-
return self.dp_replicate > 1
117+
return self.dp_replicate > 1 or self.ft_manager is not None
105118

106119
@property
107120
def dp_shard_enabled(self):

‎torchtitan/utils.py

+18
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
# LICENSE file in the root directory of this source tree.
66

77
import contextlib
8+
import copy
89
import gc
910
import importlib
1011
import math
@@ -18,6 +19,7 @@
1819
import torch
1920
import torch.distributed._functional_collectives as funcol
2021
import torch.distributed.distributed_c10d as c10d
22+
import torchft as ft
2123
from torch import distributed as dist
2224
from torch._utils import _get_available_device_type, _get_device_module
2325
from torch.distributed.device_mesh import DeviceMesh
@@ -38,6 +40,11 @@ def get_device_info():
3840

3941

4042
def dist_reduce(x: torch.Tensor, reduceOp: str, mesh: DeviceMesh) -> float:
43+
if isinstance(mesh, ft.process_group._FlattenDeviceMesh):
44+
torch.distributed.all_reduce(x, group=mesh.managed_mesh.replicate_pg)
45+
# x = funcol.all_reduce(x, reduceOp=reduceOp, group=mesh.managed_mesh.replicate_pg)
46+
mesh = mesh.managed_mesh.mesh
47+
4148
if isinstance(x, DTensor):
4249
# functional collectives do not support DTensor inputs
4350
x = x.full_tensor()
@@ -402,6 +409,17 @@ def clip_grad_norm_(
402409
if isinstance(total_norm, DTensor):
403410
# Will reach here if any non-PP parallelism is used.
404411
# If only using PP, total_norm will be a local tensor.
412+
mesh = total_norm._spec.mesh
413+
if isinstance(mesh, ft.process_group.ManagedDeviceMesh):
414+
# The gradients along the replicated dim has been reduced.
415+
# So we don't need another reducution beforing removing the
416+
# replicate dimension
417+
local_tensor = total_norm.to_local()
418+
placements = list(copy.copy(total_norm._spec.placements))
419+
placements.pop(mesh.replicate_dim)
420+
mesh = mesh.mesh
421+
total_norm = DTensor.from_local(local_tensor, mesh, placements)
422+
405423
total_norm = total_norm.full_tensor()
406424

407425
if pp_mesh is not None:

‎train.py

+9-3
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
from torchtitan.checkpoint import CheckpointManager, TrainState
1717
from torchtitan.config_manager import JobConfig
1818
from torchtitan.float8 import Float8Handler
19+
from torchtitan.ft import init_ft_manager, set_ft_state_dict_fns
1920
from torchtitan.logging import init_logger, logger
2021
from torchtitan.metrics import build_device_memory_monitor, build_metric_logger
2122
from torchtitan.parallelisms import ParallelDims
@@ -42,6 +43,10 @@ def main(job_config: JobConfig):
4243
# take control of garbage collection to avoid stragglers
4344
gc_handler = utils.GarbageCollection(gc_freq=job_config.training.gc_freq)
4445

46+
device = torch.device(f"{device_type}:{int(os.environ['LOCAL_RANK'])}")
47+
device_module.set_device(device)
48+
ft_manager = init_ft_manager(job_config)
49+
4550
# init distributed
4651
world_size = int(os.environ["WORLD_SIZE"])
4752
parallel_dims = ParallelDims(
@@ -52,9 +57,8 @@ def main(job_config: JobConfig):
5257
pp=job_config.experimental.pipeline_parallel_degree,
5358
world_size=world_size,
5459
enable_loss_parallel=not job_config.training.disable_loss_parallel,
60+
ft_manager=ft_manager,
5561
)
56-
device = torch.device(f"{device_type}:{int(os.environ['LOCAL_RANK'])}")
57-
device_module.set_device(device)
5862
utils.init_distributed(job_config)
5963
# initialize device memory monitor and get peak flops for MFU calculation
6064
device_memory_monitor = build_device_memory_monitor()
@@ -186,7 +190,7 @@ def loss_fn(pred, labels):
186190
)
187191

188192
# build optimizer after applying parallelisms to the model
189-
optimizers = train_spec.build_optimizers_fn(model_parts, job_config)
193+
optimizers = train_spec.build_optimizers_fn(model_parts, job_config, ft_manager)
190194
lr_schedulers = train_spec.build_lr_schedulers_fn(optimizers, job_config)
191195

192196
train_state = TrainState()
@@ -199,7 +203,9 @@ def loss_fn(pred, labels):
199203
lr_schedulers=lr_schedulers,
200204
states={"train_state": train_state},
201205
job_config=job_config,
206+
ft_manager=ft_manager,
202207
)
208+
set_ft_state_dict_fns(ft_manager, checkpoint)
203209

204210
if job_config.checkpoint.create_seed_checkpoint:
205211
assert (

0 commit comments

Comments
 (0)
Please sign in to comment.