[0.18.0][BugFix]: order acl graph updates before model forward for ENPU (#8317)

<!--  Thanks for sending a pull request!

BEFORE SUBMITTING, PLEASE READ
https://docs.vllm.ai/en/latest/contributing/overview.html

-->
### What this PR does / why we need it?
<!--
- Please clarify what changes you are proposing. The purpose of this
section is to outline the changes and how this PR fixes the issue.
If possible, please consider writing useful notes for better and faster
reviews in your PR.

- Please clarify why the changes are needed. For instance, the use case
and bug description.

- Fixes #
-->
For the ENPU scenario, it is required that device events follow the
principle of "record first, wait later", otherwise the inference process
may become stuck. However, in the current model_forward function,
event.wait precedes event.record. Therefore, for the ENPU scenario,
graph parameter updates should be performed before model execution.

### Does this PR introduce _any_ user-facing change?
<!--
Note that it means *any* user-facing change including all aspects such
as API, interface or other behavior changes.
Documentation-only updates are not considered user-facing changes.
-->
N/A

### How was this patch tested?
<!--
CI passed with new added/existing test.
If it was tested in a way different from regular unit tests, please
clarify how you tested step by step, ideally copy and paste-able, so
that other reviewers can test and check, and descendants can verify in
the future.
If tests were not added, please describe why they were not added and/or
why it was difficult to add.
-->

---------

Signed-off-by: 1zzk <785396250@qq.com>
Signed-off-by: 1kzk <785396250@qq.com>
Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
This commit is contained in:
1kzk
2026-04-16 16:26:59 +08:00
committed by GitHub
parent 2ac8bfb4cb
commit 52f0f9b5e4
2 changed files with 64 additions and 19 deletions

View File

@@ -23,6 +23,7 @@ from collections import defaultdict
from contextlib import contextmanager, nullcontext
from copy import copy, deepcopy
from dataclasses import dataclass
from functools import partial
from multiprocessing import Manager
from typing import TYPE_CHECKING, Any, NamedTuple, TypeAlias
@@ -36,7 +37,7 @@ from vllm.distributed import get_tensor_model_parallel_world_size, tensor_model_
from vllm.distributed.ec_transfer import get_ec_transfer, has_ec_transfer
from vllm.distributed.kv_transfer import get_kv_transfer_group, has_kv_transfer_group
from vllm.distributed.parallel_state import get_dcp_group, get_dp_group, get_pcp_group, get_pp_group, get_tp_group
from vllm.forward_context import BatchDescriptor, get_forward_context
from vllm.forward_context import BatchDescriptor, ForwardContext, get_forward_context
from vllm.logger import logger
from vllm.model_executor.layers.attention_layer_base import AttentionLayerBase
from vllm.model_executor.layers.mamba.abstract import MambaBase
@@ -122,6 +123,7 @@ from vllm_ascend.utils import (
check_gdn_layer,
enable_sp,
enable_sp_by_pass,
get_c_env,
global_stream,
lmhead_tp_enable,
set_weight_prefetch_method,
@@ -422,6 +424,9 @@ class NPUModelRunner(GPUModelRunner):
self.cudagraph_batch_sizes = []
self.mamba_state_idx: dict[str, int] = {}
self._mamba_copy_bufs: mamba_utils.MambaCopyBuffers | None = None
env_enpu_enable = get_c_env("ENPU_ENABLE")
# When True, run update_full_graph_params before self.model (ENPU / graph capture order).
self.enable_enpu = env_enpu_enable is not None and env_enpu_enable.lower() == "true"
@property
def use_cp(self) -> bool:
@@ -1767,25 +1772,12 @@ class NPUModelRunner(GPUModelRunner):
)
return NPUModelRunner._all_gather_hidden_states(hidden_states)
def _model_forward(
def _update_full_graph_params_if_needed(
self,
forward_context: ForwardContext,
num_tokens_padded: int,
input_ids: torch.Tensor | None = None,
positions: torch.Tensor | None = None,
intermediate_tensors: IntermediateTensors | None = None,
inputs_embeds: torch.Tensor | None = None,
**model_kwargs: dict[str, Any],
):
assert self.model is not None
hidden_states = self.model(
input_ids=input_ids,
positions=positions,
intermediate_tensors=intermediate_tensors,
inputs_embeds=inputs_embeds,
**model_kwargs,
)
forward_context = get_forward_context()
assert forward_context is not None
positions: torch.Tensor | None,
) -> None:
if (
forward_context.cudagraph_runtime_mode == CUDAGraphMode.FULL
and not forward_context.capturing
@@ -1801,7 +1793,42 @@ class NPUModelRunner(GPUModelRunner):
self.speculative_config,
positions.shape[0],
)
if get_forward_context().flash_comm_v1_enabled and not isinstance(hidden_states, IntermediateTensors):
def _model_forward(
self,
num_tokens_padded: int,
input_ids: torch.Tensor | None = None,
positions: torch.Tensor | None = None,
intermediate_tensors: IntermediateTensors | None = None,
inputs_embeds: torch.Tensor | None = None,
**model_kwargs: dict[str, Any],
):
assert self.model is not None
forward_context = get_forward_context()
assert forward_context is not None
model_inputs: dict[str, Any] = {
"input_ids": input_ids,
"positions": positions,
"intermediate_tensors": intermediate_tensors,
"inputs_embeds": inputs_embeds,
**model_kwargs,
}
run_model = partial(self.model, **model_inputs)
if self.enable_enpu:
# The soft segmentation scenario requires event.record first, then event.wait
self._update_full_graph_params_if_needed(
forward_context, num_tokens_padded, positions
)
hidden_states = run_model()
else:
hidden_states = run_model()
self._update_full_graph_params_if_needed(
forward_context, num_tokens_padded, positions
)
if forward_context.flash_comm_v1_enabled and not isinstance(hidden_states, IntermediateTensors):
hidden_states = self._all_gather_hidden_states_and_aux(hidden_states)
return hidden_states