[Generative Score API] Optimization to Remove Decode. (#8840)
This commit is contained in:
committed by
GitHub
parent
9e426466af
commit
a027a9b4b3
@@ -699,7 +699,7 @@ class TokenizerManager:
|
||||
# Process all requests
|
||||
tokenized_objs = []
|
||||
for i, req in enumerate(requests):
|
||||
self._validate_token_len(obj[i], input_ids_list[i])
|
||||
self._validate_one_request(obj[i], input_ids_list[i])
|
||||
tokenized_objs.append(
|
||||
self._create_tokenized_object(
|
||||
req, req.text, input_ids_list[i], None, None
|
||||
@@ -1892,6 +1892,13 @@ class TokenizerManager:
|
||||
f"Token ID {token_id} is out of vocabulary (vocab size: {vocab_size})"
|
||||
)
|
||||
|
||||
batch_request = GenerateReqInput(
|
||||
token_ids_logprob=label_token_ids,
|
||||
return_logprob=True,
|
||||
stream=False,
|
||||
sampling_params={"max_new_tokens": 0},
|
||||
)
|
||||
|
||||
# Handle string or tokenized query/items
|
||||
if isinstance(query, str) and (
|
||||
isinstance(items, str)
|
||||
@@ -1903,13 +1910,9 @@ class TokenizerManager:
|
||||
prompts = [f"{item}{query}" for item in items_list]
|
||||
else:
|
||||
prompts = [f"{query}{item}" for item in items_list]
|
||||
batch_request = GenerateReqInput(
|
||||
text=prompts,
|
||||
return_logprob=True,
|
||||
token_ids_logprob=label_token_ids,
|
||||
stream=False,
|
||||
sampling_params={"max_new_tokens": 1},
|
||||
)
|
||||
|
||||
batch_request.text = prompts
|
||||
|
||||
elif (
|
||||
isinstance(query, list)
|
||||
and isinstance(items, list)
|
||||
@@ -1921,13 +1924,8 @@ class TokenizerManager:
|
||||
input_ids_list = [item + query for item in items]
|
||||
else:
|
||||
input_ids_list = [query + item for item in items]
|
||||
batch_request = GenerateReqInput(
|
||||
input_ids=input_ids_list,
|
||||
return_logprob=True,
|
||||
token_ids_logprob=label_token_ids,
|
||||
stream=False,
|
||||
sampling_params={"max_new_tokens": 1},
|
||||
)
|
||||
|
||||
batch_request.input_ids = input_ids_list
|
||||
else:
|
||||
raise ValueError(
|
||||
"Invalid combination of query/items types for score_request."
|
||||
@@ -1939,9 +1937,20 @@ class TokenizerManager:
|
||||
for result in results:
|
||||
# Get logprobs for each token
|
||||
logprobs = {}
|
||||
for logprob, token_id, _ in result["meta_info"].get(
|
||||
"output_token_ids_logprobs", []
|
||||
)[0]:
|
||||
|
||||
# For scoring requests, we read from output_token_ids_logprobs since we want
|
||||
# the logprobs for specific tokens mentioned in the label_token_ids at
|
||||
# the next position after the last token in the prompt
|
||||
output_logprobs = result["meta_info"].get("output_token_ids_logprobs", [])
|
||||
|
||||
# Throw an error here if output_logprobs is None
|
||||
if output_logprobs is None:
|
||||
raise RuntimeError(
|
||||
f"output_logprobs is None for request {result['meta_info'].get('id', '<unknown>')}. "
|
||||
"This usually indicates a problem with the scoring request or the backend output."
|
||||
)
|
||||
|
||||
for logprob, token_id, _ in output_logprobs[0]:
|
||||
if token_id in label_token_ids:
|
||||
logprobs[token_id] = logprob
|
||||
|
||||
|
||||
Reference in New Issue
Block a user