[Generative Score API] Scoring(Prefill-only) optimizations. (#9748)
This commit is contained in:
committed by
GitHub
parent
94d0f656fb
commit
a360511d7b
@@ -561,7 +561,10 @@ class Req:
|
||||
# shape: (bs, k)
|
||||
self.output_top_logprobs_val = []
|
||||
self.output_top_logprobs_idx = []
|
||||
self.output_token_ids_logprobs_val = []
|
||||
# Can contain either lists or GPU tensors (delayed copy optimization for prefill-only scoring)
|
||||
self.output_token_ids_logprobs_val: List[
|
||||
Union[List[float], torch.Tensor]
|
||||
] = []
|
||||
self.output_token_ids_logprobs_idx = []
|
||||
else:
|
||||
self.output_token_logprobs_val = self.output_token_logprobs_idx = (
|
||||
@@ -619,6 +622,11 @@ class Req:
|
||||
def seqlen(self):
|
||||
return len(self.origin_input_ids) + len(self.output_ids)
|
||||
|
||||
@property
|
||||
def is_prefill_only(self) -> bool:
|
||||
"""Check if this request is prefill-only (no token generation needed)."""
|
||||
return self.sampling_params.max_new_tokens == 0
|
||||
|
||||
def extend_image_inputs(self, image_inputs):
|
||||
if self.multimodal_inputs is None:
|
||||
self.multimodal_inputs = image_inputs
|
||||
@@ -950,9 +958,7 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
|
||||
device=req_to_token_pool.device,
|
||||
spec_algorithm=spec_algorithm,
|
||||
return_hidden_states=any(req.return_hidden_states for req in reqs),
|
||||
is_prefill_only=all(
|
||||
req.sampling_params.max_new_tokens == 0 for req in reqs
|
||||
),
|
||||
is_prefill_only=all(req.is_prefill_only for req in reqs),
|
||||
chunked_req=chunked_req,
|
||||
)
|
||||
|
||||
@@ -1210,13 +1216,36 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
|
||||
req.is_retracted = False
|
||||
|
||||
# Compute the relative logprob_start_len in an extend batch
|
||||
#
|
||||
# Key variables:
|
||||
# - logprob_start_len: Absolute position in full sequence where logprob computation begins
|
||||
# - extend_logprob_start_len: Relative position within current extend batch where logprob computation begins
|
||||
# - extend_input_len: Number of tokens that need to be processed in this extend batch
|
||||
# (= len(fill_ids) - len(prefix_indices), where fill_ids = origin_input_ids + output_ids
|
||||
# and prefix_indices are the cached/shared prefix tokens)
|
||||
#
|
||||
if req.logprob_start_len >= pre_len:
|
||||
req.extend_logprob_start_len = min(
|
||||
req.logprob_start_len - pre_len,
|
||||
req.extend_input_len,
|
||||
req.seqlen - 1,
|
||||
)
|
||||
# Optimization for prefill-only requests: When we only need logprobs at
|
||||
# positions beyond the input sequence (to score next-token likelihood), skip all
|
||||
# input logprob computation during prefill since no generation will occur.
|
||||
if self.is_prefill_only and req.logprob_start_len == len(
|
||||
req.origin_input_ids
|
||||
):
|
||||
# Skip ALL input logprobs: set extend_logprob_start_len = extend_input_len
|
||||
req.extend_logprob_start_len = req.extend_input_len
|
||||
else:
|
||||
# Convert absolute logprob_start_len to relative extend_logprob_start_len
|
||||
#
|
||||
# Example: origin_input_ids=[1,2,3,4,5] (5 tokens, positions 0-4), logprob_start_len=3
|
||||
# Regular logic: min(3-0, 5, 5-1) = min(3,5,4) = 3
|
||||
# This means: "compute logprobs from position 3 onwards in extend batch"
|
||||
req.extend_logprob_start_len = min(
|
||||
req.logprob_start_len - pre_len,
|
||||
req.extend_input_len,
|
||||
req.seqlen - 1,
|
||||
)
|
||||
else:
|
||||
# logprob_start_len is before the current extend batch, so start from beginning
|
||||
req.extend_logprob_start_len = 0
|
||||
|
||||
if self.return_logprob:
|
||||
@@ -1763,6 +1792,7 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
|
||||
),
|
||||
extend_input_logprob_token_ids=self.extend_input_logprob_token_ids,
|
||||
launch_done=self.launch_done,
|
||||
is_prefill_only=self.is_prefill_only,
|
||||
)
|
||||
|
||||
def copy(self):
|
||||
@@ -1905,6 +1935,9 @@ class ModelWorkerBatch:
|
||||
# Overlap event
|
||||
launch_done: Optional[threading.Event] = None
|
||||
|
||||
# Whether this batch is prefill-only (no token generation needed)
|
||||
is_prefill_only: bool = False
|
||||
|
||||
|
||||
@triton.jit
|
||||
def write_req_to_token_pool_triton(
|
||||
|
||||
Reference in New Issue
Block a user