[Refactor][Graph] Move graph parameter logic to acl_graph module (#3101)
### What this PR does / why we need it? This is the follow-up PR of #2128 . Moves graph parameter management components, including `GraphParams`, `get_graph_params`, and `set_graph_params`, from the generic `utils.py` to the more specific `compilation/acl_graph.py`. Additionally, extracts the `update_attn_params` logic from the `NPUModelRunner` class into a standalone function within the `acl_graph` module. This refactoring improves code organization by centralizing ACL graph-related logic into its own dedicated module, enhancing modularity and clarity. ### Does this PR introduce _any_ user-facing change? None. ### How was this patch tested? None needed. Signed-off-by: Yizhou Liu <liu_yizhou@outlook.com>
This commit is contained in:
@@ -36,10 +36,10 @@ from vllm.v1.core.sched.output import SchedulerOutput
|
|||||||
from vllm.v1.kv_cache_interface import AttentionSpec
|
from vllm.v1.kv_cache_interface import AttentionSpec
|
||||||
|
|
||||||
from vllm_ascend.attention.utils import AscendCommonAttentionMetadata
|
from vllm_ascend.attention.utils import AscendCommonAttentionMetadata
|
||||||
|
from vllm_ascend.compilation.acl_graph import get_graph_params
|
||||||
from vllm_ascend.ops.attention import vanilla_chunked_prefill
|
from vllm_ascend.ops.attention import vanilla_chunked_prefill
|
||||||
from vllm_ascend.utils import (ACL_FORMAT_FRACTAL_NZ, aligned_16,
|
from vllm_ascend.utils import (ACL_FORMAT_FRACTAL_NZ, aligned_16, is_310p,
|
||||||
get_graph_params, is_310p, nd_to_nz_2d,
|
nd_to_nz_2d, nd_to_nz_spec)
|
||||||
nd_to_nz_spec)
|
|
||||||
|
|
||||||
|
|
||||||
def wait_for_kv_layer_from_connector(layer_name: str):
|
def wait_for_kv_layer_from_connector(layer_name: str):
|
||||||
|
|||||||
@@ -3,10 +3,12 @@
|
|||||||
|
|
||||||
import dataclasses
|
import dataclasses
|
||||||
from contextlib import ExitStack
|
from contextlib import ExitStack
|
||||||
|
from dataclasses import dataclass
|
||||||
from typing import Any, Callable, Optional
|
from typing import Any, Callable, Optional
|
||||||
from unittest.mock import patch
|
from unittest.mock import patch
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
import torch_npu
|
||||||
import vllm.envs as envs
|
import vllm.envs as envs
|
||||||
from vllm.compilation.counter import compilation_counter
|
from vllm.compilation.counter import compilation_counter
|
||||||
from vllm.compilation.cuda_graph import CUDAGraphOptions
|
from vllm.compilation.cuda_graph import CUDAGraphOptions
|
||||||
@@ -185,3 +187,74 @@ class ACLGraphWrapper:
|
|||||||
logger.info_once("Replaying aclgraph")
|
logger.info_once("Replaying aclgraph")
|
||||||
entry.aclgraph.replay()
|
entry.aclgraph.replay()
|
||||||
return entry.output
|
return entry.output
|
||||||
|
|
||||||
|
|
||||||
|
def update_attn_params(update_stream, forward_context, runtime_shape):
|
||||||
|
graph_params = get_graph_params()
|
||||||
|
# FIXME: Behold! We are using a temporary hack here to update the args
|
||||||
|
# for each layer's attention op in the graph.
|
||||||
|
for key, param, handle, event in zip(
|
||||||
|
forward_context.attn_metadata,
|
||||||
|
graph_params.attn_params[runtime_shape],
|
||||||
|
graph_params.handles[runtime_shape],
|
||||||
|
graph_params.events[runtime_shape],
|
||||||
|
):
|
||||||
|
(
|
||||||
|
query,
|
||||||
|
key_cache,
|
||||||
|
value_cache,
|
||||||
|
num_kv_heads,
|
||||||
|
num_heads,
|
||||||
|
scale,
|
||||||
|
block_table,
|
||||||
|
seq_lens,
|
||||||
|
output,
|
||||||
|
) = param
|
||||||
|
# block_table = forward_context.attn_metadata[key].block_tables
|
||||||
|
seq_lens = forward_context.attn_metadata[key].seq_lens
|
||||||
|
|
||||||
|
with torch.npu.stream(update_stream):
|
||||||
|
torch.npu.graph_task_update_begin(update_stream, handle)
|
||||||
|
torch_npu._npu_paged_attention(query=query,
|
||||||
|
key_cache=key_cache,
|
||||||
|
value_cache=value_cache,
|
||||||
|
num_kv_heads=num_kv_heads,
|
||||||
|
num_heads=num_heads,
|
||||||
|
scale_value=scale,
|
||||||
|
block_table=block_table,
|
||||||
|
context_lens=seq_lens,
|
||||||
|
out=output)
|
||||||
|
torch.npu.graph_task_update_end(update_stream)
|
||||||
|
|
||||||
|
event.record(update_stream)
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class GraphParams:
|
||||||
|
events: dict[int, list[torch.npu.ExternalEvent]]
|
||||||
|
workspaces: dict[int, torch.Tensor]
|
||||||
|
handles: dict[int, list[torch_npu._C._NPUTaskGroupHandle]]
|
||||||
|
attn_params: dict[int, list[tuple]]
|
||||||
|
|
||||||
|
|
||||||
|
_graph_params: Optional[GraphParams] = None
|
||||||
|
|
||||||
|
|
||||||
|
def set_graph_params(aclgraph_capture_sizes: set[int]):
|
||||||
|
global _graph_params
|
||||||
|
if _graph_params is not None:
|
||||||
|
raise ValueError("Graph parameters have already been set!")
|
||||||
|
_graph_params = GraphParams(
|
||||||
|
{size: []
|
||||||
|
for size in aclgraph_capture_sizes},
|
||||||
|
{size: None
|
||||||
|
for size in aclgraph_capture_sizes},
|
||||||
|
{size: []
|
||||||
|
for size in aclgraph_capture_sizes},
|
||||||
|
{size: []
|
||||||
|
for size in aclgraph_capture_sizes},
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def get_graph_params():
|
||||||
|
return _graph_params
|
||||||
|
|||||||
@@ -22,13 +22,12 @@ import functools
|
|||||||
import math
|
import math
|
||||||
import os
|
import os
|
||||||
from contextlib import contextmanager, nullcontext
|
from contextlib import contextmanager, nullcontext
|
||||||
from dataclasses import dataclass
|
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
from threading import Lock
|
from threading import Lock
|
||||||
from typing import TYPE_CHECKING, Any, List, Optional, Tuple, Union
|
from typing import TYPE_CHECKING, Any, List, Optional, Tuple, Union
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch_npu # noqa: F401 # noqa: F401
|
import torch_npu # noqa: F401
|
||||||
from packaging.version import InvalidVersion, Version
|
from packaging.version import InvalidVersion, Version
|
||||||
from torch_npu.npu.streams import Event
|
from torch_npu.npu.streams import Event
|
||||||
from vllm.logger import logger
|
from vllm.logger import logger
|
||||||
@@ -635,34 +634,3 @@ def npu_stream_switch(target_stream: torch.npu.Stream,
|
|||||||
return nullcontext()
|
return nullcontext()
|
||||||
assert target_stream is not None
|
assert target_stream is not None
|
||||||
return torch.npu.stream(target_stream)
|
return torch.npu.stream(target_stream)
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
|
||||||
class GraphParams:
|
|
||||||
events: dict[int, list[torch.npu.ExternalEvent]]
|
|
||||||
workspaces: dict[int, torch.Tensor]
|
|
||||||
handles: dict[int, list[torch_npu._C._NPUTaskGroupHandle]]
|
|
||||||
attn_params: dict[int, list[tuple]]
|
|
||||||
|
|
||||||
|
|
||||||
_graph_params: Optional[GraphParams] = None
|
|
||||||
|
|
||||||
|
|
||||||
def set_graph_params(aclgraph_capture_sizes: set[int]):
|
|
||||||
global _graph_params
|
|
||||||
if _graph_params is not None:
|
|
||||||
raise ValueError("Graph parameters have already been set!")
|
|
||||||
_graph_params = GraphParams(
|
|
||||||
{size: []
|
|
||||||
for size in aclgraph_capture_sizes},
|
|
||||||
{size: None
|
|
||||||
for size in aclgraph_capture_sizes},
|
|
||||||
{size: []
|
|
||||||
for size in aclgraph_capture_sizes},
|
|
||||||
{size: []
|
|
||||||
for size in aclgraph_capture_sizes},
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def get_graph_params():
|
|
||||||
return _graph_params
|
|
||||||
|
|||||||
@@ -99,7 +99,9 @@ from vllm_ascend.ascend_forward_context import (MoECommType,
|
|||||||
from vllm_ascend.attention.attention_mask import AttentionMaskBuilder
|
from vllm_ascend.attention.attention_mask import AttentionMaskBuilder
|
||||||
from vllm_ascend.attention.attention_v1 import AscendAttentionState
|
from vllm_ascend.attention.attention_v1 import AscendAttentionState
|
||||||
from vllm_ascend.attention.utils import AscendCommonAttentionMetadata
|
from vllm_ascend.attention.utils import AscendCommonAttentionMetadata
|
||||||
from vllm_ascend.compilation.acl_graph import ACLGraphWrapper
|
from vllm_ascend.compilation.acl_graph import (ACLGraphWrapper,
|
||||||
|
set_graph_params,
|
||||||
|
update_attn_params)
|
||||||
from vllm_ascend.eplb.adaptor.vllm_adaptor import VllmEplbAdaptor
|
from vllm_ascend.eplb.adaptor.vllm_adaptor import VllmEplbAdaptor
|
||||||
from vllm_ascend.eplb.core.eplb_device_transfer_loader import \
|
from vllm_ascend.eplb.core.eplb_device_transfer_loader import \
|
||||||
D2DExpertWeightLoader
|
D2DExpertWeightLoader
|
||||||
@@ -117,9 +119,8 @@ from vllm_ascend.spec_decode.interface import SpecDcodeType
|
|||||||
from vllm_ascend.spec_decode.mtp_proposer import MtpProposer
|
from vllm_ascend.spec_decode.mtp_proposer import MtpProposer
|
||||||
from vllm_ascend.utils import (ACL_FORMAT_FRACTAL_ND, ACL_FORMAT_FRACTAL_NZ,
|
from vllm_ascend.utils import (ACL_FORMAT_FRACTAL_ND, ACL_FORMAT_FRACTAL_NZ,
|
||||||
AscendSocVersion, ProfileExecuteDuration,
|
AscendSocVersion, ProfileExecuteDuration,
|
||||||
get_ascend_soc_version, get_graph_params,
|
get_ascend_soc_version, is_310p,
|
||||||
is_310p, lmhead_tp_enable, set_graph_params,
|
lmhead_tp_enable, vllm_version_is)
|
||||||
vllm_version_is)
|
|
||||||
from vllm_ascend.worker.npu_input_batch import CachedRequestState, InputBatch
|
from vllm_ascend.worker.npu_input_batch import CachedRequestState, InputBatch
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
@@ -1571,9 +1572,8 @@ class NPUModelRunner(LoRAModelRunnerMixin):
|
|||||||
|
|
||||||
forward_context = get_forward_context()
|
forward_context = get_forward_context()
|
||||||
if forward_context.cudagraph_runtime_mode == CUDAGraphMode.FULL:
|
if forward_context.cudagraph_runtime_mode == CUDAGraphMode.FULL:
|
||||||
graph_params = get_graph_params()
|
update_attn_params(self.update_stream, forward_context,
|
||||||
self.update_attn_params(graph_params, forward_context,
|
positions.shape[0])
|
||||||
positions.shape[0])
|
|
||||||
|
|
||||||
if get_forward_context().flashcomm_v1_enabled:
|
if get_forward_context().flashcomm_v1_enabled:
|
||||||
hidden_states = tensor_model_parallel_all_gather(hidden_states, 0)
|
hidden_states = tensor_model_parallel_all_gather(hidden_states, 0)
|
||||||
@@ -1582,44 +1582,6 @@ class NPUModelRunner(LoRAModelRunnerMixin):
|
|||||||
hidden_states = hidden_states[:-pad_size, :]
|
hidden_states = hidden_states[:-pad_size, :]
|
||||||
return hidden_states
|
return hidden_states
|
||||||
|
|
||||||
def update_attn_params(self, graph_params, forward_context, runtime_shape):
|
|
||||||
# FIXME: Behold! We are using a temporary hack here to update the args
|
|
||||||
# for each layer's attention op in the graph.
|
|
||||||
for key, param, handle, event in zip(
|
|
||||||
forward_context.attn_metadata,
|
|
||||||
graph_params.attn_params[runtime_shape],
|
|
||||||
graph_params.handles[runtime_shape],
|
|
||||||
graph_params.events[runtime_shape],
|
|
||||||
):
|
|
||||||
(
|
|
||||||
query,
|
|
||||||
key_cache,
|
|
||||||
value_cache,
|
|
||||||
num_kv_heads,
|
|
||||||
num_heads,
|
|
||||||
scale,
|
|
||||||
block_table,
|
|
||||||
seq_lens,
|
|
||||||
output,
|
|
||||||
) = param
|
|
||||||
# block_table = forward_context.attn_metadata[key].block_tables
|
|
||||||
seq_lens = forward_context.attn_metadata[key].seq_lens
|
|
||||||
|
|
||||||
with torch.npu.stream(self.update_stream):
|
|
||||||
torch.npu.graph_task_update_begin(self.update_stream, handle)
|
|
||||||
torch_npu._npu_paged_attention(query=query,
|
|
||||||
key_cache=key_cache,
|
|
||||||
value_cache=value_cache,
|
|
||||||
num_kv_heads=num_kv_heads,
|
|
||||||
num_heads=num_heads,
|
|
||||||
scale_value=scale,
|
|
||||||
block_table=block_table,
|
|
||||||
context_lens=seq_lens,
|
|
||||||
out=output)
|
|
||||||
torch.npu.graph_task_update_end(self.update_stream)
|
|
||||||
|
|
||||||
event.record(self.update_stream)
|
|
||||||
|
|
||||||
def _build_attn_state(self, num_reqs, num_scheduled_tokens,
|
def _build_attn_state(self, num_reqs, num_scheduled_tokens,
|
||||||
num_valid_tokens):
|
num_valid_tokens):
|
||||||
ascend_config = get_ascend_config()
|
ascend_config = get_ascend_config()
|
||||||
|
|||||||
Reference in New Issue
Block a user