Skip to content

Commit

Permalink
DeepSeek_v3 support (#1735)
Browse files Browse the repository at this point in the history
Co-authored-by: Kavulya, Soila P <[email protected]>
Co-authored-by: pallavi jaini <[email protected]>
  • Loading branch information
3 people authored Feb 21, 2025
1 parent 982bda6 commit c5a715c
Show file tree
Hide file tree
Showing 10 changed files with 2,253 additions and 0 deletions.
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -280,6 +280,7 @@ The following model architectures, tasks and device distributions have been vali
| MiniCPM3 | | <li>Single card</li> | <li>[text generation](https://github.com/huggingface/optimum-habana/tree/main/examples/text-generation)</li> |
| Baichuan2 | <li>DeepSpeed</li> | <li>Single card</li> | <li>[language modeling](https://github.com/huggingface/optimum-habana/tree/main/examples/language-modeling)</li><li>[text generation](https://github.com/huggingface/optimum-habana/tree/main/examples/text-generation)</li> |
| DeepSeek-V2 | :heavy_check_mark: | :heavy_check_mark: | <li>[text generation](https://github.com/huggingface/optimum-habana/tree/main/examples/text-generation)</li> |
| DeepSeek-V3 | | :heavy_check_mark: | <li>[text generation](https://github.com/huggingface/optimum-habana/tree/main/examples/text-generation)</li> |
| ChatGLM | <li>DeepSpeed</li> | <li>Single card</li> | <li>[language modeling](https://github.com/huggingface/optimum-habana/tree/main/examples/language-modeling)</li><li>[text generation](https://github.com/huggingface/optimum-habana/tree/main/examples/text-generation)</li> |
| Qwen2-VL | | <div style="text-align:left"><li>Single card</li></div> | <li>[image to text](https://github.com/huggingface/optimum-habana/tree/main/examples/image-to-text)</li> |
| VideoLLaVA | | <div style="text-align:left"><li>Single card</li></div> | <li>[Video comprehension](https://github.com/huggingface/optimum-habana/tree/main/examples/video-comprehension)</li> |
Expand Down
1 change: 1 addition & 0 deletions docs/source/index.mdx
Original file line number Diff line number Diff line change
Expand Up @@ -109,6 +109,7 @@ In the tables below, ✅ means single-card, multi-card and DeepSpeed have all be
| MiniCPM3 | | <div style="text-align:left"><li>Single card</li></div> | <li>[text generation](https://github.com/huggingface/optimum-habana/tree/main/examples/text-generation)</li> |
| Baichuan2 | <div style="text-align:left"><li>DeepSpeed</li></div> | <div style="text-align:left"><li>Single card</li></div> | <li>[language modeling](https://github.com/huggingface/optimum-habana/tree/main/examples/language-modeling)</li><li>[text generation](https://github.com/huggingface/optimum-habana/tree/main/examples/text-generation)</li> |
| DeepSeek-V2 ||| <li>[text generation](https://github.com/huggingface/optimum-habana/tree/main/examples/text-generation)</li> |
| DeepSeek-V3 | || <li>[text generation](https://github.com/huggingface/optimum-habana/tree/main/examples/text-generation)</li> |
| ChatGLM | <div style="text-align:left"><li>DeepSpeed</li></div> | <div style="text-align:left"><li>Single card</li></div> | <li>[language modeling](https://github.com/huggingface/optimum-habana/tree/main/examples/language-modeling)</li><li>[text generation](https://github.com/huggingface/optimum-habana/tree/main/examples/text-generation)</li> |
| Qwen2-VL | | <div style="text-align:left"><li>Single card</li></div> | <li>[image to text](https://github.com/huggingface/optimum-habana/tree/main/examples/image-to-text)</li> |

Expand Down
14 changes: 14 additions & 0 deletions examples/text-generation/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -202,6 +202,20 @@ python ../gaudi_spawn.py --use_deepspeed --world_size 8 run_generation.py \
--flash_attention_causal_mask
```

To run Deepseek-R1-BF16 inference on 16 Gaudi3 cards (2 nodes) use the following command. Ensure you replace the hostfile parameter with the appropriate file. Sample hostfile reference [here](https://github.com/huggingface/optimum-habana/blob/main/examples/multi-node-training/hostfile)
```bash
python3 ../gaudi_spawn.py --hostfile=<hostfile> --use_deepspeed \
--world_size 16 ./run_generation.py \
--model_name_or_path opensourcerelease/DeepSeek-R1-bf16 \
--bf16 \
--trim_logits \
--batch_size 1 \
--use_hpu_graphs \
--use_kv_cache \
--parallel_strategy "ep" \
--prompt "DeepSpeed is a machine learning framework"
```

> To be able to run gated models like [StarCoder](https://huggingface.co/bigcode/starcoder), you should:
> - have a HF account
> - agree to the terms of use of the model in its model card on the HF Hub
Expand Down
2 changes: 2 additions & 0 deletions optimum/habana/transformers/generation/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,6 +116,7 @@
"minicpm3",
"baichuan",
"deepseek_v2",
"deepseek_v3",
"chatglm",
"qwen2_vl",
]
Expand Down Expand Up @@ -1095,6 +1096,7 @@ def generate(
"baichuan",
"chatglm",
"deepseek_v2",
"deepseek_v3",
], (
"reuse_cache only supported by llama, mistral, falcon, mixtral, phi, qwen2, qwen2_moe, gemma, gemma2, starcoder2, baichuan, chatglm and deepseek_v2 at the moment"
)
Expand Down
10 changes: 10 additions & 0 deletions optimum/habana/transformers/modeling_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import accelerate
import transformers
import transformers.utils.fx
Expand Down Expand Up @@ -40,6 +41,7 @@
gaudi_awq_quantizer_process_model_before_weight_loading,
gaudi_awq_quantizer_validate_environment,
)
from .modeling_utils_transformers import load_state_dict
from .models import (
GAUDI_WHISPER_ATTENTION_CLASSES,
BaichuanConfig,
Expand All @@ -54,6 +56,8 @@
DeepseekTokenizerFast,
DeepseekV2Config,
DeepseekV2ForCausalLM,
DeepseekV3Config,
DeepseekV3ForCausalLM,
Gaudi2Idefics2ImageProcessor,
GaudiBloomForCausalLM,
GaudiBloomMLP,
Expand Down Expand Up @@ -313,6 +317,9 @@ def adapt_transformers_to_gaudi():
# optimize Conv1D
transformers.pytorch_utils.Conv1D.forward = gaudi_conv1d_forward

# override of load_state_dict for deepseekv3. Delete on upgrade to transformers v4.48
transformers.modeling_utils.load_state_dict = load_state_dict

# Optimization tweak for ViT
transformers.models.vit.modeling_vit.ViTSelfAttention.forward = gaudi_vit_self_attention_forward

Expand Down Expand Up @@ -743,9 +750,12 @@ def adapt_transformers_to_gaudi():
transformers.AutoConfig.register("deci", DeciLMConfig)
transformers.AutoModelForCausalLM.register(DeciLMConfig, DeciLMForCausalLM)

# Optimization for deepseek on Gaudi
transformers.AutoConfig.register("deepseek_v2", DeepseekV2Config)
transformers.AutoModelForCausalLM.register(DeepseekV2Config, DeepseekV2ForCausalLM)
transformers.AutoTokenizer.register(DeepseekV2Config, fast_tokenizer_class=DeepseekTokenizerFast)
transformers.AutoConfig.register("deepseek_v3", DeepseekV3Config)
transformers.AutoModelForCausalLM.register(DeepseekV3Config, DeepseekV3ForCausalLM)

# Optimization for cohere on Gaudi
transformers.models.cohere.modeling_cohere.CohereDecoderLayer = GaudiCohereDecoderLayer
Expand Down
89 changes: 89 additions & 0 deletions optimum/habana/transformers/modeling_utils_transformers.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,89 @@
import os
from typing import Optional, Union
from zipfile import is_zipfile

import torch
from packaging import version
from transformers.integrations import is_deepspeed_zero3_enabled
from transformers.modeling_utils import is_fsdp_enabled, is_local_dist_rank_0
from transformers.utils import (
is_safetensors_available,
)


if is_safetensors_available():
from safetensors import safe_open
from safetensors.torch import load_file as safe_load_file


def load_state_dict(
checkpoint_file: Union[str, os.PathLike],
is_quantized: bool = False,
map_location: Optional[Union[str, torch.device]] = None,
weights_only: bool = True,
):
"""
Reads a PyTorch checkpoint file, returning properly formatted errors if they arise.
Copied from transformers v4.48.2 for DeepSeek-R1 support https://github.com/huggingface/transformers/blob/b673c16cad81c71f70903a9a63f5b5f06014aa9e/src/transformers/modeling_utils.py#L493
Delete after upgrade transformers v4.45.2 to v4.48
"""
if checkpoint_file.endswith(".safetensors") and is_safetensors_available():
# Check format of the archive
with safe_open(checkpoint_file, framework="pt") as f:
metadata = f.metadata()
if metadata is not None and metadata.get("format") not in ["pt", "tf", "flax", "mlx"]:
raise OSError(
f"The safetensors archive passed at {checkpoint_file} does not contain the valid metadata. Make sure "
"you save your model with the `save_pretrained` method."
)
return safe_load_file(checkpoint_file)
try:
if map_location is None:
if (
(
is_deepspeed_zero3_enabled()
and torch.distributed.is_initialized()
and torch.distributed.get_rank() > 0
)
or (is_fsdp_enabled() and not is_local_dist_rank_0())
) and not is_quantized:
map_location = "meta"
else:
map_location = "cpu"
extra_args = {}
# mmap can only be used with files serialized with zipfile-based format.
if (
isinstance(checkpoint_file, str)
and map_location != "meta"
and version.parse(torch.__version__) >= version.parse("2.1.0")
and is_zipfile(checkpoint_file)
):
extra_args = {"mmap": True}
weights_only_kwarg = {"weights_only": weights_only}
return torch.load(
checkpoint_file,
map_location=map_location,
**weights_only_kwarg,
**extra_args,
)
except Exception as e:
try:
with open(checkpoint_file) as f:
if f.read(7) == "version":
raise OSError(
"You seem to have cloned a repository without having git-lfs installed. Please install "
"git-lfs and run `git lfs install` followed by `git lfs pull` in the folder "
"you cloned."
)
else:
raise ValueError(
f"Unable to locate the file {checkpoint_file} which is necessary to load this pretrained "
"model. Make sure you have saved the model properly."
) from e
except (UnicodeDecodeError, ValueError):
raise OSError(
f"Unable to load weights from pytorch checkpoint file for '{checkpoint_file}' "
f"at '{checkpoint_file}'. "
"If you tried to load a PyTorch model from a TF 2.0 checkpoint, please set from_tf=True."
)
4 changes: 4 additions & 0 deletions optimum/habana/transformers/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,10 @@
DeepseekV2Config,
DeepseekV2ForCausalLM,
)
from .deepseek_v3 import (
DeepseekV3Config,
DeepseekV3ForCausalLM,
)
from .detr import (
gaudi_DetrConvModel_forward,
gaudi_DetrHungarianMatcher_forward,
Expand Down
2 changes: 2 additions & 0 deletions optimum/habana/transformers/models/deepseek_v3/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
from .configuration_deepseek_v3 import DeepseekV3Config
from .modeling_deepseek_v3 import DeepseekV3ForCausalLM
Loading

0 comments on commit c5a715c

Please sign in to comment.