support FULL graph mode for GQA (#3970)
### What this PR does / why we need it?
The current library only supports the FullDecodeOnly graph mode, which
enables full graph execution during the decode. This PR extends support
to allow full graph execution in both the prefill and decode, referred
to as FULL graph mode.
- vLLM version: v0.11.0
- vLLM main:
2918c1b49c
Signed-off-by: wangxiaoxin-sherie <wangxiaoxin7@huawei.com>
Co-authored-by: wangxiaoxin-sherie <wangxiaoxin7@huawei.com>
This commit is contained in:
@@ -203,48 +203,31 @@ def update_attn_params(update_stream, forward_context, runtime_shape):
|
||||
graph_params.handles[runtime_shape],
|
||||
graph_params.events[runtime_shape],
|
||||
):
|
||||
(
|
||||
query,
|
||||
key_cache,
|
||||
value_cache,
|
||||
num_kv_heads,
|
||||
num_heads,
|
||||
scale,
|
||||
block_table,
|
||||
seq_lens,
|
||||
output,
|
||||
) = param
|
||||
seq_lens = forward_context.attn_metadata[key].seq_lens
|
||||
(query, key_cache, value, block_tables, attn_mask, block_size,
|
||||
seq_lens, query_start_loc, num_kv_heads, num_heads, scale,
|
||||
attn_output, softmax_lse) = param
|
||||
|
||||
# When using FULL_DECODE_ONLY, there are some rare bugs for FULL_DECODE_ONLY
|
||||
# mode with GQA. This is triggered by getting workspace for _npu_paged_attention
|
||||
# in torch_npu. On some rare cases, _npu_paged_attention with smaller seq_lens
|
||||
# might encounter a bigger workspace, while currently we use max_model_len to
|
||||
# calculate max workspace in capturing. So additional get_workspace is added
|
||||
# here to avoid such bugs.
|
||||
# TODO(Angazenn): we will remove this once _npu_paged_attention is fully
|
||||
# replaced by npu_fused_infer_attention_score which does not contain such bugs.
|
||||
workspace = torch_npu._npu_paged_attention_get_workspace(
|
||||
query=query,
|
||||
key_cache=key_cache,
|
||||
value_cache=value_cache,
|
||||
num_kv_heads=num_kv_heads,
|
||||
num_heads=num_heads,
|
||||
scale_value=scale,
|
||||
block_table=block_table,
|
||||
context_lens=seq_lens,
|
||||
out=output)
|
||||
seq_lens = forward_context.attn_metadata[key].seq_lens_list
|
||||
query_start_loc = forward_context.attn_metadata[
|
||||
key].query_start_loc_list
|
||||
torch.npu.graph_task_update_begin(update_stream, handle)
|
||||
torch_npu._npu_paged_attention(query=query,
|
||||
key_cache=key_cache,
|
||||
value_cache=value_cache,
|
||||
num_kv_heads=num_kv_heads,
|
||||
num_heads=num_heads,
|
||||
scale_value=scale,
|
||||
block_table=block_table,
|
||||
context_lens=seq_lens,
|
||||
out=output,
|
||||
workspace=workspace)
|
||||
torch_npu.npu_fused_infer_attention_score.out(
|
||||
query=query,
|
||||
key=key_cache,
|
||||
value=value,
|
||||
block_table=block_tables,
|
||||
atten_mask=attn_mask,
|
||||
input_layout="TND",
|
||||
block_size=block_size,
|
||||
actual_seq_lengths=query_start_loc,
|
||||
actual_seq_lengths_kv=seq_lens,
|
||||
num_key_value_heads=num_kv_heads,
|
||||
num_heads=num_heads,
|
||||
scale=scale,
|
||||
sparse_mode=3,
|
||||
workspace=graph_params.workspaces.get(runtime_shape),
|
||||
out=[attn_output, softmax_lse],
|
||||
)
|
||||
torch.npu.graph_task_update_end(update_stream)
|
||||
|
||||
event.record(update_stream)
|
||||
@@ -446,10 +429,10 @@ def set_graph_params(aclgraph_capture_sizes: set[int]):
|
||||
)
|
||||
|
||||
|
||||
def update_graph_params_workspaces(num_tokens: int, workspace: Any):
|
||||
def update_graph_params_workspaces(num_tokens: int, workspace: int):
|
||||
global _graph_params
|
||||
if _graph_params is not None:
|
||||
_graph_params.workspaces[num_tokens] = workspace
|
||||
_graph_params.workspaces[num_tokens] = weak_ref_tensors(workspace)
|
||||
|
||||
|
||||
def get_graph_params():
|
||||
|
||||
Reference in New Issue
Block a user