Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add Dynamic Model Import and ModelSpec Definition #814

Open
wants to merge 10 commits into
base: gh/fegin/8/base
Choose a base branch
from

Conversation

fegin
Copy link
Contributor

@fegin fegin commented Jan 31, 2025

Stack from ghstack (oldest at bottom):

What does this PR do?

  1. This PR introduces ModelSpec to describe a model and how to parallelize a model.
    • All the models should call register_model_spec().
    • Users can also use --experimental.custom_model_path to dynamically import a model that is not implemented by TorchTitan. The module should also call register_model_spec().
  2. This PR also refactors OptimizersContainer and LRSchedulersContainers
    • Fixes an issue that optimizers will accept parameters that requires_grad is False.
    • Improve typing and docstring.
    • Improve the function and class reusability.
    • OptimizersContainer now inherits from torch.optim.Optimizer .
  3. This PR also moves parallelize_llama and pipelining_llama to the llama folder.

Why do we need this PR?
This allows users to use TorchTitan with a new model without intrusively change TorchTitan code.

Next steps

  1. Dataloader is not included
  2. Checkpoint customization is not included yet.

[ghstack-poisoned]
fegin added a commit that referenced this pull request Jan 31, 2025
**What does this PR do?**
1. This PR introduce ModelSpec to decribe a model and how to parallelize a model.
2. All the models should define `build_model_spec()` or `model_spec` to
   be imported by the `model` module.
3. `build_model_specs()` is called in the trainer to get the `model_specs` and the result is used to get the corresponding model spec.
4. Users can also use `--experimental.model_module_path` to dynamically import a model that is not implemented by TorchTitan.

**Why do we need this PR?**
This allows users to use TorchTitan with a new model without intrusively change TorchTitan code.

**Next steps**
1. This PR only include the mode definitions, configurations, totkenizer, parallize_fn, and
   pipelining_fn.  We may also want to extend ModelSpec to include optimizer and lr_scheduler
2. Current TorchTitan parallelize and pipelining_fn import ModelArgs which can cause circular imports.
   We should fix this issue.

**What does this PR do?**
1. Introduces `ModelSpec` to describe a model and how to parallelize it.
2. Requires all models to define `build_model_spec()` or `model_spec`, which will be imported by the model module.
3. Calls `build_model_specs()` in the trainer to obtain `model_specs`, which are then used to retrieve the corresponding model spec.
4. Allows users to dynamically import a model not implemented by TorchTitan using --experimental.model_module_path.

**Why do we need this PR?**
This PR enables users to integrate new models with TorchTitan without making intrusive changes to the TorchTitan codebase.

**Next steps**
1. This PR includes only the model definitions, configurations, tokenizer, parallelize_fn, and pipelining_fn. We may want to extend ModelSpec to include the optimizer and learning rate scheduler.
2. The current TorchTitan parallelize and pipelining_fn import ModelArgs, which can lead to circular imports. This issue needs to be addressed.

ghstack-source-id: f0847f5efebfdf8c6619f58c1b0131a233502eaf
Pull Request resolved: #814
@facebook-github-bot facebook-github-bot added the CLA Signed This label is managed by the Meta Open Source bot. label Jan 31, 2025
@fegin fegin requested review from tianyu-l, wconstab and fduwjj January 31, 2025 18:40
@fegin fegin changed the title Allow users to use the customized model Add Dynamic Model Import and ModelSpec Definition Jan 31, 2025
[ghstack-poisoned]
fegin added a commit that referenced this pull request Jan 31, 2025
**What does this PR do?**
1. This PR introduce ModelSpec to decribe a model and how to parallelize a model.
2. All the models should define `build_model_spec()` or `model_spec` to
   be imported by the `model` module.
3. `build_model_specs()` is called in the trainer to get the `model_specs` and the result is used to get the corresponding model spec.
4. Users can also use `--experimental.model_module_path` to dynamically import a model that is not implemented by TorchTitan.

**Why do we need this PR?**
This allows users to use TorchTitan with a new model without intrusively change TorchTitan code.

