diff --git a/vllm_ascend/distributed/kv_transfer/kv_p2p/mooncake_connector.py b/vllm_ascend/distributed/kv_transfer/kv_p2p/mooncake_connector.py index 02fd1a61..4c692e33 100644 --- a/vllm_ascend/distributed/kv_transfer/kv_p2p/mooncake_connector.py +++ b/vllm_ascend/distributed/kv_transfer/kv_p2p/mooncake_connector.py @@ -1323,13 +1323,13 @@ class MooncakeConnectorWorker: def context_parallel_parameters_check(): assert (meta.remote_pcp_size * meta.remote_dcp_size) % (self.pcp_size * self.dcp_size) == 0 - if not self.use_mla: + if not (self.use_mla or self.use_sparse): p_node_heads_per_rank = math.ceil(self.num_key_value_heads / prefill_tp_size) d_node_heads_per_rank = math.ceil(self.num_key_value_heads / self.tp_size) assert d_node_heads_per_rank % p_node_heads_per_rank == 0 def get_kv_head_groups(tp_size): - if self.use_mla: + if self.use_mla or self.use_sparse: kv_head_groups = [] kv_head_ids = [0] kv_head_groups.append(tuple(kv_head_ids))