diff --git a/vllm_ascend/utils.py b/vllm_ascend/utils.py index 923e1a6f..3b0632ba 100644 --- a/vllm_ascend/utils.py +++ b/vllm_ascend/utils.py @@ -1353,3 +1353,21 @@ def parse_layer_idx(prefix: str) -> int | None: """Extract the layer index from a module prefix string like 'model.layers.0.self_attn'.""" match = re.search(r"layers\.(\d+)", prefix) return int(match.group(1)) if match else None + + +@lru_cache(maxsize=1) +def _libc_getenv(): + import ctypes + + libc = ctypes.CDLL(None) + libc.getenv.argtypes = [ctypes.c_char_p] + libc.getenv.restype = ctypes.c_char_p + return libc.getenv + + +def get_c_env(name: str, encoding: str = "utf-8") -> str | None: + """Read env via C getenv; returns None if unset.""" + raw = _libc_getenv()(name.encode(encoding)) + if raw is None: + return None + return raw.decode(encoding) diff --git a/vllm_ascend/worker/model_runner_v1.py b/vllm_ascend/worker/model_runner_v1.py index 5fc405dd..4a6bf06f 100644 --- a/vllm_ascend/worker/model_runner_v1.py +++ b/vllm_ascend/worker/model_runner_v1.py @@ -23,6 +23,7 @@ from collections import defaultdict from contextlib import contextmanager, nullcontext from copy import copy, deepcopy from dataclasses import dataclass +from functools import partial from multiprocessing import Manager from typing import TYPE_CHECKING, Any, NamedTuple, TypeAlias @@ -36,7 +37,7 @@ from vllm.distributed import get_tensor_model_parallel_world_size, tensor_model_ from vllm.distributed.ec_transfer import get_ec_transfer, has_ec_transfer from vllm.distributed.kv_transfer import get_kv_transfer_group, has_kv_transfer_group from vllm.distributed.parallel_state import get_dcp_group, get_dp_group, get_pcp_group, get_pp_group, get_tp_group -from vllm.forward_context import BatchDescriptor, get_forward_context +from vllm.forward_context import BatchDescriptor, ForwardContext, get_forward_context from vllm.logger import logger from vllm.model_executor.layers.attention_layer_base import AttentionLayerBase from vllm.model_executor.layers.mamba.abstract import MambaBase @@ -122,6 +123,7 @@ from vllm_ascend.utils import ( check_gdn_layer, enable_sp, enable_sp_by_pass, + get_c_env, global_stream, lmhead_tp_enable, set_weight_prefetch_method, @@ -422,6 +424,9 @@ class NPUModelRunner(GPUModelRunner): self.cudagraph_batch_sizes = [] self.mamba_state_idx: dict[str, int] = {} self._mamba_copy_bufs: mamba_utils.MambaCopyBuffers | None = None + env_enpu_enable = get_c_env("ENPU_ENABLE") + # When True, run update_full_graph_params before self.model (ENPU / graph capture order). + self.enable_enpu = env_enpu_enable is not None and env_enpu_enable.lower() == "true" @property def use_cp(self) -> bool: @@ -1767,25 +1772,12 @@ class NPUModelRunner(GPUModelRunner): ) return NPUModelRunner._all_gather_hidden_states(hidden_states) - def _model_forward( + def _update_full_graph_params_if_needed( self, + forward_context: ForwardContext, num_tokens_padded: int, - input_ids: torch.Tensor | None = None, - positions: torch.Tensor | None = None, - intermediate_tensors: IntermediateTensors | None = None, - inputs_embeds: torch.Tensor | None = None, - **model_kwargs: dict[str, Any], - ): - assert self.model is not None - hidden_states = self.model( - input_ids=input_ids, - positions=positions, - intermediate_tensors=intermediate_tensors, - inputs_embeds=inputs_embeds, - **model_kwargs, - ) - forward_context = get_forward_context() - assert forward_context is not None + positions: torch.Tensor | None, + ) -> None: if ( forward_context.cudagraph_runtime_mode == CUDAGraphMode.FULL and not forward_context.capturing @@ -1801,7 +1793,42 @@ class NPUModelRunner(GPUModelRunner): self.speculative_config, positions.shape[0], ) - if get_forward_context().flash_comm_v1_enabled and not isinstance(hidden_states, IntermediateTensors): + + def _model_forward( + self, + num_tokens_padded: int, + input_ids: torch.Tensor | None = None, + positions: torch.Tensor | None = None, + intermediate_tensors: IntermediateTensors | None = None, + inputs_embeds: torch.Tensor | None = None, + **model_kwargs: dict[str, Any], + ): + assert self.model is not None + forward_context = get_forward_context() + assert forward_context is not None + + model_inputs: dict[str, Any] = { + "input_ids": input_ids, + "positions": positions, + "intermediate_tensors": intermediate_tensors, + "inputs_embeds": inputs_embeds, + **model_kwargs, + } + run_model = partial(self.model, **model_inputs) + + if self.enable_enpu: + # The soft segmentation scenario requires event.record first, then event.wait + self._update_full_graph_params_if_needed( + forward_context, num_tokens_padded, positions + ) + hidden_states = run_model() + else: + hidden_states = run_model() + self._update_full_graph_params_if_needed( + forward_context, num_tokens_padded, positions + ) + + if forward_context.flash_comm_v1_enabled and not isinstance(hidden_states, IntermediateTensors): hidden_states = self._all_gather_hidden_states_and_aux(hidden_states) return hidden_states