support FULL graph mode for GQA (#3970)

### What this PR does / why we need it?
The current library only supports the FullDecodeOnly graph mode, which
enables full graph execution during the decode. This PR extends support
to allow full graph execution in both the prefill and decode, referred
to as FULL graph mode.

- vLLM version: v0.11.0
- vLLM main:
2918c1b49c

Signed-off-by: wangxiaoxin-sherie <wangxiaoxin7@huawei.com>
Co-authored-by: wangxiaoxin-sherie <wangxiaoxin7@huawei.com>
This commit is contained in:
XiaoxinWang
2025-11-17 10:50:35 +08:00
committed by GitHub
parent c334114f69
commit e38ef2c434
11 changed files with 328 additions and 296 deletions

View File

@@ -203,48 +203,31 @@ def update_attn_params(update_stream, forward_context, runtime_shape):
graph_params.handles[runtime_shape],
graph_params.events[runtime_shape],
):
(
query,
key_cache,
value_cache,
num_kv_heads,
num_heads,
scale,
block_table,
seq_lens,
output,
) = param
seq_lens = forward_context.attn_metadata[key].seq_lens
(query, key_cache, value, block_tables, attn_mask, block_size,
seq_lens, query_start_loc, num_kv_heads, num_heads, scale,
attn_output, softmax_lse) = param
# When using FULL_DECODE_ONLY, there are some rare bugs for FULL_DECODE_ONLY
# mode with GQA. This is triggered by getting workspace for _npu_paged_attention
# in torch_npu. On some rare cases, _npu_paged_attention with smaller seq_lens
# might encounter a bigger workspace, while currently we use max_model_len to
# calculate max workspace in capturing. So additional get_workspace is added
# here to avoid such bugs.
# TODO(Angazenn): we will remove this once _npu_paged_attention is fully
# replaced by npu_fused_infer_attention_score which does not contain such bugs.
workspace = torch_npu._npu_paged_attention_get_workspace(
query=query,
key_cache=key_cache,
value_cache=value_cache,
num_kv_heads=num_kv_heads,
num_heads=num_heads,
scale_value=scale,
block_table=block_table,
context_lens=seq_lens,
out=output)
seq_lens = forward_context.attn_metadata[key].seq_lens_list
query_start_loc = forward_context.attn_metadata[
key].query_start_loc_list
torch.npu.graph_task_update_begin(update_stream, handle)
torch_npu._npu_paged_attention(query=query,
key_cache=key_cache,
value_cache=value_cache,
num_kv_heads=num_kv_heads,
num_heads=num_heads,
scale_value=scale,
block_table=block_table,
context_lens=seq_lens,
out=output,
workspace=workspace)
torch_npu.npu_fused_infer_attention_score.out(
query=query,
key=key_cache,
value=value,
block_table=block_tables,
atten_mask=attn_mask,
input_layout="TND",
block_size=block_size,
actual_seq_lengths=query_start_loc,
actual_seq_lengths_kv=seq_lens,
num_key_value_heads=num_kv_heads,
num_heads=num_heads,
scale=scale,
sparse_mode=3,
workspace=graph_params.workspaces.get(runtime_shape),
out=[attn_output, softmax_lse],
)
torch.npu.graph_task_update_end(update_stream)
event.record(update_stream)
@@ -446,10 +429,10 @@ def set_graph_params(aclgraph_capture_sizes: set[int]):
)
def update_graph_params_workspaces(num_tokens: int, workspace: Any):
def update_graph_params_workspaces(num_tokens: int, workspace: int):
global _graph_params
if _graph_params is not None:
_graph_params.workspaces[num_tokens] = workspace
_graph_params.workspaces[num_tokens] = weak_ref_tensors(workspace)
def get_graph_params():