[Generative Score API] Optimization to Remove Decode. (#8840)

This commit is contained in:
Sundara Raman Ramachandran
2025-08-13 14:12:24 -07:00
committed by GitHub
parent 9e426466af
commit a027a9b4b3
6 changed files with 843 additions and 20 deletions

View File

@@ -913,6 +913,9 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
# Whether to return hidden states
return_hidden_states: bool = False
# Whether this batch is prefill-only (no token generation needed)
is_prefill_only: bool = False
# hicache pointer for synchronizing data loading from CPU to GPU
hicache_consumer_index: int = 0
@@ -953,6 +956,9 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
device=req_to_token_pool.device,
spec_algorithm=spec_algorithm,
return_hidden_states=any(req.return_hidden_states for req in reqs),
is_prefill_only=all(
req.sampling_params.max_new_tokens == 0 for req in reqs
),
chunked_req=chunked_req,
)
@@ -1796,6 +1802,7 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
global_num_tokens_for_logprob=self.global_num_tokens_for_logprob,
can_run_dp_cuda_graph=self.can_run_dp_cuda_graph,
is_extend_in_batch=self.is_extend_in_batch,
is_prefill_only=self.is_prefill_only,
)
def _evict_tree_cache_if_needed(self, num_tokens: int):

View File

@@ -1466,8 +1466,9 @@ class Scheduler(
if self.last_batch.batch_size() < last_bs:
self.running_batch.batch_is_full = False
# Merge the new batch into the running batch
if not self.last_batch.is_empty():
# Merge the new batch into the running batch.
# For prefill-only batch, we can avoid going through decoding step.
if not self.last_batch.is_empty() and not self.last_batch.is_prefill_only:
if self.running_batch.is_empty():
self.running_batch = self.last_batch
else:

View File

@@ -699,7 +699,7 @@ class TokenizerManager:
# Process all requests
tokenized_objs = []
for i, req in enumerate(requests):
self._validate_token_len(obj[i], input_ids_list[i])
self._validate_one_request(obj[i], input_ids_list[i])
tokenized_objs.append(
self._create_tokenized_object(
req, req.text, input_ids_list[i], None, None
@@ -1892,6 +1892,13 @@ class TokenizerManager:
f"Token ID {token_id} is out of vocabulary (vocab size: {vocab_size})"
)
batch_request = GenerateReqInput(
token_ids_logprob=label_token_ids,
return_logprob=True,
stream=False,
sampling_params={"max_new_tokens": 0},
)
# Handle string or tokenized query/items
if isinstance(query, str) and (
isinstance(items, str)
@@ -1903,13 +1910,9 @@ class TokenizerManager:
prompts = [f"{item}{query}" for item in items_list]
else:
prompts = [f"{query}{item}" for item in items_list]
batch_request = GenerateReqInput(
text=prompts,
return_logprob=True,
token_ids_logprob=label_token_ids,
stream=False,
sampling_params={"max_new_tokens": 1},
)
batch_request.text = prompts
elif (
isinstance(query, list)
and isinstance(items, list)
@@ -1921,13 +1924,8 @@ class TokenizerManager:
input_ids_list = [item + query for item in items]
else:
input_ids_list = [query + item for item in items]
batch_request = GenerateReqInput(
input_ids=input_ids_list,
return_logprob=True,
token_ids_logprob=label_token_ids,
stream=False,
sampling_params={"max_new_tokens": 1},
)
batch_request.input_ids = input_ids_list
else:
raise ValueError(
"Invalid combination of query/items types for score_request."
@@ -1939,9 +1937,20 @@ class TokenizerManager:
for result in results:
# Get logprobs for each token
logprobs = {}
for logprob, token_id, _ in result["meta_info"].get(
"output_token_ids_logprobs", []
)[0]:
# For scoring requests, we read from output_token_ids_logprobs since we want
# the logprobs for specific tokens mentioned in the label_token_ids at
# the next position after the last token in the prompt
output_logprobs = result["meta_info"].get("output_token_ids_logprobs", [])
# Throw an error here if output_logprobs is None
if output_logprobs is None:
raise RuntimeError(
f"output_logprobs is None for request {result['meta_info'].get('id', '<unknown>')}. "
"This usually indicates a problem with the scoring request or the backend output."
)
for logprob, token_id, _ in output_logprobs[0]:
if token_id in label_token_ids:
logprobs[token_id] = logprob