[overlap-spec] fix stop condition and trimming (#11819)
This commit is contained in:
@@ -142,6 +142,7 @@ class DetokenizerManager(MultiHttpWorkerDetokenizerMixin):
|
||||
if output[-1] == 200012 and self.is_tool_call_parser_gpt_oss:
|
||||
return output
|
||||
assert len(output) > 0
|
||||
# NOTE: We can always assume the last token is the matched stop token
|
||||
return output[:-1]
|
||||
return output
|
||||
|
||||
|
||||
@@ -486,6 +486,8 @@ class Req:
|
||||
# Check finish
|
||||
self.tokenizer = None
|
||||
self.finished_reason = None
|
||||
# finished position (in output_ids), used when checking stop conditions with speculative decoding
|
||||
self.finished_len = None
|
||||
# Whether this request has finished output
|
||||
self.finished_output = None
|
||||
# If we want to abort the request in the middle of the event loop, set this to true
|
||||
@@ -651,6 +653,13 @@ class Req:
|
||||
spec_alg = get_global_server_args().speculative_algorithm
|
||||
return self.sampling_params.max_new_tokens == 0 and spec_alg is None
|
||||
|
||||
@property
|
||||
def output_ids_through_stop(self) -> List[int]:
|
||||
"""Get the output ids through the stop condition. Stop position is included."""
|
||||
if self.finished_len is not None:
|
||||
return self.output_ids[: self.finished_len]
|
||||
return self.output_ids
|
||||
|
||||
def add_latency(self, stage: RequestStage):
|
||||
if self.metrics_collector is None:
|
||||
return
|
||||
@@ -702,18 +711,20 @@ class Req:
|
||||
def init_incremental_detokenize(self):
|
||||
first_iter = self.surr_offset is None or self.read_offset is None
|
||||
|
||||
output_ids = self.output_ids_through_stop
|
||||
|
||||
if first_iter:
|
||||
self.read_offset = len(self.origin_input_ids_unpadded)
|
||||
self.surr_offset = max(
|
||||
self.read_offset - INIT_INCREMENTAL_DETOKENIZATION_OFFSET, 0
|
||||
)
|
||||
self.surr_and_decode_ids = (
|
||||
self.origin_input_ids_unpadded[self.surr_offset :] + self.output_ids
|
||||
self.origin_input_ids_unpadded[self.surr_offset :] + output_ids
|
||||
)
|
||||
self.cur_decode_ids_len = len(self.output_ids)
|
||||
self.cur_decode_ids_len = len(output_ids)
|
||||
else:
|
||||
self.surr_and_decode_ids.extend(self.output_ids[self.cur_decode_ids_len :])
|
||||
self.cur_decode_ids_len = len(self.output_ids)
|
||||
self.surr_and_decode_ids.extend(output_ids[self.cur_decode_ids_len :])
|
||||
self.cur_decode_ids_len = len(output_ids)
|
||||
|
||||
return self.surr_and_decode_ids, self.read_offset - self.surr_offset
|
||||
|
||||
@@ -760,7 +771,72 @@ class Req:
|
||||
|
||||
return False
|
||||
|
||||
def check_finished(self):
|
||||
def _check_token_based_finish(self, new_accepted_tokens: List[int]) -> bool:
|
||||
if self.sampling_params.ignore_eos:
|
||||
return False
|
||||
|
||||
# Check stop token ids
|
||||
matched_eos = False
|
||||
|
||||
for i, token_id in enumerate(new_accepted_tokens):
|
||||
if self.sampling_params.stop_token_ids:
|
||||
matched_eos |= token_id in self.sampling_params.stop_token_ids
|
||||
if self.eos_token_ids:
|
||||
matched_eos |= token_id in self.eos_token_ids
|
||||
if self.tokenizer is not None:
|
||||
matched_eos |= token_id == self.tokenizer.eos_token_id
|
||||
if self.tokenizer.additional_stop_token_ids:
|
||||
matched_eos |= token_id in self.tokenizer.additional_stop_token_ids
|
||||
if matched_eos:
|
||||
self.finished_reason = FINISH_MATCHED_TOKEN(matched=token_id)
|
||||
matched_pos = len(self.output_ids) - len(new_accepted_tokens) + i
|
||||
self.finished_len = matched_pos + 1
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
def _check_str_based_finish(self):
|
||||
if (
|
||||
len(self.sampling_params.stop_strs) > 0
|
||||
or len(self.sampling_params.stop_regex_strs) > 0
|
||||
):
|
||||
tail_str = self.tail_str()
|
||||
|
||||
# Check stop strings
|
||||
if len(self.sampling_params.stop_strs) > 0:
|
||||
for stop_str in self.sampling_params.stop_strs:
|
||||
if stop_str in tail_str or stop_str in self.decoded_text:
|
||||
self.finished_reason = FINISH_MATCHED_STR(matched=stop_str)
|
||||
return True
|
||||
|
||||
# Check stop regex
|
||||
if len(self.sampling_params.stop_regex_strs) > 0:
|
||||
for stop_regex_str in self.sampling_params.stop_regex_strs:
|
||||
if re.search(stop_regex_str, tail_str):
|
||||
self.finished_reason = FINISHED_MATCHED_REGEX(
|
||||
matched=stop_regex_str
|
||||
)
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
def _check_vocab_boundary_finish(self, new_accepted_tokens: List[int] = None):
|
||||
for i, token_id in enumerate(new_accepted_tokens):
|
||||
if token_id > self.vocab_size or token_id < 0:
|
||||
offset = len(self.output_ids) - len(new_accepted_tokens) + i
|
||||
if self.sampling_params.stop_token_ids:
|
||||
self.output_ids[offset] = next(
|
||||
iter(self.sampling_params.stop_token_ids)
|
||||
)
|
||||
if self.eos_token_ids:
|
||||
self.output_ids[offset] = next(iter(self.eos_token_ids))
|
||||
self.finished_reason = FINISH_MATCHED_STR(matched="NaN happened")
|
||||
self.finished_len = offset + 1
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
def check_finished(self, new_accepted_len: int = 1):
|
||||
if self.finished():
|
||||
return
|
||||
|
||||
@@ -774,6 +850,7 @@ class Req:
|
||||
self.finished_reason = FINISH_LENGTH(
|
||||
length=self.sampling_params.max_new_tokens
|
||||
)
|
||||
self.finished_len = self.sampling_params.max_new_tokens
|
||||
return
|
||||
|
||||
if self.grammar is not None:
|
||||
@@ -781,55 +858,16 @@ class Req:
|
||||
self.finished_reason = FINISH_MATCHED_TOKEN(matched=self.output_ids[-1])
|
||||
return
|
||||
|
||||
last_token_id = self.output_ids[-1]
|
||||
new_accepted_tokens = self.output_ids[-new_accepted_len:]
|
||||
|
||||
if not self.sampling_params.ignore_eos:
|
||||
matched_eos = False
|
||||
|
||||
# Check stop token ids
|
||||
if self.sampling_params.stop_token_ids:
|
||||
matched_eos = last_token_id in self.sampling_params.stop_token_ids
|
||||
if self.eos_token_ids:
|
||||
matched_eos |= last_token_id in self.eos_token_ids
|
||||
if self.tokenizer is not None:
|
||||
matched_eos |= last_token_id == self.tokenizer.eos_token_id
|
||||
if self.tokenizer.additional_stop_token_ids:
|
||||
matched_eos |= (
|
||||
last_token_id in self.tokenizer.additional_stop_token_ids
|
||||
)
|
||||
if matched_eos:
|
||||
self.finished_reason = FINISH_MATCHED_TOKEN(matched=last_token_id)
|
||||
return
|
||||
|
||||
if last_token_id > self.vocab_size or last_token_id < 0:
|
||||
if self.sampling_params.stop_token_ids:
|
||||
self.output_ids[-1] = next(iter(self.sampling_params.stop_token_ids))
|
||||
if self.eos_token_ids:
|
||||
self.output_ids[-1] = next(iter(self.eos_token_ids))
|
||||
self.finished_reason = FINISH_MATCHED_STR(matched="NaN happened")
|
||||
if self._check_token_based_finish(new_accepted_tokens):
|
||||
return
|
||||
|
||||
if (
|
||||
len(self.sampling_params.stop_strs) > 0
|
||||
or len(self.sampling_params.stop_regex_strs) > 0
|
||||
):
|
||||
tail_str = self.tail_str()
|
||||
if self._check_vocab_boundary_finish(new_accepted_tokens):
|
||||
return
|
||||
|
||||
# Check stop strings
|
||||
if len(self.sampling_params.stop_strs) > 0:
|
||||
for stop_str in self.sampling_params.stop_strs:
|
||||
if stop_str in tail_str or stop_str in self.decoded_text:
|
||||
self.finished_reason = FINISH_MATCHED_STR(matched=stop_str)
|
||||
return
|
||||
|
||||
# Check stop regex
|
||||
if len(self.sampling_params.stop_regex_strs) > 0:
|
||||
for stop_regex_str in self.sampling_params.stop_regex_strs:
|
||||
if re.search(stop_regex_str, tail_str):
|
||||
self.finished_reason = FINISHED_MATCHED_REGEX(
|
||||
matched=stop_regex_str
|
||||
)
|
||||
return
|
||||
if self._check_str_based_finish():
|
||||
return
|
||||
|
||||
def reset_for_retract(self):
|
||||
self.prefix_indices = torch.empty((0,), dtype=torch.int64)
|
||||
|
||||
@@ -286,13 +286,16 @@ class SchedulerOutputProcessorMixin:
|
||||
self.token_to_kv_pool_allocator.free(indices_to_free)
|
||||
continue
|
||||
|
||||
new_accepted_len = 1
|
||||
if batch.spec_algorithm.is_none():
|
||||
req.output_ids.append(next_token_id)
|
||||
elif batch.is_v2_eagle:
|
||||
# Only v2 eagle's output_ids are updated here.
|
||||
req.output_ids.extend(next_token_id)
|
||||
new_accepted_len = len(next_token_id)
|
||||
|
||||
req.check_finished(new_accepted_len)
|
||||
|
||||
req.check_finished()
|
||||
if req.finished():
|
||||
if batch.is_v2_eagle and self.cur_batch.forward_mode.is_extend():
|
||||
# FIXME(lsyin): fix the messy logic here
|
||||
@@ -734,6 +737,8 @@ class SchedulerOutputProcessorMixin:
|
||||
# because of the one additional delayed token. This "continue" prevented the dummy output.
|
||||
continue
|
||||
req.finished_output = True
|
||||
if req.finished_len is None:
|
||||
req.finished_len = len(req.output_ids)
|
||||
should_output = True
|
||||
else:
|
||||
if req.stream:
|
||||
@@ -776,17 +781,20 @@ class SchedulerOutputProcessorMixin:
|
||||
else:
|
||||
decode_ids_list.append(decode_ids[req.send_decode_id_offset :])
|
||||
|
||||
# Exclude the tokens after stop condition
|
||||
output_ids_ = req.output_ids_through_stop
|
||||
|
||||
req.send_decode_id_offset = len(decode_ids)
|
||||
read_offsets.append(read_offset)
|
||||
output_ids.append(req.output_ids[send_token_offset:])
|
||||
req.send_token_offset = len(req.output_ids)
|
||||
output_ids.append(output_ids_[send_token_offset:])
|
||||
req.send_token_offset = len(output_ids_)
|
||||
skip_special_tokens.append(req.sampling_params.skip_special_tokens)
|
||||
spaces_between_special_tokens.append(
|
||||
req.sampling_params.spaces_between_special_tokens
|
||||
)
|
||||
no_stop_trim.append(req.sampling_params.no_stop_trim)
|
||||
prompt_tokens.append(len(req.origin_input_ids))
|
||||
completion_tokens.append(len(req.output_ids))
|
||||
completion_tokens.append(len(output_ids_))
|
||||
cached_tokens.append(req.cached_tokens)
|
||||
|
||||
if not self.spec_algorithm.is_none():
|
||||
|
||||
Reference in New Issue
Block a user