From 3fa7cf6345c4e65a767d193a132b51eadfcbe70b Mon Sep 17 00:00:00 2001 From: Yizhou <136800916+yiz-liu@users.noreply.github.com> Date: Mon, 22 Sep 2025 22:23:14 +0800 Subject: [PATCH] [Refactor][Graph] Move graph parameter logic to acl_graph module (#3101) ### What this PR does / why we need it? This is the follow-up PR of #2128 . Moves graph parameter management components, including `GraphParams`, `get_graph_params`, and `set_graph_params`, from the generic `utils.py` to the more specific `compilation/acl_graph.py`. Additionally, extracts the `update_attn_params` logic from the `NPUModelRunner` class into a standalone function within the `acl_graph` module. This refactoring improves code organization by centralizing ACL graph-related logic into its own dedicated module, enhancing modularity and clarity. ### Does this PR introduce _any_ user-facing change? None. ### How was this patch tested? None needed. Signed-off-by: Yizhou Liu --- vllm_ascend/attention/attention_v1.py | 6 +-- vllm_ascend/compilation/acl_graph.py | 73 +++++++++++++++++++++++++++ vllm_ascend/utils.py | 34 +------------ vllm_ascend/worker/model_runner_v1.py | 52 +++---------------- 4 files changed, 84 insertions(+), 81 deletions(-) diff --git a/vllm_ascend/attention/attention_v1.py b/vllm_ascend/attention/attention_v1.py index b2e7c84..7d3d18f 100644 --- a/vllm_ascend/attention/attention_v1.py +++ b/vllm_ascend/attention/attention_v1.py @@ -36,10 +36,10 @@ from vllm.v1.core.sched.output import SchedulerOutput from vllm.v1.kv_cache_interface import AttentionSpec from vllm_ascend.attention.utils import AscendCommonAttentionMetadata +from vllm_ascend.compilation.acl_graph import get_graph_params from vllm_ascend.ops.attention import vanilla_chunked_prefill -from vllm_ascend.utils import (ACL_FORMAT_FRACTAL_NZ, aligned_16, - get_graph_params, is_310p, nd_to_nz_2d, - nd_to_nz_spec) +from vllm_ascend.utils import (ACL_FORMAT_FRACTAL_NZ, aligned_16, is_310p, + nd_to_nz_2d, nd_to_nz_spec) def wait_for_kv_layer_from_connector(layer_name: str): diff --git a/vllm_ascend/compilation/acl_graph.py b/vllm_ascend/compilation/acl_graph.py index d88ee34..8a41807 100644 --- a/vllm_ascend/compilation/acl_graph.py +++ b/vllm_ascend/compilation/acl_graph.py @@ -3,10 +3,12 @@ import dataclasses from contextlib import ExitStack +from dataclasses import dataclass from typing import Any, Callable, Optional from unittest.mock import patch import torch +import torch_npu import vllm.envs as envs from vllm.compilation.counter import compilation_counter from vllm.compilation.cuda_graph import CUDAGraphOptions @@ -185,3 +187,74 @@ class ACLGraphWrapper: logger.info_once("Replaying aclgraph") entry.aclgraph.replay() return entry.output + + +def update_attn_params(update_stream, forward_context, runtime_shape): + graph_params = get_graph_params() + # FIXME: Behold! We are using a temporary hack here to update the args + # for each layer's attention op in the graph. + for key, param, handle, event in zip( + forward_context.attn_metadata, + graph_params.attn_params[runtime_shape], + graph_params.handles[runtime_shape], + graph_params.events[runtime_shape], + ): + ( + query, + key_cache, + value_cache, + num_kv_heads, + num_heads, + scale, + block_table, + seq_lens, + output, + ) = param + # block_table = forward_context.attn_metadata[key].block_tables + seq_lens = forward_context.attn_metadata[key].seq_lens + + with torch.npu.stream(update_stream): + torch.npu.graph_task_update_begin(update_stream, handle) + torch_npu._npu_paged_attention(query=query, + key_cache=key_cache, + value_cache=value_cache, + num_kv_heads=num_kv_heads, + num_heads=num_heads, + scale_value=scale, + block_table=block_table, + context_lens=seq_lens, + out=output) + torch.npu.graph_task_update_end(update_stream) + + event.record(update_stream) + + +@dataclass +class GraphParams: + events: dict[int, list[torch.npu.ExternalEvent]] + workspaces: dict[int, torch.Tensor] + handles: dict[int, list[torch_npu._C._NPUTaskGroupHandle]] + attn_params: dict[int, list[tuple]] + + +_graph_params: Optional[GraphParams] = None + + +def set_graph_params(aclgraph_capture_sizes: set[int]): + global _graph_params + if _graph_params is not None: + raise ValueError("Graph parameters have already been set!") + _graph_params = GraphParams( + {size: [] + for size in aclgraph_capture_sizes}, + {size: None + for size in aclgraph_capture_sizes}, + {size: [] + for size in aclgraph_capture_sizes}, + {size: [] + for size in aclgraph_capture_sizes}, + ) + + +def get_graph_params(): + return _graph_params diff --git a/vllm_ascend/utils.py b/vllm_ascend/utils.py index 451112d..2b2b540 100644 --- a/vllm_ascend/utils.py +++ b/vllm_ascend/utils.py @@ -22,13 +22,12 @@ import functools import math import os from contextlib import contextmanager, nullcontext -from dataclasses import dataclass from enum import Enum from threading import Lock from typing import TYPE_CHECKING, Any, List, Optional, Tuple, Union import torch -import torch_npu # noqa: F401 # noqa: F401 +import torch_npu # noqa: F401 from packaging.version import InvalidVersion, Version from torch_npu.npu.streams import Event from vllm.logger import logger @@ -635,34 +634,3 @@ def npu_stream_switch(target_stream: torch.npu.Stream, return nullcontext() assert target_stream is not None return torch.npu.stream(target_stream) - - -@dataclass -class GraphParams: - events: dict[int, list[torch.npu.ExternalEvent]] - workspaces: dict[int, torch.Tensor] - handles: dict[int, list[torch_npu._C._NPUTaskGroupHandle]] - attn_params: dict[int, list[tuple]] - - -_graph_params: Optional[GraphParams] = None - - -def set_graph_params(aclgraph_capture_sizes: set[int]): - global _graph_params - if _graph_params is not None: - raise ValueError("Graph parameters have already been set!") - _graph_params = GraphParams( - {size: [] - for size in aclgraph_capture_sizes}, - {size: None - for size in aclgraph_capture_sizes}, - {size: [] - for size in aclgraph_capture_sizes}, - {size: [] - for size in aclgraph_capture_sizes}, - ) - - -def get_graph_params(): - return _graph_params diff --git a/vllm_ascend/worker/model_runner_v1.py b/vllm_ascend/worker/model_runner_v1.py index 242ff6e..22a4c4b 100644 --- a/vllm_ascend/worker/model_runner_v1.py +++ b/vllm_ascend/worker/model_runner_v1.py @@ -99,7 +99,9 @@ from vllm_ascend.ascend_forward_context import (MoECommType, from vllm_ascend.attention.attention_mask import AttentionMaskBuilder from vllm_ascend.attention.attention_v1 import AscendAttentionState from vllm_ascend.attention.utils import AscendCommonAttentionMetadata -from vllm_ascend.compilation.acl_graph import ACLGraphWrapper +from vllm_ascend.compilation.acl_graph import (ACLGraphWrapper, + set_graph_params, + update_attn_params) from vllm_ascend.eplb.adaptor.vllm_adaptor import VllmEplbAdaptor from vllm_ascend.eplb.core.eplb_device_transfer_loader import \ D2DExpertWeightLoader @@ -117,9 +119,8 @@ from vllm_ascend.spec_decode.interface import SpecDcodeType from vllm_ascend.spec_decode.mtp_proposer import MtpProposer from vllm_ascend.utils import (ACL_FORMAT_FRACTAL_ND, ACL_FORMAT_FRACTAL_NZ, AscendSocVersion, ProfileExecuteDuration, - get_ascend_soc_version, get_graph_params, - is_310p, lmhead_tp_enable, set_graph_params, - vllm_version_is) + get_ascend_soc_version, is_310p, + lmhead_tp_enable, vllm_version_is) from vllm_ascend.worker.npu_input_batch import CachedRequestState, InputBatch if TYPE_CHECKING: @@ -1571,9 +1572,8 @@ class NPUModelRunner(LoRAModelRunnerMixin): forward_context = get_forward_context() if forward_context.cudagraph_runtime_mode == CUDAGraphMode.FULL: - graph_params = get_graph_params() - self.update_attn_params(graph_params, forward_context, - positions.shape[0]) + update_attn_params(self.update_stream, forward_context, + positions.shape[0]) if get_forward_context().flashcomm_v1_enabled: hidden_states = tensor_model_parallel_all_gather(hidden_states, 0) @@ -1582,44 +1582,6 @@ class NPUModelRunner(LoRAModelRunnerMixin): hidden_states = hidden_states[:-pad_size, :] return hidden_states - def update_attn_params(self, graph_params, forward_context, runtime_shape): - # FIXME: Behold! We are using a temporary hack here to update the args - # for each layer's attention op in the graph. - for key, param, handle, event in zip( - forward_context.attn_metadata, - graph_params.attn_params[runtime_shape], - graph_params.handles[runtime_shape], - graph_params.events[runtime_shape], - ): - ( - query, - key_cache, - value_cache, - num_kv_heads, - num_heads, - scale, - block_table, - seq_lens, - output, - ) = param - # block_table = forward_context.attn_metadata[key].block_tables - seq_lens = forward_context.attn_metadata[key].seq_lens - - with torch.npu.stream(self.update_stream): - torch.npu.graph_task_update_begin(self.update_stream, handle) - torch_npu._npu_paged_attention(query=query, - key_cache=key_cache, - value_cache=value_cache, - num_kv_heads=num_kv_heads, - num_heads=num_heads, - scale_value=scale, - block_table=block_table, - context_lens=seq_lens, - out=output) - torch.npu.graph_task_update_end(self.update_stream) - - event.record(self.update_stream) - def _build_attn_state(self, num_reqs, num_scheduled_tokens, num_valid_tokens): ascend_config = get_ascend_config()