Skip to content

Commit

Permalink
Allow users to use the customized model
Browse files Browse the repository at this point in the history
**What does this PR do?**
1. This PR introduce ModelSpec to decribe a model and how to parallelize a model.
2. All the models should define `build_model_spec()` or `model_spec` to
   be imported by the `model` module.
3. `build_model_specs()` is called in the trainer to get the `model_specs` and the result is used to get the corresponding model spec.
4. Users can also use `--experimental.model_module_path` to dynamically import a model that is not implemented by TorchTitan.

**Why do we need this PR?**
This allows users to use TorchTitan with a new model without intrusively change TorchTitan code.

**Next steps**
1. This PR only include the mode definitions, configurations, totkenizer, parallize_fn, and
   pipelining_fn.  We may also want to extend ModelSpec to include optimizer and lr_scheduler
2. Current TorchTitan parallelize and pipelining_fn import ModelArgs which can cause circular imports.
   We should fix this issue.

ghstack-source-id: 671424d38a040c8594f8b3d692cd8e141ce5c656
Pull Request resolved: #814
  • Loading branch information
fegin committed Feb 7, 2025
1 parent 5940dde commit 70902ae
Show file tree
Hide file tree
Showing 16 changed files with 465 additions and 172 deletions.
12 changes: 12 additions & 0 deletions torchtitan/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
#
# Copyright (c) Meta Platforms, Inc. All Rights Reserved.

# Import the built-in models here so that the corresponding register_model_spec()
# will be called.
import torchtitan.models

4 changes: 2 additions & 2 deletions torchtitan/checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@
from torch.utils.data import DataLoader
from torchtitan.config_manager import JobConfig, TORCH_DTYPE_MAP
from torchtitan.logging import init_logger, logger
from torchtitan.optimizer import OptimizersContainer, SchedulersContainer
from torchtitan.optimizer import LRSchedulersContainer, OptimizersContainer


class IntervalType(enum.Enum):
Expand Down Expand Up @@ -140,7 +140,7 @@ def __init__(
dataloader: DataLoader,
model_parts: List[nn.Module],
optimizers: OptimizersContainer,
lr_schedulers: SchedulersContainer,
lr_schedulers: LRSchedulersContainer,
states: Dict[str, Any],
job_config: JobConfig,
) -> None:
Expand Down
17 changes: 17 additions & 0 deletions torchtitan/config_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -393,6 +393,23 @@ def __init__(self):
The default value is 'allgather'.
""",
)
# I'm not particularly fond of this. Users can choose to write their own wrapper
# module and import TorchTitan training loop and execute it, which look cleaner.
# One reason to provide this option is to allow users to use the existing run script.
# While the script is pretty trivial now, we may add more logic when integrating
# with TorchFT.
# This option is subject to change and may be deleted in the future.
self.parser.add_argument(
"--experimental.custom_model_path",
type=str,
default="",
help="""
The --custom_model_path option allows to specify a custom path to a model module
that is not natively implemented within TorchTitan.
Acceptable values are the file system path to the module (e.g., my_models/model_x)
dotted import module (e.g., some_package.model_x).
""",
)
self.parser.add_argument(
"--training.mixed_precision_param",
type=str,
Expand Down
77 changes: 77 additions & 0 deletions torchtitan/model_spec.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,77 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
#
# Copyright (c) Meta Platforms, Inc. All Rights Reserved.


from dataclasses import dataclass
from typing import Callable, Dict, List, Protocol, Tuple, Type

import torch.nn as nn
from torch.distributed.pipelining.schedules import _PipelineSchedule
from torchtitan.config_manager import JobConfig
from torchtitan.optimizer import LRSchedulersContainer, OptimizersContainer


@dataclass
class BaseModelArgs:
"""All ModelArgs should inherit from this class.
The only usage of this class is type checking but allows us to extend common
arguments to all models in the future.
"""

_enforced: str = "This field is used to enforce all fields have defaults."


class ModelProtocol(Protocol):
"""Defines the interface for a model class.
This is used to enforce that all model classes have some methods that are
required by the TorchTitan trainer.
"""

@staticmethod
def from_model_args(self, args: BaseModelArgs) -> nn.Module: ...


@dataclass
class ModelSpec:
name: str
cls: Type[nn.Module]
config: Dict[str, BaseModelArgs]
# TODO: Add a ``build_dataloader_fn``
# As for now, this is a string. So it will have to be built-in to the
# TorchTitan library. A better way would be to have a dataloader class
# and a ``build_dataloader`` function that take job_config to consume
# the different dataloader and tokenizer configs.
tokenizer: str
parallelize_fn: Callable[[nn.Module], None]
pipelining_fn: Callable[[nn.Module], Tuple[_PipelineSchedule, List[nn.Module]]]
build_optimizers_fn: Callable[[List[nn.Module], JobConfig], OptimizersContainer]
build_lr_schedulers_fn: Callable[
[List[nn.Module], JobConfig], LRSchedulersContainer
]

# TODO: Add a FQN convert fn to allow users to load checkpoints from
# HuggingFace or other sources that have different FQN conventions.


_model_specs = {}


def register_model_spec(model_spec: ModelSpec) -> None:
global _model_specs
if model_spec.name in _model_specs:
raise ValueError(f"Model {model_spec.name} is already registered.")
_model_specs[model_spec.name] = model_spec


def get_model_spec(name: str) -> ModelSpec:
global _model_specs
if name not in _model_specs:
raise ValueError(f"Model {name} is not registered.")
return _model_specs[name]
13 changes: 3 additions & 10 deletions torchtitan/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,14 +4,7 @@
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

from torchtitan.models.llama import llama3_configs, Transformer

models_config = {
"llama3": llama3_configs,
}

model_name_to_cls = {"llama3": Transformer}

model_name_to_tokenizer = {
"llama3": "tiktoken",
}
# Import the built-in models here so that the corresponding register_model_spec()
# will be called.
import torchtitan.models.llama # noqa
22 changes: 21 additions & 1 deletion torchtitan/models/llama/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,15 @@
#
# Copyright (c) Meta Platforms, Inc. All Rights Reserved.

from torchtitan.model_spec import ModelSpec, register_model_spec
from torchtitan.models.llama.model import ModelArgs, Transformer
from torchtitan.optimizer import build_lr_schedulers, build_optimizers

from .parallelize_llama import parallelize_llama
from .pipeline_llama import pipeline_llama

__all__ = ["parallelize_llama", "pipeline_llama", "ModelArgs", "Transformer"]

__all__ = ["Transformer"]

llama3_configs = {
"debugmodel": ModelArgs(dim=256, n_layers=8, n_heads=16, rope_theta=500000),
Expand Down Expand Up @@ -40,3 +46,17 @@
rope_theta=500000,
),
}


register_model_spec(
ModelSpec(
name="llama3",
cls=Transformer,
config=llama3_configs,
tokenizer="tiktoken",
parallelize_fn=parallelize_llama,
pipelining_fn=pipeline_llama,
build_optimizers_fn=build_optimizers,
build_lr_schedulers_fn=build_lr_schedulers,
)
)
5 changes: 3 additions & 2 deletions torchtitan/models/llama/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,11 +13,12 @@
import torch
import torch.nn.functional as F
from torch import nn
from torchtitan.model_spec import BaseModelArgs, ModelProtocol
from torchtitan.models.norms import build_norm


@dataclass
class ModelArgs:
class ModelArgs(BaseModelArgs):
dim: int = 4096
n_layers: int = 32
n_heads: int = 32
Expand Down Expand Up @@ -331,7 +332,7 @@ def init_weights(self):
self.feed_forward.init_weights(self.weight_init_std)


class Transformer(nn.Module):
class Transformer(nn.Module, ModelProtocol):
"""
Transformer Module
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@