**Next steps**
1. This PR only include the mode definitions, configurations, totkenizer, parallize_fn, and
   pipelining_fn.  We may also want to extend ModelSpec to include optimizer and lr_scheduler
2. Current TorchTitan parallelize and pipelining_fn import ModelArgs which can cause circular imports.
   We should fix this issue.

**What does this PR do?**
1. Introduces `ModelSpec` to describe a model and how to parallelize it.
2. Requires all models to define `build_model_spec()` or `model_spec`, which will be imported by the model module.
3. Calls `build_model_specs()` in the trainer to obtain `model_specs`, which are then used to retrieve the corresponding model spec.
4. Allows users to dynamically import a model not implemented by TorchTitan using --experimental.model_module_path.

**Why do we need this PR?**
This PR enables users to integrate new models with TorchTitan without making intrusive changes to the TorchTitan codebase.

**Next steps**
1. This PR includes only the model definitions, configurations, tokenizer, parallelize_fn, and pipelining_fn. We may want to extend ModelSpec to include the optimizer and learning rate scheduler.
2. The current TorchTitan parallelize and pipelining_fn import ModelArgs, which can lead to circular imports. This issue needs to be addressed.

ghstack-source-id: 28259eb74975eeb7ad790a774b6e719f3aa19a31
Pull Request resolved: #814
[ghstack-poisoned]
fegin added a commit that referenced this pull request Jan 31, 2025
**What does this PR do?**
1. This PR introduce ModelSpec to decribe a model and how to parallelize a model.
2. All the models should define `build_model_spec()` or `model_spec` to
   be imported by the `model` module.
3. `build_model_specs()` is called in the trainer to get the `model_specs` and the result is used to get the corresponding model spec.
4. Users can also use `--experimental.model_module_path` to dynamically import a model that is not implemented by TorchTitan.

**Why do we need this PR?**
This allows users to use TorchTitan with a new model without intrusively change TorchTitan code.

**Next steps**
1. This PR only include the mode definitions, configurations, totkenizer, parallize_fn, and
   pipelining_fn.  We may also want to extend ModelSpec to include optimizer and lr_scheduler
2. Current TorchTitan parallelize and pipelining_fn import ModelArgs which can cause circular imports.
   We should fix this issue.

**What does this PR do?**
1. Introduces `ModelSpec` to describe a model and how to parallelize it.
2. Requires all models to define `build_model_spec()` or `model_spec`, which will be imported by the model module.
3. Calls `build_model_specs()` in the trainer to obtain `model_specs`, which are then used to retrieve the corresponding model spec.
4. Allows users to dynamically import a model not implemented by TorchTitan using --experimental.model_module_path.

**Why do we need this PR?**
This PR enables users to integrate new models with TorchTitan without making intrusive changes to the TorchTitan codebase.

**Next steps**
1. This PR includes only the model definitions, configurations, tokenizer, parallelize_fn, and pipelining_fn. We may want to extend ModelSpec to include the optimizer and learning rate scheduler.
2. The current TorchTitan parallelize and pipelining_fn import ModelArgs, which can lead to circular imports. This issue needs to be addressed.

ghstack-source-id: ba1389f57808b1c6b309f554a675523d09395b42
Pull Request resolved: #814
[ghstack-poisoned]
fegin added a commit that referenced this pull request Jan 31, 2025
**What does this PR do?**
1. This PR introduce ModelSpec to decribe a model and how to parallelize a model.
2. All the models should define `build_model_spec()` or `model_spec` to
   be imported by the `model` module.
3. `build_model_specs()` is called in the trainer to get the `model_specs` and the result is used to get the corresponding model spec.
4. Users can also use `--experimental.model_module_path` to dynamically import a model that is not implemented by TorchTitan.

**Why do we need this PR?**
This allows users to use TorchTitan with a new model without intrusively change TorchTitan code.

**Next steps**
1. This PR only include the mode definitions, configurations, totkenizer, parallize_fn, and
   pipelining_fn.  We may also want to extend ModelSpec to include optimizer and lr_scheduler
2. Current TorchTitan parallelize and pipelining_fn import ModelArgs which can cause circular imports.
   We should fix this issue.

