From 52f0f9b5e431119d0970f784d9e8a39892393d9b Mon Sep 17 00:00:00 2001 From: 1kzk <785396250@qq.com> Date: Thu, 16 Apr 2026 16:26:59 +0800 Subject: [PATCH] [0.18.0][BugFix]: order acl graph updates before model forward for ENPU (#8317) ### What this PR does / why we need it? For the ENPU scenario, it is required that device events follow the principle of "record first, wait later", otherwise the inference process may become stuck. However, in the current model_forward function, event.wait precedes event.record. Therefore, for the ENPU scenario, graph parameter updates should be performed before model execution. ### Does this PR introduce _any_ user-facing change? N/A ### How was this patch tested? --------- Signed-off-by: 1zzk <785396250@qq.com> Signed-off-by: 1kzk <785396250@qq.com> Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com> --- vllm_ascend/utils.py | 18 ++++++++ vllm_ascend/worker/model_runner_v1.py | 65 +++++++++++++++++++-------- 2 files changed, 64 insertions(+), 19 deletions(-) diff --git a/vllm_ascend/utils.py b/vllm_ascend/utils.py index 923e1a6f..3b0632ba 100644 --- a/vllm_ascend/utils.py +++ b/vllm_ascend/utils.py @@ -1353,3 +1353,21 @@ def parse_layer_idx(prefix: str) -> int | None: """Extract the layer index from a module prefix string like 'model.layers.0.self_attn'.""" match = re.search(r"layers\.(\d+)", prefix) return int(match.group(1)) if match else None + + +@lru_cache(maxsize=1) +def _libc_getenv(): + import ctypes + + libc = ctypes.CDLL(None) + libc.getenv.argtypes = [ctypes.c_char_p] + libc.getenv.restype = ctypes.c_char_p + return libc.getenv + + +def get_c_env(name: str, encoding: str = "utf-8") -> str | None: + """Read env via C getenv; returns None if unset.""" + raw = _libc_getenv()(name.encode(encoding)) + if raw is None: + return None + return raw.decode(encoding) diff --git a/vllm_ascend/worker/model_runner_v1.py b/vllm_ascend/worker/model_runner_v1.py index 5fc405dd..4a6bf06f 100644 --- a/vllm_ascend/worker/model_runner_v1.py +++ b/vllm_ascend/worker/model_runner_v1.py @@ -23,6 +23,7 @@ from collections import defaultdict from contextlib import contextmanager, nullcontext from copy import copy, deepcopy from dataclasses import dataclass +from functools import partial from multiprocessing import Manager from typing import TYPE_CHECKING, Any, NamedTuple, TypeAlias @@ -36,7 +37,7 @@ from vllm.distributed import get_tensor_model_parallel_world_size, tensor_model_ from vllm.distributed.ec_transfer import get_ec_transfer, has_ec_transfer from vllm.distributed.kv_transfer import get_kv_transfer_group, has_kv_transfer_group from vllm.distributed.parallel_state import get_dcp_group, get_dp_group, get_pcp_group, get_pp_group, get_tp_group -from vllm.forward_context import BatchDescriptor, get_forward_context +from vllm.forward_context import BatchDescriptor, ForwardContext, get_forward_context from vllm.logger import logger from vllm.model_executor.layers.attention_layer_base import AttentionLayerBase from vllm.model_executor.layers.mamba.abstract import MambaBase @@ -122,6 +123,7 @@ from vllm_ascend.utils import ( check_gdn_layer, enable_sp, enable_sp_by_pass, + get_c_env, global_stream, lmhead_tp_enable, set_weight_prefetch_method, @@ -422,6 +424,9 @@ class NPUModelRunner(GPUModelRunner): self.cudagraph_batch_sizes = [] self.mamba_state_idx: dict[str, int] = {} self._mamba_copy_bufs: mamba_utils.MambaCopyBuffers | None = None + env_enpu_enable = get_c_env("ENPU_ENABLE") + # When True, run update_full_graph_params before self.model (ENPU / graph capture order). + self.enable_enpu = env_enpu_enable is not None and env_enpu_enable.lower() == "true" @property def use_cp(self) -> bool: @@ -1767,25 +1772,12 @@ class NPUModelRunner(GPUModelRunner): ) return NPUModelRunner._all_gather_hidden_states(hidden_states) - def _model_forward( + def _update_full_graph_params_if_needed( self, + forward_context: ForwardContext, num_tokens_padded: int, - input_ids: torch.Tensor | None = None, - positions: torch.Tensor | None = None, - intermediate_tensors: IntermediateTensors | None = None, - inputs_embeds: torch.Tensor | None = None, - **model_kwargs: dict[str, Any], - ): - assert self.model is not None - hidden_states = self.model( - input_ids=input_ids, - positions=positions, - intermediate_tensors=intermediate_tensors, - inputs_embeds=inputs_embeds, - **model_kwargs, - ) - forward_context = get_forward_context() - assert forward_context is not None + positions: torch.Tensor | None, + ) -> None: if ( forward_context.cudagraph_runtime_mode == CUDAGraphMode.FULL and not forward_context.capturing @@ -1801,7 +1793,42 @@ class NPUModelRunner(GPUModelRunner): self.speculative_config, positions.shape[0], ) - if get_forward_context().flash_comm_v1_enabled and not isinstance(hidden_states, IntermediateTensors): + + def _model_forward( + self, + num_tokens_padded: int, + input_ids: torch.Tensor | None = None, + positions: torch.Tensor | None = None, + intermediate_tensors: IntermediateTensors | None = None, + inputs_embeds: torch.Tensor | None = None, + **model_kwargs: dict[str, Any], + ): + assert self.model is not None + forward_context = get_forward_context() + assert forward_context is not None + + model_inputs: dict[str, Any] = { + "input_ids": input_ids, + "positions": positions, + "intermediate_tensors": intermediate_tensors, + "inputs_embeds": inputs_embeds, + **model_kwargs, + } + run_model = partial(self.model, **model_inputs) + + if self.enable_enpu: + # The soft segmentation scenario requires event.record first, then event.wait + self._update_full_graph_params_if_needed( + forward_context, num_tokens_padded, positions + ) + hidden_states = run_model() + else: + hidden_states = run_model() + self._update_full_graph_params_if_needed( + forward_context, num_tokens_padded, positions + ) + + if forward_context.flash_comm_v1_enabled and not isinstance(hidden_states, IntermediateTensors): hidden_states = self._all_gather_hidden_states_and_aux(hidden_states) return hidden_states