diff --git a/vllm_ascend/distributed/mooncake_layerwise_connector.py b/vllm_ascend/distributed/mooncake_layerwise_connector.py index e7e3219a..f9e15d5c 100644 --- a/vllm_ascend/distributed/mooncake_layerwise_connector.py +++ b/vllm_ascend/distributed/mooncake_layerwise_connector.py @@ -1078,10 +1078,14 @@ class MooncakeLayerwiseConnectorWorker: ): # enable decode prefix cache if self.use_mla or self.use_sparse: - num_kv_head = self._decode_tp_size + num_need_send = self._decode_tp_size else: num_kv_head = self.vllm_config.model_config.hf_config.num_key_value_heads - num_replica_groups = self.tp_size // num_kv_head if self.tp_size >= num_kv_head else 1 + if self.tp_size <= num_kv_head: + num_need_send = self.tp_size + else: + num_need_send = self._decode_tp_size if self._decode_tp_size >= num_kv_head else num_kv_head + num_replica_groups = self.tp_size // num_need_send if self.tp_size >= num_need_send else 1 replica_group_idx = self.tp_rank % num_replica_groups req_ids = sorted(list(connector_metadata.requests.keys())) selected_req_ids = [