from torchtitan.config_manager import JobConfig, TORCH_DTYPE_MAP
from torchtitan.logging import logger
from torchtitan.parallelisms.parallel_dims import ParallelDims
from torchtitan.parallelisms import ParallelDims


def parallelize_llama(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,16 +14,19 @@
from torch.distributed import DeviceMesh
from torch.distributed.pipelining import PipelineStage

from torch.distributed.pipelining.schedules import _PipelineSchedule

from torchtitan.config_manager import JobConfig
from torchtitan.logging import logger
from torchtitan.models.llama.model import ModelArgs
from torchtitan.parallelisms.parallel_dims import ParallelDims
from torchtitan.parallelisms.pipelining_utils import (
from torchtitan.parallelisms import (
build_pipeline_schedule,
generate_split_points,
ParallelDims,
stage_ids_this_rank,
)

from .model import ModelArgs


DeviceType = Union[int, str, torch.device]

Expand All @@ -36,7 +39,7 @@ def pipeline_llama(
device: DeviceType,
model_config: ModelArgs,
loss_fn: Callable[..., torch.Tensor],
):
) -> tuple[_PipelineSchedule, list[nn.Module]]:
stages, models = pipeline_llama_manual_split(
model, pp_mesh, parallel_dims, job_config, device, model_config
)
Expand All @@ -53,7 +56,7 @@ def pipeline_llama_manual_split(
job_config: JobConfig,
device: DeviceType,
model_config: ModelArgs,
):
) -> tuple[list[PipelineStage], list[nn.Module]]:
"""
This API extracts one torch.nn.Module objects for the part of the model configured to run inside this stage.
Expand All @@ -67,10 +70,16 @@ def pipeline_llama_manual_split(

splits = (
job_config.experimental.pipeline_parallel_split_points
or generate_split_points(job_config, parallel_dims.pp, model_config)
or generate_split_points(job_config, parallel_dims.pp, model_config.n_layers)
)

def _build_stage(stage_idx, start_layer, stop_layer, is_first=False, is_last=False):
def _build_stage(
stage_idx: int,
start_layer: Optional[str],
stop_layer: Optional[str],
is_first: bool = False,
is_last: bool = False,
) -> tuple[PipelineStage, nn.Module]:
model = copy.deepcopy(whole_model)
if not is_first:
model.tok_embeddings = None
Expand Down
Loading

0 comments on commit 70902ae

Please sign in to comment.