[Generative Score API] Scoring(Prefill-only) optimizations. (#9748)

This commit is contained in:
Sundara Raman Ramachandran
2025-09-13 10:57:06 -07:00
committed by GitHub
parent 94d0f656fb
commit a360511d7b
9 changed files with 325 additions and 48 deletions

View File

@@ -561,7 +561,10 @@ class Req:
# shape: (bs, k)
self.output_top_logprobs_val = []
self.output_top_logprobs_idx = []
self.output_token_ids_logprobs_val = []
# Can contain either lists or GPU tensors (delayed copy optimization for prefill-only scoring)
self.output_token_ids_logprobs_val: List[
Union[List[float], torch.Tensor]
] = []
self.output_token_ids_logprobs_idx = []
else:
self.output_token_logprobs_val = self.output_token_logprobs_idx = (
@@ -619,6 +622,11 @@ class Req:
def seqlen(self):
return len(self.origin_input_ids) + len(self.output_ids)
@property
def is_prefill_only(self) -> bool:
"""Check if this request is prefill-only (no token generation needed)."""
return self.sampling_params.max_new_tokens == 0
def extend_image_inputs(self, image_inputs):
if self.multimodal_inputs is None:
self.multimodal_inputs = image_inputs
@@ -950,9 +958,7 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
device=req_to_token_pool.device,
spec_algorithm=spec_algorithm,
return_hidden_states=any(req.return_hidden_states for req in reqs),
is_prefill_only=all(
req.sampling_params.max_new_tokens == 0 for req in reqs
),
is_prefill_only=all(req.is_prefill_only for req in reqs),
chunked_req=chunked_req,
)
@@ -1210,13 +1216,36 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
req.is_retracted = False
# Compute the relative logprob_start_len in an extend batch
#
# Key variables:
# - logprob_start_len: Absolute position in full sequence where logprob computation begins
# - extend_logprob_start_len: Relative position within current extend batch where logprob computation begins
# - extend_input_len: Number of tokens that need to be processed in this extend batch
# (= len(fill_ids) - len(prefix_indices), where fill_ids = origin_input_ids + output_ids
# and prefix_indices are the cached/shared prefix tokens)
#
if req.logprob_start_len >= pre_len:
req.extend_logprob_start_len = min(
req.logprob_start_len - pre_len,
req.extend_input_len,
req.seqlen - 1,
)
# Optimization for prefill-only requests: When we only need logprobs at
# positions beyond the input sequence (to score next-token likelihood), skip all
# input logprob computation during prefill since no generation will occur.
if self.is_prefill_only and req.logprob_start_len == len(
req.origin_input_ids
):
# Skip ALL input logprobs: set extend_logprob_start_len = extend_input_len
req.extend_logprob_start_len = req.extend_input_len
else:
# Convert absolute logprob_start_len to relative extend_logprob_start_len
#
# Example: origin_input_ids=[1,2,3,4,5] (5 tokens, positions 0-4), logprob_start_len=3
# Regular logic: min(3-0, 5, 5-1) = min(3,5,4) = 3
# This means: "compute logprobs from position 3 onwards in extend batch"
req.extend_logprob_start_len = min(
req.logprob_start_len - pre_len,
req.extend_input_len,
req.seqlen - 1,
)
else:
# logprob_start_len is before the current extend batch, so start from beginning
req.extend_logprob_start_len = 0
if self.return_logprob:
@@ -1763,6 +1792,7 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
),
extend_input_logprob_token_ids=self.extend_input_logprob_token_ids,
launch_done=self.launch_done,
is_prefill_only=self.is_prefill_only,
)
def copy(self):
@@ -1905,6 +1935,9 @@ class ModelWorkerBatch:
# Overlap event
launch_done: Optional[threading.Event] = None
# Whether this batch is prefill-only (no token generation needed)
is_prefill_only: bool = False
@triton.jit
def write_req_to_token_pool_triton(