diff --git a/python/sglang/srt/layers/logits_processor.py b/python/sglang/srt/layers/logits_processor.py index 13eac206f..e39727842 100644 --- a/python/sglang/srt/layers/logits_processor.py +++ b/python/sglang/srt/layers/logits_processor.py @@ -72,7 +72,10 @@ class LogitsProcessorOutput: next_token_top_logprobs_val: Optional[List] = None next_token_top_logprobs_idx: Optional[List] = None # The logprobs and ids of the requested token ids in output positions. shape: [#seq, n] (n is the number of requested token ids) - next_token_token_ids_logprobs_val: Optional[List] = None + # Can contain either lists or GPU tensors (for delayed copy optimization in prefill-only requests) + next_token_token_ids_logprobs_val: Optional[ + List[Union[List[float], torch.Tensor]] + ] = None next_token_token_ids_logprobs_idx: Optional[List] = None ## Part 3: Prefill-only. This part will be assigned in python/sglang/srt/layers/logits_processor.py::LogitsProcessor diff --git a/python/sglang/srt/layers/sampler.py b/python/sglang/srt/layers/sampler.py index 565ce106e..03972c58b 100644 --- a/python/sglang/srt/layers/sampler.py +++ b/python/sglang/srt/layers/sampler.py @@ -1,5 +1,5 @@ import logging -from typing import List +from typing import List, Tuple import torch import torch.distributed as dist @@ -39,6 +39,25 @@ class Sampler(nn.Module): if is_dp_attention_enabled(): self.tp_sync_group = get_attention_tp_group().device_group + def _preprocess_logits( + self, logits: torch.Tensor, sampling_info: SamplingBatchInfo + ) -> torch.Tensor: + """Apply custom logit processors and handle NaN detection.""" + # Apply the custom logit processors if registered in the sampling info + if sampling_info.has_custom_logit_processor: + apply_custom_logit_processor(logits, sampling_info) + + # Detect and handle NaN values in logits + if self.use_nan_detection and torch.any(torch.isnan(logits)): + logger.warning("Detected errors during sampling! NaN in the logits.") + logits = torch.where( + torch.isnan(logits), torch.full_like(logits, -1e5), logits + ) + if crash_on_warnings(): + raise ValueError("Detected errors during sampling! NaN in the logits.") + + return logits + def forward( self, logits_output: LogitsProcessorOutput, @@ -61,17 +80,8 @@ class Sampler(nn.Module): """ logits = logits_output.next_token_logits - # Apply the custom logit processors if registered in the sampling info. - if sampling_info.has_custom_logit_processor: - apply_custom_logit_processor(logits, sampling_info) - - if self.use_nan_detection and torch.any(torch.isnan(logits)): - logger.warning("Detected errors during sampling! NaN in the logits.") - logits = torch.where( - torch.isnan(logits), torch.full_like(logits, -1e5), logits - ) - if crash_on_warnings(): - raise ValueError("Detected errors during sampling! NaN in the logits.") + # Preprocess logits (custom processors and NaN handling) + logits = self._preprocess_logits(logits, sampling_info) if sampling_info.is_all_greedy: # Use torch.argmax if all requests use greedy sampling @@ -165,6 +175,54 @@ class Sampler(nn.Module): return batch_next_token_ids + def compute_logprobs_only( + self, + logits_output: LogitsProcessorOutput, + sampling_info: SamplingBatchInfo, + return_logprob: bool, + top_logprobs_nums: List[int], + token_ids_logprobs: List[List[int]], + ) -> None: + """ + Compute logprobs for requested token IDs without performing sampling. + + Optimized for prefill-only scoring requests that need token probabilities + but don't require next token generation. + """ + if logits_output.next_token_logits is None: + logger.warning("No logits available for logprob computation") + return + + # Check if any requests actually need logprobs computation + needs_token_ids_logprobs = any( + token_ids is not None and len(token_ids) > 0 + for token_ids in token_ids_logprobs + ) + needs_top_logprobs = any(x > 0 for x in top_logprobs_nums) + + if not (needs_token_ids_logprobs or needs_top_logprobs): + return + + # Preprocess logits (custom processors and NaN handling) + logits = self._preprocess_logits(logits_output.next_token_logits, sampling_info) + + # Compute logprobs + logprobs = torch.nn.functional.log_softmax(logits, dim=-1) + + # Handle top logprobs if requested + if needs_top_logprobs: + ( + logits_output.next_token_top_logprobs_val, + logits_output.next_token_top_logprobs_idx, + ) = get_top_logprobs(logprobs, top_logprobs_nums) + + # Handle token_ids logprobs if requested + if needs_token_ids_logprobs: + ( + logits_output.next_token_token_ids_logprobs_val, + logits_output.next_token_token_ids_logprobs_idx, + ) = get_token_ids_logprobs_batch_optimized(logprobs, token_ids_logprobs) + def top_k_top_p_min_p_sampling_from_probs_torch( probs: torch.Tensor, @@ -234,10 +292,95 @@ def get_top_logprobs( ) -def get_token_ids_logprobs( +def get_token_ids_logprobs_batch_optimized( logprobs: torch.Tensor, token_ids_logprobs: List[List[int]], -): +) -> Tuple[List, List]: + """ + Vectorized batch processing for token ID logprobs extraction. + + Uses a single GPU kernel call for the entire batch instead of multiple + separate calls, significantly improving performance for large batches. + + Args: + logprobs: Log probabilities tensor [batch_size, vocab_size] + token_ids_logprobs: List of token IDs to extract logprobs for + + Example: + # Input: batch_size=3, vocab_size=5 + logprobs = torch.tensor([ + [-1.2, -2.1, -0.8, -3.0, -1.5], # batch 0 + [-0.5, -1.8, -2.2, -1.1, -2.7], # batch 1 + [-2.0, -0.9, -1.4, -2.8, -1.6], # batch 2 + ]) + token_ids_logprobs = [[1, 3], [2], [0, 2, 4]] + + # Output: + # values = [tensor([-2.1, -3.0]), tensor([-2.2]), tensor([-2.0, -1.4, -1.6])] + # indices = [[1, 3], [2], [0, 2, 4]] + """ + batch_size = len(token_ids_logprobs) + device = logprobs.device + + # Step 1: Calculate lengths for each request, treating None as empty list + # Example: [[1, 3], [2], [0, 2, 4]] -> token_lengths = tensor([2, 1, 3]) + token_lengths = torch.tensor( + [len(token_ids or []) for token_ids in token_ids_logprobs], device=device + ) + total_tokens = int(token_lengths.sum().item()) # 2 + 1 + 3 = 6 + + # Handle edge case where no tokens are requested + if total_tokens == 0: + return [logprobs.new_empty(0) for _ in token_ids_logprobs], [ + [] for _ in token_ids_logprobs + ] + + # Step 2: Build flattened indices using torch operations + # Example: row_indices = [0, 0, 1, 2, 2, 2] (batch indices repeated by their lengths) + row_indices = torch.repeat_interleave( + torch.arange(batch_size, device=device), token_lengths + ) + # Example: col_indices = [1, 3, 2, 0, 2, 4] (flattened token IDs from all requests) + col_indices = torch.tensor( + [ + token_id + for token_ids in token_ids_logprobs + for token_id in (token_ids or []) + ], + device=device, + dtype=torch.long, + ) + + # Step 3: Single vectorized gather operation + # Example: logprobs[row_indices, col_indices] -> [-2.1, -3.0, -2.2, -2.0, -1.4, -1.6] + gathered_logprobs = logprobs[row_indices, col_indices] + + # Step 4: Split results back per request using torch operations + # Example: split tensor [6] into chunks of sizes [2, 1, 3] -> [tensor(2), tensor(1), tensor(3)] + split_logprobs = torch.split_with_sizes( + gathered_logprobs, token_lengths.tolist(), dim=0 + ) + + # Step 5: Format output to match expected return structure + # Example: Convert split tensors back to list format with proper empty handling + # i=0: [1,3] -> append split_logprobs[0] and [1,3] + # i=1: [2] -> append split_logprobs[1] and [2] + # i=2: [0,2,4] -> append split_logprobs[2] and [0,2,4] + output_token_ids_logprobs_val = [] + output_token_ids_logprobs_idx = [] + + for i, token_ids in enumerate(token_ids_logprobs): + if token_ids is not None and len(token_ids) > 0: + output_token_ids_logprobs_val.append(split_logprobs[i]) + output_token_ids_logprobs_idx.append(token_ids) + else: + output_token_ids_logprobs_val.append(logprobs.new_empty(0)) + output_token_ids_logprobs_idx.append([]) + + return output_token_ids_logprobs_val, output_token_ids_logprobs_idx + + +def get_token_ids_logprobs(logprobs: torch.Tensor, token_ids_logprobs: List[List[int]]): output_token_ids_logprobs_val = [] output_token_ids_logprobs_idx = [] for i, token_ids in enumerate(token_ids_logprobs): diff --git a/python/sglang/srt/managers/schedule_batch.py b/python/sglang/srt/managers/schedule_batch.py index c0c0917ac..5fd830afe 100644 --- a/python/sglang/srt/managers/schedule_batch.py +++ b/python/sglang/srt/managers/schedule_batch.py @@ -561,7 +561,10 @@ class Req: # shape: (bs, k) self.output_top_logprobs_val = [] self.output_top_logprobs_idx = [] - self.output_token_ids_logprobs_val = [] + # Can contain either lists or GPU tensors (delayed copy optimization for prefill-only scoring) + self.output_token_ids_logprobs_val: List[ + Union[List[float], torch.Tensor] + ] = [] self.output_token_ids_logprobs_idx = [] else: self.output_token_logprobs_val = self.output_token_logprobs_idx = ( @@ -619,6 +622,11 @@ class Req: def seqlen(self): return len(self.origin_input_ids) + len(self.output_ids) + @property + def is_prefill_only(self) -> bool: + """Check if this request is prefill-only (no token generation needed).""" + return self.sampling_params.max_new_tokens == 0 + def extend_image_inputs(self, image_inputs): if self.multimodal_inputs is None: self.multimodal_inputs = image_inputs @@ -950,9 +958,7 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin): device=req_to_token_pool.device, spec_algorithm=spec_algorithm, return_hidden_states=any(req.return_hidden_states for req in reqs), - is_prefill_only=all( - req.sampling_params.max_new_tokens == 0 for req in reqs - ), + is_prefill_only=all(req.is_prefill_only for req in reqs), chunked_req=chunked_req, ) @@ -1210,13 +1216,36 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin): req.is_retracted = False # Compute the relative logprob_start_len in an extend batch + # + # Key variables: + # - logprob_start_len: Absolute position in full sequence where logprob computation begins + # - extend_logprob_start_len: Relative position within current extend batch where logprob computation begins + # - extend_input_len: Number of tokens that need to be processed in this extend batch + # (= len(fill_ids) - len(prefix_indices), where fill_ids = origin_input_ids + output_ids + # and prefix_indices are the cached/shared prefix tokens) + # if req.logprob_start_len >= pre_len: - req.extend_logprob_start_len = min( - req.logprob_start_len - pre_len, - req.extend_input_len, - req.seqlen - 1, - ) + # Optimization for prefill-only requests: When we only need logprobs at + # positions beyond the input sequence (to score next-token likelihood), skip all + # input logprob computation during prefill since no generation will occur. + if self.is_prefill_only and req.logprob_start_len == len( + req.origin_input_ids + ): + # Skip ALL input logprobs: set extend_logprob_start_len = extend_input_len + req.extend_logprob_start_len = req.extend_input_len + else: + # Convert absolute logprob_start_len to relative extend_logprob_start_len + # + # Example: origin_input_ids=[1,2,3,4,5] (5 tokens, positions 0-4), logprob_start_len=3 + # Regular logic: min(3-0, 5, 5-1) = min(3,5,4) = 3 + # This means: "compute logprobs from position 3 onwards in extend batch" + req.extend_logprob_start_len = min( + req.logprob_start_len - pre_len, + req.extend_input_len, + req.seqlen - 1, + ) else: + # logprob_start_len is before the current extend batch, so start from beginning req.extend_logprob_start_len = 0 if self.return_logprob: @@ -1763,6 +1792,7 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin): ), extend_input_logprob_token_ids=self.extend_input_logprob_token_ids, launch_done=self.launch_done, + is_prefill_only=self.is_prefill_only, ) def copy(self): @@ -1905,6 +1935,9 @@ class ModelWorkerBatch: # Overlap event launch_done: Optional[threading.Event] = None + # Whether this batch is prefill-only (no token generation needed) + is_prefill_only: bool = False + @triton.jit def write_req_to_token_pool_triton( diff --git a/python/sglang/srt/managers/scheduler.py b/python/sglang/srt/managers/scheduler.py index cd915f765..b4a95a584 100644 --- a/python/sglang/srt/managers/scheduler.py +++ b/python/sglang/srt/managers/scheduler.py @@ -1261,11 +1261,19 @@ class Scheduler( # Copy more attributes if recv_req.logprob_start_len == -1 or not recv_req.return_logprob: # By default, only return the logprobs for output tokens - req.logprob_start_len = len(req.origin_input_ids) - 1 + # For prefill-only requests with logprob_start_len == -1, set logprob_start_len beyond input sequence + # to skip input logprob computation entirely + if req.is_prefill_only: + req.logprob_start_len = len(req.origin_input_ids) + else: + # TODO: For text generation, evaluate setting logprob_start_len to len(req.origin_input_ids) as well + req.logprob_start_len = len(req.origin_input_ids) - 1 else: req.logprob_start_len = recv_req.logprob_start_len - if req.logprob_start_len >= len(req.origin_input_ids): + if not req.is_prefill_only and req.logprob_start_len >= len( + req.origin_input_ids + ): error_msg = f"{req.logprob_start_len=} is higher than the number of input tokens {len(req.origin_input_ids)=}. Please use a smaller logprob_start_len." req.logprob_start_len = len(req.origin_input_ids) - 1 req.set_finish_with_abort(error_msg) diff --git a/python/sglang/srt/managers/scheduler_output_processor_mixin.py b/python/sglang/srt/managers/scheduler_output_processor_mixin.py index d931759bb..aa060af8a 100644 --- a/python/sglang/srt/managers/scheduler_output_processor_mixin.py +++ b/python/sglang/srt/managers/scheduler_output_processor_mixin.py @@ -5,6 +5,8 @@ import threading import time from typing import TYPE_CHECKING, List, Optional, Tuple, Union +import torch + from sglang.srt.disaggregation.utils import DisaggregationMode from sglang.srt.layers.logits_processor import LogitsProcessorOutput from sglang.srt.managers.io_struct import AbortReq, BatchEmbeddingOut, BatchTokenIDOut @@ -71,6 +73,7 @@ class SchedulerOutputProcessorMixin: # Check finish conditions logprob_pt = 0 + for i, (req, next_token_id) in enumerate(zip(batch.reqs, next_token_ids)): if req.is_retracted: continue @@ -99,6 +102,7 @@ class SchedulerOutputProcessorMixin: extend_logprob_start_len = extend_logprob_start_len_per_req[i] extend_input_len = extend_input_len_per_req[i] num_input_logprobs = extend_input_len - extend_logprob_start_len + if req.return_logprob: self.add_logprob_return_values( i, @@ -441,27 +445,59 @@ class SchedulerOutputProcessorMixin: output: LogitsProcessorOutput, ): """Attach logprobs to the return values.""" - req.output_token_logprobs_val.append(output.next_token_logprobs[i]) - req.output_token_logprobs_idx.append(next_token_ids[i]) + if output.next_token_logprobs is not None: + req.output_token_logprobs_val.append(output.next_token_logprobs[i]) + req.output_token_logprobs_idx.append(next_token_ids[i]) - self.add_input_logprob_return_values( - i, req, output, pt, num_input_logprobs, last_prefill_chunk=True - ) + # Only add input logprobs if there are input tokens to process + # Note: For prefill-only requests with default logprob_start_len, this will be 0, + # meaning we only compute output logprobs (which is the intended behavior) + if num_input_logprobs > 0: + self.add_input_logprob_return_values( + i, req, output, pt, num_input_logprobs, last_prefill_chunk=True + ) + else: + self._initialize_empty_logprob_containers(req) if req.top_logprobs_num > 0: req.output_top_logprobs_val.append(output.next_token_top_logprobs_val[i]) req.output_top_logprobs_idx.append(output.next_token_top_logprobs_idx[i]) - if req.token_ids_logprob is not None: - req.output_token_ids_logprobs_val.append( - output.next_token_token_ids_logprobs_val[i] - ) + if ( + req.token_ids_logprob is not None + and output.next_token_token_ids_logprobs_val is not None + ): + # Convert GPU tensor to list if needed + logprobs_val = output.next_token_token_ids_logprobs_val[i] + if isinstance(logprobs_val, torch.Tensor): + logprobs_val = logprobs_val.tolist() + req.output_token_ids_logprobs_val.append(logprobs_val) req.output_token_ids_logprobs_idx.append( output.next_token_token_ids_logprobs_idx[i] ) return num_input_logprobs + def _initialize_empty_logprob_containers(self, req: Req) -> None: + """ + Initialize logprob fields to empty lists if unset. + + This is needed for prefill-only requests where the normal initialization + flow might be bypassed, but downstream code expects these fields to be lists. + """ + if req.input_token_logprobs_val is None: + req.input_token_logprobs_val = [] + if req.input_token_logprobs_idx is None: + req.input_token_logprobs_idx = [] + if req.input_top_logprobs_val is None: + req.input_top_logprobs_val = [] + if req.input_top_logprobs_idx is None: + req.input_top_logprobs_idx = [] + if req.input_token_ids_logprobs_val is None: + req.input_token_ids_logprobs_val = [] + if req.input_token_ids_logprobs_idx is None: + req.input_token_ids_logprobs_idx = [] + def stream_output( self: Scheduler, reqs: List[Req], diff --git a/python/sglang/srt/managers/tokenizer_manager.py b/python/sglang/srt/managers/tokenizer_manager.py index 730ee05aa..7e1ea2cd6 100644 --- a/python/sglang/srt/managers/tokenizer_manager.py +++ b/python/sglang/srt/managers/tokenizer_manager.py @@ -1778,11 +1778,15 @@ class TokenizerManager(TokenizerCommunicatorMixin): # the next position after the last token in the prompt output_logprobs = result["meta_info"].get("output_token_ids_logprobs", []) - # Throw an error here if output_logprobs is None - if output_logprobs is None: + # Check if output_logprobs is properly populated + if ( + output_logprobs is None + or not output_logprobs + or len(output_logprobs) == 0 + ): raise RuntimeError( - f"output_logprobs is None for request {result['meta_info'].get('id', '')}. " - "This usually indicates a problem with the scoring request or the backend output." + f"output_logprobs is empty for request {result['meta_info'].get('id', '')}. " + "This indicates token_ids_logprobs were not computed properly for the scoring request." ) for logprob, token_id, _ in output_logprobs[0]: diff --git a/python/sglang/srt/managers/tp_worker.py b/python/sglang/srt/managers/tp_worker.py index 4c13ff796..059813f83 100644 --- a/python/sglang/srt/managers/tp_worker.py +++ b/python/sglang/srt/managers/tp_worker.py @@ -259,6 +259,15 @@ class TpModelWorker: if skip_sample: next_token_ids = None + # For prefill-only requests, we still need to compute logprobs even when sampling is skipped + if ( + model_worker_batch.is_prefill_only + and model_worker_batch.return_logprob + ): + # Compute logprobs without full sampling + self.model_runner.compute_logprobs_only( + logits_output, model_worker_batch + ) else: next_token_ids = self.model_runner.sample( logits_output, model_worker_batch diff --git a/python/sglang/srt/managers/tp_worker_overlap_thread.py b/python/sglang/srt/managers/tp_worker_overlap_thread.py index 399ac1675..e34399a41 100644 --- a/python/sglang/srt/managers/tp_worker_overlap_thread.py +++ b/python/sglang/srt/managers/tp_worker_overlap_thread.py @@ -174,21 +174,28 @@ class TpModelWorkerClient: # Run forward logits_output, next_token_ids, can_run_cuda_graph = ( self.worker.forward_batch_generation( - model_worker_batch, model_worker_batch.launch_done + model_worker_batch, + model_worker_batch.launch_done, + # Skip sampling for prefill-only requests + skip_sample=model_worker_batch.is_prefill_only, ) ) # Update the future token ids map bs = len(model_worker_batch.seq_lens) + if model_worker_batch.is_prefill_only: + # For prefill-only requests, create dummy token IDs on CPU + next_token_ids = torch.zeros(bs, dtype=torch.long) self.future_token_ids_map[ future_token_ids_ct + 1 : future_token_ids_ct + bs + 1 ] = next_token_ids # Copy results to the CPU if model_worker_batch.return_logprob: - logits_output.next_token_logprobs = ( - logits_output.next_token_logprobs.to("cpu", non_blocking=True) - ) + if logits_output.next_token_logprobs is not None: + logits_output.next_token_logprobs = ( + logits_output.next_token_logprobs.to("cpu", non_blocking=True) + ) if logits_output.input_token_logprobs is not None: logits_output.input_token_logprobs = ( logits_output.input_token_logprobs.to("cpu", non_blocking=True) @@ -197,7 +204,9 @@ class TpModelWorkerClient: logits_output.hidden_states = logits_output.hidden_states.to( "cpu", non_blocking=True ) - next_token_ids = next_token_ids.to("cpu", non_blocking=True) + # Only copy to CPU if not already on CPU + if next_token_ids.device.type != "cpu": + next_token_ids = next_token_ids.to("cpu", non_blocking=True) copy_done.record() self.output_queue.put( @@ -221,10 +230,10 @@ class TpModelWorkerClient: logits_output.next_token_logprobs = ( logits_output.next_token_logprobs.tolist() ) - if logits_output.input_token_logprobs is not None: - logits_output.input_token_logprobs = tuple( - logits_output.input_token_logprobs.tolist() - ) + if logits_output.input_token_logprobs is not None: + logits_output.input_token_logprobs = tuple( + logits_output.input_token_logprobs.tolist() + ) next_token_ids = next_token_ids.tolist() return logits_output, next_token_ids, can_run_cuda_graph diff --git a/python/sglang/srt/model_executor/model_runner.py b/python/sglang/srt/model_executor/model_runner.py index c9aae4d2b..2cb27bcbe 100644 --- a/python/sglang/srt/model_executor/model_runner.py +++ b/python/sglang/srt/model_executor/model_runner.py @@ -2158,6 +2158,38 @@ class ModelRunner: ) return next_token_ids + def compute_logprobs_only( + self, + logits_output: LogitsProcessorOutput, + forward_batch: ForwardBatch, + ) -> None: + """ + Compute token_ids_logprobs without performing sampling. + + Optimized path for prefill-only requests that need token_ids_logprobs but don't + require next token generation. Skips expensive sampling operations + while still providing requested probability information. + + Args: + logits_output: The logits output from the model forward + forward_batch: The forward batch that generates logits_output + """ + if not forward_batch.token_ids_logprobs: + return + + # Preprocess logits (same as in sample method) + self._preprocess_logits(logits_output, forward_batch.sampling_info) + + # Delegate to sampler for logprob-only computation + # This populates logits_output with requested token probabilities + self.sampler.compute_logprobs_only( + logits_output, + forward_batch.sampling_info, + forward_batch.return_logprob, + forward_batch.top_logprobs_nums, + forward_batch.token_ids_logprobs, + ) + @property def model_is_mrope(self) -> bool: """Detect if the model has "mrope" rope_scaling type.