Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add Dynamic Model Import and ModelSpec Definition #814

Open
wants to merge 10 commits into
base: gh/fegin/8/base
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 12 additions & 0 deletions torchtitan/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
#
# Copyright (c) Meta Platforms, Inc. All Rights Reserved.

# Import the built-in models here so that the corresponding register_model_spec()
# will be called.
import torchtitan.models

4 changes: 2 additions & 2 deletions torchtitan/checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@
from torch.utils.data import DataLoader
from torchtitan.config_manager import JobConfig, TORCH_DTYPE_MAP
from torchtitan.logging import init_logger, logger
from torchtitan.optimizer import OptimizersContainer, SchedulersContainer
from torchtitan.optimizer import LRSchedulersContainer, OptimizersContainer


class IntervalType(enum.Enum):
Expand Down Expand Up @@ -140,7 +140,7 @@ def __init__(
dataloader: DataLoader,
model_parts: List[nn.Module],
optimizers: OptimizersContainer,
lr_schedulers: SchedulersContainer,
lr_schedulers: LRSchedulersContainer,
states: Dict[str, Any],
job_config: JobConfig,
) -> None:
Expand Down
17 changes: 17 additions & 0 deletions torchtitan/config_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -393,6 +393,23 @@ def __init__(self):
The default value is 'allgather'.
""",
)
# I'm not particularly fond of this. Users can choose to write their own wrapper
# module and import TorchTitan training loop and execute it, which look cleaner.
# One reason to provide this option is to allow users to use the existing run script.
# While the script is pretty trivial now, we may add more logic when integrating
# with TorchFT.
# This option is subject to change and may be deleted in the future.
self.parser.add_argument(
"--experimental.custom_model_path",
type=str,
default="",
help="""
The --custom_model_path option allows to specify a custom path to a model module
that is not natively implemented within TorchTitan.
Acceptable values are the file system path to the module (e.g., my_models/model_x)
dotted import module (e.g., some_package.model_x).
""",
)
self.parser.add_argument(
"--training.mixed_precision_param",
type=str,
Expand Down
77 changes: 77 additions & 0 deletions torchtitan/model_spec.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,77 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
#
# Copyright (c) Meta Platforms, Inc. All Rights Reserved.


from dataclasses import dataclass
from typing import Callable, Dict, List, Protocol, Tuple, Type

import torch.nn as nn
from torch.distributed.pipelining.schedules import _PipelineSchedule
from torchtitan.config_manager import JobConfig
from torchtitan.optimizer import LRSchedulersContainer, OptimizersContainer


@dataclass
class BaseModelArgs:
"""All ModelArgs should inherit from this class.
The only usage of this class is type checking but allows us to extend common
arguments to all models in the future.
"""

_enforced: str = "This field is used to enforce all fields have defaults."


class ModelProtocol(Protocol):
"""Defines the interface for a model class.
This is used to enforce that all model classes have some methods that are
required by the TorchTitan trainer.
"""

@staticmethod
def from_model_args(self, args: BaseModelArgs) -> nn.Module: ...


@dataclass
class ModelSpec:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since the Spec is not only about model, e.g. conceptually there can be multiple ways to do training for the same model (gpu/tpu, customized parallelize/pipeline), shall we consider renaming it to TrainSpec?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That's actually a good question and suggestion. I am open to this option.

name: str
cls: Type[nn.Module]
config: Dict[str, BaseModelArgs]
# TODO: Add a ``build_dataloader_fn``
# As for now, this is a string. So it will have to be built-in to the
# TorchTitan library. A better way would be to have a dataloader class
# and a ``build_dataloader`` function that take job_config to consume
# the different dataloader and tokenizer configs.
tokenizer: str
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

currently tokenizer is part of data loader
https://github.com/pytorch/torchtitan/blob/main/torchtitan/datasets/hf_datasets.py#L186

maybe let's remove it for now

parallelize_fn: Callable[[nn.Module], None]
pipelining_fn: Callable[[nn.Module], Tuple[_PipelineSchedule, List[nn.Module]]]
build_optimizers_fn: Callable[[List[nn.Module], JobConfig], OptimizersContainer]
build_lr_schedulers_fn: Callable[
[List[nn.Module], JobConfig], LRSchedulersContainer
]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For some models we may need to alter loss_fn as well, e.g. in diffusion models. We may add that later when necessary.


# TODO: Add a FQN convert fn to allow users to load checkpoints from
# HuggingFace or other sources that have different FQN conventions.


_model_specs = {}


def register_model_spec(model_spec: ModelSpec) -> None:
global _model_specs
if model_spec.name in _model_specs:
raise ValueError(f"Model {model_spec.name} is already registered.")
_model_specs[model_spec.name] = model_spec


def get_model_spec(name: str) -> ModelSpec:
global _model_specs
if name not in _model_specs:
raise ValueError(f"Model {name} is not registered.")
return _model_specs[name]
13 changes: 3 additions & 10 deletions torchtitan/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,14 +4,7 @@
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

from torchtitan.models.llama import llama3_configs, Transformer

models_config = {
"llama3": llama3_configs,
}

model_name_to_cls = {"llama3": Transformer}

model_name_to_tokenizer = {
"llama3": "tiktoken",
}
# Import the built-in models here so that the corresponding register_model_spec()
# will be called.
import torchtitan.models.llama # noqa
22 changes: 21 additions & 1 deletion torchtitan/models/llama/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,15 @@
#
# Copyright (c) Meta Platforms, Inc. All Rights Reserved.

from torchtitan.model_spec import ModelSpec, register_model_spec
from torchtitan.models.llama.model import ModelArgs, Transformer
from torchtitan.optimizer import build_lr_schedulers, build_optimizers

from .parallelize_llama import parallelize_llama
from .pipeline_llama import pipeline_llama

__all__ = ["parallelize_llama", "pipeline_llama", "ModelArgs", "Transformer"]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

do we need to expose these fields in llama/__init__.py?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, so that users can reuse the parallelism APIs from llama.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That makes sense. But maybe also llama3_configs?
I imagine some one wants to use implement new parallelisms, but relying on existing definitions of Llama 3 8B/70B/405B. In that case they don't need ModelArgs but only the preset configs.


__all__ = ["Transformer"]

llama3_configs = {
"debugmodel": ModelArgs(dim=256, n_layers=8, n_heads=16, rope_theta=500000),
Expand Down Expand Up @@ -40,3 +46,17 @@
rope_theta=500000,
),
}


register_model_spec(
ModelSpec(
name="llama3",
cls=Transformer,
config=llama3_configs,
tokenizer="tiktoken",
parallelize_fn=parallelize_llama,
pipelining_fn=pipeline_llama,
build_optimizers_fn=build_optimizers,
build_lr_schedulers_fn=build_lr_schedulers,
)
)
5 changes: 3 additions & 2 deletions torchtitan/models/llama/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,11 +13,12 @@
import torch
import torch.nn.functional as F
from torch import nn
from torchtitan.model_spec import BaseModelArgs, ModelProtocol
from torchtitan.models.norms import build_norm


@dataclass
class ModelArgs:
class ModelArgs(BaseModelArgs):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Down the road we will have many models, like MM model. Do we want all model args to inherit this? Currently we use different model args for different model arch.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This mainly for typing for now but also preserve the ability to introduce common model args.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can we rename it to "TransformerModelArgs"?

dim: int = 4096
n_layers: int = 32
n_heads: int = 32
Expand Down Expand Up @@ -331,7 +332,7 @@ def init_weights(self):
self.feed_forward.init_weights(self.weight_init_std)


class Transformer(nn.Module):
class Transformer(nn.Module, ModelProtocol):
"""
Transformer Module

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@

