diff --git a/vllm-v0.6.2/vllm_mlu/vllm_mlu/model_executor/models/deepseek_v2.py b/vllm-v0.6.2/vllm_mlu/vllm_mlu/model_executor/models/deepseek_v2.py index e45f792..2dd9446 100644 --- a/vllm-v0.6.2/vllm_mlu/vllm_mlu/model_executor/models/deepseek_v2.py +++ b/vllm-v0.6.2/vllm_mlu/vllm_mlu/model_executor/models/deepseek_v2.py @@ -252,19 +252,27 @@ def forward_prefill( updated_slot_mapping = attn_metadata.slot_mapping if self.attn.kv_cache_dtype == 'int8': key_cache_scale = kv_cache[1][0] - mlu_ops.quant_to_paged_cache(key_value, + mlu_ops.quant_to_linear_cache(key_value, + None, + key_cache, + None, + key_cache_scale, + None, + attn_metadata.cu_seq_lens, + attn_metadata.max_seq_len, + True, None, + attn_metadata.batch_ids, + attn_metadata.slot_mapping_unpaged) + else: + mlu_ops.reshape_linear_cache(key_value, None, key_cache, None, - key_cache_scale, - None, - updated_slot_mapping.flatten()) - else: - mlu_ops.reshape_paged_cache(key_value, - None, - key_cache, - None, - updated_slot_mapping.flatten()) + attn_metadata.cu_seq_lens, + attn_metadata.max_seq_len, + True, None, + attn_metadata.batch_ids, + attn_metadata.slot_mapping_unpaged) ''' ================== End of MLU Hijack