[Refactor][Graph] Move graph parameter logic to acl_graph module (#3101)

### What this PR does / why we need it?
This is the follow-up PR of #2128 .

Moves graph parameter management components, including `GraphParams`,
`get_graph_params`, and `set_graph_params`, from the generic `utils.py`
to the more specific `compilation/acl_graph.py`.

Additionally, extracts the `update_attn_params` logic from the
`NPUModelRunner` class into a standalone function within the `acl_graph`
module.

This refactoring improves code organization by centralizing ACL
graph-related logic into its own dedicated module, enhancing modularity
and clarity.

### Does this PR introduce _any_ user-facing change?
None.

### How was this patch tested?
None needed.

Signed-off-by: Yizhou Liu <liu_yizhou@outlook.com>
This commit is contained in:
Yizhou
2025-09-22 22:23:14 +08:00
committed by GitHub
parent 02f89d166f
commit 3fa7cf6345
4 changed files with 84 additions and 81 deletions

View File

@@ -36,10 +36,10 @@ from vllm.v1.core.sched.output import SchedulerOutput
from vllm.v1.kv_cache_interface import AttentionSpec
from vllm_ascend.attention.utils import AscendCommonAttentionMetadata
from vllm_ascend.compilation.acl_graph import get_graph_params
from vllm_ascend.ops.attention import vanilla_chunked_prefill
from vllm_ascend.utils import (ACL_FORMAT_FRACTAL_NZ, aligned_16,
get_graph_params, is_310p, nd_to_nz_2d,
nd_to_nz_spec)
from vllm_ascend.utils import (ACL_FORMAT_FRACTAL_NZ, aligned_16, is_310p,
nd_to_nz_2d, nd_to_nz_spec)
def wait_for_kv_layer_from_connector(layer_name: str):

View File

@@ -3,10 +3,12 @@
import dataclasses
from contextlib import ExitStack
from dataclasses import dataclass
from typing import Any, Callable, Optional
from unittest.mock import patch
import torch
import torch_npu
import vllm.envs as envs
from vllm.compilation.counter import compilation_counter
from vllm.compilation.cuda_graph import CUDAGraphOptions
@@ -185,3 +187,74 @@ class ACLGraphWrapper:
logger.info_once("Replaying aclgraph")
entry.aclgraph.replay()
return entry.output
def update_attn_params(update_stream, forward_context, runtime_shape):
graph_params = get_graph_params()
# FIXME: Behold! We are using a temporary hack here to update the args
# for each layer's attention op in the graph.
for key, param, handle, event in zip(
forward_context.attn_metadata,
graph_params.attn_params[runtime_shape],
graph_params.handles[runtime_shape],
graph_params.events[runtime_shape],
):
(
query,
key_cache,
value_cache,
num_kv_heads,
num_heads,
scale,
block_table,
seq_lens,
output,
) = param
# block_table = forward_context.attn_metadata[key].block_tables
seq_lens = forward_context.attn_metadata[key].seq_lens
with torch.npu.stream(update_stream):
torch.npu.graph_task_update_begin(update_stream, handle)
torch_npu._npu_paged_attention(query=query,
key_cache=key_cache,
value_cache=value_cache,
num_kv_heads=num_kv_heads,
num_heads=num_heads,
scale_value=scale,
block_table=block_table,
context_lens=seq_lens,
out=output)
torch.npu.graph_task_update_end(update_stream)
event.record(update_stream)
@dataclass
class GraphParams:
events: dict[int, list[torch.npu.ExternalEvent]]
workspaces: dict[int, torch.Tensor]
handles: dict[int, list[torch_npu._C._NPUTaskGroupHandle]]
attn_params: dict[int, list[tuple]]
_graph_params: Optional[GraphParams] = None
def set_graph_params(aclgraph_capture_sizes: set[int]):
global _graph_params
if _graph_params is not None:
raise ValueError("Graph parameters have already been set!")
_graph_params = GraphParams(
{size: []
for size in aclgraph_capture_sizes},
{size: None
for size in aclgraph_capture_sizes},
{size: []
for size in aclgraph_capture_sizes},
{size: []
for size in aclgraph_capture_sizes},
)
def get_graph_params():
return _graph_params

View File

@@ -22,13 +22,12 @@ import functools
import math
import os
from contextlib import contextmanager, nullcontext
from dataclasses import dataclass
from enum import Enum
from threading import Lock
from typing import TYPE_CHECKING, Any, List, Optional, Tuple, Union
import torch
import torch_npu # noqa: F401 # noqa: F401
import torch_npu # noqa: F401
from packaging.version import InvalidVersion, Version
from torch_npu.npu.streams import Event
from vllm.logger import logger
@@ -635,34 +634,3 @@ def npu_stream_switch(target_stream: torch.npu.Stream,
return nullcontext()
assert target_stream is not None
return torch.npu.stream(target_stream)
@dataclass
class GraphParams:
events: dict[int, list[torch.npu.ExternalEvent]]
workspaces: dict[int, torch.Tensor]
handles: dict[int, list[torch_npu._C._NPUTaskGroupHandle]]
attn_params: dict[int, list[tuple]]
_graph_params: Optional[GraphParams] = None
def set_graph_params(aclgraph_capture_sizes: set[int]):
global _graph_params
if _graph_params is not None:
raise ValueError("Graph parameters have already been set!")
_graph_params = GraphParams(
{size: []
for size in aclgraph_capture_sizes},
{size: None
for size in aclgraph_capture_sizes},
{size: []
for size in aclgraph_capture_sizes},
{size: []
for size in aclgraph_capture_sizes},
)
def get_graph_params():
return _graph_params

View File

@@ -99,7 +99,9 @@ from vllm_ascend.ascend_forward_context import (MoECommType,
from vllm_ascend.attention.attention_mask import AttentionMaskBuilder
from vllm_ascend.attention.attention_v1 import AscendAttentionState
from vllm_ascend.attention.utils import AscendCommonAttentionMetadata
from vllm_ascend.compilation.acl_graph import ACLGraphWrapper
from vllm_ascend.compilation.acl_graph import (ACLGraphWrapper,
set_graph_params,
update_attn_params)
from vllm_ascend.eplb.adaptor.vllm_adaptor import VllmEplbAdaptor
from vllm_ascend.eplb.core.eplb_device_transfer_loader import \
D2DExpertWeightLoader
@@ -117,9 +119,8 @@ from vllm_ascend.spec_decode.interface import SpecDcodeType
from vllm_ascend.spec_decode.mtp_proposer import MtpProposer
from vllm_ascend.utils import (ACL_FORMAT_FRACTAL_ND, ACL_FORMAT_FRACTAL_NZ,
AscendSocVersion, ProfileExecuteDuration,
get_ascend_soc_version, get_graph_params,
is_310p, lmhead_tp_enable, set_graph_params,
vllm_version_is)
get_ascend_soc_version, is_310p,
lmhead_tp_enable, vllm_version_is)
from vllm_ascend.worker.npu_input_batch import CachedRequestState, InputBatch
if TYPE_CHECKING:
@@ -1571,9 +1572,8 @@ class NPUModelRunner(LoRAModelRunnerMixin):
forward_context = get_forward_context()
if forward_context.cudagraph_runtime_mode == CUDAGraphMode.FULL:
graph_params = get_graph_params()
self.update_attn_params(graph_params, forward_context,
positions.shape[0])
update_attn_params(self.update_stream, forward_context,
positions.shape[0])
if get_forward_context().flashcomm_v1_enabled:
hidden_states = tensor_model_parallel_all_gather(hidden_states, 0)
@@ -1582,44 +1582,6 @@ class NPUModelRunner(LoRAModelRunnerMixin):
hidden_states = hidden_states[:-pad_size, :]
return hidden_states
def update_attn_params(self, graph_params, forward_context, runtime_shape):
# FIXME: Behold! We are using a temporary hack here to update the args
# for each layer's attention op in the graph.
for key, param, handle, event in zip(
forward_context.attn_metadata,
graph_params.attn_params[runtime_shape],
graph_params.handles[runtime_shape],
graph_params.events[runtime_shape],
):
(
query,
key_cache,
value_cache,
num_kv_heads,
num_heads,
scale,
block_table,
seq_lens,
output,
) = param
# block_table = forward_context.attn_metadata[key].block_tables
seq_lens = forward_context.attn_metadata[key].seq_lens
with torch.npu.stream(self.update_stream):
torch.npu.graph_task_update_begin(self.update_stream, handle)
torch_npu._npu_paged_attention(query=query,
key_cache=key_cache,
value_cache=value_cache,
num_kv_heads=num_kv_heads,
num_heads=num_heads,
scale_value=scale,
block_table=block_table,
context_lens=seq_lens,
out=output)
torch.npu.graph_task_update_end(self.update_stream)
event.record(self.update_stream)
def _build_attn_state(self, num_reqs, num_scheduled_tokens,
num_valid_tokens):
ascend_config = get_ascend_config()