[BugFix] fix hang in async scheduling while open ENPU (#8354)
### What this PR does / why we need it? 1. there is no synchronization between steps. However, in async scheduling with aclgraph, it is possible that the CPU's record event for the current iteration completes before the previous iteration's graph execution has finished. If cpu is fast enough, device will hang on event_wait in interation i+1 (assume that event_record is executed immediately on update stream of device). 2. Under ENPU, eagle proposers also need to follow event.record first, and then event.Wait. ### Does this PR introduce _any_ user-facing change? N/A ### How was this patch tested? --------- Signed-off-by: 1zzk <785396250@qq.com>
This commit is contained in:
@@ -2,6 +2,7 @@
|
||||
import copy
|
||||
from collections.abc import Callable
|
||||
from contextlib import AbstractContextManager, contextmanager, nullcontext
|
||||
from functools import partial
|
||||
from typing import Any
|
||||
|
||||
import numpy as np
|
||||
@@ -17,7 +18,7 @@ from vllm.distributed.parallel_state import (
|
||||
init_model_parallel_group,
|
||||
patch_tensor_parallel_group,
|
||||
)
|
||||
from vllm.forward_context import BatchDescriptor, get_forward_context
|
||||
from vllm.forward_context import BatchDescriptor, ForwardContext, get_forward_context
|
||||
from vllm.logger import logger
|
||||
from vllm.model_executor.layers.attention_layer_base import AttentionLayerBase
|
||||
from vllm.model_executor.model_loader import get_model
|
||||
@@ -169,6 +170,8 @@ class SpecDecodeBaseProposer(EagleProposer):
|
||||
self.positions = torch.zeros(self.max_num_tokens, dtype=torch.int32, device=device)
|
||||
|
||||
self.token_arange_np = np.arange(self.max_num_tokens + 1)
|
||||
self.enable_enpu = self.runner.enable_enpu
|
||||
self.use_eagle = self.runner.use_eagle
|
||||
|
||||
def _get_model(self) -> nn.Module:
|
||||
"""
|
||||
@@ -321,10 +324,20 @@ class SpecDecodeBaseProposer(EagleProposer):
|
||||
if self.vllm_config.compilation_config.cudagraph_mode.has_full_cudagraphs() and self.use_cuda_graph:
|
||||
self.update_stream = torch.npu.Stream()
|
||||
if self.method == "mtp":
|
||||
self.model = ACLGraphWrapper(self.model, self.vllm_config, runtime_mode=CUDAGraphMode.FULL)
|
||||
self.model = ACLGraphWrapper(
|
||||
self.model,
|
||||
self.vllm_config,
|
||||
runtime_mode=CUDAGraphMode.FULL,
|
||||
use_eagle=self.use_eagle,
|
||||
enable_enpu=self.enable_enpu,
|
||||
)
|
||||
else:
|
||||
self._runnable = ACLGraphWrapper(
|
||||
self._run_merged_draft, self.vllm_config, runtime_mode=CUDAGraphMode.FULL
|
||||
self._run_merged_draft,
|
||||
self.vllm_config,
|
||||
runtime_mode=CUDAGraphMode.FULL,
|
||||
use_eagle=self.use_eagle,
|
||||
enable_enpu=self.enable_enpu,
|
||||
)
|
||||
|
||||
def get_model(self) -> nn.Module:
|
||||
@@ -458,6 +471,15 @@ class SpecDecodeBaseProposer(EagleProposer):
|
||||
if forward_context.cudagraph_runtime_mode == CUDAGraphMode.FULL and not _EXTRA_CTX.capturing:
|
||||
self._update_full_graph_params(forward_context, num_tokens, multi_steps_attn_metadata)
|
||||
|
||||
def _update_full_graph_params_if_needed(
|
||||
self,
|
||||
forward_context: ForwardContext,
|
||||
num_input_tokens: int,
|
||||
multi_steps_attn_metadata: list[dict[str, Any]],
|
||||
) -> None:
|
||||
if forward_context.cudagraph_runtime_mode == CUDAGraphMode.FULL:
|
||||
self._update_full_graph_params(forward_context, num_input_tokens, multi_steps_attn_metadata)
|
||||
|
||||
def _propose(
|
||||
self,
|
||||
# [num_tokens]
|
||||
@@ -695,20 +717,24 @@ class SpecDecodeBaseProposer(EagleProposer):
|
||||
if forward_context is not None:
|
||||
forward_context.moe_layer_index = 0
|
||||
|
||||
draft_token_ids = self._runnable(
|
||||
num_input_tokens=num_input_tokens,
|
||||
batch_size=batch_size,
|
||||
token_indices_to_sample=self.token_indices_to_sample[:token_indices_to_sample_len],
|
||||
target_positions=target_positions,
|
||||
inputs_embeds=inputs_embeds,
|
||||
multi_steps_attn_metadata=multi_steps_attn_metadata,
|
||||
num_tokens=num_tokens,
|
||||
is_prefill=attn_metadata_i.num_prefills,
|
||||
)
|
||||
model_inputs: dict[str, Any] = {
|
||||
"num_input_tokens": num_input_tokens,
|
||||
"batch_size": batch_size,
|
||||
"token_indices_to_sample": self.token_indices_to_sample[:token_indices_to_sample_len],
|
||||
"target_positions": target_positions,
|
||||
"inputs_embeds": inputs_embeds,
|
||||
"multi_steps_attn_metadata": multi_steps_attn_metadata,
|
||||
"num_tokens": num_tokens,
|
||||
"is_prefill": attn_metadata_i.num_prefills,
|
||||
}
|
||||
run_draft = partial(self._runnable, **model_inputs)
|
||||
|
||||
forward_context = get_forward_context()
|
||||
if forward_context.cudagraph_runtime_mode == CUDAGraphMode.FULL:
|
||||
self._update_full_graph_params(forward_context, num_input_tokens, multi_steps_attn_metadata)
|
||||
if self.enable_enpu:
|
||||
self._update_full_graph_params_if_needed(forward_context, num_input_tokens, multi_steps_attn_metadata)
|
||||
draft_token_ids = run_draft()
|
||||
else:
|
||||
draft_token_ids = run_draft()
|
||||
self._update_full_graph_params_if_needed(forward_context, num_input_tokens, multi_steps_attn_metadata)
|
||||
return draft_token_ids
|
||||
|
||||
def _run_merged_draft(
|
||||
|
||||
Reference in New Issue
Block a user