forked from EngineX-Cambricon/enginex-mlu370-vllm
add deepseekv3 and llama4
This commit is contained in:
@@ -37,7 +37,7 @@ def vllm__config__CacheConfig___verify_cache_dtype(self) -> None:
|
||||
|
||||
def vllm__config__ModelConfig__get_num_kv_heads(self, parallel_config: "ParallelConfig") -> int:
|
||||
"""Returns the number of KV heads per GPU."""
|
||||
if hasattr(self.hf_text_config,"model_type") and self.hf_text_config.model_type in ('deepseek_v2', 'deepseek_v3'):
|
||||
if hasattr(self.hf_text_config,"model_type") and self.hf_text_config.model_type in ('deepseek_v2', 'deepseek_v3', 'deepseek_mtp'):
|
||||
# feature flag MLA
|
||||
return 1
|
||||
total_num_kv_heads = self.get_total_num_kv_heads()
|
||||
@@ -51,7 +51,7 @@ def vllm__config__ModelConfig__get_num_kv_heads(self, parallel_config: "Parallel
|
||||
def vllm__config__ModelConfig__get_head_size(self) -> int:
|
||||
# TODO remove hard code
|
||||
if hasattr(self.hf_text_config, "model_type"
|
||||
) and self.hf_text_config.model_type in ('deepseek_v2', 'deepseek_v3'):
|
||||
) and self.hf_text_config.model_type in ('deepseek_v2', 'deepseek_v3', 'deepseek_mtp'):
|
||||
'''
|
||||
=============================
|
||||
Modify by vllm_mlu
|
||||
@@ -109,7 +109,7 @@ def vllm__config__LoRAConfig__verify_with_model_config(self, model_config: Model
|
||||
def vllm__config__ModelConfig__is_deepseek_v2(self) -> bool:
|
||||
result = hasattr(
|
||||
self.hf_text_config,
|
||||
"model_type") and self.hf_text_config.model_type in ('deepseek_v2', 'deepseek_v3')
|
||||
"model_type") and self.hf_text_config.model_type in ('deepseek_v2', 'deepseek_v3', 'deepseek_mtp')
|
||||
return result
|
||||
|
||||
MluHijackObject.apply_hijack(ModelConfig,
|
||||
|
||||
Reference in New Issue
Block a user