diff --git a/vllm_ascend/distributed/mooncake_connector.py b/vllm_ascend/distributed/mooncake_connector.py index 781aa2c4..92ff6413 100644 --- a/vllm_ascend/distributed/mooncake_connector.py +++ b/vllm_ascend/distributed/mooncake_connector.py @@ -42,7 +42,7 @@ from vllm.v1.request import RequestStatus from vllm_ascend.ascend_config import get_ascend_config, init_ascend_config from vllm_ascend.distributed.mooncake_transfer_engine import global_te from vllm_ascend.distributed.utils import get_transfer_timeout_value -from vllm_ascend.utils import prefill_context_parallel_enable +from vllm_ascend.utils import is_vl_model, prefill_context_parallel_enable if TYPE_CHECKING: from vllm.attention.backends.abstract import AttentionMetadata @@ -317,16 +317,17 @@ class KVCacheRecvingThread(threading.Thread): self.vllm_config = vllm_config self.model_config = self.vllm_config.model_config self.block_size = self.vllm_config.cache_config.block_size - if self.use_mla: - self.k_head_dim = self.model_config.hf_config.kv_lora_rank - self.v_head_dim = self.model_config.hf_config.qk_rope_head_dim - self.num_kv_heads = 1 - else: - self.k_head_dim = self.model_config.hf_config.head_dim - self.v_head_dim = self.model_config.hf_config.head_dim - self.num_kv_heads = max( - self.model_config.hf_config.num_key_value_heads // - self.tp_size, 1) + if not is_vl_model(vllm_config): + if self.use_mla: + self.k_head_dim = self.model_config.hf_config.kv_lora_rank + self.v_head_dim = self.model_config.hf_config.qk_rope_head_dim + self.num_kv_heads = 1 + else: + self.k_head_dim = self.model_config.hf_config.head_dim + self.v_head_dim = self.model_config.hf_config.head_dim + self.num_kv_heads = max( + self.model_config.hf_config.num_key_value_heads // + self.tp_size, 1) def add_request(self, request_id: str, local_block_ids: list[int], remote_block_ids: list[int], remote_engine_id: str,