Skip to content

Commit

Permalink
Allow users to use the customized model
Browse files Browse the repository at this point in the history
**What does this PR do?**
1. This PR introduce ModelSpec to decribe a model and how to parallelize a model.
2. All the models should define `build_model_spec()` or `model_spec` to
   be imported by the `model` module.
3. `build_model_specs()` is called in the trainer to get the `model_specs` and the result is used to get the corresponding model spec.
4. Users can also use `--experimental.model_module_path` to dynamically import a model that is not implemented by TorchTitan.

**Why do we need this PR?**
This allows users to use TorchTitan with a new model without intrusively change TorchTitan code.

**Next steps**
1. This PR only include the mode definitions, configurations, totkenizer, parallize_fn, and
   pipelining_fn.  We may also want to extend ModelSpec to include optimizer and lr_scheduler
2. Current TorchTitan parallelize and pipelining_fn import ModelArgs which can cause circular imports.
   We should fix this issue.

**What does this PR do?**
1. Introduces `ModelSpec` to describe a model and how to parallelize it.
2. Requires all models to define `build_model_spec()` or `model_spec`, which will be imported by the model module.
3. Calls `build_model_specs()` in the trainer to obtain `model_specs`, which are then used to retrieve the corresponding model spec.
4. Allows users to dynamically import a model not implemented by TorchTitan using --experimental.model_module_path.

**Why do we need this PR?**
This PR enables users to integrate new models with TorchTitan without making intrusive changes to the TorchTitan codebase.

**Next steps**
1. This PR includes only the model definitions, configurations, tokenizer, parallelize_fn, and pipelining_fn. We may want to extend ModelSpec to include the optimizer and learning rate scheduler.
2. The current TorchTitan parallelize and pipelining_fn import ModelArgs, which can lead to circular imports. This issue needs to be addressed.

ghstack-source-id: a88ff3ebe5c869055dd3314fb1b791855fd0e0b2
Pull Request resolved: #814
  • Loading branch information
