[Generative Score API] Optimization to Remove Decode. (#8840)

This commit is contained in:
Sundara Raman Ramachandran
2025-08-13 14:12:24 -07:00
committed by GitHub
parent 9e426466af
commit a027a9b4b3
6 changed files with 843 additions and 20 deletions

View File

@@ -699,7 +699,7 @@ class TokenizerManager:
# Process all requests
tokenized_objs = []
for i, req in enumerate(requests):
self._validate_token_len(obj[i], input_ids_list[i])
self._validate_one_request(obj[i], input_ids_list[i])
tokenized_objs.append(
self._create_tokenized_object(
req, req.text, input_ids_list[i], None, None
@@ -1892,6 +1892,13 @@ class TokenizerManager:
f"Token ID {token_id} is out of vocabulary (vocab size: {vocab_size})"
)
batch_request = GenerateReqInput(
token_ids_logprob=label_token_ids,
return_logprob=True,
stream=False,
sampling_params={"max_new_tokens": 0},
)
# Handle string or tokenized query/items
if isinstance(query, str) and (
isinstance(items, str)
@@ -1903,13 +1910,9 @@ class TokenizerManager:
prompts = [f"{item}{query}" for item in items_list]
else:
prompts = [f"{query}{item}" for item in items_list]
batch_request = GenerateReqInput(
text=prompts,
return_logprob=True,
token_ids_logprob=label_token_ids,
stream=False,
sampling_params={"max_new_tokens": 1},
)
batch_request.text = prompts
elif (
isinstance(query, list)
and isinstance(items, list)
@@ -1921,13 +1924,8 @@ class TokenizerManager:
input_ids_list = [item + query for item in items]
else:
input_ids_list = [query + item for item in items]
batch_request = GenerateReqInput(
input_ids=input_ids_list,
return_logprob=True,
token_ids_logprob=label_token_ids,
stream=False,
sampling_params={"max_new_tokens": 1},
)
batch_request.input_ids = input_ids_list
else:
raise ValueError(
"Invalid combination of query/items types for score_request."
@@ -1939,9 +1937,20 @@ class TokenizerManager:
for result in results:
# Get logprobs for each token
logprobs = {}
for logprob, token_id, _ in result["meta_info"].get(
"output_token_ids_logprobs", []
)[0]:
# For scoring requests, we read from output_token_ids_logprobs since we want
# the logprobs for specific tokens mentioned in the label_token_ids at
# the next position after the last token in the prompt
output_logprobs = result["meta_info"].get("output_token_ids_logprobs", [])
# Throw an error here if output_logprobs is None
if output_logprobs is None:
raise RuntimeError(
f"output_logprobs is None for request {result['meta_info'].get('id', '<unknown>')}. "
"This usually indicates a problem with the scoring request or the backend output."
)
for logprob, token_id, _ in output_logprobs[0]:
if token_id in label_token_ids:
logprobs[token_id] = logprob