From f0876b5d88a04734505bb5cea7bda95b026923ff Mon Sep 17 00:00:00 2001 From: zengzengran Date: Mon, 8 Dec 2025 17:19:58 +0800 Subject: [PATCH] [Bugfix] Fix Dcp dimension mismatch when enable Mlapo (#4687) ### What this PR does / why we need it? After enabling Mlapo and DCP, since Mlapo has its own mla_preprocess logic and does not perform additional all_gather operations on the DCP group, this will lead to dimension mismatch during the subsequent forward proces ### Does this PR introduce _any_ user-facing change? N/A - vLLM version: v0.12.0 - vLLM main: https://github.com/vllm-project/vllm/commit/ad32e3e19ccf0526cb6744a5fed09a138a5fb2f9 Signed-off-by: zengran Co-authored-by: wangxiyuan --- vllm_ascend/attention/mla_v1.py | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/vllm_ascend/attention/mla_v1.py b/vllm_ascend/attention/mla_v1.py index 23ee0692..cd2111fa 100644 --- a/vllm_ascend/attention/mla_v1.py +++ b/vllm_ascend/attention/mla_v1.py @@ -1495,6 +1495,13 @@ class AscendMLAImpl(MLAAttentionImpl): self.kv_lora_rank) decode_q_pe = decode_q_pe.view(bsz, self.num_heads, -1) + if self.dcp_size > 1: + decode_q_no_split = torch.cat([decode_q_nope, decode_q_pe], dim=-1) + decode_q_no_split = get_dcp_group().all_gather( + decode_q_no_split, 1) + decode_q_nope, decode_q_pe = decode_q_no_split.split( + [self.kv_lora_rank, self.qk_rope_head_dim], dim=-1) + decode_preprocess_res = DecodeMLAPreprocessResult( decode_q_nope, decode_q_pe, decode_k_nope, decode_k_pe) return decode_preprocess_res, None