[BugFix] fix hang in async scheduling while open ENPU (#8354)

### What this PR does / why we need it?
1. there is no synchronization between steps. However, in async
scheduling with aclgraph, it is possible that the CPU's record event for
the current iteration completes before the previous iteration's graph
execution has finished. If cpu is fast enough, device will hang on
event_wait in interation i+1 (assume that event_record is executed
immediately on update stream of device).
2. Under ENPU, eagle proposers also need to follow event.record first,
and then event.Wait.

### Does this PR introduce _any_ user-facing change?
N/A

### How was this patch tested?

---------

Signed-off-by: 1zzk <785396250@qq.com>
This commit is contained in:
1kzk
2026-04-18 00:07:15 +08:00
committed by GitHub
parent f81f9a3c89
commit c995a959e6
3 changed files with 72 additions and 28 deletions

View File

@@ -20,8 +20,7 @@ from vllm.logger import logger
from vllm.platforms import current_platform from vllm.platforms import current_platform
from vllm_ascend.ascend_forward_context import _EXTRA_CTX from vllm_ascend.ascend_forward_context import _EXTRA_CTX
from vllm_ascend.utils import weak_ref_tensors
from ..utils import weak_ref_tensors
@dataclasses.dataclass @dataclasses.dataclass
@@ -66,6 +65,9 @@ class ACLGraphWrapper:
vllm_config: VllmConfig, vllm_config: VllmConfig,
runtime_mode: CUDAGraphMode, runtime_mode: CUDAGraphMode,
cudagraph_options: CUDAGraphOptions | None = None, cudagraph_options: CUDAGraphOptions | None = None,
*,
use_eagle: bool = False,
enable_enpu: bool = False,
): ):
self.runnable = runnable self.runnable = runnable
self.vllm_config = vllm_config self.vllm_config = vllm_config
@@ -87,6 +89,8 @@ class ACLGraphWrapper:
# the entries for different batch descriptors that we need to capture # the entries for different batch descriptors that we need to capture
# aclgraphs for. # aclgraphs for.
self.concrete_aclgraph_entries: dict[BatchDescriptor, ACLGraphEntry] = {} self.concrete_aclgraph_entries: dict[BatchDescriptor, ACLGraphEntry] = {}
self.enable_enpu = enable_enpu
self.use_eagle = use_eagle
def __getattr__(self, key: str): def __getattr__(self, key: str):
# allow accessing the attributes of the runnable. # allow accessing the attributes of the runnable.
@@ -197,12 +201,11 @@ class ACLGraphWrapper:
# so that update_attn_params only executes after the previous graph replay has fully completed. # so that update_attn_params only executes after the previous graph replay has fully completed.
# If we do not in main model and in full-graph mode when using merge-eagle-graph, # If we do not in main model and in full-graph mode when using merge-eagle-graph,
# we do not need to synchronize. # we do not need to synchronize.
use_eagle = ( # When enable_enpu is on, model_runner orders update vs replay; skip here.
self.vllm_config.speculative_config.method in ("eagle", "eagle3") # When FULL + EAGLE draft (merge path), replay does not need this barrier.
if self.vllm_config.speculative_config is_draft_eagle = _EXTRA_CTX.is_draft_model and self.use_eagle
else False need_sync = self.runtime_mode == CUDAGraphMode.FULL and not is_draft_eagle
) if not self.enable_enpu and need_sync:
if self.runtime_mode != CUDAGraphMode.FULL or not _EXTRA_CTX.is_draft_model or not use_eagle:
torch.npu.current_stream().synchronize() torch.npu.current_stream().synchronize()
entry.aclgraph.replay() entry.aclgraph.replay()
return entry.output return entry.output

View File

