diff --git a/vllm_ascend/compilation/acl_graph.py b/vllm_ascend/compilation/acl_graph.py index 1ceb3434..060ecf98 100644 --- a/vllm_ascend/compilation/acl_graph.py +++ b/vllm_ascend/compilation/acl_graph.py @@ -20,8 +20,7 @@ from vllm.logger import logger from vllm.platforms import current_platform from vllm_ascend.ascend_forward_context import _EXTRA_CTX - -from ..utils import weak_ref_tensors +from vllm_ascend.utils import weak_ref_tensors @dataclasses.dataclass @@ -66,6 +65,9 @@ class ACLGraphWrapper: vllm_config: VllmConfig, runtime_mode: CUDAGraphMode, cudagraph_options: CUDAGraphOptions | None = None, + *, + use_eagle: bool = False, + enable_enpu: bool = False, ): self.runnable = runnable self.vllm_config = vllm_config @@ -87,6 +89,8 @@ class ACLGraphWrapper: # the entries for different batch descriptors that we need to capture # aclgraphs for. self.concrete_aclgraph_entries: dict[BatchDescriptor, ACLGraphEntry] = {} + self.enable_enpu = enable_enpu + self.use_eagle = use_eagle def __getattr__(self, key: str): # allow accessing the attributes of the runnable. @@ -197,12 +201,11 @@ class ACLGraphWrapper: # so that update_attn_params only executes after the previous graph replay has fully completed. # If we do not in main model and in full-graph mode when using merge-eagle-graph, # we do not need to synchronize. - use_eagle = ( - self.vllm_config.speculative_config.method in ("eagle", "eagle3") - if self.vllm_config.speculative_config - else False - ) - if self.runtime_mode != CUDAGraphMode.FULL or not _EXTRA_CTX.is_draft_model or not use_eagle: + # When enable_enpu is on, model_runner orders update vs replay; skip here. + # When FULL + EAGLE draft (merge path), replay does not need this barrier. + is_draft_eagle = _EXTRA_CTX.is_draft_model and self.use_eagle + need_sync = self.runtime_mode == CUDAGraphMode.FULL and not is_draft_eagle + if not self.enable_enpu and need_sync: torch.npu.current_stream().synchronize() entry.aclgraph.replay() return entry.output diff --git a/vllm_ascend/spec_decode/eagle_proposer.py b/vllm_ascend/spec_decode/eagle_proposer.py index 7e193371..b4338345 100644 --- a/vllm_ascend/spec_decode/eagle_proposer.py +++ b/vllm_ascend/spec_decode/eagle_proposer.py @@ -2,6 +2,7 @@ import copy from collections.abc import Callable from contextlib import AbstractContextManager, contextmanager, nullcontext +from functools import partial from typing import Any import numpy as np @@ -17,7 +18,7 @@ from vllm.distributed.parallel_state import ( init_model_parallel_group, patch_tensor_parallel_group, ) -from vllm.forward_context import BatchDescriptor, get_forward_context +from vllm.forward_context import BatchDescriptor, ForwardContext, get_forward_context from vllm.logger import logger from vllm.model_executor.layers.attention_layer_base import AttentionLayerBase from vllm.model_executor.model_loader import get_model @@ -169,6 +170,8 @@ class SpecDecodeBaseProposer(EagleProposer): self.positions = torch.zeros(self.max_num_tokens, dtype=torch.int32, device=device) self.token_arange_np = np.arange(self.max_num_tokens + 1) + self.enable_enpu = self.runner.enable_enpu + self.use_eagle = self.runner.use_eagle def _get_model(self) -> nn.Module: """ @@ -321,10 +324,20 @@ class SpecDecodeBaseProposer(EagleProposer): if self.vllm_config.compilation_config.cudagraph_mode.has_full_cudagraphs() and self.use_cuda_graph: self.update_stream = torch.npu.Stream() if self.method == "mtp": - self.model = ACLGraphWrapper(self.model, self.vllm_config, runtime_mode=CUDAGraphMode.FULL) + self.model = ACLGraphWrapper( + self.model, + self.vllm_config, + runtime_mode=CUDAGraphMode.FULL, + use_eagle=self.use_eagle, + enable_enpu=self.enable_enpu, + ) else: self._runnable = ACLGraphWrapper( - self._run_merged_draft, self.vllm_config, runtime_mode=CUDAGraphMode.FULL + self._run_merged_draft, + self.vllm_config, + runtime_mode=CUDAGraphMode.FULL, + use_eagle=self.use_eagle, + enable_enpu=self.enable_enpu, ) def get_model(self) -> nn.Module: @@ -458,6 +471,15 @@ class SpecDecodeBaseProposer(EagleProposer): if forward_context.cudagraph_runtime_mode == CUDAGraphMode.FULL and not _EXTRA_CTX.capturing: self._update_full_graph_params(forward_context, num_tokens, multi_steps_attn_metadata) + def _update_full_graph_params_if_needed( + self, + forward_context: ForwardContext, + num_input_tokens: int, + multi_steps_attn_metadata: list[dict[str, Any]], + ) -> None: + if forward_context.cudagraph_runtime_mode == CUDAGraphMode.FULL: + self._update_full_graph_params(forward_context, num_input_tokens, multi_steps_attn_metadata) + def _propose( self, # [num_tokens] @@ -695,20 +717,24 @@ class SpecDecodeBaseProposer(EagleProposer): if forward_context is not None: forward_context.moe_layer_index = 0 - draft_token_ids = self._runnable( - num_input_tokens=num_input_tokens, - batch_size=batch_size, - token_indices_to_sample=self.token_indices_to_sample[:token_indices_to_sample_len], - target_positions=target_positions, - inputs_embeds=inputs_embeds, - multi_steps_attn_metadata=multi_steps_attn_metadata, - num_tokens=num_tokens, - is_prefill=attn_metadata_i.num_prefills, - ) + model_inputs: dict[str, Any] = { + "num_input_tokens": num_input_tokens, + "batch_size": batch_size, + "token_indices_to_sample": self.token_indices_to_sample[:token_indices_to_sample_len], + "target_positions": target_positions, + "inputs_embeds": inputs_embeds, + "multi_steps_attn_metadata": multi_steps_attn_metadata, + "num_tokens": num_tokens, + "is_prefill": attn_metadata_i.num_prefills, + } + run_draft = partial(self._runnable, **model_inputs) - forward_context = get_forward_context() - if forward_context.cudagraph_runtime_mode == CUDAGraphMode.FULL: - self._update_full_graph_params(forward_context, num_input_tokens, multi_steps_attn_metadata) + if self.enable_enpu: + self._update_full_graph_params_if_needed(forward_context, num_input_tokens, multi_steps_attn_metadata) + draft_token_ids = run_draft() + else: + draft_token_ids = run_draft() + self._update_full_graph_params_if_needed(forward_context, num_input_tokens, multi_steps_attn_metadata) return draft_token_ids def _run_merged_draft( diff --git a/vllm_ascend/worker/model_runner_v1.py b/vllm_ascend/worker/model_runner_v1.py index 357ee039..48d24a07 100644 --- a/vllm_ascend/worker/model_runner_v1.py +++ b/vllm_ascend/worker/model_runner_v1.py @@ -336,6 +336,15 @@ class NPUModelRunner(GPUModelRunner): self.input_ids = self._make_buffer(max_buffer_num_tokens, dtype=torch.int32) self.positions = self._make_buffer(max_buffer_num_tokens, dtype=torch.int64) + self.use_eagle = ( + vllm_config.speculative_config.method in ("eagle", "eagle3", "mtp") + if vllm_config.speculative_config + else False + ) + # When True, run update_full_graph_params before self.model (ENPU / graph capture order). + # Internal / non-public toggle: read C getenv ``ENPU_ENABLE`` from enpu code (not in envs.py). + _enpu = get_c_env("ENPU_ENABLE") + self.enable_enpu = _enpu is not None and _enpu.lower() == "true" self._set_up_drafter() # kv role @@ -424,9 +433,6 @@ class NPUModelRunner(GPUModelRunner): self.cudagraph_batch_sizes = [] self.mamba_state_idx: dict[str, int] = {} self._mamba_copy_bufs: mamba_utils.MambaCopyBuffers | None = None - env_enpu_enable = get_c_env("ENPU_ENABLE") - # When True, run update_full_graph_params before self.model (ENPU / graph capture order). - self.enable_enpu = env_enpu_enable is not None and env_enpu_enable.lower() == "true" @property def use_cp(self) -> bool: @@ -1795,6 +1801,9 @@ class NPUModelRunner(GPUModelRunner): and not forward_context.capturing and not self.use_sparse ): + if self.enable_enpu: + torch.npu.current_stream().synchronize() + assert positions is not None update_full_graph_params( self.attn_backend, @@ -2592,7 +2601,13 @@ class NPUModelRunner(GPUModelRunner): # wrap the model with full graph wrapper if needed. if self.compilation_config.cudagraph_mode.has_full_cudagraphs(): self.update_stream: torch.npu.Stream = torch.npu.Stream() - self.model = ACLGraphWrapper(self.model, self.vllm_config, runtime_mode=CUDAGraphMode.FULL) + self.model = ACLGraphWrapper( + self.model, + self.vllm_config, + runtime_mode=CUDAGraphMode.FULL, + use_eagle=self.use_eagle, + enable_enpu=self.enable_enpu, + ) def initialize_kv_cache(self, kv_cache_config: KVCacheConfig) -> None: """