[BugFix] fix hang in async scheduling while open ENPU (#8354)

### What this PR does / why we need it?
1. there is no synchronization between steps. However, in async
scheduling with aclgraph, it is possible that the CPU's record event for
the current iteration completes before the previous iteration's graph
execution has finished. If cpu is fast enough, device will hang on
event_wait in interation i+1 (assume that event_record is executed
immediately on update stream of device).
2. Under ENPU, eagle proposers also need to follow event.record first,
and then event.Wait.

### Does this PR introduce _any_ user-facing change?
N/A

### How was this patch tested?

---------

Signed-off-by: 1zzk <785396250@qq.com>
This commit is contained in:
1kzk
2026-04-18 00:07:15 +08:00
committed by GitHub
parent f81f9a3c89
commit c995a959e6
3 changed files with 72 additions and 28 deletions

View File

@@ -2,6 +2,7 @@
import copy
from collections.abc import Callable
from contextlib import AbstractContextManager, contextmanager, nullcontext
from functools import partial
from typing import Any
import numpy as np
@@ -17,7 +18,7 @@ from vllm.distributed.parallel_state import (
init_model_parallel_group,
patch_tensor_parallel_group,
)
from vllm.forward_context import BatchDescriptor, get_forward_context
from vllm.forward_context import BatchDescriptor, ForwardContext, get_forward_context
from vllm.logger import logger
from vllm.model_executor.layers.attention_layer_base import AttentionLayerBase
from vllm.model_executor.model_loader import get_model
@@ -169,6 +170,8 @@ class SpecDecodeBaseProposer(EagleProposer):
self.positions = torch.zeros(self.max_num_tokens, dtype=torch.int32, device=device)
self.token_arange_np = np.arange(self.max_num_tokens + 1)
self.enable_enpu = self.runner.enable_enpu
self.use_eagle = self.runner.use_eagle
def _get_model(self) -> nn.Module:
"""
@@ -321,10 +324,20 @@ class SpecDecodeBaseProposer(EagleProposer):
if self.vllm_config.compilation_config.cudagraph_mode.has_full_cudagraphs() and self.use_cuda_graph:
self.update_stream = torch.npu.Stream()
if self.method == "mtp":
self.model = ACLGraphWrapper(self.model, self.vllm_config, runtime_mode=CUDAGraphMode.FULL)
self.model = ACLGraphWrapper(
self.model,
self.vllm_config,
runtime_mode=CUDAGraphMode.FULL,
use_eagle=self.use_eagle,
enable_enpu=self.enable_enpu,
)
else:
self._runnable = ACLGraphWrapper(
self._run_merged_draft, self.vllm_config, runtime_mode=CUDAGraphMode.FULL
self._run_merged_draft,
self.vllm_config,
runtime_mode=CUDAGraphMode.FULL,
use_eagle=self.use_eagle,
enable_enpu=self.enable_enpu,
)
def get_model(self) -> nn.Module:
@@ -458,6 +471,15 @@ class SpecDecodeBaseProposer(EagleProposer):
if forward_context.cudagraph_runtime_mode == CUDAGraphMode.FULL and not _EXTRA_CTX.capturing:
self._update_full_graph_params(forward_context, num_tokens, multi_steps_attn_metadata)
def _update_full_graph_params_if_needed(
self,
forward_context: ForwardContext,
num_input_tokens: int,
multi_steps_attn_metadata: list[dict[str, Any]],
) -> None:
if forward_context.cudagraph_runtime_mode == CUDAGraphMode.FULL:
self._update_full_graph_params(forward_context, num_input_tokens, multi_steps_attn_metadata)
def _propose(
self,
# [num_tokens]
@@ -695,20 +717,24 @@ class SpecDecodeBaseProposer(EagleProposer):
if forward_context is not None:
forward_context.moe_layer_index = 0
draft_token_ids = self._runnable(
num_input_tokens=num_input_tokens,
batch_size=batch_size,
token_indices_to_sample=self.token_indices_to_sample[:token_indices_to_sample_len],
target_positions=target_positions,
inputs_embeds=inputs_embeds,
multi_steps_attn_metadata=multi_steps_attn_metadata,
num_tokens=num_tokens,
is_prefill=attn_metadata_i.num_prefills,
)
model_inputs: dict[str, Any] = {
"num_input_tokens": num_input_tokens,
"batch_size": batch_size,
"token_indices_to_sample": self.token_indices_to_sample[:token_indices_to_sample_len],
"target_positions": target_positions,
"inputs_embeds": inputs_embeds,
"multi_steps_attn_metadata": multi_steps_attn_metadata,
"num_tokens": num_tokens,
"is_prefill": attn_metadata_i.num_prefills,
}
run_draft = partial(self._runnable, **model_inputs)
forward_context = get_forward_context()
if forward_context.cudagraph_runtime_mode == CUDAGraphMode.FULL:
self._update_full_graph_params(forward_context, num_input_tokens, multi_steps_attn_metadata)
if self.enable_enpu:
self._update_full_graph_params_if_needed(forward_context, num_input_tokens, multi_steps_attn_metadata)
draft_token_ids = run_draft()
else:
draft_token_ids = run_draft()
self._update_full_graph_params_if_needed(forward_context, num_input_tokens, multi_steps_attn_metadata)
return draft_token_ids
def _run_merged_draft(