-
Notifications
You must be signed in to change notification settings - Fork 234
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Co-authored-by: Kavulya, Soila P <[email protected]> Co-authored-by: pallavi jaini <[email protected]>
- Loading branch information
1 parent
982bda6
commit c5a715c
Showing
10 changed files
with
2,253 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
89 changes: 89 additions & 0 deletions
89
optimum/habana/transformers/modeling_utils_transformers.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,89 @@ | ||
import os | ||
from typing import Optional, Union | ||
from zipfile import is_zipfile | ||
|
||
import torch | ||
from packaging import version | ||
from transformers.integrations import is_deepspeed_zero3_enabled | ||
from transformers.modeling_utils import is_fsdp_enabled, is_local_dist_rank_0 | ||
from transformers.utils import ( | ||
is_safetensors_available, | ||
) | ||
|
||
|
||
if is_safetensors_available(): | ||
from safetensors import safe_open | ||
from safetensors.torch import load_file as safe_load_file | ||
|
||
|
||
def load_state_dict( | ||
checkpoint_file: Union[str, os.PathLike], | ||
is_quantized: bool = False, | ||
map_location: Optional[Union[str, torch.device]] = None, | ||
weights_only: bool = True, | ||
): | ||
""" | ||
Reads a PyTorch checkpoint file, returning properly formatted errors if they arise. | ||
Copied from transformers v4.48.2 for DeepSeek-R1 support https://github.com/huggingface/transformers/blob/b673c16cad81c71f70903a9a63f5b5f06014aa9e/src/transformers/modeling_utils.py#L493 | ||
Delete after upgrade transformers v4.45.2 to v4.48 | ||
""" | ||
if checkpoint_file.endswith(".safetensors") and is_safetensors_available(): | ||
# Check format of the archive | ||
with safe_open(checkpoint_file, framework="pt") as f: | ||
metadata = f.metadata() | ||
if metadata is not None and metadata.get("format") not in ["pt", "tf", "flax", "mlx"]: | ||
raise OSError( | ||
f"The safetensors archive passed at {checkpoint_file} does not contain the valid metadata. Make sure " | ||
"you save your model with the `save_pretrained` method." | ||
) | ||
return safe_load_file(checkpoint_file) | ||
try: | ||
if map_location is None: | ||
if ( | ||
( | ||
is_deepspeed_zero3_enabled() | ||
and torch.distributed.is_initialized() | ||
and torch.distributed.get_rank() > 0 | ||
) | ||
or (is_fsdp_enabled() and not is_local_dist_rank_0()) | ||
) and not is_quantized: | ||
map_location = "meta" | ||
else: | ||
map_location = "cpu" | ||
extra_args = {} | ||
# mmap can only be used with files serialized with zipfile-based format. | ||
if ( | ||
isinstance(checkpoint_file, str) | ||
and map_location != "meta" | ||
and version.parse(torch.__version__) >= version.parse("2.1.0") | ||
and is_zipfile(checkpoint_file) | ||
): | ||
extra_args = {"mmap": True} | ||
weights_only_kwarg = {"weights_only": weights_only} | ||
return torch.load( | ||
checkpoint_file, | ||
map_location=map_location, | ||
**weights_only_kwarg, | ||
**extra_args, | ||
) | ||
except Exception as e: | ||
try: | ||
with open(checkpoint_file) as f: | ||
if f.read(7) == "version": | ||
raise OSError( | ||
"You seem to have cloned a repository without having git-lfs installed. Please install " | ||
"git-lfs and run `git lfs install` followed by `git lfs pull` in the folder " | ||
"you cloned." | ||
) | ||
else: | ||
raise ValueError( | ||
f"Unable to locate the file {checkpoint_file} which is necessary to load this pretrained " | ||
"model. Make sure you have saved the model properly." | ||
) from e | ||
except (UnicodeDecodeError, ValueError): | ||
raise OSError( | ||
f"Unable to load weights from pytorch checkpoint file for '{checkpoint_file}' " | ||
f"at '{checkpoint_file}'. " | ||
"If you tried to load a PyTorch model from a TF 2.0 checkpoint, please set from_tf=True." | ||
) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,2 @@ | ||
from .configuration_deepseek_v3 import DeepseekV3Config | ||
from .modeling_deepseek_v3 import DeepseekV3ForCausalLM |
Oops, something went wrong.