[Generative Score API] Multi-Item scoring with custom attention mask. (#10979)
This commit is contained in:
committed by
GitHub
parent
e22b13c569
commit
53bd00d975
@@ -7,6 +7,7 @@ FlashInfer is faster and Triton is easier to customize.
|
|||||||
Each backend supports two operators: extend (i.e. prefill with cached prefix) and decode.
|
Each backend supports two operators: extend (i.e. prefill with cached prefix) and decode.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
import logging
|
||||||
import os
|
import os
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from enum import Enum, auto
|
from enum import Enum, auto
|
||||||
@@ -16,11 +17,11 @@ from typing import TYPE_CHECKING, Callable, List, Optional, Union
|
|||||||
import torch
|
import torch
|
||||||
|
|
||||||
if os.environ["SGLANG_ENABLE_TORCH_COMPILE"] == "1":
|
if os.environ["SGLANG_ENABLE_TORCH_COMPILE"] == "1":
|
||||||
import logging
|
|
||||||
|
|
||||||
torch._logging.set_logs(dynamo=logging.ERROR)
|
torch._logging.set_logs(dynamo=logging.ERROR)
|
||||||
torch._dynamo.config.suppress_errors = True
|
torch._dynamo.config.suppress_errors = True
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
from sglang.global_config import global_config
|
from sglang.global_config import global_config
|
||||||
from sglang.srt.layers.attention.base_attn_backend import AttentionBackend
|
from sglang.srt.layers.attention.base_attn_backend import AttentionBackend
|
||||||
from sglang.srt.layers.attention.utils import create_flashinfer_kv_indices_triton
|
from sglang.srt.layers.attention.utils import create_flashinfer_kv_indices_triton
|
||||||
@@ -58,6 +59,36 @@ class WrapperDispatch(Enum):
|
|||||||
CROSS_ATTENTION = auto()
|
CROSS_ATTENTION = auto()
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class MultiItemScoringParams:
|
||||||
|
"""Parameters for multi-item scoring in attention computation.
|
||||||
|
|
||||||
|
Used when processing sequences with multiple items separated by delimiters,
|
||||||
|
where each item needs specific attention patterns that respect item boundaries.
|
||||||
|
|
||||||
|
Attributes:
|
||||||
|
prefix_len_ptr: A uint32 1D tensor indicating the prefix length of each prompt.
|
||||||
|
The tensor size is equal to the batch size.
|
||||||
|
token_pos_in_items_ptr: A uint16 1D tensor indicating the token position of each item
|
||||||
|
starting from 0 (delimiter) for each item. For batch size > 1,
|
||||||
|
sequences are concatenated with zero padding to ensure same length.
|
||||||
|
token_pos_in_items_len: Zero padding length for token_pos_in_items_ptr to handle
|
||||||
|
batch_size > 1 case. Defines the padded length for each sequence.
|
||||||
|
max_item_len_ptr: A uint16 tensor containing the max token length of all items
|
||||||
|
for each prompt in the batch.
|
||||||
|
|
||||||
|
"""
|
||||||
|
|
||||||
|
prefix_len_ptr: Optional[torch.Tensor] = None
|
||||||
|
token_pos_in_items_ptr: Optional[torch.Tensor] = None
|
||||||
|
token_pos_in_items_len: int = 0
|
||||||
|
max_item_len_ptr: Optional[torch.Tensor] = None
|
||||||
|
|
||||||
|
def is_enabled(self) -> bool:
|
||||||
|
"""Check if multi-item scoring is enabled."""
|
||||||
|
return self.prefix_len_ptr is not None
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class DecodeMetadata:
|
class DecodeMetadata:
|
||||||
decode_wrappers: List[BatchDecodeWithPagedKVCacheWrapper]
|
decode_wrappers: List[BatchDecodeWithPagedKVCacheWrapper]
|
||||||
@@ -68,6 +99,7 @@ class PrefillMetadata:
|
|||||||
prefill_wrappers: List[BatchPrefillWithPagedKVCacheWrapper]
|
prefill_wrappers: List[BatchPrefillWithPagedKVCacheWrapper]
|
||||||
use_ragged: bool
|
use_ragged: bool
|
||||||
extend_no_prefix: bool
|
extend_no_prefix: bool
|
||||||
|
multi_item_params: Optional[MultiItemScoringParams] = None
|
||||||
|
|
||||||
|
|
||||||
# Reuse this workspace buffer across all flashinfer wrappers
|
# Reuse this workspace buffer across all flashinfer wrappers
|
||||||
@@ -90,6 +122,11 @@ class FlashInferAttnBackend(AttentionBackend):
|
|||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
|
# Store multi-item scoring delimiter for efficient access
|
||||||
|
self.multi_item_scoring_delimiter = (
|
||||||
|
model_runner.server_args.multi_item_scoring_delimiter
|
||||||
|
)
|
||||||
|
|
||||||
# Parse constants
|
# Parse constants
|
||||||
self.decode_use_tensor_cores = should_use_tensor_core(
|
self.decode_use_tensor_cores = should_use_tensor_core(
|
||||||
kv_cache_dtype=model_runner.kv_cache_dtype,
|
kv_cache_dtype=model_runner.kv_cache_dtype,
|
||||||
@@ -229,10 +266,133 @@ class FlashInferAttnBackend(AttentionBackend):
|
|||||||
|
|
||||||
# Other metadata
|
# Other metadata
|
||||||
self.forward_metadata: Union[PrefillMetadata, DecodeMetadata] = None
|
self.forward_metadata: Union[PrefillMetadata, DecodeMetadata] = None
|
||||||
|
|
||||||
self.decode_cuda_graph_metadata = {}
|
self.decode_cuda_graph_metadata = {}
|
||||||
self.prefill_cuda_graph_metadata = {} # For verify
|
self.prefill_cuda_graph_metadata = {} # For verify
|
||||||
self.draft_extend_cuda_graph_metadata = {} # For draft extend
|
self.draft_extend_cuda_graph_metadata = {} # For draft extend
|
||||||
|
|
||||||
|
def _process_multi_item_scoring(
|
||||||
|
self, forward_batch: ForwardBatch
|
||||||
|
) -> MultiItemScoringParams:
|
||||||
|
"""Process multi-item scoring tensors for FlashInfer attention.
|
||||||
|
|
||||||
|
This method handles sequences containing multiple "items" separated by delimiter tokens,
|
||||||
|
where each item needs specific attention patterns that respect item boundaries.
|
||||||
|
|
||||||
|
The method produces four key tensors for FlashInfer:
|
||||||
|
- prefix_len_ptr: uint32 tensor with prefix length for each prompt in batch
|
||||||
|
- token_pos_in_items_ptr: uint16 tensor with token positions starting from 0 at delimiters
|
||||||
|
- token_pos_in_items_len: padding length for batch processing
|
||||||
|
- max_item_len_ptr: uint16 tensor with max item length for each prompt
|
||||||
|
|
||||||
|
Args:
|
||||||
|
forward_batch: The forward batch containing input sequences and delimiter info
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
MultiItemScoringParams: The processed multi-item scoring parameters
|
||||||
|
|
||||||
|
Examples:
|
||||||
|
Following FlashInfer definition: for 3 items of length 3, 2, 4 respectively:
|
||||||
|
token_pos_in_items_ptr = [0, 1, 2, 3, 0, 1, 2, 0, 1, 2, 3, 4, 0]
|
||||||
|
|
||||||
|
Case 1: Single sequence
|
||||||
|
Text: "What is the capital of France? <delim> London <delim> Paris <delim> Berlin <delim>"
|
||||||
|
Tokens: [What, is, the, capital, of, France, ?, <delim>, London, <delim>, Paris, <delim>, Berlin, <delim>]
|
||||||
|
Indices: [ 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13]
|
||||||
|
- prefix_len_ptr: [7] (query length before first delimiter)
|
||||||
|
- token_pos_in_items_ptr: [0, 1, 0, 1, 0, 1, 0] (delim=0, London=1, delim=0, Paris=1, delim=0, Berlin=1, delim=0)
|
||||||
|
- token_pos_in_items_len: 7 (actual length)
|
||||||
|
- max_item_len_ptr: [1] (max item length is 1 token - all options are single tokens)
|
||||||
|
|
||||||
|
Case 2: Batch processing (batch_size=2)
|
||||||
|
Sequence 1: 2 items of length 2, 1 → [0, 1, 2, 0, 1, 0] (6 elements)
|
||||||
|
Sequence 2: 3 items of length 1, 3, 2 → [0, 1, 0, 1, 2, 3, 0, 1, 2, 0] (10 elements)
|
||||||
|
After padding both to length 10:
|
||||||
|
- token_pos_in_items_ptr: [0, 1, 2, 0, 1, 0, 0, 0, 0, 0, 0, 1, 0, 1, 2, 3, 0, 1, 2, 0]
|
||||||
|
- token_pos_in_items_len: 10 (padded length for batch processing)
|
||||||
|
- max_item_len_ptr: [2, 3] (max lengths per sequence)
|
||||||
|
"""
|
||||||
|
|
||||||
|
delimiter = self.multi_item_scoring_delimiter
|
||||||
|
if delimiter is None or forward_batch.forward_mode == ForwardMode.DECODE:
|
||||||
|
return MultiItemScoringParams()
|
||||||
|
|
||||||
|
delimiter_mask = forward_batch.input_ids == delimiter
|
||||||
|
prefix_cache_lens = getattr(forward_batch, "extend_prefix_lens", None)
|
||||||
|
extend_seq_lens = getattr(forward_batch, "extend_seq_lens", None)
|
||||||
|
prefix_len_ptr, token_pos_in_items_ptr = [], []
|
||||||
|
token_pos_in_items_len = 0
|
||||||
|
|
||||||
|
# If no extend_seq_lens, treat whole batch as one sequence
|
||||||
|
if extend_seq_lens is None or len(extend_seq_lens) <= 1:
|
||||||
|
extend_seq_lens = [forward_batch.input_ids.size(0)]
|
||||||
|
|
||||||
|
seq_start = 0
|
||||||
|
for i, seq_len in enumerate(extend_seq_lens):
|
||||||
|
seq_end = seq_start + seq_len
|
||||||
|
mask = delimiter_mask[seq_start:seq_end]
|
||||||
|
pos = forward_batch.positions[seq_start:seq_end]
|
||||||
|
delimiter_indices = torch.nonzero(mask, as_tuple=True)[0]
|
||||||
|
|
||||||
|
if len(delimiter_indices) > 0:
|
||||||
|
first_delim = delimiter_indices[0]
|
||||||
|
# Prefix length: store as scalar
|
||||||
|
prefix_len = first_delim + (
|
||||||
|
prefix_cache_lens[i] if prefix_cache_lens is not None else 0
|
||||||
|
)
|
||||||
|
prefix_len_ptr.append(
|
||||||
|
prefix_len.item() if torch.is_tensor(prefix_len) else prefix_len
|
||||||
|
)
|
||||||
|
|
||||||
|
# Compute relative positions within items after delimiters
|
||||||
|
diff = pos[first_delim:] - torch.cummax(mask[first_delim:], 0)[1]
|
||||||
|
token_pos = (diff - pos[first_delim]).to(torch.uint16)
|
||||||
|
token_pos_in_items_ptr.append(token_pos)
|
||||||
|
|
||||||
|
# Update forward_batch positions in-place
|
||||||
|
pos[first_delim:] = diff - 1
|
||||||
|
forward_batch.positions[seq_start:seq_end] = pos
|
||||||
|
|
||||||
|
seq_start = seq_end
|
||||||
|
|
||||||
|
# Pad token_pos_in_items_ptr for batch processing
|
||||||
|
if token_pos_in_items_ptr:
|
||||||
|
token_pos_in_items_len = max(t.numel() for t in token_pos_in_items_ptr)
|
||||||
|
device = forward_batch.input_ids.device
|
||||||
|
token_pos_in_items_ptr = [
|
||||||
|
torch.cat(
|
||||||
|
[
|
||||||
|
t,
|
||||||
|
torch.zeros(
|
||||||
|
token_pos_in_items_len - t.numel(),
|
||||||
|
dtype=torch.uint16,
|
||||||
|
device=device,
|
||||||
|
),
|
||||||
|
]
|
||||||
|
)
|
||||||
|
for t in token_pos_in_items_ptr
|
||||||
|
]
|
||||||
|
|
||||||
|
if not prefix_len_ptr or not token_pos_in_items_ptr:
|
||||||
|
return MultiItemScoringParams()
|
||||||
|
|
||||||
|
# Build final params
|
||||||
|
device = forward_batch.input_ids.device
|
||||||
|
return MultiItemScoringParams(
|
||||||
|
prefix_len_ptr=torch.tensor(
|
||||||
|
prefix_len_ptr, dtype=torch.uint32, device=device
|
||||||
|
),
|
||||||
|
token_pos_in_items_ptr=torch.cat(token_pos_in_items_ptr, dim=0),
|
||||||
|
token_pos_in_items_len=token_pos_in_items_len & 0xFFFFFFFF,
|
||||||
|
max_item_len_ptr=torch.stack(
|
||||||
|
[
|
||||||
|
t.to(torch.int32).max().to(torch.uint16)
|
||||||
|
for t in token_pos_in_items_ptr
|
||||||
|
],
|
||||||
|
dim=0,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
def init_forward_metadata(self, forward_batch: ForwardBatch):
|
def init_forward_metadata(self, forward_batch: ForwardBatch):
|
||||||
if forward_batch.forward_mode.is_decode_or_idle():
|
if forward_batch.forward_mode.is_decode_or_idle():
|
||||||
self.indices_updater_decode.update(
|
self.indices_updater_decode.update(
|
||||||
@@ -280,13 +440,26 @@ class FlashInferAttnBackend(AttentionBackend):
|
|||||||
else:
|
else:
|
||||||
prefix_lens = forward_batch.extend_prefix_lens
|
prefix_lens = forward_batch.extend_prefix_lens
|
||||||
|
|
||||||
if self.is_multimodal:
|
# Disable ragged wrapper and ensure prefix handling for multimodal and multi-item scoring
|
||||||
|
if self.is_multimodal or self.multi_item_scoring_delimiter is not None:
|
||||||
|
# use_ragged = False: Multi-item scoring requires the paged wrapper because:
|
||||||
|
# 1. Ragged wrapper doesn't support the specialized multi-item parameters
|
||||||
|
# (prefix_len_ptr, token_pos_in_items_ptr, etc.)
|
||||||
|
# 2. Paged wrapper provides better control over attention masking needed
|
||||||
|
# for respecting item boundaries in multi-item sequences
|
||||||
|
# 3. Custom masking logic conflicts with ragged wrapper's assumptions
|
||||||
use_ragged = False
|
use_ragged = False
|
||||||
extend_no_prefix = False
|
extend_no_prefix = False
|
||||||
else:
|
else:
|
||||||
use_ragged = not self.enable_deterministic
|
use_ragged = not self.enable_deterministic
|
||||||
extend_no_prefix = not any(forward_batch.extend_prefix_lens_cpu)
|
extend_no_prefix = not any(forward_batch.extend_prefix_lens_cpu)
|
||||||
|
|
||||||
|
# Process multi-item scoring in attention backend instead of ForwardBatch
|
||||||
|
multi_item_params = MultiItemScoringParams()
|
||||||
|
if self.multi_item_scoring_delimiter is not None:
|
||||||
|
# Use new backend-specific implementation
|
||||||
|
multi_item_params = self._process_multi_item_scoring(forward_batch)
|
||||||
|
|
||||||
self.indices_updater_prefill.update(
|
self.indices_updater_prefill.update(
|
||||||
forward_batch.req_pool_indices,
|
forward_batch.req_pool_indices,
|
||||||
forward_batch.seq_lens,
|
forward_batch.seq_lens,
|
||||||
@@ -298,9 +471,13 @@ class FlashInferAttnBackend(AttentionBackend):
|
|||||||
encoder_lens=forward_batch.encoder_lens,
|
encoder_lens=forward_batch.encoder_lens,
|
||||||
spec_info=None,
|
spec_info=None,
|
||||||
fixed_split_size=self.prefill_split_tile_size,
|
fixed_split_size=self.prefill_split_tile_size,
|
||||||
|
multi_item_params=multi_item_params,
|
||||||
)
|
)
|
||||||
self.forward_metadata = PrefillMetadata(
|
self.forward_metadata = PrefillMetadata(
|
||||||
self.prefill_wrappers_paged, use_ragged, extend_no_prefix
|
self.prefill_wrappers_paged,
|
||||||
|
use_ragged,
|
||||||
|
extend_no_prefix,
|
||||||
|
multi_item_params,
|
||||||
)
|
)
|
||||||
|
|
||||||
def init_cuda_graph_state(
|
def init_cuda_graph_state(
|
||||||
@@ -531,7 +708,20 @@ class FlashInferAttnBackend(AttentionBackend):
|
|||||||
forward_batch.token_to_kv_pool.get_kv_buffer(layer.layer_id),
|
forward_batch.token_to_kv_pool.get_kv_buffer(layer.layer_id),
|
||||||
causal=not layer.is_cross_attention,
|
causal=not layer.is_cross_attention,
|
||||||
sm_scale=layer.scaling,
|
sm_scale=layer.scaling,
|
||||||
window_left=layer.sliding_window_size,
|
# Disable sliding window attention for multi-item scoring:
|
||||||
|
# - Sliding window could cut across item boundaries, breaking semantic coherence
|
||||||
|
# - Multi-item sequences need full attention to properly handle delimiter tokens
|
||||||
|
# - Specialized multi-item parameters (prefix_len_ptr, token_pos_in_items_ptr)
|
||||||
|
# provide more precise attention control than simple sliding windows
|
||||||
|
# - Item-aware masking takes precedence over window-based masking
|
||||||
|
window_left=(
|
||||||
|
layer.sliding_window_size
|
||||||
|
if not (
|
||||||
|
self.forward_metadata.multi_item_params
|
||||||
|
and self.forward_metadata.multi_item_params.is_enabled()
|
||||||
|
)
|
||||||
|
else -1
|
||||||
|
),
|
||||||
logits_soft_cap=logits_soft_cap,
|
logits_soft_cap=logits_soft_cap,
|
||||||
# Must use _float to avoid device-to-host copy that breaks cuda graph capture.
|
# Must use _float to avoid device-to-host copy that breaks cuda graph capture.
|
||||||
k_scale=layer.k_scale_float,
|
k_scale=layer.k_scale_float,
|
||||||
@@ -952,6 +1142,7 @@ class FlashInferIndicesUpdaterPrefill:
|
|||||||
encoder_lens: Optional[torch.Tensor],
|
encoder_lens: Optional[torch.Tensor],
|
||||||
spec_info: Optional[SpecInput],
|
spec_info: Optional[SpecInput],
|
||||||
fixed_split_size: Optional[int] = None,
|
fixed_split_size: Optional[int] = None,
|
||||||
|
multi_item_params: Optional[MultiItemScoringParams] = None,
|
||||||
):
|
):
|
||||||
if use_ragged:
|
if use_ragged:
|
||||||
# TODO: remove this device sync, we can use forward_batch.extend_prefix_lens_cpu
|
# TODO: remove this device sync, we can use forward_batch.extend_prefix_lens_cpu
|
||||||
@@ -976,6 +1167,7 @@ class FlashInferIndicesUpdaterPrefill:
|
|||||||
use_ragged,
|
use_ragged,
|
||||||
spec_info,
|
spec_info,
|
||||||
fixed_split_size=fixed_split_size,
|
fixed_split_size=fixed_split_size,
|
||||||
|
multi_item_params=multi_item_params,
|
||||||
)
|
)
|
||||||
|
|
||||||
def update_sliding_window(
|
def update_sliding_window(
|
||||||
@@ -990,6 +1182,7 @@ class FlashInferIndicesUpdaterPrefill:
|
|||||||
encoder_lens: Optional[torch.Tensor],
|
encoder_lens: Optional[torch.Tensor],
|
||||||
spec_info: Optional[SpecInput],
|
spec_info: Optional[SpecInput],
|
||||||
fixed_split_size: Optional[int] = None,
|
fixed_split_size: Optional[int] = None,
|
||||||
|
multi_item_params: Optional[MultiItemScoringParams] = None,
|
||||||
):
|
):
|
||||||
for wrapper_id in range(2):
|
for wrapper_id in range(2):
|
||||||
if wrapper_id == 0:
|
if wrapper_id == 0:
|
||||||
@@ -1023,6 +1216,7 @@ class FlashInferIndicesUpdaterPrefill:
|
|||||||
use_ragged,
|
use_ragged,
|
||||||
spec_info,
|
spec_info,
|
||||||
use_sliding_window_kv_pool=use_sliding_window_kv_pool,
|
use_sliding_window_kv_pool=use_sliding_window_kv_pool,
|
||||||
|
multi_item_params=multi_item_params,
|
||||||
)
|
)
|
||||||
|
|
||||||
def update_cross_attention(
|
def update_cross_attention(
|
||||||
@@ -1037,6 +1231,7 @@ class FlashInferIndicesUpdaterPrefill:
|
|||||||
encoder_lens: Optional[torch.Tensor],
|
encoder_lens: Optional[torch.Tensor],
|
||||||
spec_info: Optional[SpecInput],
|
spec_info: Optional[SpecInput],
|
||||||
fixed_split_size: Optional[int] = None,
|
fixed_split_size: Optional[int] = None,
|
||||||
|
multi_item_params: Optional[MultiItemScoringParams] = None,
|
||||||
):
|
):
|
||||||
for wrapper_id in range(2):
|
for wrapper_id in range(2):
|
||||||
if wrapper_id == 0:
|
if wrapper_id == 0:
|
||||||
@@ -1063,6 +1258,7 @@ class FlashInferIndicesUpdaterPrefill:
|
|||||||
self.qo_indptr[wrapper_id],
|
self.qo_indptr[wrapper_id],
|
||||||
use_ragged,
|
use_ragged,
|
||||||
spec_info,
|
spec_info,
|
||||||
|
multi_item_params=multi_item_params,
|
||||||
)
|
)
|
||||||
|
|
||||||
def call_begin_forward(
|
def call_begin_forward(
|
||||||
@@ -1081,6 +1277,7 @@ class FlashInferIndicesUpdaterPrefill:
|
|||||||
spec_info: Optional[SpecInput],
|
spec_info: Optional[SpecInput],
|
||||||
use_sliding_window_kv_pool: bool = False,
|
use_sliding_window_kv_pool: bool = False,
|
||||||
fixed_split_size: Optional[int] = None,
|
fixed_split_size: Optional[int] = None,
|
||||||
|
multi_item_params: Optional[MultiItemScoringParams] = None,
|
||||||
):
|
):
|
||||||
bs = len(seq_lens)
|
bs = len(seq_lens)
|
||||||
if spec_info is None:
|
if spec_info is None:
|
||||||
@@ -1136,6 +1333,22 @@ class FlashInferIndicesUpdaterPrefill:
|
|||||||
)
|
)
|
||||||
|
|
||||||
# cached part
|
# cached part
|
||||||
|
# Conditionally set multi-item parameters
|
||||||
|
if multi_item_params is not None and multi_item_params.is_enabled():
|
||||||
|
# Multi-item scoring is active - use specialized parameters and disable generic custom_mask
|
||||||
|
use_custom_mask = None
|
||||||
|
prefix_len_ptr = multi_item_params.prefix_len_ptr
|
||||||
|
token_pos_in_items_ptr = multi_item_params.token_pos_in_items_ptr
|
||||||
|
token_pos_in_items_len = multi_item_params.token_pos_in_items_len
|
||||||
|
max_item_len_ptr = multi_item_params.max_item_len_ptr
|
||||||
|
else:
|
||||||
|
# No multi-item scoring - use standard parameters
|
||||||
|
use_custom_mask = custom_mask
|
||||||
|
prefix_len_ptr = None
|
||||||
|
token_pos_in_items_ptr = None
|
||||||
|
token_pos_in_items_len = 0
|
||||||
|
max_item_len_ptr = None
|
||||||
|
|
||||||
wrapper_paged.begin_forward(
|
wrapper_paged.begin_forward(
|
||||||
qo_indptr,
|
qo_indptr,
|
||||||
kv_indptr,
|
kv_indptr,
|
||||||
@@ -1147,9 +1360,13 @@ class FlashInferIndicesUpdaterPrefill:
|
|||||||
1,
|
1,
|
||||||
q_data_type=self.q_data_type,
|
q_data_type=self.q_data_type,
|
||||||
kv_data_type=self.data_type,
|
kv_data_type=self.data_type,
|
||||||
custom_mask=custom_mask,
|
custom_mask=use_custom_mask,
|
||||||
non_blocking=True,
|
non_blocking=True,
|
||||||
fixed_split_size=fixed_split_size,
|
fixed_split_size=fixed_split_size,
|
||||||
|
prefix_len_ptr=prefix_len_ptr,
|
||||||
|
token_pos_in_items_ptr=token_pos_in_items_ptr,
|
||||||
|
token_pos_in_items_len=token_pos_in_items_len,
|
||||||
|
max_item_len_ptr=max_item_len_ptr,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -60,7 +60,8 @@ _is_npu = is_npu()
|
|||||||
class LogitsProcessorOutput:
|
class LogitsProcessorOutput:
|
||||||
## Part 1: This part will be assigned in python/sglang/srt/layers/logits_processor.py::LogitsProcessor
|
## Part 1: This part will be assigned in python/sglang/srt/layers/logits_processor.py::LogitsProcessor
|
||||||
# The logits of the next tokens. shape: [#seq, vocab_size]
|
# The logits of the next tokens. shape: [#seq, vocab_size]
|
||||||
next_token_logits: torch.Tensor
|
# Can be None for certain prefill-only requests (e.g., multi-item scoring) that don't need next token generation
|
||||||
|
next_token_logits: Optional[torch.Tensor]
|
||||||
# Used by speculative decoding (EAGLE)
|
# Used by speculative decoding (EAGLE)
|
||||||
# The last hidden layers
|
# The last hidden layers
|
||||||
hidden_states: Optional[torch.Tensor] = None
|
hidden_states: Optional[torch.Tensor] = None
|
||||||
@@ -85,7 +86,10 @@ class LogitsProcessorOutput:
|
|||||||
input_top_logprobs_val: List = None
|
input_top_logprobs_val: List = None
|
||||||
input_top_logprobs_idx: List = None
|
input_top_logprobs_idx: List = None
|
||||||
# The logprobs and ids of the requested token ids in input positions. shape: [#seq, n] (n is the number of requested token ids)
|
# The logprobs and ids of the requested token ids in input positions. shape: [#seq, n] (n is the number of requested token ids)
|
||||||
input_token_ids_logprobs_val: Optional[List] = None
|
# Can contain either lists or GPU tensors (for delayed GPU-to-CPU transfer optimization)
|
||||||
|
input_token_ids_logprobs_val: Optional[List[Union[List[float], torch.Tensor]]] = (
|
||||||
|
None
|
||||||
|
)
|
||||||
input_token_ids_logprobs_idx: Optional[List] = None
|
input_token_ids_logprobs_idx: Optional[List] = None
|
||||||
|
|
||||||
|
|
||||||
@@ -127,6 +131,9 @@ class LogitsMetadata:
|
|||||||
# for padding
|
# for padding
|
||||||
padded_static_len: int = -1
|
padded_static_len: int = -1
|
||||||
|
|
||||||
|
# Whether this batch is prefill-only (no token generation needed)
|
||||||
|
is_prefill_only: bool = False
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_forward_batch(cls, forward_batch: ForwardBatch):
|
def from_forward_batch(cls, forward_batch: ForwardBatch):
|
||||||
if (
|
if (
|
||||||
@@ -169,6 +176,7 @@ class LogitsMetadata:
|
|||||||
token_ids_logprobs=forward_batch.token_ids_logprobs,
|
token_ids_logprobs=forward_batch.token_ids_logprobs,
|
||||||
extend_input_logprob_token_ids_gpu=forward_batch.extend_input_logprob_token_ids_gpu,
|
extend_input_logprob_token_ids_gpu=forward_batch.extend_input_logprob_token_ids_gpu,
|
||||||
padded_static_len=forward_batch.padded_static_len,
|
padded_static_len=forward_batch.padded_static_len,
|
||||||
|
is_prefill_only=forward_batch.is_prefill_only,
|
||||||
global_num_tokens_gpu=forward_batch.global_num_tokens_gpu,
|
global_num_tokens_gpu=forward_batch.global_num_tokens_gpu,
|
||||||
dp_local_start_pos=forward_batch.dp_local_start_pos,
|
dp_local_start_pos=forward_batch.dp_local_start_pos,
|
||||||
dp_local_num_tokens=forward_batch.dp_local_num_tokens,
|
dp_local_num_tokens=forward_batch.dp_local_num_tokens,
|
||||||
@@ -247,6 +255,108 @@ class LogitsProcessor(nn.Module):
|
|||||||
"debug_tensor_dump_output_folder", None
|
"debug_tensor_dump_output_folder", None
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def compute_logprobs_for_multi_item_scoring(
|
||||||
|
self,
|
||||||
|
input_ids,
|
||||||
|
hidden_states,
|
||||||
|
lm_head: VocabParallelEmbedding,
|
||||||
|
logits_metadata: Union[LogitsMetadata, ForwardBatch],
|
||||||
|
delimiter_token: int,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Compute logprobs for multi-item scoring using delimiter-based token extraction.
|
||||||
|
|
||||||
|
This method is designed for scenarios where you want to score multiple items/candidates
|
||||||
|
against a single query by combining them into one sequence separated by delimiters.
|
||||||
|
|
||||||
|
Sequence format: Query<delimiter>Item1<delimiter>Item2<delimiter>...
|
||||||
|
Scoring positions: Extracts logprobs at positions before each <delimiter>
|
||||||
|
|
||||||
|
Args:
|
||||||
|
input_ids (torch.Tensor): Input token IDs containing query and items separated by delimiters.
|
||||||
|
Shape: [total_sequence_length] for single request or [batch_total_length] for batch.
|
||||||
|
hidden_states (torch.Tensor): Hidden states from the model.
|
||||||
|
Shape: [sequence_length, hidden_dim].
|
||||||
|
lm_head (VocabParallelEmbedding): Language model head for computing logits.
|
||||||
|
logits_metadata (Union[LogitsMetadata, ForwardBatch]): Metadata containing batch info
|
||||||
|
and token ID specifications for logprob extraction.
|
||||||
|
delimiter_token (int): Token ID used as delimiter between query and items.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
LogitsProcessorOutput: Contains:
|
||||||
|
- next_token_logits: None (not needed for scoring-only requests)
|
||||||
|
- input_token_logprobs: Logprobs of delimiter tokens at scoring positions
|
||||||
|
- input_top_logprobs_val: Top-k logprobs at delimiter positions (if requested)
|
||||||
|
- input_top_logprobs_idx: Top-k token indices at delimiter positions (if requested)
|
||||||
|
- input_token_ids_logprobs_val: Logprobs for user-requested token IDs (if any)
|
||||||
|
- input_token_ids_logprobs_idx: Indices for user-requested token IDs (if any)
|
||||||
|
"""
|
||||||
|
multi_item_indices = (input_ids == delimiter_token).nonzero(as_tuple=True)[
|
||||||
|
0
|
||||||
|
] - 1
|
||||||
|
# Extract hidden states at delimiter positions for multi-item scoring
|
||||||
|
sliced_hidden = hidden_states[multi_item_indices]
|
||||||
|
|
||||||
|
sliced_logits = self._get_logits(sliced_hidden, lm_head, logits_metadata)
|
||||||
|
sliced_logprobs = torch.nn.functional.log_softmax(sliced_logits, dim=-1)
|
||||||
|
|
||||||
|
# Initialize return values
|
||||||
|
input_token_ids_logprobs_val = []
|
||||||
|
input_token_ids_logprobs_idx = []
|
||||||
|
input_top_logprobs_val = None
|
||||||
|
input_top_logprobs_idx = None
|
||||||
|
|
||||||
|
# Recalculate extend_logprob_pruned_lens_cpu to match delimiter counts per request
|
||||||
|
# Original contains sequence lengths, but we need delimiter counts for sliced_logprobs
|
||||||
|
if (
|
||||||
|
logits_metadata.token_ids_logprobs
|
||||||
|
or logits_metadata.extend_return_top_logprob
|
||||||
|
):
|
||||||
|
logits_metadata.extend_logprob_pruned_lens_cpu = []
|
||||||
|
|
||||||
|
if logits_metadata.extend_seq_lens_cpu is not None:
|
||||||
|
# Multi-request batch: count delimiters per request
|
||||||
|
input_pt = 0
|
||||||
|
for req_seq_len in logits_metadata.extend_seq_lens_cpu:
|
||||||
|
req_input_ids = input_ids[input_pt : input_pt + req_seq_len]
|
||||||
|
delimiter_count = (req_input_ids == delimiter_token).sum().item()
|
||||||
|
logits_metadata.extend_logprob_pruned_lens_cpu.append(
|
||||||
|
delimiter_count
|
||||||
|
)
|
||||||
|
input_pt += req_seq_len
|
||||||
|
else:
|
||||||
|
# Single request case: one request gets all delimiters
|
||||||
|
total_delimiters = (input_ids == delimiter_token).sum().item()
|
||||||
|
logits_metadata.extend_logprob_pruned_lens_cpu = [total_delimiters]
|
||||||
|
|
||||||
|
# Get the logprobs of specified token ids
|
||||||
|
if logits_metadata.extend_token_ids_logprob:
|
||||||
|
(
|
||||||
|
input_token_ids_logprobs_val,
|
||||||
|
input_token_ids_logprobs_idx,
|
||||||
|
) = self.get_token_ids_logprobs(
|
||||||
|
sliced_logprobs, logits_metadata, delay_cpu_copy=True
|
||||||
|
)
|
||||||
|
|
||||||
|
# Get the logprob of top-k tokens
|
||||||
|
if logits_metadata.extend_return_top_logprob:
|
||||||
|
(
|
||||||
|
input_top_logprobs_val,
|
||||||
|
input_top_logprobs_idx,
|
||||||
|
) = self.get_top_logprobs(sliced_logprobs, logits_metadata)
|
||||||
|
|
||||||
|
# For input_token_logprobs, use delimiter token logprobs
|
||||||
|
input_token_logprobs = sliced_logprobs[:, delimiter_token]
|
||||||
|
|
||||||
|
return LogitsProcessorOutput(
|
||||||
|
next_token_logits=None, # Multi-item scoring doesn't need next token logits
|
||||||
|
input_token_logprobs=input_token_logprobs,
|
||||||
|
input_top_logprobs_val=input_top_logprobs_val,
|
||||||
|
input_top_logprobs_idx=input_top_logprobs_idx,
|
||||||
|
input_token_ids_logprobs_val=input_token_ids_logprobs_val,
|
||||||
|
input_token_ids_logprobs_idx=input_token_ids_logprobs_idx,
|
||||||
|
)
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
input_ids,
|
input_ids,
|
||||||
@@ -257,6 +367,16 @@ class LogitsProcessor(nn.Module):
|
|||||||
) -> LogitsProcessorOutput:
|
) -> LogitsProcessorOutput:
|
||||||
if isinstance(logits_metadata, ForwardBatch):
|
if isinstance(logits_metadata, ForwardBatch):
|
||||||
logits_metadata = LogitsMetadata.from_forward_batch(logits_metadata)
|
logits_metadata = LogitsMetadata.from_forward_batch(logits_metadata)
|
||||||
|
|
||||||
|
# Check if multi-item scoring is enabled via server args (only for prefill-only requests)
|
||||||
|
multi_item_delimiter = global_server_args_dict.get(
|
||||||
|
"multi_item_scoring_delimiter"
|
||||||
|
)
|
||||||
|
if multi_item_delimiter is not None and logits_metadata.is_prefill_only:
|
||||||
|
return self.compute_logprobs_for_multi_item_scoring(
|
||||||
|
input_ids, hidden_states, lm_head, logits_metadata, multi_item_delimiter
|
||||||
|
)
|
||||||
|
|
||||||
# Get the last hidden states and last logits for the next token prediction
|
# Get the last hidden states and last logits for the next token prediction
|
||||||
if (
|
if (
|
||||||
logits_metadata.forward_mode.is_decode_or_idle()
|
logits_metadata.forward_mode.is_decode_or_idle()
|
||||||
@@ -584,7 +704,9 @@ class LogitsProcessor(nn.Module):
|
|||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def get_token_ids_logprobs(
|
def get_token_ids_logprobs(
|
||||||
all_logprobs: torch.Tensor, logits_metadata: LogitsMetadata
|
all_logprobs: torch.Tensor,
|
||||||
|
logits_metadata: LogitsMetadata,
|
||||||
|
delay_cpu_copy: bool = False,
|
||||||
):
|
):
|
||||||
input_token_ids_logprobs_val, input_token_ids_logprobs_idx = [], []
|
input_token_ids_logprobs_val, input_token_ids_logprobs_idx = [], []
|
||||||
pt = 0
|
pt = 0
|
||||||
@@ -597,9 +719,17 @@ class LogitsProcessor(nn.Module):
|
|||||||
input_token_ids_logprobs_idx.append([])
|
input_token_ids_logprobs_idx.append([])
|
||||||
continue
|
continue
|
||||||
|
|
||||||
input_token_ids_logprobs_val.append(
|
position_logprobs = all_logprobs[
|
||||||
[all_logprobs[pt + j, token_ids].tolist() for j in range(pruned_len)]
|
pt : pt + pruned_len, token_ids
|
||||||
)
|
] # Shape: [pruned_len, num_tokens]
|
||||||
|
|
||||||
|
if delay_cpu_copy:
|
||||||
|
# Keep as tensor to delay GPU-to-CPU transfer
|
||||||
|
input_token_ids_logprobs_val.append(position_logprobs)
|
||||||
|
else:
|
||||||
|
# Convert to list immediately (default behavior)
|
||||||
|
input_token_ids_logprobs_val.append(position_logprobs.tolist())
|
||||||
|
|
||||||
input_token_ids_logprobs_idx.append([token_ids for _ in range(pruned_len)])
|
input_token_ids_logprobs_idx.append([token_ids for _ in range(pruned_len)])
|
||||||
pt += pruned_len
|
pt += pruned_len
|
||||||
|
|
||||||
|
|||||||
@@ -114,6 +114,7 @@ GLOBAL_SERVER_ARGS_KEYS = [
|
|||||||
"enable_deterministic_inference",
|
"enable_deterministic_inference",
|
||||||
"nsa_prefill",
|
"nsa_prefill",
|
||||||
"nsa_decode",
|
"nsa_decode",
|
||||||
|
"multi_item_scoring_delimiter",
|
||||||
]
|
]
|
||||||
|
|
||||||
# Put some global args for easy access
|
# Put some global args for easy access
|
||||||
@@ -666,9 +667,11 @@ class Req:
|
|||||||
def is_prefill_only(self) -> bool:
|
def is_prefill_only(self) -> bool:
|
||||||
"""Check if this request is prefill-only (no token generation needed)."""
|
"""Check if this request is prefill-only (no token generation needed)."""
|
||||||
# NOTE: when spec is enabled, prefill_only optimizations are disabled
|
# NOTE: when spec is enabled, prefill_only optimizations are disabled
|
||||||
return (
|
from sglang.srt.speculative.spec_info import SpeculativeAlgorithm
|
||||||
self.sampling_params.max_new_tokens == 0
|
|
||||||
and global_server_args_dict["speculative_algorithm"] is None
|
spec_alg = global_server_args_dict["speculative_algorithm"]
|
||||||
|
return self.sampling_params.max_new_tokens == 0 and (
|
||||||
|
spec_alg is None or spec_alg == SpeculativeAlgorithm.NONE
|
||||||
)
|
)
|
||||||
|
|
||||||
def add_latency(self, stage: RequestStage):
|
def add_latency(self, stage: RequestStage):
|
||||||
|
|||||||
@@ -104,7 +104,10 @@ class SchedulerOutputProcessorMixin:
|
|||||||
assert extend_input_len_per_req is not None
|
assert extend_input_len_per_req is not None
|
||||||
extend_logprob_start_len = extend_logprob_start_len_per_req[i]
|
extend_logprob_start_len = extend_logprob_start_len_per_req[i]
|
||||||
extend_input_len = extend_input_len_per_req[i]
|
extend_input_len = extend_input_len_per_req[i]
|
||||||
num_input_logprobs = extend_input_len - extend_logprob_start_len
|
|
||||||
|
num_input_logprobs = self._calculate_num_input_logprobs(
|
||||||
|
req, extend_input_len, extend_logprob_start_len
|
||||||
|
)
|
||||||
|
|
||||||
if req.return_logprob:
|
if req.return_logprob:
|
||||||
self.add_logprob_return_values(
|
self.add_logprob_return_values(
|
||||||
@@ -159,8 +162,8 @@ class SchedulerOutputProcessorMixin:
|
|||||||
extend_input_len = extend_input_len_per_req[i]
|
extend_input_len = extend_input_len_per_req[i]
|
||||||
if extend_logprob_start_len < extend_input_len:
|
if extend_logprob_start_len < extend_input_len:
|
||||||
# Update input logprobs.
|
# Update input logprobs.
|
||||||
num_input_logprobs = (
|
num_input_logprobs = self._calculate_num_input_logprobs(
|
||||||
extend_input_len - extend_logprob_start_len
|
req, extend_input_len, extend_logprob_start_len
|
||||||
)
|
)
|
||||||
if req.return_logprob:
|
if req.return_logprob:
|
||||||
self.add_input_logprob_return_values(
|
self.add_input_logprob_return_values(
|
||||||
@@ -303,6 +306,153 @@ class SchedulerOutputProcessorMixin:
|
|||||||
):
|
):
|
||||||
self.log_decode_stats(can_run_cuda_graph, running_batch=batch)
|
self.log_decode_stats(can_run_cuda_graph, running_batch=batch)
|
||||||
|
|
||||||
|
def _process_input_token_logprobs(
|
||||||
|
self, req: Req, input_token_logprobs: List
|
||||||
|
) -> None:
|
||||||
|
"""Process input token logprobs values and indices."""
|
||||||
|
is_multi_item_scoring = self._is_multi_item_scoring(req)
|
||||||
|
|
||||||
|
# Process logprob values - handle multi-item scoring vs regular requests
|
||||||
|
if is_multi_item_scoring:
|
||||||
|
# Multi-item scoring: use all logprobs as-is
|
||||||
|
req.input_token_logprobs_val = input_token_logprobs
|
||||||
|
else:
|
||||||
|
# Regular request: add None at start, remove last (sampling token)
|
||||||
|
req.input_token_logprobs_val = [None] + input_token_logprobs[:-1]
|
||||||
|
|
||||||
|
# Process logprob indices based on scoring type
|
||||||
|
if is_multi_item_scoring:
|
||||||
|
# Multi-item scoring: only include delimiter token positions
|
||||||
|
relevant_tokens = req.origin_input_ids[req.logprob_start_len :]
|
||||||
|
input_token_logprobs_idx = [
|
||||||
|
token_id
|
||||||
|
for token_id in relevant_tokens
|
||||||
|
if token_id == self.server_args.multi_item_scoring_delimiter
|
||||||
|
]
|
||||||
|
else:
|
||||||
|
# Regular request: include all tokens from logprob_start_len onwards
|
||||||
|
input_token_logprobs_idx = req.origin_input_ids[req.logprob_start_len :]
|
||||||
|
|
||||||
|
# Clip padded hash values from image tokens to prevent detokenization errors
|
||||||
|
req.input_token_logprobs_idx = [
|
||||||
|
x if x < self.model_config.vocab_size - 1 else 0
|
||||||
|
for x in input_token_logprobs_idx
|
||||||
|
]
|
||||||
|
|
||||||
|
def _process_input_top_logprobs(self, req: Req) -> None:
|
||||||
|
"""Process input top logprobs."""
|
||||||
|
if req.top_logprobs_num <= 0:
|
||||||
|
return
|
||||||
|
|
||||||
|
is_multi_item_scoring = self._is_multi_item_scoring(req)
|
||||||
|
|
||||||
|
# Initialize arrays - multi-item scoring starts empty, others start with None
|
||||||
|
req.input_top_logprobs_val = [] if is_multi_item_scoring else [None]
|
||||||
|
req.input_top_logprobs_idx = [] if is_multi_item_scoring else [None]
|
||||||
|
|
||||||
|
# Extend arrays with temp values
|
||||||
|
for val, idx in zip(
|
||||||
|
req.temp_input_top_logprobs_val,
|
||||||
|
req.temp_input_top_logprobs_idx,
|
||||||
|
strict=True,
|
||||||
|
):
|
||||||
|
req.input_top_logprobs_val.extend(val)
|
||||||
|
req.input_top_logprobs_idx.extend(idx)
|
||||||
|
|
||||||
|
# Remove last token (sampling token) for non multi-item scoring requests
|
||||||
|
if not is_multi_item_scoring:
|
||||||
|
req.input_top_logprobs_val.pop()
|
||||||
|
req.input_top_logprobs_idx.pop()
|
||||||
|
|
||||||
|
# Clean up temp storage
|
||||||
|
req.temp_input_top_logprobs_idx = None
|
||||||
|
req.temp_input_top_logprobs_val = None
|
||||||
|
|
||||||
|
def _process_input_token_ids_logprobs(self, req: Req) -> None:
|
||||||
|
"""Process input token IDs logprobs."""
|
||||||
|
if req.token_ids_logprob is None:
|
||||||
|
return
|
||||||
|
|
||||||
|
is_multi_item_scoring = self._is_multi_item_scoring(req)
|
||||||
|
|
||||||
|
# Initialize arrays - multi-item scoring starts empty, others start with None
|
||||||
|
req.input_token_ids_logprobs_val = [] if is_multi_item_scoring else [None]
|
||||||
|
req.input_token_ids_logprobs_idx = [] if is_multi_item_scoring else [None]
|
||||||
|
|
||||||
|
# Process temp values - convert tensors to lists and extend arrays
|
||||||
|
for val, idx in zip(
|
||||||
|
req.temp_input_token_ids_logprobs_val,
|
||||||
|
req.temp_input_token_ids_logprobs_idx,
|
||||||
|
strict=True,
|
||||||
|
):
|
||||||
|
val_list = val.tolist() if isinstance(val, torch.Tensor) else val
|
||||||
|
req.input_token_ids_logprobs_val.extend(
|
||||||
|
val_list if isinstance(val_list, list) else [val_list]
|
||||||
|
)
|
||||||
|
req.input_token_ids_logprobs_idx.extend(idx)
|
||||||
|
|
||||||
|
# Remove last token (sampling token) for non multi-item scoring requests
|
||||||
|
if not is_multi_item_scoring:
|
||||||
|
req.input_token_ids_logprobs_val.pop()
|
||||||
|
req.input_token_ids_logprobs_idx.pop()
|
||||||
|
|
||||||
|
# Clean up temp storage
|
||||||
|
req.temp_input_token_ids_logprobs_idx = None
|
||||||
|
req.temp_input_token_ids_logprobs_val = None
|
||||||
|
|
||||||
|
def _calculate_relevant_tokens_len(self, req: Req) -> int:
|
||||||
|
"""Calculate the expected length of logprob arrays based on whether multi-item scoring is enabled.
|
||||||
|
|
||||||
|
For multi-item scoring, only delimiter positions have logprobs.
|
||||||
|
For regular requests, all positions from logprob_start_len onwards have logprobs.
|
||||||
|
"""
|
||||||
|
is_multi_item_scoring = self._is_multi_item_scoring(req)
|
||||||
|
|
||||||
|
if is_multi_item_scoring:
|
||||||
|
# Multi-item scoring: count delimiter tokens from logprob_start_len onwards
|
||||||
|
relevant_tokens = req.origin_input_ids[req.logprob_start_len :]
|
||||||
|
return sum(
|
||||||
|
1
|
||||||
|
for token_id in relevant_tokens
|
||||||
|
if token_id == self.server_args.multi_item_scoring_delimiter
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
# Regular request: all tokens from logprob_start_len onwards
|
||||||
|
return len(req.origin_input_ids) - req.logprob_start_len
|
||||||
|
|
||||||
|
def _calculate_num_input_logprobs(
|
||||||
|
self, req: Req, extend_input_len: int, extend_logprob_start_len: int
|
||||||
|
) -> int:
|
||||||
|
"""Calculate the number of input logprobs based on whether multi-item scoring is enabled.
|
||||||
|
|
||||||
|
For multi-item scoring, only delimiter positions have logprobs.
|
||||||
|
For regular requests, all positions in the range have logprobs.
|
||||||
|
"""
|
||||||
|
is_multi_item_scoring = self._is_multi_item_scoring(req)
|
||||||
|
|
||||||
|
if is_multi_item_scoring:
|
||||||
|
# Multi-item scoring: count delimiter tokens in the relevant portion
|
||||||
|
relevant_tokens = req.origin_input_ids[
|
||||||
|
extend_logprob_start_len:extend_input_len
|
||||||
|
]
|
||||||
|
return sum(
|
||||||
|
1
|
||||||
|
for token_id in relevant_tokens
|
||||||
|
if token_id == self.server_args.multi_item_scoring_delimiter
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
# Regular request: all tokens in the range
|
||||||
|
return extend_input_len - extend_logprob_start_len
|
||||||
|
|
||||||
|
def _is_multi_item_scoring(self, req: Req) -> bool:
|
||||||
|
"""Check if request uses multi-item scoring.
|
||||||
|
|
||||||
|
Multi-item scoring applies to prefill-only requests when a delimiter
|
||||||
|
token is configured. In this mode, only positions containing the
|
||||||
|
delimiter token receive logprobs.
|
||||||
|
"""
|
||||||
|
return req.is_prefill_only and self.server_args.multi_item_scoring_delimiter
|
||||||
|
|
||||||
def add_input_logprob_return_values(
|
def add_input_logprob_return_values(
|
||||||
self: Scheduler,
|
self: Scheduler,
|
||||||
i: int,
|
i: int,
|
||||||
@@ -371,63 +521,14 @@ class SchedulerOutputProcessorMixin:
|
|||||||
assert req.input_top_logprobs_val is None
|
assert req.input_top_logprobs_val is None
|
||||||
assert req.input_top_logprobs_idx is None
|
assert req.input_top_logprobs_idx is None
|
||||||
|
|
||||||
# Compute input_token_logprobs_val
|
# Process all input logprob types using helper functions
|
||||||
# Always pad the first one with None.
|
self._process_input_token_logprobs(req, input_token_logprobs)
|
||||||
req.input_token_logprobs_val = [None]
|
self._process_input_top_logprobs(req)
|
||||||
req.input_token_logprobs_val.extend(input_token_logprobs)
|
|
||||||
# The last input logprob is for sampling, so just pop it out.
|
|
||||||
req.input_token_logprobs_val.pop()
|
|
||||||
|
|
||||||
# Compute input_token_logprobs_idx
|
self._process_input_token_ids_logprobs(req)
|
||||||
input_token_logprobs_idx = req.origin_input_ids[req.logprob_start_len :]
|
|
||||||
# Clip the padded hash values from image tokens.
|
|
||||||
# Otherwise, it will lead to detokenization errors.
|
|
||||||
input_token_logprobs_idx = [
|
|
||||||
x if x < self.model_config.vocab_size - 1 else 0
|
|
||||||
for x in input_token_logprobs_idx
|
|
||||||
]
|
|
||||||
req.input_token_logprobs_idx = input_token_logprobs_idx
|
|
||||||
|
|
||||||
if req.top_logprobs_num > 0:
|
|
||||||
req.input_top_logprobs_val = [None]
|
|
||||||
req.input_top_logprobs_idx = [None]
|
|
||||||
assert len(req.temp_input_token_ids_logprobs_val) == len(
|
|
||||||
req.temp_input_token_ids_logprobs_idx
|
|
||||||
)
|
|
||||||
for val, idx in zip(
|
|
||||||
req.temp_input_top_logprobs_val,
|
|
||||||
req.temp_input_top_logprobs_idx,
|
|
||||||
strict=True,
|
|
||||||
):
|
|
||||||
req.input_top_logprobs_val.extend(val)
|
|
||||||
req.input_top_logprobs_idx.extend(idx)
|
|
||||||
|
|
||||||
# Last token is a sample token.
|
|
||||||
req.input_top_logprobs_val.pop()
|
|
||||||
req.input_top_logprobs_idx.pop()
|
|
||||||
req.temp_input_top_logprobs_idx = None
|
|
||||||
req.temp_input_top_logprobs_val = None
|
|
||||||
|
|
||||||
if req.token_ids_logprob is not None:
|
|
||||||
req.input_token_ids_logprobs_val = [None]
|
|
||||||
req.input_token_ids_logprobs_idx = [None]
|
|
||||||
|
|
||||||
for val, idx in zip(
|
|
||||||
req.temp_input_token_ids_logprobs_val,
|
|
||||||
req.temp_input_token_ids_logprobs_idx,
|
|
||||||
strict=True,
|
|
||||||
):
|
|
||||||
req.input_token_ids_logprobs_val.extend(val)
|
|
||||||
req.input_token_ids_logprobs_idx.extend(idx)
|
|
||||||
|
|
||||||
# Last token is a sample token.
|
|
||||||
req.input_token_ids_logprobs_val.pop()
|
|
||||||
req.input_token_ids_logprobs_idx.pop()
|
|
||||||
req.temp_input_token_ids_logprobs_idx = None
|
|
||||||
req.temp_input_token_ids_logprobs_val = None
|
|
||||||
|
|
||||||
if req.return_logprob:
|
if req.return_logprob:
|
||||||
relevant_tokens_len = len(req.origin_input_ids) - req.logprob_start_len
|
relevant_tokens_len = self._calculate_relevant_tokens_len(req)
|
||||||
assert len(req.input_token_logprobs_val) == relevant_tokens_len
|
assert len(req.input_token_logprobs_val) == relevant_tokens_len
|
||||||
assert len(req.input_token_logprobs_idx) == relevant_tokens_len
|
assert len(req.input_token_logprobs_idx) == relevant_tokens_len
|
||||||
if req.top_logprobs_num > 0:
|
if req.top_logprobs_num > 0:
|
||||||
|
|||||||
@@ -182,6 +182,8 @@ class TokenizerManager(TokenizerCommunicatorMixin):
|
|||||||
if speculative_algorithm.is_none()
|
if speculative_algorithm.is_none()
|
||||||
else server_args.speculative_num_draft_tokens
|
else server_args.speculative_num_draft_tokens
|
||||||
)
|
)
|
||||||
|
# Initialize delimiter text for multi-item scoring (will be set after tokenizer is loaded)
|
||||||
|
self.multi_item_delimiter_text = None
|
||||||
|
|
||||||
if self.model_config.is_multimodal:
|
if self.model_config.is_multimodal:
|
||||||
import_processors("sglang.srt.multimodal.processors")
|
import_processors("sglang.srt.multimodal.processors")
|
||||||
@@ -223,6 +225,7 @@ class TokenizerManager(TokenizerCommunicatorMixin):
|
|||||||
self.processor = _processor
|
self.processor = _processor
|
||||||
self.tokenizer = get_tokenizer_from_processor(self.processor)
|
self.tokenizer = get_tokenizer_from_processor(self.processor)
|
||||||
os.environ["TOKENIZERS_PARALLELISM"] = "false"
|
os.environ["TOKENIZERS_PARALLELISM"] = "false"
|
||||||
|
self._initialize_multi_item_delimiter_text()
|
||||||
else:
|
else:
|
||||||
self.mm_processor = self.processor = None
|
self.mm_processor = self.processor = None
|
||||||
|
|
||||||
@@ -235,6 +238,7 @@ class TokenizerManager(TokenizerCommunicatorMixin):
|
|||||||
trust_remote_code=server_args.trust_remote_code,
|
trust_remote_code=server_args.trust_remote_code,
|
||||||
revision=server_args.revision,
|
revision=server_args.revision,
|
||||||
)
|
)
|
||||||
|
self._initialize_multi_item_delimiter_text()
|
||||||
# Initialize async dynamic batch tokenizer if enabled (common for both multimodal and non-multimodal)
|
# Initialize async dynamic batch tokenizer if enabled (common for both multimodal and non-multimodal)
|
||||||
if (
|
if (
|
||||||
server_args.enable_dynamic_batch_tokenizer
|
server_args.enable_dynamic_batch_tokenizer
|
||||||
@@ -1678,6 +1682,201 @@ class TokenizerManager(TokenizerCommunicatorMixin):
|
|||||||
if len(self.model_update_tmp) == self.server_args.dp_size:
|
if len(self.model_update_tmp) == self.server_args.dp_size:
|
||||||
self.model_update_result.set_result(self.model_update_tmp)
|
self.model_update_result.set_result(self.model_update_tmp)
|
||||||
|
|
||||||
|
def _initialize_multi_item_delimiter_text(self):
|
||||||
|
"""Initialize multi-item delimiter text from token ID after tokenizer is loaded."""
|
||||||
|
if (
|
||||||
|
hasattr(self.server_args, "multi_item_scoring_delimiter")
|
||||||
|
and self.server_args.multi_item_scoring_delimiter is not None
|
||||||
|
and self.tokenizer is not None
|
||||||
|
):
|
||||||
|
try:
|
||||||
|
self.multi_item_delimiter_text = self.tokenizer.decode(
|
||||||
|
[self.server_args.multi_item_scoring_delimiter],
|
||||||
|
skip_special_tokens=False,
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(
|
||||||
|
f"Failed to decode delimiter token {self.server_args.multi_item_scoring_delimiter}: {e}"
|
||||||
|
)
|
||||||
|
self.multi_item_delimiter_text = None
|
||||||
|
|
||||||
|
def _build_multi_item_token_sequence(
|
||||||
|
self, query: List[int], items: List[List[int]], delimiter_token_id: int
|
||||||
|
) -> List[int]:
|
||||||
|
"""
|
||||||
|
Build a single token sequence for multi-item scoring.
|
||||||
|
Format: query<delimiter>item1<delimiter>item2<delimiter>item3<delimiter>
|
||||||
|
|
||||||
|
Args:
|
||||||
|
query: Query token IDs
|
||||||
|
items: List of item token ID sequences
|
||||||
|
delimiter_token_id: Token ID to use as delimiter
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Combined token sequence
|
||||||
|
"""
|
||||||
|
combined_sequence = query[:] # Start with query
|
||||||
|
|
||||||
|
for item in items:
|
||||||
|
combined_sequence.append(delimiter_token_id) # Add delimiter
|
||||||
|
combined_sequence.extend(item) # Add item tokens
|
||||||
|
|
||||||
|
# Add final delimiter after the last item for logprob extraction
|
||||||
|
combined_sequence.append(delimiter_token_id)
|
||||||
|
|
||||||
|
return combined_sequence
|
||||||
|
|
||||||
|
def _extract_logprobs_for_tokens(
|
||||||
|
self, logprobs_data: List, label_token_ids: List[int]
|
||||||
|
) -> Dict[int, float]:
|
||||||
|
"""
|
||||||
|
Extract logprobs for specified token IDs from logprobs data.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
logprobs_data: List of (logprob, token_id, text) tuples
|
||||||
|
label_token_ids: Token IDs to extract logprobs for
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Dictionary mapping token_id to logprob
|
||||||
|
"""
|
||||||
|
logprobs = {}
|
||||||
|
if logprobs_data:
|
||||||
|
for logprob, token_id, _ in logprobs_data:
|
||||||
|
if token_id in label_token_ids:
|
||||||
|
logprobs[token_id] = logprob
|
||||||
|
return logprobs
|
||||||
|
|
||||||
|
def _convert_logprobs_to_scores(
|
||||||
|
self,
|
||||||
|
logprobs: Dict[int, float],
|
||||||
|
label_token_ids: List[int],
|
||||||
|
apply_softmax: bool,
|
||||||
|
) -> List[float]:
|
||||||
|
"""
|
||||||
|
Convert logprobs dictionary to ordered score list.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
logprobs: Dictionary mapping token_id to logprob
|
||||||
|
label_token_ids: Token IDs in desired order
|
||||||
|
apply_softmax: Whether to apply softmax normalization
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of scores in the same order as label_token_ids
|
||||||
|
"""
|
||||||
|
score_list = [
|
||||||
|
logprobs.get(token_id, float("-inf")) for token_id in label_token_ids
|
||||||
|
]
|
||||||
|
|
||||||
|
if apply_softmax:
|
||||||
|
score_list = torch.softmax(torch.tensor(score_list), dim=0).tolist()
|
||||||
|
else:
|
||||||
|
# Convert logprobs to probabilities if not using softmax
|
||||||
|
score_list = [
|
||||||
|
math.exp(x) if x != float("-inf") else 0.0 for x in score_list
|
||||||
|
]
|
||||||
|
|
||||||
|
return score_list
|
||||||
|
|
||||||
|
def _process_multi_item_scoring_results(
|
||||||
|
self,
|
||||||
|
results: Any,
|
||||||
|
items: List,
|
||||||
|
label_token_ids: List[int],
|
||||||
|
apply_softmax: bool,
|
||||||
|
batch_request=None,
|
||||||
|
) -> List[List[float]]:
|
||||||
|
"""
|
||||||
|
Process results from multi-item scoring request.
|
||||||
|
Extracts logprobs at delimiter positions from input_token_ids_logprobs.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
results: Results from generate_request
|
||||||
|
items: List of items being scored
|
||||||
|
label_token_ids: Token IDs to extract scores for
|
||||||
|
apply_softmax: Whether to apply softmax normalization
|
||||||
|
batch_request: The original batch request containing input sequence
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of score lists, one for each item
|
||||||
|
"""
|
||||||
|
single_result = results[0] if isinstance(results, list) else results
|
||||||
|
|
||||||
|
# For multi-item scoring, logprobs are in input_token_ids_logprobs
|
||||||
|
input_logprobs = single_result["meta_info"].get("input_token_ids_logprobs", [])
|
||||||
|
|
||||||
|
if not input_logprobs:
|
||||||
|
raise RuntimeError(
|
||||||
|
f"input_token_ids_logprobs is empty for multi-item scoring request {single_result['meta_info'].get('id', '<unknown>')}. "
|
||||||
|
"This indicates token_ids_logprobs were not computed properly for Mutil Item Scoring."
|
||||||
|
)
|
||||||
|
|
||||||
|
scores = []
|
||||||
|
num_items = len(items) if isinstance(items, list) else 1
|
||||||
|
|
||||||
|
# Check if we have the expected number of logprobs
|
||||||
|
expected_logprobs_count = num_items + 1
|
||||||
|
if len(input_logprobs) != expected_logprobs_count:
|
||||||
|
raise RuntimeError(
|
||||||
|
f"Expected {expected_logprobs_count} input_token_ids_logprobs for multi-item scoring "
|
||||||
|
f"with {num_items} items, but got {len(input_logprobs)}. "
|
||||||
|
f"Request ID: {single_result['meta_info'].get('id', '<unknown>')}"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Skip the first delimiter (between query and first item) and process remaining delimiter positions
|
||||||
|
# We want to exclude the first one since it represents the boundary between query and first item, not an item boundary
|
||||||
|
start_idx = 1 if len(input_logprobs) > 1 else 0
|
||||||
|
|
||||||
|
# Process logprobs for each item position (excluding first delimiter)
|
||||||
|
for item_idx in range(num_items):
|
||||||
|
logprob_idx = start_idx + item_idx
|
||||||
|
item_logprobs_data = input_logprobs[logprob_idx]
|
||||||
|
logprobs = self._extract_logprobs_for_tokens(
|
||||||
|
item_logprobs_data, label_token_ids
|
||||||
|
)
|
||||||
|
score_list = self._convert_logprobs_to_scores(
|
||||||
|
logprobs, label_token_ids, apply_softmax
|
||||||
|
)
|
||||||
|
scores.append(score_list)
|
||||||
|
|
||||||
|
return scores
|
||||||
|
|
||||||
|
def _process_single_item_scoring_results(
|
||||||
|
self, results: Any, label_token_ids: List[int], apply_softmax: bool
|
||||||
|
) -> List[List[float]]:
|
||||||
|
"""
|
||||||
|
Process results from single-item scoring request.
|
||||||
|
Single-item scoring results are stored in output_token_ids_logprobs.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
results: Results from generate_request
|
||||||
|
label_token_ids: Token IDs to extract scores for
|
||||||
|
apply_softmax: Whether to apply softmax normalization
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of score lists, one for each result
|
||||||
|
"""
|
||||||
|
scores = []
|
||||||
|
|
||||||
|
for result in results:
|
||||||
|
# For single-item scoring, logprobs are in output_token_ids_logprobs
|
||||||
|
output_logprobs = result["meta_info"].get("output_token_ids_logprobs", [])
|
||||||
|
|
||||||
|
if not output_logprobs or len(output_logprobs) == 0:
|
||||||
|
raise RuntimeError(
|
||||||
|
f"output_logprobs is empty for request {result['meta_info'].get('id', '<unknown>')}."
|
||||||
|
)
|
||||||
|
|
||||||
|
# Extract logprobs for the first (and only) position
|
||||||
|
logprobs = self._extract_logprobs_for_tokens(
|
||||||
|
output_logprobs[0], label_token_ids
|
||||||
|
)
|
||||||
|
score_list = self._convert_logprobs_to_scores(
|
||||||
|
logprobs, label_token_ids, apply_softmax
|
||||||
|
)
|
||||||
|
scores.append(score_list)
|
||||||
|
|
||||||
|
return scores
|
||||||
|
|
||||||
async def score_request(
|
async def score_request(
|
||||||
self,
|
self,
|
||||||
query: Optional[Union[str, List[int]]] = None,
|
query: Optional[Union[str, List[int]]] = None,
|
||||||
@@ -1688,7 +1887,29 @@ class TokenizerManager(TokenizerCommunicatorMixin):
|
|||||||
request: Optional[Any] = None,
|
request: Optional[Any] = None,
|
||||||
) -> List[List[float]]:
|
) -> List[List[float]]:
|
||||||
"""
|
"""
|
||||||
See Engine.score() for more details.
|
Score the probability of specified token IDs appearing after the given (query + item) pair.
|
||||||
|
|
||||||
|
This method supports two scoring approaches:
|
||||||
|
1. Single-Item scoring (default): Process each query+item pair independently
|
||||||
|
2. Multi-Item scoring: When multi_item_scoring_delimiter is set, combine query and
|
||||||
|
multiple items into a single sequence using delimiter for efficient processing.
|
||||||
|
Note: item_first parameter is ignored in multi-item scoring mode since it uses
|
||||||
|
a fixed format: query<delimiter>item1<delimiter>item2<delimiter>item3<delimiter>
|
||||||
|
|
||||||
|
Multi-item scoring works with both text and pre-tokenized inputs:
|
||||||
|
- Text: query<delimiter_text>item1<delimiter_text>item2<delimiter_text>item3<delimiter_text>
|
||||||
|
- Tokens: query<delimiter_token_id>item1<delimiter_token_id>item2<delimiter_token_id>item3<delimiter_token_id>
|
||||||
|
|
||||||
|
Args:
|
||||||
|
query: The query text or pre-tokenized query token IDs
|
||||||
|
items: The item text(s) or pre-tokenized item token IDs
|
||||||
|
label_token_ids: List of token IDs to compute probabilities for
|
||||||
|
apply_softmax: Whether to normalize probabilities using softmax
|
||||||
|
item_first: If True, prepend items to query. Ignored for multi-item scoring.
|
||||||
|
request: Optional FastAPI request object
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of lists containing probabilities for each item and each label token
|
||||||
"""
|
"""
|
||||||
if label_token_ids is None:
|
if label_token_ids is None:
|
||||||
raise ValueError("label_token_ids must be provided")
|
raise ValueError("label_token_ids must be provided")
|
||||||
@@ -1701,9 +1922,17 @@ class TokenizerManager(TokenizerCommunicatorMixin):
|
|||||||
f"Token ID {token_id} is out of vocabulary (vocab size: {vocab_size})"
|
f"Token ID {token_id} is out of vocabulary (vocab size: {vocab_size})"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# Check if multi-item scoring is enabled by presence of delimiter
|
||||||
|
use_multi_item_scoring = (
|
||||||
|
self.server_args.multi_item_scoring_delimiter is not None
|
||||||
|
and self.multi_item_delimiter_text is not None
|
||||||
|
)
|
||||||
|
|
||||||
batch_request = GenerateReqInput(
|
batch_request = GenerateReqInput(
|
||||||
token_ids_logprob=label_token_ids,
|
token_ids_logprob=label_token_ids,
|
||||||
return_logprob=True,
|
return_logprob=True,
|
||||||
|
# Set logprob_start_len=0 for multi-item scoring since we want logprobs at all delimiter positions
|
||||||
|
logprob_start_len=0 if use_multi_item_scoring else -1,
|
||||||
stream=False,
|
stream=False,
|
||||||
sampling_params={"max_new_tokens": 0},
|
sampling_params={"max_new_tokens": 0},
|
||||||
)
|
)
|
||||||
@@ -1715,12 +1944,23 @@ class TokenizerManager(TokenizerCommunicatorMixin):
|
|||||||
):
|
):
|
||||||
# Both query and items are text
|
# Both query and items are text
|
||||||
items_list = [items] if isinstance(items, str) else items
|
items_list = [items] if isinstance(items, str) else items
|
||||||
if item_first:
|
|
||||||
prompts = [f"{item}{query}" for item in items_list]
|
|
||||||
else:
|
|
||||||
prompts = [f"{query}{item}" for item in items_list]
|
|
||||||
|
|
||||||
batch_request.text = prompts
|
if use_multi_item_scoring:
|
||||||
|
# Multi-item scoring: create single prompt with delimiter text
|
||||||
|
# Always use format: query<delimiter>item1<delimiter>item2<delimiter>item3<delimiter>
|
||||||
|
# (item_first is ignored for multi-item scoring)
|
||||||
|
delimiter = self.multi_item_delimiter_text
|
||||||
|
combined_items = delimiter.join(items_list)
|
||||||
|
# Add final delimiter after the last item for logprob extraction
|
||||||
|
single_prompt = f"{query}{delimiter}{combined_items}{delimiter}"
|
||||||
|
batch_request.text = [single_prompt]
|
||||||
|
else:
|
||||||
|
# Single-item scoring: create separate prompts for each item
|
||||||
|
if item_first:
|
||||||
|
prompts = [f"{item}{query}" for item in items_list]
|
||||||
|
else:
|
||||||
|
prompts = [f"{query}{item}" for item in items_list]
|
||||||
|
batch_request.text = prompts
|
||||||
|
|
||||||
elif (
|
elif (
|
||||||
isinstance(query, list)
|
isinstance(query, list)
|
||||||
@@ -1729,61 +1969,38 @@ class TokenizerManager(TokenizerCommunicatorMixin):
|
|||||||
and isinstance(items[0], list)
|
and isinstance(items[0], list)
|
||||||
):
|
):
|
||||||
# Both query and items are token IDs
|
# Both query and items are token IDs
|
||||||
if item_first:
|
if use_multi_item_scoring:
|
||||||
input_ids_list = [item + query for item in items]
|
# Multi-item scoring: concatenate with delimiter token ID
|
||||||
|
# Format: query<delimiter_token_id>item1<delimiter_token_id>item2<delimiter_token_id>item3<delimiter_token_id>
|
||||||
|
delimiter_token_id = self.server_args.multi_item_scoring_delimiter
|
||||||
|
combined_input_ids = self._build_multi_item_token_sequence(
|
||||||
|
query, items, delimiter_token_id
|
||||||
|
)
|
||||||
|
batch_request.input_ids = [combined_input_ids]
|
||||||
else:
|
else:
|
||||||
input_ids_list = [query + item for item in items]
|
# Single-item scoring: process each item separately
|
||||||
|
if item_first:
|
||||||
batch_request.input_ids = input_ids_list
|
input_ids_list = [item + query for item in items]
|
||||||
|
else:
|
||||||
|
input_ids_list = [query + item for item in items]
|
||||||
|
batch_request.input_ids = input_ids_list
|
||||||
else:
|
else:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
"Invalid combination of query/items types for score_request."
|
"Invalid combination of query/items types for score_request."
|
||||||
)
|
)
|
||||||
|
|
||||||
results = await self.generate_request(batch_request, request).__anext__()
|
results = await self.generate_request(batch_request, request).__anext__()
|
||||||
scores = []
|
|
||||||
|
|
||||||
for result in results:
|
if use_multi_item_scoring:
|
||||||
# Get logprobs for each token
|
# Multi-item scoring: extract scores from input_token_ids_logprobs
|
||||||
logprobs = {}
|
return self._process_multi_item_scoring_results(
|
||||||
|
results, items, label_token_ids, apply_softmax, batch_request
|
||||||
# For scoring requests, we read from output_token_ids_logprobs since we want
|
)
|
||||||
# the logprobs for specific tokens mentioned in the label_token_ids at
|
else:
|
||||||
# the next position after the last token in the prompt
|
# Single-item scoring: process each result separately
|
||||||
output_logprobs = result["meta_info"].get("output_token_ids_logprobs", [])
|
return self._process_single_item_scoring_results(
|
||||||
|
results, label_token_ids, apply_softmax
|
||||||
# Check if output_logprobs is properly populated
|
)
|
||||||
if (
|
|
||||||
output_logprobs is None
|
|
||||||
or not output_logprobs
|
|
||||||
or len(output_logprobs) == 0
|
|
||||||
):
|
|
||||||
raise RuntimeError(
|
|
||||||
f"output_logprobs is empty for request {result['meta_info'].get('id', '<unknown>')}. "
|
|
||||||
"This indicates token_ids_logprobs were not computed properly for the scoring request."
|
|
||||||
)
|
|
||||||
|
|
||||||
for logprob, token_id, _ in output_logprobs[0]:
|
|
||||||
if token_id in label_token_ids:
|
|
||||||
logprobs[token_id] = logprob
|
|
||||||
|
|
||||||
# Get scores in order of label_token_ids
|
|
||||||
score_list = [
|
|
||||||
logprobs.get(token_id, float("-inf")) for token_id in label_token_ids
|
|
||||||
]
|
|
||||||
|
|
||||||
# Apply softmax to logprobs if needed
|
|
||||||
if apply_softmax:
|
|
||||||
score_list = torch.softmax(torch.tensor(score_list), dim=0).tolist()
|
|
||||||
else:
|
|
||||||
# Convert logprobs to probabilities if not using softmax
|
|
||||||
score_list = [
|
|
||||||
math.exp(x) if x != float("-inf") else 0.0 for x in score_list
|
|
||||||
]
|
|
||||||
|
|
||||||
scores.append(score_list)
|
|
||||||
|
|
||||||
return scores
|
|
||||||
|
|
||||||
async def watch_load_thread(self):
|
async def watch_load_thread(self):
|
||||||
# Only for dp_controller when dp_size > 1
|
# Only for dp_controller when dp_size > 1
|
||||||
|
|||||||
@@ -266,10 +266,16 @@ class TpModelWorker:
|
|||||||
|
|
||||||
if model_worker_batch.is_prefill_only:
|
if model_worker_batch.is_prefill_only:
|
||||||
# For prefill-only requests, create dummy token IDs on CPU
|
# For prefill-only requests, create dummy token IDs on CPU
|
||||||
batch_result.next_token_ids = torch.zeros_like(
|
# The size should match the batch size (number of sequences), not total tokens
|
||||||
model_worker_batch.input_ids, dtype=torch.long
|
batch_result.next_token_ids = torch.zeros(
|
||||||
|
len(model_worker_batch.seq_lens),
|
||||||
|
dtype=torch.long,
|
||||||
|
device=model_worker_batch.input_ids.device,
|
||||||
)
|
)
|
||||||
if model_worker_batch.return_logprob:
|
if (
|
||||||
|
model_worker_batch.return_logprob
|
||||||
|
and logits_output.next_token_logits is not None
|
||||||
|
):
|
||||||
# NOTE: Compute logprobs without full sampling
|
# NOTE: Compute logprobs without full sampling
|
||||||
self.model_runner.compute_logprobs_only(
|
self.model_runner.compute_logprobs_only(
|
||||||
logits_output, model_worker_batch
|
logits_output, model_worker_batch
|
||||||
|
|||||||
@@ -278,6 +278,9 @@ class ForwardBatch:
|
|||||||
can_run_dp_cuda_graph: bool = False
|
can_run_dp_cuda_graph: bool = False
|
||||||
global_forward_mode: Optional[ForwardMode] = None
|
global_forward_mode: Optional[ForwardMode] = None
|
||||||
|
|
||||||
|
# Whether this batch is prefill-only (no token generation needed)
|
||||||
|
is_prefill_only: bool = False
|
||||||
|
|
||||||
# Speculative decoding
|
# Speculative decoding
|
||||||
spec_info: Optional[SpecInput] = None
|
spec_info: Optional[SpecInput] = None
|
||||||
spec_algorithm: SpeculativeAlgorithm = None
|
spec_algorithm: SpeculativeAlgorithm = None
|
||||||
@@ -325,6 +328,7 @@ class ForwardBatch:
|
|||||||
is_extend_in_batch=batch.is_extend_in_batch,
|
is_extend_in_batch=batch.is_extend_in_batch,
|
||||||
can_run_dp_cuda_graph=batch.can_run_dp_cuda_graph,
|
can_run_dp_cuda_graph=batch.can_run_dp_cuda_graph,
|
||||||
global_forward_mode=batch.global_forward_mode,
|
global_forward_mode=batch.global_forward_mode,
|
||||||
|
is_prefill_only=batch.is_prefill_only,
|
||||||
lora_ids=batch.lora_ids,
|
lora_ids=batch.lora_ids,
|
||||||
sampling_info=batch.sampling_info,
|
sampling_info=batch.sampling_info,
|
||||||
req_to_token_pool=model_runner.req_to_token_pool,
|
req_to_token_pool=model_runner.req_to_token_pool,
|
||||||
|
|||||||
@@ -382,6 +382,12 @@ class ServerArgs:
|
|||||||
offload_prefetch_step: int = 1
|
offload_prefetch_step: int = 1
|
||||||
offload_mode: str = "cpu"
|
offload_mode: str = "cpu"
|
||||||
|
|
||||||
|
# Scoring configuration
|
||||||
|
# Delimiter token ID used to combine Query and Items into a single sequence for multi-item scoring.
|
||||||
|
# Format: Query<delimiter>Item1<delimiter>Item2<delimiter>...
|
||||||
|
# This enables efficient batch processing of multiple items against a single query.
|
||||||
|
multi_item_scoring_delimiter: Optional[Union[int]] = None
|
||||||
|
|
||||||
# Optimization/debug options
|
# Optimization/debug options
|
||||||
disable_radix_cache: bool = False
|
disable_radix_cache: bool = False
|
||||||
cuda_graph_max_bs: Optional[int] = None
|
cuda_graph_max_bs: Optional[int] = None
|
||||||
@@ -2334,7 +2340,13 @@ class ServerArgs:
|
|||||||
choices=["float32", "bfloat16"],
|
choices=["float32", "bfloat16"],
|
||||||
help="The data type of the SSM states in mamba cache.",
|
help="The data type of the SSM states in mamba cache.",
|
||||||
)
|
)
|
||||||
|
# Args for multi-item-scoring
|
||||||
|
parser.add_argument(
|
||||||
|
"--multi-item-scoring-delimiter",
|
||||||
|
type=int,
|
||||||
|
default=ServerArgs.multi_item_scoring_delimiter,
|
||||||
|
help="Delimiter token ID for multi-item scoring. Used to combine Query and Items into a single sequence: Query<delimiter>Item1<delimiter>Item2<delimiter>... This enables efficient batch processing of multiple items against a single query.",
|
||||||
|
)
|
||||||
# Hierarchical cache
|
# Hierarchical cache
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--enable-hierarchical-cache",
|
"--enable-hierarchical-cache",
|
||||||
@@ -3004,6 +3016,17 @@ class ServerArgs:
|
|||||||
"lof",
|
"lof",
|
||||||
], f"To use priority scheduling, schedule_policy must be 'fcfs' or 'lof'. '{self.schedule_policy}' is not supported."
|
], f"To use priority scheduling, schedule_policy must be 'fcfs' or 'lof'. '{self.schedule_policy}' is not supported."
|
||||||
|
|
||||||
|
# Check multi-item scoring
|
||||||
|
if self.multi_item_scoring_delimiter is not None:
|
||||||
|
assert self.disable_radix_cache, (
|
||||||
|
"Multi-item scoring requires radix cache to be disabled. "
|
||||||
|
"Please set --disable-radix-cache when using --multi-item-scoring-delimiter."
|
||||||
|
)
|
||||||
|
assert self.chunked_prefill_size == -1, (
|
||||||
|
"Multi-item scoring requires chunked prefill to be disabled. "
|
||||||
|
"Please set --chunked-prefill-size -1 when using --multi-item-scoring-delimiter."
|
||||||
|
)
|
||||||
|
|
||||||
def check_lora_server_args(self):
|
def check_lora_server_args(self):
|
||||||
assert self.max_loras_per_batch > 0, "max_loras_per_batch must be positive"
|
assert self.max_loras_per_batch > 0, "max_loras_per_batch must be positive"
|
||||||
|
|
||||||
|
|||||||
@@ -667,6 +667,7 @@ class TboForwardBatchPreparer:
|
|||||||
"can_run_dp_cuda_graph",
|
"can_run_dp_cuda_graph",
|
||||||
"dp_padding_mode",
|
"dp_padding_mode",
|
||||||
"global_forward_mode",
|
"global_forward_mode",
|
||||||
|
"is_prefill_only",
|
||||||
"spec_algorithm",
|
"spec_algorithm",
|
||||||
"capture_hidden_mode",
|
"capture_hidden_mode",
|
||||||
"padded_static_len",
|
"padded_static_len",
|
||||||
|
|||||||
@@ -295,6 +295,296 @@ class TestScoreAPI(CustomTestCase):
|
|||||||
)
|
)
|
||||||
self.assertFalse(request.stream, "Scoring requests should not stream")
|
self.assertFalse(request.stream, "Scoring requests should not stream")
|
||||||
|
|
||||||
|
def test_multi_item_scoring_basic(self):
|
||||||
|
"""Test basic multi-item scoring functionality."""
|
||||||
|
# Test with a simple query and items
|
||||||
|
query = "What is the capital of California? Answer Yes or No for each of the following options:"
|
||||||
|
items = ["Sacramento", "San Jose", "San Francisco"]
|
||||||
|
label_token_ids = [9454, 2753] # "Yes" and "No" tokens
|
||||||
|
|
||||||
|
# Get scores using SGLang
|
||||||
|
scores = self.engine.score(
|
||||||
|
query=query,
|
||||||
|
items=items,
|
||||||
|
label_token_ids=label_token_ids,
|
||||||
|
apply_softmax=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Verify we get the expected number of scores
|
||||||
|
self.assertEqual(len(scores), len(items), "Should get one score list per item")
|
||||||
|
|
||||||
|
# Verify each score list has the correct length
|
||||||
|
for i, score_list in enumerate(scores):
|
||||||
|
self.assertEqual(
|
||||||
|
len(score_list),
|
||||||
|
len(label_token_ids),
|
||||||
|
f"Item {i} should have {len(label_token_ids)} scores",
|
||||||
|
)
|
||||||
|
# Verify scores are probabilities (sum to 1)
|
||||||
|
self.assertAlmostEqual(
|
||||||
|
sum(score_list),
|
||||||
|
1.0,
|
||||||
|
places=6,
|
||||||
|
msg=f"Scores for item {i} should sum to 1",
|
||||||
|
)
|
||||||
|
# Verify all scores are non-negative
|
||||||
|
for j, score in enumerate(score_list):
|
||||||
|
self.assertGreaterEqual(
|
||||||
|
score, 0, f"Score {j} for item {i} should be non-negative"
|
||||||
|
)
|
||||||
|
|
||||||
|
def test_multi_item_scoring_consistency(self):
|
||||||
|
"""Test that multi-item scoring gives consistent results."""
|
||||||
|
query = "Choose the best option:"
|
||||||
|
items = ["Option A", "Option B", "Option C"]
|
||||||
|
label_token_ids = [1, 2, 3]
|
||||||
|
|
||||||
|
# Run the same test multiple times
|
||||||
|
scores1 = self.engine.score(
|
||||||
|
query=query,
|
||||||
|
items=items,
|
||||||
|
label_token_ids=label_token_ids,
|
||||||
|
apply_softmax=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
scores2 = self.engine.score(
|
||||||
|
query=query,
|
||||||
|
items=items,
|
||||||
|
label_token_ids=label_token_ids,
|
||||||
|
apply_softmax=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Results should be identical (deterministic)
|
||||||
|
self.assertEqual(len(scores1), len(scores2), "Should get same number of items")
|
||||||
|
for i, (s1, s2) in enumerate(zip(scores1, scores2)):
|
||||||
|
self.assertEqual(
|
||||||
|
len(s1), len(s2), f"Item {i} should have same number of scores"
|
||||||
|
)
|
||||||
|
for j, (score1, score2) in enumerate(zip(s1, s2)):
|
||||||
|
self.assertAlmostEqual(
|
||||||
|
score1,
|
||||||
|
score2,
|
||||||
|
places=6,
|
||||||
|
msg=f"Score {j} for item {i} should be identical",
|
||||||
|
)
|
||||||
|
|
||||||
|
def test_multi_item_scoring_different_sizes(self):
|
||||||
|
"""Test multi-item scoring with different numbers of items."""
|
||||||
|
query = "Rate each option:"
|
||||||
|
label_token_ids = [1, 2, 3, 4, 5]
|
||||||
|
|
||||||
|
# Test with different numbers of items
|
||||||
|
test_cases = [
|
||||||
|
["Single item"],
|
||||||
|
["Item 1", "Item 2"],
|
||||||
|
["A", "B", "C", "D"],
|
||||||
|
["X", "Y", "Z", "W", "V", "U"],
|
||||||
|
]
|
||||||
|
|
||||||
|
for items in test_cases:
|
||||||
|
with self.subTest(items=items):
|
||||||
|
scores = self.engine.score(
|
||||||
|
query=query,
|
||||||
|
items=items,
|
||||||
|
label_token_ids=label_token_ids,
|
||||||
|
apply_softmax=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
self.assertEqual(
|
||||||
|
len(scores), len(items), f"Should get {len(items)} score lists"
|
||||||
|
)
|
||||||
|
|
||||||
|
for i, score_list in enumerate(scores):
|
||||||
|
self.assertEqual(
|
||||||
|
len(score_list),
|
||||||
|
len(label_token_ids),
|
||||||
|
f"Item {i} should have {len(label_token_ids)} scores",
|
||||||
|
)
|
||||||
|
self.assertAlmostEqual(sum(score_list), 1.0, places=6)
|
||||||
|
|
||||||
|
def test_multi_item_scoring_empty_items(self):
|
||||||
|
"""Test multi-item scoring with empty items list."""
|
||||||
|
query = "Test query"
|
||||||
|
items = []
|
||||||
|
label_token_ids = [1, 2]
|
||||||
|
|
||||||
|
scores = self.engine.score(
|
||||||
|
query=query,
|
||||||
|
items=items,
|
||||||
|
label_token_ids=label_token_ids,
|
||||||
|
apply_softmax=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
self.assertEqual(len(scores), 0, "Should return empty list for empty items")
|
||||||
|
|
||||||
|
def test_multi_item_scoring_single_item(self):
|
||||||
|
"""Test multi-item scoring with single item (should work like regular scoring)."""
|
||||||
|
query = "Complete this sentence: The capital of France is"
|
||||||
|
items = ["Paris"]
|
||||||
|
label_token_ids = [1, 2, 3]
|
||||||
|
|
||||||
|
scores = self.engine.score(
|
||||||
|
query=query,
|
||||||
|
items=items,
|
||||||
|
label_token_ids=label_token_ids,
|
||||||
|
apply_softmax=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
self.assertEqual(len(scores), 1, "Should get one score list")
|
||||||
|
self.assertEqual(
|
||||||
|
len(scores[0]), len(label_token_ids), "Should have correct number of scores"
|
||||||
|
)
|
||||||
|
self.assertAlmostEqual(sum(scores[0]), 1.0, places=6)
|
||||||
|
|
||||||
|
def test_multi_item_scoring_different_queries(self):
|
||||||
|
"""Test multi-item scoring with different types of queries."""
|
||||||
|
items = ["Yes", "No"]
|
||||||
|
label_token_ids = [1, 2]
|
||||||
|
|
||||||
|
test_queries = [
|
||||||
|
"Is this true?",
|
||||||
|
"Choose the correct answer:",
|
||||||
|
"What is the best option?",
|
||||||
|
"Select all that apply:",
|
||||||
|
"", # Empty query
|
||||||
|
]
|
||||||
|
|
||||||
|
for query in test_queries:
|
||||||
|
with self.subTest(query=query):
|
||||||
|
scores = self.engine.score(
|
||||||
|
query=query,
|
||||||
|
items=items,
|
||||||
|
label_token_ids=label_token_ids,
|
||||||
|
apply_softmax=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
self.assertEqual(
|
||||||
|
len(scores),
|
||||||
|
len(items),
|
||||||
|
f"Should get {len(items)} score lists for query: '{query}'",
|
||||||
|
)
|
||||||
|
|
||||||
|
for i, score_list in enumerate(scores):
|
||||||
|
self.assertEqual(len(score_list), len(label_token_ids))
|
||||||
|
self.assertAlmostEqual(sum(score_list), 1.0, places=6)
|
||||||
|
|
||||||
|
def test_multi_item_scoring_different_label_tokens(self):
|
||||||
|
"""Test multi-item scoring with different label token sets."""
|
||||||
|
query = "Choose the best option:"
|
||||||
|
items = ["Option A", "Option B"]
|
||||||
|
|
||||||
|
test_label_tokens = [
|
||||||
|
[1, 2], # Two tokens
|
||||||
|
[1, 2, 3, 4], # Four tokens
|
||||||
|
[1], # Single token
|
||||||
|
[1, 2, 3, 4, 5, 6, 7, 8], # Many tokens
|
||||||
|
]
|
||||||
|
|
||||||
|
for label_token_ids in test_label_tokens:
|
||||||
|
with self.subTest(label_tokens=label_token_ids):
|
||||||
|
scores = self.engine.score(
|
||||||
|
query=query,
|
||||||
|
items=items,
|
||||||
|
label_token_ids=label_token_ids,
|
||||||
|
apply_softmax=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
self.assertEqual(len(scores), len(items))
|
||||||
|
|
||||||
|
for i, score_list in enumerate(scores):
|
||||||
|
self.assertEqual(
|
||||||
|
len(score_list),
|
||||||
|
len(label_token_ids),
|
||||||
|
f"Item {i} should have {len(label_token_ids)} scores",
|
||||||
|
)
|
||||||
|
self.assertAlmostEqual(sum(score_list), 1.0, places=6)
|
||||||
|
|
||||||
|
def test_multi_item_scoring_without_softmax(self):
|
||||||
|
"""Test multi-item scoring without softmax normalization."""
|
||||||
|
query = "Rate each option:"
|
||||||
|
items = ["Good", "Bad", "Neutral"]
|
||||||
|
label_token_ids = [1, 2, 3]
|
||||||
|
|
||||||
|
scores = self.engine.score(
|
||||||
|
query=query,
|
||||||
|
items=items,
|
||||||
|
label_token_ids=label_token_ids,
|
||||||
|
apply_softmax=False, # No softmax
|
||||||
|
)
|
||||||
|
|
||||||
|
self.assertEqual(len(scores), len(items))
|
||||||
|
|
||||||
|
for i, score_list in enumerate(scores):
|
||||||
|
self.assertEqual(len(score_list), len(label_token_ids))
|
||||||
|
# Without softmax, scores don't need to sum to 1
|
||||||
|
# But they should still be valid logits/probabilities
|
||||||
|
for j, score in enumerate(score_list):
|
||||||
|
self.assertIsInstance(
|
||||||
|
score, (int, float), f"Score {j} for item {i} should be numeric"
|
||||||
|
)
|
||||||
|
|
||||||
|
def test_multi_item_scoring_large_batch(self):
|
||||||
|
"""Test multi-item scoring with a large number of items."""
|
||||||
|
query = "Classify each item:"
|
||||||
|
items = [f"Item {i}" for i in range(20)] # 20 items
|
||||||
|
label_token_ids = [1, 2, 3]
|
||||||
|
|
||||||
|
scores = self.engine.score(
|
||||||
|
query=query,
|
||||||
|
items=items,
|
||||||
|
label_token_ids=label_token_ids,
|
||||||
|
apply_softmax=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
self.assertEqual(len(scores), len(items), "Should handle large batches")
|
||||||
|
|
||||||
|
for i, score_list in enumerate(scores):
|
||||||
|
self.assertEqual(len(score_list), len(label_token_ids))
|
||||||
|
self.assertAlmostEqual(sum(score_list), 1.0, places=6)
|
||||||
|
|
||||||
|
def test_multi_item_scoring_unicode(self):
|
||||||
|
"""Test multi-item scoring with unicode characters."""
|
||||||
|
query = "选择最佳选项:"
|
||||||
|
items = ["选项A", "选项B", "选项C"]
|
||||||
|
label_token_ids = [1, 2, 3]
|
||||||
|
|
||||||
|
scores = self.engine.score(
|
||||||
|
query=query,
|
||||||
|
items=items,
|
||||||
|
label_token_ids=label_token_ids,
|
||||||
|
apply_softmax=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
self.assertEqual(len(scores), len(items))
|
||||||
|
|
||||||
|
for i, score_list in enumerate(scores):
|
||||||
|
self.assertEqual(len(score_list), len(label_token_ids))
|
||||||
|
self.assertAlmostEqual(sum(score_list), 1.0, places=6)
|
||||||
|
|
||||||
|
def test_multi_item_scoring_error_handling(self):
|
||||||
|
"""Test multi-item scoring error handling."""
|
||||||
|
query = "Test query"
|
||||||
|
items = ["Item 1", "Item 2"]
|
||||||
|
label_token_ids = [1, 2]
|
||||||
|
|
||||||
|
# Test with invalid label_token_ids
|
||||||
|
with self.assertRaises((ValueError, TypeError)):
|
||||||
|
self.engine.score(
|
||||||
|
query=query,
|
||||||
|
items=items,
|
||||||
|
label_token_ids="invalid", # Should be list of ints
|
||||||
|
apply_softmax=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Test with None items
|
||||||
|
with self.assertRaises((ValueError, TypeError)):
|
||||||
|
self.engine.score(
|
||||||
|
query=query,
|
||||||
|
items=None,
|
||||||
|
label_token_ids=label_token_ids,
|
||||||
|
apply_softmax=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
unittest.main()
|
unittest.main()
|
||||||
|
|||||||
Reference in New Issue
Block a user