[Refactor][Graph] Move graph parameter logic to acl_graph module (#3101)

### What this PR does / why we need it?
This is the follow-up PR of #2128 .

Moves graph parameter management components, including `GraphParams`,
`get_graph_params`, and `set_graph_params`, from the generic `utils.py`
to the more specific `compilation/acl_graph.py`.

Additionally, extracts the `update_attn_params` logic from the
`NPUModelRunner` class into a standalone function within the `acl_graph`
module.

This refactoring improves code organization by centralizing ACL
graph-related logic into its own dedicated module, enhancing modularity
and clarity.

### Does this PR introduce _any_ user-facing change?
None.

### How was this patch tested?
None needed.

Signed-off-by: Yizhou Liu <liu_yizhou@outlook.com>
This commit is contained in:
Yizhou
2025-09-22 22:23:14 +08:00
committed by GitHub
parent 02f89d166f
commit 3fa7cf6345
4 changed files with 84 additions and 81 deletions

View File

@@ -99,7 +99,9 @@ from vllm_ascend.ascend_forward_context import (MoECommType,
from vllm_ascend.attention.attention_mask import AttentionMaskBuilder
from vllm_ascend.attention.attention_v1 import AscendAttentionState
from vllm_ascend.attention.utils import AscendCommonAttentionMetadata
from vllm_ascend.compilation.acl_graph import ACLGraphWrapper
from vllm_ascend.compilation.acl_graph import (ACLGraphWrapper,
set_graph_params,
update_attn_params)
from vllm_ascend.eplb.adaptor.vllm_adaptor import VllmEplbAdaptor
from vllm_ascend.eplb.core.eplb_device_transfer_loader import \
D2DExpertWeightLoader
@@ -117,9 +119,8 @@ from vllm_ascend.spec_decode.interface import SpecDcodeType
from vllm_ascend.spec_decode.mtp_proposer import MtpProposer
from vllm_ascend.utils import (ACL_FORMAT_FRACTAL_ND, ACL_FORMAT_FRACTAL_NZ,
AscendSocVersion, ProfileExecuteDuration,
get_ascend_soc_version, get_graph_params,
is_310p, lmhead_tp_enable, set_graph_params,
vllm_version_is)
get_ascend_soc_version, is_310p,
lmhead_tp_enable, vllm_version_is)
from vllm_ascend.worker.npu_input_batch import CachedRequestState, InputBatch
if TYPE_CHECKING:
@@ -1571,9 +1572,8 @@ class NPUModelRunner(LoRAModelRunnerMixin):
forward_context = get_forward_context()
if forward_context.cudagraph_runtime_mode == CUDAGraphMode.FULL:
graph_params = get_graph_params()
self.update_attn_params(graph_params, forward_context,
positions.shape[0])
update_attn_params(self.update_stream, forward_context,
positions.shape[0])
if get_forward_context().flashcomm_v1_enabled:
hidden_states = tensor_model_parallel_all_gather(hidden_states, 0)
@@ -1582,44 +1582,6 @@ class NPUModelRunner(LoRAModelRunnerMixin):
hidden_states = hidden_states[:-pad_size, :]
return hidden_states
def update_attn_params(self, graph_params, forward_context, runtime_shape):
# FIXME: Behold! We are using a temporary hack here to update the args
# for each layer's attention op in the graph.
for key, param, handle, event in zip(
forward_context.attn_metadata,
graph_params.attn_params[runtime_shape],
graph_params.handles[runtime_shape],
graph_params.events[runtime_shape],
):
(
query,
key_cache,
value_cache,
num_kv_heads,
num_heads,
scale,
block_table,
seq_lens,
output,
) = param
# block_table = forward_context.attn_metadata[key].block_tables
seq_lens = forward_context.attn_metadata[key].seq_lens
with torch.npu.stream(self.update_stream):
torch.npu.graph_task_update_begin(self.update_stream, handle)
torch_npu._npu_paged_attention(query=query,
key_cache=key_cache,
value_cache=value_cache,
num_kv_heads=num_kv_heads,
num_heads=num_heads,
scale_value=scale,
block_table=block_table,
context_lens=seq_lens,
out=output)
torch.npu.graph_task_update_end(self.update_stream)
event.record(self.update_stream)
def _build_attn_state(self, num_reqs, num_scheduled_tokens,
num_valid_tokens):
ascend_config = get_ascend_config()