[Attention] Temporarily add back pa for small batch sizes. (#4765)

### What this PR does / why we need it?
This PR adds back pa in scenarios of small batch sizes due to
performance consideration. Will remove pa once fia performs better than
pa in all scenarios.

### Does this PR introduce _any_ user-facing change?
No

### How was this patch tested?
CI passed with existing test.


- vLLM version: v0.12.0
- vLLM main:
ad32e3e19c

---------

Signed-off-by: whx-sjtu <2952154980@qq.com>
Co-authored-by: weijinqian0 <1184188277@qq.com>
This commit is contained in:
whx
2025-12-15 20:35:50 +08:00
committed by GitHub
parent 95e6400128
commit a9625851ef
4 changed files with 163 additions and 6 deletions

View File

@@ -34,7 +34,8 @@ from vllm.v1.core.sched.output import SchedulerOutput
from vllm.v1.kv_cache_interface import AttentionSpec
from vllm_ascend.attention.utils import (AscendCommonAttentionMetadata,
split_decodes_and_prefills)
split_decodes_and_prefills,
using_paged_attention)
from vllm_ascend.compilation.acl_graph import (get_graph_params,
update_graph_params_workspaces)
from vllm_ascend.utils import (AscendDeviceType, get_ascend_device_type,
@@ -488,6 +489,67 @@ class AscendAttentionBackendImpl(AttentionImpl):
graph_params.handles[num_tokens].append(handle)
return output, num_tokens
def full_graph_attention_with_pa(
self,
query: torch.Tensor,
attn_metadata: AscendMetadata,
output: Optional[torch.Tensor] = None,
):
graph_params = get_graph_params()
forward_context: ForwardContext = get_forward_context()
num_tokens = query.shape[0]
if forward_context.capturing:
# Get workspace from cache or calculate it if not present.
workspace = graph_params.workspaces.get(num_tokens)
if workspace is None:
workspace = torch_npu._npu_paged_attention_get_workspace(
query=query,
key_cache=self.key_cache,
value_cache=self.value_cache,
num_kv_heads=self.num_kv_heads,
num_heads=self.num_heads,
scale_value=self.scale,
block_table=attn_metadata.block_tables,
context_lens=attn_metadata.seq_lens,
out=output)
update_graph_params_workspaces(num_tokens,
weak_ref_tensors(workspace))
# Handle graph capturing mode
stream = torch_npu.npu.current_stream()
event = torch.npu.ExternalEvent()
event.wait(stream)
event.reset(stream)
graph_params.events[num_tokens].append(event)
graph_params.attn_params[num_tokens].append((
weak_ref_tensors(query),
weak_ref_tensors(self.key_cache),
weak_ref_tensors(self.value_cache),
self.num_kv_heads,
self.num_heads,
self.scale,
attn_metadata.block_tables,
attn_metadata.seq_lens,
weak_ref_tensors(output),
))
torch.npu.graph_task_group_begin(stream)
torch_npu._npu_paged_attention(
query=query,
key_cache=self.key_cache,
value_cache=self.value_cache,
num_kv_heads=self.num_kv_heads,
num_heads=self.num_heads,
scale_value=self.scale,
block_table=attn_metadata.block_tables,
context_lens=attn_metadata.seq_lens,
out=output,
workspace=workspace)
handle = torch.npu.graph_task_group_end(stream)
graph_params.handles[num_tokens].append(handle)
return output
def _forward_prefill(self, query: torch.Tensor, key: torch.Tensor,
value: torch.Tensor, attn_metadata: AscendMetadata,
output: torch.Tensor):
@@ -701,9 +763,14 @@ class AscendAttentionBackendImpl(AttentionImpl):
output = self._forward_prefill(query, key, value,
attn_metadata, output)
else:
attn_output, num_tokens = self.full_graph_attention(
query, key, value, attn_metadata, output)
output[:num_tokens] = attn_output[:num_tokens]
num_tokens = query.shape[0]
if using_paged_attention(num_tokens):
output = self.full_graph_attention_with_pa(
query, attn_metadata, output)
else:
attn_output, num_tokens = self.full_graph_attention(
query, key, value, attn_metadata, output)
output[:num_tokens] = attn_output[:num_tokens]
return output