[Bugfix] Fix Dcp dimension mismatch when enable Mlapo (#4687)

### What this PR does / why we need it?
After enabling Mlapo and DCP, since Mlapo has its own mla_preprocess
logic and does not perform additional all_gather operations on the DCP
group, this will lead to dimension mismatch during the subsequent
forward proces

### Does this PR introduce _any_ user-facing change?

N/A


- vLLM version: v0.12.0
- vLLM main:
ad32e3e19c

Signed-off-by: zengran <zengran2@huawei.com>
Co-authored-by: wangxiyuan <wangxiyuan1007@gmail.com>
This commit is contained in:
zengzengran
2025-12-08 17:19:58 +08:00
committed by GitHub
parent afe00505de
commit f0876b5d88

View File

@@ -1495,6 +1495,13 @@ class AscendMLAImpl(MLAAttentionImpl):
self.kv_lora_rank)
decode_q_pe = decode_q_pe.view(bsz, self.num_heads, -1)
if self.dcp_size > 1:
decode_q_no_split = torch.cat([decode_q_nope, decode_q_pe], dim=-1)
decode_q_no_split = get_dcp_group().all_gather(
decode_q_no_split, 1)
decode_q_nope, decode_q_pe = decode_q_no_split.split(
[self.kv_lora_rank, self.qk_rope_head_dim], dim=-1)
decode_preprocess_res = DecodeMLAPreprocessResult(
decode_q_nope, decode_q_pe, decode_k_nope, decode_k_pe)
return decode_preprocess_res, None