[Bugfix] Fix Dcp dimension mismatch when enable Mlapo (#4687)
### What this PR does / why we need it?
After enabling Mlapo and DCP, since Mlapo has its own mla_preprocess
logic and does not perform additional all_gather operations on the DCP
group, this will lead to dimension mismatch during the subsequent
forward proces
### Does this PR introduce _any_ user-facing change?
N/A
- vLLM version: v0.12.0
- vLLM main:
ad32e3e19c
Signed-off-by: zengran <zengran2@huawei.com>
Co-authored-by: wangxiyuan <wangxiyuan1007@gmail.com>
This commit is contained in:
@@ -1495,6 +1495,13 @@ class AscendMLAImpl(MLAAttentionImpl):
|
|||||||
self.kv_lora_rank)
|
self.kv_lora_rank)
|
||||||
decode_q_pe = decode_q_pe.view(bsz, self.num_heads, -1)
|
decode_q_pe = decode_q_pe.view(bsz, self.num_heads, -1)
|
||||||
|
|
||||||
|
if self.dcp_size > 1:
|
||||||
|
decode_q_no_split = torch.cat([decode_q_nope, decode_q_pe], dim=-1)
|
||||||
|
decode_q_no_split = get_dcp_group().all_gather(
|
||||||
|
decode_q_no_split, 1)
|
||||||
|
decode_q_nope, decode_q_pe = decode_q_no_split.split(
|
||||||
|
[self.kv_lora_rank, self.qk_rope_head_dim], dim=-1)
|
||||||
|
|
||||||
decode_preprocess_res = DecodeMLAPreprocessResult(
|
decode_preprocess_res = DecodeMLAPreprocessResult(
|
||||||
decode_q_nope, decode_q_pe, decode_k_nope, decode_k_pe)
|
decode_q_nope, decode_q_pe, decode_k_nope, decode_k_pe)
|
||||||
return decode_preprocess_res, None
|
return decode_preprocess_res, None
|
||||||
|
|||||||
Reference in New Issue
Block a user