[Generative Score API] Multi-Item scoring with custom attention mask. (#10979)

This commit is contained in:
Sundara Raman Ramachandran
2025-10-08 18:47:32 -07:00
committed by GitHub
parent e22b13c569
commit 53bd00d975
10 changed files with 1121 additions and 129 deletions

View File

@@ -7,6 +7,7 @@ FlashInfer is faster and Triton is easier to customize.
Each backend supports two operators: extend (i.e. prefill with cached prefix) and decode.
"""
import logging
import os
from dataclasses import dataclass
from enum import Enum, auto
@@ -16,11 +17,11 @@ from typing import TYPE_CHECKING, Callable, List, Optional, Union
import torch
if os.environ["SGLANG_ENABLE_TORCH_COMPILE"] == "1":
import logging
torch._logging.set_logs(dynamo=logging.ERROR)
torch._dynamo.config.suppress_errors = True
logger = logging.getLogger(__name__)
from sglang.global_config import global_config
from sglang.srt.layers.attention.base_attn_backend import AttentionBackend
from sglang.srt.layers.attention.utils import create_flashinfer_kv_indices_triton
@@ -58,6 +59,36 @@ class WrapperDispatch(Enum):
CROSS_ATTENTION = auto()
@dataclass
class MultiItemScoringParams:
"""Parameters for multi-item scoring in attention computation.
Used when processing sequences with multiple items separated by delimiters,
where each item needs specific attention patterns that respect item boundaries.
Attributes:
prefix_len_ptr: A uint32 1D tensor indicating the prefix length of each prompt.
The tensor size is equal to the batch size.
token_pos_in_items_ptr: A uint16 1D tensor indicating the token position of each item
starting from 0 (delimiter) for each item. For batch size > 1,
sequences are concatenated with zero padding to ensure same length.
token_pos_in_items_len: Zero padding length for token_pos_in_items_ptr to handle
batch_size > 1 case. Defines the padded length for each sequence.
max_item_len_ptr: A uint16 tensor containing the max token length of all items
for each prompt in the batch.
"""
prefix_len_ptr: Optional[torch.Tensor] = None
token_pos_in_items_ptr: Optional[torch.Tensor] = None
token_pos_in_items_len: int = 0
max_item_len_ptr: Optional[torch.Tensor] = None
def is_enabled(self) -> bool:
"""Check if multi-item scoring is enabled."""
return self.prefix_len_ptr is not None
@dataclass
class DecodeMetadata:
decode_wrappers: List[BatchDecodeWithPagedKVCacheWrapper]
@@ -68,6 +99,7 @@ class PrefillMetadata:
prefill_wrappers: List[BatchPrefillWithPagedKVCacheWrapper]
use_ragged: bool
extend_no_prefix: bool
multi_item_params: Optional[MultiItemScoringParams] = None
# Reuse this workspace buffer across all flashinfer wrappers
@@ -90,6 +122,11 @@ class FlashInferAttnBackend(AttentionBackend):
):
super().__init__()
# Store multi-item scoring delimiter for efficient access
self.multi_item_scoring_delimiter = (
model_runner.server_args.multi_item_scoring_delimiter
)
# Parse constants
self.decode_use_tensor_cores = should_use_tensor_core(
kv_cache_dtype=model_runner.kv_cache_dtype,
@@ -229,10 +266,133 @@ class FlashInferAttnBackend(AttentionBackend):
# Other metadata
self.forward_metadata: Union[PrefillMetadata, DecodeMetadata] = None
self.decode_cuda_graph_metadata = {}
self.prefill_cuda_graph_metadata = {} # For verify
self.draft_extend_cuda_graph_metadata = {} # For draft extend
def _process_multi_item_scoring(
self, forward_batch: ForwardBatch
) -> MultiItemScoringParams:
"""Process multi-item scoring tensors for FlashInfer attention.
This method handles sequences containing multiple "items" separated by delimiter tokens,
where each item needs specific attention patterns that respect item boundaries.
The method produces four key tensors for FlashInfer:
- prefix_len_ptr: uint32 tensor with prefix length for each prompt in batch
- token_pos_in_items_ptr: uint16 tensor with token positions starting from 0 at delimiters
- token_pos_in_items_len: padding length for batch processing
- max_item_len_ptr: uint16 tensor with max item length for each prompt
Args:
forward_batch: The forward batch containing input sequences and delimiter info
Returns:
MultiItemScoringParams: The processed multi-item scoring parameters
Examples:
Following FlashInfer definition: for 3 items of length 3, 2, 4 respectively:
token_pos_in_items_ptr = [0, 1, 2, 3, 0, 1, 2, 0, 1, 2, 3, 4, 0]
Case 1: Single sequence
Text: "What is the capital of France? <delim> London <delim> Paris <delim> Berlin <delim>"
Tokens: [What, is, the, capital, of, France, ?, <delim>, London, <delim>, Paris, <delim>, Berlin, <delim>]
Indices: [ 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13]
- prefix_len_ptr: [7] (query length before first delimiter)
- token_pos_in_items_ptr: [0, 1, 0, 1, 0, 1, 0] (delim=0, London=1, delim=0, Paris=1, delim=0, Berlin=1, delim=0)
- token_pos_in_items_len: 7 (actual length)
- max_item_len_ptr: [1] (max item length is 1 token - all options are single tokens)
Case 2: Batch processing (batch_size=2)
Sequence 1: 2 items of length 2, 1 → [0, 1, 2, 0, 1, 0] (6 elements)
Sequence 2: 3 items of length 1, 3, 2 → [0, 1, 0, 1, 2, 3, 0, 1, 2, 0] (10 elements)
After padding both to length 10:
- token_pos_in_items_ptr: [0, 1, 2, 0, 1, 0, 0, 0, 0, 0, 0, 1, 0, 1, 2, 3, 0, 1, 2, 0]
- token_pos_in_items_len: 10 (padded length for batch processing)
- max_item_len_ptr: [2, 3] (max lengths per sequence)
"""
delimiter = self.multi_item_scoring_delimiter
if delimiter is None or forward_batch.forward_mode == ForwardMode.DECODE:
return MultiItemScoringParams()
delimiter_mask = forward_batch.input_ids == delimiter
prefix_cache_lens = getattr(forward_batch, "extend_prefix_lens", None)
extend_seq_lens = getattr(forward_batch, "extend_seq_lens", None)
prefix_len_ptr, token_pos_in_items_ptr = [], []
token_pos_in_items_len = 0
# If no extend_seq_lens, treat whole batch as one sequence
if extend_seq_lens is None or len(extend_seq_lens) <= 1:
extend_seq_lens = [forward_batch.input_ids.size(0)]
seq_start = 0
for i, seq_len in enumerate(extend_seq_lens):
seq_end = seq_start + seq_len
mask = delimiter_mask[seq_start:seq_end]
pos = forward_batch.positions[seq_start:seq_end]
delimiter_indices = torch.nonzero(mask, as_tuple=True)[0]
if len(delimiter_indices) > 0:
first_delim = delimiter_indices[0]
# Prefix length: store as scalar
prefix_len = first_delim + (
prefix_cache_lens[i] if prefix_cache_lens is not None else 0
)
prefix_len_ptr.append(
prefix_len.item() if torch.is_tensor(prefix_len) else prefix_len
)
# Compute relative positions within items after delimiters
diff = pos[first_delim:] - torch.cummax(mask[first_delim:], 0)[1]
token_pos = (diff - pos[first_delim]).to(torch.uint16)
token_pos_in_items_ptr.append(token_pos)
# Update forward_batch positions in-place
pos[first_delim:] = diff - 1
forward_batch.positions[seq_start:seq_end] = pos
seq_start = seq_end
# Pad token_pos_in_items_ptr for batch processing
if token_pos_in_items_ptr:
token_pos_in_items_len = max(t.numel() for t in token_pos_in_items_ptr)
device = forward_batch.input_ids.device
token_pos_in_items_ptr = [
torch.cat(
[
t,
torch.zeros(
token_pos_in_items_len - t.numel(),
dtype=torch.uint16,
device=device,
),
]
)
for t in token_pos_in_items_ptr
]
if not prefix_len_ptr or not token_pos_in_items_ptr:
return MultiItemScoringParams()
# Build final params
device = forward_batch.input_ids.device
return MultiItemScoringParams(
prefix_len_ptr=torch.tensor(
prefix_len_ptr, dtype=torch.uint32, device=device
),
token_pos_in_items_ptr=torch.cat(token_pos_in_items_ptr, dim=0),
token_pos_in_items_len=token_pos_in_items_len & 0xFFFFFFFF,
max_item_len_ptr=torch.stack(
[
t.to(torch.int32).max().to(torch.uint16)
for t in token_pos_in_items_ptr
],
dim=0,
),
)
def init_forward_metadata(self, forward_batch: ForwardBatch):
if forward_batch.forward_mode.is_decode_or_idle():
self.indices_updater_decode.update(
@@ -280,13 +440,26 @@ class FlashInferAttnBackend(AttentionBackend):
else:
prefix_lens = forward_batch.extend_prefix_lens
if self.is_multimodal:
# Disable ragged wrapper and ensure prefix handling for multimodal and multi-item scoring
if self.is_multimodal or self.multi_item_scoring_delimiter is not None:
# use_ragged = False: Multi-item scoring requires the paged wrapper because:
# 1. Ragged wrapper doesn't support the specialized multi-item parameters
# (prefix_len_ptr, token_pos_in_items_ptr, etc.)
# 2. Paged wrapper provides better control over attention masking needed
# for respecting item boundaries in multi-item sequences
# 3. Custom masking logic conflicts with ragged wrapper's assumptions
use_ragged = False
extend_no_prefix = False
else:
use_ragged = not self.enable_deterministic
extend_no_prefix = not any(forward_batch.extend_prefix_lens_cpu)
# Process multi-item scoring in attention backend instead of ForwardBatch
multi_item_params = MultiItemScoringParams()
if self.multi_item_scoring_delimiter is not None:
# Use new backend-specific implementation
multi_item_params = self._process_multi_item_scoring(forward_batch)
self.indices_updater_prefill.update(
forward_batch.req_pool_indices,
forward_batch.seq_lens,
@@ -298,9 +471,13 @@ class FlashInferAttnBackend(AttentionBackend):
encoder_lens=forward_batch.encoder_lens,
spec_info=None,
fixed_split_size=self.prefill_split_tile_size,
multi_item_params=multi_item_params,
)
self.forward_metadata = PrefillMetadata(
self.prefill_wrappers_paged, use_ragged, extend_no_prefix
self.prefill_wrappers_paged,
use_ragged,
extend_no_prefix,
multi_item_params,
)
def init_cuda_graph_state(
@@ -531,7 +708,20 @@ class FlashInferAttnBackend(AttentionBackend):
forward_batch.token_to_kv_pool.get_kv_buffer(layer.layer_id),
causal=not layer.is_cross_attention,
sm_scale=layer.scaling,
window_left=layer.sliding_window_size,
# Disable sliding window attention for multi-item scoring:
# - Sliding window could cut across item boundaries, breaking semantic coherence
# - Multi-item sequences need full attention to properly handle delimiter tokens
# - Specialized multi-item parameters (prefix_len_ptr, token_pos_in_items_ptr)
# provide more precise attention control than simple sliding windows
# - Item-aware masking takes precedence over window-based masking
window_left=(
layer.sliding_window_size
if not (
self.forward_metadata.multi_item_params
and self.forward_metadata.multi_item_params.is_enabled()
)
else -1
),
logits_soft_cap=logits_soft_cap,
# Must use _float to avoid device-to-host copy that breaks cuda graph capture.
k_scale=layer.k_scale_float,
@@ -952,6 +1142,7 @@ class FlashInferIndicesUpdaterPrefill:
encoder_lens: Optional[torch.Tensor],
spec_info: Optional[SpecInput],
fixed_split_size: Optional[int] = None,
multi_item_params: Optional[MultiItemScoringParams] = None,
):
if use_ragged:
# TODO: remove this device sync, we can use forward_batch.extend_prefix_lens_cpu
@@ -976,6 +1167,7 @@ class FlashInferIndicesUpdaterPrefill:
use_ragged,
spec_info,
fixed_split_size=fixed_split_size,
multi_item_params=multi_item_params,
)
def update_sliding_window(
@@ -990,6 +1182,7 @@ class FlashInferIndicesUpdaterPrefill:
encoder_lens: Optional[torch.Tensor],
spec_info: Optional[SpecInput],
fixed_split_size: Optional[int] = None,
multi_item_params: Optional[MultiItemScoringParams] = None,
):
for wrapper_id in range(2):
if wrapper_id == 0:
@@ -1023,6 +1216,7 @@ class FlashInferIndicesUpdaterPrefill:
use_ragged,
spec_info,
use_sliding_window_kv_pool=use_sliding_window_kv_pool,
multi_item_params=multi_item_params,
)
def update_cross_attention(
@@ -1037,6 +1231,7 @@ class FlashInferIndicesUpdaterPrefill:
encoder_lens: Optional[torch.Tensor],
spec_info: Optional[SpecInput],
fixed_split_size: Optional[int] = None,
multi_item_params: Optional[MultiItemScoringParams] = None,
):
for wrapper_id in range(2):
if wrapper_id == 0:
@@ -1063,6 +1258,7 @@ class FlashInferIndicesUpdaterPrefill:
self.qo_indptr[wrapper_id],
use_ragged,
spec_info,
multi_item_params=multi_item_params,
)
def call_begin_forward(
@@ -1081,6 +1277,7 @@ class FlashInferIndicesUpdaterPrefill:
spec_info: Optional[SpecInput],
use_sliding_window_kv_pool: bool = False,
fixed_split_size: Optional[int] = None,
multi_item_params: Optional[MultiItemScoringParams] = None,
):
bs = len(seq_lens)
if spec_info is None:
@@ -1136,6 +1333,22 @@ class FlashInferIndicesUpdaterPrefill:
)
# cached part
# Conditionally set multi-item parameters
if multi_item_params is not None and multi_item_params.is_enabled():
# Multi-item scoring is active - use specialized parameters and disable generic custom_mask
use_custom_mask = None
prefix_len_ptr = multi_item_params.prefix_len_ptr
token_pos_in_items_ptr = multi_item_params.token_pos_in_items_ptr
token_pos_in_items_len = multi_item_params.token_pos_in_items_len
max_item_len_ptr = multi_item_params.max_item_len_ptr
else:
# No multi-item scoring - use standard parameters
use_custom_mask = custom_mask
prefix_len_ptr = None
token_pos_in_items_ptr = None
token_pos_in_items_len = 0
max_item_len_ptr = None
wrapper_paged.begin_forward(
qo_indptr,
kv_indptr,
@@ -1147,9 +1360,13 @@ class FlashInferIndicesUpdaterPrefill:
1,
q_data_type=self.q_data_type,
kv_data_type=self.data_type,
custom_mask=custom_mask,
custom_mask=use_custom_mask,
non_blocking=True,
fixed_split_size=fixed_split_size,
prefix_len_ptr=prefix_len_ptr,
token_pos_in_items_ptr=token_pos_in_items_ptr,
token_pos_in_items_len=token_pos_in_items_len,
max_item_len_ptr=max_item_len_ptr,
)

View File

@@ -60,7 +60,8 @@ _is_npu = is_npu()
class LogitsProcessorOutput:
## Part 1: This part will be assigned in python/sglang/srt/layers/logits_processor.py::LogitsProcessor
# The logits of the next tokens. shape: [#seq, vocab_size]
next_token_logits: torch.Tensor
# Can be None for certain prefill-only requests (e.g., multi-item scoring) that don't need next token generation
next_token_logits: Optional[torch.Tensor]
# Used by speculative decoding (EAGLE)
# The last hidden layers
hidden_states: Optional[torch.Tensor] = None
@@ -85,7 +86,10 @@ class LogitsProcessorOutput:
input_top_logprobs_val: List = None
input_top_logprobs_idx: List = None
# The logprobs and ids of the requested token ids in input positions. shape: [#seq, n] (n is the number of requested token ids)
input_token_ids_logprobs_val: Optional[List] = None
# Can contain either lists or GPU tensors (for delayed GPU-to-CPU transfer optimization)
input_token_ids_logprobs_val: Optional[List[Union[List[float], torch.Tensor]]] = (
None
)
input_token_ids_logprobs_idx: Optional[List] = None
@@ -127,6 +131,9 @@ class LogitsMetadata:
# for padding
padded_static_len: int = -1
# Whether this batch is prefill-only (no token generation needed)
is_prefill_only: bool = False
@classmethod
def from_forward_batch(cls, forward_batch: ForwardBatch):
if (
@@ -169,6 +176,7 @@ class LogitsMetadata:
token_ids_logprobs=forward_batch.token_ids_logprobs,
extend_input_logprob_token_ids_gpu=forward_batch.extend_input_logprob_token_ids_gpu,
padded_static_len=forward_batch.padded_static_len,
is_prefill_only=forward_batch.is_prefill_only,
global_num_tokens_gpu=forward_batch.global_num_tokens_gpu,
dp_local_start_pos=forward_batch.dp_local_start_pos,
dp_local_num_tokens=forward_batch.dp_local_num_tokens,
@@ -247,6 +255,108 @@ class LogitsProcessor(nn.Module):
"debug_tensor_dump_output_folder", None
)
def compute_logprobs_for_multi_item_scoring(
self,
input_ids,
hidden_states,
lm_head: VocabParallelEmbedding,
logits_metadata: Union[LogitsMetadata, ForwardBatch],
delimiter_token: int,
):
"""
Compute logprobs for multi-item scoring using delimiter-based token extraction.
This method is designed for scenarios where you want to score multiple items/candidates
against a single query by combining them into one sequence separated by delimiters.
Sequence format: Query<delimiter>Item1<delimiter>Item2<delimiter>...
Scoring positions: Extracts logprobs at positions before each <delimiter>
Args:
input_ids (torch.Tensor): Input token IDs containing query and items separated by delimiters.
Shape: [total_sequence_length] for single request or [batch_total_length] for batch.
hidden_states (torch.Tensor): Hidden states from the model.
Shape: [sequence_length, hidden_dim].
lm_head (VocabParallelEmbedding): Language model head for computing logits.
logits_metadata (Union[LogitsMetadata, ForwardBatch]): Metadata containing batch info
and token ID specifications for logprob extraction.
delimiter_token (int): Token ID used as delimiter between query and items.
Returns:
LogitsProcessorOutput: Contains:
- next_token_logits: None (not needed for scoring-only requests)
- input_token_logprobs: Logprobs of delimiter tokens at scoring positions
- input_top_logprobs_val: Top-k logprobs at delimiter positions (if requested)
- input_top_logprobs_idx: Top-k token indices at delimiter positions (if requested)
- input_token_ids_logprobs_val: Logprobs for user-requested token IDs (if any)
- input_token_ids_logprobs_idx: Indices for user-requested token IDs (if any)
"""
multi_item_indices = (input_ids == delimiter_token).nonzero(as_tuple=True)[
0
] - 1
# Extract hidden states at delimiter positions for multi-item scoring
sliced_hidden = hidden_states[multi_item_indices]
sliced_logits = self._get_logits(sliced_hidden, lm_head, logits_metadata)
sliced_logprobs = torch.nn.functional.log_softmax(sliced_logits, dim=-1)
# Initialize return values
input_token_ids_logprobs_val = []
input_token_ids_logprobs_idx = []
input_top_logprobs_val = None
input_top_logprobs_idx = None
# Recalculate extend_logprob_pruned_lens_cpu to match delimiter counts per request
# Original contains sequence lengths, but we need delimiter counts for sliced_logprobs
if (
logits_metadata.token_ids_logprobs
or logits_metadata.extend_return_top_logprob
):
logits_metadata.extend_logprob_pruned_lens_cpu = []
if logits_metadata.extend_seq_lens_cpu is not None:
# Multi-request batch: count delimiters per request
input_pt = 0
for req_seq_len in logits_metadata.extend_seq_lens_cpu:
req_input_ids = input_ids[input_pt : input_pt + req_seq_len]
delimiter_count = (req_input_ids == delimiter_token).sum().item()
logits_metadata.extend_logprob_pruned_lens_cpu.append(
delimiter_count
)
input_pt += req_seq_len
else:
# Single request case: one request gets all delimiters
total_delimiters = (input_ids == delimiter_token).sum().item()
logits_metadata.extend_logprob_pruned_lens_cpu = [total_delimiters]
# Get the logprobs of specified token ids
if logits_metadata.extend_token_ids_logprob:
(
input_token_ids_logprobs_val,
input_token_ids_logprobs_idx,
) = self.get_token_ids_logprobs(
sliced_logprobs, logits_metadata, delay_cpu_copy=True
)
# Get the logprob of top-k tokens
if logits_metadata.extend_return_top_logprob:
(
input_top_logprobs_val,
input_top_logprobs_idx,
) = self.get_top_logprobs(sliced_logprobs, logits_metadata)
# For input_token_logprobs, use delimiter token logprobs
input_token_logprobs = sliced_logprobs[:, delimiter_token]
return LogitsProcessorOutput(
next_token_logits=None, # Multi-item scoring doesn't need next token logits
input_token_logprobs=input_token_logprobs,
input_top_logprobs_val=input_top_logprobs_val,
input_top_logprobs_idx=input_top_logprobs_idx,
input_token_ids_logprobs_val=input_token_ids_logprobs_val,
input_token_ids_logprobs_idx=input_token_ids_logprobs_idx,
)
def forward(
self,
input_ids,
@@ -257,6 +367,16 @@ class LogitsProcessor(nn.Module):
) -> LogitsProcessorOutput:
if isinstance(logits_metadata, ForwardBatch):
logits_metadata = LogitsMetadata.from_forward_batch(logits_metadata)
# Check if multi-item scoring is enabled via server args (only for prefill-only requests)
multi_item_delimiter = global_server_args_dict.get(
"multi_item_scoring_delimiter"
)
if multi_item_delimiter is not None and logits_metadata.is_prefill_only:
return self.compute_logprobs_for_multi_item_scoring(
input_ids, hidden_states, lm_head, logits_metadata, multi_item_delimiter
)
# Get the last hidden states and last logits for the next token prediction
if (
logits_metadata.forward_mode.is_decode_or_idle()
@@ -584,7 +704,9 @@ class LogitsProcessor(nn.Module):
@staticmethod
def get_token_ids_logprobs(
all_logprobs: torch.Tensor, logits_metadata: LogitsMetadata
all_logprobs: torch.Tensor,
logits_metadata: LogitsMetadata,
delay_cpu_copy: bool = False,
):
input_token_ids_logprobs_val, input_token_ids_logprobs_idx = [], []
pt = 0
@@ -597,9 +719,17 @@ class LogitsProcessor(nn.Module):
input_token_ids_logprobs_idx.append([])
continue
input_token_ids_logprobs_val.append(
[all_logprobs[pt + j, token_ids].tolist() for j in range(pruned_len)]
)
position_logprobs = all_logprobs[
pt : pt + pruned_len, token_ids
] # Shape: [pruned_len, num_tokens]
if delay_cpu_copy:
# Keep as tensor to delay GPU-to-CPU transfer
input_token_ids_logprobs_val.append(position_logprobs)
else:
# Convert to list immediately (default behavior)
input_token_ids_logprobs_val.append(position_logprobs.tolist())
input_token_ids_logprobs_idx.append([token_ids for _ in range(pruned_len)])
pt += pruned_len