Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Enabling Alternative Path ABC implementations #1393

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 9 additions & 4 deletions megatron/core/dist_checkpointing/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import json
from dataclasses import asdict, dataclass
from pathlib import Path
from pathlib_abc import PathBase
from typing import Optional

CONFIG_FNAME = 'metadata.json'
Expand Down Expand Up @@ -45,7 +46,7 @@ def check_is_distributed_checkpoint(checkpoint_dir):
return maybe_load_config(checkpoint_dir) is not None


def maybe_load_config(checkpoint_dir: str) -> Optional[CheckpointingConfig]:
def maybe_load_config(checkpoint_dir: str | PathBase) -> Optional[CheckpointingConfig]:
"""Returns checkpoint config if `checkpoint_dir` is a distributed checkpoint and None otherwise

Args:
Expand All @@ -54,15 +55,17 @@ def maybe_load_config(checkpoint_dir: str) -> Optional[CheckpointingConfig]:
Returns:
CheckpointingConfig (optional): None if checkpoint is not a valid distributed checkpoint
"""
config_path = Path(checkpoint_dir, CONFIG_FNAME)
if isinstance(checkpoint_dir, str):
config_path = Path(checkpoint_dir)
config_path = config_path / CONFIG_FNAME
if not config_path.exists():
return None
with config_path.open() as f:
config_dict = json.load(f)
return CheckpointingConfig(**config_dict)


def save_config(config: CheckpointingConfig, checkpoint_dir: str):
def save_config(config: CheckpointingConfig, checkpoint_dir: str | PathBase):
"""Save given config to checkpoint directory.

Args:
Expand All @@ -72,6 +75,8 @@ def save_config(config: CheckpointingConfig, checkpoint_dir: str):
Returns:
None
"""
config_path = Path(checkpoint_dir, CONFIG_FNAME)
if isinstance(checkpoint_dir, str):
config_path = Path(checkpoint_dir)
config_path = config_path / CONFIG_FNAME
with config_path.open('w') as f:
json.dump(asdict(config), f)
35 changes: 23 additions & 12 deletions megatron/core/dist_checkpointing/serialization.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@

import logging
from pathlib import Path
from pathlib_abc import PathBase
from typing import Callable, Dict, Optional, Set, Tuple, Union

import torch
Expand Down Expand Up @@ -55,7 +56,7 @@

