[BugFix] fix hang in async scheduling while open ENPU (#8354)
### What this PR does / why we need it? 1. there is no synchronization between steps. However, in async scheduling with aclgraph, it is possible that the CPU's record event for the current iteration completes before the previous iteration's graph execution has finished. If cpu is fast enough, device will hang on event_wait in interation i+1 (assume that event_record is executed immediately on update stream of device). 2. Under ENPU, eagle proposers also need to follow event.record first, and then event.Wait. ### Does this PR introduce _any_ user-facing change? N/A ### How was this patch tested? --------- Signed-off-by: 1zzk <785396250@qq.com>
This commit is contained in:
@@ -20,8 +20,7 @@ from vllm.logger import logger
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
from vllm_ascend.ascend_forward_context import _EXTRA_CTX
|
||||
|
||||
from ..utils import weak_ref_tensors
|
||||
from vllm_ascend.utils import weak_ref_tensors
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
@@ -66,6 +65,9 @@ class ACLGraphWrapper:
|
||||
vllm_config: VllmConfig,
|
||||
runtime_mode: CUDAGraphMode,
|
||||
cudagraph_options: CUDAGraphOptions | None = None,
|
||||
*,
|
||||
use_eagle: bool = False,
|
||||
enable_enpu: bool = False,
|
||||
):
|
||||
self.runnable = runnable
|
||||
self.vllm_config = vllm_config
|
||||
@@ -87,6 +89,8 @@ class ACLGraphWrapper:
|
||||
# the entries for different batch descriptors that we need to capture
|
||||
# aclgraphs for.
|
||||
self.concrete_aclgraph_entries: dict[BatchDescriptor, ACLGraphEntry] = {}
|
||||
self.enable_enpu = enable_enpu
|
||||
self.use_eagle = use_eagle
|
||||
|
||||
def __getattr__(self, key: str):
|
||||
# allow accessing the attributes of the runnable.
|
||||
@@ -197,12 +201,11 @@ class ACLGraphWrapper:
|
||||
# so that update_attn_params only executes after the previous graph replay has fully completed.
|
||||
# If we do not in main model and in full-graph mode when using merge-eagle-graph,
|
||||
# we do not need to synchronize.
|
||||
use_eagle = (
|
||||
self.vllm_config.speculative_config.method in ("eagle", "eagle3")
|
||||
if self.vllm_config.speculative_config
|
||||
else False
|
||||
)
|
||||
if self.runtime_mode != CUDAGraphMode.FULL or not _EXTRA_CTX.is_draft_model or not use_eagle:
|
||||
# When enable_enpu is on, model_runner orders update vs replay; skip here.
|
||||
# When FULL + EAGLE draft (merge path), replay does not need this barrier.
|
||||
is_draft_eagle = _EXTRA_CTX.is_draft_model and self.use_eagle
|
||||
need_sync = self.runtime_mode == CUDAGraphMode.FULL and not is_draft_eagle
|
||||
if not self.enable_enpu and need_sync:
|
||||
torch.npu.current_stream().synchronize()
|
||||
entry.aclgraph.replay()
|
||||
return entry.output
|
||||
|
||||
@@ -2,6 +2,7 @@
|
||||
import copy
|
||||
from collections.abc import Callable
|
||||
from contextlib import AbstractContextManager, contextmanager, nullcontext
|
||||
from functools import partial
|
||||
from typing import Any
|
||||
|
||||
import numpy as np
|
||||
@@ -17,7 +18,7 @@ from vllm.distributed.parallel_state import (
|
||||
init_model_parallel_group,
|
||||
patch_tensor_parallel_group,
|
||||
)
|
||||
from vllm.forward_context import BatchDescriptor, get_forward_context
|
||||
from vllm.forward_context import BatchDescriptor, ForwardContext, get_forward_context
|
||||
from vllm.logger import logger
|
||||
from vllm.model_executor.layers.attention_layer_base import AttentionLayerBase
|
||||
from vllm.model_executor.model_loader import get_model
|
||||
@@ -169,6 +170,8 @@ class SpecDecodeBaseProposer(EagleProposer):
|
||||
self.positions = torch.zeros(self.max_num_tokens, dtype=torch.int32, device=device)
|
||||
|
||||
self.token_arange_np = np.arange(self.max_num_tokens + 1)
|
||||
self.enable_enpu = self.runner.enable_enpu
|
||||
self.use_eagle = self.runner.use_eagle
|
||||
|
||||
def _get_model(self) -> nn.Module:
|
||||
"""
|
||||
@@ -321,10 +324,20 @@ class SpecDecodeBaseProposer(EagleProposer):
|
||||
if self.vllm_config.compilation_config.cudagraph_mode.has_full_cudagraphs() and self.use_cuda_graph:
|
||||
self.update_stream = torch.npu.Stream()
|
||||
if self.method == "mtp":
|
||||
self.model = ACLGraphWrapper(self.model, self.vllm_config, runtime_mode=CUDAGraphMode.FULL)
|
||||
self.model = ACLGraphWrapper(
|
||||
self.model,
|
||||
self.vllm_config,
|
||||
runtime_mode=CUDAGraphMode.FULL,
|
||||
use_eagle=self.use_eagle,
|
||||
enable_enpu=self.enable_enpu,
|
||||
)
|
||||
else:
|
||||
self._runnable = ACLGraphWrapper(
|
||||
self._run_merged_draft, self.vllm_config, runtime_mode=CUDAGraphMode.FULL
|
||||
self._run_merged_draft,
|
||||
self.vllm_config,
|
||||
runtime_mode=CUDAGraphMode.FULL,
|
||||
use_eagle=self.use_eagle,
|
||||
enable_enpu=self.enable_enpu,
|
||||
)
|
||||
|
||||
def get_model(self) -> nn.Module:
|
||||
@@ -458,6 +471,15 @@ class SpecDecodeBaseProposer(EagleProposer):
|
||||
if forward_context.cudagraph_runtime_mode == CUDAGraphMode.FULL and not _EXTRA_CTX.capturing:
|
||||
self._update_full_graph_params(forward_context, num_tokens, multi_steps_attn_metadata)
|
||||
|
||||
def _update_full_graph_params_if_needed(
|
||||
self,
|
||||
forward_context: ForwardContext,
|
||||
num_input_tokens: int,
|
||||
multi_steps_attn_metadata: list[dict[str, Any]],
|
||||
) -> None:
|
||||
if forward_context.cudagraph_runtime_mode == CUDAGraphMode.FULL:
|
||||
self._update_full_graph_params(forward_context, num_input_tokens, multi_steps_attn_metadata)
|
||||
|
||||
def _propose(
|
||||
self,
|
||||
# [num_tokens]
|
||||
@@ -695,20 +717,24 @@ class SpecDecodeBaseProposer(EagleProposer):
|
||||
if forward_context is not None:
|
||||
forward_context.moe_layer_index = 0
|
||||
|
||||
draft_token_ids = self._runnable(
|
||||
num_input_tokens=num_input_tokens,
|
||||
batch_size=batch_size,
|
||||
token_indices_to_sample=self.token_indices_to_sample[:token_indices_to_sample_len],
|
||||
target_positions=target_positions,
|
||||
inputs_embeds=inputs_embeds,
|
||||
multi_steps_attn_metadata=multi_steps_attn_metadata,
|
||||
num_tokens=num_tokens,
|
||||
is_prefill=attn_metadata_i.num_prefills,
|
||||
)
|
||||
model_inputs: dict[str, Any] = {
|
||||
"num_input_tokens": num_input_tokens,
|
||||
"batch_size": batch_size,
|
||||
"token_indices_to_sample": self.token_indices_to_sample[:token_indices_to_sample_len],
|
||||
"target_positions": target_positions,
|
||||
"inputs_embeds": inputs_embeds,
|
||||
"multi_steps_attn_metadata": multi_steps_attn_metadata,
|
||||
"num_tokens": num_tokens,
|
||||
"is_prefill": attn_metadata_i.num_prefills,
|
||||
}
|
||||
run_draft = partial(self._runnable, **model_inputs)
|
||||
|
||||
forward_context = get_forward_context()
|
||||
if forward_context.cudagraph_runtime_mode == CUDAGraphMode.FULL:
|
||||
self._update_full_graph_params(forward_context, num_input_tokens, multi_steps_attn_metadata)
|
||||
if self.enable_enpu:
|
||||
self._update_full_graph_params_if_needed(forward_context, num_input_tokens, multi_steps_attn_metadata)
|
||||
draft_token_ids = run_draft()
|
||||
else:
|
||||
draft_token_ids = run_draft()
|
||||
self._update_full_graph_params_if_needed(forward_context, num_input_tokens, multi_steps_attn_metadata)
|
||||
return draft_token_ids
|
||||
|
||||
def _run_merged_draft(
|
||||
|
||||
@@ -336,6 +336,15 @@ class NPUModelRunner(GPUModelRunner):
|
||||
self.input_ids = self._make_buffer(max_buffer_num_tokens, dtype=torch.int32)
|
||||
self.positions = self._make_buffer(max_buffer_num_tokens, dtype=torch.int64)
|
||||
|
||||
self.use_eagle = (
|
||||
vllm_config.speculative_config.method in ("eagle", "eagle3", "mtp")
|
||||
if vllm_config.speculative_config
|
||||
else False
|
||||
)
|
||||
# When True, run update_full_graph_params before self.model (ENPU / graph capture order).
|
||||
# Internal / non-public toggle: read C getenv ``ENPU_ENABLE`` from enpu code (not in envs.py).
|
||||
_enpu = get_c_env("ENPU_ENABLE")
|
||||
self.enable_enpu = _enpu is not None and _enpu.lower() == "true"
|
||||
self._set_up_drafter()
|
||||
|
||||
# kv role
|
||||
@@ -424,9 +433,6 @@ class NPUModelRunner(GPUModelRunner):
|
||||
self.cudagraph_batch_sizes = []
|
||||
self.mamba_state_idx: dict[str, int] = {}
|
||||
self._mamba_copy_bufs: mamba_utils.MambaCopyBuffers | None = None
|
||||
env_enpu_enable = get_c_env("ENPU_ENABLE")
|
||||
# When True, run update_full_graph_params before self.model (ENPU / graph capture order).
|
||||
self.enable_enpu = env_enpu_enable is not None and env_enpu_enable.lower() == "true"
|
||||
|
||||
@property
|
||||
def use_cp(self) -> bool:
|
||||
@@ -1795,6 +1801,9 @@ class NPUModelRunner(GPUModelRunner):
|
||||
and not forward_context.capturing
|
||||
and not self.use_sparse
|
||||
):
|
||||
if self.enable_enpu:
|
||||
torch.npu.current_stream().synchronize()
|
||||
|
||||
assert positions is not None
|
||||
update_full_graph_params(
|
||||
self.attn_backend,
|
||||
@@ -2592,7 +2601,13 @@ class NPUModelRunner(GPUModelRunner):
|
||||
# wrap the model with full graph wrapper if needed.
|
||||
if self.compilation_config.cudagraph_mode.has_full_cudagraphs():
|
||||
self.update_stream: torch.npu.Stream = torch.npu.Stream()
|
||||
self.model = ACLGraphWrapper(self.model, self.vllm_config, runtime_mode=CUDAGraphMode.FULL)
|
||||
self.model = ACLGraphWrapper(
|
||||
self.model,
|
||||
self.vllm_config,
|
||||
runtime_mode=CUDAGraphMode.FULL,
|
||||
use_eagle=self.use_eagle,
|
||||
enable_enpu=self.enable_enpu,
|
||||
)
|
||||
|
||||
def initialize_kv_cache(self, kv_cache_config: KVCacheConfig) -> None:
|
||||
"""
|
||||
|
||||
Reference in New Issue
Block a user