diff --git a/vllm_ascend/attention/attention_v1.py b/vllm_ascend/attention/attention_v1.py index 87d28ec7..6955ccea 100644 --- a/vllm_ascend/attention/attention_v1.py +++ b/vllm_ascend/attention/attention_v1.py @@ -727,7 +727,7 @@ class AscendAttentionBackendImpl(AttentionImpl): key = self.key_cache.flatten(2, 3).contiguous() value = self.value_cache.flatten(2, 3).contiguous() - output, _ = torch_npu.npu_fused_infer_attention_score( + attn_output, _ = torch_npu.npu_fused_infer_attention_score( query, key, value, @@ -742,7 +742,8 @@ class AscendAttentionBackendImpl(AttentionImpl): actual_seq_lengths_kv=attn_metadata.seq_lens, ) - output = output.view(batch_size, self.num_heads, self.head_size) + attn_output = attn_output.view(batch_size, self.num_heads, self.head_size) + output[:batch_size] = attn_output[:batch_size] return output def forward_fused_infer_attention(