From cb4c7de856bb4bea822602c82abb23cb421b11a4 Mon Sep 17 00:00:00 2001 From: Cao Yi Date: Mon, 9 Mar 2026 15:55:27 +0800 Subject: [PATCH] [Perf] Optimize MTP execution by reordering state update operation (#6844) ## Summary - Move `_update_states_after_model_execute` call from after main model sampling to after draft model execution - This reordering reduces pipeline bubbles between main model and draft model execution - No accuracy impact - the state update operation is independent of draft token proposal ## Performance Impact Reduces idle time between main model and draft model execution stages, improving overall MTP (Multi-Token Prediction) performance. - vLLM version: v0.15.0 - vLLM main: https://github.com/vllm-project/vllm/commit/83b47f67b1dfad505606070ae4d9f83e50ad4ebd --------- Signed-off-by: SlightwindSec Co-authored-by: wanghuanjun2113 --- vllm_ascend/worker/model_runner_v1.py | 18 +++++++++++++++++- 1 file changed, 17 insertions(+), 1 deletion(-) diff --git a/vllm_ascend/worker/model_runner_v1.py b/vllm_ascend/worker/model_runner_v1.py index 86cd598c..51018023 100644 --- a/vllm_ascend/worker/model_runner_v1.py +++ b/vllm_ascend/worker/model_runner_v1.py @@ -116,6 +116,7 @@ from vllm_ascend.utils import ( check_gdn_layer, enable_sp, enable_sp_by_pass, + global_stream, is_drafter_moe_model, is_moe_model, lmhead_tp_enable, @@ -390,6 +391,7 @@ class NPUModelRunner(GPUModelRunner): self.long_seq_metadata = None self.query_lens: torch.Tensor | None = None self.cpu_slot_mapping = None + self.sampling_done_event: torch.npu.Event | None = None @property def use_cp(self) -> bool: @@ -1457,7 +1459,11 @@ class NPUModelRunner(GPUModelRunner): sampler_output = self._sample(logits, spec_decode_metadata) if self.need_accepted_tokens: - self._update_states_after_model_execute(sampler_output.sampled_token_ids, scheduler_output) + if self.sampling_done_event is None: + self.sampling_done_event = torch.npu.Event() + + assert self.sampling_done_event is not None + self.sampling_done_event.record() def propose_draft_token_ids(sampled_token_ids): assert spec_decode_common_attn_metadata is not None @@ -1536,6 +1542,16 @@ class NPUModelRunner(GPUModelRunner): if self.debugger is not None: self.debugger.stop() self.debugger.step() + + if self.need_accepted_tokens: + assert self.sampling_done_event is not None + with ( + record_function_or_nullcontext("async_state_update"), + torch.npu.stream(global_stream()), + ): + global_stream().wait_event(self.sampling_done_event) + self._update_states_after_model_execute(sampler_output.sampled_token_ids, scheduler_output) + if not self.use_async_scheduling: return model_runner_output return AsyncGPUModelRunnerOutput(