diff --git a/tests/ut/kv_connector/test_mooncake_connector.py b/tests/ut/kv_connector/test_mooncake_connector.py index fabb67b9..48271269 100644 --- a/tests/ut/kv_connector/test_mooncake_connector.py +++ b/tests/ut/kv_connector/test_mooncake_connector.py @@ -1195,7 +1195,7 @@ class TestMooncakeConnectorWorker(unittest.TestCase): "prefill": {"tp_size": prefill_tp_size, "dp_size": 1, "pp_size": prefill_pp_size}, "decode": {"tp_size": decode_tp_size, "dp_size": 1, "pp_size": 1} }.get(k, d)): - self.vllm_config.model_config.hf_config.num_key_value_heads = num_kv_heads + self.vllm_config.model_config.hf_text_config.num_key_value_heads = num_kv_heads self.vllm_config.model_config.is_deepseek_mla = is_deepseek_mla worker = MooncakeConnectorWorker(self.vllm_config, self.engine_id) diff --git a/vllm_ascend/distributed/mooncake_connector.py b/vllm_ascend/distributed/mooncake_connector.py index 92ff6413..8f269610 100644 --- a/vllm_ascend/distributed/mooncake_connector.py +++ b/vllm_ascend/distributed/mooncake_connector.py @@ -319,14 +319,14 @@ class KVCacheRecvingThread(threading.Thread): self.block_size = self.vllm_config.cache_config.block_size if not is_vl_model(vllm_config): if self.use_mla: - self.k_head_dim = self.model_config.hf_config.kv_lora_rank - self.v_head_dim = self.model_config.hf_config.qk_rope_head_dim + self.k_head_dim = self.model_config.hf_text_config.kv_lora_rank + self.v_head_dim = self.model_config.hf_text_config.qk_rope_head_dim self.num_kv_heads = 1 else: - self.k_head_dim = self.model_config.hf_config.head_dim - self.v_head_dim = self.model_config.hf_config.head_dim + self.k_head_dim = self.model_config.hf_text_config.head_dim + self.v_head_dim = self.model_config.hf_text_config.head_dim self.num_kv_heads = max( - self.model_config.hf_config.num_key_value_heads // + self.model_config.hf_text_config.num_key_value_heads // self.tp_size, 1) def add_request(self, request_id: str, local_block_ids: list[int], @@ -438,7 +438,7 @@ class KVCacheRecvingThread(threading.Thread): remote_kv_caches_base_addrs = \ self.kv_caches_base_addr[remote_engine_id][remote_handshake_port] - num_layers = self.model_config.hf_config.num_hidden_layers + num_layers = self.model_config.hf_text_config.num_hidden_layers first_layer_index, end_layer_index = get_pp_indices( num_layers, prefill_pp_rank, self._prefill_pp_size) num_cache_per_layer = len(list( @@ -503,10 +503,11 @@ class KVCacheRecvingThread(threading.Thread): k_cache = list(self.kv_caches.values())[0][0] dtype = k_cache.dtype device = k_cache.device - head_dim = self.model_config.hf_config.head_dim + head_dim = self.model_config.hf_text_config.head_dim block_size = self.vllm_config.cache_config.block_size num_kv_head = max( - self.model_config.hf_config.num_key_value_heads // self.tp_size, 1) + self.model_config.hf_text_config.num_key_value_heads // + self.tp_size, 1) flat_block_ids = [item for sublist in block_ids for item in sublist] block_ids_tensor = torch.tensor(flat_block_ids, dtype=torch.int32) @@ -1010,7 +1011,7 @@ class MooncakeConnectorWorker: self.max_device_id = self.tp_size * self.dp_size * self.pcp_size * self.pp_size self.kv_role = vllm_config.kv_transfer_config.kv_role - self.num_key_value_heads = self.vllm_config.model_config.hf_config.num_key_value_heads + self.num_key_value_heads = self.vllm_config.model_config.hf_text_config.num_key_value_heads # Handshake base port self.side_channel_port = (