[Generative Score API] Optimization to Remove Decode. (#8840)
This commit is contained in:
committed by
GitHub
parent
9e426466af
commit
a027a9b4b3
@@ -913,6 +913,9 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
|
||||
# Whether to return hidden states
|
||||
return_hidden_states: bool = False
|
||||
|
||||
# Whether this batch is prefill-only (no token generation needed)
|
||||
is_prefill_only: bool = False
|
||||
|
||||
# hicache pointer for synchronizing data loading from CPU to GPU
|
||||
hicache_consumer_index: int = 0
|
||||
|
||||
@@ -953,6 +956,9 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
|
||||
device=req_to_token_pool.device,
|
||||
spec_algorithm=spec_algorithm,
|
||||
return_hidden_states=any(req.return_hidden_states for req in reqs),
|
||||
is_prefill_only=all(
|
||||
req.sampling_params.max_new_tokens == 0 for req in reqs
|
||||
),
|
||||
chunked_req=chunked_req,
|
||||
)
|
||||
|
||||
@@ -1796,6 +1802,7 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
|
||||
global_num_tokens_for_logprob=self.global_num_tokens_for_logprob,
|
||||
can_run_dp_cuda_graph=self.can_run_dp_cuda_graph,
|
||||
is_extend_in_batch=self.is_extend_in_batch,
|
||||
is_prefill_only=self.is_prefill_only,
|
||||
)
|
||||
|
||||
def _evict_tree_cache_if_needed(self, num_tokens: int):
|
||||
|
||||
@@ -1466,8 +1466,9 @@ class Scheduler(
|
||||
if self.last_batch.batch_size() < last_bs:
|
||||
self.running_batch.batch_is_full = False
|
||||
|
||||
# Merge the new batch into the running batch
|
||||
if not self.last_batch.is_empty():
|
||||
# Merge the new batch into the running batch.
|
||||
# For prefill-only batch, we can avoid going through decoding step.
|
||||
if not self.last_batch.is_empty() and not self.last_batch.is_prefill_only:
|
||||
if self.running_batch.is_empty():
|
||||
self.running_batch = self.last_batch
|
||||
else:
|
||||
|
||||
@@ -699,7 +699,7 @@ class TokenizerManager:
|
||||
# Process all requests
|
||||
tokenized_objs = []
|
||||
for i, req in enumerate(requests):
|
||||
self._validate_token_len(obj[i], input_ids_list[i])
|
||||
self._validate_one_request(obj[i], input_ids_list[i])
|
||||
tokenized_objs.append(
|
||||
self._create_tokenized_object(
|
||||
req, req.text, input_ids_list[i], None, None
|
||||
@@ -1892,6 +1892,13 @@ class TokenizerManager:
|
||||
f"Token ID {token_id} is out of vocabulary (vocab size: {vocab_size})"
|
||||
)
|
||||
|
||||
batch_request = GenerateReqInput(
|
||||
token_ids_logprob=label_token_ids,
|
||||
return_logprob=True,
|
||||
stream=False,
|
||||
sampling_params={"max_new_tokens": 0},
|
||||
)
|
||||
|
||||
# Handle string or tokenized query/items
|
||||
if isinstance(query, str) and (
|
||||
isinstance(items, str)
|
||||
@@ -1903,13 +1910,9 @@ class TokenizerManager:
|
||||
prompts = [f"{item}{query}" for item in items_list]
|
||||
else:
|
||||
prompts = [f"{query}{item}" for item in items_list]
|
||||
batch_request = GenerateReqInput(
|
||||
text=prompts,
|
||||
return_logprob=True,
|
||||
token_ids_logprob=label_token_ids,
|
||||
stream=False,
|
||||
sampling_params={"max_new_tokens": 1},
|
||||
)
|
||||
|
||||
batch_request.text = prompts
|
||||
|
||||
elif (
|
||||
isinstance(query, list)
|
||||
and isinstance(items, list)
|
||||
@@ -1921,13 +1924,8 @@ class TokenizerManager:
|
||||
input_ids_list = [item + query for item in items]
|
||||
else:
|
||||
input_ids_list = [query + item for item in items]
|
||||
batch_request = GenerateReqInput(
|
||||
input_ids=input_ids_list,
|
||||
return_logprob=True,
|
||||
token_ids_logprob=label_token_ids,
|
||||
stream=False,
|
||||
sampling_params={"max_new_tokens": 1},
|
||||
)
|
||||
|
||||
batch_request.input_ids = input_ids_list
|
||||
else:
|
||||
raise ValueError(
|
||||
"Invalid combination of query/items types for score_request."
|
||||
@@ -1939,9 +1937,20 @@ class TokenizerManager:
|
||||
for result in results:
|
||||
# Get logprobs for each token
|
||||
logprobs = {}
|
||||
for logprob, token_id, _ in result["meta_info"].get(
|
||||
"output_token_ids_logprobs", []
|
||||
)[0]:
|
||||
|
||||
# For scoring requests, we read from output_token_ids_logprobs since we want
|
||||
# the logprobs for specific tokens mentioned in the label_token_ids at
|
||||
# the next position after the last token in the prompt
|
||||
output_logprobs = result["meta_info"].get("output_token_ids_logprobs", [])
|
||||
|
||||
# Throw an error here if output_logprobs is None
|
||||
if output_logprobs is None:
|
||||
raise RuntimeError(
|
||||
f"output_logprobs is None for request {result['meta_info'].get('id', '<unknown>')}. "
|
||||
"This usually indicates a problem with the scoring request or the backend output."
|
||||
)
|
||||
|
||||
for logprob, token_id, _ in output_logprobs[0]:
|
||||
if token_id in label_token_ids:
|
||||
logprobs[token_id] = logprob
|
||||
|
||||
|
||||
Reference in New Issue
Block a user