From 4cb5a5235e49f08e9f47165fb7f95e35145fce43 Mon Sep 17 00:00:00 2001 From: Liangsheng Yin Date: Sun, 5 Oct 2025 23:41:04 +0800 Subject: [PATCH] Tiny `skip_sample` adjust (#11225) --- python/sglang/srt/managers/schedule_batch.py | 6 ++++- python/sglang/srt/managers/tp_worker.py | 23 ++++++++----------- .../srt/managers/tp_worker_overlap_thread.py | 2 -- python/sglang/srt/speculative/eagle_worker.py | 2 +- python/sglang/srt/speculative/ngram_worker.py | 2 +- 5 files changed, 17 insertions(+), 18 deletions(-) diff --git a/python/sglang/srt/managers/schedule_batch.py b/python/sglang/srt/managers/schedule_batch.py index 31b696f9f..7ec246e4b 100644 --- a/python/sglang/srt/managers/schedule_batch.py +++ b/python/sglang/srt/managers/schedule_batch.py @@ -663,7 +663,11 @@ class Req: @property def is_prefill_only(self) -> bool: """Check if this request is prefill-only (no token generation needed).""" - return self.sampling_params.max_new_tokens == 0 + # NOTE: when spec is enabled, prefill_only optimizations are disabled + return ( + self.sampling_params.max_new_tokens == 0 + and global_server_args_dict["speculative_algorithm"] is None + ) def add_latency(self, stage: RequestStage): if self.metrics_collector is None: diff --git a/python/sglang/srt/managers/tp_worker.py b/python/sglang/srt/managers/tp_worker.py index 6a3ef3b96..475305a2f 100644 --- a/python/sglang/srt/managers/tp_worker.py +++ b/python/sglang/srt/managers/tp_worker.py @@ -237,7 +237,7 @@ class TpModelWorker: self, model_worker_batch: ModelWorkerBatch, launch_done: Optional[threading.Event] = None, - skip_sample: bool = False, + is_verify: bool = False, ) -> ForwardBatchOutput: # update the consumer index of hicache to the running batch self.set_hicache_consumer(model_worker_batch.hicache_consumer_index) @@ -259,19 +259,16 @@ class TpModelWorker: if launch_done is not None: launch_done.set() - if skip_sample: - next_token_ids = None - # For prefill-only requests, we still need to compute logprobs even when sampling is skipped - if ( - model_worker_batch.is_prefill_only - and model_worker_batch.return_logprob - ): - # Compute logprobs without full sampling - self.model_runner.compute_logprobs_only( - logits_output, model_worker_batch - ) - else: + skip_sample = is_verify or model_worker_batch.is_prefill_only + next_token_ids = None + + if not skip_sample: next_token_ids = self.model_runner.sample(logits_output, forward_batch) + elif model_worker_batch.return_logprob and not is_verify: + # NOTE: Compute logprobs without full sampling + self.model_runner.compute_logprobs_only( + logits_output, model_worker_batch + ) return ForwardBatchOutput( logits_output=logits_output, diff --git a/python/sglang/srt/managers/tp_worker_overlap_thread.py b/python/sglang/srt/managers/tp_worker_overlap_thread.py index 1af05a434..3e0240361 100644 --- a/python/sglang/srt/managers/tp_worker_overlap_thread.py +++ b/python/sglang/srt/managers/tp_worker_overlap_thread.py @@ -164,8 +164,6 @@ class TpModelWorkerClient: forward_batch_output = self.worker.forward_batch_generation( model_worker_batch, model_worker_batch.launch_done, - # Skip sampling for prefill-only requests - skip_sample=model_worker_batch.is_prefill_only, ) logits_output, next_token_ids, can_run_cuda_graph = ( diff --git a/python/sglang/srt/speculative/eagle_worker.py b/python/sglang/srt/speculative/eagle_worker.py index 9df1ef973..08d659d02 100644 --- a/python/sglang/srt/speculative/eagle_worker.py +++ b/python/sglang/srt/speculative/eagle_worker.py @@ -823,7 +823,7 @@ class EAGLEWorker(TpModelWorker): # Forward forward_batch_output = self.target_worker.forward_batch_generation( - model_worker_batch, skip_sample=True + model_worker_batch, is_verify=True ) logits_output, can_run_cuda_graph = ( forward_batch_output.logits_output, diff --git a/python/sglang/srt/speculative/ngram_worker.py b/python/sglang/srt/speculative/ngram_worker.py index 473e040d2..97aa620ce 100644 --- a/python/sglang/srt/speculative/ngram_worker.py +++ b/python/sglang/srt/speculative/ngram_worker.py @@ -214,7 +214,7 @@ class NGRAMWorker: if model_worker_batch.forward_mode.is_target_verify(): forward_batch_output = self.target_worker.forward_batch_generation( - model_worker_batch, skip_sample=True + model_worker_batch, is_verify=True ) logits_output, can_run_cuda_graph = ( forward_batch_output.logits_output,