[Perf] Optimize MTP execution by reordering state update operation (#6844)

## Summary
- Move `_update_states_after_model_execute` call from after main model
sampling to after draft model execution
- This reordering reduces pipeline bubbles between main model and draft
model execution
- No accuracy impact - the state update operation is independent of
draft token proposal

## Performance Impact
Reduces idle time between main model and draft model execution stages,
improving overall MTP (Multi-Token Prediction) performance.
- vLLM version: v0.15.0
- vLLM main:
83b47f67b1

---------

Signed-off-by: SlightwindSec <slightwindsec@gmail.com>
Co-authored-by: wanghuanjun2113 <wanghuanjun2113@gmail.com>
This commit is contained in:
Cao Yi
2026-03-09 15:55:27 +08:00
committed by GitHub
parent d39d80830c
commit cb4c7de856

View File

@@ -116,6 +116,7 @@ from vllm_ascend.utils import (
check_gdn_layer,
enable_sp,
enable_sp_by_pass,
global_stream,
is_drafter_moe_model,
is_moe_model,
lmhead_tp_enable,
@@ -390,6 +391,7 @@ class NPUModelRunner(GPUModelRunner):
self.long_seq_metadata = None
self.query_lens: torch.Tensor | None = None
self.cpu_slot_mapping = None
self.sampling_done_event: torch.npu.Event | None = None
@property
def use_cp(self) -> bool:
@@ -1457,7 +1459,11 @@ class NPUModelRunner(GPUModelRunner):
sampler_output = self._sample(logits, spec_decode_metadata)
if self.need_accepted_tokens:
self._update_states_after_model_execute(sampler_output.sampled_token_ids, scheduler_output)
if self.sampling_done_event is None:
self.sampling_done_event = torch.npu.Event()
assert self.sampling_done_event is not None
self.sampling_done_event.record()
def propose_draft_token_ids(sampled_token_ids):
assert spec_decode_common_attn_metadata is not None
@@ -1536,6 +1542,16 @@ class NPUModelRunner(GPUModelRunner):
if self.debugger is not None:
self.debugger.stop()
self.debugger.step()
if self.need_accepted_tokens:
assert self.sampling_done_event is not None
with (
record_function_or_nullcontext("async_state_update"),
torch.npu.stream(global_stream()),
):
global_stream().wait_event(self.sampling_done_event)
self._update_states_after_model_execute(sampler_output.sampled_token_ids, scheduler_output)
if not self.use_async_scheduling:
return model_runner_output
return AsyncGPUModelRunnerOutput(