Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Introducing a generic ModelConverter interface. #823

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions docs/float8.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,9 @@ USE_CPP=0 python -m pip install git+https://github.com/pytorch/ao.git

Launch training job with the following command (or alternatively set configs in toml files)
```
CONFIG_FILE="./train_configs/llama3_8b.toml" ./run_llama_train.sh --float8.enable_float8_linear --float8.enable_fsdp_float8_all_gather --float8.precompute_float8_dynamic_scale_for_fsdp
CONFIG_FILE="./train_configs/llama3_8b.toml" ./run_llama_train.sh --model.converters="float8" --float8.enable_fsdp_float8_all_gather --float8.precompute_float8_dynamic_scale_for_fsdp
```
* `--float8.enable_float8_linear`: swap `nn.Linear` with `Float8Linear` to perform float8 matmul.
* `--model.converters="float8"`: swap `nn.Linear` with `Float8Linear` to perform float8 matmul.
* `--float8.enable_fsdp_float8_all_gather`: cast `Float8Linear.weight` from high precision to float8 before FSDP all-gather so we can communicate in float8 to save bandwidth.
* `--float8.precompute_float8_dynamic_scale_for_fsdp` (optional): communicate AMAX/scales efficiently in a single all-reduce for all parameters instead of doing many small all-reduce for each parameter.

Expand Down
17 changes: 9 additions & 8 deletions scripts/estimate/estimation.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,15 +9,16 @@
import os

import torch

import torchtitan.float8 # noqa
from torch._guards import active_fake_mode
from torch._subclasses.fake_tensor import FakeTensorMode
from torch.distributed._tools.fsdp2_mem_tracker import FSDPMemTracker
from torch.testing._internal.distributed.fake_pg import FakeStore

from torchtitan import utils
from torchtitan.config_manager import JobConfig
from torchtitan.float8 import Float8Handler
from torchtitan.logging import init_logger, logger
from torchtitan.model_converter import build_model_converters
from torchtitan.optimizer import build_lr_schedulers, build_optimizers
from torchtitan.parallelisms import ParallelDims
from torchtitan.train_spec import get_train_spec
Expand Down Expand Up @@ -117,10 +118,9 @@ def loss_fn(pred, labels):
with torch.device("meta"):
model = model_cls.from_model_args(model_config)

# a no-op hander if float8 is not enabled
float8_handler = Float8Handler(job_config, parallel_dims)
# swap to Float8Linear based on float8 configs
float8_handler.convert_to_float8_training(model)
# Build the collection of model converters. No-op if `model.converters` empty
model_converters = build_model_converters(job_config, parallel_dims)
model_converters.convert(model)

# apply PT-D DP/TP parallelisms and activation checkpointing
train_spec.parallelize_fn(model, world_mesh, parallel_dims, job_config)
Expand Down Expand Up @@ -170,9 +170,10 @@ def loss_fn(pred, labels):
# optimizer step
optimizers.step()
lr_schedulers.step()
# calculate float8 dynamic amax/scale for all-parameter for FSDP2
# Post-optimizer model converters hook.
# e.g. calculate float8 dynamic amax/scale for all-parameter for FSDP2
# it issues a single all-reduce for all parameters at once for better performance
float8_handler.precompute_float8_dynamic_scale_for_fsdp(model)
model_converters.post_optimizer_hook(model)
optimizers.zero_grad()
print(f"Peak Memory at iter: {iter_idx}")
fsdp_memtracker.display_snapshot("peak", units="MiB", tabulate=True)
Expand Down
9 changes: 9 additions & 0 deletions tests/unit_tests/test_job_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -187,6 +187,15 @@ def test_parse_exclude_from_loading(self):
config.checkpoint.exclude_from_loading == cmdline_splits
), config.checkpoint.exclude_from_loading

def test_job_config_model_converters_split(self):
config = JobConfig()
config.parse_args([])
assert config.model.converters == []

config = JobConfig()
config.parse_args(["--model.converters", "float8,mxfp"])
assert config.model.converters == ["float8", "mxfp"]

def test_print_help(self):
config = JobConfig()
parser = config.parser
Expand Down
44 changes: 44 additions & 0 deletions tests/unit_tests/test_model_converter.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

from torchtitan.config_manager import JobConfig
from torchtitan.float8 import Float8Converter
from torchtitan.model_converter import build_model_converters, ModelConvertersContainer
from torchtitan.parallelisms import ParallelDims


def build_parallel_dims(job_config, world_size):
parallel_dims = ParallelDims(
dp_shard=job_config.training.data_parallel_shard_degree,
dp_replicate=job_config.training.data_parallel_replicate_degree,
cp=job_config.experimental.context_parallel_degree,
tp=job_config.training.tensor_parallel_degree,
pp=job_config.experimental.pipeline_parallel_degree,
world_size=world_size,
enable_loss_parallel=not job_config.training.disable_loss_parallel,
)
return parallel_dims


def test_build_model_converters_empty_list():
config = JobConfig()
config.parse_args([])
parallel_dims = build_parallel_dims(config, 1)

