[overlap-spec] fix stop condition and trimming (#11819)

This commit is contained in:
Liangsheng Yin
2025-10-19 22:00:20 +08:00
committed by GitHub
parent 57e25de756
commit d658f0497e
3 changed files with 101 additions and 54 deletions

View File

@@ -142,6 +142,7 @@ class DetokenizerManager(MultiHttpWorkerDetokenizerMixin):
if output[-1] == 200012 and self.is_tool_call_parser_gpt_oss:
return output
assert len(output) > 0
# NOTE: We can always assume the last token is the matched stop token
return output[:-1]
return output

View File

@@ -486,6 +486,8 @@ class Req:
# Check finish
self.tokenizer = None
self.finished_reason = None
# finished position (in output_ids), used when checking stop conditions with speculative decoding
self.finished_len = None
# Whether this request has finished output
self.finished_output = None
# If we want to abort the request in the middle of the event loop, set this to true
@@ -651,6 +653,13 @@ class Req:
spec_alg = get_global_server_args().speculative_algorithm
return self.sampling_params.max_new_tokens == 0 and spec_alg is None
@property
def output_ids_through_stop(self) -> List[int]:
"""Get the output ids through the stop condition. Stop position is included."""
if self.finished_len is not None:
return self.output_ids[: self.finished_len]
return self.output_ids
def add_latency(self, stage: RequestStage):
if self.metrics_collector is None:
return
@@ -702,18 +711,20 @@ class Req:
def init_incremental_detokenize(self):
first_iter = self.surr_offset is None or self.read_offset is None
output_ids = self.output_ids_through_stop
if first_iter:
self.read_offset = len(self.origin_input_ids_unpadded)
self.surr_offset = max(
self.read_offset - INIT_INCREMENTAL_DETOKENIZATION_OFFSET, 0
)
self.surr_and_decode_ids = (
self.origin_input_ids_unpadded[self.surr_offset :] + self.output_ids
self.origin_input_ids_unpadded[self.surr_offset :] + output_ids
)
self.cur_decode_ids_len = len(self.output_ids)
self.cur_decode_ids_len = len(output_ids)
else:
self.surr_and_decode_ids.extend(self.output_ids[self.cur_decode_ids_len :])
self.cur_decode_ids_len = len(self.output_ids)
self.surr_and_decode_ids.extend(output_ids[self.cur_decode_ids_len :])
self.cur_decode_ids_len = len(output_ids)
return self.surr_and_decode_ids, self.read_offset - self.surr_offset
@@ -760,7 +771,72 @@ class Req:
return False
def check_finished(self):
def _check_token_based_finish(self, new_accepted_tokens: List[int]) -> bool:
if self.sampling_params.ignore_eos:
return False
# Check stop token ids
matched_eos = False
for i, token_id in enumerate(new_accepted_tokens):
if self.sampling_params.stop_token_ids:
matched_eos |= token_id in self.sampling_params.stop_token_ids
if self.eos_token_ids:
matched_eos |= token_id in self.eos_token_ids
if self.tokenizer is not None:
matched_eos |= token_id == self.tokenizer.eos_token_id
if self.tokenizer.additional_stop_token_ids:
matched_eos |= token_id in self.tokenizer.additional_stop_token_ids
if matched_eos:
self.finished_reason = FINISH_MATCHED_TOKEN(matched=token_id)
matched_pos = len(self.output_ids) - len(new_accepted_tokens) + i
self.finished_len = matched_pos + 1
return True
return False
def _check_str_based_finish(self):
if (
len(self.sampling_params.stop_strs) > 0
or len(self.sampling_params.stop_regex_strs) > 0
):
tail_str = self.tail_str()
# Check stop strings
if len(self.sampling_params.stop_strs) > 0:
for stop_str in self.sampling_params.stop_strs:
if stop_str in tail_str or stop_str in self.decoded_text:
self.finished_reason = FINISH_MATCHED_STR(matched=stop_str)
return True
# Check stop regex
if len(self.sampling_params.stop_regex_strs) > 0:
for stop_regex_str in self.sampling_params.stop_regex_strs:
if re.search(stop_regex_str, tail_str):
self.finished_reason = FINISHED_MATCHED_REGEX(
matched=stop_regex_str
)
return True
return False
def _check_vocab_boundary_finish(self, new_accepted_tokens: List[int] = None):
for i, token_id in enumerate(new_accepted_tokens):
if token_id > self.vocab_size or token_id < 0:
offset = len(self.output_ids) - len(new_accepted_tokens) + i
if self.sampling_params.stop_token_ids:
self.output_ids[offset] = next(
iter(self.sampling_params.stop_token_ids)
)
if self.eos_token_ids:
self.output_ids[offset] = next(iter(self.eos_token_ids))
self.finished_reason = FINISH_MATCHED_STR(matched="NaN happened")
self.finished_len = offset + 1
return True
return False
def check_finished(self, new_accepted_len: int = 1):
if self.finished():
return
@@ -774,6 +850,7 @@ class Req:
self.finished_reason = FINISH_LENGTH(
length=self.sampling_params.max_new_tokens
)
self.finished_len = self.sampling_params.max_new_tokens
return
if self.grammar is not None:
@@ -781,55 +858,16 @@ class Req:
self.finished_reason = FINISH_MATCHED_TOKEN(matched=self.output_ids[-1])
return
last_token_id = self.output_ids[-1]
new_accepted_tokens = self.output_ids[-new_accepted_len:]
if not self.sampling_params.ignore_eos:
matched_eos = False
# Check stop token ids
if self.sampling_params.stop_token_ids:
matched_eos = last_token_id in self.sampling_params.stop_token_ids
if self.eos_token_ids:
matched_eos |= last_token_id in self.eos_token_ids
if self.tokenizer is not None:
matched_eos |= last_token_id == self.tokenizer.eos_token_id
if self.tokenizer.additional_stop_token_ids:
matched_eos |= (
last_token_id in self.tokenizer.additional_stop_token_ids
)
if matched_eos:
self.finished_reason = FINISH_MATCHED_TOKEN(matched=last_token_id)
return
if last_token_id > self.vocab_size or last_token_id < 0:
if self.sampling_params.stop_token_ids:
self.output_ids[-1] = next(iter(self.sampling_params.stop_token_ids))
if self.eos_token_ids:
self.output_ids[-1] = next(iter(self.eos_token_ids))
self.finished_reason = FINISH_MATCHED_STR(matched="NaN happened")
if self._check_token_based_finish(new_accepted_tokens):
return
if (
len(self.sampling_params.stop_strs) > 0
or len(self.sampling_params.stop_regex_strs) > 0
):
tail_str = self.tail_str()
if self._check_vocab_boundary_finish(new_accepted_tokens):
return
# Check stop strings
if len(self.sampling_params.stop_strs) > 0:
for stop_str in self.sampling_params.stop_strs:
if stop_str in tail_str or stop_str in self.decoded_text:
self.finished_reason = FINISH_MATCHED_STR(matched=stop_str)
return
# Check stop regex
if len(self.sampling_params.stop_regex_strs) > 0:
for stop_regex_str in self.sampling_params.stop_regex_strs:
if re.search(stop_regex_str, tail_str):
self.finished_reason = FINISHED_MATCHED_REGEX(
matched=stop_regex_str
)
return
if self._check_str_based_finish():
return
def reset_for_retract(self):
self.prefix_indices = torch.empty((0,), dtype=torch.int64)

View File

@@ -286,13 +286,16 @@ class SchedulerOutputProcessorMixin:
self.token_to_kv_pool_allocator.free(indices_to_free)
continue
new_accepted_len = 1
if batch.spec_algorithm.is_none():
req.output_ids.append(next_token_id)
elif batch.is_v2_eagle:
# Only v2 eagle's output_ids are updated here.
req.output_ids.extend(next_token_id)
new_accepted_len = len(next_token_id)
req.check_finished(new_accepted_len)
req.check_finished()
if req.finished():
if batch.is_v2_eagle and self.cur_batch.forward_mode.is_extend():
# FIXME(lsyin): fix the messy logic here
@@ -734,6 +737,8 @@ class SchedulerOutputProcessorMixin:
# because of the one additional delayed token. This "continue" prevented the dummy output.
continue
req.finished_output = True
if req.finished_len is None:
req.finished_len = len(req.output_ids)
should_output = True
else:
if req.stream:
@@ -776,17 +781,20 @@ class SchedulerOutputProcessorMixin:
else:
decode_ids_list.append(decode_ids[req.send_decode_id_offset :])
# Exclude the tokens after stop condition
output_ids_ = req.output_ids_through_stop
req.send_decode_id_offset = len(decode_ids)
read_offsets.append(read_offset)
output_ids.append(req.output_ids[send_token_offset:])
req.send_token_offset = len(req.output_ids)
output_ids.append(output_ids_[send_token_offset:])
req.send_token_offset = len(output_ids_)
skip_special_tokens.append(req.sampling_params.skip_special_tokens)
spaces_between_special_tokens.append(
req.sampling_params.spaces_between_special_tokens
)
no_stop_trim.append(req.sampling_params.no_stop_trim)
prompt_tokens.append(len(req.origin_input_ids))
completion_tokens.append(len(req.output_ids))
completion_tokens.append(len(output_ids_))
cached_tokens.append(req.cached_tokens)
if not self.spec_algorithm.is_none():