24
24
from vllm .attention import Attention , AttentionMetadata
25
25
from vllm .lora .worker_manager import LRUCacheWorkerLoRAManager
26
26
from vllm .config import DeviceConfig
27
- from typing import Tuple
27
+
28
+ from vllm ._C import ops
28
29
from ipex_llm .utils .common import invalidInputError
30
+ from typing import List , Optional , Tuple , Union
29
31
30
32
31
33
def _MLP_forward (self , x ):
@@ -42,7 +44,7 @@ def _Attention_forward(
42
44
kv_cache : torch .Tensor ,
43
45
attn_metadata : AttentionMetadata ,
44
46
) -> torch .Tensor :
45
- qkv = self .qkv_proj (hidden_states )
47
+ qkv = self .qkv_proj (hidden_states ). to ( dtype = kv_cache . dtype )
46
48
q , k , v = qkv .split ([self .q_size , self .kv_size , self .kv_size ], dim = - 1 )
47
49
q , k = self .rotary_emb (positions , q , k )
48
50
attn_output = self .attn (q , k , v , kv_cache , attn_metadata , self .kv_scale )
@@ -145,21 +147,77 @@ def _model_attention_convert():
145
147
146
148
147
149
def _ipex_llm_convert (load_in_low_bit ):
148
- from vllm .worker .model_runner import ModelRunner
150
+ from vllm .worker .cpu_model_runner import CPUModelRunner
149
151
import vllm .model_executor .model_loader as model_loader
150
- setattr (ModelRunner , "load_model" , get_load_function (load_in_low_bit ))
152
+ setattr (CPUModelRunner , "load_model" , get_load_function (load_in_low_bit ))
153
+
154
+ from vllm .model_executor .layers .rotary_embedding import RotaryEmbedding
155
+ setattr (RotaryEmbedding , "forward" , _ipex_llm_rotary_embedding_forward )
156
+ from vllm .model_executor .layers .layernorm import RMSNorm
157
+ setattr (RMSNorm , "forward" , _ipex_llm_rmsnorm_forward )
158
+
159
+
160
+ def _ipex_llm_rotary_embedding_forward (
161
+ self ,
162
+ positions : torch .Tensor ,
163
+ query : torch .Tensor ,
164
+ key : torch .Tensor ,
165
+ offsets : Optional [torch .Tensor ] = None ,
166
+ ) -> Tuple [torch .Tensor , torch .Tensor ]:
167
+ self .cos_sin_cache = self .cos_sin_cache .to (positions .device , dtype = query .dtype )
168
+
169
+ # ops.rotary_embedding()/batched_rotary_embedding()
170
+ # are in-place operations that update the query and key tensors.
171
+ if offsets is not None :
172
+ ops .batched_rotary_embedding (positions , query , key , self .head_size ,
173
+ self .cos_sin_cache ,
174
+ self .is_neox_style , self .rotary_dim ,
175
+ offsets )
176
+ else :
177
+ ops .rotary_embedding (positions , query , key , self .head_size ,
178
+ self .cos_sin_cache , self .is_neox_style )
179
+ return query , key
180
+
181
+
182
+ def _ipex_llm_rmsnorm_forward (
183
+ self ,
184
+ x : torch .Tensor ,
185
+ residual : Optional [torch .Tensor ] = None ,
186
+ ) -> Union [torch .Tensor , Tuple [torch .Tensor , torch .Tensor ]]:
187
+ x = x .to (dtype = self .weight .data .dtype )
188
+ if residual is not None :
189
+ residual = residual .to (dtype = self .weight .data .dtype )
190
+ ops .fused_add_rms_norm (
191
+ x ,
192
+ residual ,
193
+ self .weight .data ,
194
+ self .variance_epsilon ,
195
+ )
196
+ return x , residual
197
+ out = torch .empty_like (x )
198
+ ops .rms_norm (
199
+ out ,
200
+ x ,
201
+ self .weight .data ,
202
+ self .variance_epsilon ,
203
+ )
204
+ return out
151
205
152
206
153
207
def get_load_function (low_bit ):
154
208
def _ipex_llm_load_model (self ) -> None :
155
209
_model_mlp_convert ()
156
210
_model_attention_convert ()
157
211
158
- self .model = get_model (self .model_config ,
159
- self .device_config ,
160
- lora_config = self .lora_config ,
161
- parallel_config = self .parallel_config ,
162
- scheduler_config = self .scheduler_config )
212
+ self .model = get_model (
213
+ model_config = self .model_config ,
214
+ load_config = self .load_config ,
215
+ device_config = self .device_config ,
216
+ vision_language_config = self .vision_language_config ,
217
+ lora_config = self .lora_config ,
218
+ parallel_config = self .parallel_config ,
219
+ scheduler_config = self .scheduler_config )
220
+
163
221
from ipex_llm import optimize_model
164
222
optimize_model (self .model , low_bit = low_bit , torch_dtype = self .model_config .dtype )
165
223
0 commit comments