[Bugfix] Fix Qwen P/D Disaggregation accuracy issue (#5340)
### What this PR does / why we need it?
Fix Qwen P/D Disaggregation accuracy issue
- vLLM version: release/v0.13.0
- vLLM main:
bc0a5a0c08
Signed-off-by: F.Liu <liufeng248@huawei.com>
Co-authored-by: F.Liu <liufeng248@huawei.com>
This commit is contained in:
@@ -549,6 +549,11 @@ class AscendAttentionCPImpl(AscendAttentionBackendImpl):
|
||||
attn_out, attn_lse = torch_npu.npu_fused_infer_attention_score(
|
||||
query, k_nope, value, **common_kwargs)
|
||||
|
||||
out_mask = attn_metadata.decode_meta.batch_seq_mask[:, None,
|
||||
None].expand_as(
|
||||
attn_out)
|
||||
attn_out = torch.where(out_mask, 0, attn_out)
|
||||
|
||||
lse_mask = attn_metadata.decode_meta.batch_seq_mask[:, None,
|
||||
None].expand_as(
|
||||
attn_lse)
|
||||
|
||||
Reference in New Issue
Block a user