@@ -2,6 +2,7 @@
import copy import copy
from collections.abc import Callable from collections.abc import Callable
from contextlib import AbstractContextManager, contextmanager, nullcontext from contextlib import AbstractContextManager, contextmanager, nullcontext
from functools import partial
from typing import Any from typing import Any
import numpy as np import numpy as np
@@ -17,7 +18,7 @@ from vllm.distributed.parallel_state import (
init_model_parallel_group, init_model_parallel_group,
patch_tensor_parallel_group, patch_tensor_parallel_group,
) )
from vllm.forward_context import BatchDescriptor, get_forward_context from vllm.forward_context import BatchDescriptor, ForwardContext, get_forward_context
from vllm.logger import logger from vllm.logger import logger
from vllm.model_executor.layers.attention_layer_base import AttentionLayerBase from vllm.model_executor.layers.attention_layer_base import AttentionLayerBase
from vllm.model_executor.model_loader import get_model from vllm.model_executor.model_loader import get_model
@@ -169,6 +170,8 @@ class SpecDecodeBaseProposer(EagleProposer):
self.positions = torch.zeros(self.max_num_tokens, dtype=torch.int32, device=device) self.positions = torch.zeros(self.max_num_tokens, dtype=torch.int32, device=device)
self.token_arange_np = np.arange(self.max_num_tokens + 1) self.token_arange_np = np.arange(self.max_num_tokens + 1)
self.enable_enpu = self.runner.enable_enpu
self.use_eagle = self.runner.use_eagle
def _get_model(self) -> nn.Module: def _get_model(self) -> nn.Module:
""" """
@@ -321,10 +324,20 @@ class SpecDecodeBaseProposer(EagleProposer):
if self.vllm_config.compilation_config.cudagraph_mode.has_full_cudagraphs() and self.use_cuda_graph: if self.vllm_config.compilation_config.cudagraph_mode.has_full_cudagraphs() and self.use_cuda_graph:
self.update_stream = torch.npu.Stream() self.update_stream = torch.npu.Stream()
if self.method == "mtp": if self.method == "mtp":
self.model = ACLGraphWrapper(self.model, self.vllm_config, runtime_mode=CUDAGraphMode.FULL) self.model = ACLGraphWrapper(
self.model,
self.vllm_config,
runtime_mode=CUDAGraphMode.FULL,
use_eagle=self.use_eagle,
enable_enpu=self.enable_enpu,
)
else: else:
self._runnable = ACLGraphWrapper( self._runnable = ACLGraphWrapper(
self._run_merged_draft, self.vllm_config, runtime_mode=CUDAGraphMode.FULL self._run_merged_draft,
self.vllm_config,
runtime_mode=CUDAGraphMode.FULL,
use_eagle=self.use_eagle,
enable_enpu=self.enable_enpu,
) )
def get_model(self) -> nn.Module: def get_model(self) -> nn.Module:
@@ -458,6 +471,15 @@ class SpecDecodeBaseProposer(EagleProposer):
if forward_context.cudagraph_runtime_mode == CUDAGraphMode.FULL and not _EXTRA_CTX.capturing: if forward_context.cudagraph_runtime_mode == CUDAGraphMode.FULL and not _EXTRA_CTX.capturing:
self._update_full_graph_params(forward_context, num_tokens, multi_steps_attn_metadata) self._update_full_graph_params(forward_context, num_tokens, multi_steps_attn_metadata)
def _update_full_graph_params_if_needed(
self,
forward_context: ForwardContext,
num_input_tokens: int,
multi_steps_attn_metadata: list[dict[str, Any]],
) -> None:
if forward_context.cudagraph_runtime_mode == CUDAGraphMode.FULL:
self._update_full_graph_params(forward_context, num_input_tokens, multi_steps_attn_metadata)
def _propose( def _propose(
self, self,
# [num_tokens] # [num_tokens]
@@ -695,20 +717,24 @@ class SpecDecodeBaseProposer(EagleProposer):
if forward_context is not None: if forward_context is not None:
forward_context.moe_layer_index = 0 forward_context.moe_layer_index = 0
draft_token_ids = self._runnable( model_inputs: dict[str, Any] = {
num_input_tokens=num_input_tokens, "num_input_tokens": num_input_tokens,
batch_size=batch_size, "batch_size": batch_size,
token_indices_to_sample=self.token_indices_to_sample[:token_indices_to_sample_len], "token_indices_to_sample": self.token_indices_to_sample[:token_indices_to_sample_len],
target_positions=target_positions, "target_positions": target_positions,
inputs_embeds=inputs_embeds, "inputs_embeds": inputs_embeds,
multi_steps_attn_metadata=multi_steps_attn_metadata, "multi_steps_attn_metadata": multi_steps_attn_metadata,
num_tokens=num_tokens, "num_tokens": num_tokens,
is_prefill=attn_metadata_i.num_prefills, "is_prefill": attn_metadata_i.num_prefills,
) }
run_draft = partial(self._runnable, **model_inputs)
forward_context = get_forward_context() if self.enable_enpu:
if forward_context.cudagraph_runtime_mode == CUDAGraphMode.FULL: self._update_full_graph_params_if_needed(forward_context, num_input_tokens, multi_steps_attn_metadata)
self._update_full_graph_params(forward_context, num_input_tokens, multi_steps_attn_metadata) draft_token_ids = run_draft()
else:
draft_token_ids = run_draft()
self._update_full_graph_params_if_needed(forward_context, num_input_tokens, multi_steps_attn_metadata)
return draft_token_ids return draft_token_ids
def _run_merged_draft( def _run_merged_draft(

View File

@@ -336,6 +336,15 @@ class NPUModelRunner(GPUModelRunner):
self.input_ids = self._make_buffer(max_buffer_num_tokens, dtype=torch.int32) self.input_ids = self._make_buffer(max_buffer_num_tokens, dtype=torch.int32)
self.positions = self._make_buffer(max_buffer_num_tokens, dtype=torch.int64) self.positions = self._make_buffer(max_buffer_num_tokens, dtype=torch.int64)
self.use_eagle = (
vllm_config.speculative_config.method in ("eagle", "eagle3", "mtp")
if vllm_config.speculative_config
else False
)
# When True, run update_full_graph_params before self.model (ENPU / graph capture order).
# Internal / non-public toggle: read C getenv ``ENPU_ENABLE`` from enpu code (not in envs.py).
_enpu = get_c_env("ENPU_ENABLE")
self.enable_enpu = _enpu is not None and _enpu.lower() == "true"
self._set_up_drafter() self._set_up_drafter()
# kv role # kv role
@@ -424,9 +433,6 @@ class NPUModelRunner(GPUModelRunner):
self.cudagraph_batch_sizes = [] self.cudagraph_batch_sizes = []
self.mamba_state_idx: dict[str, int] = {} self.mamba_state_idx: dict[str, int] = {}
self._mamba_copy_bufs: mamba_utils.MambaCopyBuffers | None = None self._mamba_copy_bufs: mamba_utils.MambaCopyBuffers | None = None
env_enpu_enable = get_c_env("ENPU_ENABLE")
# When True, run update_full_graph_params before self.model (ENPU / graph capture order).
self.enable_enpu = env_enpu_enable is not None and env_enpu_enable.lower() == "true"
@property @property
def use_cp(self) -> bool: def use_cp(self) -> bool:
@@ -1795,6 +1801,9 @@ class NPUModelRunner(GPUModelRunner):
and not forward_context.capturing and not forward_context.capturing
and not self.use_sparse and not self.use_sparse
): ):
if self.enable_enpu:
torch.npu.current_stream().synchronize()
assert positions is not None assert positions is not None
update_full_graph_params( update_full_graph_params(
self.attn_backend, self.attn_backend,
@@ -2592,7 +2601,13 @@ class NPUModelRunner(GPUModelRunner):
# wrap the model with full graph wrapper if needed. # wrap the model with full graph wrapper if needed.
if self.compilation_config.cudagraph_mode.has_full_cudagraphs(): if self.compilation_config.cudagraph_mode.has_full_cudagraphs():
self.update_stream: torch.npu.Stream = torch.npu.Stream() self.update_stream: torch.npu.Stream = torch.npu.Stream()
self.model = ACLGraphWrapper(self.model, self.vllm_config, runtime_mode=CUDAGraphMode.FULL) self.model = ACLGraphWrapper(
self.model,
self.vllm_config,
runtime_mode=CUDAGraphMode.FULL,
use_eagle=self.use_eagle,
enable_enpu=self.enable_enpu,
)
def initialize_kv_cache(self, kv_cache_config: KVCacheConfig) -> None: def initialize_kv_cache(self, kv_cache_config: KVCacheConfig) -> None:
""" """