**What does this PR do?**
1. Introduces `ModelSpec` to describe a model and how to parallelize it.
2. Requires all models to define `build_model_spec()` or `model_spec`, which will be imported by the model module.
3. Calls `build_model_specs()` in the trainer to obtain `model_specs`, which are then used to retrieve the corresponding model spec.
4. Allows users to dynamically import a model not implemented by TorchTitan using --experimental.model_module_path.

**Why do we need this PR?**
This PR enables users to integrate new models with TorchTitan without making intrusive changes to the TorchTitan codebase.

**Next steps**
1. This PR includes only the model definitions, configurations, tokenizer, parallelize_fn, and pipelining_fn. We may want to extend ModelSpec to include the optimizer and learning rate scheduler.
2. The current TorchTitan parallelize and pipelining_fn import ModelArgs, which can lead to circular imports. This issue needs to be addressed.

ghstack-source-id: a88ff3ebe5c869055dd3314fb1b791855fd0e0b2
Pull Request resolved: #814
[ghstack-poisoned]
fegin added a commit that referenced this pull request Jan 31, 2025
**What does this PR do?**
1. This PR introduce ModelSpec to decribe a model and how to parallelize a model.
2. All the models should define `build_model_spec()` or `model_spec` to
   be imported by the `model` module.
3. `build_model_specs()` is called in the trainer to get the `model_specs` and the result is used to get the corresponding model spec.
4. Users can also use `--experimental.model_module_path` to dynamically import a model that is not implemented by TorchTitan.

**Why do we need this PR?**
This allows users to use TorchTitan with a new model without intrusively change TorchTitan code.

**Next steps**
1. This PR only include the mode definitions, configurations, totkenizer, parallize_fn, and
   pipelining_fn.  We may also want to extend ModelSpec to include optimizer and lr_scheduler
2. Current TorchTitan parallelize and pipelining_fn import ModelArgs which can cause circular imports.
   We should fix this issue.

**What does this PR do?**
1. Introduces `ModelSpec` to describe a model and how to parallelize it.
2. Requires all models to define `build_model_spec()` or `model_spec`, which will be imported by the model module.
3. Calls `build_model_specs()` in the trainer to obtain `model_specs`, which are then used to retrieve the corresponding model spec.
4. Allows users to dynamically import a model not implemented by TorchTitan using --experimental.model_module_path.

**Why do we need this PR?**
This PR enables users to integrate new models with TorchTitan without making intrusive changes to the TorchTitan codebase.

**Next steps**
1. This PR includes only the model definitions, configurations, tokenizer, parallelize_fn, and pipelining_fn. We may want to extend ModelSpec to include the optimizer and learning rate scheduler.
2. The current TorchTitan parallelize and pipelining_fn import ModelArgs, which can lead to circular imports. This issue needs to be addressed.

ghstack-source-id: 362df77a3f6a2b9f3cff514938a415bfe25e2100
Pull Request resolved: #814
Copy link
Contributor

@tianyu-l tianyu-l left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Initial pass looks great. Had some suggestions on restructuring.

torchtitan/models/llama/model.py Outdated Show resolved Hide resolved
torchtitan/models/__init__.py Outdated Show resolved Hide resolved
torchtitan/models/__init__.py Outdated Show resolved Hide resolved
torchtitan/config_manager.py Outdated Show resolved Hide resolved
torchtitan/models/llama/__init__.py Outdated Show resolved Hide resolved
from torchtitan.models.norms import build_norm


@dataclass
class ModelArgs:
class ModelArgs(BaseModelArgs):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Down the road we will have many models, like MM model. Do we want all model args to inherit this? Currently we use different model args for different model arch.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This mainly for typing for now but also preserve the ability to introduce common model args.

[ghstack-poisoned]
fegin added a commit that referenced this pull request Feb 6, 2025
**What does this PR do?**
1. This PR introduce ModelSpec to decribe a model and how to parallelize a model.
2. All the models should define `build_model_spec()` or `model_spec` to
   be imported by the `model` module.
