add pagedattention to support FULL_DECODE_ONLY. (#3102)

### What this PR does / why we need it?
Calculate in advance the workspace memory size needed for the
PagedAttention operator to avoid deadlocks during resource cleanup. This
PR requires torch_npu version 0920 or newer.
### How was this patch tested?


- vLLM version: v0.11.0

---------

Signed-off-by: wangxiaoxin-sherie <wangxiaoxin7@huawei.com>
Co-authored-by: wangxiaoxin-sherie <wangxiaoxin7@huawei.com>
This commit is contained in:
XiaoxinWang
2025-10-10 08:50:33 +08:00
committed by GitHub
parent 1c2c72af8d
commit 579b7e5f21
5 changed files with 245 additions and 12 deletions

View File

@@ -215,15 +215,17 @@ def update_attn_params(update_stream, forward_context, runtime_shape):
with torch.npu.stream(update_stream):
torch.npu.graph_task_update_begin(update_stream, handle)
torch_npu._npu_paged_attention(query=query,
key_cache=key_cache,
value_cache=value_cache,
num_kv_heads=num_kv_heads,
num_heads=num_heads,
scale_value=scale,
block_table=block_table,
context_lens=seq_lens,
out=output)
torch_npu._npu_paged_attention(
query=query,
key_cache=key_cache,
value_cache=value_cache,
num_kv_heads=num_kv_heads,
num_heads=num_heads,
scale_value=scale,
block_table=block_table,
context_lens=seq_lens,
out=output,
workspace=graph_params.workspaces.get(runtime_shape))
torch.npu.graph_task_update_end(update_stream)
event.record(update_stream)
@@ -256,5 +258,11 @@ def set_graph_params(aclgraph_capture_sizes: set[int]):
)
def update_graph_params_workspaces(num_tokens: int, workspace: int):
global _graph_params
if _graph_params is not None:
_graph_params.workspaces[num_tokens] = workspace
def get_graph_params():
return _graph_params