[Perf] Optimize MTP execution by reordering state update operation (#6844)
## Summary
- Move `_update_states_after_model_execute` call from after main model
sampling to after draft model execution
- This reordering reduces pipeline bubbles between main model and draft
model execution
- No accuracy impact - the state update operation is independent of
draft token proposal
## Performance Impact
Reduces idle time between main model and draft model execution stages,
improving overall MTP (Multi-Token Prediction) performance.
- vLLM version: v0.15.0
- vLLM main:
83b47f67b1
---------
Signed-off-by: SlightwindSec <slightwindsec@gmail.com>
Co-authored-by: wanghuanjun2113 <wanghuanjun2113@gmail.com>
This commit is contained in:
@@ -116,6 +116,7 @@ from vllm_ascend.utils import (
|
||||
check_gdn_layer,
|
||||
enable_sp,
|
||||
enable_sp_by_pass,
|
||||
global_stream,
|
||||
is_drafter_moe_model,
|
||||
is_moe_model,
|
||||
lmhead_tp_enable,
|
||||
@@ -390,6 +391,7 @@ class NPUModelRunner(GPUModelRunner):
|
||||
self.long_seq_metadata = None
|
||||
self.query_lens: torch.Tensor | None = None
|
||||
self.cpu_slot_mapping = None
|
||||
self.sampling_done_event: torch.npu.Event | None = None
|
||||
|
||||
@property
|
||||
def use_cp(self) -> bool:
|
||||
@@ -1457,7 +1459,11 @@ class NPUModelRunner(GPUModelRunner):
|
||||
sampler_output = self._sample(logits, spec_decode_metadata)
|
||||
|
||||
if self.need_accepted_tokens:
|
||||
self._update_states_after_model_execute(sampler_output.sampled_token_ids, scheduler_output)
|
||||
if self.sampling_done_event is None:
|
||||
self.sampling_done_event = torch.npu.Event()
|
||||
|
||||
assert self.sampling_done_event is not None
|
||||
self.sampling_done_event.record()
|
||||
|
||||
def propose_draft_token_ids(sampled_token_ids):
|
||||
assert spec_decode_common_attn_metadata is not None
|
||||
@@ -1536,6 +1542,16 @@ class NPUModelRunner(GPUModelRunner):
|
||||
if self.debugger is not None:
|
||||
self.debugger.stop()
|
||||
self.debugger.step()
|
||||
|
||||
if self.need_accepted_tokens:
|
||||
assert self.sampling_done_event is not None
|
||||
with (
|
||||
record_function_or_nullcontext("async_state_update"),
|
||||
torch.npu.stream(global_stream()),
|
||||
):
|
||||
global_stream().wait_event(self.sampling_done_event)
|
||||
self._update_states_after_model_execute(sampler_output.sampled_token_ids, scheduler_output)
|
||||
|
||||
if not self.use_async_scheduling:
|
||||
return model_runner_output
|
||||
return AsyncGPUModelRunnerOutput(
|
||||
|
||||
Reference in New Issue
Block a user