3. `build_model_specs()` is called in the trainer to get the `model_specs` and the result is used to get the corresponding model spec.
4. Users can also use `--experimental.model_module_path` to dynamically import a model that is not implemented by TorchTitan.

**Why do we need this PR?**
This allows users to use TorchTitan with a new model without intrusively change TorchTitan code.

**Next steps**
1. This PR only include the mode definitions, configurations, totkenizer, parallize_fn, and
   pipelining_fn.  We may also want to extend ModelSpec to include optimizer and lr_scheduler
2. Current TorchTitan parallelize and pipelining_fn import ModelArgs which can cause circular imports.
   We should fix this issue.

**What does this PR do?**
1. Introduces `ModelSpec` to describe a model and how to parallelize it.
2. Requires all models to define `build_model_spec()` or `model_spec`, which will be imported by the model module.
3. Calls `build_model_specs()` in the trainer to obtain `model_specs`, which are then used to retrieve the corresponding model spec.
4. Allows users to dynamically import a model not implemented by TorchTitan using --experimental.model_module_path.

**Why do we need this PR?**
This PR enables users to integrate new models with TorchTitan without making intrusive changes to the TorchTitan codebase.

**Next steps**
1. This PR includes only the model definitions, configurations, tokenizer, parallelize_fn, and pipelining_fn. We may want to extend ModelSpec to include the optimizer and learning rate scheduler.
2. The current TorchTitan parallelize and pipelining_fn import ModelArgs, which can lead to circular imports. This issue needs to be addressed.

ghstack-source-id: 9ed1b54aa945af27ce0881ea02150c9e2f0022e8
Pull Request resolved: #814
[ghstack-poisoned]
fegin added a commit that referenced this pull request Feb 6, 2025
**What does this PR do?**
1. This PR introduce ModelSpec to decribe a model and how to parallelize a model.
2. All the models should define `build_model_spec()` or `model_spec` to
   be imported by the `model` module.
3. `build_model_specs()` is called in the trainer to get the `model_specs` and the result is used to get the corresponding model spec.
4. Users can also use `--experimental.model_module_path` to dynamically import a model that is not implemented by TorchTitan.

**Why do we need this PR?**
This allows users to use TorchTitan with a new model without intrusively change TorchTitan code.

**Next steps**
1. This PR only include the mode definitions, configurations, totkenizer, parallize_fn, and
   pipelining_fn.  We may also want to extend ModelSpec to include optimizer and lr_scheduler
2. Current TorchTitan parallelize and pipelining_fn import ModelArgs which can cause circular imports.
   We should fix this issue.

**What does this PR do?**
1. Introduces `ModelSpec` to describe a model and how to parallelize it.
2. Requires all models to define `build_model_spec()` or `model_spec`, which will be imported by the model module.
3. Calls `build_model_specs()` in the trainer to obtain `model_specs`, which are then used to retrieve the corresponding model spec.
4. Allows users to dynamically import a model not implemented by TorchTitan using --experimental.model_module_path.

**Why do we need this PR?**
This PR enables users to integrate new models with TorchTitan without making intrusive changes to the TorchTitan codebase.

**Next steps**
1. This PR includes only the model definitions, configurations, tokenizer, parallelize_fn, and pipelining_fn. We may want to extend ModelSpec to include the optimizer and learning rate scheduler.
2. The current TorchTitan parallelize and pipelining_fn import ModelArgs, which can lead to circular imports. This issue needs to be addressed.

ghstack-source-id: 01c89646326d2c356b6f82e0fa714a347da7b869
Pull Request resolved: #814
[ghstack-poisoned]
fegin added a commit that referenced this pull request Feb 6, 2025
**What does this PR do?**
1. This PR introduce ModelSpec to decribe a model and how to parallelize a model.
2. All the models should define `build_model_spec()` or `model_spec` to
   be imported by the `model` module.
3. `build_model_specs()` is called in the trainer to get the `model_specs` and the result is used to get the corresponding model spec.
4. Users can also use `--experimental.model_module_path` to dynamically import a model that is not implemented by TorchTitan.

