-
Notifications
You must be signed in to change notification settings - Fork 270
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Allow users to use the customized model
**What does this PR do?** 1. This PR introduce ModelSpec to decribe a model and how to parallelize a model. 2. All the models should define `build_model_spec()` or `model_spec` to be imported by the `model` module. 3. `build_model_specs()` is called in the trainer to get the `model_specs` and the result is used to get the corresponding model spec. 4. Users can also use `--experimental.model_module_path` to dynamically import a model that is not implemented by TorchTitan. **Why do we need this PR?** This allows users to use TorchTitan with a new model without intrusively change TorchTitan code. **Next steps** 1. This PR only include the mode definitions, configurations, totkenizer, parallize_fn, and pipelining_fn. We may also want to extend ModelSpec to include optimizer and lr_scheduler 2. Current TorchTitan parallelize and pipelining_fn import ModelArgs which can cause circular imports. We should fix this issue. ghstack-source-id: 671424d38a040c8594f8b3d692cd8e141ce5c656 Pull Request resolved: #814
- Loading branch information
Showing
16 changed files
with
465 additions
and
172 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,12 @@ | ||
# Copyright (c) Meta Platforms, Inc. and affiliates. | ||
# All rights reserved. | ||
# | ||
# This source code is licensed under the BSD-style license found in the | ||
# LICENSE file in the root directory of this source tree. | ||
# | ||
# Copyright (c) Meta Platforms, Inc. All Rights Reserved. | ||
|
||
# Import the built-in models here so that the corresponding register_model_spec() | ||
# will be called. | ||
import torchtitan.models | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,77 @@ | ||
# Copyright (c) Meta Platforms, Inc. and affiliates. | ||
# All rights reserved. | ||
# | ||
# This source code is licensed under the BSD-style license found in the | ||
# LICENSE file in the root directory of this source tree. | ||
# | ||
# Copyright (c) Meta Platforms, Inc. All Rights Reserved. | ||
|
||
|
||
from dataclasses import dataclass | ||
from typing import Callable, Dict, List, Protocol, Tuple, Type | ||
|
||
import torch.nn as nn | ||
from torch.distributed.pipelining.schedules import _PipelineSchedule | ||
from torchtitan.config_manager import JobConfig | ||
from torchtitan.optimizer import LRSchedulersContainer, OptimizersContainer | ||
|
||
|
||
@dataclass | ||
class BaseModelArgs: | ||
"""All ModelArgs should inherit from this class. | ||
The only usage of this class is type checking but allows us to extend common | ||
arguments to all models in the future. | ||
""" | ||
|
||
_enforced: str = "This field is used to enforce all fields have defaults." | ||
|
||
|
||
class ModelProtocol(Protocol): | ||
"""Defines the interface for a model class. | ||
This is used to enforce that all model classes have some methods that are | ||
required by the TorchTitan trainer. | ||
""" | ||
|
||
@staticmethod | ||
def from_model_args(self, args: BaseModelArgs) -> nn.Module: ... | ||
|
||
|
||
@dataclass | ||
class ModelSpec: | ||
name: str | ||
cls: Type[nn.Module] | ||
config: Dict[str, BaseModelArgs] | ||
# TODO: Add a ``build_dataloader_fn`` | ||
# As for now, this is a string. So it will have to be built-in to the | ||
# TorchTitan library. A better way would be to have a dataloader class | ||
# and a ``build_dataloader`` function that take job_config to consume | ||
# the different dataloader and tokenizer configs. | ||
tokenizer: str | ||
parallelize_fn: Callable[[nn.Module], None] | ||
pipelining_fn: Callable[[nn.Module], Tuple[_PipelineSchedule, List[nn.Module]]] | ||
build_optimizers_fn: Callable[[List[nn.Module], JobConfig], OptimizersContainer] | ||
build_lr_schedulers_fn: Callable[ | ||
[List[nn.Module], JobConfig], LRSchedulersContainer | ||
] | ||
|
||
# TODO: Add a FQN convert fn to allow users to load checkpoints from | ||
# HuggingFace or other sources that have different FQN conventions. | ||
|
||
|
||
_model_specs = {} | ||
|
||
|
||
def register_model_spec(model_spec: ModelSpec) -> None: | ||
global _model_specs | ||
if model_spec.name in _model_specs: | ||
raise ValueError(f"Model {model_spec.name} is already registered.") | ||
_model_specs[model_spec.name] = model_spec | ||
|
||
|
||
def get_model_spec(name: str) -> ModelSpec: | ||
global _model_specs | ||
if name not in _model_specs: | ||
raise ValueError(f"Model {name} is not registered.") | ||
return _model_specs[name] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.