add pagedattention to support FULL_DECODE_ONLY. (#3102)
### What this PR does / why we need it? Calculate in advance the workspace memory size needed for the PagedAttention operator to avoid deadlocks during resource cleanup. This PR requires torch_npu version 0920 or newer. ### How was this patch tested? - vLLM version: v0.11.0 --------- Signed-off-by: wangxiaoxin-sherie <wangxiaoxin7@huawei.com> Co-authored-by: wangxiaoxin-sherie <wangxiaoxin7@huawei.com>
This commit is contained in:
@@ -34,7 +34,8 @@ from vllm.v1.kv_cache_interface import AttentionSpec
|
||||
from vllm_ascend.attention.utils import (AscendCommonAttentionMetadata,
|
||||
maybe_save_kv_layer_to_connector,
|
||||
wait_for_kv_layer_from_connector)
|
||||
from vllm_ascend.compilation.acl_graph import get_graph_params
|
||||
from vllm_ascend.compilation.acl_graph import (get_graph_params,
|
||||
update_graph_params_workspaces)
|
||||
from vllm_ascend.ops.attention import vanilla_chunked_prefill
|
||||
from vllm_ascend.utils import (ACL_FORMAT_FRACTAL_NZ, aligned_16, is_310p,
|
||||
nd_to_nz_2d, nd_to_nz_spec)
|
||||
@@ -393,13 +394,28 @@ class AscendAttentionBackendImpl(AttentionImpl):
|
||||
forward_context: ForwardContext = get_forward_context()
|
||||
num_tokens = query.shape[0]
|
||||
if forward_context.capturing:
|
||||
# Get workspace from cache or calculate it if not present.
|
||||
workspace = graph_params.workspaces.get(num_tokens)
|
||||
if workspace is None:
|
||||
workspace = torch_npu._npu_paged_attention_get_workspace(
|
||||
query=query,
|
||||
key_cache=self.key_cache,
|
||||
value_cache=self.value_cache,
|
||||
num_kv_heads=self.num_kv_heads,
|
||||
num_heads=self.num_heads,
|
||||
scale_value=self.scale,
|
||||
block_table=attn_metadata.block_tables,
|
||||
context_lens=attn_metadata.seq_lens,
|
||||
out=output)
|
||||
update_graph_params_workspaces(num_tokens, workspace)
|
||||
|
||||
# Handle graph capturing mode
|
||||
stream = torch_npu.npu.current_stream()
|
||||
|
||||
event = torch.npu.ExternalEvent()
|
||||
event.wait(stream)
|
||||
event.reset(stream)
|
||||
graph_params.events[num_tokens].append(event)
|
||||
|
||||
graph_params.attn_params[num_tokens].append((
|
||||
query,
|
||||
self.key_cache,
|
||||
@@ -413,6 +429,7 @@ class AscendAttentionBackendImpl(AttentionImpl):
|
||||
))
|
||||
|
||||
torch.npu.graph_task_group_begin(stream)
|
||||
|
||||
torch_npu._npu_paged_attention(
|
||||
query=query,
|
||||
key_cache=self.key_cache,
|
||||
@@ -422,7 +439,8 @@ class AscendAttentionBackendImpl(AttentionImpl):
|
||||
scale_value=self.scale,
|
||||
block_table=attn_metadata.block_tables,
|
||||
context_lens=attn_metadata.seq_lens,
|
||||
out=output)
|
||||
out=output,
|
||||
workspace=workspace)
|
||||
handle = torch.npu.graph_task_group_end(stream)
|
||||
graph_params.handles[num_tokens].append(handle)
|
||||
else:
|
||||
|
||||
Reference in New Issue
Block a user