[Refactor][Graph] Move graph parameter logic to acl_graph module (#3101)
### What this PR does / why we need it? This is the follow-up PR of #2128 . Moves graph parameter management components, including `GraphParams`, `get_graph_params`, and `set_graph_params`, from the generic `utils.py` to the more specific `compilation/acl_graph.py`. Additionally, extracts the `update_attn_params` logic from the `NPUModelRunner` class into a standalone function within the `acl_graph` module. This refactoring improves code organization by centralizing ACL graph-related logic into its own dedicated module, enhancing modularity and clarity. ### Does this PR introduce _any_ user-facing change? None. ### How was this patch tested? None needed. Signed-off-by: Yizhou Liu <liu_yizhou@outlook.com>
This commit is contained in:
@@ -3,10 +3,12 @@
|
||||
|
||||
import dataclasses
|
||||
from contextlib import ExitStack
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Callable, Optional
|
||||
from unittest.mock import patch
|
||||
|
||||
import torch
|
||||
import torch_npu
|
||||
import vllm.envs as envs
|
||||
from vllm.compilation.counter import compilation_counter
|
||||
from vllm.compilation.cuda_graph import CUDAGraphOptions
|
||||
@@ -185,3 +187,74 @@ class ACLGraphWrapper:
|
||||
logger.info_once("Replaying aclgraph")
|
||||
entry.aclgraph.replay()
|
||||
return entry.output
|
||||
|
||||
|
||||
def update_attn_params(update_stream, forward_context, runtime_shape):
|
||||
graph_params = get_graph_params()
|
||||
# FIXME: Behold! We are using a temporary hack here to update the args
|
||||
# for each layer's attention op in the graph.
|
||||
for key, param, handle, event in zip(
|
||||
forward_context.attn_metadata,
|
||||
graph_params.attn_params[runtime_shape],
|
||||
graph_params.handles[runtime_shape],
|
||||
graph_params.events[runtime_shape],
|
||||
):
|
||||
(
|
||||
query,
|
||||
key_cache,
|
||||
value_cache,
|
||||
num_kv_heads,
|
||||
num_heads,
|
||||
scale,
|
||||
block_table,
|
||||
seq_lens,
|
||||
output,
|
||||
) = param
|
||||
# block_table = forward_context.attn_metadata[key].block_tables
|
||||
seq_lens = forward_context.attn_metadata[key].seq_lens
|
||||
|
||||
with torch.npu.stream(update_stream):
|
||||
torch.npu.graph_task_update_begin(update_stream, handle)
|
||||
torch_npu._npu_paged_attention(query=query,
|
||||
key_cache=key_cache,
|
||||
value_cache=value_cache,
|
||||
num_kv_heads=num_kv_heads,
|
||||
num_heads=num_heads,
|
||||
scale_value=scale,
|
||||
block_table=block_table,
|
||||
context_lens=seq_lens,
|
||||
out=output)
|
||||
torch.npu.graph_task_update_end(update_stream)
|
||||
|
||||
event.record(update_stream)
|
||||
|
||||
|
||||
@dataclass
|
||||
class GraphParams:
|
||||
events: dict[int, list[torch.npu.ExternalEvent]]
|
||||
workspaces: dict[int, torch.Tensor]
|
||||
handles: dict[int, list[torch_npu._C._NPUTaskGroupHandle]]
|
||||
attn_params: dict[int, list[tuple]]
|
||||
|
||||
|
||||
_graph_params: Optional[GraphParams] = None
|
||||
|
||||
|
||||
def set_graph_params(aclgraph_capture_sizes: set[int]):
|
||||
global _graph_params
|
||||
if _graph_params is not None:
|
||||
raise ValueError("Graph parameters have already been set!")
|
||||
_graph_params = GraphParams(
|
||||
{size: []
|
||||
for size in aclgraph_capture_sizes},
|
||||
{size: None
|
||||
for size in aclgraph_capture_sizes},
|
||||
{size: []
|
||||
for size in aclgraph_capture_sizes},
|
||||
{size: []
|
||||
for size in aclgraph_capture_sizes},
|
||||
)
|
||||
|
||||
|
||||
def get_graph_params():
|
||||
return _graph_params
|
||||
|
||||
Reference in New Issue
Block a user