[Refactor][Graph] Move graph parameter logic to acl_graph module (#3101)

### What this PR does / why we need it?
This is the follow-up PR of #2128 .

Moves graph parameter management components, including `GraphParams`,
`get_graph_params`, and `set_graph_params`, from the generic `utils.py`
to the more specific `compilation/acl_graph.py`.

Additionally, extracts the `update_attn_params` logic from the
`NPUModelRunner` class into a standalone function within the `acl_graph`
module.

This refactoring improves code organization by centralizing ACL
graph-related logic into its own dedicated module, enhancing modularity
and clarity.

### Does this PR introduce _any_ user-facing change?
None.

### How was this patch tested?
None needed.

Signed-off-by: Yizhou Liu <liu_yizhou@outlook.com>
This commit is contained in:
Yizhou
2025-09-22 22:23:14 +08:00
committed by GitHub
parent 02f89d166f
commit 3fa7cf6345
4 changed files with 84 additions and 81 deletions

View File

@@ -36,10 +36,10 @@ from vllm.v1.core.sched.output import SchedulerOutput
from vllm.v1.kv_cache_interface import AttentionSpec from vllm.v1.kv_cache_interface import AttentionSpec
from vllm_ascend.attention.utils import AscendCommonAttentionMetadata from vllm_ascend.attention.utils import AscendCommonAttentionMetadata
from vllm_ascend.compilation.acl_graph import get_graph_params
from vllm_ascend.ops.attention import vanilla_chunked_prefill from vllm_ascend.ops.attention import vanilla_chunked_prefill
from vllm_ascend.utils import (ACL_FORMAT_FRACTAL_NZ, aligned_16, from vllm_ascend.utils import (ACL_FORMAT_FRACTAL_NZ, aligned_16, is_310p,
get_graph_params, is_310p, nd_to_nz_2d, nd_to_nz_2d, nd_to_nz_spec)
nd_to_nz_spec)
def wait_for_kv_layer_from_connector(layer_name: str): def wait_for_kv_layer_from_connector(layer_name: str):

View File

@@ -3,10 +3,12 @@
import dataclasses import dataclasses
from contextlib import ExitStack from contextlib import ExitStack
from dataclasses import dataclass
from typing import Any, Callable, Optional from typing import Any, Callable, Optional
from unittest.mock import patch from unittest.mock import patch
import torch import torch
import torch_npu
import vllm.envs as envs import vllm.envs as envs
from vllm.compilation.counter import compilation_counter from vllm.compilation.counter import compilation_counter
from vllm.compilation.cuda_graph import CUDAGraphOptions from vllm.compilation.cuda_graph import CUDAGraphOptions
@@ -185,3 +187,74 @@ class ACLGraphWrapper:
logger.info_once("Replaying aclgraph") logger.info_once("Replaying aclgraph")
entry.aclgraph.replay() entry.aclgraph.replay()
return entry.output return entry.output
def update_attn_params(update_stream, forward_context, runtime_shape):
graph_params = get_graph_params()
# FIXME: Behold! We are using a temporary hack here to update the args
# for each layer's attention op in the graph.
for key, param, handle, event in zip(
forward_context.attn_metadata,
graph_params.attn_params[runtime_shape],
graph_params.handles[runtime_shape],
graph_params.events[runtime_shape],
):
(
query,
key_cache,
value_cache,
num_kv_heads,
num_heads,
scale,
block_table,
seq_lens,
output,
) = param
# block_table = forward_context.attn_metadata[key].block_tables
seq_lens = forward_context.attn_metadata[key].seq_lens
with torch.npu.stream(update_stream):
torch.npu.graph_task_update_begin(update_stream, handle)
torch_npu._npu_paged_attention(query=query,
key_cache=key_cache,
value_cache=value_cache,
num_kv_heads=num_kv_heads,
num_heads=num_heads,
scale_value=scale,
block_table=block_table,
context_lens=seq_lens,
out=output)
torch.npu.graph_task_update_end(update_stream)
event.record(update_stream)
@dataclass
class GraphParams:
events: dict[int, list[torch.npu.ExternalEvent]]
workspaces: dict[int, torch.Tensor]
handles: dict[int, list[torch_npu._C._NPUTaskGroupHandle]]
attn_params: dict[int, list[tuple]]
_graph_params: Optional[GraphParams] = None
def set_graph_params(aclgraph_capture_sizes: set[int]):
global _graph_params
if _graph_params is not None:
raise ValueError("Graph parameters have already been set!")
_graph_params = GraphParams(
{size: []
for size in aclgraph_capture_sizes},
{size: None
for size in aclgraph_capture_sizes},
{size: []
for size in aclgraph_capture_sizes},
{size: []
for size in aclgraph_capture_sizes},
)
def get_graph_params():
return _graph_params

View File

@@ -22,13 +22,12 @@ import functools
import math import math
import os import os
from contextlib import contextmanager, nullcontext from contextlib import contextmanager, nullcontext
from dataclasses import dataclass
from enum import Enum from enum import Enum
from threading import Lock from threading import Lock
from typing import TYPE_CHECKING, Any, List, Optional, Tuple, Union from typing import TYPE_CHECKING, Any, List, Optional, Tuple, Union
import torch import torch
import torch_npu # noqa: F401 # noqa: F401 import torch_npu # noqa: F401
from packaging.version import InvalidVersion, Version from packaging.version import InvalidVersion, Version
from torch_npu.npu.streams import Event from torch_npu.npu.streams import Event
from vllm.logger import logger from vllm.logger import logger
@@ -635,34 +634,3 @@ def npu_stream_switch(target_stream: torch.npu.Stream,
return nullcontext() return nullcontext()
assert target_stream is not None assert target_stream is not None
return torch.npu.stream(target_stream) return torch.npu.stream(target_stream)
@dataclass
class GraphParams:
events: dict[int, list[torch.npu.ExternalEvent]]
workspaces: dict[int, torch.Tensor]
handles: dict[int, list[torch_npu._C._NPUTaskGroupHandle]]
attn_params: dict[int, list[tuple]]
_graph_params: Optional[GraphParams] = None
def set_graph_params(aclgraph_capture_sizes: set[int]):
global _graph_params
if _graph_params is not None:
raise ValueError("Graph parameters have already been set!")
_graph_params = GraphParams(
{size: []
for size in aclgraph_capture_sizes},
{size: None
for size in aclgraph_capture_sizes},
{size: []
for size in aclgraph_capture_sizes},
{size: []
for size in aclgraph_capture_sizes},
)
def get_graph_params():
return _graph_params

View File

@@ -99,7 +99,9 @@ from vllm_ascend.ascend_forward_context import (MoECommType,
from vllm_ascend.attention.attention_mask import AttentionMaskBuilder from vllm_ascend.attention.attention_mask import AttentionMaskBuilder
from vllm_ascend.attention.attention_v1 import AscendAttentionState from vllm_ascend.attention.attention_v1 import AscendAttentionState
from vllm_ascend.attention.utils import AscendCommonAttentionMetadata from vllm_ascend.attention.utils import AscendCommonAttentionMetadata
from vllm_ascend.compilation.acl_graph import ACLGraphWrapper from vllm_ascend.compilation.acl_graph import (ACLGraphWrapper,
set_graph_params,
update_attn_params)
from vllm_ascend.eplb.adaptor.vllm_adaptor import VllmEplbAdaptor from vllm_ascend.eplb.adaptor.vllm_adaptor import VllmEplbAdaptor
from vllm_ascend.eplb.core.eplb_device_transfer_loader import \ from vllm_ascend.eplb.core.eplb_device_transfer_loader import \
D2DExpertWeightLoader D2DExpertWeightLoader
@@ -117,9 +119,8 @@ from vllm_ascend.spec_decode.interface import SpecDcodeType
from vllm_ascend.spec_decode.mtp_proposer import MtpProposer from vllm_ascend.spec_decode.mtp_proposer import MtpProposer
from vllm_ascend.utils import (ACL_FORMAT_FRACTAL_ND, ACL_FORMAT_FRACTAL_NZ, from vllm_ascend.utils import (ACL_FORMAT_FRACTAL_ND, ACL_FORMAT_FRACTAL_NZ,
AscendSocVersion, ProfileExecuteDuration, AscendSocVersion, ProfileExecuteDuration,
get_ascend_soc_version, get_graph_params, get_ascend_soc_version, is_310p,
is_310p, lmhead_tp_enable, set_graph_params, lmhead_tp_enable, vllm_version_is)
vllm_version_is)
from vllm_ascend.worker.npu_input_batch import CachedRequestState, InputBatch from vllm_ascend.worker.npu_input_batch import CachedRequestState, InputBatch
if TYPE_CHECKING: if TYPE_CHECKING:
@@ -1571,9 +1572,8 @@ class NPUModelRunner(LoRAModelRunnerMixin):
forward_context = get_forward_context() forward_context = get_forward_context()
if forward_context.cudagraph_runtime_mode == CUDAGraphMode.FULL: if forward_context.cudagraph_runtime_mode == CUDAGraphMode.FULL:
graph_params = get_graph_params() update_attn_params(self.update_stream, forward_context,
self.update_attn_params(graph_params, forward_context, positions.shape[0])
positions.shape[0])
if get_forward_context().flashcomm_v1_enabled: if get_forward_context().flashcomm_v1_enabled:
hidden_states = tensor_model_parallel_all_gather(hidden_states, 0) hidden_states = tensor_model_parallel_all_gather(hidden_states, 0)
@@ -1582,44 +1582,6 @@ class NPUModelRunner(LoRAModelRunnerMixin):
hidden_states = hidden_states[:-pad_size, :] hidden_states = hidden_states[:-pad_size, :]
return hidden_states return hidden_states
def update_attn_params(self, graph_params, forward_context, runtime_shape):
# FIXME: Behold! We are using a temporary hack here to update the args
# for each layer's attention op in the graph.
for key, param, handle, event in zip(
forward_context.attn_metadata,
graph_params.attn_params[runtime_shape],
graph_params.handles[runtime_shape],
graph_params.events[runtime_shape],
):
(
query,
key_cache,
value_cache,
num_kv_heads,
num_heads,
scale,
block_table,
seq_lens,
output,
) = param
# block_table = forward_context.attn_metadata[key].block_tables
seq_lens = forward_context.attn_metadata[key].seq_lens
with torch.npu.stream(self.update_stream):
torch.npu.graph_task_update_begin(self.update_stream, handle)
torch_npu._npu_paged_attention(query=query,
key_cache=key_cache,
value_cache=value_cache,
num_kv_heads=num_kv_heads,
num_heads=num_heads,
scale_value=scale,
block_table=block_table,
context_lens=seq_lens,
out=output)
torch.npu.graph_task_update_end(self.update_stream)
event.record(self.update_stream)
def _build_attn_state(self, num_reqs, num_scheduled_tokens, def _build_attn_state(self, num_reqs, num_scheduled_tokens,
num_valid_tokens): num_valid_tokens):
ascend_config = get_ascend_config() ascend_config = get_ascend_config()