[Generative Score API] Multi-Item scoring with custom attention mask. (#10979)
This commit is contained in:
committed by
GitHub
parent
e22b13c569
commit
53bd00d975
@@ -7,6 +7,7 @@ FlashInfer is faster and Triton is easier to customize.
|
||||
Each backend supports two operators: extend (i.e. prefill with cached prefix) and decode.
|
||||
"""
|
||||
|
||||
import logging
|
||||
import os
|
||||
from dataclasses import dataclass
|
||||
from enum import Enum, auto
|
||||
@@ -16,11 +17,11 @@ from typing import TYPE_CHECKING, Callable, List, Optional, Union
|
||||
import torch
|
||||
|
||||
if os.environ["SGLANG_ENABLE_TORCH_COMPILE"] == "1":
|
||||
import logging
|
||||
|
||||
torch._logging.set_logs(dynamo=logging.ERROR)
|
||||
torch._dynamo.config.suppress_errors = True
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
from sglang.global_config import global_config
|
||||
from sglang.srt.layers.attention.base_attn_backend import AttentionBackend
|
||||
from sglang.srt.layers.attention.utils import create_flashinfer_kv_indices_triton
|
||||
@@ -58,6 +59,36 @@ class WrapperDispatch(Enum):
|
||||
CROSS_ATTENTION = auto()
|
||||
|
||||
|
||||
@dataclass
|
||||
class MultiItemScoringParams:
|
||||
"""Parameters for multi-item scoring in attention computation.
|
||||
|
||||
Used when processing sequences with multiple items separated by delimiters,
|
||||
where each item needs specific attention patterns that respect item boundaries.
|
||||
|
||||
Attributes:
|
||||
prefix_len_ptr: A uint32 1D tensor indicating the prefix length of each prompt.
|
||||
The tensor size is equal to the batch size.
|
||||
token_pos_in_items_ptr: A uint16 1D tensor indicating the token position of each item
|
||||
starting from 0 (delimiter) for each item. For batch size > 1,
|
||||
sequences are concatenated with zero padding to ensure same length.
|
||||
token_pos_in_items_len: Zero padding length for token_pos_in_items_ptr to handle
|
||||
batch_size > 1 case. Defines the padded length for each sequence.
|
||||
max_item_len_ptr: A uint16 tensor containing the max token length of all items
|
||||
for each prompt in the batch.
|
||||
|
||||
"""
|
||||
|
||||
prefix_len_ptr: Optional[torch.Tensor] = None
|
||||
token_pos_in_items_ptr: Optional[torch.Tensor] = None
|
||||
token_pos_in_items_len: int = 0
|
||||
max_item_len_ptr: Optional[torch.Tensor] = None
|
||||
|
||||
def is_enabled(self) -> bool:
|
||||
"""Check if multi-item scoring is enabled."""
|
||||
return self.prefix_len_ptr is not None
|
||||
|
||||
|
||||
@dataclass
|
||||
class DecodeMetadata:
|
||||
decode_wrappers: List[BatchDecodeWithPagedKVCacheWrapper]
|
||||
@@ -68,6 +99,7 @@ class PrefillMetadata:
|
||||
prefill_wrappers: List[BatchPrefillWithPagedKVCacheWrapper]
|
||||
use_ragged: bool
|
||||
extend_no_prefix: bool
|
||||
multi_item_params: Optional[MultiItemScoringParams] = None
|
||||
|
||||
|
||||
# Reuse this workspace buffer across all flashinfer wrappers
|
||||
@@ -90,6 +122,11 @@ class FlashInferAttnBackend(AttentionBackend):
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
# Store multi-item scoring delimiter for efficient access
|
||||
self.multi_item_scoring_delimiter = (
|
||||
model_runner.server_args.multi_item_scoring_delimiter
|
||||
)
|
||||
|
||||
# Parse constants
|
||||
self.decode_use_tensor_cores = should_use_tensor_core(
|
||||
kv_cache_dtype=model_runner.kv_cache_dtype,
|
||||
@@ -229,10 +266,133 @@ class FlashInferAttnBackend(AttentionBackend):
|
||||
|
||||
# Other metadata
|
||||
self.forward_metadata: Union[PrefillMetadata, DecodeMetadata] = None
|
||||
|
||||
self.decode_cuda_graph_metadata = {}
|
||||
self.prefill_cuda_graph_metadata = {} # For verify
|
||||
self.draft_extend_cuda_graph_metadata = {} # For draft extend
|
||||
|
||||
def _process_multi_item_scoring(
|
||||
self, forward_batch: ForwardBatch
|
||||
) -> MultiItemScoringParams:
|
||||
"""Process multi-item scoring tensors for FlashInfer attention.
|
||||
|
||||
This method handles sequences containing multiple "items" separated by delimiter tokens,
|
||||
where each item needs specific attention patterns that respect item boundaries.
|
||||
|
||||
The method produces four key tensors for FlashInfer:
|
||||
- prefix_len_ptr: uint32 tensor with prefix length for each prompt in batch
|
||||
- token_pos_in_items_ptr: uint16 tensor with token positions starting from 0 at delimiters
|
||||
- token_pos_in_items_len: padding length for batch processing
|
||||
- max_item_len_ptr: uint16 tensor with max item length for each prompt
|
||||
|
||||
Args:
|
||||
forward_batch: The forward batch containing input sequences and delimiter info
|
||||
|
||||
Returns:
|
||||
MultiItemScoringParams: The processed multi-item scoring parameters
|
||||
|
||||
Examples:
|
||||
Following FlashInfer definition: for 3 items of length 3, 2, 4 respectively:
|
||||
token_pos_in_items_ptr = [0, 1, 2, 3, 0, 1, 2, 0, 1, 2, 3, 4, 0]
|
||||
|
||||
Case 1: Single sequence
|
||||
Text: "What is the capital of France? <delim> London <delim> Paris <delim> Berlin <delim>"
|
||||
Tokens: [What, is, the, capital, of, France, ?, <delim>, London, <delim>, Paris, <delim>, Berlin, <delim>]
|
||||
Indices: [ 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13]
|
||||
- prefix_len_ptr: [7] (query length before first delimiter)
|
||||
- token_pos_in_items_ptr: [0, 1, 0, 1, 0, 1, 0] (delim=0, London=1, delim=0, Paris=1, delim=0, Berlin=1, delim=0)
|
||||
- token_pos_in_items_len: 7 (actual length)
|
||||
- max_item_len_ptr: [1] (max item length is 1 token - all options are single tokens)
|
||||
|
||||
Case 2: Batch processing (batch_size=2)
|
||||
Sequence 1: 2 items of length 2, 1 → [0, 1, 2, 0, 1, 0] (6 elements)
|
||||
Sequence 2: 3 items of length 1, 3, 2 → [0, 1, 0, 1, 2, 3, 0, 1, 2, 0] (10 elements)
|
||||
After padding both to length 10:
|
||||
- token_pos_in_items_ptr: [0, 1, 2, 0, 1, 0, 0, 0, 0, 0, 0, 1, 0, 1, 2, 3, 0, 1, 2, 0]
|
||||
- token_pos_in_items_len: 10 (padded length for batch processing)
|
||||
- max_item_len_ptr: [2, 3] (max lengths per sequence)
|
||||
"""
|
||||
|
||||
delimiter = self.multi_item_scoring_delimiter
|
||||
if delimiter is None or forward_batch.forward_mode == ForwardMode.DECODE:
|
||||
return MultiItemScoringParams()
|
||||
|
||||
delimiter_mask = forward_batch.input_ids == delimiter
|
||||
prefix_cache_lens = getattr(forward_batch, "extend_prefix_lens", None)
|
||||
extend_seq_lens = getattr(forward_batch, "extend_seq_lens", None)
|
||||
prefix_len_ptr, token_pos_in_items_ptr = [], []
|
||||
token_pos_in_items_len = 0
|
||||
|
||||
# If no extend_seq_lens, treat whole batch as one sequence
|
||||
if extend_seq_lens is None or len(extend_seq_lens) <= 1:
|
||||
extend_seq_lens = [forward_batch.input_ids.size(0)]
|
||||
|
||||
seq_start = 0
|
||||
for i, seq_len in enumerate(extend_seq_lens):
|
||||
seq_end = seq_start + seq_len
|
||||
mask = delimiter_mask[seq_start:seq_end]
|
||||
pos = forward_batch.positions[seq_start:seq_end]
|
||||
delimiter_indices = torch.nonzero(mask, as_tuple=True)[0]
|
||||
|
||||
if len(delimiter_indices) > 0:
|
||||
first_delim = delimiter_indices[0]
|
||||
# Prefix length: store as scalar
|
||||
prefix_len = first_delim + (
|
||||
prefix_cache_lens[i] if prefix_cache_lens is not None else 0
|
||||
)
|
||||
prefix_len_ptr.append(
|
||||
prefix_len.item() if torch.is_tensor(prefix_len) else prefix_len
|
||||
)
|
||||
|
||||
# Compute relative positions within items after delimiters
|
||||
diff = pos[first_delim:] - torch.cummax(mask[first_delim:], 0)[1]
|
||||
token_pos = (diff - pos[first_delim]).to(torch.uint16)
|
||||
token_pos_in_items_ptr.append(token_pos)
|
||||
|
||||
# Update forward_batch positions in-place
|
||||
pos[first_delim:] = diff - 1
|
||||
forward_batch.positions[seq_start:seq_end] = pos
|
||||
|
||||
seq_start = seq_end
|
||||
|
||||
# Pad token_pos_in_items_ptr for batch processing
|
||||
if token_pos_in_items_ptr:
|
||||
token_pos_in_items_len = max(t.numel() for t in token_pos_in_items_ptr)
|
||||
device = forward_batch.input_ids.device
|
||||
token_pos_in_items_ptr = [
|
||||
torch.cat(
|
||||
[
|
||||
t,
|
||||
torch.zeros(
|
||||
token_pos_in_items_len - t.numel(),
|
||||
dtype=torch.uint16,
|
||||
device=device,
|
||||
),
|
||||
]
|
||||
)
|
||||
for t in token_pos_in_items_ptr
|
||||
]
|
||||
|
||||
if not prefix_len_ptr or not token_pos_in_items_ptr:
|
||||
return MultiItemScoringParams()
|
||||
|
||||
# Build final params
|
||||
device = forward_batch.input_ids.device
|
||||
return MultiItemScoringParams(
|
||||
prefix_len_ptr=torch.tensor(
|
||||
prefix_len_ptr, dtype=torch.uint32, device=device
|
||||
),
|
||||
token_pos_in_items_ptr=torch.cat(token_pos_in_items_ptr, dim=0),
|
||||
token_pos_in_items_len=token_pos_in_items_len & 0xFFFFFFFF,
|
||||
max_item_len_ptr=torch.stack(
|
||||
[
|
||||
t.to(torch.int32).max().to(torch.uint16)
|
||||
for t in token_pos_in_items_ptr
|
||||
],
|
||||
dim=0,
|
||||
),
|
||||
)
|
||||
|
||||
def init_forward_metadata(self, forward_batch: ForwardBatch):
|
||||
if forward_batch.forward_mode.is_decode_or_idle():
|
||||
self.indices_updater_decode.update(
|
||||
@@ -280,13 +440,26 @@ class FlashInferAttnBackend(AttentionBackend):
|
||||
else:
|
||||
prefix_lens = forward_batch.extend_prefix_lens
|
||||
|
||||
if self.is_multimodal:
|
||||
# Disable ragged wrapper and ensure prefix handling for multimodal and multi-item scoring
|
||||
if self.is_multimodal or self.multi_item_scoring_delimiter is not None:
|
||||
# use_ragged = False: Multi-item scoring requires the paged wrapper because:
|
||||
# 1. Ragged wrapper doesn't support the specialized multi-item parameters
|
||||
# (prefix_len_ptr, token_pos_in_items_ptr, etc.)
|
||||
# 2. Paged wrapper provides better control over attention masking needed
|
||||
# for respecting item boundaries in multi-item sequences
|
||||
# 3. Custom masking logic conflicts with ragged wrapper's assumptions
|
||||
use_ragged = False
|
||||
extend_no_prefix = False
|
||||
else:
|
||||
use_ragged = not self.enable_deterministic
|
||||
extend_no_prefix = not any(forward_batch.extend_prefix_lens_cpu)
|
||||
|
||||
# Process multi-item scoring in attention backend instead of ForwardBatch
|
||||
multi_item_params = MultiItemScoringParams()
|
||||
if self.multi_item_scoring_delimiter is not None:
|
||||
# Use new backend-specific implementation
|
||||
multi_item_params = self._process_multi_item_scoring(forward_batch)
|
||||
|
||||
self.indices_updater_prefill.update(
|
||||
forward_batch.req_pool_indices,
|
||||
forward_batch.seq_lens,
|
||||
@@ -298,9 +471,13 @@ class FlashInferAttnBackend(AttentionBackend):
|
||||
encoder_lens=forward_batch.encoder_lens,
|
||||
spec_info=None,
|
||||
fixed_split_size=self.prefill_split_tile_size,
|
||||
multi_item_params=multi_item_params,
|
||||
)
|
||||
self.forward_metadata = PrefillMetadata(
|
||||
self.prefill_wrappers_paged, use_ragged, extend_no_prefix
|
||||
self.prefill_wrappers_paged,
|
||||
use_ragged,
|
||||
extend_no_prefix,
|
||||
multi_item_params,
|
||||
)
|
||||
|
||||
def init_cuda_graph_state(
|
||||
@@ -531,7 +708,20 @@ class FlashInferAttnBackend(AttentionBackend):
|
||||
forward_batch.token_to_kv_pool.get_kv_buffer(layer.layer_id),
|
||||
causal=not layer.is_cross_attention,
|
||||
sm_scale=layer.scaling,
|
||||
window_left=layer.sliding_window_size,
|
||||
# Disable sliding window attention for multi-item scoring:
|
||||
# - Sliding window could cut across item boundaries, breaking semantic coherence
|
||||
# - Multi-item sequences need full attention to properly handle delimiter tokens
|
||||
# - Specialized multi-item parameters (prefix_len_ptr, token_pos_in_items_ptr)
|
||||
# provide more precise attention control than simple sliding windows
|
||||
# - Item-aware masking takes precedence over window-based masking
|
||||
window_left=(
|
||||
layer.sliding_window_size
|
||||
if not (
|
||||
self.forward_metadata.multi_item_params
|
||||
and self.forward_metadata.multi_item_params.is_enabled()
|
||||
)
|
||||
else -1
|
||||
),
|
||||
logits_soft_cap=logits_soft_cap,
|
||||
# Must use _float to avoid device-to-host copy that breaks cuda graph capture.
|
||||
k_scale=layer.k_scale_float,
|
||||
@@ -952,6 +1142,7 @@ class FlashInferIndicesUpdaterPrefill:
|
||||
encoder_lens: Optional[torch.Tensor],
|
||||
spec_info: Optional[SpecInput],
|
||||
fixed_split_size: Optional[int] = None,
|
||||
multi_item_params: Optional[MultiItemScoringParams] = None,
|
||||
):
|
||||
if use_ragged:
|
||||
# TODO: remove this device sync, we can use forward_batch.extend_prefix_lens_cpu
|
||||
@@ -976,6 +1167,7 @@ class FlashInferIndicesUpdaterPrefill:
|
||||
use_ragged,
|
||||
spec_info,
|
||||
fixed_split_size=fixed_split_size,
|
||||
multi_item_params=multi_item_params,
|
||||
)
|
||||
|
||||
def update_sliding_window(
|
||||
@@ -990,6 +1182,7 @@ class FlashInferIndicesUpdaterPrefill:
|
||||
encoder_lens: Optional[torch.Tensor],
|
||||
spec_info: Optional[SpecInput],
|
||||
fixed_split_size: Optional[int] = None,
|
||||
multi_item_params: Optional[MultiItemScoringParams] = None,
|
||||
):
|
||||
for wrapper_id in range(2):
|
||||
if wrapper_id == 0:
|
||||
@@ -1023,6 +1216,7 @@ class FlashInferIndicesUpdaterPrefill:
|
||||
use_ragged,
|
||||
spec_info,
|
||||
use_sliding_window_kv_pool=use_sliding_window_kv_pool,
|
||||
multi_item_params=multi_item_params,
|
||||
)
|
||||
|
||||
def update_cross_attention(
|
||||
@@ -1037,6 +1231,7 @@ class FlashInferIndicesUpdaterPrefill:
|
||||
encoder_lens: Optional[torch.Tensor],
|
||||
spec_info: Optional[SpecInput],
|
||||
fixed_split_size: Optional[int] = None,
|
||||
multi_item_params: Optional[MultiItemScoringParams] = None,
|
||||
):
|
||||
for wrapper_id in range(2):
|
||||
if wrapper_id == 0:
|
||||
@@ -1063,6 +1258,7 @@ class FlashInferIndicesUpdaterPrefill:
|
||||
self.qo_indptr[wrapper_id],
|
||||
use_ragged,
|
||||
spec_info,
|
||||
multi_item_params=multi_item_params,
|
||||
)
|
||||
|
||||
def call_begin_forward(
|
||||
@@ -1081,6 +1277,7 @@ class FlashInferIndicesUpdaterPrefill:
|
||||
spec_info: Optional[SpecInput],
|
||||
use_sliding_window_kv_pool: bool = False,
|
||||
fixed_split_size: Optional[int] = None,
|
||||
multi_item_params: Optional[MultiItemScoringParams] = None,
|
||||
):
|
||||
bs = len(seq_lens)
|
||||
if spec_info is None:
|
||||
@@ -1136,6 +1333,22 @@ class FlashInferIndicesUpdaterPrefill:
|
||||
)
|
||||
|
||||
# cached part
|
||||
# Conditionally set multi-item parameters
|
||||
if multi_item_params is not None and multi_item_params.is_enabled():
|
||||
# Multi-item scoring is active - use specialized parameters and disable generic custom_mask
|
||||
use_custom_mask = None
|
||||
prefix_len_ptr = multi_item_params.prefix_len_ptr
|
||||
token_pos_in_items_ptr = multi_item_params.token_pos_in_items_ptr
|
||||
token_pos_in_items_len = multi_item_params.token_pos_in_items_len
|
||||
max_item_len_ptr = multi_item_params.max_item_len_ptr
|
||||
else:
|
||||
# No multi-item scoring - use standard parameters
|
||||
use_custom_mask = custom_mask
|
||||
prefix_len_ptr = None
|
||||
token_pos_in_items_ptr = None
|
||||
token_pos_in_items_len = 0
|
||||
max_item_len_ptr = None
|
||||
|
||||
wrapper_paged.begin_forward(
|
||||
qo_indptr,
|
||||
kv_indptr,
|
||||
@@ -1147,9 +1360,13 @@ class FlashInferIndicesUpdaterPrefill:
|
||||
1,
|
||||
q_data_type=self.q_data_type,
|
||||
kv_data_type=self.data_type,
|
||||
custom_mask=custom_mask,
|
||||
custom_mask=use_custom_mask,
|
||||
non_blocking=True,
|
||||
fixed_split_size=fixed_split_size,
|
||||
prefix_len_ptr=prefix_len_ptr,
|
||||
token_pos_in_items_ptr=token_pos_in_items_ptr,
|
||||
token_pos_in_items_len=token_pos_in_items_len,
|
||||
max_item_len_ptr=max_item_len_ptr,
|
||||
)
|
||||
|
||||
|
||||
|
||||
@@ -60,7 +60,8 @@ _is_npu = is_npu()
|
||||
class LogitsProcessorOutput:
|
||||
## Part 1: This part will be assigned in python/sglang/srt/layers/logits_processor.py::LogitsProcessor
|
||||
# The logits of the next tokens. shape: [#seq, vocab_size]
|
||||
next_token_logits: torch.Tensor
|
||||
# Can be None for certain prefill-only requests (e.g., multi-item scoring) that don't need next token generation
|
||||
next_token_logits: Optional[torch.Tensor]
|
||||
# Used by speculative decoding (EAGLE)
|
||||
# The last hidden layers
|
||||
hidden_states: Optional[torch.Tensor] = None
|
||||
@@ -85,7 +86,10 @@ class LogitsProcessorOutput:
|
||||
input_top_logprobs_val: List = None
|
||||
input_top_logprobs_idx: List = None
|
||||
# The logprobs and ids of the requested token ids in input positions. shape: [#seq, n] (n is the number of requested token ids)
|
||||
input_token_ids_logprobs_val: Optional[List] = None
|
||||
# Can contain either lists or GPU tensors (for delayed GPU-to-CPU transfer optimization)
|
||||
input_token_ids_logprobs_val: Optional[List[Union[List[float], torch.Tensor]]] = (
|
||||
None
|
||||
)
|
||||
input_token_ids_logprobs_idx: Optional[List] = None
|
||||
|
||||
|
||||
@@ -127,6 +131,9 @@ class LogitsMetadata:
|
||||
# for padding
|
||||
padded_static_len: int = -1
|
||||
|
||||
# Whether this batch is prefill-only (no token generation needed)
|
||||
is_prefill_only: bool = False
|
||||
|
||||
@classmethod
|
||||
def from_forward_batch(cls, forward_batch: ForwardBatch):
|
||||
if (
|
||||
@@ -169,6 +176,7 @@ class LogitsMetadata:
|
||||
token_ids_logprobs=forward_batch.token_ids_logprobs,
|
||||
extend_input_logprob_token_ids_gpu=forward_batch.extend_input_logprob_token_ids_gpu,
|
||||
padded_static_len=forward_batch.padded_static_len,
|
||||
is_prefill_only=forward_batch.is_prefill_only,
|
||||
global_num_tokens_gpu=forward_batch.global_num_tokens_gpu,
|
||||
dp_local_start_pos=forward_batch.dp_local_start_pos,
|
||||
dp_local_num_tokens=forward_batch.dp_local_num_tokens,
|
||||
@@ -247,6 +255,108 @@ class LogitsProcessor(nn.Module):
|
||||
"debug_tensor_dump_output_folder", None
|
||||
)
|
||||
|
||||
def compute_logprobs_for_multi_item_scoring(
|
||||
self,
|
||||
input_ids,
|
||||
hidden_states,
|
||||
lm_head: VocabParallelEmbedding,
|
||||
logits_metadata: Union[LogitsMetadata, ForwardBatch],
|
||||
delimiter_token: int,
|
||||
):
|
||||
"""
|
||||
Compute logprobs for multi-item scoring using delimiter-based token extraction.
|
||||
|
||||
This method is designed for scenarios where you want to score multiple items/candidates
|
||||
against a single query by combining them into one sequence separated by delimiters.
|
||||
|
||||
Sequence format: Query<delimiter>Item1<delimiter>Item2<delimiter>...
|
||||
Scoring positions: Extracts logprobs at positions before each <delimiter>
|
||||
|
||||
Args:
|
||||
input_ids (torch.Tensor): Input token IDs containing query and items separated by delimiters.
|
||||
Shape: [total_sequence_length] for single request or [batch_total_length] for batch.
|
||||
hidden_states (torch.Tensor): Hidden states from the model.
|
||||
Shape: [sequence_length, hidden_dim].
|
||||
lm_head (VocabParallelEmbedding): Language model head for computing logits.
|
||||
logits_metadata (Union[LogitsMetadata, ForwardBatch]): Metadata containing batch info
|
||||
and token ID specifications for logprob extraction.
|
||||
delimiter_token (int): Token ID used as delimiter between query and items.
|
||||
|
||||
Returns:
|
||||
LogitsProcessorOutput: Contains:
|
||||
- next_token_logits: None (not needed for scoring-only requests)
|
||||
- input_token_logprobs: Logprobs of delimiter tokens at scoring positions
|
||||
- input_top_logprobs_val: Top-k logprobs at delimiter positions (if requested)
|
||||
- input_top_logprobs_idx: Top-k token indices at delimiter positions (if requested)
|
||||
- input_token_ids_logprobs_val: Logprobs for user-requested token IDs (if any)
|
||||
- input_token_ids_logprobs_idx: Indices for user-requested token IDs (if any)
|
||||
"""
|
||||
multi_item_indices = (input_ids == delimiter_token).nonzero(as_tuple=True)[
|
||||
0
|
||||
] - 1
|
||||
# Extract hidden states at delimiter positions for multi-item scoring
|
||||
sliced_hidden = hidden_states[multi_item_indices]
|
||||
|
||||
sliced_logits = self._get_logits(sliced_hidden, lm_head, logits_metadata)
|
||||
sliced_logprobs = torch.nn.functional.log_softmax(sliced_logits, dim=-1)
|
||||
|
||||
# Initialize return values
|
||||
input_token_ids_logprobs_val = []
|
||||
input_token_ids_logprobs_idx = []
|
||||
input_top_logprobs_val = None
|
||||
input_top_logprobs_idx = None
|
||||
|
||||
# Recalculate extend_logprob_pruned_lens_cpu to match delimiter counts per request
|
||||
# Original contains sequence lengths, but we need delimiter counts for sliced_logprobs
|
||||
if (
|
||||
logits_metadata.token_ids_logprobs
|
||||
or logits_metadata.extend_return_top_logprob
|
||||
):
|
||||
logits_metadata.extend_logprob_pruned_lens_cpu = []
|
||||
|
||||
if logits_metadata.extend_seq_lens_cpu is not None:
|
||||
# Multi-request batch: count delimiters per request
|
||||
input_pt = 0
|
||||
for req_seq_len in logits_metadata.extend_seq_lens_cpu:
|
||||
req_input_ids = input_ids[input_pt : input_pt + req_seq_len]
|
||||
delimiter_count = (req_input_ids == delimiter_token).sum().item()
|
||||
logits_metadata.extend_logprob_pruned_lens_cpu.append(
|
||||
delimiter_count
|
||||
)
|
||||
input_pt += req_seq_len
|
||||
else:
|
||||
# Single request case: one request gets all delimiters
|
||||
total_delimiters = (input_ids == delimiter_token).sum().item()
|
||||
logits_metadata.extend_logprob_pruned_lens_cpu = [total_delimiters]
|
||||
|
||||
# Get the logprobs of specified token ids
|
||||
if logits_metadata.extend_token_ids_logprob:
|
||||
(
|
||||
input_token_ids_logprobs_val,
|
||||
input_token_ids_logprobs_idx,
|
||||
) = self.get_token_ids_logprobs(
|
||||
sliced_logprobs, logits_metadata, delay_cpu_copy=True
|
||||
)
|
||||
|
||||
# Get the logprob of top-k tokens
|
||||
if logits_metadata.extend_return_top_logprob:
|
||||
(
|
||||
input_top_logprobs_val,
|
||||
input_top_logprobs_idx,
|
||||
) = self.get_top_logprobs(sliced_logprobs, logits_metadata)
|
||||
|
||||
# For input_token_logprobs, use delimiter token logprobs
|
||||
input_token_logprobs = sliced_logprobs[:, delimiter_token]
|
||||
|
||||
return LogitsProcessorOutput(
|
||||
next_token_logits=None, # Multi-item scoring doesn't need next token logits
|
||||
input_token_logprobs=input_token_logprobs,
|
||||
input_top_logprobs_val=input_top_logprobs_val,
|
||||
input_top_logprobs_idx=input_top_logprobs_idx,
|
||||
input_token_ids_logprobs_val=input_token_ids_logprobs_val,
|
||||
input_token_ids_logprobs_idx=input_token_ids_logprobs_idx,
|
||||
)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
input_ids,
|
||||
@@ -257,6 +367,16 @@ class LogitsProcessor(nn.Module):
|
||||
) -> LogitsProcessorOutput:
|
||||
if isinstance(logits_metadata, ForwardBatch):
|
||||
logits_metadata = LogitsMetadata.from_forward_batch(logits_metadata)
|
||||
|
||||
# Check if multi-item scoring is enabled via server args (only for prefill-only requests)
|
||||
multi_item_delimiter = global_server_args_dict.get(
|
||||
"multi_item_scoring_delimiter"
|
||||
)
|
||||
if multi_item_delimiter is not None and logits_metadata.is_prefill_only:
|
||||
return self.compute_logprobs_for_multi_item_scoring(
|
||||
input_ids, hidden_states, lm_head, logits_metadata, multi_item_delimiter
|
||||
)
|
||||
|
||||
# Get the last hidden states and last logits for the next token prediction
|
||||
if (
|
||||
logits_metadata.forward_mode.is_decode_or_idle()
|
||||
@@ -584,7 +704,9 @@ class LogitsProcessor(nn.Module):
|
||||
|
||||
@staticmethod
|
||||
def get_token_ids_logprobs(
|
||||
all_logprobs: torch.Tensor, logits_metadata: LogitsMetadata
|
||||
all_logprobs: torch.Tensor,
|
||||
logits_metadata: LogitsMetadata,
|
||||
delay_cpu_copy: bool = False,
|
||||
):
|
||||
input_token_ids_logprobs_val, input_token_ids_logprobs_idx = [], []
|
||||
pt = 0
|
||||
@@ -597,9 +719,17 @@ class LogitsProcessor(nn.Module):
|
||||
input_token_ids_logprobs_idx.append([])
|
||||
continue
|
||||
|
||||
input_token_ids_logprobs_val.append(
|
||||
[all_logprobs[pt + j, token_ids].tolist() for j in range(pruned_len)]
|
||||
)
|
||||
position_logprobs = all_logprobs[
|
||||
pt : pt + pruned_len, token_ids
|
||||
] # Shape: [pruned_len, num_tokens]
|
||||
|
||||
if delay_cpu_copy:
|
||||
# Keep as tensor to delay GPU-to-CPU transfer
|
||||
input_token_ids_logprobs_val.append(position_logprobs)
|
||||
else:
|
||||
# Convert to list immediately (default behavior)
|
||||
input_token_ids_logprobs_val.append(position_logprobs.tolist())
|
||||
|
||||
input_token_ids_logprobs_idx.append([token_ids for _ in range(pruned_len)])
|
||||
pt += pruned_len
|
||||
|
||||
|
||||
Reference in New Issue
Block a user