-
Notifications
You must be signed in to change notification settings - Fork 2.7k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Adding serialization to all Auto* objects in HuggingFace transformers (…
…#11645) * Adding serialization to all Auto* objects in HuggingFace transformers Signed-off-by: Marc Romeyn <[email protected]> * Apply isort and black reformatting Signed-off-by: marcromeyn <[email protected]> * Adding docs Signed-off-by: Marc Romeyn <[email protected]> * Apply isort and black reformatting Signed-off-by: marcromeyn <[email protected]> * Adding more doc-strings Signed-off-by: Marc Romeyn <[email protected]> * Adding more doc-strings Signed-off-by: Marc Romeyn <[email protected]> * Address comments Signed-off-by: Marc Romeijn <[email protected]> * Apply isort and black reformatting Signed-off-by: marcromeyn <[email protected]> * fix? Signed-off-by: Alexandros Koumparoulis <[email protected]> --------- Signed-off-by: Marc Romeyn <[email protected]> Signed-off-by: marcromeyn <[email protected]> Signed-off-by: Marc Romeijn <[email protected]> Signed-off-by: Alexandros Koumparoulis <[email protected]> Co-authored-by: marcromeyn <[email protected]> Co-authored-by: Alexandros Koumparoulis <[email protected]>
- Loading branch information
1 parent
ea5ed67
commit 2f66ada
Showing
8 changed files
with
314 additions
and
20 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,4 +1,5 @@ | ||
from nemo.lightning.io.artifact.base import Artifact | ||
from nemo.lightning.io.artifact.file import DirArtifact, DirOrStringArtifact, FileArtifact, PathArtifact | ||
from nemo.lightning.io.artifact.hf_auto import HFAutoArtifact | ||
|
||
__all__ = ["Artifact", "FileArtifact", "PathArtifact", "DirArtifact", "DirOrStringArtifact"] | ||
__all__ = ["Artifact", "FileArtifact", "PathArtifact", "DirArtifact", "DirOrStringArtifact", "HFAutoArtifact"] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,152 @@ | ||
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
|
||
|
||
"""HuggingFace model serialization support for NeMo's configuration system. | ||
This module provides integration between NeMo's configuration system and HuggingFace's | ||
pretrained models. It enables automatic serialization and deserialization of HuggingFace | ||
models within NeMo's configuration framework. | ||
The integration works by: | ||
1. Detecting HuggingFace models through their characteristic methods (save_pretrained/from_pretrained) | ||
2. Converting them to Fiddle configurations that preserve the model's class and path | ||
3. Providing an artifact handler (HFAutoArtifact) that manages the actual model files | ||
Example: | ||
```python | ||
from transformers import AutoModel | ||
# This model will be automatically handled by the HFAutoArtifact system | ||
model = AutoModel.from_pretrained("bert-base-uncased") | ||
# When serialized, the model files will be saved to the artifacts directory | ||
# When deserialized, the model will be loaded from the saved files | ||
``` | ||
""" | ||
|
||
import contextlib | ||
import inspect | ||
import threading | ||
from pathlib import Path | ||
|
||
import fiddle as fdl | ||
|
||
from nemo.lightning.io.artifact import Artifact | ||
from nemo.lightning.io.to_config import to_config | ||
|
||
_local = threading.local() | ||
|
||
|
||
class HFAutoArtifact(Artifact): | ||
"""Artifact handler for HuggingFace pretrained model/processor/tokenizer/etc.. | ||
This handler manages the serialization and deserialization of HuggingFace models | ||
by utilizing their save_pretrained/from_pretrained methods. It saves models to | ||
an 'artifacts' subdirectory within the specified path. | ||
""" | ||
|
||
def dump(self, instance, value: Path, absolute_dir: Path, relative_dir: Path) -> Path: | ||
"""Save a HuggingFace model to disk. | ||
Args: | ||
instance: The HuggingFace model instance to save | ||
value: Original path value (unused) | ||
absolute_dir: Absolute path to the save directory | ||
relative_dir: Relative path from the config file to the save directory | ||
Returns: | ||
str: The relative path to the saved model artifacts | ||
""" | ||
instance.save_pretrained(Path(absolute_dir) / "artifacts") | ||
return "./" + str(Path(relative_dir) / "artifacts") | ||
|
||
def load(self, path: Path) -> Path: | ||
"""Return the path to load a HuggingFace model. | ||
Args: | ||
path: Path to the saved model artifacts | ||
Returns: | ||
Path: The same path, to be used with from_pretrained | ||
""" | ||
return path | ||
|
||
|
||
@contextlib.contextmanager | ||
def from_pretrained_kwargs(**kwargs): | ||
"""Context manager for passing additional kwargs to from_pretrained. | ||
Args: | ||
**kwargs: Keyword arguments to pass to from_pretrained | ||
Example: | ||
with from_pretrained_kwargs(trust_remote_code=True): | ||
io.load_context("path/to/checkpoint") | ||
""" | ||
if not hasattr(_local, "kwargs"): | ||
_local.kwargs = {} | ||
previous = _local.kwargs.copy() | ||
_local.kwargs.update(kwargs) | ||
try: | ||
yield | ||
finally: | ||
_local.kwargs = previous | ||
|
||
|
||
def from_pretrained(auto_cls, pretrained_model_name_or_path="dummy"): | ||
"""Factory function for loading HuggingFace pretrained models. | ||
This function is used as the serialization target for HuggingFace models. | ||
When deserialized, it will recreate the model using its from_pretrained method. | ||
Args: | ||
auto_cls: The HuggingFace model class (e.g., AutoModel, AutoTokenizer) | ||
pretrained_model_name_or_path: Path to the saved model or model identifier | ||
Returns: | ||
The loaded HuggingFace model | ||
""" | ||
kwargs = getattr(_local, "kwargs", {}) | ||
return auto_cls.from_pretrained(pretrained_model_name_or_path, **kwargs) | ||
|
||
|
||
@to_config.register( | ||
lambda v: not inspect.isclass(v) | ||
and getattr(v, "__module__", "").startswith("transformers") | ||
and hasattr(v, "save_pretrained") | ||
and hasattr(v, "from_pretrained") | ||
) | ||
def handle_hf_pretrained(value): | ||
"""Convert a HuggingFace model instance to a Fiddle configuration. | ||
This handler detects HuggingFace model instances by checking for the presence | ||
of save_pretrained and from_pretrained methods. It converts them to a Fiddle | ||
configuration that will recreate the model using from_pretrained. | ||
Args: | ||
value: A HuggingFace model instance | ||
Returns: | ||
fdl.Config: A Fiddle configuration that will recreate the model | ||
""" | ||
return fdl.Config( | ||
from_pretrained, | ||
auto_cls=value.__class__, | ||
pretrained_model_name_or_path="dummy", | ||
) | ||
|
||
|
||
# Register the HFAutoArtifact handler for the pretrained_model_name_or_path parameter | ||
from_pretrained.__io_artifacts__ = [HFAutoArtifact("pretrained_model_name_or_path")] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.