def load(
sharded_state_dict: ShardedStateDict,
checkpoint_dir: str,
checkpoint_dir: str | PathBase,
sharded_strategy: Union[LoadShardedStrategy, Tuple[str, int], None] = None,
common_strategy: Union[LoadCommonStrategy, Tuple[str, int], None] = None,
validate_access_integrity: bool = True,
Expand Down Expand Up @@ -102,7 +103,8 @@ def load(
checkpoint_dir, sharded_strategy, common_strategy
)

checkpoint_dir = Path(checkpoint_dir)
if isinstance(checkpoint_dir, str):
checkpoint_dir = Path(checkpoint_dir)
common_state_dict = common_strategy.load_common(checkpoint_dir)

sharded_state_dict, nonpersistent_state_dict, sh_ten_factories = load_preprocess(
Expand Down Expand Up @@ -156,7 +158,7 @@ def load(
return common_state_dict


def load_common_state_dict(checkpoint_dir: Path) -> StateDict:
def load_common_state_dict(checkpoint_dir: PathBase) -> StateDict:
"""Load common (non-sharded) objects state dict from the checkpoint.

Args:
Expand All @@ -170,7 +172,7 @@ def load_common_state_dict(checkpoint_dir: Path) -> StateDict:


def load_tensors_metadata(
checkpoint_dir: str, sharded_strategy: Union[LoadShardedStrategy, None] = None
checkpoint_dir: str | PathBase, sharded_strategy: Union[LoadShardedStrategy, None] = None
) -> CkptShardedMetadata:
"""Load tensors metadata from the checkpoint.

Expand All @@ -197,11 +199,13 @@ def load_tensors_metadata(
sharded_strategy, common_strategy = verify_checkpoint_and_load_strategy(
checkpoint_dir, sharded_strategy
)
return sharded_strategy.load_tensors_metadata(Path(checkpoint_dir))
if isinstance(checkpoint_dir, str):
checkpoint_dir = Path(checkpoint_dir)
return sharded_strategy.load_tensors_metadata(checkpoint_dir)


def load_sharded_metadata(
checkpoint_dir: str,
checkpoint_dir: str | PathBase,
sharded_strategy: Union[LoadShardedStrategy, None] = None,
common_strategy: Union[LoadCommonStrategy, None] = None,
) -> CkptShardedMetadata:
Expand Down Expand Up @@ -235,15 +239,17 @@ def load_sharded_metadata(
sharded_strategy, common_strategy = verify_checkpoint_and_load_strategy(
checkpoint_dir, sharded_strategy, common_strategy
)
sharded_metadata = sharded_strategy.load_sharded_metadata(Path(checkpoint_dir))
if isinstance(checkpoint_dir, str):
checkpoint_dir = Path(checkpoint_dir)
sharded_metadata = sharded_strategy.load_sharded_metadata(checkpoint_dir)
if not sharded_strategy.can_handle_sharded_objects:
validate_sharded_objects_handling(sharded_strategy, common_strategy)
common_metadata = common_strategy.load_sharded_metadata(Path(checkpoint_dir))
common_metadata = common_strategy.load_sharded_metadata(checkpoint_dir)
sharded_metadata = merge(sharded_metadata, common_metadata)
return sharded_metadata


def load_plain_tensors(checkpoint_dir: str) -> StateDict:
def load_plain_tensors(checkpoint_dir: str | PathBase) -> StateDict:
"""Load checkpoint tensors without any sharding and plain structure.

NOTE: common state dict is NOT included.
Expand All @@ -254,6 +260,8 @@ def load_plain_tensors(checkpoint_dir: str) -> StateDict:
Returns:
StateDict: checkpoint state dict containing only torch.Tensors.
"""
if isinstance(checkpoint_dir, str):
checkpoint_dir = Path(checkpoint_dir)
sharded_state_dict = load_tensors_metadata(checkpoint_dir)
# Don't validate integrity because shards will be overlapped
# if world_size > 1 (all processes load whole tensors)
Expand All @@ -279,15 +287,17 @@ def load_plain_tensors(checkpoint_dir: str) -> StateDict:
# return load(sharded_state_dict, checkpoint_dir, validate_access_integrity=False)


def remove_sharded_tensors(checkpoint_dir: str, key_prefix: str):
def remove_sharded_tensors(checkpoint_dir: str | PathBase, key_prefix: str):
"""determine the appropriate sharding strategy and delegate removal to the sharded strategy"""
if isinstance(checkpoint_dir, str):
checkpoint_dir = Path(checkpoint_dir)
sharded_strategy, common_strategy = verify_checkpoint_and_load_strategy(checkpoint_dir)
sharded_strategy.remove_sharded_tensors(checkpoint_dir, key_prefix)


def save(
sharded_state_dict: ShardedStateDict,
checkpoint_dir: str,
checkpoint_dir: str | PathBase,
sharded_strategy: Union[SaveShardedStrategy, Tuple[str, int], None] = None,
common_strategy: Union[SaveCommonStrategy, Tuple[str, int], None] = None,
validate_access_integrity: bool = True,
Expand Down Expand Up @@ -342,7 +352,8 @@ def save(
async request that should be scheduled by the caller of this function.
None otherwise.
"""
checkpoint_dir = Path(checkpoint_dir)
if isinstance(checkpoint_dir, str):
checkpoint_dir = Path(checkpoint_dir)

if torch.distributed.get_rank() == 0:
if not checkpoint_dir.exists():
Expand Down
14 changes: 7 additions & 7 deletions megatron/core/dist_checkpointing/strategies/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import logging
import os
from pathlib import Path

from pathlib_abc import PathBase
import torch

from megatron.core.dist_checkpointing.mapping import ShardedStateDict, StateDict
Expand Down Expand Up @@ -35,13 +35,13 @@ def register_default_common_strategies():
class TorchCommonSaveStrategy(SaveCommonStrategy):
"""Common save strategy leveraging native torch save/load."""

def save_common(self, common_state_dict: StateDict, checkpoint_dir: Path):
def save_common(self, common_state_dict: StateDict, checkpoint_dir: PathBase):
"""Save common part of the state dict."""
if torch.distributed.get_rank() == 0:
torch.save(common_state_dict, checkpoint_dir / COMMON_STATE_FNAME)

def save_sharded_objects(
self, sharded_objects_state_dict: ShardedStateDict, checkpoint_dir: Path
self, sharded_objects_state_dict: ShardedStateDict, checkpoint_dir: PathBase
):
"""Save sharded objects from the state dict."""
for sh_obj in nested_values(sharded_objects_state_dict):
Expand All @@ -58,7 +58,7 @@ def can_handle_sharded_objects(self):
class TorchCommonLoadStrategy(LoadCommonStrategy):
"""Common load strategy leveraging native torch save/load."""

def load_common(self, checkpoint_dir: Path):
def load_common(self, checkpoint_dir: PathBase):
"""Load common (non-sharded) objects state dict from the checkpoint.

Args:
Expand All @@ -67,7 +67,7 @@ def load_common(self, checkpoint_dir: Path):
Returns:
StateDict: state dict with non-sharded objects from the checkpoint
"""
load_path = Path(checkpoint_dir) / COMMON_STATE_FNAME
load_path = checkpoint_dir / COMMON_STATE_FNAME
try:
return torch.load(load_path, map_location='cpu', weights_only=False)
except FileNotFoundError as e:
Expand All @@ -77,7 +77,7 @@ def load_common(self, checkpoint_dir: Path):
raise CheckpointingException(err_msg) from e

def load_sharded_objects(
self, sharded_objects_state_dict: ShardedStateDict, checkpoint_dir: Path
self, sharded_objects_state_dict: ShardedStateDict, checkpoint_dir: PathBase
):
"""Replaces all ShardedObject from a given state dict with values loaded from the
checkpoint.
Expand Down Expand Up @@ -120,7 +120,7 @@ def load_sharded_object(sh_obj: ShardedObject):

return dict_list_map_inplace(load_sharded_object, sharded_objects_state_dict)

def load_sharded_metadata(self, checkpoint_dir: Path) -> ShardedStateDict:
def load_sharded_metadata(self, checkpoint_dir: PathBase) -> ShardedStateDict:
sharded_metadata = {}
for subdir in checkpoint_dir.iterdir():
if not subdir.is_dir():
Expand Down
7 changes: 5 additions & 2 deletions megatron/core/dist_checkpointing/validation.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from collections import Counter, defaultdict
from enum import Enum
from pathlib import Path
from pathlib_abc import PathBase
from typing import TYPE_CHECKING, List, Optional, Set, Tuple, Union

import numpy as np
Expand Down Expand Up @@ -202,7 +203,7 @@ def validate_integrity_and_strict_load(


def verify_checkpoint_and_load_strategy(
checkpoint_dir: str,
checkpoint_dir: str | PathBase,
sharded_strategy: Union[LoadShardedStrategy, Tuple[str, int], None] = None,
common_strategy: Union[LoadCommonStrategy, Tuple[str, int], None] = None,
) -> Tuple[LoadShardedStrategy, LoadCommonStrategy]:
Expand All @@ -219,7 +220,9 @@ def verify_checkpoint_and_load_strategy(
if compatible with the checkpoint content. If None, the default common load strategy
for the checkpoint backend will be returned.
"""
if not Path(checkpoint_dir).exists():
if isinstance(checkpoint_dir, str):
checkpoint_dir = Path(checkpoint_dir)
if not checkpoint_dir.exists():
raise CheckpointingException(f'Checkpoint directory {checkpoint_dir} does not exist')

saved_config = maybe_load_config(checkpoint_dir)
Expand Down