-
Notifications
You must be signed in to change notification settings - Fork 268
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Allow users to use the customized model
**What does this PR do?** 1. This PR introduce ModelSpec to decribe a model and how to parallelize a model. 2. All the models should define `build_model_spec()` or `model_spec` to be imported by the `model` module. 3. `build_model_specs()` is called in the trainer to get the `model_specs` and the result is used to get the corresponding model spec. 4. Users can also use `--experimental.model_module_path` to dynamically import a model that is not implemented by TorchTitan. **Why do we need this PR?** This allows users to use TorchTitan with a new model without intrusively change TorchTitan code. **Next steps** 1. This PR only include the mode definitions, configurations, totkenizer, parallize_fn, and pipelining_fn. We may also want to extend ModelSpec to include optimizer and lr_scheduler 2. Current TorchTitan parallelize and pipelining_fn import ModelArgs which can cause circular imports. We should fix this issue. **What does this PR do?** 1. Introduces `ModelSpec` to describe a model and how to parallelize it. 2. Requires all models to define `build_model_spec()` or `model_spec`, which will be imported by the model module. 3. Calls `build_model_specs()` in the trainer to obtain `model_specs`, which are then used to retrieve the corresponding model spec. 4. Allows users to dynamically import a model not implemented by TorchTitan using --experimental.model_module_path. **Why do we need this PR?** This PR enables users to integrate new models with TorchTitan without making intrusive changes to the TorchTitan codebase. **Next steps** 1. This PR includes only the model definitions, configurations, tokenizer, parallelize_fn, and pipelining_fn. We may want to extend ModelSpec to include the optimizer and learning rate scheduler. 2. The current TorchTitan parallelize and pipelining_fn import ModelArgs, which can lead to circular imports. This issue needs to be addressed. ghstack-source-id: a88ff3ebe5c869055dd3314fb1b791855fd0e0b2 Pull Request resolved: #814
- Loading branch information
Showing
6 changed files
with
158 additions
and
23 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,28 @@ | ||
from dataclasses import dataclass | ||
from typing import Callable, Dict, List, Protocol, Tuple, Type | ||
|
||
import torch.nn as nn | ||
from torch.distributed.pipelining.schedules import _PipelineSchedule | ||
|
||
|
||
@dataclass | ||
class BaseModelArgs: | ||
_enforced: str = "This field is used to enforce all fields have defaults." | ||
|
||
|
||
class ModelProtocol(Protocol): | ||
def from_model_args(self, args: BaseModelArgs) -> nn.Module: | ||
... | ||
|
||
|
||
@dataclass | ||
class ModelSpec: | ||
name: str | ||
cls: Type[nn.Module] | ||
config: Dict[str, BaseModelArgs] | ||
# As for now, this is a string. So it will have to be built-in to the | ||
# TorchTitan library. In the future, we can make this a defined class | ||
# that can be extended like ModelSpec. | ||
tokenizer: str | ||
parallelize_fn: Callable[[nn.Module], None] | ||
pipelining_fn: Callable[[nn.Module], Tuple[_PipelineSchedule, List[nn.Module]]] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters