Skip to content

Commit 182a7ae

Browse files
committed
Using string_list for model.handlers argument.
1 parent 64a5338 commit 182a7ae

12 files changed

+148
-128
lines changed

docs/float8.md

+1-1
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@ Launch training job with the following command (or alternatively set configs in
99
```
1010
CONFIG_FILE="./train_configs/llama3_8b.toml" ./run_llama_train.sh --float8.enable_float8_linear --float8.enable_fsdp_float8_all_gather --float8.precompute_float8_dynamic_scale_for_fsdp
1111
```
12-
<!-- * `--float8.enable_float8_linear`: swap `nn.Linear` with `Float8Linear` to perform float8 matmul. -->
12+
* `--float8.enable_float8_linear`: swap `nn.Linear` with `Float8Linear` to perform float8 matmul.
1313
* `--float8.enable_fsdp_float8_all_gather`: cast `Float8Linear.weight` from high precision to float8 before FSDP all-gather so we can communicate in float8 to save bandwidth.
1414
* `--float8.precompute_float8_dynamic_scale_for_fsdp` (optional): communicate AMAX/scales efficiently in a single all-reduce for all parameters instead of doing many small all-reduce for each parameter.
1515

tests/unit_tests/test_job_config.py

+9
Original file line numberDiff line numberDiff line change
@@ -116,6 +116,15 @@ def test_parse_pp_split_points(self):
116116
config.experimental.pipeline_parallel_split_points == cmdline_splits
117117
), config.experimental.pipeline_parallel_split_points
118118

119+
def test_job_config_model_converters_split(self):
120+
config = JobConfig()
121+
config.parse_args([])
122+
assert config.model.converters == []
123+
124+
config = JobConfig()
125+
config.parse_args(["--model.converters", "float8,mxfp"])
126+
assert config.model.converters == ["float8", "mxfp"]
127+
119128
def test_print_help(self):
120129
config = JobConfig()
121130
parser = config.parser

torchtitan/config_manager.py

+41-26
Original file line numberDiff line numberDiff line change
@@ -26,9 +26,22 @@
2626

2727

2828
def string_list(raw_arg):
29+
"""Comma-separated string list argument."""
2930
return raw_arg.split(",")
3031

3132

33+
def check_string_list_argument(args_dict: dict[str, any], fullargname: str):
34+
section, name = fullargname.split(".")
35+
# Split string list which are still raw strings.
36+
if (
37+
section in args_dict
38+
and name in args_dict[section]
39+
and isinstance(args_dict[section][name], str)
40+
):
41+
sec = args_dict[section]
42+
sec[name] = string_list(sec[name])
43+
44+
3245
class JobConfig:
3346
"""
3447
A helper class to manage the train configuration.
@@ -183,13 +196,14 @@ def __init__(self):
183196
help="Tokenizer path",
184197
)
185198
self.parser.add_argument(
186-
"--model.handlers",
187-
type=str,
188-
default="",
199+
"--model.converters",
200+
type=string_list,
201+
nargs="+",
202+
default=[],
189203
help="""
190-
Comma separated list of handlers to apply to the model.
204+
Comma separated list of converters to apply to the model.
191205
192-
For instance, the `float8` handler swaps `torch.nn.Linear`
206+
For instance, the `float8` converter swaps `torch.nn.Linear`
193207
with `Float8Linear`. This feature requires you to install 'torchao'
194208
which can be found here: https://github.com/pytorch/ao
195209
""",
@@ -541,15 +555,15 @@ def __init__(self):
541555
)
542556

543557
# float8 configs
544-
# self.parser.add_argument(
545-
# "--float8.enable_float8_linear",
546-
# action="store_true",
547-
# help="""
548-
# If true, swaps `torch.nn.Linear` with `Float8Linear`.
549-
# This feature requires you to install 'torchao' which can be found
550-
# here: https://github.com/pytorch/ao
551-
# """,
552-
# )
558+
self.parser.add_argument(
559+
"--float8.enable_float8_linear",
560+
action="store_true",
561+
help="""
562+
If true, swaps `torch.nn.Linear` with `Float8Linear`.
563+
This feature requires you to install 'torchao' which can be found
564+
here: https://github.com/pytorch/ao
565+
""",
566+
)
553567
self.parser.add_argument(
554568
"--float8.enable_fsdp_float8_all_gather",
555569
action="store_true",
@@ -618,18 +632,11 @@ def parse_args(self, args_list: list = sys.argv[1:]):
618632
logger.exception(f"Error details: {str(e)}")
619633
raise e
620634

635+
# Checking string-list arguments are properly split into a list
621636
# if split-points came from 'args' (from cmd line) it would have already been parsed into a list by that parser
622-
if (
623-
"experimental" in args_dict
624-
and "pipeline_parallel_split_points" in args_dict["experimental"]
625-
and isinstance(
626-
args_dict["experimental"]["pipeline_parallel_split_points"], str
627-
)
628-
):
629-
exp = args_dict["experimental"]
630-
exp["pipeline_parallel_split_points"] = string_list(
631-
exp["pipeline_parallel_split_points"]
632-
)
637+
string_list_argnames = self._get_string_list_argument_names()
638+
for n in string_list_argnames:
639+
check_string_list_argument(args_dict, n)
633640

634641
# override args dict with cmd_args
635642
cmd_args_dict = self._args_to_two_level_dict(cmd_args)
@@ -657,13 +664,21 @@ def _validate_config(self) -> None:
657664
assert self.model.flavor
658665
assert self.model.tokenizer_path
659666

667+
def _get_string_list_argument_names(self) -> list[str]:
668+
"""Get the parser argument names of type `string_list`."""
669+
string_list_args = [
670+
v.dest for v in self.parser._actions if v.type is string_list
671+
]
672+
return string_list_args
673+
660674
def parse_args_from_command_line(
661675
self, args_list
662676
) -> Tuple[argparse.Namespace, argparse.Namespace]:
663677
"""
664678
Parse command line arguments and return the parsed args and the command line only args
665679
"""
666680
args = self.parser.parse_args(args_list)
681+
string_list_argnames = set(self._get_string_list_argument_names())
667682

668683
# aux parser to parse the command line only args, with no defaults from main parser
669684
aux_parser = argparse.ArgumentParser(argument_default=argparse.SUPPRESS)
@@ -672,7 +687,7 @@ def parse_args_from_command_line(
672687
aux_parser.add_argument(
673688
"--" + arg, action="store_true" if val else "store_false"
674689
)
675-
elif arg == "experimental.pipeline_parallel_split_points":
690+
elif arg in string_list_argnames:
676691
# without this special case, type inference breaks here,
677692
# since the inferred type is just 'list' and it ends up flattening
678693
# e.g. from ["layers.0", "layers.1"] into ["l", "a", "y", "e", "r", "s", ".0", ...]

torchtitan/float8.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@
2020

2121
from torchtitan.config_manager import JobConfig
2222
from torchtitan.logging import logger
23-
from torchtitan.model_handler import ModelHandler, register_model_handler
23+
from torchtitan.model_converter import ModelConverter, register_model_converter
2424
from torchtitan.parallelisms import ParallelDims
2525

2626

@@ -29,7 +29,7 @@ def _is_sm89_or_later():
2929
return torch.cuda.is_available() and torch.cuda.get_device_capability() >= (8, 9)
3030

3131

32-
class Float8Handler(ModelHandler):
32+
class Float8Handler(ModelConverter):
3333
def __init__(self, job_config: JobConfig, parallel_dims: ParallelDims):
3434
self.enabled = False
3535

@@ -109,4 +109,4 @@ def precompute_float8_dynamic_scale_for_fsdp(
109109
precompute_float8_dynamic_scale_for_fsdp(m)
110110

111111

112-
register_model_handler(Float8Handler, "float8")
112+
register_model_converter(Float8Handler, "float8")

torchtitan/model_converter.py

+80
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,80 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
from typing import Dict, List, Protocol, Union
7+
8+
import torch.nn as nn
9+
10+
from torchtitan.config_manager import JobConfig
11+
from torchtitan.parallelisms import ParallelDims
12+
13+
14+
class ModelConverter(Protocol):
15+
"""General model converter interface.
16+
17+
A model converter is applying a modification to PyTorch model.
18+
Typical use cases are:
19+
- Quantization: using QAT, FP8, ... specialized linear layers;
20+
- Fused optimized layers (e.g. flash-attention, norms, ...)
21+
"""
22+
23+
def __init__(self, job_config: JobConfig, parallel_dims: ParallelDims):
24+
...
25+
26+
def convert(self, model: nn.Module):
27+
"""Inplace convertion of the model."""
28+
...
29+
30+
def post_optimizer_hook(self, model: Union[nn.Module, List[nn.Module]]):
31+
"""Post-optimizer (optional) hook (e.g. compute weights statistics)."""
32+
...
33+
34+
35+
_registry_model_converter_cls: Dict[str, type[ModelConverter]] = {}
36+
"""Registry of model converter classes.
37+
"""
38+
39+
40+
def register_model_converter(converter_cls: type[ModelConverter], name: str):
41+
"""Register a model converter class.
42+
43+
A registered model converter can be applied on any model
44+
using the `model.converters` config parameter.
45+
"""
46+
assert (
47+
name not in _registry_model_converter_cls
48+
), f"A model converter '{name}' is already registered."
49+
_registry_model_converter_cls[name] = converter_cls
50+
51+
52+
class ModelConvertersContainer(ModelConverter):
53+
"""Model converters sequential container.
54+
55+
The class build the sequence of model converters defined in `model.converters`
56+
job config, and apply them to the model sequentially.
57+
"""
58+
59+
def __init__(self, job_config: JobConfig, parallel_dims: ParallelDims):
60+
converter_classes = [
61+
_registry_model_converter_cls[name] for name in job_config.model.converters
62+
]
63+
self.converters = [
64+
mh_cls(job_config, parallel_dims) for mh_cls in converter_classes
65+
]
66+
67+
def convert(self, model: nn.Module):
68+
for mh in self.converters:
69+
mh.convert(model)
70+
71+
def post_optimizer_hook(self, model: Union[nn.Module, List[nn.Module]]):
72+
for mh in self.converters:
73+
mh.post_optimizer_hook(model)
74+
75+
76+
def build_model_converters(
77+
job_config: JobConfig, parallel_dims: ParallelDims
78+
) -> ModelConvertersContainer:
79+
"""Build the collection of model converters to apply to the model."""
80+
return ModelConvertersContainer(job_config, parallel_dims)

torchtitan/model_handler.py

-86
This file was deleted.

torchtitan/parallelisms/parallelize_llama.py

+3-2
Original file line numberDiff line numberDiff line change
@@ -33,8 +33,9 @@
3333

3434
from torchtitan.config_manager import JobConfig, TORCH_DTYPE_MAP
3535
from torchtitan.logging import logger
36+
from torchtitan.model_converter import parse_model_converters
3637
from torchtitan.parallelisms.parallel_dims import ParallelDims
37-
from torchtitan.model_handler import parse_model_handlers
38+
3839

3940
def parallelize_llama(
4041
model: nn.Module,
@@ -56,7 +57,7 @@ def parallelize_llama(
5657
and not job_config.training.compile
5758
):
5859
raise RuntimeError("Async TP requires --training.compile")
59-
enable_float8 = "float8" in parse_model_handlers(job_config)
60+
enable_float8 = "float8" in parse_model_converters(job_config)
6061
apply_tp(
6162
model,
6263
world_mesh["tp"],

train.py

+6-6
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818
from torchtitan.datasets import build_hf_data_loader, build_tokenizer
1919
from torchtitan.logging import init_logger, logger
2020
from torchtitan.metrics import build_device_memory_monitor, build_metric_logger
21-
from torchtitan.model_handler import build_model_handlers_container
21+
from torchtitan.model_converter import build_model_converters
2222
from torchtitan.models import model_name_to_cls, model_name_to_tokenizer, models_config
2323
from torchtitan.optimizer import build_lr_schedulers, build_optimizers
2424
from torchtitan.parallelisms import (
@@ -110,9 +110,9 @@ def main(job_config: JobConfig):
110110
with torch.device("meta"):
111111
model = model_cls.from_model_args(model_config)
112112

113-
# Build the collection of model handlers. No-op if `model.handlers` empty
114-
model_handlers = build_model_handlers_container(job_config, parallel_dims)
115-
model_handlers.convert(model)
113+
# Build the collection of model converters. No-op if `model.converters` empty
114+
model_converters = build_model_converters(job_config, parallel_dims)
115+
model_converters.convert(model)
116116

117117
# log model size
118118
model_param_count = utils.get_num_params(model)
@@ -325,10 +325,10 @@ def loss_fn(pred, labels):
325325
optimizers.step()
326326
lr_schedulers.step()
327327

328-
# Post-optimizer model handlers hook.
328+
# Post-optimizer model converters hook.
329329
# e.g. calculate float8 dynamic amax/scale for all-parameter for FSDP2
330330
# it issues a single all-reduce for all parameters at once for better performance
331-
model_handlers.post_optimizer_hook(model_parts)
331+
model_converters.post_optimizer_hook(model_parts)
332332

333333
# log metrics
334334
if (

0 commit comments

Comments
 (0)