model_converters = build_model_converters(config, parallel_dims)
assert isinstance(model_converters, ModelConvertersContainer)
assert model_converters.converters == []


def test_build_model_converters_float8_converter():
config = JobConfig()
config.parse_args(["--model.converters", "float8"])
parallel_dims = build_parallel_dims(config, 1)

model_converters = build_model_converters(config, parallel_dims)
assert isinstance(model_converters, ModelConvertersContainer)
assert len(model_converters.converters) == 1
assert isinstance(model_converters.converters[0], Float8Converter)
3 changes: 3 additions & 0 deletions torchtitan/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,9 @@
#
# Copyright (c) Meta Platforms, Inc. All Rights Reserved.

# Import to register Float8Converter.
import torchtitan.float8 # noqa: F401

# Import the built-in models here so that the corresponding register_model_spec()
# will be called.
import torchtitan.models # noqa: F401
70 changes: 39 additions & 31 deletions torchtitan/config_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,9 +26,22 @@


def string_list(raw_arg):
"""Comma-separated string list argument."""
return [s.strip() for s in raw_arg.split(",") if s.strip()]


def check_string_list_argument(args_dict: dict[str, any], fullargname: str):
section, name = fullargname.split(".")
# Split string list which are still raw strings.
if (
section in args_dict
and name in args_dict[section]
and isinstance(args_dict[section][name], str)
):
sec = args_dict[section]
sec[name] = string_list(sec[name])


class JobConfig:
"""
A helper class to manage the train configuration.
Expand Down Expand Up @@ -183,6 +196,19 @@ def __init__(self):
default="./torchtitan/datasets/tokenizer/tokenizer.model",
help="Tokenizer path",
)
self.parser.add_argument(
"--model.converters",
type=string_list,
nargs="+",
default=[],
help="""
Comma separated list of converters to apply to the model.

