Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add Dynamic Model Import and ModelSpec Definition #814

Merged
merged 20 commits into from
Feb 12, 2025
23 changes: 23 additions & 0 deletions torchtitan/config_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -375,6 +375,20 @@ def __init__(self):
The default value is 'allgather'.
""",
)
self.parser.add_argument(
"--experimental.custom_model_path",
type=str,
default="",
help="""
The --custom_model_path option allows to specify a custom path to a model module

fegin marked this conversation as resolved.
Show resolved Hide resolved
that is not natively implemented within TorchTitan.

Acceptable values are the file system path to the module (e.g., my_models/model_x)

dotted import module (e.g., some_package.model_x).
""",
)
self.parser.add_argument(
"--training.mixed_precision_param",
type=str,
Expand Down Expand Up @@ -638,6 +652,15 @@ def parse_args(self, args_list: list = sys.argv[1:]):
exp["pipeline_parallel_split_points"]
)

if (
"experimental" in args_dict
and "model_module_path" in args_dict["experimental"]
and args_dict["experimental"]["model_module_path"]
):
from torchtitan.models import add_model_spec_path

add_model_spec_path(args_dict["experimental"]["model_module_path"])
fegin marked this conversation as resolved.
Show resolved Hide resolved

# override args dict with cmd_args
cmd_args_dict = self._args_to_two_level_dict(cmd_args)
for section, section_args in cmd_args_dict.items():
Expand Down
37 changes: 37 additions & 0 deletions torchtitan/model_spec.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
#
# Copyright (c) Meta Platforms, Inc. All Rights Reserved.


from dataclasses import dataclass
from typing import Callable, Dict, List, Protocol, Tuple, Type

import torch.nn as nn
from torch.distributed.pipelining.schedules import _PipelineSchedule


@dataclass
class BaseModelArgs:
_enforced: str = "This field is used to enforce all fields have defaults."


class ModelProtocol(Protocol):
def from_model_args(self, args: BaseModelArgs) -> nn.Module:
...


@dataclass
class ModelSpec:
fegin marked this conversation as resolved.
Show resolved Hide resolved
name: str
cls: Type[nn.Module]
config: Dict[str, BaseModelArgs]
# As for now, this is a string. So it will have to be built-in to the
# TorchTitan library. In the future, we can make this a defined class
# that can be extended like ModelSpec.
tokenizer: str
fegin marked this conversation as resolved.
Show resolved Hide resolved
parallelize_fn: Callable[[nn.Module], None]
pipelining_fn: Callable[[nn.Module], Tuple[_PipelineSchedule, List[nn.Module]]]
87 changes: 79 additions & 8 deletions torchtitan/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,14 +4,85 @@
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

from torchtitan.models.llama import llama3_configs, Transformer
import importlib

models_config = {
"llama3": llama3_configs,
}
import os
import pkgutil
from typing import Dict, Set

model_name_to_cls = {"llama3": Transformer}
import torchtitan.models as models
from torchtitan.model_spec import ModelSpec

model_name_to_tokenizer = {
"llama3": "tiktoken",
}

_model_specs_path: Set[str] = set()
fegin marked this conversation as resolved.
Show resolved Hide resolved


def _load_module(path: str):
path = os.path.expanduser(path)

# 1. Check if path is an existing file or directory path.
if os.path.exists(path):
if os.path.isdir(path):
init_file = os.path.join(path, "__init__.py")
if os.path.isfile(init_file):
return _load_module_from_init(path)
tianyu-l marked this conversation as resolved.
Show resolved Hide resolved
fegin marked this conversation as resolved.
Show resolved Hide resolved

raise ImportError(
f"Directory '{path}' is not a Python package because it does not "
"contain an __init__.py file."
)
else:
raise ImportError(f"Path '{path}' is not a directory.")

# 2. If not a valid path, assume it's a dotted module name.
return importlib.import_module(path)


def _load_module_from_init(path: str):
module_name = os.path.basename(os.path.normpath(path))
init_file = os.path.join(path, "__init__.py")

spec = importlib.util.spec_from_file_location(module_name, init_file)
if spec is None:
raise ImportError(f"Could not create spec from '{init_file}'")
fegin marked this conversation as resolved.
Show resolved Hide resolved

module = importlib.util.module_from_spec(spec)
spec.loader.exec_module(module)
return module


for _, name, _ in pkgutil.iter_modules(models.__path__):
fegin marked this conversation as resolved.
Show resolved Hide resolved
full_module_name = f"{models.__name__}.{name}"
_model_specs_path.add(full_module_name)
# model_module = importlib.import_module(full_module_name)
# load_spec_from_module(model_module)


def add_model_spec_path(path: str):
global _model_specs_path
_model_specs_path.add(path)


def build_model_specs() -> Dict[str, ModelSpec]:
"""
Load all model specs from the `models` package.
"""
global _model_specs_path
model_specs = {}
for path in _model_specs_path:
module = _load_module(path)
model_spec = getattr(module, "model_spec", None)
if model_spec is not None:
model_specs[model_spec.name] = model_spec
# We would like to just use `model_spec` but current torchtitan parallelize
# functions depend on ModelArgs and can cause circular imports.
# As a result, we have to use `build_model_spec` as a workaround.
build_model_spec = getattr(module, "build_model_spec", None)
if build_model_spec:
model_spec = build_model_spec()
model_specs[model_spec.name] = model_spec

return model_specs


__all__ = [add_model_spec_path, build_model_specs]
17 changes: 16 additions & 1 deletion torchtitan/models/llama/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,9 @@
#
# Copyright (c) Meta Platforms, Inc. All Rights Reserved.

from torchtitan.model_spec import ModelSpec
from torchtitan.models.llama.model import ModelArgs, Transformer

__all__ = ["Transformer"]

llama3_configs = {
"debugmodel": ModelArgs(dim=256, n_layers=8, n_heads=16, rope_theta=500000),
Expand Down Expand Up @@ -40,3 +40,18 @@
rope_theta=500000,
),
}


def build_model_spec() -> ModelSpec:
# Avoid circular import
fegin marked this conversation as resolved.
Show resolved Hide resolved
from torchtitan.parallelisms.parallelize_llama import parallelize_llama
from torchtitan.parallelisms.pipeline_llama import pipeline_llama

return ModelSpec(
name="llama3",
cls=Transformer,
config=llama3_configs,
tokenizer="tiktoken",
parallelize_fn=parallelize_llama,
pipelining_fn=pipeline_llama,
)
5 changes: 3 additions & 2 deletions torchtitan/models/llama/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,11 +13,12 @@
import torch
import torch.nn.functional as F
from torch import nn
from torchtitan.model_spec import BaseModelArgs, ModelProtocol
from torchtitan.models.norms import build_norm


@dataclass
class ModelArgs:
class ModelArgs(BaseModelArgs):
fegin marked this conversation as resolved.
Show resolved Hide resolved
fegin marked this conversation as resolved.
Show resolved Hide resolved
dim: int = 4096
n_layers: int = 32
n_heads: int = 32
Expand Down Expand Up @@ -258,7 +259,7 @@ def init_weights(self, init_std: float):
nn.init.trunc_normal_(linear.weight, mean=0.0, std=init_std)


class TransformerBlock(nn.Module):
class TransformerBlock(nn.Module, ModelProtocol):
fegin marked this conversation as resolved.
Show resolved Hide resolved
"""
TransformerBlock Module

