diff --git a/vllm_ascend/attention/mla_v1.py b/vllm_ascend/attention/mla_v1.py index 23ee0692..cd2111fa 100644 --- a/vllm_ascend/attention/mla_v1.py +++ b/vllm_ascend/attention/mla_v1.py @@ -1495,6 +1495,13 @@ class AscendMLAImpl(MLAAttentionImpl): self.kv_lora_rank) decode_q_pe = decode_q_pe.view(bsz, self.num_heads, -1) + if self.dcp_size > 1: + decode_q_no_split = torch.cat([decode_q_nope, decode_q_pe], dim=-1) + decode_q_no_split = get_dcp_group().all_gather( + decode_q_no_split, 1) + decode_q_nope, decode_q_pe = decode_q_no_split.split( + [self.kv_lora_rank, self.qk_rope_head_dim], dim=-1) + decode_preprocess_res = DecodeMLAPreprocessResult( decode_q_nope, decode_q_pe, decode_k_nope, decode_k_pe) return decode_preprocess_res, None