**Why do we need this PR?**
This allows users to use TorchTitan with a new model without intrusively change TorchTitan code.

**Next steps**
1. This PR only include the mode definitions, configurations, totkenizer, parallize_fn, and
   pipelining_fn.  We may also want to extend ModelSpec to include optimizer and lr_scheduler
2. Current TorchTitan parallelize and pipelining_fn import ModelArgs which can cause circular imports.
   We should fix this issue.

**What does this PR do?**
1. Introduces `ModelSpec` to describe a model and how to parallelize it.
2. Requires all models to define `build_model_spec()` or `model_spec`, which will be imported by the model module.
3. Calls `build_model_specs()` in the trainer to obtain `model_specs`, which are then used to retrieve the corresponding model spec.
4. Allows users to dynamically import a model not implemented by TorchTitan using --experimental.model_module_path.

**Why do we need this PR?**
This PR enables users to integrate new models with TorchTitan without making intrusive changes to the TorchTitan codebase.

**Next steps**
1. This PR includes only the model definitions, configurations, tokenizer, parallelize_fn, and pipelining_fn. We may want to extend ModelSpec to include the optimizer and learning rate scheduler.
2. The current TorchTitan parallelize and pipelining_fn import ModelArgs, which can lead to circular imports. This issue needs to be addressed.

ghstack-source-id: bee7a1df8af55ea6a7ad7451a6bc9a3158922d4f
Pull Request resolved: #814
[ghstack-poisoned]
fegin added a commit that referenced this pull request Feb 6, 2025
**What does this PR do?**
1. This PR introduce ModelSpec to decribe a model and how to parallelize a model.
2. All the models should define `build_model_spec()` or `model_spec` to
   be imported by the `model` module.
3. `build_model_specs()` is called in the trainer to get the `model_specs` and the result is used to get the corresponding model spec.
4. Users can also use `--experimental.model_module_path` to dynamically import a model that is not implemented by TorchTitan.

**Why do we need this PR?**
This allows users to use TorchTitan with a new model without intrusively change TorchTitan code.

**Next steps**
1. This PR only include the mode definitions, configurations, totkenizer, parallize_fn, and
   pipelining_fn.  We may also want to extend ModelSpec to include optimizer and lr_scheduler
2. Current TorchTitan parallelize and pipelining_fn import ModelArgs which can cause circular imports.
   We should fix this issue.

**What does this PR do?**
1. Introduces `ModelSpec` to describe a model and how to parallelize it.
2. Requires all models to define `build_model_spec()` or `model_spec`, which will be imported by the model module.
3. Calls `build_model_specs()` in the trainer to obtain `model_specs`, which are then used to retrieve the corresponding model spec.
4. Allows users to dynamically import a model not implemented by TorchTitan using --experimental.model_module_path.

**Why do we need this PR?**
This PR enables users to integrate new models with TorchTitan without making intrusive changes to the TorchTitan codebase.

**Next steps**
1. This PR includes only the model definitions, configurations, tokenizer, parallelize_fn, and pipelining_fn. We may want to extend ModelSpec to include the optimizer and learning rate scheduler.
2. The current TorchTitan parallelize and pipelining_fn import ModelArgs, which can lead to circular imports. This issue needs to be addressed.

ghstack-source-id: 91e1dca3dc8d3268c2d636e335029cb3e18318d6
Pull Request resolved: #814
@fegin
Copy link
Contributor Author

fegin commented Feb 6, 2025

Tests and document will come in the next few updates.

[ghstack-poisoned]
fegin added a commit that referenced this pull request Feb 7, 2025
**What does this PR do?**
1. This PR introduce ModelSpec to decribe a model and how to parallelize a model.
2. All the models should define `build_model_spec()` or `model_spec` to
   be imported by the `model` module.
3. `build_model_specs()` is called in the trainer to get the `model_specs` and the result is used to get the corresponding model spec.
4. Users can also use `--experimental.model_module_path` to dynamically import a model that is not implemented by TorchTitan.

**Why do we need this PR?**
This allows users to use TorchTitan with a new model without intrusively change TorchTitan code.

