[Refactor][Graph] Move graph parameter logic to acl_graph module (#3101)
### What this PR does / why we need it? This is the follow-up PR of #2128 . Moves graph parameter management components, including `GraphParams`, `get_graph_params`, and `set_graph_params`, from the generic `utils.py` to the more specific `compilation/acl_graph.py`. Additionally, extracts the `update_attn_params` logic from the `NPUModelRunner` class into a standalone function within the `acl_graph` module. This refactoring improves code organization by centralizing ACL graph-related logic into its own dedicated module, enhancing modularity and clarity. ### Does this PR introduce _any_ user-facing change? None. ### How was this patch tested? None needed. Signed-off-by: Yizhou Liu <liu_yizhou@outlook.com>
This commit is contained in:
@@ -22,13 +22,12 @@ import functools
|
||||
import math
|
||||
import os
|
||||
from contextlib import contextmanager, nullcontext
|
||||
from dataclasses import dataclass
|
||||
from enum import Enum
|
||||
from threading import Lock
|
||||
from typing import TYPE_CHECKING, Any, List, Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
import torch_npu # noqa: F401 # noqa: F401
|
||||
import torch_npu # noqa: F401
|
||||
from packaging.version import InvalidVersion, Version
|
||||
from torch_npu.npu.streams import Event
|
||||
from vllm.logger import logger
|
||||
@@ -635,34 +634,3 @@ def npu_stream_switch(target_stream: torch.npu.Stream,
|
||||
return nullcontext()
|
||||
assert target_stream is not None
|
||||
return torch.npu.stream(target_stream)
|
||||
|
||||
|
||||
@dataclass
|
||||
class GraphParams:
|
||||
events: dict[int, list[torch.npu.ExternalEvent]]
|
||||
workspaces: dict[int, torch.Tensor]
|
||||
handles: dict[int, list[torch_npu._C._NPUTaskGroupHandle]]
|
||||
attn_params: dict[int, list[tuple]]
|
||||
|
||||
|
||||
_graph_params: Optional[GraphParams] = None
|
||||
|
||||
|
||||
def set_graph_params(aclgraph_capture_sizes: set[int]):
|
||||
global _graph_params
|
||||
if _graph_params is not None:
|
||||
raise ValueError("Graph parameters have already been set!")
|
||||
_graph_params = GraphParams(
|
||||
{size: []
|
||||
for size in aclgraph_capture_sizes},
|
||||
{size: None
|
||||
for size in aclgraph_capture_sizes},
|
||||
{size: []
|
||||
for size in aclgraph_capture_sizes},
|
||||
{size: []
|
||||
for size in aclgraph_capture_sizes},
|
||||
)
|
||||
|
||||
|
||||
def get_graph_params():
|
||||
return _graph_params
|
||||
|
||||
Reference in New Issue
Block a user