Expand Down
21 changes: 9 additions & 12 deletions train.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,13 +19,9 @@
from torchtitan.float8 import Float8Handler
from torchtitan.logging import init_logger, logger
from torchtitan.metrics import build_device_memory_monitor, build_metric_logger
from torchtitan.models import model_name_to_cls, model_name_to_tokenizer, models_config
from torchtitan.models import build_model_specs
from torchtitan.optimizer import build_lr_schedulers, build_optimizers
from torchtitan.parallelisms import (
models_parallelize_fns,
models_pipelining_fns,
ParallelDims,
)
from torchtitan.parallelisms import ParallelDims
from torchtitan.profiling import maybe_enable_memory_snapshot, maybe_enable_profiling
from torchtitan.utils import device_module, device_type

Expand Down Expand Up @@ -80,9 +76,10 @@ def main(job_config: JobConfig):
world_mesh, device, job_config.training.seed, job_config.training.deterministic
)
model_name = job_config.model.name
model_spec = build_model_specs()[model_name]

# build tokenizer
tokenizer_type = model_name_to_tokenizer[model_name]
tokenizer_type = model_spec.tokenizer
tokenizer = build_tokenizer(tokenizer_type, job_config.model.tokenizer_path)
# build dataloader
data_loader = build_hf_data_loader(
Expand All @@ -96,8 +93,8 @@ def main(job_config: JobConfig):
)

# build model (using meta init)
model_cls = model_name_to_cls[model_name]
model_config = models_config[model_name][job_config.model.flavor]
model_cls = model_spec.cls
model_config = model_spec.config[job_config.model.flavor]
# set the model configs from training inputs:
# 1. norm type to decide which norm layer to use
# 2. vocab size from tokenizer
Expand Down Expand Up @@ -151,7 +148,7 @@ def loss_fn(pred, labels):
# apply parallelisms and initialization
if parallel_dims.pp_enabled:
# apply PT-D Pipeline Parallel
pp_schedule, model_parts = models_pipelining_fns[model_name](
pp_schedule, model_parts = model_spec.pipelining_fn(
model, pp_mesh, parallel_dims, job_config, device, model_config, loss_fn
)
# when PP is enabled, `model` obj is no longer used after this point, model_parts is used instead
Expand All @@ -162,14 +159,14 @@ def loss_fn(pred, labels):
# optimizer, and checkpointing
for m in model_parts:
# apply SPMD-style PT-D techniques
models_parallelize_fns[model_name](m, world_mesh, parallel_dims, job_config)
model_spec.parallelize_fn(m, world_mesh, parallel_dims, job_config)
m.to_empty(device=init_device)
with torch.no_grad():
m.init_weights(buffer_device=buffer_device)
m.train()
else:
# apply PT-D Tensor Parallel, activation checkpointing, torch.compile, Data Parallel
models_parallelize_fns[model_name](model, world_mesh, parallel_dims, job_config)
model_spec.parallelize_fn(model, world_mesh, parallel_dims, job_config)
model.to_empty(device=init_device)
with torch.no_grad():
model.init_weights(buffer_device=buffer_device)
Expand Down
Loading