**Next steps**
1. This PR only include the mode definitions, configurations, totkenizer, parallize_fn, and
   pipelining_fn.  We may also want to extend ModelSpec to include optimizer and lr_scheduler
2. Current TorchTitan parallelize and pipelining_fn import ModelArgs which can cause circular imports.
   We should fix this issue.

ghstack-source-id: 671424d38a040c8594f8b3d692cd8e141ce5c656
Pull Request resolved: #814
Copy link
Contributor

@tianyu-l tianyu-l left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is super cool! Thank you for unlocking torchtitan to reach next level.

# TorchTitan library. A better way would be to have a dataloader class
# and a ``build_dataloader`` function that take job_config to consume
# the different dataloader and tokenizer configs.
tokenizer: str
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

currently tokenizer is part of data loader
https://github.com/pytorch/torchtitan/blob/main/torchtitan/datasets/hf_datasets.py#L186

maybe let's remove it for now

from .parallelize_llama import parallelize_llama
from .pipeline_llama import pipeline_llama

__all__ = ["parallelize_llama", "pipeline_llama", "ModelArgs", "Transformer"]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

do we need to expose these fields in llama/__init__.py?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, so that users can reuse the parallelism APIs from llama.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That makes sense. But maybe also llama3_configs?
I imagine some one wants to use implement new parallelisms, but relying on existing definitions of Llama 3 8B/70B/405B. In that case they don't need ModelArgs but only the preset configs.

from torchtitan.models.norms import build_norm


@dataclass
class ModelArgs:
class ModelArgs(BaseModelArgs):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can we rename it to "TransformerModelArgs"?

@@ -67,10 +70,16 @@ def pipeline_llama_manual_split(

splits = (
job_config.experimental.pipeline_parallel_split_points
or generate_split_points(job_config, parallel_dims.pp, model_config)
or generate_split_points(job_config, parallel_dims.pp, model_config.n_layers)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nice

# that is immutable. As long as ``training.steps`` and ``training.warmup_steps``
# in ``job_config`` remain unchanged when resuming from a checkpoint, this
# approach is safe. We call ``copy()`` here to ensure extra safety.
# TODO: Should we deepcopy the state_dict?
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we should -- that was the intention.

Comment on lines +24 to +25
# TODO: It's unclear if this API is general enough to be used by other models.
# If not, we should move it to a Transformer-specific directory.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

👍

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I have to say the docs added in this file look fabulous ✨

Comment on lines +74 to +75
# We need to call super().__init__() to initialize some necessary optimizer
# functionality such as hooks.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can we put this to where _post_init is called for better readability?

@@ -36,7 +39,7 @@ def pipeline_llama(
device: DeviceType,
model_config: ModelArgs,
loss_fn: Callable[..., torch.Tensor],
):
) -> tuple[_PipelineSchedule, list[nn.Module]]:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Typing in uppercase vs. lowercase seems inconsistent throughout the PR. Is this intentionally? and what's the recommended way?

hmm it seems only for state_dict you used uppercase, maybe because compatibility.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Uppercase is the recommended way if we don't support <= Python 3.8. After Pytorch 2.6, that's the case. So we should just change to the lower case one. I may revisit the code and try to change all to lowercases.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I feel this file is growing to be too big -- we basically throw things here when we don't know where to put them.
Maybe let's revisit later as a BE thing.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Lol, ye, that is the legacy putting everything into utils file issue. I agree we should split it. But we can do it in another BE PR.



@dataclass
class ModelSpec:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since the Spec is not only about model, e.g. conceptually there can be multiple ways to do training for the same model (gpu/tpu, customized parallelize/pipeline), shall we consider renaming it to TrainSpec?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That's actually a good question and suggestion. I am open to this option.

build_optimizers_fn: Callable[[List[nn.Module], JobConfig], OptimizersContainer]
build_lr_schedulers_fn: Callable[
[List[nn.Module], JobConfig], LRSchedulersContainer
]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For some models we may need to alter loss_fn as well, e.g. in diffusion models. We may add that later when necessary.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
CLA Signed This label is managed by the Meta Open Source bot.
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants