fix pagedattention to support fullgraph. (#3436)
### What this PR does / why we need it? Calculate in advance the workspace memory size needed for the PagedAttention operator to avoid deadlocks during resource cleanup. This PR requires torch_npu version 0920 or newer. ### Does this PR introduce _any_ user-facing change? ### How was this patch tested? - vLLM version: v0.11.0rc3 - vLLM main: https://github.com/vllm-project/vllm/commit/v0.11.0 Signed-off-by: wangxiaoxin-sherie <wangxiaoxin7@huawei.com> Co-authored-by: wangxiaoxin-sherie <wangxiaoxin7@huawei.com>
This commit is contained in:
@@ -33,8 +33,10 @@ from vllm.v1.kv_cache_interface import AttentionSpec
|
||||
|
||||
from vllm_ascend.attention.utils import (AscendCommonAttentionMetadata,
|
||||
maybe_save_kv_layer_to_connector,
|
||||
version_check,
|
||||
wait_for_kv_layer_from_connector)
|
||||
from vllm_ascend.compilation.acl_graph import get_graph_params
|
||||
from vllm_ascend.compilation.acl_graph import (get_graph_params,
|
||||
update_graph_params_workspaces)
|
||||
from vllm_ascend.ops.attention import vanilla_chunked_prefill
|
||||
from vllm_ascend.utils import (ACL_FORMAT_FRACTAL_NZ, aligned_16, is_310p,
|
||||
nd_to_nz_2d, nd_to_nz_spec)
|
||||
@@ -289,6 +291,7 @@ class AscendAttentionBackendImpl(AttentionImpl):
|
||||
self.num_queries_per_kv = self.num_heads // self.num_kv_heads
|
||||
self.key_cache = None
|
||||
self.value_cache = None
|
||||
self.torch_npu_check = version_check()
|
||||
|
||||
def _forward_prefill_no_cache(
|
||||
self,
|
||||
@@ -396,13 +399,29 @@ class AscendAttentionBackendImpl(AttentionImpl):
|
||||
forward_context: ForwardContext = get_forward_context()
|
||||
num_tokens = query.shape[0]
|
||||
if forward_context.capturing:
|
||||
if self.torch_npu_check:
|
||||
# Get workspace from cache or calculate it if not present.
|
||||
workspace = graph_params.workspaces.get(num_tokens)
|
||||
if workspace is None:
|
||||
workspace = torch_npu._npu_paged_attention_get_workspace(
|
||||
query=query,
|
||||
key_cache=self.key_cache,
|
||||
value_cache=self.value_cache,
|
||||
num_kv_heads=self.num_kv_heads,
|
||||
num_heads=self.num_heads,
|
||||
scale_value=self.scale,
|
||||
block_table=attn_metadata.block_tables,
|
||||
context_lens=attn_metadata.seq_lens,
|
||||
out=output)
|
||||
update_graph_params_workspaces(num_tokens, workspace)
|
||||
|
||||
# Handle graph capturing mode
|
||||
stream = torch_npu.npu.current_stream()
|
||||
|
||||
event = torch.npu.ExternalEvent()
|
||||
event.wait(stream)
|
||||
event.reset(stream)
|
||||
graph_params.events[num_tokens].append(event)
|
||||
|
||||
graph_params.attn_params[num_tokens].append((
|
||||
weak_ref_tensors(query),
|
||||
weak_ref_tensors(self.key_cache),
|
||||
@@ -416,16 +435,30 @@ class AscendAttentionBackendImpl(AttentionImpl):
|
||||
))
|
||||
|
||||
torch.npu.graph_task_group_begin(stream)
|
||||
torch_npu._npu_paged_attention(
|
||||
query=query,
|
||||
key_cache=self.key_cache,
|
||||
value_cache=self.value_cache,
|
||||
num_kv_heads=self.num_kv_heads,
|
||||
num_heads=self.num_heads,
|
||||
scale_value=self.scale,
|
||||
block_table=attn_metadata.block_tables,
|
||||
context_lens=attn_metadata.seq_lens,
|
||||
out=output)
|
||||
|
||||
if self.torch_npu_check:
|
||||
torch_npu._npu_paged_attention(
|
||||
query=query,
|
||||
key_cache=self.key_cache,
|
||||
value_cache=self.value_cache,
|
||||
num_kv_heads=self.num_kv_heads,
|
||||
num_heads=self.num_heads,
|
||||
scale_value=self.scale,
|
||||
block_table=attn_metadata.block_tables,
|
||||
context_lens=attn_metadata.seq_lens,
|
||||
out=output,
|
||||
workspace=workspace)
|
||||
else:
|
||||
torch_npu._npu_paged_attention(
|
||||
query=query,
|
||||
key_cache=self.key_cache,
|
||||
value_cache=self.value_cache,
|
||||
num_kv_heads=self.num_kv_heads,
|
||||
num_heads=self.num_heads,
|
||||
scale_value=self.scale,
|
||||
block_table=attn_metadata.block_tables,
|
||||
context_lens=attn_metadata.seq_lens,
|
||||
out=output)
|
||||
handle = torch.npu.graph_task_group_end(stream)
|
||||
graph_params.handles[num_tokens].append(handle)
|
||||
else:
|
||||
|
||||
@@ -1,7 +1,9 @@
|
||||
import functools
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, List
|
||||
|
||||
import torch
|
||||
import torch_npu
|
||||
from vllm.distributed.kv_transfer import (get_kv_transfer_group,
|
||||
has_kv_transfer_group,
|
||||
is_v1_kv_transfer_group)
|
||||
@@ -139,3 +141,17 @@ def maybe_save_kv_layer_to_connector(
|
||||
return
|
||||
# TODO: assert ascendMetadata
|
||||
connector.save_kv_layer(layer_name, kv_cache_layer, attn_metadata)
|
||||
|
||||
|
||||
@functools.cache
|
||||
def version_check():
|
||||
import re
|
||||
torch_npu_version = torch_npu.version.__version__
|
||||
date_pattern = r'dev(\d{8})'
|
||||
|
||||
match = re.search(date_pattern, torch_npu_version)
|
||||
if match:
|
||||
full_date = match.group(1)
|
||||
if full_date >= "20250919":
|
||||
return True
|
||||
return False
|
||||
|
||||
Reference in New Issue
Block a user