For instance, the `float8` converter swaps `torch.nn.Linear`
with `Float8Linear`. This feature requires you to install 'torchao'
which can be found here: https://github.com/pytorch/ao
""",
)

# optimizer configs
self.parser.add_argument(
Expand Down Expand Up @@ -575,15 +601,6 @@ def __init__(self):
)

# float8 configs
self.parser.add_argument(
"--float8.enable_float8_linear",
action="store_true",
help="""
If true, swaps `torch.nn.Linear` with `Float8Linear`.
This feature requires you to install 'torchao' which can be found
here: https://github.com/pytorch/ao
""",
)
self.parser.add_argument(
"--float8.enable_fsdp_float8_all_gather",
action="store_true",
Expand Down Expand Up @@ -652,25 +669,11 @@ def parse_args(self, args_list: list = sys.argv[1:]):
logger.exception(f"Error details: {str(e)}")
raise e

# Checking string-list arguments are properly split into a list
# if split-points came from 'args' (from cmd line) it would have already been parsed into a list by that parser
if (
"experimental" in args_dict
and "pipeline_parallel_split_points" in args_dict["experimental"]
and isinstance(
args_dict["experimental"]["pipeline_parallel_split_points"], str
)
):
exp = args_dict["experimental"]
exp["pipeline_parallel_split_points"] = string_list(
exp["pipeline_parallel_split_points"]
)
if (
"checkpoint" in args_dict
and "exclude_from_loading" in args_dict["checkpoint"]
and isinstance(args_dict["checkpoint"]["exclude_from_loading"], str)
):
ckpt = args_dict["checkpoint"]
ckpt["exclude_from_loading"] = string_list(ckpt["exclude_from_loading"])
string_list_argnames = self._get_string_list_argument_names()
for n in string_list_argnames:
check_string_list_argument(args_dict, n)

# override args dict with cmd_args
cmd_args_dict = self._args_to_two_level_dict(cmd_args)
Expand Down Expand Up @@ -698,13 +701,21 @@ def _validate_config(self) -> None:
assert self.model.flavor
assert self.model.tokenizer_path

def _get_string_list_argument_names(self) -> list[str]:
"""Get the parser argument names of type `string_list`."""
string_list_args = [
v.dest for v in self.parser._actions if v.type is string_list
]
return string_list_args

def parse_args_from_command_line(
self, args_list
) -> Tuple[argparse.Namespace, argparse.Namespace]:
"""
Parse command line arguments and return the parsed args and the command line only args
"""
args = self.parser.parse_args(args_list)
string_list_argnames = set(self._get_string_list_argument_names())

# aux parser to parse the command line only args, with no defaults from main parser
aux_parser = argparse.ArgumentParser(argument_default=argparse.SUPPRESS)
Expand All @@ -713,14 +724,11 @@ def parse_args_from_command_line(
aux_parser.add_argument(
"--" + arg, action="store_true" if val else "store_false"
)
elif arg == "experimental.pipeline_parallel_split_points":
elif arg in string_list_argnames:
# without this special case, type inference breaks here,
# since the inferred type is just 'list' and it ends up flattening
# e.g. from ["layers.0", "layers.1"] into ["l", "a", "y", "e", "r", "s", ".0", ...]
aux_parser.add_argument("--" + arg, type=string_list)
elif arg == "checkpoint.exclude_from_loading":
# similar to the case above
aux_parser.add_argument("--" + arg, type=string_list)
else:
aux_parser.add_argument("--" + arg, type=type(val))

Expand Down
14 changes: 11 additions & 3 deletions torchtitan/float8.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@

from torchtitan.config_manager import JobConfig
from torchtitan.logging import logger
from torchtitan.model_converter import ModelConverter, register_model_converter
from torchtitan.parallelisms import ParallelDims


Expand All @@ -28,13 +29,11 @@ def _is_sm89_or_later():
return torch.cuda.is_available() and torch.cuda.get_device_capability() >= (8, 9)


class Float8Handler:
class Float8Converter(ModelConverter):
def __init__(self, job_config: JobConfig, parallel_dims: ParallelDims):
self.enabled = False

float8_config = job_config.float8
if not float8_config.enable_float8_linear:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

please remove this in config_manager.py too

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

My mistake, forgot to add it to the commit! Now fixed.

return
if not _is_sm89_or_later():
logger.warning(
"Failed to swap to Float8Linear because float8 is only supported on SM89 or later",
Expand Down Expand Up @@ -66,6 +65,12 @@ def __init__(self, job_config: JobConfig, parallel_dims: ParallelDims):

logger.info("Float8 training active")

def convert(self, model: nn.Module):
return self.convert_to_float8_training(model)

def post_optimizer_hook(self, model: Union[nn.Module, List[nn.Module]]):
return self.precompute_float8_dynamic_scale_for_fsdp(model)

def convert_to_float8_training(self, model: nn.Module):
"""
This function converts the linear layers of `model` to `Float8Linear`.
Expand Down Expand Up @@ -102,3 +107,6 @@ def precompute_float8_dynamic_scale_for_fsdp(
models = [model] if isinstance(model, nn.Module) else model
for m in models:
precompute_float8_dynamic_scale_for_fsdp(m)


register_model_converter(Float8Converter, "float8")
1 change: 1 addition & 0 deletions torchtitan/metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@

import torch
from torch.utils.tensorboard import SummaryWriter

from torchtitan.config_manager import JobConfig
from torchtitan.logging import logger
from torchtitan.parallelisms import ParallelDims
Expand Down
80 changes: 80 additions & 0 deletions torchtitan/model_converter.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,80 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
from typing import Dict, List, Protocol, Union

import torch.nn as nn

from torchtitan.config_manager import JobConfig
from torchtitan.parallelisms import ParallelDims


class ModelConverter(Protocol):
"""General model converter interface.

A model converter is applying a modification to PyTorch model.
Typical use cases are:
- Quantization: using QAT, FP8, ... specialized linear layers;
- Fused optimized layers (e.g. flash-attention, norms, ...)
"""

def __init__(self, job_config: JobConfig, parallel_dims: ParallelDims):
...

def convert(self, model: nn.Module):
"""Inplace convertion of the model."""
...

def post_optimizer_hook(self, model: Union[nn.Module, List[nn.Module]]):
"""Post-optimizer (optional) hook (e.g. compute weights statistics)."""
...


_registry_model_converter_cls: Dict[str, type[ModelConverter]] = {}
"""Registry of model converter classes.
"""


def register_model_converter(converter_cls: type[ModelConverter], name: str):
"""Register a model converter class.

A registered model converter can be applied on any model
using the `model.converters` config parameter.
"""
assert (
name not in _registry_model_converter_cls
), f"A model converter '{name}' is already registered."
_registry_model_converter_cls[name] = converter_cls


class ModelConvertersContainer(ModelConverter):
"""Model converters sequential container.

The class build the sequence of model converters defined in `model.converters`
job config, and apply them to the model sequentially.
"""

def __init__(self, job_config: JobConfig, parallel_dims: ParallelDims):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Regarding #814 (comment), I think we can call apply_to_train_specs to register hooks to optimizers here

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could be indeed an option. Happy to discuss on a new PR, the small downside is my hunch that registers should be immutable, I have a bad feeling about modifying an existing entry! But maybe it would be an issue.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yeah, the intertwined logic is bad -- can we just specify it in the TrainSpec construction and let the spec handle the registration and usage?

converter_classes = [
_registry_model_converter_cls[name] for name in job_config.model.converters
]
self.converters = [
mh_cls(job_config, parallel_dims) for mh_cls in converter_classes
]

def convert(self, model: nn.Module):
for mh in self.converters:
mh.convert(model)

def post_optimizer_hook(self, model: Union[nn.Module, List[nn.Module]]):
for mh in self.converters:
mh.post_optimizer_hook(model)


def build_model_converters(
job_config: JobConfig, parallel_dims: ParallelDims
) -> ModelConvertersContainer:
"""Build the collection of model converters to apply to the model."""
return ModelConvertersContainer(job_config, parallel_dims)
Loading