add pagedattention to support FULL_DECODE_ONLY. (#3102)
### What this PR does / why we need it? Calculate in advance the workspace memory size needed for the PagedAttention operator to avoid deadlocks during resource cleanup. This PR requires torch_npu version 0920 or newer. ### How was this patch tested? - vLLM version: v0.11.0 --------- Signed-off-by: wangxiaoxin-sherie <wangxiaoxin7@huawei.com> Co-authored-by: wangxiaoxin-sherie <wangxiaoxin7@huawei.com>
This commit is contained in:
@@ -215,15 +215,17 @@ def update_attn_params(update_stream, forward_context, runtime_shape):
|
||||
|
||||
with torch.npu.stream(update_stream):
|
||||
torch.npu.graph_task_update_begin(update_stream, handle)
|
||||
torch_npu._npu_paged_attention(query=query,
|
||||
key_cache=key_cache,
|
||||
value_cache=value_cache,
|
||||
num_kv_heads=num_kv_heads,
|
||||
num_heads=num_heads,
|
||||
scale_value=scale,
|
||||
block_table=block_table,
|
||||
context_lens=seq_lens,
|
||||
out=output)
|
||||
torch_npu._npu_paged_attention(
|
||||
query=query,
|
||||
key_cache=key_cache,
|
||||
value_cache=value_cache,
|
||||
num_kv_heads=num_kv_heads,
|
||||
num_heads=num_heads,
|
||||
scale_value=scale,
|
||||
block_table=block_table,
|
||||
context_lens=seq_lens,
|
||||
out=output,
|
||||
workspace=graph_params.workspaces.get(runtime_shape))
|
||||
torch.npu.graph_task_update_end(update_stream)
|
||||
|
||||
event.record(update_stream)
|
||||
@@ -256,5 +258,11 @@ def set_graph_params(aclgraph_capture_sizes: set[int]):
|
||||
)
|
||||
|
||||
|
||||
def update_graph_params_workspaces(num_tokens: int, workspace: int):
|
||||
global _graph_params
|
||||
if _graph_params is not None:
|
||||
_graph_params.workspaces[num_tokens] = workspace
|
||||
|
||||
|
||||
def get_graph_params():
|
||||
return _graph_params
|
||||
|
||||
Reference in New Issue
Block a user