[Attention] Temporarily add back pa for small batch sizes. (#4765)
### What this PR does / why we need it?
This PR adds back pa in scenarios of small batch sizes due to
performance consideration. Will remove pa once fia performs better than
pa in all scenarios.
### Does this PR introduce _any_ user-facing change?
No
### How was this patch tested?
CI passed with existing test.
- vLLM version: v0.12.0
- vLLM main:
ad32e3e19c
---------
Signed-off-by: whx-sjtu <2952154980@qq.com>
Co-authored-by: weijinqian0 <1184188277@qq.com>
This commit is contained in:
@@ -153,6 +153,13 @@ class AscendConfig:
|
||||
raise NotImplementedError(
|
||||
"This feature is still in the experiment and will be supported soon."
|
||||
)
|
||||
# We find that _npu_paged_attention still performs better than
|
||||
# npu_fused_infer_attention_score in some cases. We allow to execute
|
||||
# _npu_paged_attention in this cases. This should be removed once
|
||||
# npu_fused_infer_attention_score performs better on all scenarios.
|
||||
self.pa_shape_list = additional_config.get("pa_shape_list",
|
||||
[1, 2, 3, 4])
|
||||
|
||||
kv_cfg = vllm_config.kv_transfer_config
|
||||
if kv_cfg is not None and not getattr(kv_cfg, "_engine_id_patched",
|
||||
False):
|
||||
|
||||
@@ -34,7 +34,8 @@ from vllm.v1.core.sched.output import SchedulerOutput
|
||||
from vllm.v1.kv_cache_interface import AttentionSpec
|
||||
|
||||
from vllm_ascend.attention.utils import (AscendCommonAttentionMetadata,
|
||||
split_decodes_and_prefills)
|
||||
split_decodes_and_prefills,
|
||||
using_paged_attention)
|
||||
from vllm_ascend.compilation.acl_graph import (get_graph_params,
|
||||
update_graph_params_workspaces)
|
||||
from vllm_ascend.utils import (AscendDeviceType, get_ascend_device_type,
|
||||
@@ -488,6 +489,67 @@ class AscendAttentionBackendImpl(AttentionImpl):
|
||||
graph_params.handles[num_tokens].append(handle)
|
||||
return output, num_tokens
|
||||
|
||||
def full_graph_attention_with_pa(
|
||||
self,
|
||||
query: torch.Tensor,
|
||||
attn_metadata: AscendMetadata,
|
||||
output: Optional[torch.Tensor] = None,
|
||||
):
|
||||
graph_params = get_graph_params()
|
||||
forward_context: ForwardContext = get_forward_context()
|
||||
num_tokens = query.shape[0]
|
||||
if forward_context.capturing:
|
||||
# Get workspace from cache or calculate it if not present.
|
||||
workspace = graph_params.workspaces.get(num_tokens)
|
||||
if workspace is None:
|
||||
workspace = torch_npu._npu_paged_attention_get_workspace(
|
||||
query=query,
|
||||
key_cache=self.key_cache,
|
||||
value_cache=self.value_cache,
|
||||
num_kv_heads=self.num_kv_heads,
|
||||
num_heads=self.num_heads,
|
||||
scale_value=self.scale,
|
||||
block_table=attn_metadata.block_tables,
|
||||
context_lens=attn_metadata.seq_lens,
|
||||
out=output)
|
||||
update_graph_params_workspaces(num_tokens,
|
||||
weak_ref_tensors(workspace))
|
||||
|
||||
# Handle graph capturing mode
|
||||
stream = torch_npu.npu.current_stream()
|
||||
|
||||
event = torch.npu.ExternalEvent()
|
||||
event.wait(stream)
|
||||
event.reset(stream)
|
||||
graph_params.events[num_tokens].append(event)
|
||||
graph_params.attn_params[num_tokens].append((
|
||||
weak_ref_tensors(query),
|
||||
weak_ref_tensors(self.key_cache),
|
||||
weak_ref_tensors(self.value_cache),
|
||||
self.num_kv_heads,
|
||||
self.num_heads,
|
||||
self.scale,
|
||||
attn_metadata.block_tables,
|
||||
attn_metadata.seq_lens,
|
||||
weak_ref_tensors(output),
|
||||
))
|
||||
|
||||
torch.npu.graph_task_group_begin(stream)
|
||||
torch_npu._npu_paged_attention(
|
||||
query=query,
|
||||
key_cache=self.key_cache,
|
||||
value_cache=self.value_cache,
|
||||
num_kv_heads=self.num_kv_heads,
|
||||
num_heads=self.num_heads,
|
||||
scale_value=self.scale,
|
||||
block_table=attn_metadata.block_tables,
|
||||
context_lens=attn_metadata.seq_lens,
|
||||
out=output,
|
||||
workspace=workspace)
|
||||
handle = torch.npu.graph_task_group_end(stream)
|
||||
graph_params.handles[num_tokens].append(handle)
|
||||
return output
|
||||
|
||||
def _forward_prefill(self, query: torch.Tensor, key: torch.Tensor,
|
||||
value: torch.Tensor, attn_metadata: AscendMetadata,
|
||||
output: torch.Tensor):
|
||||
@@ -701,9 +763,14 @@ class AscendAttentionBackendImpl(AttentionImpl):
|
||||
output = self._forward_prefill(query, key, value,
|
||||
attn_metadata, output)
|
||||
else:
|
||||
attn_output, num_tokens = self.full_graph_attention(
|
||||
query, key, value, attn_metadata, output)
|
||||
output[:num_tokens] = attn_output[:num_tokens]
|
||||
num_tokens = query.shape[0]
|
||||
if using_paged_attention(num_tokens):
|
||||
output = self.full_graph_attention_with_pa(
|
||||
query, attn_metadata, output)
|
||||
else:
|
||||
attn_output, num_tokens = self.full_graph_attention(
|
||||
query, key, value, attn_metadata, output)
|
||||
output[:num_tokens] = attn_output[:num_tokens]
|
||||
|
||||
return output
|
||||
|
||||
|
||||
@@ -1,13 +1,29 @@
|
||||
from dataclasses import dataclass
|
||||
from functools import lru_cache
|
||||
from typing import Any, List, Optional
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from vllm.config import get_current_vllm_config
|
||||
from vllm.distributed.kv_transfer import (get_kv_transfer_group,
|
||||
has_kv_transfer_group,
|
||||
is_v1_kv_transfer_group)
|
||||
from vllm.forward_context import ForwardContext, get_forward_context
|
||||
|
||||
from vllm_ascend.utils import get_ascend_config
|
||||
|
||||
|
||||
@lru_cache
|
||||
def using_paged_attention(runtime_shape: int) -> bool:
|
||||
vllm_config = get_current_vllm_config()
|
||||
if vllm_config.speculative_config is not None:
|
||||
return False
|
||||
from vllm.config.compilation import CUDAGraphMode
|
||||
if vllm_config.compilation_config.cudagraph_mode != CUDAGraphMode.FULL_DECODE_ONLY:
|
||||
return False
|
||||
|
||||
return runtime_shape in get_ascend_config().pa_shape_list
|
||||
|
||||
|
||||
@dataclass
|
||||
# class AscendCommonLongSequenceMetadata:
|
||||
|
||||
@@ -19,6 +19,8 @@ from vllm.forward_context import BatchDescriptor, get_forward_context
|
||||
from vllm.logger import logger
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
from vllm_ascend.attention.utils import using_paged_attention
|
||||
|
||||
from ..utils import weak_ref_tensors
|
||||
|
||||
|
||||
@@ -193,7 +195,65 @@ class ACLGraphWrapper:
|
||||
return entry.output
|
||||
|
||||
|
||||
def update_attn_params(update_stream, forward_context, runtime_shape):
|
||||
def _update_attn_pa_params(update_stream, forward_context, runtime_shape):
|
||||
graph_params = get_graph_params()
|
||||
# FIXME: Behold! We are using a temporary hack here to update the args
|
||||
# for each layer's attention op in the graph.
|
||||
with torch.npu.stream(update_stream):
|
||||
for key, param, handle, event in zip(
|
||||
forward_context.attn_metadata,
|
||||
graph_params.attn_params[runtime_shape],
|
||||
graph_params.handles[runtime_shape],
|
||||
graph_params.events[runtime_shape],
|
||||
):
|
||||
(
|
||||
query,
|
||||
key_cache,
|
||||
value_cache,
|
||||
num_kv_heads,
|
||||
num_heads,
|
||||
scale,
|
||||
block_table,
|
||||
seq_lens,
|
||||
output,
|
||||
) = param
|
||||
seq_lens = forward_context.attn_metadata[key].seq_lens
|
||||
|
||||
# When using FULL_DECODE_ONLY, there are some rare bugs for FULL_DECODE_ONLY
|
||||
# mode with GQA. This is triggered by getting workspace for _npu_paged_attention
|
||||
# in torch_npu. On some rare cases, _npu_paged_attention with smaller seq_lens
|
||||
# might encounter a bigger workspace, while currently we use max_model_len to
|
||||
# calculate max workspace in capturing. So additional get_workspace is added
|
||||
# here to avoid such bugs.
|
||||
# TODO(Angazenn): we will remove this once _npu_paged_attention is fully
|
||||
# replaced by npu_fused_infer_attention_score which does not contain such bugs.
|
||||
workspace = torch_npu._npu_paged_attention_get_workspace(
|
||||
query=query,
|
||||
key_cache=key_cache,
|
||||
value_cache=value_cache,
|
||||
num_kv_heads=num_kv_heads,
|
||||
num_heads=num_heads,
|
||||
scale_value=scale,
|
||||
block_table=block_table,
|
||||
context_lens=seq_lens,
|
||||
out=output)
|
||||
torch.npu.graph_task_update_begin(update_stream, handle)
|
||||
torch_npu._npu_paged_attention(query=query,
|
||||
key_cache=key_cache,
|
||||
value_cache=value_cache,
|
||||
num_kv_heads=num_kv_heads,
|
||||
num_heads=num_heads,
|
||||
scale_value=scale,
|
||||
block_table=block_table,
|
||||
context_lens=seq_lens,
|
||||
out=output,
|
||||
workspace=workspace)
|
||||
torch.npu.graph_task_update_end(update_stream)
|
||||
|
||||
event.record(update_stream)
|
||||
|
||||
|
||||
def _update_attn_fia_params(update_stream, forward_context, runtime_shape):
|
||||
graph_params = get_graph_params()
|
||||
# For Qwen3-next, since the kv_cache_config has already categorized
|
||||
# linear_attn and self_attn, the attn_metadata is first arranged with
|
||||
@@ -236,6 +296,13 @@ def update_attn_params(update_stream, forward_context, runtime_shape):
|
||||
event.record(update_stream)
|
||||
|
||||
|
||||
def update_attn_params(update_stream, forward_context, runtime_shape):
|
||||
if using_paged_attention(runtime_shape):
|
||||
_update_attn_pa_params(update_stream, forward_context, runtime_shape)
|
||||
else:
|
||||
_update_attn_fia_params(update_stream, forward_context, runtime_shape)
|
||||
|
||||
|
||||
def update_mla_attn_params(update_stream, forward_context, runtime_shape,
|
||||
speculative_config):
|
||||
if forward_context.is_mtp_model:
|
||||
@@ -446,7 +513,7 @@ def set_graph_params(aclgraph_capture_sizes: list[int]):
|
||||
)
|
||||
|
||||
|
||||
def update_graph_params_workspaces(num_tokens: int, workspace: int):
|
||||
def update_graph_params_workspaces(num_tokens: int, workspace: torch.Tensor):
|
||||
global _graph_params
|
||||
if _graph_params is not None:
|
||||
_graph_params.workspaces[num_tokens] = weak_ref_tensors(workspace)
|
||||
|
||||
Reference in New Issue
Block a user