[Refactor][Graph] Move graph parameter logic to acl_graph module (#3101)

### What this PR does / why we need it?
This is the follow-up PR of #2128 .

Moves graph parameter management components, including `GraphParams`,
`get_graph_params`, and `set_graph_params`, from the generic `utils.py`
to the more specific `compilation/acl_graph.py`.

Additionally, extracts the `update_attn_params` logic from the
`NPUModelRunner` class into a standalone function within the `acl_graph`
module.

This refactoring improves code organization by centralizing ACL
graph-related logic into its own dedicated module, enhancing modularity
and clarity.

### Does this PR introduce _any_ user-facing change?
None.

### How was this patch tested?
None needed.

Signed-off-by: Yizhou Liu <liu_yizhou@outlook.com>
This commit is contained in:
Yizhou
2025-09-22 22:23:14 +08:00
committed by GitHub
parent 02f89d166f
commit 3fa7cf6345
4 changed files with 84 additions and 81 deletions

View File

@@ -22,13 +22,12 @@ import functools
import math
import os
from contextlib import contextmanager, nullcontext
from dataclasses import dataclass
from enum import Enum
from threading import Lock
from typing import TYPE_CHECKING, Any, List, Optional, Tuple, Union
import torch
import torch_npu # noqa: F401 # noqa: F401
import torch_npu # noqa: F401
from packaging.version import InvalidVersion, Version
from torch_npu.npu.streams import Event
from vllm.logger import logger
@@ -635,34 +634,3 @@ def npu_stream_switch(target_stream: torch.npu.Stream,
return nullcontext()
assert target_stream is not None
return torch.npu.stream(target_stream)
@dataclass
class GraphParams:
events: dict[int, list[torch.npu.ExternalEvent]]
workspaces: dict[int, torch.Tensor]
handles: dict[int, list[torch_npu._C._NPUTaskGroupHandle]]
attn_params: dict[int, list[tuple]]
_graph_params: Optional[GraphParams] = None
def set_graph_params(aclgraph_capture_sizes: set[int]):
global _graph_params
if _graph_params is not None:
raise ValueError("Graph parameters have already been set!")
_graph_params = GraphParams(
{size: []
for size in aclgraph_capture_sizes},
{size: None
for size in aclgraph_capture_sizes},
{size: []
for size in aclgraph_capture_sizes},
{size: []
for size in aclgraph_capture_sizes},
)
def get_graph_params():
return _graph_params