[Generative Score API] Multi-Item scoring with custom attention mask. (#10979)
This commit is contained in:
committed by
GitHub
parent
e22b13c569
commit
53bd00d975
@@ -7,6 +7,7 @@ FlashInfer is faster and Triton is easier to customize.
|
||||
Each backend supports two operators: extend (i.e. prefill with cached prefix) and decode.
|
||||
"""
|
||||
|
||||
import logging
|
||||
import os
|
||||
from dataclasses import dataclass
|
||||
from enum import Enum, auto
|
||||
@@ -16,11 +17,11 @@ from typing import TYPE_CHECKING, Callable, List, Optional, Union
|
||||
import torch
|
||||
|
||||
if os.environ["SGLANG_ENABLE_TORCH_COMPILE"] == "1":
|
||||
import logging
|
||||
|
||||
torch._logging.set_logs(dynamo=logging.ERROR)
|
||||
torch._dynamo.config.suppress_errors = True
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
from sglang.global_config import global_config
|
||||
from sglang.srt.layers.attention.base_attn_backend import AttentionBackend
|
||||
from sglang.srt.layers.attention.utils import create_flashinfer_kv_indices_triton
|
||||
@@ -58,6 +59,36 @@ class WrapperDispatch(Enum):
|
||||
CROSS_ATTENTION = auto()
|
||||
|
||||
|
||||
@dataclass
|
||||
class MultiItemScoringParams:
|
||||
"""Parameters for multi-item scoring in attention computation.
|
||||
|
||||
Used when processing sequences with multiple items separated by delimiters,
|
||||
where each item needs specific attention patterns that respect item boundaries.
|
||||
|
||||
Attributes:
|
||||
prefix_len_ptr: A uint32 1D tensor indicating the prefix length of each prompt.
|
||||
The tensor size is equal to the batch size.
|
||||
token_pos_in_items_ptr: A uint16 1D tensor indicating the token position of each item
|
||||
starting from 0 (delimiter) for each item. For batch size > 1,
|
||||
sequences are concatenated with zero padding to ensure same length.
|
||||
token_pos_in_items_len: Zero padding length for token_pos_in_items_ptr to handle
|
||||
batch_size > 1 case. Defines the padded length for each sequence.
|
||||
max_item_len_ptr: A uint16 tensor containing the max token length of all items
|
||||
for each prompt in the batch.
|
||||
|
||||
"""
|
||||
|
||||
prefix_len_ptr: Optional[torch.Tensor] = None
|
||||
token_pos_in_items_ptr: Optional[torch.Tensor] = None
|
||||
token_pos_in_items_len: int = 0
|
||||
max_item_len_ptr: Optional[torch.Tensor] = None
|
||||
|
||||
def is_enabled(self) -> bool:
|
||||
"""Check if multi-item scoring is enabled."""
|
||||
return self.prefix_len_ptr is not None
|
||||
|
||||
|
||||
@dataclass
|
||||
class DecodeMetadata:
|
||||
decode_wrappers: List[BatchDecodeWithPagedKVCacheWrapper]
|
||||
@@ -68,6 +99,7 @@ class PrefillMetadata:
|
||||
prefill_wrappers: List[BatchPrefillWithPagedKVCacheWrapper]
|
||||
use_ragged: bool
|
||||
extend_no_prefix: bool
|
||||
multi_item_params: Optional[MultiItemScoringParams] = None
|
||||
|
||||
|
||||
# Reuse this workspace buffer across all flashinfer wrappers
|
||||
@@ -90,6 +122,11 @@ class FlashInferAttnBackend(AttentionBackend):
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
# Store multi-item scoring delimiter for efficient access
|
||||
self.multi_item_scoring_delimiter = (
|
||||
model_runner.server_args.multi_item_scoring_delimiter
|
||||
)
|
||||
|
||||
# Parse constants
|
||||
self.decode_use_tensor_cores = should_use_tensor_core(
|
||||
kv_cache_dtype=model_runner.kv_cache_dtype,
|
||||
@@ -229,10 +266,133 @@ class FlashInferAttnBackend(AttentionBackend):
|
||||
|
||||
# Other metadata
|
||||
self.forward_metadata: Union[PrefillMetadata, DecodeMetadata] = None
|
||||
|
||||
self.decode_cuda_graph_metadata = {}
|
||||
self.prefill_cuda_graph_metadata = {} # For verify
|
||||
self.draft_extend_cuda_graph_metadata = {} # For draft extend
|
||||
|
||||
def _process_multi_item_scoring(
|
||||
self, forward_batch: ForwardBatch
|
||||
) -> MultiItemScoringParams:
|
||||
"""Process multi-item scoring tensors for FlashInfer attention.
|
||||
|
||||
This method handles sequences containing multiple "items" separated by delimiter tokens,
|
||||
where each item needs specific attention patterns that respect item boundaries.
|
||||
|
||||
The method produces four key tensors for FlashInfer:
|
||||
- prefix_len_ptr: uint32 tensor with prefix length for each prompt in batch
|
||||
- token_pos_in_items_ptr: uint16 tensor with token positions starting from 0 at delimiters
|
||||
- token_pos_in_items_len: padding length for batch processing
|
||||
- max_item_len_ptr: uint16 tensor with max item length for each prompt
|
||||
|
||||
Args:
|
||||
forward_batch: The forward batch containing input sequences and delimiter info
|
||||
|
||||
Returns:
|
||||
MultiItemScoringParams: The processed multi-item scoring parameters
|
||||
|
||||
Examples:
|
||||
Following FlashInfer definition: for 3 items of length 3, 2, 4 respectively:
|
||||
token_pos_in_items_ptr = [0, 1, 2, 3, 0, 1, 2, 0, 1, 2, 3, 4, 0]
|
||||
|
||||
Case 1: Single sequence
|
||||
Text: "What is the capital of France? <delim> London <delim> Paris <delim> Berlin <delim>"
|
||||
Tokens: [What, is, the, capital, of, France, ?, <delim>, London, <delim>, Paris, <delim>, Berlin, <delim>]
|
||||
Indices: [ 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13]
|
||||
- prefix_len_ptr: [7] (query length before first delimiter)
|
||||
- token_pos_in_items_ptr: [0, 1, 0, 1, 0, 1, 0] (delim=0, London=1, delim=0, Paris=1, delim=0, Berlin=1, delim=0)
|
||||
- token_pos_in_items_len: 7 (actual length)
|
||||
- max_item_len_ptr: [1] (max item length is 1 token - all options are single tokens)
|
||||
|
||||
Case 2: Batch processing (batch_size=2)
|
||||
Sequence 1: 2 items of length 2, 1 → [0, 1, 2, 0, 1, 0] (6 elements)
|
||||
Sequence 2: 3 items of length 1, 3, 2 → [0, 1, 0, 1, 2, 3, 0, 1, 2, 0] (10 elements)
|
||||
After padding both to length 10:
|
||||
- token_pos_in_items_ptr: [0, 1, 2, 0, 1, 0, 0, 0, 0, 0, 0, 1, 0, 1, 2, 3, 0, 1, 2, 0]
|
||||
- token_pos_in_items_len: 10 (padded length for batch processing)
|
||||
- max_item_len_ptr: [2, 3] (max lengths per sequence)
|
||||
"""
|
||||
|
||||
delimiter = self.multi_item_scoring_delimiter
|
||||
if delimiter is None or forward_batch.forward_mode == ForwardMode.DECODE:
|
||||
return MultiItemScoringParams()
|
||||
|
||||
delimiter_mask = forward_batch.input_ids == delimiter
|
||||
prefix_cache_lens = getattr(forward_batch, "extend_prefix_lens", None)
|
||||
extend_seq_lens = getattr(forward_batch, "extend_seq_lens", None)
|
||||
prefix_len_ptr, token_pos_in_items_ptr = [], []
|
||||
token_pos_in_items_len = 0
|
||||
|
||||
# If no extend_seq_lens, treat whole batch as one sequence
|
||||
if extend_seq_lens is None or len(extend_seq_lens) <= 1:
|
||||
extend_seq_lens = [forward_batch.input_ids.size(0)]
|
||||
|
||||
seq_start = 0
|
||||
for i, seq_len in enumerate(extend_seq_lens):
|
||||
seq_end = seq_start + seq_len
|
||||
mask = delimiter_mask[seq_start:seq_end]
|
||||
pos = forward_batch.positions[seq_start:seq_end]
|
||||
delimiter_indices = torch.nonzero(mask, as_tuple=True)[0]
|
||||
|
||||
if len(delimiter_indices) > 0:
|
||||
first_delim = delimiter_indices[0]
|
||||
# Prefix length: store as scalar
|
||||
prefix_len = first_delim + (
|
||||
prefix_cache_lens[i] if prefix_cache_lens is not None else 0
|
||||
)
|
||||
prefix_len_ptr.append(
|
||||
prefix_len.item() if torch.is_tensor(prefix_len) else prefix_len
|
||||
)
|
||||
|
||||
# Compute relative positions within items after delimiters
|
||||
diff = pos[first_delim:] - torch.cummax(mask[first_delim:], 0)[1]
|
||||
token_pos = (diff - pos[first_delim]).to(torch.uint16)
|
||||
token_pos_in_items_ptr.append(token_pos)
|
||||
|
||||
# Update forward_batch positions in-place
|
||||
pos[first_delim:] = diff - 1
|
||||
forward_batch.positions[seq_start:seq_end] = pos
|
||||
|
||||
seq_start = seq_end
|
||||
|
||||
# Pad token_pos_in_items_ptr for batch processing
|
||||
if token_pos_in_items_ptr:
|
||||
token_pos_in_items_len = max(t.numel() for t in token_pos_in_items_ptr)
|
||||
device = forward_batch.input_ids.device
|
||||
token_pos_in_items_ptr = [
|
||||
torch.cat(
|
||||
[
|
||||
t,
|
||||
torch.zeros(
|
||||
token_pos_in_items_len - t.numel(),
|
||||
dtype=torch.uint16,
|
||||
device=device,
|
||||
),
|
||||
]
|
||||
)
|
||||
for t in token_pos_in_items_ptr
|
||||
]
|
||||
|
||||
if not prefix_len_ptr or not token_pos_in_items_ptr:
|
||||
return MultiItemScoringParams()
|
||||
|
||||
# Build final params
|
||||
device = forward_batch.input_ids.device
|
||||
return MultiItemScoringParams(
|
||||
prefix_len_ptr=torch.tensor(
|
||||
prefix_len_ptr, dtype=torch.uint32, device=device
|
||||
),
|
||||
token_pos_in_items_ptr=torch.cat(token_pos_in_items_ptr, dim=0),
|
||||
token_pos_in_items_len=token_pos_in_items_len & 0xFFFFFFFF,
|
||||
max_item_len_ptr=torch.stack(
|
||||
[
|
||||
t.to(torch.int32).max().to(torch.uint16)
|
||||
for t in token_pos_in_items_ptr
|
||||
],
|
||||
dim=0,
|
||||
),
|
||||
)
|
||||
|
||||
def init_forward_metadata(self, forward_batch: ForwardBatch):
|
||||
if forward_batch.forward_mode.is_decode_or_idle():
|
||||
self.indices_updater_decode.update(
|
||||
@@ -280,13 +440,26 @@ class FlashInferAttnBackend(AttentionBackend):
|
||||
else:
|
||||
prefix_lens = forward_batch.extend_prefix_lens
|
||||
|
||||
if self.is_multimodal:
|
||||
# Disable ragged wrapper and ensure prefix handling for multimodal and multi-item scoring
|
||||
if self.is_multimodal or self.multi_item_scoring_delimiter is not None:
|
||||
# use_ragged = False: Multi-item scoring requires the paged wrapper because:
|
||||
# 1. Ragged wrapper doesn't support the specialized multi-item parameters
|
||||
# (prefix_len_ptr, token_pos_in_items_ptr, etc.)
|
||||
# 2. Paged wrapper provides better control over attention masking needed
|
||||
# for respecting item boundaries in multi-item sequences
|
||||
# 3. Custom masking logic conflicts with ragged wrapper's assumptions
|
||||
use_ragged = False
|
||||
extend_no_prefix = False
|
||||
else:
|
||||
use_ragged = not self.enable_deterministic
|
||||
extend_no_prefix = not any(forward_batch.extend_prefix_lens_cpu)
|
||||
|
||||
# Process multi-item scoring in attention backend instead of ForwardBatch
|
||||
multi_item_params = MultiItemScoringParams()
|
||||
if self.multi_item_scoring_delimiter is not None:
|
||||
# Use new backend-specific implementation
|
||||
multi_item_params = self._process_multi_item_scoring(forward_batch)
|
||||
|
||||
self.indices_updater_prefill.update(
|
||||
forward_batch.req_pool_indices,
|
||||
forward_batch.seq_lens,
|
||||
@@ -298,9 +471,13 @@ class FlashInferAttnBackend(AttentionBackend):
|
||||
encoder_lens=forward_batch.encoder_lens,
|
||||
spec_info=None,
|
||||
fixed_split_size=self.prefill_split_tile_size,
|
||||
multi_item_params=multi_item_params,
|
||||
)
|
||||
self.forward_metadata = PrefillMetadata(
|
||||
self.prefill_wrappers_paged, use_ragged, extend_no_prefix
|
||||
self.prefill_wrappers_paged,
|
||||
use_ragged,
|
||||
extend_no_prefix,
|
||||
multi_item_params,
|
||||
)
|
||||
|
||||
def init_cuda_graph_state(
|
||||
@@ -531,7 +708,20 @@ class FlashInferAttnBackend(AttentionBackend):
|
||||
forward_batch.token_to_kv_pool.get_kv_buffer(layer.layer_id),
|
||||
causal=not layer.is_cross_attention,
|
||||
sm_scale=layer.scaling,
|
||||
window_left=layer.sliding_window_size,
|
||||
# Disable sliding window attention for multi-item scoring:
|
||||
# - Sliding window could cut across item boundaries, breaking semantic coherence
|
||||
# - Multi-item sequences need full attention to properly handle delimiter tokens
|
||||
# - Specialized multi-item parameters (prefix_len_ptr, token_pos_in_items_ptr)
|
||||
# provide more precise attention control than simple sliding windows
|
||||
# - Item-aware masking takes precedence over window-based masking
|
||||
window_left=(
|
||||
layer.sliding_window_size
|
||||
if not (
|
||||
self.forward_metadata.multi_item_params
|
||||
and self.forward_metadata.multi_item_params.is_enabled()
|
||||
)
|
||||
else -1
|
||||
),
|
||||
logits_soft_cap=logits_soft_cap,
|
||||
# Must use _float to avoid device-to-host copy that breaks cuda graph capture.
|
||||
k_scale=layer.k_scale_float,
|
||||
@@ -952,6 +1142,7 @@ class FlashInferIndicesUpdaterPrefill:
|
||||
encoder_lens: Optional[torch.Tensor],
|
||||
spec_info: Optional[SpecInput],
|
||||
fixed_split_size: Optional[int] = None,
|
||||
multi_item_params: Optional[MultiItemScoringParams] = None,
|
||||
):
|
||||
if use_ragged:
|
||||
# TODO: remove this device sync, we can use forward_batch.extend_prefix_lens_cpu
|
||||
@@ -976,6 +1167,7 @@ class FlashInferIndicesUpdaterPrefill:
|
||||
use_ragged,
|
||||
spec_info,
|
||||
fixed_split_size=fixed_split_size,
|
||||
multi_item_params=multi_item_params,
|
||||
)
|
||||
|
||||
def update_sliding_window(
|
||||
@@ -990,6 +1182,7 @@ class FlashInferIndicesUpdaterPrefill:
|
||||
encoder_lens: Optional[torch.Tensor],
|
||||
spec_info: Optional[SpecInput],
|
||||
fixed_split_size: Optional[int] = None,
|
||||
multi_item_params: Optional[MultiItemScoringParams] = None,
|
||||
):
|
||||
for wrapper_id in range(2):
|
||||
if wrapper_id == 0:
|
||||
@@ -1023,6 +1216,7 @@ class FlashInferIndicesUpdaterPrefill:
|
||||
use_ragged,
|
||||
spec_info,
|
||||
use_sliding_window_kv_pool=use_sliding_window_kv_pool,
|
||||
multi_item_params=multi_item_params,
|
||||
)
|
||||
|
||||
def update_cross_attention(
|
||||
@@ -1037,6 +1231,7 @@ class FlashInferIndicesUpdaterPrefill:
|
||||
encoder_lens: Optional[torch.Tensor],
|
||||
spec_info: Optional[SpecInput],
|
||||
fixed_split_size: Optional[int] = None,
|
||||
multi_item_params: Optional[MultiItemScoringParams] = None,
|
||||
):
|
||||
for wrapper_id in range(2):
|
||||
if wrapper_id == 0:
|
||||
@@ -1063,6 +1258,7 @@ class FlashInferIndicesUpdaterPrefill:
|
||||
self.qo_indptr[wrapper_id],
|
||||
use_ragged,
|
||||
spec_info,
|
||||
multi_item_params=multi_item_params,
|
||||
)
|
||||
|
||||
def call_begin_forward(
|
||||
@@ -1081,6 +1277,7 @@ class FlashInferIndicesUpdaterPrefill:
|
||||
spec_info: Optional[SpecInput],
|
||||
use_sliding_window_kv_pool: bool = False,
|
||||
fixed_split_size: Optional[int] = None,
|
||||
multi_item_params: Optional[MultiItemScoringParams] = None,
|
||||
):
|
||||
bs = len(seq_lens)
|
||||
if spec_info is None:
|
||||
@@ -1136,6 +1333,22 @@ class FlashInferIndicesUpdaterPrefill:
|
||||
)
|
||||
|
||||
# cached part
|
||||
# Conditionally set multi-item parameters
|
||||
if multi_item_params is not None and multi_item_params.is_enabled():
|
||||
# Multi-item scoring is active - use specialized parameters and disable generic custom_mask
|
||||
use_custom_mask = None
|
||||
prefix_len_ptr = multi_item_params.prefix_len_ptr
|
||||
token_pos_in_items_ptr = multi_item_params.token_pos_in_items_ptr
|
||||
token_pos_in_items_len = multi_item_params.token_pos_in_items_len
|
||||
max_item_len_ptr = multi_item_params.max_item_len_ptr
|
||||
else:
|
||||
# No multi-item scoring - use standard parameters
|
||||
use_custom_mask = custom_mask
|
||||
prefix_len_ptr = None
|
||||
token_pos_in_items_ptr = None
|
||||
token_pos_in_items_len = 0
|
||||
max_item_len_ptr = None
|
||||
|
||||
wrapper_paged.begin_forward(
|
||||
qo_indptr,
|
||||
kv_indptr,
|
||||
@@ -1147,9 +1360,13 @@ class FlashInferIndicesUpdaterPrefill:
|
||||
1,
|
||||
q_data_type=self.q_data_type,
|
||||
kv_data_type=self.data_type,
|
||||
custom_mask=custom_mask,
|
||||
custom_mask=use_custom_mask,
|
||||
non_blocking=True,
|
||||
fixed_split_size=fixed_split_size,
|
||||
prefix_len_ptr=prefix_len_ptr,
|
||||
token_pos_in_items_ptr=token_pos_in_items_ptr,
|
||||
token_pos_in_items_len=token_pos_in_items_len,
|
||||
max_item_len_ptr=max_item_len_ptr,
|
||||
)
|
||||
|
||||
|
||||
|
||||
@@ -60,7 +60,8 @@ _is_npu = is_npu()
|
||||
class LogitsProcessorOutput:
|
||||
## Part 1: This part will be assigned in python/sglang/srt/layers/logits_processor.py::LogitsProcessor
|
||||
# The logits of the next tokens. shape: [#seq, vocab_size]
|
||||
next_token_logits: torch.Tensor
|
||||
# Can be None for certain prefill-only requests (e.g., multi-item scoring) that don't need next token generation
|
||||
next_token_logits: Optional[torch.Tensor]
|
||||
# Used by speculative decoding (EAGLE)
|
||||
# The last hidden layers
|
||||
hidden_states: Optional[torch.Tensor] = None
|
||||
@@ -85,7 +86,10 @@ class LogitsProcessorOutput:
|
||||
input_top_logprobs_val: List = None
|
||||
input_top_logprobs_idx: List = None
|
||||
# The logprobs and ids of the requested token ids in input positions. shape: [#seq, n] (n is the number of requested token ids)
|
||||
input_token_ids_logprobs_val: Optional[List] = None
|
||||
# Can contain either lists or GPU tensors (for delayed GPU-to-CPU transfer optimization)
|
||||
input_token_ids_logprobs_val: Optional[List[Union[List[float], torch.Tensor]]] = (
|
||||
None
|
||||
)
|
||||
input_token_ids_logprobs_idx: Optional[List] = None
|
||||
|
||||
|
||||
@@ -127,6 +131,9 @@ class LogitsMetadata:
|
||||
# for padding
|
||||
padded_static_len: int = -1
|
||||
|
||||
# Whether this batch is prefill-only (no token generation needed)
|
||||
is_prefill_only: bool = False
|
||||
|
||||
@classmethod
|
||||
def from_forward_batch(cls, forward_batch: ForwardBatch):
|
||||
if (
|
||||
@@ -169,6 +176,7 @@ class LogitsMetadata:
|
||||
token_ids_logprobs=forward_batch.token_ids_logprobs,
|
||||
extend_input_logprob_token_ids_gpu=forward_batch.extend_input_logprob_token_ids_gpu,
|
||||
padded_static_len=forward_batch.padded_static_len,
|
||||
is_prefill_only=forward_batch.is_prefill_only,
|
||||
global_num_tokens_gpu=forward_batch.global_num_tokens_gpu,
|
||||
dp_local_start_pos=forward_batch.dp_local_start_pos,
|
||||
dp_local_num_tokens=forward_batch.dp_local_num_tokens,
|
||||
@@ -247,6 +255,108 @@ class LogitsProcessor(nn.Module):
|
||||
"debug_tensor_dump_output_folder", None
|
||||
)
|
||||
|
||||
def compute_logprobs_for_multi_item_scoring(
|
||||
self,
|
||||
input_ids,
|
||||
hidden_states,
|
||||
lm_head: VocabParallelEmbedding,
|
||||
logits_metadata: Union[LogitsMetadata, ForwardBatch],
|
||||
delimiter_token: int,
|
||||
):
|
||||
"""
|
||||
Compute logprobs for multi-item scoring using delimiter-based token extraction.
|
||||
|
||||
This method is designed for scenarios where you want to score multiple items/candidates
|
||||
against a single query by combining them into one sequence separated by delimiters.
|
||||
|
||||
Sequence format: Query<delimiter>Item1<delimiter>Item2<delimiter>...
|
||||
Scoring positions: Extracts logprobs at positions before each <delimiter>
|
||||
|
||||
Args:
|
||||
input_ids (torch.Tensor): Input token IDs containing query and items separated by delimiters.
|
||||
Shape: [total_sequence_length] for single request or [batch_total_length] for batch.
|
||||
hidden_states (torch.Tensor): Hidden states from the model.
|
||||
Shape: [sequence_length, hidden_dim].
|
||||
lm_head (VocabParallelEmbedding): Language model head for computing logits.
|
||||
logits_metadata (Union[LogitsMetadata, ForwardBatch]): Metadata containing batch info
|
||||
and token ID specifications for logprob extraction.
|
||||
delimiter_token (int): Token ID used as delimiter between query and items.
|
||||
|
||||
Returns:
|
||||
LogitsProcessorOutput: Contains:
|
||||
- next_token_logits: None (not needed for scoring-only requests)
|
||||
- input_token_logprobs: Logprobs of delimiter tokens at scoring positions
|
||||
- input_top_logprobs_val: Top-k logprobs at delimiter positions (if requested)
|
||||
- input_top_logprobs_idx: Top-k token indices at delimiter positions (if requested)
|
||||
- input_token_ids_logprobs_val: Logprobs for user-requested token IDs (if any)
|
||||
- input_token_ids_logprobs_idx: Indices for user-requested token IDs (if any)
|
||||
"""
|
||||
multi_item_indices = (input_ids == delimiter_token).nonzero(as_tuple=True)[
|
||||
0
|
||||
] - 1
|
||||
# Extract hidden states at delimiter positions for multi-item scoring
|
||||
sliced_hidden = hidden_states[multi_item_indices]
|
||||
|
||||
sliced_logits = self._get_logits(sliced_hidden, lm_head, logits_metadata)
|
||||
sliced_logprobs = torch.nn.functional.log_softmax(sliced_logits, dim=-1)
|
||||
|
||||
# Initialize return values
|
||||
input_token_ids_logprobs_val = []
|
||||
input_token_ids_logprobs_idx = []
|
||||
input_top_logprobs_val = None
|
||||
input_top_logprobs_idx = None
|
||||
|
||||
# Recalculate extend_logprob_pruned_lens_cpu to match delimiter counts per request
|
||||
# Original contains sequence lengths, but we need delimiter counts for sliced_logprobs
|
||||
if (
|
||||
logits_metadata.token_ids_logprobs
|
||||
or logits_metadata.extend_return_top_logprob
|
||||
):
|
||||
logits_metadata.extend_logprob_pruned_lens_cpu = []
|
||||
|
||||
if logits_metadata.extend_seq_lens_cpu is not None:
|
||||
# Multi-request batch: count delimiters per request
|
||||
input_pt = 0
|
||||
for req_seq_len in logits_metadata.extend_seq_lens_cpu:
|
||||
req_input_ids = input_ids[input_pt : input_pt + req_seq_len]
|
||||
delimiter_count = (req_input_ids == delimiter_token).sum().item()
|
||||
logits_metadata.extend_logprob_pruned_lens_cpu.append(
|
||||
delimiter_count
|
||||
)
|
||||
input_pt += req_seq_len
|
||||
else:
|
||||
# Single request case: one request gets all delimiters
|
||||
total_delimiters = (input_ids == delimiter_token).sum().item()
|
||||
logits_metadata.extend_logprob_pruned_lens_cpu = [total_delimiters]
|
||||
|
||||
# Get the logprobs of specified token ids
|
||||
if logits_metadata.extend_token_ids_logprob:
|
||||
(
|
||||
input_token_ids_logprobs_val,
|
||||
input_token_ids_logprobs_idx,
|
||||
) = self.get_token_ids_logprobs(
|
||||
sliced_logprobs, logits_metadata, delay_cpu_copy=True
|
||||
)
|
||||
|
||||
# Get the logprob of top-k tokens
|
||||
if logits_metadata.extend_return_top_logprob:
|
||||
(
|
||||
input_top_logprobs_val,
|
||||
input_top_logprobs_idx,
|
||||
) = self.get_top_logprobs(sliced_logprobs, logits_metadata)
|
||||
|
||||
# For input_token_logprobs, use delimiter token logprobs
|
||||
input_token_logprobs = sliced_logprobs[:, delimiter_token]
|
||||
|
||||
return LogitsProcessorOutput(
|
||||
next_token_logits=None, # Multi-item scoring doesn't need next token logits
|
||||
input_token_logprobs=input_token_logprobs,
|
||||
input_top_logprobs_val=input_top_logprobs_val,
|
||||
input_top_logprobs_idx=input_top_logprobs_idx,
|
||||
input_token_ids_logprobs_val=input_token_ids_logprobs_val,
|
||||
input_token_ids_logprobs_idx=input_token_ids_logprobs_idx,
|
||||
)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
input_ids,
|
||||
@@ -257,6 +367,16 @@ class LogitsProcessor(nn.Module):
|
||||
) -> LogitsProcessorOutput:
|
||||
if isinstance(logits_metadata, ForwardBatch):
|
||||
logits_metadata = LogitsMetadata.from_forward_batch(logits_metadata)
|
||||
|
||||
# Check if multi-item scoring is enabled via server args (only for prefill-only requests)
|
||||
multi_item_delimiter = global_server_args_dict.get(
|
||||
"multi_item_scoring_delimiter"
|
||||
)
|
||||
if multi_item_delimiter is not None and logits_metadata.is_prefill_only:
|
||||
return self.compute_logprobs_for_multi_item_scoring(
|
||||
input_ids, hidden_states, lm_head, logits_metadata, multi_item_delimiter
|
||||
)
|
||||
|
||||
# Get the last hidden states and last logits for the next token prediction
|
||||
if (
|
||||
logits_metadata.forward_mode.is_decode_or_idle()
|
||||
@@ -584,7 +704,9 @@ class LogitsProcessor(nn.Module):
|
||||
|
||||
@staticmethod
|
||||
def get_token_ids_logprobs(
|
||||
all_logprobs: torch.Tensor, logits_metadata: LogitsMetadata
|
||||
all_logprobs: torch.Tensor,
|
||||
logits_metadata: LogitsMetadata,
|
||||
delay_cpu_copy: bool = False,
|
||||
):
|
||||
input_token_ids_logprobs_val, input_token_ids_logprobs_idx = [], []
|
||||
pt = 0
|
||||
@@ -597,9 +719,17 @@ class LogitsProcessor(nn.Module):
|
||||
input_token_ids_logprobs_idx.append([])
|
||||
continue
|
||||
|
||||
input_token_ids_logprobs_val.append(
|
||||
[all_logprobs[pt + j, token_ids].tolist() for j in range(pruned_len)]
|
||||
)
|
||||
position_logprobs = all_logprobs[
|
||||
pt : pt + pruned_len, token_ids
|
||||
] # Shape: [pruned_len, num_tokens]
|
||||
|
||||
if delay_cpu_copy:
|
||||
# Keep as tensor to delay GPU-to-CPU transfer
|
||||
input_token_ids_logprobs_val.append(position_logprobs)
|
||||
else:
|
||||
# Convert to list immediately (default behavior)
|
||||
input_token_ids_logprobs_val.append(position_logprobs.tolist())
|
||||
|
||||
input_token_ids_logprobs_idx.append([token_ids for _ in range(pruned_len)])
|
||||
pt += pruned_len
|
||||
|
||||
|
||||
@@ -114,6 +114,7 @@ GLOBAL_SERVER_ARGS_KEYS = [
|
||||
"enable_deterministic_inference",
|
||||
"nsa_prefill",
|
||||
"nsa_decode",
|
||||
"multi_item_scoring_delimiter",
|
||||
]
|
||||
|
||||
# Put some global args for easy access
|
||||
@@ -666,9 +667,11 @@ class Req:
|
||||
def is_prefill_only(self) -> bool:
|
||||
"""Check if this request is prefill-only (no token generation needed)."""
|
||||
# NOTE: when spec is enabled, prefill_only optimizations are disabled
|
||||
return (
|
||||
self.sampling_params.max_new_tokens == 0
|
||||
and global_server_args_dict["speculative_algorithm"] is None
|
||||
from sglang.srt.speculative.spec_info import SpeculativeAlgorithm
|
||||
|
||||
spec_alg = global_server_args_dict["speculative_algorithm"]
|
||||
return self.sampling_params.max_new_tokens == 0 and (
|
||||
spec_alg is None or spec_alg == SpeculativeAlgorithm.NONE
|
||||
)
|
||||
|
||||
def add_latency(self, stage: RequestStage):
|
||||
|
||||
@@ -104,7 +104,10 @@ class SchedulerOutputProcessorMixin:
|
||||
assert extend_input_len_per_req is not None
|
||||
extend_logprob_start_len = extend_logprob_start_len_per_req[i]
|
||||
extend_input_len = extend_input_len_per_req[i]
|
||||
num_input_logprobs = extend_input_len - extend_logprob_start_len
|
||||
|
||||
num_input_logprobs = self._calculate_num_input_logprobs(
|
||||
req, extend_input_len, extend_logprob_start_len
|
||||
)
|
||||
|
||||
if req.return_logprob:
|
||||
self.add_logprob_return_values(
|
||||
@@ -159,8 +162,8 @@ class SchedulerOutputProcessorMixin:
|
||||
extend_input_len = extend_input_len_per_req[i]
|
||||
if extend_logprob_start_len < extend_input_len:
|
||||
# Update input logprobs.
|
||||
num_input_logprobs = (
|
||||
extend_input_len - extend_logprob_start_len
|
||||
num_input_logprobs = self._calculate_num_input_logprobs(
|
||||
req, extend_input_len, extend_logprob_start_len
|
||||
)
|
||||
if req.return_logprob:
|
||||
self.add_input_logprob_return_values(
|
||||
@@ -303,6 +306,153 @@ class SchedulerOutputProcessorMixin:
|
||||
):
|
||||
self.log_decode_stats(can_run_cuda_graph, running_batch=batch)
|
||||
|
||||
def _process_input_token_logprobs(
|
||||
self, req: Req, input_token_logprobs: List
|
||||
) -> None:
|
||||
"""Process input token logprobs values and indices."""
|
||||
is_multi_item_scoring = self._is_multi_item_scoring(req)
|
||||
|
||||
# Process logprob values - handle multi-item scoring vs regular requests
|
||||
if is_multi_item_scoring:
|
||||
# Multi-item scoring: use all logprobs as-is
|
||||
req.input_token_logprobs_val = input_token_logprobs
|
||||
else:
|
||||
# Regular request: add None at start, remove last (sampling token)
|
||||
req.input_token_logprobs_val = [None] + input_token_logprobs[:-1]
|
||||
|
||||
# Process logprob indices based on scoring type
|
||||
if is_multi_item_scoring:
|
||||
# Multi-item scoring: only include delimiter token positions
|
||||
relevant_tokens = req.origin_input_ids[req.logprob_start_len :]
|
||||
input_token_logprobs_idx = [
|
||||
token_id
|
||||
for token_id in relevant_tokens
|
||||
if token_id == self.server_args.multi_item_scoring_delimiter
|
||||
]
|
||||
else:
|
||||
# Regular request: include all tokens from logprob_start_len onwards
|
||||
input_token_logprobs_idx = req.origin_input_ids[req.logprob_start_len :]
|
||||
|
||||
# Clip padded hash values from image tokens to prevent detokenization errors
|
||||
req.input_token_logprobs_idx = [
|
||||
x if x < self.model_config.vocab_size - 1 else 0
|
||||
for x in input_token_logprobs_idx
|
||||
]
|
||||
|
||||
def _process_input_top_logprobs(self, req: Req) -> None:
|
||||
"""Process input top logprobs."""
|
||||
if req.top_logprobs_num <= 0:
|
||||
return
|
||||
|
||||
is_multi_item_scoring = self._is_multi_item_scoring(req)
|
||||
|
||||
# Initialize arrays - multi-item scoring starts empty, others start with None
|
||||
req.input_top_logprobs_val = [] if is_multi_item_scoring else [None]
|
||||
req.input_top_logprobs_idx = [] if is_multi_item_scoring else [None]
|
||||
|
||||
# Extend arrays with temp values
|
||||
for val, idx in zip(
|
||||
req.temp_input_top_logprobs_val,
|
||||
req.temp_input_top_logprobs_idx,
|
||||
strict=True,
|
||||
):
|
||||
req.input_top_logprobs_val.extend(val)
|
||||
req.input_top_logprobs_idx.extend(idx)
|
||||
|
||||
# Remove last token (sampling token) for non multi-item scoring requests
|
||||
if not is_multi_item_scoring:
|
||||
req.input_top_logprobs_val.pop()
|
||||
req.input_top_logprobs_idx.pop()
|
||||
|
||||
# Clean up temp storage
|
||||
req.temp_input_top_logprobs_idx = None
|
||||
req.temp_input_top_logprobs_val = None
|
||||
|
||||
def _process_input_token_ids_logprobs(self, req: Req) -> None:
|
||||
"""Process input token IDs logprobs."""
|
||||
if req.token_ids_logprob is None:
|
||||
return
|
||||
|
||||
is_multi_item_scoring = self._is_multi_item_scoring(req)
|
||||
|
||||
# Initialize arrays - multi-item scoring starts empty, others start with None
|
||||
req.input_token_ids_logprobs_val = [] if is_multi_item_scoring else [None]
|
||||
req.input_token_ids_logprobs_idx = [] if is_multi_item_scoring else [None]
|
||||
|
||||
# Process temp values - convert tensors to lists and extend arrays
|
||||
for val, idx in zip(
|
||||
req.temp_input_token_ids_logprobs_val,
|
||||
req.temp_input_token_ids_logprobs_idx,
|
||||
strict=True,
|
||||
):
|
||||
val_list = val.tolist() if isinstance(val, torch.Tensor) else val
|
||||
req.input_token_ids_logprobs_val.extend(
|
||||
val_list if isinstance(val_list, list) else [val_list]
|
||||
)
|
||||
req.input_token_ids_logprobs_idx.extend(idx)
|
||||
|
||||
# Remove last token (sampling token) for non multi-item scoring requests
|
||||
if not is_multi_item_scoring:
|
||||
req.input_token_ids_logprobs_val.pop()
|
||||
req.input_token_ids_logprobs_idx.pop()
|
||||
|
||||
# Clean up temp storage
|
||||
req.temp_input_token_ids_logprobs_idx = None
|
||||
req.temp_input_token_ids_logprobs_val = None
|
||||
|
||||
def _calculate_relevant_tokens_len(self, req: Req) -> int:
|
||||
"""Calculate the expected length of logprob arrays based on whether multi-item scoring is enabled.
|
||||
|
||||
For multi-item scoring, only delimiter positions have logprobs.
|
||||
For regular requests, all positions from logprob_start_len onwards have logprobs.
|
||||
"""
|
||||
is_multi_item_scoring = self._is_multi_item_scoring(req)
|
||||
|
||||
if is_multi_item_scoring:
|
||||
# Multi-item scoring: count delimiter tokens from logprob_start_len onwards
|
||||
relevant_tokens = req.origin_input_ids[req.logprob_start_len :]
|
||||
return sum(
|
||||
1
|
||||
for token_id in relevant_tokens
|
||||
if token_id == self.server_args.multi_item_scoring_delimiter
|
||||
)
|
||||
else:
|
||||
# Regular request: all tokens from logprob_start_len onwards
|
||||
return len(req.origin_input_ids) - req.logprob_start_len
|
||||
|
||||
def _calculate_num_input_logprobs(
|
||||
self, req: Req, extend_input_len: int, extend_logprob_start_len: int
|
||||
) -> int:
|
||||
"""Calculate the number of input logprobs based on whether multi-item scoring is enabled.
|
||||
|
||||
For multi-item scoring, only delimiter positions have logprobs.
|
||||
For regular requests, all positions in the range have logprobs.
|
||||
"""
|
||||
is_multi_item_scoring = self._is_multi_item_scoring(req)
|
||||
|
||||
if is_multi_item_scoring:
|
||||
# Multi-item scoring: count delimiter tokens in the relevant portion
|
||||
relevant_tokens = req.origin_input_ids[
|
||||
extend_logprob_start_len:extend_input_len
|
||||
]
|
||||
return sum(
|
||||
1
|
||||
for token_id in relevant_tokens
|
||||
if token_id == self.server_args.multi_item_scoring_delimiter
|
||||
)
|
||||
else:
|
||||
# Regular request: all tokens in the range
|
||||
return extend_input_len - extend_logprob_start_len
|
||||
|
||||
def _is_multi_item_scoring(self, req: Req) -> bool:
|
||||
"""Check if request uses multi-item scoring.
|
||||
|
||||
Multi-item scoring applies to prefill-only requests when a delimiter
|
||||
token is configured. In this mode, only positions containing the
|
||||
delimiter token receive logprobs.
|
||||
"""
|
||||
return req.is_prefill_only and self.server_args.multi_item_scoring_delimiter
|
||||
|
||||
def add_input_logprob_return_values(
|
||||
self: Scheduler,
|
||||
i: int,
|
||||
@@ -371,63 +521,14 @@ class SchedulerOutputProcessorMixin:
|
||||
assert req.input_top_logprobs_val is None
|
||||
assert req.input_top_logprobs_idx is None
|
||||
|
||||
# Compute input_token_logprobs_val
|
||||
# Always pad the first one with None.
|
||||
req.input_token_logprobs_val = [None]
|
||||
req.input_token_logprobs_val.extend(input_token_logprobs)
|
||||
# The last input logprob is for sampling, so just pop it out.
|
||||
req.input_token_logprobs_val.pop()
|
||||
# Process all input logprob types using helper functions
|
||||
self._process_input_token_logprobs(req, input_token_logprobs)
|
||||
self._process_input_top_logprobs(req)
|
||||
|
||||
# Compute input_token_logprobs_idx
|
||||
input_token_logprobs_idx = req.origin_input_ids[req.logprob_start_len :]
|
||||
# Clip the padded hash values from image tokens.
|
||||
# Otherwise, it will lead to detokenization errors.
|
||||
input_token_logprobs_idx = [
|
||||
x if x < self.model_config.vocab_size - 1 else 0
|
||||
for x in input_token_logprobs_idx
|
||||
]
|
||||
req.input_token_logprobs_idx = input_token_logprobs_idx
|
||||
|
||||
if req.top_logprobs_num > 0:
|
||||
req.input_top_logprobs_val = [None]
|
||||
req.input_top_logprobs_idx = [None]
|
||||
assert len(req.temp_input_token_ids_logprobs_val) == len(
|
||||
req.temp_input_token_ids_logprobs_idx
|
||||
)
|
||||
for val, idx in zip(
|
||||
req.temp_input_top_logprobs_val,
|
||||
req.temp_input_top_logprobs_idx,
|
||||
strict=True,
|
||||
):
|
||||
req.input_top_logprobs_val.extend(val)
|
||||
req.input_top_logprobs_idx.extend(idx)
|
||||
|
||||
# Last token is a sample token.
|
||||
req.input_top_logprobs_val.pop()
|
||||
req.input_top_logprobs_idx.pop()
|
||||
req.temp_input_top_logprobs_idx = None
|
||||
req.temp_input_top_logprobs_val = None
|
||||
|
||||
if req.token_ids_logprob is not None:
|
||||
req.input_token_ids_logprobs_val = [None]
|
||||
req.input_token_ids_logprobs_idx = [None]
|
||||
|
||||
for val, idx in zip(
|
||||
req.temp_input_token_ids_logprobs_val,
|
||||
req.temp_input_token_ids_logprobs_idx,
|
||||
strict=True,
|
||||
):
|
||||
req.input_token_ids_logprobs_val.extend(val)
|
||||
req.input_token_ids_logprobs_idx.extend(idx)
|
||||
|
||||
# Last token is a sample token.
|
||||
req.input_token_ids_logprobs_val.pop()
|
||||
req.input_token_ids_logprobs_idx.pop()
|
||||
req.temp_input_token_ids_logprobs_idx = None
|
||||
req.temp_input_token_ids_logprobs_val = None
|
||||
self._process_input_token_ids_logprobs(req)
|
||||
|
||||
if req.return_logprob:
|
||||
relevant_tokens_len = len(req.origin_input_ids) - req.logprob_start_len
|
||||
relevant_tokens_len = self._calculate_relevant_tokens_len(req)
|
||||
assert len(req.input_token_logprobs_val) == relevant_tokens_len
|
||||
assert len(req.input_token_logprobs_idx) == relevant_tokens_len
|
||||
if req.top_logprobs_num > 0:
|
||||
|
||||
@@ -182,6 +182,8 @@ class TokenizerManager(TokenizerCommunicatorMixin):
|
||||
if speculative_algorithm.is_none()
|
||||
else server_args.speculative_num_draft_tokens
|
||||
)
|
||||
# Initialize delimiter text for multi-item scoring (will be set after tokenizer is loaded)
|
||||
self.multi_item_delimiter_text = None
|
||||
|
||||
if self.model_config.is_multimodal:
|
||||
import_processors("sglang.srt.multimodal.processors")
|
||||
@@ -223,6 +225,7 @@ class TokenizerManager(TokenizerCommunicatorMixin):
|
||||
self.processor = _processor
|
||||
self.tokenizer = get_tokenizer_from_processor(self.processor)
|
||||
os.environ["TOKENIZERS_PARALLELISM"] = "false"
|
||||
self._initialize_multi_item_delimiter_text()
|
||||
else:
|
||||
self.mm_processor = self.processor = None
|
||||
|
||||
@@ -235,6 +238,7 @@ class TokenizerManager(TokenizerCommunicatorMixin):
|
||||
trust_remote_code=server_args.trust_remote_code,
|
||||
revision=server_args.revision,
|
||||
)
|
||||
self._initialize_multi_item_delimiter_text()
|
||||
# Initialize async dynamic batch tokenizer if enabled (common for both multimodal and non-multimodal)
|
||||
if (
|
||||
server_args.enable_dynamic_batch_tokenizer
|
||||
@@ -1678,6 +1682,201 @@ class TokenizerManager(TokenizerCommunicatorMixin):
|
||||
if len(self.model_update_tmp) == self.server_args.dp_size:
|
||||
self.model_update_result.set_result(self.model_update_tmp)
|
||||
|
||||
def _initialize_multi_item_delimiter_text(self):
|
||||
"""Initialize multi-item delimiter text from token ID after tokenizer is loaded."""
|
||||
if (
|
||||
hasattr(self.server_args, "multi_item_scoring_delimiter")
|
||||
and self.server_args.multi_item_scoring_delimiter is not None
|
||||
and self.tokenizer is not None
|
||||
):
|
||||
try:
|
||||
self.multi_item_delimiter_text = self.tokenizer.decode(
|
||||
[self.server_args.multi_item_scoring_delimiter],
|
||||
skip_special_tokens=False,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.warning(
|
||||
f"Failed to decode delimiter token {self.server_args.multi_item_scoring_delimiter}: {e}"
|
||||
)
|
||||
self.multi_item_delimiter_text = None
|
||||
|
||||
def _build_multi_item_token_sequence(
|
||||
self, query: List[int], items: List[List[int]], delimiter_token_id: int
|
||||
) -> List[int]:
|
||||
"""
|
||||
Build a single token sequence for multi-item scoring.
|
||||
Format: query<delimiter>item1<delimiter>item2<delimiter>item3<delimiter>
|
||||
|
||||
Args:
|
||||
query: Query token IDs
|
||||
items: List of item token ID sequences
|
||||
delimiter_token_id: Token ID to use as delimiter
|
||||
|
||||
Returns:
|
||||
Combined token sequence
|
||||
"""
|
||||
combined_sequence = query[:] # Start with query
|
||||
|
||||
for item in items:
|
||||
combined_sequence.append(delimiter_token_id) # Add delimiter
|
||||
combined_sequence.extend(item) # Add item tokens
|
||||
|
||||
# Add final delimiter after the last item for logprob extraction
|
||||
combined_sequence.append(delimiter_token_id)
|
||||
|
||||
return combined_sequence
|
||||
|
||||
def _extract_logprobs_for_tokens(
|
||||
self, logprobs_data: List, label_token_ids: List[int]
|
||||
) -> Dict[int, float]:
|
||||
"""
|
||||
Extract logprobs for specified token IDs from logprobs data.
|
||||
|
||||
Args:
|
||||
logprobs_data: List of (logprob, token_id, text) tuples
|
||||
label_token_ids: Token IDs to extract logprobs for
|
||||
|
||||
Returns:
|
||||
Dictionary mapping token_id to logprob
|
||||
"""
|
||||
logprobs = {}
|
||||
if logprobs_data:
|
||||
for logprob, token_id, _ in logprobs_data:
|
||||
if token_id in label_token_ids:
|
||||
logprobs[token_id] = logprob
|
||||
return logprobs
|
||||
|
||||
def _convert_logprobs_to_scores(
|
||||
self,
|
||||
logprobs: Dict[int, float],
|
||||
label_token_ids: List[int],
|
||||
apply_softmax: bool,
|
||||
) -> List[float]:
|
||||
"""
|
||||
Convert logprobs dictionary to ordered score list.
|
||||
|
||||
Args:
|
||||
logprobs: Dictionary mapping token_id to logprob
|
||||
label_token_ids: Token IDs in desired order
|
||||
apply_softmax: Whether to apply softmax normalization
|
||||
|
||||
Returns:
|
||||
List of scores in the same order as label_token_ids
|
||||
"""
|
||||
score_list = [
|
||||
logprobs.get(token_id, float("-inf")) for token_id in label_token_ids
|
||||
]
|
||||
|
||||
if apply_softmax:
|
||||
score_list = torch.softmax(torch.tensor(score_list), dim=0).tolist()
|
||||
else:
|
||||
# Convert logprobs to probabilities if not using softmax
|
||||
score_list = [
|
||||
math.exp(x) if x != float("-inf") else 0.0 for x in score_list
|
||||
]
|
||||
|
||||
return score_list
|
||||
|
||||
def _process_multi_item_scoring_results(
|
||||
self,
|
||||
results: Any,
|
||||
items: List,
|
||||
label_token_ids: List[int],
|
||||
apply_softmax: bool,
|
||||
batch_request=None,
|
||||
) -> List[List[float]]:
|
||||
"""
|
||||
Process results from multi-item scoring request.
|
||||
Extracts logprobs at delimiter positions from input_token_ids_logprobs.
|
||||
|
||||
Args:
|
||||
results: Results from generate_request
|
||||
items: List of items being scored
|
||||
label_token_ids: Token IDs to extract scores for
|
||||
apply_softmax: Whether to apply softmax normalization
|
||||
batch_request: The original batch request containing input sequence
|
||||
|
||||
Returns:
|
||||
List of score lists, one for each item
|
||||
"""
|
||||
single_result = results[0] if isinstance(results, list) else results
|
||||
|
||||
# For multi-item scoring, logprobs are in input_token_ids_logprobs
|
||||
input_logprobs = single_result["meta_info"].get("input_token_ids_logprobs", [])
|
||||
|
||||
if not input_logprobs:
|
||||
raise RuntimeError(
|
||||
f"input_token_ids_logprobs is empty for multi-item scoring request {single_result['meta_info'].get('id', '<unknown>')}. "
|
||||
"This indicates token_ids_logprobs were not computed properly for Mutil Item Scoring."
|
||||
)
|
||||
|
||||
scores = []
|
||||
num_items = len(items) if isinstance(items, list) else 1
|
||||
|
||||
# Check if we have the expected number of logprobs
|
||||
expected_logprobs_count = num_items + 1
|
||||
if len(input_logprobs) != expected_logprobs_count:
|
||||
raise RuntimeError(
|
||||
f"Expected {expected_logprobs_count} input_token_ids_logprobs for multi-item scoring "
|
||||
f"with {num_items} items, but got {len(input_logprobs)}. "
|
||||
f"Request ID: {single_result['meta_info'].get('id', '<unknown>')}"
|
||||
)
|
||||
|
||||
# Skip the first delimiter (between query and first item) and process remaining delimiter positions
|
||||
# We want to exclude the first one since it represents the boundary between query and first item, not an item boundary
|
||||
start_idx = 1 if len(input_logprobs) > 1 else 0
|
||||
|
||||
# Process logprobs for each item position (excluding first delimiter)
|
||||
for item_idx in range(num_items):
|
||||
logprob_idx = start_idx + item_idx
|
||||
item_logprobs_data = input_logprobs[logprob_idx]
|
||||
logprobs = self._extract_logprobs_for_tokens(
|
||||
item_logprobs_data, label_token_ids
|
||||
)
|
||||
score_list = self._convert_logprobs_to_scores(
|
||||
logprobs, label_token_ids, apply_softmax
|
||||
)
|
||||
scores.append(score_list)
|
||||
|
||||
return scores
|
||||
|
||||
def _process_single_item_scoring_results(
|
||||
self, results: Any, label_token_ids: List[int], apply_softmax: bool
|
||||
) -> List[List[float]]:
|
||||
"""
|
||||
Process results from single-item scoring request.
|
||||
Single-item scoring results are stored in output_token_ids_logprobs.
|
||||
|
||||
Args:
|
||||
results: Results from generate_request
|
||||
label_token_ids: Token IDs to extract scores for
|
||||
apply_softmax: Whether to apply softmax normalization
|
||||
|
||||
Returns:
|
||||
List of score lists, one for each result
|
||||
"""
|
||||
scores = []
|
||||
|
||||
for result in results:
|
||||
# For single-item scoring, logprobs are in output_token_ids_logprobs
|
||||
output_logprobs = result["meta_info"].get("output_token_ids_logprobs", [])
|
||||
|
||||
if not output_logprobs or len(output_logprobs) == 0:
|
||||
raise RuntimeError(
|
||||
f"output_logprobs is empty for request {result['meta_info'].get('id', '<unknown>')}."
|
||||
)
|
||||
|
||||
# Extract logprobs for the first (and only) position
|
||||
logprobs = self._extract_logprobs_for_tokens(
|
||||
output_logprobs[0], label_token_ids
|
||||
)
|
||||
score_list = self._convert_logprobs_to_scores(
|
||||
logprobs, label_token_ids, apply_softmax
|
||||
)
|
||||
scores.append(score_list)
|
||||
|
||||
return scores
|
||||
|
||||
async def score_request(
|
||||
self,
|
||||
query: Optional[Union[str, List[int]]] = None,
|
||||
@@ -1688,7 +1887,29 @@ class TokenizerManager(TokenizerCommunicatorMixin):
|
||||
request: Optional[Any] = None,
|
||||
) -> List[List[float]]:
|
||||
"""
|
||||
See Engine.score() for more details.
|
||||
Score the probability of specified token IDs appearing after the given (query + item) pair.
|
||||
|
||||
This method supports two scoring approaches:
|
||||
1. Single-Item scoring (default): Process each query+item pair independently
|
||||
2. Multi-Item scoring: When multi_item_scoring_delimiter is set, combine query and
|
||||
multiple items into a single sequence using delimiter for efficient processing.
|
||||
Note: item_first parameter is ignored in multi-item scoring mode since it uses
|
||||
a fixed format: query<delimiter>item1<delimiter>item2<delimiter>item3<delimiter>
|
||||
|
||||
Multi-item scoring works with both text and pre-tokenized inputs:
|
||||
- Text: query<delimiter_text>item1<delimiter_text>item2<delimiter_text>item3<delimiter_text>
|
||||
- Tokens: query<delimiter_token_id>item1<delimiter_token_id>item2<delimiter_token_id>item3<delimiter_token_id>
|
||||
|
||||
Args:
|
||||
query: The query text or pre-tokenized query token IDs
|
||||
items: The item text(s) or pre-tokenized item token IDs
|
||||
label_token_ids: List of token IDs to compute probabilities for
|
||||
apply_softmax: Whether to normalize probabilities using softmax
|
||||
item_first: If True, prepend items to query. Ignored for multi-item scoring.
|
||||
request: Optional FastAPI request object
|
||||
|
||||
Returns:
|
||||
List of lists containing probabilities for each item and each label token
|
||||
"""
|
||||
if label_token_ids is None:
|
||||
raise ValueError("label_token_ids must be provided")
|
||||
@@ -1701,9 +1922,17 @@ class TokenizerManager(TokenizerCommunicatorMixin):
|
||||
f"Token ID {token_id} is out of vocabulary (vocab size: {vocab_size})"
|
||||
)
|
||||
|
||||
# Check if multi-item scoring is enabled by presence of delimiter
|
||||
use_multi_item_scoring = (
|
||||
self.server_args.multi_item_scoring_delimiter is not None
|
||||
and self.multi_item_delimiter_text is not None
|
||||
)
|
||||
|
||||
batch_request = GenerateReqInput(
|
||||
token_ids_logprob=label_token_ids,
|
||||
return_logprob=True,
|
||||
# Set logprob_start_len=0 for multi-item scoring since we want logprobs at all delimiter positions
|
||||
logprob_start_len=0 if use_multi_item_scoring else -1,
|
||||
stream=False,
|
||||
sampling_params={"max_new_tokens": 0},
|
||||
)
|
||||
@@ -1715,12 +1944,23 @@ class TokenizerManager(TokenizerCommunicatorMixin):
|
||||
):
|
||||
# Both query and items are text
|
||||
items_list = [items] if isinstance(items, str) else items
|
||||
if item_first:
|
||||
prompts = [f"{item}{query}" for item in items_list]
|
||||
else:
|
||||
prompts = [f"{query}{item}" for item in items_list]
|
||||
|
||||
batch_request.text = prompts
|
||||
if use_multi_item_scoring:
|
||||
# Multi-item scoring: create single prompt with delimiter text
|
||||
# Always use format: query<delimiter>item1<delimiter>item2<delimiter>item3<delimiter>
|
||||
# (item_first is ignored for multi-item scoring)
|
||||
delimiter = self.multi_item_delimiter_text
|
||||
combined_items = delimiter.join(items_list)
|
||||
# Add final delimiter after the last item for logprob extraction
|
||||
single_prompt = f"{query}{delimiter}{combined_items}{delimiter}"
|
||||
batch_request.text = [single_prompt]
|
||||
else:
|
||||
# Single-item scoring: create separate prompts for each item
|
||||
if item_first:
|
||||
prompts = [f"{item}{query}" for item in items_list]
|
||||
else:
|
||||
prompts = [f"{query}{item}" for item in items_list]
|
||||
batch_request.text = prompts
|
||||
|
||||
elif (
|
||||
isinstance(query, list)
|
||||
@@ -1729,61 +1969,38 @@ class TokenizerManager(TokenizerCommunicatorMixin):
|
||||
and isinstance(items[0], list)
|
||||
):
|
||||
# Both query and items are token IDs
|
||||
if item_first:
|
||||
input_ids_list = [item + query for item in items]
|
||||
if use_multi_item_scoring:
|
||||
# Multi-item scoring: concatenate with delimiter token ID
|
||||
# Format: query<delimiter_token_id>item1<delimiter_token_id>item2<delimiter_token_id>item3<delimiter_token_id>
|
||||
delimiter_token_id = self.server_args.multi_item_scoring_delimiter
|
||||
combined_input_ids = self._build_multi_item_token_sequence(
|
||||
query, items, delimiter_token_id
|
||||
)
|
||||
batch_request.input_ids = [combined_input_ids]
|
||||
else:
|
||||
input_ids_list = [query + item for item in items]
|
||||
|
||||
batch_request.input_ids = input_ids_list
|
||||
# Single-item scoring: process each item separately
|
||||
if item_first:
|
||||
input_ids_list = [item + query for item in items]
|
||||
else:
|
||||
input_ids_list = [query + item for item in items]
|
||||
batch_request.input_ids = input_ids_list
|
||||
else:
|
||||
raise ValueError(
|
||||
"Invalid combination of query/items types for score_request."
|
||||
)
|
||||
|
||||
results = await self.generate_request(batch_request, request).__anext__()
|
||||
scores = []
|
||||
|
||||
for result in results:
|
||||
# Get logprobs for each token
|
||||
logprobs = {}
|
||||
|
||||
# For scoring requests, we read from output_token_ids_logprobs since we want
|
||||
# the logprobs for specific tokens mentioned in the label_token_ids at
|
||||
# the next position after the last token in the prompt
|
||||
output_logprobs = result["meta_info"].get("output_token_ids_logprobs", [])
|
||||
|
||||
# Check if output_logprobs is properly populated
|
||||
if (
|
||||
output_logprobs is None
|
||||
or not output_logprobs
|
||||
or len(output_logprobs) == 0
|
||||
):
|
||||
raise RuntimeError(
|
||||
f"output_logprobs is empty for request {result['meta_info'].get('id', '<unknown>')}. "
|
||||
"This indicates token_ids_logprobs were not computed properly for the scoring request."
|
||||
)
|
||||
|
||||
for logprob, token_id, _ in output_logprobs[0]:
|
||||
if token_id in label_token_ids:
|
||||
logprobs[token_id] = logprob
|
||||
|
||||
# Get scores in order of label_token_ids
|
||||
score_list = [
|
||||
logprobs.get(token_id, float("-inf")) for token_id in label_token_ids
|
||||
]
|
||||
|
||||
# Apply softmax to logprobs if needed
|
||||
if apply_softmax:
|
||||
score_list = torch.softmax(torch.tensor(score_list), dim=0).tolist()
|
||||
else:
|
||||
# Convert logprobs to probabilities if not using softmax
|
||||
score_list = [
|
||||
math.exp(x) if x != float("-inf") else 0.0 for x in score_list
|
||||
]
|
||||
|
||||
scores.append(score_list)
|
||||
|
||||
return scores
|
||||
if use_multi_item_scoring:
|
||||
# Multi-item scoring: extract scores from input_token_ids_logprobs
|
||||
return self._process_multi_item_scoring_results(
|
||||
results, items, label_token_ids, apply_softmax, batch_request
|
||||
)
|
||||
else:
|
||||
# Single-item scoring: process each result separately
|
||||
return self._process_single_item_scoring_results(
|
||||
results, label_token_ids, apply_softmax
|
||||
)
|
||||
|
||||
async def watch_load_thread(self):
|
||||
# Only for dp_controller when dp_size > 1
|
||||
|
||||
@@ -266,10 +266,16 @@ class TpModelWorker:
|
||||
|
||||
if model_worker_batch.is_prefill_only:
|
||||
# For prefill-only requests, create dummy token IDs on CPU
|
||||
batch_result.next_token_ids = torch.zeros_like(
|
||||
model_worker_batch.input_ids, dtype=torch.long
|
||||
# The size should match the batch size (number of sequences), not total tokens
|
||||
batch_result.next_token_ids = torch.zeros(
|
||||
len(model_worker_batch.seq_lens),
|
||||
dtype=torch.long,
|
||||
device=model_worker_batch.input_ids.device,
|
||||
)
|
||||
if model_worker_batch.return_logprob:
|
||||
if (
|
||||
model_worker_batch.return_logprob
|
||||
and logits_output.next_token_logits is not None
|
||||
):
|
||||
# NOTE: Compute logprobs without full sampling
|
||||
self.model_runner.compute_logprobs_only(
|
||||
logits_output, model_worker_batch
|
||||
|
||||
@@ -278,6 +278,9 @@ class ForwardBatch:
|
||||
can_run_dp_cuda_graph: bool = False
|
||||
global_forward_mode: Optional[ForwardMode] = None
|
||||
|
||||
# Whether this batch is prefill-only (no token generation needed)
|
||||
is_prefill_only: bool = False
|
||||
|
||||
# Speculative decoding
|
||||
spec_info: Optional[SpecInput] = None
|
||||
spec_algorithm: SpeculativeAlgorithm = None
|
||||
@@ -325,6 +328,7 @@ class ForwardBatch:
|
||||
is_extend_in_batch=batch.is_extend_in_batch,
|
||||
can_run_dp_cuda_graph=batch.can_run_dp_cuda_graph,
|
||||
global_forward_mode=batch.global_forward_mode,
|
||||
is_prefill_only=batch.is_prefill_only,
|
||||
lora_ids=batch.lora_ids,
|
||||
sampling_info=batch.sampling_info,
|
||||
req_to_token_pool=model_runner.req_to_token_pool,
|
||||
|
||||
@@ -382,6 +382,12 @@ class ServerArgs:
|
||||
offload_prefetch_step: int = 1
|
||||
offload_mode: str = "cpu"
|
||||
|
||||
# Scoring configuration
|
||||
# Delimiter token ID used to combine Query and Items into a single sequence for multi-item scoring.
|
||||
# Format: Query<delimiter>Item1<delimiter>Item2<delimiter>...
|
||||
# This enables efficient batch processing of multiple items against a single query.
|
||||
multi_item_scoring_delimiter: Optional[Union[int]] = None
|
||||
|
||||
# Optimization/debug options
|
||||
disable_radix_cache: bool = False
|
||||
cuda_graph_max_bs: Optional[int] = None
|
||||
@@ -2334,7 +2340,13 @@ class ServerArgs:
|
||||
choices=["float32", "bfloat16"],
|
||||
help="The data type of the SSM states in mamba cache.",
|
||||
)
|
||||
|
||||
# Args for multi-item-scoring
|
||||
parser.add_argument(
|
||||
"--multi-item-scoring-delimiter",
|
||||
type=int,
|
||||
default=ServerArgs.multi_item_scoring_delimiter,
|
||||
help="Delimiter token ID for multi-item scoring. Used to combine Query and Items into a single sequence: Query<delimiter>Item1<delimiter>Item2<delimiter>... This enables efficient batch processing of multiple items against a single query.",
|
||||
)
|
||||
# Hierarchical cache
|
||||
parser.add_argument(
|
||||
"--enable-hierarchical-cache",
|
||||
@@ -3004,6 +3016,17 @@ class ServerArgs:
|
||||
"lof",
|
||||
], f"To use priority scheduling, schedule_policy must be 'fcfs' or 'lof'. '{self.schedule_policy}' is not supported."
|
||||
|
||||
# Check multi-item scoring
|
||||
if self.multi_item_scoring_delimiter is not None:
|
||||
assert self.disable_radix_cache, (
|
||||
"Multi-item scoring requires radix cache to be disabled. "
|
||||
"Please set --disable-radix-cache when using --multi-item-scoring-delimiter."
|
||||
)
|
||||
assert self.chunked_prefill_size == -1, (
|
||||
"Multi-item scoring requires chunked prefill to be disabled. "
|
||||
"Please set --chunked-prefill-size -1 when using --multi-item-scoring-delimiter."
|
||||
)
|
||||
|
||||
def check_lora_server_args(self):
|
||||
assert self.max_loras_per_batch > 0, "max_loras_per_batch must be positive"
|
||||
|
||||
|
||||
@@ -667,6 +667,7 @@ class TboForwardBatchPreparer:
|
||||
"can_run_dp_cuda_graph",
|
||||
"dp_padding_mode",
|
||||
"global_forward_mode",
|
||||
"is_prefill_only",
|
||||
"spec_algorithm",
|
||||
"capture_hidden_mode",
|
||||
"padded_static_len",
|
||||
|
||||
Reference in New Issue
Block a user