from torchtitan.config_manager import JobConfig, TORCH_DTYPE_MAP
from torchtitan.logging import logger
from torchtitan.parallelisms.parallel_dims import ParallelDims
from torchtitan.parallelisms import ParallelDims


def parallelize_llama(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,16 +14,19 @@
from torch.distributed import DeviceMesh
from torch.distributed.pipelining import PipelineStage

from torch.distributed.pipelining.schedules import _PipelineSchedule

from torchtitan.config_manager import JobConfig
from torchtitan.logging import logger
from torchtitan.models.llama.model import ModelArgs
from torchtitan.parallelisms.parallel_dims import ParallelDims
from torchtitan.parallelisms.pipelining_utils import (
from torchtitan.parallelisms import (
build_pipeline_schedule,
generate_split_points,
ParallelDims,
stage_ids_this_rank,
)

from .model import ModelArgs


DeviceType = Union[int, str, torch.device]

Expand All @@ -36,7 +39,7 @@ def pipeline_llama(
device: DeviceType,
model_config: ModelArgs,
loss_fn: Callable[..., torch.Tensor],
):
) -> tuple[_PipelineSchedule, list[nn.Module]]:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Typing in uppercase vs. lowercase seems inconsistent throughout the PR. Is this intentionally? and what's the recommended way?

hmm it seems only for state_dict you used uppercase, maybe because compatibility.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Uppercase is the recommended way if we don't support <= Python 3.8. After Pytorch 2.6, that's the case. So we should just change to the lower case one. I may revisit the code and try to change all to lowercases.

stages, models = pipeline_llama_manual_split(
model, pp_mesh, parallel_dims, job_config, device, model_config
)
Expand All @@ -53,7 +56,7 @@ def pipeline_llama_manual_split(
job_config: JobConfig,
device: DeviceType,
model_config: ModelArgs,
):
) -> tuple[list[PipelineStage], list[nn.Module]]:
"""
This API extracts one torch.nn.Module objects for the part of the model configured to run inside this stage.

Expand All @@ -67,10 +70,16 @@ def pipeline_llama_manual_split(

splits = (
job_config.experimental.pipeline_parallel_split_points
or generate_split_points(job_config, parallel_dims.pp, model_config)
or generate_split_points(job_config, parallel_dims.pp, model_config.n_layers)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nice

)

def _build_stage(stage_idx, start_layer, stop_layer, is_first=False, is_last=False):
def _build_stage(
stage_idx: int,
start_layer: Optional[str],
stop_layer: Optional[str],
is_first: bool = False,
is_last: bool = False,
) -> tuple[PipelineStage, nn.Module]:
model = copy.deepcopy(whole_model)
if not is_first:
model.tok_embeddings = None
Expand Down
Loading
Loading