-
Notifications
You must be signed in to change notification settings - Fork 268
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add Dynamic Model Import and ModelSpec Definition #814
base: gh/fegin/8/base
Are you sure you want to change the base?
Changes from all commits
df1bc6a
dfc1649
720f12a
225bfcc
650152e
687fda9
6a51325
5b33b65
2e569d7
bab9bf5
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,12 @@ | ||
# Copyright (c) Meta Platforms, Inc. and affiliates. | ||
# All rights reserved. | ||
# | ||
# This source code is licensed under the BSD-style license found in the | ||
# LICENSE file in the root directory of this source tree. | ||
# | ||
# Copyright (c) Meta Platforms, Inc. All Rights Reserved. | ||
|
||
# Import the built-in models here so that the corresponding register_model_spec() | ||
# will be called. | ||
import torchtitan.models | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,77 @@ | ||
# Copyright (c) Meta Platforms, Inc. and affiliates. | ||
# All rights reserved. | ||
# | ||
# This source code is licensed under the BSD-style license found in the | ||
# LICENSE file in the root directory of this source tree. | ||
# | ||
# Copyright (c) Meta Platforms, Inc. All Rights Reserved. | ||
|
||
|
||
from dataclasses import dataclass | ||
from typing import Callable, Dict, List, Protocol, Tuple, Type | ||
|
||
import torch.nn as nn | ||
from torch.distributed.pipelining.schedules import _PipelineSchedule | ||
from torchtitan.config_manager import JobConfig | ||
from torchtitan.optimizer import LRSchedulersContainer, OptimizersContainer | ||
|
||
|
||
@dataclass | ||
class BaseModelArgs: | ||
"""All ModelArgs should inherit from this class. | ||
The only usage of this class is type checking but allows us to extend common | ||
arguments to all models in the future. | ||
""" | ||
|
||
_enforced: str = "This field is used to enforce all fields have defaults." | ||
|
||
|
||
class ModelProtocol(Protocol): | ||
"""Defines the interface for a model class. | ||
This is used to enforce that all model classes have some methods that are | ||
required by the TorchTitan trainer. | ||
""" | ||
|
||
@staticmethod | ||
def from_model_args(self, args: BaseModelArgs) -> nn.Module: ... | ||
|
||
|
||
@dataclass | ||
class ModelSpec: | ||
name: str | ||
cls: Type[nn.Module] | ||
config: Dict[str, BaseModelArgs] | ||
# TODO: Add a ``build_dataloader_fn`` | ||
# As for now, this is a string. So it will have to be built-in to the | ||
# TorchTitan library. A better way would be to have a dataloader class | ||
# and a ``build_dataloader`` function that take job_config to consume | ||
# the different dataloader and tokenizer configs. | ||
tokenizer: str | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. currently tokenizer is part of data loader maybe let's remove it for now |
||
parallelize_fn: Callable[[nn.Module], None] | ||
pipelining_fn: Callable[[nn.Module], Tuple[_PipelineSchedule, List[nn.Module]]] | ||
build_optimizers_fn: Callable[[List[nn.Module], JobConfig], OptimizersContainer] | ||
build_lr_schedulers_fn: Callable[ | ||
[List[nn.Module], JobConfig], LRSchedulersContainer | ||
] | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. For some models we may need to alter |
||
|
||
# TODO: Add a FQN convert fn to allow users to load checkpoints from | ||
# HuggingFace or other sources that have different FQN conventions. | ||
|
||
|
||
_model_specs = {} | ||
|
||
|
||
def register_model_spec(model_spec: ModelSpec) -> None: | ||
global _model_specs | ||
if model_spec.name in _model_specs: | ||
raise ValueError(f"Model {model_spec.name} is already registered.") | ||
_model_specs[model_spec.name] = model_spec | ||
|
||
|
||
def get_model_spec(name: str) -> ModelSpec: | ||
global _model_specs | ||
if name not in _model_specs: | ||
raise ValueError(f"Model {name} is not registered.") | ||
return _model_specs[name] |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -6,9 +6,15 @@ | |
# | ||
# Copyright (c) Meta Platforms, Inc. All Rights Reserved. | ||
|
||
from torchtitan.model_spec import ModelSpec, register_model_spec | ||
from torchtitan.models.llama.model import ModelArgs, Transformer | ||
from torchtitan.optimizer import build_lr_schedulers, build_optimizers | ||
|
||
from .parallelize_llama import parallelize_llama | ||
from .pipeline_llama import pipeline_llama | ||
|
||
__all__ = ["parallelize_llama", "pipeline_llama", "ModelArgs", "Transformer"] | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. do we need to expose these fields in There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yes, so that users can reuse the parallelism APIs from llama. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. That makes sense. But maybe also |
||
|
||
__all__ = ["Transformer"] | ||
|
||
llama3_configs = { | ||
"debugmodel": ModelArgs(dim=256, n_layers=8, n_heads=16, rope_theta=500000), | ||
|
@@ -40,3 +46,17 @@ | |
rope_theta=500000, | ||
), | ||
} | ||
|
||
|
||
register_model_spec( | ||
ModelSpec( | ||
name="llama3", | ||
cls=Transformer, | ||
config=llama3_configs, | ||
tokenizer="tiktoken", | ||
parallelize_fn=parallelize_llama, | ||
pipelining_fn=pipeline_llama, | ||
build_optimizers_fn=build_optimizers, | ||
build_lr_schedulers_fn=build_lr_schedulers, | ||
) | ||
) |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -13,11 +13,12 @@ | |
import torch | ||
import torch.nn.functional as F | ||
from torch import nn | ||
from torchtitan.model_spec import BaseModelArgs, ModelProtocol | ||
from torchtitan.models.norms import build_norm | ||
|
||
|
||
@dataclass | ||
class ModelArgs: | ||
class ModelArgs(BaseModelArgs): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Down the road we will have many models, like MM model. Do we want all model args to inherit this? Currently we use different model args for different model arch. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This mainly for typing for now but also preserve the ability to introduce common model args. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. can we rename it to "TransformerModelArgs"? |
||
dim: int = 4096 | ||
n_layers: int = 32 | ||
n_heads: int = 32 | ||
|
@@ -331,7 +332,7 @@ def init_weights(self): | |
self.feed_forward.init_weights(self.weight_init_std) | ||
|
||
|
||
class Transformer(nn.Module): | ||
class Transformer(nn.Module, ModelProtocol): | ||
""" | ||
Transformer Module | ||
|
||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -14,16 +14,19 @@ | |
from torch.distributed import DeviceMesh | ||
from torch.distributed.pipelining import PipelineStage | ||
|
||
from torch.distributed.pipelining.schedules import _PipelineSchedule | ||
|
||
from torchtitan.config_manager import JobConfig | ||
from torchtitan.logging import logger | ||
from torchtitan.models.llama.model import ModelArgs | ||
from torchtitan.parallelisms.parallel_dims import ParallelDims | ||
from torchtitan.parallelisms.pipelining_utils import ( | ||
from torchtitan.parallelisms import ( | ||
build_pipeline_schedule, | ||
generate_split_points, | ||
ParallelDims, | ||
stage_ids_this_rank, | ||
) | ||
|
||
from .model import ModelArgs | ||
|
||
|
||
DeviceType = Union[int, str, torch.device] | ||
|
||
|
@@ -36,7 +39,7 @@ def pipeline_llama( | |
device: DeviceType, | ||
model_config: ModelArgs, | ||
loss_fn: Callable[..., torch.Tensor], | ||
): | ||
) -> tuple[_PipelineSchedule, list[nn.Module]]: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Typing in uppercase vs. lowercase seems inconsistent throughout the PR. Is this intentionally? and what's the recommended way? hmm it seems only for state_dict you used uppercase, maybe because compatibility. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Uppercase is the recommended way if we don't support <= Python 3.8. After Pytorch 2.6, that's the case. So we should just change to the lower case one. I may revisit the code and try to change all to lowercases. |
||
stages, models = pipeline_llama_manual_split( | ||
model, pp_mesh, parallel_dims, job_config, device, model_config | ||
) | ||
|
@@ -53,7 +56,7 @@ def pipeline_llama_manual_split( | |
job_config: JobConfig, | ||
device: DeviceType, | ||
model_config: ModelArgs, | ||
): | ||
) -> tuple[list[PipelineStage], list[nn.Module]]: | ||
""" | ||
This API extracts one torch.nn.Module objects for the part of the model configured to run inside this stage. | ||
|
||
|
@@ -67,10 +70,16 @@ def pipeline_llama_manual_split( | |
|
||
splits = ( | ||
job_config.experimental.pipeline_parallel_split_points | ||
or generate_split_points(job_config, parallel_dims.pp, model_config) | ||
or generate_split_points(job_config, parallel_dims.pp, model_config.n_layers) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. nice |
||
) | ||
|
||
def _build_stage(stage_idx, start_layer, stop_layer, is_first=False, is_last=False): | ||
def _build_stage( | ||
stage_idx: int, | ||
start_layer: Optional[str], | ||
stop_layer: Optional[str], | ||
is_first: bool = False, | ||
is_last: bool = False, | ||
) -> tuple[PipelineStage, nn.Module]: | ||
model = copy.deepcopy(whole_model) | ||
if not is_first: | ||
model.tok_embeddings = None | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Since the
Spec
is not only about model, e.g. conceptually there can be multiple ways to do training for the same model (gpu/tpu, customized parallelize/pipeline), shall we consider renaming it toTrainSpec
?There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
That's actually a good question and suggestion. I am open to this option.