fegin committed Jan 31, 2025
1 parent d4c86e3 commit d6d103b
Show file tree
Hide file tree
Showing 6 changed files with 158 additions and 23 deletions.
23 changes: 23 additions & 0 deletions torchtitan/config_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -375,6 +375,20 @@ def __init__(self):
The default value is 'allgather'.
""",
)
self.parser.add_argument(
"--experimental.custom_model_path",
type=str,
default="",
help="""
The --custom_model_path option allows to specify a custom path to a model module
that is not natively implemented within TorchTitan.
Acceptable values are the file system path to the module (e.g., my_models/model_x)
dotted import module (e.g., some_package.model_x).
""",
)
self.parser.add_argument(
"--training.mixed_precision_param",
type=str,
Expand Down Expand Up @@ -638,6 +652,15 @@ def parse_args(self, args_list: list = sys.argv[1:]):
exp["pipeline_parallel_split_points"]
)

if (
"experimental" in args_dict
and "model_module_path" in args_dict["experimental"]
and args_dict["experimental"]["model_module_path"]
):
from torchtitan.models import add_model_spec_path

add_model_spec_path(args_dict["experimental"]["model_module_path"])

# override args dict with cmd_args
cmd_args_dict = self._args_to_two_level_dict(cmd_args)
for section, section_args in cmd_args_dict.items():
Expand Down
28 changes: 28 additions & 0 deletions torchtitan/model_spec.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
from dataclasses import dataclass
from typing import Callable, Dict, List, Protocol, Tuple, Type

import torch.nn as nn
from torch.distributed.pipelining.schedules import _PipelineSchedule


@dataclass
class BaseModelArgs:
_enforced: str = "This field is used to enforce all fields have defaults."


class ModelProtocol(Protocol):
def from_model_args(self, args: BaseModelArgs) -> nn.Module:
...


@dataclass
class ModelSpec:
name: str
cls: Type[nn.Module]
config: Dict[str, BaseModelArgs]
# As for now, this is a string. So it will have to be built-in to the
# TorchTitan library. In the future, we can make this a defined class
# that can be extended like ModelSpec.
tokenizer: str
parallelize_fn: Callable[[nn.Module], None]
pipelining_fn: Callable[[nn.Module], Tuple[_PipelineSchedule, List[nn.Module]]]
87 changes: 79 additions & 8 deletions torchtitan/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,14 +4,85 @@
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

from torchtitan.models.llama import llama3_configs, Transformer
import importlib

models_config = {
"llama3": llama3_configs,
}
import os
import pkgutil
from typing import Dict, Set

model_name_to_cls = {"llama3": Transformer}
import torchtitan.models as models
from torchtitan.model_spec import ModelSpec

model_name_to_tokenizer = {
"llama3": "tiktoken",
}

_model_specs_path: Set[str] = set()


def _load_module(path: str):
path = os.path.expanduser(path)

# 1. Check if path is an existing file or directory path.
if os.path.exists(path):
if os.path.isdir(path):
init_file = os.path.join(path, "__init__.py")
if os.path.isfile(init_file):
return _load_module_from_init(path)

raise ImportError(
f"Directory '{path}' is not a Python package because it does not "
"contain an __init__.py file."
)
else:
raise ImportError(f"Path '{path}' is not a directory.")

# 2. If not a valid path, assume it's a dotted module name.
return importlib.import_module(path)


def _load_module_from_init(path: str):
module_name = os.path.basename(os.path.normpath(path))
init_file = os.path.join(path, "__init__.py")

spec = importlib.util.spec_from_file_location(module_name, init_file)
if spec is None:
raise ImportError(f"Could not create spec from '{init_file}'")

module = importlib.util.module_from_spec(spec)
spec.loader.exec_module(module)
return module


for _, name, _ in pkgutil.iter_modules(models.__path__):
full_module_name = f"{models.__name__}.{name}"
_model_specs_path.add(full_module_name)
# model_module = importlib.import_module(full_module_name)
# load_spec_from_module(model_module)


def add_model_spec_path(path: str):
global _model_specs_path
_model_specs_path.add(path)


def build_model_specs() -> Dict[str, ModelSpec]:
"""
Load all model specs from the `models` package.
"""
global _model_specs_path
model_specs = {}
for path in _model_specs_path:
module = _load_module(path)
model_spec = getattr(module, "model_spec", None)
if model_spec is not None:
model_specs[model_spec.name] = model_spec
# We would like to just use `model_spec` but current torchtitan parallelize
# functions depend on ModelArgs and can cause circular imports.
# As a result, we have to use `build_model_spec` as a workaround.
build_model_spec = getattr(module, "build_model_spec", None)
if build_model_spec:
model_spec = build_model_spec()
model_specs[model_spec.name] = model_spec

return model_specs


__all__ = [add_model_spec_path, build_model_specs]
17 changes: 16 additions & 1 deletion torchtitan/models/llama/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,9 @@
#
# Copyright (c) Meta Platforms, Inc. All Rights Reserved.

from torchtitan.model_spec import ModelSpec
from torchtitan.models.llama.model import ModelArgs, Transformer

__all__ = ["Transformer"]

llama3_configs = {
"debugmodel": ModelArgs(dim=256, n_layers=8, n_heads=16, rope_theta=500000),
Expand Down Expand Up @@ -40,3 +40,18 @@
rope_theta=500000,
),
}


def build_model_spec() -> ModelSpec:
# Avoid circular import
from torchtitan.parallelisms.parallelize_llama import parallelize_llama
from torchtitan.parallelisms.pipeline_llama import pipeline_llama

return ModelSpec(
name="llama3",
cls=Transformer,
config=llama3_configs,
tokenizer="tiktoken",
parallelize_fn=parallelize_llama,
pipelining_fn=pipeline_llama,
)
5 changes: 3 additions & 2 deletions torchtitan/models/llama/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,11 +13,12 @@
import torch
import torch.nn.functional as F
from torch import nn
from torchtitan.model_spec import BaseModelArgs, ModelProtocol
from torchtitan.models.norms import build_norm


@dataclass
class ModelArgs:
class ModelArgs(BaseModelArgs):
dim: int = 4096
n_layers: int = 32
n_heads: int = 32
Expand Down Expand Up @@ -258,7 +259,7 @@ def init_weights(self, init_std: float):
nn.init.trunc_normal_(linear.weight, mean=0.0, std=init_std)


class TransformerBlock(nn.Module):
class TransformerBlock(nn.Module, ModelProtocol):
"""
TransformerBlock Module
Expand Down
21 changes: 9 additions & 12 deletions train.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,13 +19,9 @@
from torchtitan.float8 import Float8Handler
from torchtitan.logging import init_logger, logger
from torchtitan.metrics import build_device_memory_monitor, build_metric_logger
from torchtitan.models import model_name_to_cls, model_name_to_tokenizer, models_config
from torchtitan.models import build_model_specs
from torchtitan.optimizer import build_lr_schedulers, build_optimizers
from torchtitan.parallelisms import (
models_parallelize_fns,
models_pipelining_fns,
ParallelDims,
)
from torchtitan.parallelisms import ParallelDims
from torchtitan.profiling import maybe_enable_memory_snapshot, maybe_enable_profiling
from torchtitan.utils import device_module, device_type

