-
Notifications
You must be signed in to change notification settings - Fork 328
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Introducing a generic ModelConverter
interface.
#823
Changes from all commits
9c227aa
ed24d73
abf3cb9
a210898
dc8bb73
dc5f891
84860d6
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,44 @@ | ||
# Copyright (c) Meta Platforms, Inc. and affiliates. | ||
# All rights reserved. | ||
# | ||
# This source code is licensed under the BSD-style license found in the | ||
# LICENSE file in the root directory of this source tree. | ||
|
||
from torchtitan.config_manager import JobConfig | ||
from torchtitan.float8 import Float8Converter | ||
from torchtitan.model_converter import build_model_converters, ModelConvertersContainer | ||
from torchtitan.parallelisms import ParallelDims | ||
|
||
|
||
def build_parallel_dims(job_config, world_size): | ||
parallel_dims = ParallelDims( | ||
dp_shard=job_config.training.data_parallel_shard_degree, | ||
dp_replicate=job_config.training.data_parallel_replicate_degree, | ||
cp=job_config.experimental.context_parallel_degree, | ||
tp=job_config.training.tensor_parallel_degree, | ||
pp=job_config.experimental.pipeline_parallel_degree, | ||
world_size=world_size, | ||
enable_loss_parallel=not job_config.training.disable_loss_parallel, | ||
) | ||
return parallel_dims | ||
|
||
|
||
def test_build_model_converters_empty_list(): | ||
config = JobConfig() | ||
config.parse_args([]) | ||
parallel_dims = build_parallel_dims(config, 1) | ||
|
||
model_converters = build_model_converters(config, parallel_dims) | ||
assert isinstance(model_converters, ModelConvertersContainer) | ||
assert model_converters.converters == [] | ||
|
||
|
||
def test_build_model_converters_float8_converter(): | ||
config = JobConfig() | ||
config.parse_args(["--model.converters", "float8"]) | ||
parallel_dims = build_parallel_dims(config, 1) | ||
|
||
model_converters = build_model_converters(config, parallel_dims) | ||
assert isinstance(model_converters, ModelConvertersContainer) | ||
assert len(model_converters.converters) == 1 | ||
assert isinstance(model_converters.converters[0], Float8Converter) |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,80 @@ | ||
# Copyright (c) Meta Platforms, Inc. and affiliates. | ||
# All rights reserved. | ||
# | ||
# This source code is licensed under the BSD-style license found in the | ||
# LICENSE file in the root directory of this source tree. | ||
from typing import Dict, List, Protocol, Union | ||
|
||
import torch.nn as nn | ||
|
||
from torchtitan.config_manager import JobConfig | ||
from torchtitan.parallelisms import ParallelDims | ||
|
||
|
||
class ModelConverter(Protocol): | ||
"""General model converter interface. | ||
|
||
A model converter is applying a modification to PyTorch model. | ||
Typical use cases are: | ||
- Quantization: using QAT, FP8, ... specialized linear layers; | ||
- Fused optimized layers (e.g. flash-attention, norms, ...) | ||
""" | ||
|
||
def __init__(self, job_config: JobConfig, parallel_dims: ParallelDims): | ||
... | ||
|
||
def convert(self, model: nn.Module): | ||
"""Inplace convertion of the model.""" | ||
... | ||
|
||
def post_optimizer_hook(self, model: Union[nn.Module, List[nn.Module]]): | ||
"""Post-optimizer (optional) hook (e.g. compute weights statistics).""" | ||
... | ||
|
||
|
||
_registry_model_converter_cls: Dict[str, type[ModelConverter]] = {} | ||
"""Registry of model converter classes. | ||
""" | ||
|
||
|
||
def register_model_converter(converter_cls: type[ModelConverter], name: str): | ||
"""Register a model converter class. | ||
|
||
A registered model converter can be applied on any model | ||
using the `model.converters` config parameter. | ||
""" | ||
assert ( | ||
name not in _registry_model_converter_cls | ||
), f"A model converter '{name}' is already registered." | ||
_registry_model_converter_cls[name] = converter_cls | ||
|
||
|
||
class ModelConvertersContainer(ModelConverter): | ||
"""Model converters sequential container. | ||
|
||
The class build the sequence of model converters defined in `model.converters` | ||
job config, and apply them to the model sequentially. | ||
""" | ||
|
||
def __init__(self, job_config: JobConfig, parallel_dims: ParallelDims): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Regarding #814 (comment), I think we can call There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Could be indeed an option. Happy to discuss on a new PR, the small downside is my hunch that registers should be immutable, I have a bad feeling about modifying an existing entry! But maybe it would be an issue. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. yeah, the intertwined logic is bad -- can we just specify it in the |
||
converter_classes = [ | ||
_registry_model_converter_cls[name] for name in job_config.model.converters | ||
] | ||
self.converters = [ | ||
mh_cls(job_config, parallel_dims) for mh_cls in converter_classes | ||
] | ||
|
||
def convert(self, model: nn.Module): | ||
for mh in self.converters: | ||
mh.convert(model) | ||
|
||
def post_optimizer_hook(self, model: Union[nn.Module, List[nn.Module]]): | ||
for mh in self.converters: | ||
mh.post_optimizer_hook(model) | ||
|
||
|
||
def build_model_converters( | ||
job_config: JobConfig, parallel_dims: ParallelDims | ||
) -> ModelConvertersContainer: | ||
"""Build the collection of model converters to apply to the model.""" | ||
return ModelConvertersContainer(job_config, parallel_dims) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
please remove this in
config_manager.py
tooThere was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
My mistake, forgot to add it to the commit! Now fixed.