diff --git a/vllm-v0.6.2/vllm_mlu/vllm_mlu/config.py b/vllm-v0.6.2/vllm_mlu/vllm_mlu/config.py index 5f1dc07..52ba6fa 100644 --- a/vllm-v0.6.2/vllm_mlu/vllm_mlu/config.py +++ b/vllm-v0.6.2/vllm_mlu/vllm_mlu/config.py @@ -37,7 +37,7 @@ def vllm__config__CacheConfig___verify_cache_dtype(self) -> None: def vllm__config__ModelConfig__get_num_kv_heads(self, parallel_config: "ParallelConfig") -> int: """Returns the number of KV heads per GPU.""" - if hasattr(self.hf_text_config,"model_type") and self.hf_text_config.model_type in ('deepseek_v2', 'deepseek_v3'): + if hasattr(self.hf_text_config,"model_type") and self.hf_text_config.model_type in ('deepseek_v2', 'deepseek_v3', 'deepseek_mtp'): # feature flag MLA return 1 total_num_kv_heads = self.get_total_num_kv_heads() @@ -51,7 +51,7 @@ def vllm__config__ModelConfig__get_num_kv_heads(self, parallel_config: "Parallel def vllm__config__ModelConfig__get_head_size(self) -> int: # TODO remove hard code if hasattr(self.hf_text_config, "model_type" - ) and self.hf_text_config.model_type in ('deepseek_v2', 'deepseek_v3'): + ) and self.hf_text_config.model_type in ('deepseek_v2', 'deepseek_v3', 'deepseek_mtp'): ''' ============================= Modify by vllm_mlu @@ -109,7 +109,7 @@ def vllm__config__LoRAConfig__verify_with_model_config(self, model_config: Model def vllm__config__ModelConfig__is_deepseek_v2(self) -> bool: result = hasattr( self.hf_text_config, - "model_type") and self.hf_text_config.model_type in ('deepseek_v2', 'deepseek_v3') + "model_type") and self.hf_text_config.model_type in ('deepseek_v2', 'deepseek_v3', 'deepseek_mtp') return result MluHijackObject.apply_hijack(ModelConfig,