From 8e66299bf16a7b48cf85b080dc3f8927144d5add Mon Sep 17 00:00:00 2001 From: Ruowei Zheng <892882856@qq.com> Date: Thu, 5 Feb 2026 20:58:54 +0800 Subject: [PATCH] [Bugfix] Fix the incorrect use of the output parameter in _forward_fia_slidingwindow (#6469) ### What this PR does / why we need it? Fix the incorrect use of the `output` parameter in `_forward_fia_slidingwindow`: ``` # Original (incorrect) output, _ = torch_npu.npu_fused_infer_attention_score(...) output= output.view(batch_size, self.num_heads, self.head_size) ``` In the original writing, the `output `parameter was directly assigned a new value, which is inconsistent with the interface definition, resulting in the inability to directly update `output `when calling externally. ``` attn_output, _ = torch_npu.npu_fused_infer_attention_score(...) attn_output = attn_output.view(batch_size, self.num_heads, self.head_size) output[:batch_size] = attn_output[:batch_size] ``` ### Does this PR introduce _any_ user-facing change? No change. Co-authored-by: GoCHug ### How was this patch tested? vLLM ascend version: v0.13.0rc1 Signed-off-by: acat-rw <892882856@qq.com> --- vllm_ascend/attention/attention_v1.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/vllm_ascend/attention/attention_v1.py b/vllm_ascend/attention/attention_v1.py index 87d28ec7..6955ccea 100644 --- a/vllm_ascend/attention/attention_v1.py +++ b/vllm_ascend/attention/attention_v1.py @@ -727,7 +727,7 @@ class AscendAttentionBackendImpl(AttentionImpl): key = self.key_cache.flatten(2, 3).contiguous() value = self.value_cache.flatten(2, 3).contiguous() - output, _ = torch_npu.npu_fused_infer_attention_score( + attn_output, _ = torch_npu.npu_fused_infer_attention_score( query, key, value, @@ -742,7 +742,8 @@ class AscendAttentionBackendImpl(AttentionImpl): actual_seq_lengths_kv=attn_metadata.seq_lens, ) - output = output.view(batch_size, self.num_heads, self.head_size) + attn_output = attn_output.view(batch_size, self.num_heads, self.head_size) + output[:batch_size] = attn_output[:batch_size] return output def forward_fused_infer_attention(