From 53bd00d975ff1a0a4683a4f006c19636d529199f Mon Sep 17 00:00:00 2001 From: Sundara Raman Ramachandran Date: Wed, 8 Oct 2025 18:47:32 -0700 Subject: [PATCH] [Generative Score API] Multi-Item scoring with custom attention mask. (#10979) --- .../layers/attention/flashinfer_backend.py | 229 ++++++++++++- python/sglang/srt/layers/logits_processor.py | 142 +++++++- python/sglang/srt/managers/schedule_batch.py | 9 +- .../scheduler_output_processor_mixin.py | 215 ++++++++---- .../sglang/srt/managers/tokenizer_manager.py | 323 +++++++++++++++--- python/sglang/srt/managers/tp_worker.py | 12 +- .../srt/model_executor/forward_batch_info.py | 4 + python/sglang/srt/server_args.py | 25 +- python/sglang/srt/two_batch_overlap.py | 1 + test/srt/test_score_api.py | 290 ++++++++++++++++ 10 files changed, 1121 insertions(+), 129 deletions(-) diff --git a/python/sglang/srt/layers/attention/flashinfer_backend.py b/python/sglang/srt/layers/attention/flashinfer_backend.py index 048319202..520792119 100644 --- a/python/sglang/srt/layers/attention/flashinfer_backend.py +++ b/python/sglang/srt/layers/attention/flashinfer_backend.py @@ -7,6 +7,7 @@ FlashInfer is faster and Triton is easier to customize. Each backend supports two operators: extend (i.e. prefill with cached prefix) and decode. """ +import logging import os from dataclasses import dataclass from enum import Enum, auto @@ -16,11 +17,11 @@ from typing import TYPE_CHECKING, Callable, List, Optional, Union import torch if os.environ["SGLANG_ENABLE_TORCH_COMPILE"] == "1": - import logging - torch._logging.set_logs(dynamo=logging.ERROR) torch._dynamo.config.suppress_errors = True +logger = logging.getLogger(__name__) + from sglang.global_config import global_config from sglang.srt.layers.attention.base_attn_backend import AttentionBackend from sglang.srt.layers.attention.utils import create_flashinfer_kv_indices_triton @@ -58,6 +59,36 @@ class WrapperDispatch(Enum): CROSS_ATTENTION = auto() +@dataclass +class MultiItemScoringParams: + """Parameters for multi-item scoring in attention computation. + + Used when processing sequences with multiple items separated by delimiters, + where each item needs specific attention patterns that respect item boundaries. + + Attributes: + prefix_len_ptr: A uint32 1D tensor indicating the prefix length of each prompt. + The tensor size is equal to the batch size. + token_pos_in_items_ptr: A uint16 1D tensor indicating the token position of each item + starting from 0 (delimiter) for each item. For batch size > 1, + sequences are concatenated with zero padding to ensure same length. + token_pos_in_items_len: Zero padding length for token_pos_in_items_ptr to handle + batch_size > 1 case. Defines the padded length for each sequence. + max_item_len_ptr: A uint16 tensor containing the max token length of all items + for each prompt in the batch. + + """ + + prefix_len_ptr: Optional[torch.Tensor] = None + token_pos_in_items_ptr: Optional[torch.Tensor] = None + token_pos_in_items_len: int = 0 + max_item_len_ptr: Optional[torch.Tensor] = None + + def is_enabled(self) -> bool: + """Check if multi-item scoring is enabled.""" + return self.prefix_len_ptr is not None + + @dataclass class DecodeMetadata: decode_wrappers: List[BatchDecodeWithPagedKVCacheWrapper] @@ -68,6 +99,7 @@ class PrefillMetadata: prefill_wrappers: List[BatchPrefillWithPagedKVCacheWrapper] use_ragged: bool extend_no_prefix: bool + multi_item_params: Optional[MultiItemScoringParams] = None # Reuse this workspace buffer across all flashinfer wrappers @@ -90,6 +122,11 @@ class FlashInferAttnBackend(AttentionBackend): ): super().__init__() + # Store multi-item scoring delimiter for efficient access + self.multi_item_scoring_delimiter = ( + model_runner.server_args.multi_item_scoring_delimiter + ) + # Parse constants self.decode_use_tensor_cores = should_use_tensor_core( kv_cache_dtype=model_runner.kv_cache_dtype, @@ -229,10 +266,133 @@ class FlashInferAttnBackend(AttentionBackend): # Other metadata self.forward_metadata: Union[PrefillMetadata, DecodeMetadata] = None + self.decode_cuda_graph_metadata = {} self.prefill_cuda_graph_metadata = {} # For verify self.draft_extend_cuda_graph_metadata = {} # For draft extend + def _process_multi_item_scoring( + self, forward_batch: ForwardBatch + ) -> MultiItemScoringParams: + """Process multi-item scoring tensors for FlashInfer attention. + + This method handles sequences containing multiple "items" separated by delimiter tokens, + where each item needs specific attention patterns that respect item boundaries. + + The method produces four key tensors for FlashInfer: + - prefix_len_ptr: uint32 tensor with prefix length for each prompt in batch + - token_pos_in_items_ptr: uint16 tensor with token positions starting from 0 at delimiters + - token_pos_in_items_len: padding length for batch processing + - max_item_len_ptr: uint16 tensor with max item length for each prompt + + Args: + forward_batch: The forward batch containing input sequences and delimiter info + + Returns: + MultiItemScoringParams: The processed multi-item scoring parameters + + Examples: + Following FlashInfer definition: for 3 items of length 3, 2, 4 respectively: + token_pos_in_items_ptr = [0, 1, 2, 3, 0, 1, 2, 0, 1, 2, 3, 4, 0] + + Case 1: Single sequence + Text: "What is the capital of France? London Paris Berlin " + Tokens: [What, is, the, capital, of, France, ?, , London, , Paris, , Berlin, ] + Indices: [ 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13] + - prefix_len_ptr: [7] (query length before first delimiter) + - token_pos_in_items_ptr: [0, 1, 0, 1, 0, 1, 0] (delim=0, London=1, delim=0, Paris=1, delim=0, Berlin=1, delim=0) + - token_pos_in_items_len: 7 (actual length) + - max_item_len_ptr: [1] (max item length is 1 token - all options are single tokens) + + Case 2: Batch processing (batch_size=2) + Sequence 1: 2 items of length 2, 1 → [0, 1, 2, 0, 1, 0] (6 elements) + Sequence 2: 3 items of length 1, 3, 2 → [0, 1, 0, 1, 2, 3, 0, 1, 2, 0] (10 elements) + After padding both to length 10: + - token_pos_in_items_ptr: [0, 1, 2, 0, 1, 0, 0, 0, 0, 0, 0, 1, 0, 1, 2, 3, 0, 1, 2, 0] + - token_pos_in_items_len: 10 (padded length for batch processing) + - max_item_len_ptr: [2, 3] (max lengths per sequence) + """ + + delimiter = self.multi_item_scoring_delimiter + if delimiter is None or forward_batch.forward_mode == ForwardMode.DECODE: + return MultiItemScoringParams() + + delimiter_mask = forward_batch.input_ids == delimiter + prefix_cache_lens = getattr(forward_batch, "extend_prefix_lens", None) + extend_seq_lens = getattr(forward_batch, "extend_seq_lens", None) + prefix_len_ptr, token_pos_in_items_ptr = [], [] + token_pos_in_items_len = 0 + + # If no extend_seq_lens, treat whole batch as one sequence + if extend_seq_lens is None or len(extend_seq_lens) <= 1: + extend_seq_lens = [forward_batch.input_ids.size(0)] + + seq_start = 0 + for i, seq_len in enumerate(extend_seq_lens): + seq_end = seq_start + seq_len + mask = delimiter_mask[seq_start:seq_end] + pos = forward_batch.positions[seq_start:seq_end] + delimiter_indices = torch.nonzero(mask, as_tuple=True)[0] + + if len(delimiter_indices) > 0: + first_delim = delimiter_indices[0] + # Prefix length: store as scalar + prefix_len = first_delim + ( + prefix_cache_lens[i] if prefix_cache_lens is not None else 0 + ) + prefix_len_ptr.append( + prefix_len.item() if torch.is_tensor(prefix_len) else prefix_len + ) + + # Compute relative positions within items after delimiters + diff = pos[first_delim:] - torch.cummax(mask[first_delim:], 0)[1] + token_pos = (diff - pos[first_delim]).to(torch.uint16) + token_pos_in_items_ptr.append(token_pos) + + # Update forward_batch positions in-place + pos[first_delim:] = diff - 1 + forward_batch.positions[seq_start:seq_end] = pos + + seq_start = seq_end + + # Pad token_pos_in_items_ptr for batch processing + if token_pos_in_items_ptr: + token_pos_in_items_len = max(t.numel() for t in token_pos_in_items_ptr) + device = forward_batch.input_ids.device + token_pos_in_items_ptr = [ + torch.cat( + [ + t, + torch.zeros( + token_pos_in_items_len - t.numel(), + dtype=torch.uint16, + device=device, + ), + ] + ) + for t in token_pos_in_items_ptr + ] + + if not prefix_len_ptr or not token_pos_in_items_ptr: + return MultiItemScoringParams() + + # Build final params + device = forward_batch.input_ids.device + return MultiItemScoringParams( + prefix_len_ptr=torch.tensor( + prefix_len_ptr, dtype=torch.uint32, device=device + ), + token_pos_in_items_ptr=torch.cat(token_pos_in_items_ptr, dim=0), + token_pos_in_items_len=token_pos_in_items_len & 0xFFFFFFFF, + max_item_len_ptr=torch.stack( + [ + t.to(torch.int32).max().to(torch.uint16) + for t in token_pos_in_items_ptr + ], + dim=0, + ), + ) + def init_forward_metadata(self, forward_batch: ForwardBatch): if forward_batch.forward_mode.is_decode_or_idle(): self.indices_updater_decode.update( @@ -280,13 +440,26 @@ class FlashInferAttnBackend(AttentionBackend): else: prefix_lens = forward_batch.extend_prefix_lens - if self.is_multimodal: + # Disable ragged wrapper and ensure prefix handling for multimodal and multi-item scoring + if self.is_multimodal or self.multi_item_scoring_delimiter is not None: + # use_ragged = False: Multi-item scoring requires the paged wrapper because: + # 1. Ragged wrapper doesn't support the specialized multi-item parameters + # (prefix_len_ptr, token_pos_in_items_ptr, etc.) + # 2. Paged wrapper provides better control over attention masking needed + # for respecting item boundaries in multi-item sequences + # 3. Custom masking logic conflicts with ragged wrapper's assumptions use_ragged = False extend_no_prefix = False else: use_ragged = not self.enable_deterministic extend_no_prefix = not any(forward_batch.extend_prefix_lens_cpu) + # Process multi-item scoring in attention backend instead of ForwardBatch + multi_item_params = MultiItemScoringParams() + if self.multi_item_scoring_delimiter is not None: + # Use new backend-specific implementation + multi_item_params = self._process_multi_item_scoring(forward_batch) + self.indices_updater_prefill.update( forward_batch.req_pool_indices, forward_batch.seq_lens, @@ -298,9 +471,13 @@ class FlashInferAttnBackend(AttentionBackend): encoder_lens=forward_batch.encoder_lens, spec_info=None, fixed_split_size=self.prefill_split_tile_size, + multi_item_params=multi_item_params, ) self.forward_metadata = PrefillMetadata( - self.prefill_wrappers_paged, use_ragged, extend_no_prefix + self.prefill_wrappers_paged, + use_ragged, + extend_no_prefix, + multi_item_params, ) def init_cuda_graph_state( @@ -531,7 +708,20 @@ class FlashInferAttnBackend(AttentionBackend): forward_batch.token_to_kv_pool.get_kv_buffer(layer.layer_id), causal=not layer.is_cross_attention, sm_scale=layer.scaling, - window_left=layer.sliding_window_size, + # Disable sliding window attention for multi-item scoring: + # - Sliding window could cut across item boundaries, breaking semantic coherence + # - Multi-item sequences need full attention to properly handle delimiter tokens + # - Specialized multi-item parameters (prefix_len_ptr, token_pos_in_items_ptr) + # provide more precise attention control than simple sliding windows + # - Item-aware masking takes precedence over window-based masking + window_left=( + layer.sliding_window_size + if not ( + self.forward_metadata.multi_item_params + and self.forward_metadata.multi_item_params.is_enabled() + ) + else -1 + ), logits_soft_cap=logits_soft_cap, # Must use _float to avoid device-to-host copy that breaks cuda graph capture. k_scale=layer.k_scale_float, @@ -952,6 +1142,7 @@ class FlashInferIndicesUpdaterPrefill: encoder_lens: Optional[torch.Tensor], spec_info: Optional[SpecInput], fixed_split_size: Optional[int] = None, + multi_item_params: Optional[MultiItemScoringParams] = None, ): if use_ragged: # TODO: remove this device sync, we can use forward_batch.extend_prefix_lens_cpu @@ -976,6 +1167,7 @@ class FlashInferIndicesUpdaterPrefill: use_ragged, spec_info, fixed_split_size=fixed_split_size, + multi_item_params=multi_item_params, ) def update_sliding_window( @@ -990,6 +1182,7 @@ class FlashInferIndicesUpdaterPrefill: encoder_lens: Optional[torch.Tensor], spec_info: Optional[SpecInput], fixed_split_size: Optional[int] = None, + multi_item_params: Optional[MultiItemScoringParams] = None, ): for wrapper_id in range(2): if wrapper_id == 0: @@ -1023,6 +1216,7 @@ class FlashInferIndicesUpdaterPrefill: use_ragged, spec_info, use_sliding_window_kv_pool=use_sliding_window_kv_pool, + multi_item_params=multi_item_params, ) def update_cross_attention( @@ -1037,6 +1231,7 @@ class FlashInferIndicesUpdaterPrefill: encoder_lens: Optional[torch.Tensor], spec_info: Optional[SpecInput], fixed_split_size: Optional[int] = None, + multi_item_params: Optional[MultiItemScoringParams] = None, ): for wrapper_id in range(2): if wrapper_id == 0: @@ -1063,6 +1258,7 @@ class FlashInferIndicesUpdaterPrefill: self.qo_indptr[wrapper_id], use_ragged, spec_info, + multi_item_params=multi_item_params, ) def call_begin_forward( @@ -1081,6 +1277,7 @@ class FlashInferIndicesUpdaterPrefill: spec_info: Optional[SpecInput], use_sliding_window_kv_pool: bool = False, fixed_split_size: Optional[int] = None, + multi_item_params: Optional[MultiItemScoringParams] = None, ): bs = len(seq_lens) if spec_info is None: @@ -1136,6 +1333,22 @@ class FlashInferIndicesUpdaterPrefill: ) # cached part + # Conditionally set multi-item parameters + if multi_item_params is not None and multi_item_params.is_enabled(): + # Multi-item scoring is active - use specialized parameters and disable generic custom_mask + use_custom_mask = None + prefix_len_ptr = multi_item_params.prefix_len_ptr + token_pos_in_items_ptr = multi_item_params.token_pos_in_items_ptr + token_pos_in_items_len = multi_item_params.token_pos_in_items_len + max_item_len_ptr = multi_item_params.max_item_len_ptr + else: + # No multi-item scoring - use standard parameters + use_custom_mask = custom_mask + prefix_len_ptr = None + token_pos_in_items_ptr = None + token_pos_in_items_len = 0 + max_item_len_ptr = None + wrapper_paged.begin_forward( qo_indptr, kv_indptr, @@ -1147,9 +1360,13 @@ class FlashInferIndicesUpdaterPrefill: 1, q_data_type=self.q_data_type, kv_data_type=self.data_type, - custom_mask=custom_mask, + custom_mask=use_custom_mask, non_blocking=True, fixed_split_size=fixed_split_size, + prefix_len_ptr=prefix_len_ptr, + token_pos_in_items_ptr=token_pos_in_items_ptr, + token_pos_in_items_len=token_pos_in_items_len, + max_item_len_ptr=max_item_len_ptr, ) diff --git a/python/sglang/srt/layers/logits_processor.py b/python/sglang/srt/layers/logits_processor.py index 5f9651086..a95e2011a 100644 --- a/python/sglang/srt/layers/logits_processor.py +++ b/python/sglang/srt/layers/logits_processor.py @@ -60,7 +60,8 @@ _is_npu = is_npu() class LogitsProcessorOutput: ## Part 1: This part will be assigned in python/sglang/srt/layers/logits_processor.py::LogitsProcessor # The logits of the next tokens. shape: [#seq, vocab_size] - next_token_logits: torch.Tensor + # Can be None for certain prefill-only requests (e.g., multi-item scoring) that don't need next token generation + next_token_logits: Optional[torch.Tensor] # Used by speculative decoding (EAGLE) # The last hidden layers hidden_states: Optional[torch.Tensor] = None @@ -85,7 +86,10 @@ class LogitsProcessorOutput: input_top_logprobs_val: List = None input_top_logprobs_idx: List = None # The logprobs and ids of the requested token ids in input positions. shape: [#seq, n] (n is the number of requested token ids) - input_token_ids_logprobs_val: Optional[List] = None + # Can contain either lists or GPU tensors (for delayed GPU-to-CPU transfer optimization) + input_token_ids_logprobs_val: Optional[List[Union[List[float], torch.Tensor]]] = ( + None + ) input_token_ids_logprobs_idx: Optional[List] = None @@ -127,6 +131,9 @@ class LogitsMetadata: # for padding padded_static_len: int = -1 + # Whether this batch is prefill-only (no token generation needed) + is_prefill_only: bool = False + @classmethod def from_forward_batch(cls, forward_batch: ForwardBatch): if ( @@ -169,6 +176,7 @@ class LogitsMetadata: token_ids_logprobs=forward_batch.token_ids_logprobs, extend_input_logprob_token_ids_gpu=forward_batch.extend_input_logprob_token_ids_gpu, padded_static_len=forward_batch.padded_static_len, + is_prefill_only=forward_batch.is_prefill_only, global_num_tokens_gpu=forward_batch.global_num_tokens_gpu, dp_local_start_pos=forward_batch.dp_local_start_pos, dp_local_num_tokens=forward_batch.dp_local_num_tokens, @@ -247,6 +255,108 @@ class LogitsProcessor(nn.Module): "debug_tensor_dump_output_folder", None ) + def compute_logprobs_for_multi_item_scoring( + self, + input_ids, + hidden_states, + lm_head: VocabParallelEmbedding, + logits_metadata: Union[LogitsMetadata, ForwardBatch], + delimiter_token: int, + ): + """ + Compute logprobs for multi-item scoring using delimiter-based token extraction. + + This method is designed for scenarios where you want to score multiple items/candidates + against a single query by combining them into one sequence separated by delimiters. + + Sequence format: QueryItem1Item2... + Scoring positions: Extracts logprobs at positions before each + + Args: + input_ids (torch.Tensor): Input token IDs containing query and items separated by delimiters. + Shape: [total_sequence_length] for single request or [batch_total_length] for batch. + hidden_states (torch.Tensor): Hidden states from the model. + Shape: [sequence_length, hidden_dim]. + lm_head (VocabParallelEmbedding): Language model head for computing logits. + logits_metadata (Union[LogitsMetadata, ForwardBatch]): Metadata containing batch info + and token ID specifications for logprob extraction. + delimiter_token (int): Token ID used as delimiter between query and items. + + Returns: + LogitsProcessorOutput: Contains: + - next_token_logits: None (not needed for scoring-only requests) + - input_token_logprobs: Logprobs of delimiter tokens at scoring positions + - input_top_logprobs_val: Top-k logprobs at delimiter positions (if requested) + - input_top_logprobs_idx: Top-k token indices at delimiter positions (if requested) + - input_token_ids_logprobs_val: Logprobs for user-requested token IDs (if any) + - input_token_ids_logprobs_idx: Indices for user-requested token IDs (if any) + """ + multi_item_indices = (input_ids == delimiter_token).nonzero(as_tuple=True)[ + 0 + ] - 1 + # Extract hidden states at delimiter positions for multi-item scoring + sliced_hidden = hidden_states[multi_item_indices] + + sliced_logits = self._get_logits(sliced_hidden, lm_head, logits_metadata) + sliced_logprobs = torch.nn.functional.log_softmax(sliced_logits, dim=-1) + + # Initialize return values + input_token_ids_logprobs_val = [] + input_token_ids_logprobs_idx = [] + input_top_logprobs_val = None + input_top_logprobs_idx = None + + # Recalculate extend_logprob_pruned_lens_cpu to match delimiter counts per request + # Original contains sequence lengths, but we need delimiter counts for sliced_logprobs + if ( + logits_metadata.token_ids_logprobs + or logits_metadata.extend_return_top_logprob + ): + logits_metadata.extend_logprob_pruned_lens_cpu = [] + + if logits_metadata.extend_seq_lens_cpu is not None: + # Multi-request batch: count delimiters per request + input_pt = 0 + for req_seq_len in logits_metadata.extend_seq_lens_cpu: + req_input_ids = input_ids[input_pt : input_pt + req_seq_len] + delimiter_count = (req_input_ids == delimiter_token).sum().item() + logits_metadata.extend_logprob_pruned_lens_cpu.append( + delimiter_count + ) + input_pt += req_seq_len + else: + # Single request case: one request gets all delimiters + total_delimiters = (input_ids == delimiter_token).sum().item() + logits_metadata.extend_logprob_pruned_lens_cpu = [total_delimiters] + + # Get the logprobs of specified token ids + if logits_metadata.extend_token_ids_logprob: + ( + input_token_ids_logprobs_val, + input_token_ids_logprobs_idx, + ) = self.get_token_ids_logprobs( + sliced_logprobs, logits_metadata, delay_cpu_copy=True + ) + + # Get the logprob of top-k tokens + if logits_metadata.extend_return_top_logprob: + ( + input_top_logprobs_val, + input_top_logprobs_idx, + ) = self.get_top_logprobs(sliced_logprobs, logits_metadata) + + # For input_token_logprobs, use delimiter token logprobs + input_token_logprobs = sliced_logprobs[:, delimiter_token] + + return LogitsProcessorOutput( + next_token_logits=None, # Multi-item scoring doesn't need next token logits + input_token_logprobs=input_token_logprobs, + input_top_logprobs_val=input_top_logprobs_val, + input_top_logprobs_idx=input_top_logprobs_idx, + input_token_ids_logprobs_val=input_token_ids_logprobs_val, + input_token_ids_logprobs_idx=input_token_ids_logprobs_idx, + ) + def forward( self, input_ids, @@ -257,6 +367,16 @@ class LogitsProcessor(nn.Module): ) -> LogitsProcessorOutput: if isinstance(logits_metadata, ForwardBatch): logits_metadata = LogitsMetadata.from_forward_batch(logits_metadata) + + # Check if multi-item scoring is enabled via server args (only for prefill-only requests) + multi_item_delimiter = global_server_args_dict.get( + "multi_item_scoring_delimiter" + ) + if multi_item_delimiter is not None and logits_metadata.is_prefill_only: + return self.compute_logprobs_for_multi_item_scoring( + input_ids, hidden_states, lm_head, logits_metadata, multi_item_delimiter + ) + # Get the last hidden states and last logits for the next token prediction if ( logits_metadata.forward_mode.is_decode_or_idle() @@ -584,7 +704,9 @@ class LogitsProcessor(nn.Module): @staticmethod def get_token_ids_logprobs( - all_logprobs: torch.Tensor, logits_metadata: LogitsMetadata + all_logprobs: torch.Tensor, + logits_metadata: LogitsMetadata, + delay_cpu_copy: bool = False, ): input_token_ids_logprobs_val, input_token_ids_logprobs_idx = [], [] pt = 0 @@ -597,9 +719,17 @@ class LogitsProcessor(nn.Module): input_token_ids_logprobs_idx.append([]) continue - input_token_ids_logprobs_val.append( - [all_logprobs[pt + j, token_ids].tolist() for j in range(pruned_len)] - ) + position_logprobs = all_logprobs[ + pt : pt + pruned_len, token_ids + ] # Shape: [pruned_len, num_tokens] + + if delay_cpu_copy: + # Keep as tensor to delay GPU-to-CPU transfer + input_token_ids_logprobs_val.append(position_logprobs) + else: + # Convert to list immediately (default behavior) + input_token_ids_logprobs_val.append(position_logprobs.tolist()) + input_token_ids_logprobs_idx.append([token_ids for _ in range(pruned_len)]) pt += pruned_len diff --git a/python/sglang/srt/managers/schedule_batch.py b/python/sglang/srt/managers/schedule_batch.py index 15cdd555a..075d90477 100644 --- a/python/sglang/srt/managers/schedule_batch.py +++ b/python/sglang/srt/managers/schedule_batch.py @@ -114,6 +114,7 @@ GLOBAL_SERVER_ARGS_KEYS = [ "enable_deterministic_inference", "nsa_prefill", "nsa_decode", + "multi_item_scoring_delimiter", ] # Put some global args for easy access @@ -666,9 +667,11 @@ class Req: def is_prefill_only(self) -> bool: """Check if this request is prefill-only (no token generation needed).""" # NOTE: when spec is enabled, prefill_only optimizations are disabled - return ( - self.sampling_params.max_new_tokens == 0 - and global_server_args_dict["speculative_algorithm"] is None + from sglang.srt.speculative.spec_info import SpeculativeAlgorithm + + spec_alg = global_server_args_dict["speculative_algorithm"] + return self.sampling_params.max_new_tokens == 0 and ( + spec_alg is None or spec_alg == SpeculativeAlgorithm.NONE ) def add_latency(self, stage: RequestStage): diff --git a/python/sglang/srt/managers/scheduler_output_processor_mixin.py b/python/sglang/srt/managers/scheduler_output_processor_mixin.py index b31bf92a7..2072f9f68 100644 --- a/python/sglang/srt/managers/scheduler_output_processor_mixin.py +++ b/python/sglang/srt/managers/scheduler_output_processor_mixin.py @@ -104,7 +104,10 @@ class SchedulerOutputProcessorMixin: assert extend_input_len_per_req is not None extend_logprob_start_len = extend_logprob_start_len_per_req[i] extend_input_len = extend_input_len_per_req[i] - num_input_logprobs = extend_input_len - extend_logprob_start_len + + num_input_logprobs = self._calculate_num_input_logprobs( + req, extend_input_len, extend_logprob_start_len + ) if req.return_logprob: self.add_logprob_return_values( @@ -159,8 +162,8 @@ class SchedulerOutputProcessorMixin: extend_input_len = extend_input_len_per_req[i] if extend_logprob_start_len < extend_input_len: # Update input logprobs. - num_input_logprobs = ( - extend_input_len - extend_logprob_start_len + num_input_logprobs = self._calculate_num_input_logprobs( + req, extend_input_len, extend_logprob_start_len ) if req.return_logprob: self.add_input_logprob_return_values( @@ -303,6 +306,153 @@ class SchedulerOutputProcessorMixin: ): self.log_decode_stats(can_run_cuda_graph, running_batch=batch) + def _process_input_token_logprobs( + self, req: Req, input_token_logprobs: List + ) -> None: + """Process input token logprobs values and indices.""" + is_multi_item_scoring = self._is_multi_item_scoring(req) + + # Process logprob values - handle multi-item scoring vs regular requests + if is_multi_item_scoring: + # Multi-item scoring: use all logprobs as-is + req.input_token_logprobs_val = input_token_logprobs + else: + # Regular request: add None at start, remove last (sampling token) + req.input_token_logprobs_val = [None] + input_token_logprobs[:-1] + + # Process logprob indices based on scoring type + if is_multi_item_scoring: + # Multi-item scoring: only include delimiter token positions + relevant_tokens = req.origin_input_ids[req.logprob_start_len :] + input_token_logprobs_idx = [ + token_id + for token_id in relevant_tokens + if token_id == self.server_args.multi_item_scoring_delimiter + ] + else: + # Regular request: include all tokens from logprob_start_len onwards + input_token_logprobs_idx = req.origin_input_ids[req.logprob_start_len :] + + # Clip padded hash values from image tokens to prevent detokenization errors + req.input_token_logprobs_idx = [ + x if x < self.model_config.vocab_size - 1 else 0 + for x in input_token_logprobs_idx + ] + + def _process_input_top_logprobs(self, req: Req) -> None: + """Process input top logprobs.""" + if req.top_logprobs_num <= 0: + return + + is_multi_item_scoring = self._is_multi_item_scoring(req) + + # Initialize arrays - multi-item scoring starts empty, others start with None + req.input_top_logprobs_val = [] if is_multi_item_scoring else [None] + req.input_top_logprobs_idx = [] if is_multi_item_scoring else [None] + + # Extend arrays with temp values + for val, idx in zip( + req.temp_input_top_logprobs_val, + req.temp_input_top_logprobs_idx, + strict=True, + ): + req.input_top_logprobs_val.extend(val) + req.input_top_logprobs_idx.extend(idx) + + # Remove last token (sampling token) for non multi-item scoring requests + if not is_multi_item_scoring: + req.input_top_logprobs_val.pop() + req.input_top_logprobs_idx.pop() + + # Clean up temp storage + req.temp_input_top_logprobs_idx = None + req.temp_input_top_logprobs_val = None + + def _process_input_token_ids_logprobs(self, req: Req) -> None: + """Process input token IDs logprobs.""" + if req.token_ids_logprob is None: + return + + is_multi_item_scoring = self._is_multi_item_scoring(req) + + # Initialize arrays - multi-item scoring starts empty, others start with None + req.input_token_ids_logprobs_val = [] if is_multi_item_scoring else [None] + req.input_token_ids_logprobs_idx = [] if is_multi_item_scoring else [None] + + # Process temp values - convert tensors to lists and extend arrays + for val, idx in zip( + req.temp_input_token_ids_logprobs_val, + req.temp_input_token_ids_logprobs_idx, + strict=True, + ): + val_list = val.tolist() if isinstance(val, torch.Tensor) else val + req.input_token_ids_logprobs_val.extend( + val_list if isinstance(val_list, list) else [val_list] + ) + req.input_token_ids_logprobs_idx.extend(idx) + + # Remove last token (sampling token) for non multi-item scoring requests + if not is_multi_item_scoring: + req.input_token_ids_logprobs_val.pop() + req.input_token_ids_logprobs_idx.pop() + + # Clean up temp storage + req.temp_input_token_ids_logprobs_idx = None + req.temp_input_token_ids_logprobs_val = None + + def _calculate_relevant_tokens_len(self, req: Req) -> int: + """Calculate the expected length of logprob arrays based on whether multi-item scoring is enabled. + + For multi-item scoring, only delimiter positions have logprobs. + For regular requests, all positions from logprob_start_len onwards have logprobs. + """ + is_multi_item_scoring = self._is_multi_item_scoring(req) + + if is_multi_item_scoring: + # Multi-item scoring: count delimiter tokens from logprob_start_len onwards + relevant_tokens = req.origin_input_ids[req.logprob_start_len :] + return sum( + 1 + for token_id in relevant_tokens + if token_id == self.server_args.multi_item_scoring_delimiter + ) + else: + # Regular request: all tokens from logprob_start_len onwards + return len(req.origin_input_ids) - req.logprob_start_len + + def _calculate_num_input_logprobs( + self, req: Req, extend_input_len: int, extend_logprob_start_len: int + ) -> int: + """Calculate the number of input logprobs based on whether multi-item scoring is enabled. + + For multi-item scoring, only delimiter positions have logprobs. + For regular requests, all positions in the range have logprobs. + """ + is_multi_item_scoring = self._is_multi_item_scoring(req) + + if is_multi_item_scoring: + # Multi-item scoring: count delimiter tokens in the relevant portion + relevant_tokens = req.origin_input_ids[ + extend_logprob_start_len:extend_input_len + ] + return sum( + 1 + for token_id in relevant_tokens + if token_id == self.server_args.multi_item_scoring_delimiter + ) + else: + # Regular request: all tokens in the range + return extend_input_len - extend_logprob_start_len + + def _is_multi_item_scoring(self, req: Req) -> bool: + """Check if request uses multi-item scoring. + + Multi-item scoring applies to prefill-only requests when a delimiter + token is configured. In this mode, only positions containing the + delimiter token receive logprobs. + """ + return req.is_prefill_only and self.server_args.multi_item_scoring_delimiter + def add_input_logprob_return_values( self: Scheduler, i: int, @@ -371,63 +521,14 @@ class SchedulerOutputProcessorMixin: assert req.input_top_logprobs_val is None assert req.input_top_logprobs_idx is None - # Compute input_token_logprobs_val - # Always pad the first one with None. - req.input_token_logprobs_val = [None] - req.input_token_logprobs_val.extend(input_token_logprobs) - # The last input logprob is for sampling, so just pop it out. - req.input_token_logprobs_val.pop() + # Process all input logprob types using helper functions + self._process_input_token_logprobs(req, input_token_logprobs) + self._process_input_top_logprobs(req) - # Compute input_token_logprobs_idx - input_token_logprobs_idx = req.origin_input_ids[req.logprob_start_len :] - # Clip the padded hash values from image tokens. - # Otherwise, it will lead to detokenization errors. - input_token_logprobs_idx = [ - x if x < self.model_config.vocab_size - 1 else 0 - for x in input_token_logprobs_idx - ] - req.input_token_logprobs_idx = input_token_logprobs_idx - - if req.top_logprobs_num > 0: - req.input_top_logprobs_val = [None] - req.input_top_logprobs_idx = [None] - assert len(req.temp_input_token_ids_logprobs_val) == len( - req.temp_input_token_ids_logprobs_idx - ) - for val, idx in zip( - req.temp_input_top_logprobs_val, - req.temp_input_top_logprobs_idx, - strict=True, - ): - req.input_top_logprobs_val.extend(val) - req.input_top_logprobs_idx.extend(idx) - - # Last token is a sample token. - req.input_top_logprobs_val.pop() - req.input_top_logprobs_idx.pop() - req.temp_input_top_logprobs_idx = None - req.temp_input_top_logprobs_val = None - - if req.token_ids_logprob is not None: - req.input_token_ids_logprobs_val = [None] - req.input_token_ids_logprobs_idx = [None] - - for val, idx in zip( - req.temp_input_token_ids_logprobs_val, - req.temp_input_token_ids_logprobs_idx, - strict=True, - ): - req.input_token_ids_logprobs_val.extend(val) - req.input_token_ids_logprobs_idx.extend(idx) - - # Last token is a sample token. - req.input_token_ids_logprobs_val.pop() - req.input_token_ids_logprobs_idx.pop() - req.temp_input_token_ids_logprobs_idx = None - req.temp_input_token_ids_logprobs_val = None + self._process_input_token_ids_logprobs(req) if req.return_logprob: - relevant_tokens_len = len(req.origin_input_ids) - req.logprob_start_len + relevant_tokens_len = self._calculate_relevant_tokens_len(req) assert len(req.input_token_logprobs_val) == relevant_tokens_len assert len(req.input_token_logprobs_idx) == relevant_tokens_len if req.top_logprobs_num > 0: diff --git a/python/sglang/srt/managers/tokenizer_manager.py b/python/sglang/srt/managers/tokenizer_manager.py index e9403a671..9d6bf9fc5 100644 --- a/python/sglang/srt/managers/tokenizer_manager.py +++ b/python/sglang/srt/managers/tokenizer_manager.py @@ -182,6 +182,8 @@ class TokenizerManager(TokenizerCommunicatorMixin): if speculative_algorithm.is_none() else server_args.speculative_num_draft_tokens ) + # Initialize delimiter text for multi-item scoring (will be set after tokenizer is loaded) + self.multi_item_delimiter_text = None if self.model_config.is_multimodal: import_processors("sglang.srt.multimodal.processors") @@ -223,6 +225,7 @@ class TokenizerManager(TokenizerCommunicatorMixin): self.processor = _processor self.tokenizer = get_tokenizer_from_processor(self.processor) os.environ["TOKENIZERS_PARALLELISM"] = "false" + self._initialize_multi_item_delimiter_text() else: self.mm_processor = self.processor = None @@ -235,6 +238,7 @@ class TokenizerManager(TokenizerCommunicatorMixin): trust_remote_code=server_args.trust_remote_code, revision=server_args.revision, ) + self._initialize_multi_item_delimiter_text() # Initialize async dynamic batch tokenizer if enabled (common for both multimodal and non-multimodal) if ( server_args.enable_dynamic_batch_tokenizer @@ -1678,6 +1682,201 @@ class TokenizerManager(TokenizerCommunicatorMixin): if len(self.model_update_tmp) == self.server_args.dp_size: self.model_update_result.set_result(self.model_update_tmp) + def _initialize_multi_item_delimiter_text(self): + """Initialize multi-item delimiter text from token ID after tokenizer is loaded.""" + if ( + hasattr(self.server_args, "multi_item_scoring_delimiter") + and self.server_args.multi_item_scoring_delimiter is not None + and self.tokenizer is not None + ): + try: + self.multi_item_delimiter_text = self.tokenizer.decode( + [self.server_args.multi_item_scoring_delimiter], + skip_special_tokens=False, + ) + except Exception as e: + logger.warning( + f"Failed to decode delimiter token {self.server_args.multi_item_scoring_delimiter}: {e}" + ) + self.multi_item_delimiter_text = None + + def _build_multi_item_token_sequence( + self, query: List[int], items: List[List[int]], delimiter_token_id: int + ) -> List[int]: + """ + Build a single token sequence for multi-item scoring. + Format: queryitem1item2item3 + + Args: + query: Query token IDs + items: List of item token ID sequences + delimiter_token_id: Token ID to use as delimiter + + Returns: + Combined token sequence + """ + combined_sequence = query[:] # Start with query + + for item in items: + combined_sequence.append(delimiter_token_id) # Add delimiter + combined_sequence.extend(item) # Add item tokens + + # Add final delimiter after the last item for logprob extraction + combined_sequence.append(delimiter_token_id) + + return combined_sequence + + def _extract_logprobs_for_tokens( + self, logprobs_data: List, label_token_ids: List[int] + ) -> Dict[int, float]: + """ + Extract logprobs for specified token IDs from logprobs data. + + Args: + logprobs_data: List of (logprob, token_id, text) tuples + label_token_ids: Token IDs to extract logprobs for + + Returns: + Dictionary mapping token_id to logprob + """ + logprobs = {} + if logprobs_data: + for logprob, token_id, _ in logprobs_data: + if token_id in label_token_ids: + logprobs[token_id] = logprob + return logprobs + + def _convert_logprobs_to_scores( + self, + logprobs: Dict[int, float], + label_token_ids: List[int], + apply_softmax: bool, + ) -> List[float]: + """ + Convert logprobs dictionary to ordered score list. + + Args: + logprobs: Dictionary mapping token_id to logprob + label_token_ids: Token IDs in desired order + apply_softmax: Whether to apply softmax normalization + + Returns: + List of scores in the same order as label_token_ids + """ + score_list = [ + logprobs.get(token_id, float("-inf")) for token_id in label_token_ids + ] + + if apply_softmax: + score_list = torch.softmax(torch.tensor(score_list), dim=0).tolist() + else: + # Convert logprobs to probabilities if not using softmax + score_list = [ + math.exp(x) if x != float("-inf") else 0.0 for x in score_list + ] + + return score_list + + def _process_multi_item_scoring_results( + self, + results: Any, + items: List, + label_token_ids: List[int], + apply_softmax: bool, + batch_request=None, + ) -> List[List[float]]: + """ + Process results from multi-item scoring request. + Extracts logprobs at delimiter positions from input_token_ids_logprobs. + + Args: + results: Results from generate_request + items: List of items being scored + label_token_ids: Token IDs to extract scores for + apply_softmax: Whether to apply softmax normalization + batch_request: The original batch request containing input sequence + + Returns: + List of score lists, one for each item + """ + single_result = results[0] if isinstance(results, list) else results + + # For multi-item scoring, logprobs are in input_token_ids_logprobs + input_logprobs = single_result["meta_info"].get("input_token_ids_logprobs", []) + + if not input_logprobs: + raise RuntimeError( + f"input_token_ids_logprobs is empty for multi-item scoring request {single_result['meta_info'].get('id', '')}. " + "This indicates token_ids_logprobs were not computed properly for Mutil Item Scoring." + ) + + scores = [] + num_items = len(items) if isinstance(items, list) else 1 + + # Check if we have the expected number of logprobs + expected_logprobs_count = num_items + 1 + if len(input_logprobs) != expected_logprobs_count: + raise RuntimeError( + f"Expected {expected_logprobs_count} input_token_ids_logprobs for multi-item scoring " + f"with {num_items} items, but got {len(input_logprobs)}. " + f"Request ID: {single_result['meta_info'].get('id', '')}" + ) + + # Skip the first delimiter (between query and first item) and process remaining delimiter positions + # We want to exclude the first one since it represents the boundary between query and first item, not an item boundary + start_idx = 1 if len(input_logprobs) > 1 else 0 + + # Process logprobs for each item position (excluding first delimiter) + for item_idx in range(num_items): + logprob_idx = start_idx + item_idx + item_logprobs_data = input_logprobs[logprob_idx] + logprobs = self._extract_logprobs_for_tokens( + item_logprobs_data, label_token_ids + ) + score_list = self._convert_logprobs_to_scores( + logprobs, label_token_ids, apply_softmax + ) + scores.append(score_list) + + return scores + + def _process_single_item_scoring_results( + self, results: Any, label_token_ids: List[int], apply_softmax: bool + ) -> List[List[float]]: + """ + Process results from single-item scoring request. + Single-item scoring results are stored in output_token_ids_logprobs. + + Args: + results: Results from generate_request + label_token_ids: Token IDs to extract scores for + apply_softmax: Whether to apply softmax normalization + + Returns: + List of score lists, one for each result + """ + scores = [] + + for result in results: + # For single-item scoring, logprobs are in output_token_ids_logprobs + output_logprobs = result["meta_info"].get("output_token_ids_logprobs", []) + + if not output_logprobs or len(output_logprobs) == 0: + raise RuntimeError( + f"output_logprobs is empty for request {result['meta_info'].get('id', '')}." + ) + + # Extract logprobs for the first (and only) position + logprobs = self._extract_logprobs_for_tokens( + output_logprobs[0], label_token_ids + ) + score_list = self._convert_logprobs_to_scores( + logprobs, label_token_ids, apply_softmax + ) + scores.append(score_list) + + return scores + async def score_request( self, query: Optional[Union[str, List[int]]] = None, @@ -1688,7 +1887,29 @@ class TokenizerManager(TokenizerCommunicatorMixin): request: Optional[Any] = None, ) -> List[List[float]]: """ - See Engine.score() for more details. + Score the probability of specified token IDs appearing after the given (query + item) pair. + + This method supports two scoring approaches: + 1. Single-Item scoring (default): Process each query+item pair independently + 2. Multi-Item scoring: When multi_item_scoring_delimiter is set, combine query and + multiple items into a single sequence using delimiter for efficient processing. + Note: item_first parameter is ignored in multi-item scoring mode since it uses + a fixed format: queryitem1item2item3 + + Multi-item scoring works with both text and pre-tokenized inputs: + - Text: queryitem1item2item3 + - Tokens: queryitem1item2item3 + + Args: + query: The query text or pre-tokenized query token IDs + items: The item text(s) or pre-tokenized item token IDs + label_token_ids: List of token IDs to compute probabilities for + apply_softmax: Whether to normalize probabilities using softmax + item_first: If True, prepend items to query. Ignored for multi-item scoring. + request: Optional FastAPI request object + + Returns: + List of lists containing probabilities for each item and each label token """ if label_token_ids is None: raise ValueError("label_token_ids must be provided") @@ -1701,9 +1922,17 @@ class TokenizerManager(TokenizerCommunicatorMixin): f"Token ID {token_id} is out of vocabulary (vocab size: {vocab_size})" ) + # Check if multi-item scoring is enabled by presence of delimiter + use_multi_item_scoring = ( + self.server_args.multi_item_scoring_delimiter is not None + and self.multi_item_delimiter_text is not None + ) + batch_request = GenerateReqInput( token_ids_logprob=label_token_ids, return_logprob=True, + # Set logprob_start_len=0 for multi-item scoring since we want logprobs at all delimiter positions + logprob_start_len=0 if use_multi_item_scoring else -1, stream=False, sampling_params={"max_new_tokens": 0}, ) @@ -1715,12 +1944,23 @@ class TokenizerManager(TokenizerCommunicatorMixin): ): # Both query and items are text items_list = [items] if isinstance(items, str) else items - if item_first: - prompts = [f"{item}{query}" for item in items_list] - else: - prompts = [f"{query}{item}" for item in items_list] - batch_request.text = prompts + if use_multi_item_scoring: + # Multi-item scoring: create single prompt with delimiter text + # Always use format: queryitem1item2item3 + # (item_first is ignored for multi-item scoring) + delimiter = self.multi_item_delimiter_text + combined_items = delimiter.join(items_list) + # Add final delimiter after the last item for logprob extraction + single_prompt = f"{query}{delimiter}{combined_items}{delimiter}" + batch_request.text = [single_prompt] + else: + # Single-item scoring: create separate prompts for each item + if item_first: + prompts = [f"{item}{query}" for item in items_list] + else: + prompts = [f"{query}{item}" for item in items_list] + batch_request.text = prompts elif ( isinstance(query, list) @@ -1729,61 +1969,38 @@ class TokenizerManager(TokenizerCommunicatorMixin): and isinstance(items[0], list) ): # Both query and items are token IDs - if item_first: - input_ids_list = [item + query for item in items] + if use_multi_item_scoring: + # Multi-item scoring: concatenate with delimiter token ID + # Format: queryitem1item2item3 + delimiter_token_id = self.server_args.multi_item_scoring_delimiter + combined_input_ids = self._build_multi_item_token_sequence( + query, items, delimiter_token_id + ) + batch_request.input_ids = [combined_input_ids] else: - input_ids_list = [query + item for item in items] - - batch_request.input_ids = input_ids_list + # Single-item scoring: process each item separately + if item_first: + input_ids_list = [item + query for item in items] + else: + input_ids_list = [query + item for item in items] + batch_request.input_ids = input_ids_list else: raise ValueError( "Invalid combination of query/items types for score_request." ) results = await self.generate_request(batch_request, request).__anext__() - scores = [] - for result in results: - # Get logprobs for each token - logprobs = {} - - # For scoring requests, we read from output_token_ids_logprobs since we want - # the logprobs for specific tokens mentioned in the label_token_ids at - # the next position after the last token in the prompt - output_logprobs = result["meta_info"].get("output_token_ids_logprobs", []) - - # Check if output_logprobs is properly populated - if ( - output_logprobs is None - or not output_logprobs - or len(output_logprobs) == 0 - ): - raise RuntimeError( - f"output_logprobs is empty for request {result['meta_info'].get('id', '')}. " - "This indicates token_ids_logprobs were not computed properly for the scoring request." - ) - - for logprob, token_id, _ in output_logprobs[0]: - if token_id in label_token_ids: - logprobs[token_id] = logprob - - # Get scores in order of label_token_ids - score_list = [ - logprobs.get(token_id, float("-inf")) for token_id in label_token_ids - ] - - # Apply softmax to logprobs if needed - if apply_softmax: - score_list = torch.softmax(torch.tensor(score_list), dim=0).tolist() - else: - # Convert logprobs to probabilities if not using softmax - score_list = [ - math.exp(x) if x != float("-inf") else 0.0 for x in score_list - ] - - scores.append(score_list) - - return scores + if use_multi_item_scoring: + # Multi-item scoring: extract scores from input_token_ids_logprobs + return self._process_multi_item_scoring_results( + results, items, label_token_ids, apply_softmax, batch_request + ) + else: + # Single-item scoring: process each result separately + return self._process_single_item_scoring_results( + results, label_token_ids, apply_softmax + ) async def watch_load_thread(self): # Only for dp_controller when dp_size > 1 diff --git a/python/sglang/srt/managers/tp_worker.py b/python/sglang/srt/managers/tp_worker.py index 051df74d7..33ac661b9 100644 --- a/python/sglang/srt/managers/tp_worker.py +++ b/python/sglang/srt/managers/tp_worker.py @@ -266,10 +266,16 @@ class TpModelWorker: if model_worker_batch.is_prefill_only: # For prefill-only requests, create dummy token IDs on CPU - batch_result.next_token_ids = torch.zeros_like( - model_worker_batch.input_ids, dtype=torch.long + # The size should match the batch size (number of sequences), not total tokens + batch_result.next_token_ids = torch.zeros( + len(model_worker_batch.seq_lens), + dtype=torch.long, + device=model_worker_batch.input_ids.device, ) - if model_worker_batch.return_logprob: + if ( + model_worker_batch.return_logprob + and logits_output.next_token_logits is not None + ): # NOTE: Compute logprobs without full sampling self.model_runner.compute_logprobs_only( logits_output, model_worker_batch diff --git a/python/sglang/srt/model_executor/forward_batch_info.py b/python/sglang/srt/model_executor/forward_batch_info.py index e16458e02..97a81f872 100644 --- a/python/sglang/srt/model_executor/forward_batch_info.py +++ b/python/sglang/srt/model_executor/forward_batch_info.py @@ -278,6 +278,9 @@ class ForwardBatch: can_run_dp_cuda_graph: bool = False global_forward_mode: Optional[ForwardMode] = None + # Whether this batch is prefill-only (no token generation needed) + is_prefill_only: bool = False + # Speculative decoding spec_info: Optional[SpecInput] = None spec_algorithm: SpeculativeAlgorithm = None @@ -325,6 +328,7 @@ class ForwardBatch: is_extend_in_batch=batch.is_extend_in_batch, can_run_dp_cuda_graph=batch.can_run_dp_cuda_graph, global_forward_mode=batch.global_forward_mode, + is_prefill_only=batch.is_prefill_only, lora_ids=batch.lora_ids, sampling_info=batch.sampling_info, req_to_token_pool=model_runner.req_to_token_pool, diff --git a/python/sglang/srt/server_args.py b/python/sglang/srt/server_args.py index 52cbd038f..64d3371df 100644 --- a/python/sglang/srt/server_args.py +++ b/python/sglang/srt/server_args.py @@ -382,6 +382,12 @@ class ServerArgs: offload_prefetch_step: int = 1 offload_mode: str = "cpu" + # Scoring configuration + # Delimiter token ID used to combine Query and Items into a single sequence for multi-item scoring. + # Format: QueryItem1Item2... + # This enables efficient batch processing of multiple items against a single query. + multi_item_scoring_delimiter: Optional[Union[int]] = None + # Optimization/debug options disable_radix_cache: bool = False cuda_graph_max_bs: Optional[int] = None @@ -2334,7 +2340,13 @@ class ServerArgs: choices=["float32", "bfloat16"], help="The data type of the SSM states in mamba cache.", ) - + # Args for multi-item-scoring + parser.add_argument( + "--multi-item-scoring-delimiter", + type=int, + default=ServerArgs.multi_item_scoring_delimiter, + help="Delimiter token ID for multi-item scoring. Used to combine Query and Items into a single sequence: QueryItem1Item2... This enables efficient batch processing of multiple items against a single query.", + ) # Hierarchical cache parser.add_argument( "--enable-hierarchical-cache", @@ -3004,6 +3016,17 @@ class ServerArgs: "lof", ], f"To use priority scheduling, schedule_policy must be 'fcfs' or 'lof'. '{self.schedule_policy}' is not supported." + # Check multi-item scoring + if self.multi_item_scoring_delimiter is not None: + assert self.disable_radix_cache, ( + "Multi-item scoring requires radix cache to be disabled. " + "Please set --disable-radix-cache when using --multi-item-scoring-delimiter." + ) + assert self.chunked_prefill_size == -1, ( + "Multi-item scoring requires chunked prefill to be disabled. " + "Please set --chunked-prefill-size -1 when using --multi-item-scoring-delimiter." + ) + def check_lora_server_args(self): assert self.max_loras_per_batch > 0, "max_loras_per_batch must be positive" diff --git a/python/sglang/srt/two_batch_overlap.py b/python/sglang/srt/two_batch_overlap.py index 49205d4c9..b62ad4136 100644 --- a/python/sglang/srt/two_batch_overlap.py +++ b/python/sglang/srt/two_batch_overlap.py @@ -667,6 +667,7 @@ class TboForwardBatchPreparer: "can_run_dp_cuda_graph", "dp_padding_mode", "global_forward_mode", + "is_prefill_only", "spec_algorithm", "capture_hidden_mode", "padded_static_len", diff --git a/test/srt/test_score_api.py b/test/srt/test_score_api.py index d08ae9df7..757af86de 100644 --- a/test/srt/test_score_api.py +++ b/test/srt/test_score_api.py @@ -295,6 +295,296 @@ class TestScoreAPI(CustomTestCase): ) self.assertFalse(request.stream, "Scoring requests should not stream") + def test_multi_item_scoring_basic(self): + """Test basic multi-item scoring functionality.""" + # Test with a simple query and items + query = "What is the capital of California? Answer Yes or No for each of the following options:" + items = ["Sacramento", "San Jose", "San Francisco"] + label_token_ids = [9454, 2753] # "Yes" and "No" tokens + + # Get scores using SGLang + scores = self.engine.score( + query=query, + items=items, + label_token_ids=label_token_ids, + apply_softmax=True, + ) + + # Verify we get the expected number of scores + self.assertEqual(len(scores), len(items), "Should get one score list per item") + + # Verify each score list has the correct length + for i, score_list in enumerate(scores): + self.assertEqual( + len(score_list), + len(label_token_ids), + f"Item {i} should have {len(label_token_ids)} scores", + ) + # Verify scores are probabilities (sum to 1) + self.assertAlmostEqual( + sum(score_list), + 1.0, + places=6, + msg=f"Scores for item {i} should sum to 1", + ) + # Verify all scores are non-negative + for j, score in enumerate(score_list): + self.assertGreaterEqual( + score, 0, f"Score {j} for item {i} should be non-negative" + ) + + def test_multi_item_scoring_consistency(self): + """Test that multi-item scoring gives consistent results.""" + query = "Choose the best option:" + items = ["Option A", "Option B", "Option C"] + label_token_ids = [1, 2, 3] + + # Run the same test multiple times + scores1 = self.engine.score( + query=query, + items=items, + label_token_ids=label_token_ids, + apply_softmax=True, + ) + + scores2 = self.engine.score( + query=query, + items=items, + label_token_ids=label_token_ids, + apply_softmax=True, + ) + + # Results should be identical (deterministic) + self.assertEqual(len(scores1), len(scores2), "Should get same number of items") + for i, (s1, s2) in enumerate(zip(scores1, scores2)): + self.assertEqual( + len(s1), len(s2), f"Item {i} should have same number of scores" + ) + for j, (score1, score2) in enumerate(zip(s1, s2)): + self.assertAlmostEqual( + score1, + score2, + places=6, + msg=f"Score {j} for item {i} should be identical", + ) + + def test_multi_item_scoring_different_sizes(self): + """Test multi-item scoring with different numbers of items.""" + query = "Rate each option:" + label_token_ids = [1, 2, 3, 4, 5] + + # Test with different numbers of items + test_cases = [ + ["Single item"], + ["Item 1", "Item 2"], + ["A", "B", "C", "D"], + ["X", "Y", "Z", "W", "V", "U"], + ] + + for items in test_cases: + with self.subTest(items=items): + scores = self.engine.score( + query=query, + items=items, + label_token_ids=label_token_ids, + apply_softmax=True, + ) + + self.assertEqual( + len(scores), len(items), f"Should get {len(items)} score lists" + ) + + for i, score_list in enumerate(scores): + self.assertEqual( + len(score_list), + len(label_token_ids), + f"Item {i} should have {len(label_token_ids)} scores", + ) + self.assertAlmostEqual(sum(score_list), 1.0, places=6) + + def test_multi_item_scoring_empty_items(self): + """Test multi-item scoring with empty items list.""" + query = "Test query" + items = [] + label_token_ids = [1, 2] + + scores = self.engine.score( + query=query, + items=items, + label_token_ids=label_token_ids, + apply_softmax=True, + ) + + self.assertEqual(len(scores), 0, "Should return empty list for empty items") + + def test_multi_item_scoring_single_item(self): + """Test multi-item scoring with single item (should work like regular scoring).""" + query = "Complete this sentence: The capital of France is" + items = ["Paris"] + label_token_ids = [1, 2, 3] + + scores = self.engine.score( + query=query, + items=items, + label_token_ids=label_token_ids, + apply_softmax=True, + ) + + self.assertEqual(len(scores), 1, "Should get one score list") + self.assertEqual( + len(scores[0]), len(label_token_ids), "Should have correct number of scores" + ) + self.assertAlmostEqual(sum(scores[0]), 1.0, places=6) + + def test_multi_item_scoring_different_queries(self): + """Test multi-item scoring with different types of queries.""" + items = ["Yes", "No"] + label_token_ids = [1, 2] + + test_queries = [ + "Is this true?", + "Choose the correct answer:", + "What is the best option?", + "Select all that apply:", + "", # Empty query + ] + + for query in test_queries: + with self.subTest(query=query): + scores = self.engine.score( + query=query, + items=items, + label_token_ids=label_token_ids, + apply_softmax=True, + ) + + self.assertEqual( + len(scores), + len(items), + f"Should get {len(items)} score lists for query: '{query}'", + ) + + for i, score_list in enumerate(scores): + self.assertEqual(len(score_list), len(label_token_ids)) + self.assertAlmostEqual(sum(score_list), 1.0, places=6) + + def test_multi_item_scoring_different_label_tokens(self): + """Test multi-item scoring with different label token sets.""" + query = "Choose the best option:" + items = ["Option A", "Option B"] + + test_label_tokens = [ + [1, 2], # Two tokens + [1, 2, 3, 4], # Four tokens + [1], # Single token + [1, 2, 3, 4, 5, 6, 7, 8], # Many tokens + ] + + for label_token_ids in test_label_tokens: + with self.subTest(label_tokens=label_token_ids): + scores = self.engine.score( + query=query, + items=items, + label_token_ids=label_token_ids, + apply_softmax=True, + ) + + self.assertEqual(len(scores), len(items)) + + for i, score_list in enumerate(scores): + self.assertEqual( + len(score_list), + len(label_token_ids), + f"Item {i} should have {len(label_token_ids)} scores", + ) + self.assertAlmostEqual(sum(score_list), 1.0, places=6) + + def test_multi_item_scoring_without_softmax(self): + """Test multi-item scoring without softmax normalization.""" + query = "Rate each option:" + items = ["Good", "Bad", "Neutral"] + label_token_ids = [1, 2, 3] + + scores = self.engine.score( + query=query, + items=items, + label_token_ids=label_token_ids, + apply_softmax=False, # No softmax + ) + + self.assertEqual(len(scores), len(items)) + + for i, score_list in enumerate(scores): + self.assertEqual(len(score_list), len(label_token_ids)) + # Without softmax, scores don't need to sum to 1 + # But they should still be valid logits/probabilities + for j, score in enumerate(score_list): + self.assertIsInstance( + score, (int, float), f"Score {j} for item {i} should be numeric" + ) + + def test_multi_item_scoring_large_batch(self): + """Test multi-item scoring with a large number of items.""" + query = "Classify each item:" + items = [f"Item {i}" for i in range(20)] # 20 items + label_token_ids = [1, 2, 3] + + scores = self.engine.score( + query=query, + items=items, + label_token_ids=label_token_ids, + apply_softmax=True, + ) + + self.assertEqual(len(scores), len(items), "Should handle large batches") + + for i, score_list in enumerate(scores): + self.assertEqual(len(score_list), len(label_token_ids)) + self.assertAlmostEqual(sum(score_list), 1.0, places=6) + + def test_multi_item_scoring_unicode(self): + """Test multi-item scoring with unicode characters.""" + query = "选择最佳选项:" + items = ["选项A", "选项B", "选项C"] + label_token_ids = [1, 2, 3] + + scores = self.engine.score( + query=query, + items=items, + label_token_ids=label_token_ids, + apply_softmax=True, + ) + + self.assertEqual(len(scores), len(items)) + + for i, score_list in enumerate(scores): + self.assertEqual(len(score_list), len(label_token_ids)) + self.assertAlmostEqual(sum(score_list), 1.0, places=6) + + def test_multi_item_scoring_error_handling(self): + """Test multi-item scoring error handling.""" + query = "Test query" + items = ["Item 1", "Item 2"] + label_token_ids = [1, 2] + + # Test with invalid label_token_ids + with self.assertRaises((ValueError, TypeError)): + self.engine.score( + query=query, + items=items, + label_token_ids="invalid", # Should be list of ints + apply_softmax=True, + ) + + # Test with None items + with self.assertRaises((ValueError, TypeError)): + self.engine.score( + query=query, + items=None, + label_token_ids=label_token_ids, + apply_softmax=True, + ) + if __name__ == "__main__": unittest.main()