diff --git a/vllm_ascend/worker/model_runner_v1.py b/vllm_ascend/worker/model_runner_v1.py index 86cd598c..51018023 100644 --- a/vllm_ascend/worker/model_runner_v1.py +++ b/vllm_ascend/worker/model_runner_v1.py @@ -116,6 +116,7 @@ from vllm_ascend.utils import ( check_gdn_layer, enable_sp, enable_sp_by_pass, + global_stream, is_drafter_moe_model, is_moe_model, lmhead_tp_enable, @@ -390,6 +391,7 @@ class NPUModelRunner(GPUModelRunner): self.long_seq_metadata = None self.query_lens: torch.Tensor | None = None self.cpu_slot_mapping = None + self.sampling_done_event: torch.npu.Event | None = None @property def use_cp(self) -> bool: @@ -1457,7 +1459,11 @@ class NPUModelRunner(GPUModelRunner): sampler_output = self._sample(logits, spec_decode_metadata) if self.need_accepted_tokens: - self._update_states_after_model_execute(sampler_output.sampled_token_ids, scheduler_output) + if self.sampling_done_event is None: + self.sampling_done_event = torch.npu.Event() + + assert self.sampling_done_event is not None + self.sampling_done_event.record() def propose_draft_token_ids(sampled_token_ids): assert spec_decode_common_attn_metadata is not None @@ -1536,6 +1542,16 @@ class NPUModelRunner(GPUModelRunner): if self.debugger is not None: self.debugger.stop() self.debugger.step() + + if self.need_accepted_tokens: + assert self.sampling_done_event is not None + with ( + record_function_or_nullcontext("async_state_update"), + torch.npu.stream(global_stream()), + ): + global_stream().wait_event(self.sampling_done_event) + self._update_states_after_model_execute(sampler_output.sampled_token_ids, scheduler_output) + if not self.use_async_scheduling: return model_runner_output return AsyncGPUModelRunnerOutput(