Expand Down Expand Up @@ -80,9 +76,10 @@ def main(job_config: JobConfig):
world_mesh, device, job_config.training.seed, job_config.training.deterministic
)
model_name = job_config.model.name
model_spec = build_model_specs()[model_name]

# build tokenizer
tokenizer_type = model_name_to_tokenizer[model_name]
tokenizer_type = model_spec.tokenizer
tokenizer = build_tokenizer(tokenizer_type, job_config.model.tokenizer_path)
# build dataloader
data_loader = build_hf_data_loader(
Expand All @@ -96,8 +93,8 @@ def main(job_config: JobConfig):
)

# build model (using meta init)
model_cls = model_name_to_cls[model_name]
model_config = models_config[model_name][job_config.model.flavor]
model_cls = model_spec.cls
model_config = model_spec.config[job_config.model.flavor]
# set the model configs from training inputs:
# 1. norm type to decide which norm layer to use
# 2. vocab size from tokenizer
Expand Down Expand Up @@ -151,7 +148,7 @@ def loss_fn(pred, labels):
# apply parallelisms and initialization
if parallel_dims.pp_enabled:
# apply PT-D Pipeline Parallel
pp_schedule, model_parts = models_pipelining_fns[model_name](
pp_schedule, model_parts = model_spec.pipelining_fn(
model, pp_mesh, parallel_dims, job_config, device, model_config, loss_fn
)
# when PP is enabled, `model` obj is no longer used after this point, model_parts is used instead
Expand All @@ -162,14 +159,14 @@ def loss_fn(pred, labels):
# optimizer, and checkpointing
for m in model_parts:
# apply SPMD-style PT-D techniques
models_parallelize_fns[model_name](m, world_mesh, parallel_dims, job_config)
model_spec.parallelize_fn(m, world_mesh, parallel_dims, job_config)
m.to_empty(device=init_device)
with torch.no_grad():
m.init_weights(buffer_device=buffer_device)
m.train()
else:
# apply PT-D Tensor Parallel, activation checkpointing, torch.compile, Data Parallel
models_parallelize_fns[model_name](model, world_mesh, parallel_dims, job_config)
model_spec.parallelize_fn(model, world_mesh, parallel_dims, job_config)
model.to_empty(device=init_device)
with torch.no_grad():
model.init_weights(buffer_device=buffer_device)
Expand Down

0 comments on commit d6d103b

Please sign in to comment.