diff --git a/vllm_ascend/ascend_config.py b/vllm_ascend/ascend_config.py index f7522320..48d23a7c 100644 --- a/vllm_ascend/ascend_config.py +++ b/vllm_ascend/ascend_config.py @@ -153,6 +153,13 @@ class AscendConfig: raise NotImplementedError( "This feature is still in the experiment and will be supported soon." ) + # We find that _npu_paged_attention still performs better than + # npu_fused_infer_attention_score in some cases. We allow to execute + # _npu_paged_attention in this cases. This should be removed once + # npu_fused_infer_attention_score performs better on all scenarios. + self.pa_shape_list = additional_config.get("pa_shape_list", + [1, 2, 3, 4]) + kv_cfg = vllm_config.kv_transfer_config if kv_cfg is not None and not getattr(kv_cfg, "_engine_id_patched", False): diff --git a/vllm_ascend/attention/attention_v1.py b/vllm_ascend/attention/attention_v1.py index b5e91096..ad22b0dc 100644 --- a/vllm_ascend/attention/attention_v1.py +++ b/vllm_ascend/attention/attention_v1.py @@ -34,7 +34,8 @@ from vllm.v1.core.sched.output import SchedulerOutput from vllm.v1.kv_cache_interface import AttentionSpec from vllm_ascend.attention.utils import (AscendCommonAttentionMetadata, - split_decodes_and_prefills) + split_decodes_and_prefills, + using_paged_attention) from vllm_ascend.compilation.acl_graph import (get_graph_params, update_graph_params_workspaces) from vllm_ascend.utils import (AscendDeviceType, get_ascend_device_type, @@ -488,6 +489,67 @@ class AscendAttentionBackendImpl(AttentionImpl): graph_params.handles[num_tokens].append(handle) return output, num_tokens + def full_graph_attention_with_pa( + self, + query: torch.Tensor, + attn_metadata: AscendMetadata, + output: Optional[torch.Tensor] = None, + ): + graph_params = get_graph_params() + forward_context: ForwardContext = get_forward_context() + num_tokens = query.shape[0] + if forward_context.capturing: + # Get workspace from cache or calculate it if not present. + workspace = graph_params.workspaces.get(num_tokens) + if workspace is None: + workspace = torch_npu._npu_paged_attention_get_workspace( + query=query, + key_cache=self.key_cache, + value_cache=self.value_cache, + num_kv_heads=self.num_kv_heads, + num_heads=self.num_heads, + scale_value=self.scale, + block_table=attn_metadata.block_tables, + context_lens=attn_metadata.seq_lens, + out=output) + update_graph_params_workspaces(num_tokens, + weak_ref_tensors(workspace)) + + # Handle graph capturing mode + stream = torch_npu.npu.current_stream() + + event = torch.npu.ExternalEvent() + event.wait(stream) + event.reset(stream) + graph_params.events[num_tokens].append(event) + graph_params.attn_params[num_tokens].append(( + weak_ref_tensors(query), + weak_ref_tensors(self.key_cache), + weak_ref_tensors(self.value_cache), + self.num_kv_heads, + self.num_heads, + self.scale, + attn_metadata.block_tables, + attn_metadata.seq_lens, + weak_ref_tensors(output), + )) + + torch.npu.graph_task_group_begin(stream) + torch_npu._npu_paged_attention( + query=query, + key_cache=self.key_cache, + value_cache=self.value_cache, + num_kv_heads=self.num_kv_heads, + num_heads=self.num_heads, + scale_value=self.scale, + block_table=attn_metadata.block_tables, + context_lens=attn_metadata.seq_lens, + out=output, + workspace=workspace) + handle = torch.npu.graph_task_group_end(stream) + graph_params.handles[num_tokens].append(handle) + return output + def _forward_prefill(self, query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, attn_metadata: AscendMetadata, output: torch.Tensor): @@ -701,9 +763,14 @@ class AscendAttentionBackendImpl(AttentionImpl): output = self._forward_prefill(query, key, value, attn_metadata, output) else: - attn_output, num_tokens = self.full_graph_attention( - query, key, value, attn_metadata, output) - output[:num_tokens] = attn_output[:num_tokens] + num_tokens = query.shape[0] + if using_paged_attention(num_tokens): + output = self.full_graph_attention_with_pa( + query, attn_metadata, output) + else: + attn_output, num_tokens = self.full_graph_attention( + query, key, value, attn_metadata, output) + output[:num_tokens] = attn_output[:num_tokens] return output diff --git a/vllm_ascend/attention/utils.py b/vllm_ascend/attention/utils.py index 25a2f2b8..cb95871a 100644 --- a/vllm_ascend/attention/utils.py +++ b/vllm_ascend/attention/utils.py @@ -1,13 +1,29 @@ from dataclasses import dataclass +from functools import lru_cache from typing import Any, List, Optional import torch import torch.nn.functional as F +from vllm.config import get_current_vllm_config from vllm.distributed.kv_transfer import (get_kv_transfer_group, has_kv_transfer_group, is_v1_kv_transfer_group) from vllm.forward_context import ForwardContext, get_forward_context +from vllm_ascend.utils import get_ascend_config + + +@lru_cache +def using_paged_attention(runtime_shape: int) -> bool: + vllm_config = get_current_vllm_config() + if vllm_config.speculative_config is not None: + return False + from vllm.config.compilation import CUDAGraphMode + if vllm_config.compilation_config.cudagraph_mode != CUDAGraphMode.FULL_DECODE_ONLY: + return False + + return runtime_shape in get_ascend_config().pa_shape_list + @dataclass # class AscendCommonLongSequenceMetadata: diff --git a/vllm_ascend/compilation/acl_graph.py b/vllm_ascend/compilation/acl_graph.py index ab611445..ab1c6ae2 100644 --- a/vllm_ascend/compilation/acl_graph.py +++ b/vllm_ascend/compilation/acl_graph.py @@ -19,6 +19,8 @@ from vllm.forward_context import BatchDescriptor, get_forward_context from vllm.logger import logger from vllm.platforms import current_platform +from vllm_ascend.attention.utils import using_paged_attention + from ..utils import weak_ref_tensors @@ -193,7 +195,65 @@ class ACLGraphWrapper: return entry.output -def update_attn_params(update_stream, forward_context, runtime_shape): +def _update_attn_pa_params(update_stream, forward_context, runtime_shape): + graph_params = get_graph_params() + # FIXME: Behold! We are using a temporary hack here to update the args + # for each layer's attention op in the graph. + with torch.npu.stream(update_stream): + for key, param, handle, event in zip( + forward_context.attn_metadata, + graph_params.attn_params[runtime_shape], + graph_params.handles[runtime_shape], + graph_params.events[runtime_shape], + ): + ( + query, + key_cache, + value_cache, + num_kv_heads, + num_heads, + scale, + block_table, + seq_lens, + output, + ) = param + seq_lens = forward_context.attn_metadata[key].seq_lens + + # When using FULL_DECODE_ONLY, there are some rare bugs for FULL_DECODE_ONLY + # mode with GQA. This is triggered by getting workspace for _npu_paged_attention + # in torch_npu. On some rare cases, _npu_paged_attention with smaller seq_lens + # might encounter a bigger workspace, while currently we use max_model_len to + # calculate max workspace in capturing. So additional get_workspace is added + # here to avoid such bugs. + # TODO(Angazenn): we will remove this once _npu_paged_attention is fully + # replaced by npu_fused_infer_attention_score which does not contain such bugs. + workspace = torch_npu._npu_paged_attention_get_workspace( + query=query, + key_cache=key_cache, + value_cache=value_cache, + num_kv_heads=num_kv_heads, + num_heads=num_heads, + scale_value=scale, + block_table=block_table, + context_lens=seq_lens, + out=output) + torch.npu.graph_task_update_begin(update_stream, handle) + torch_npu._npu_paged_attention(query=query, + key_cache=key_cache, + value_cache=value_cache, + num_kv_heads=num_kv_heads, + num_heads=num_heads, + scale_value=scale, + block_table=block_table, + context_lens=seq_lens, + out=output, + workspace=workspace) + torch.npu.graph_task_update_end(update_stream) + + event.record(update_stream) + + +def _update_attn_fia_params(update_stream, forward_context, runtime_shape): graph_params = get_graph_params() # For Qwen3-next, since the kv_cache_config has already categorized # linear_attn and self_attn, the attn_metadata is first arranged with @@ -236,6 +296,13 @@ def update_attn_params(update_stream, forward_context, runtime_shape): event.record(update_stream) +def update_attn_params(update_stream, forward_context, runtime_shape): + if using_paged_attention(runtime_shape): + _update_attn_pa_params(update_stream, forward_context, runtime_shape) + else: + _update_attn_fia_params(update_stream, forward_context, runtime_shape) + + def update_mla_attn_params(update_stream, forward_context, runtime_shape, speculative_config): if forward_context.is_mtp_model: @@ -446,7 +513,7 @@ def set_graph_params(aclgraph_capture_sizes: list[int]): ) -def update_graph_params_workspaces(num_tokens: int, workspace: int): +def update_graph_params_workspaces(num_tokens: int, workspace: torch.Tensor): global _graph_params if _graph_params is not None: _graph_params.workspaces[num_tokens] = weak_ref_tensors(workspace)