diff --git a/python/sglang/srt/managers/detokenizer_manager.py b/python/sglang/srt/managers/detokenizer_manager.py index 8db07db71..ab2777ff3 100644 --- a/python/sglang/srt/managers/detokenizer_manager.py +++ b/python/sglang/srt/managers/detokenizer_manager.py @@ -142,6 +142,7 @@ class DetokenizerManager(MultiHttpWorkerDetokenizerMixin): if output[-1] == 200012 and self.is_tool_call_parser_gpt_oss: return output assert len(output) > 0 + # NOTE: We can always assume the last token is the matched stop token return output[:-1] return output diff --git a/python/sglang/srt/managers/schedule_batch.py b/python/sglang/srt/managers/schedule_batch.py index a39a7a535..ba095da9b 100644 --- a/python/sglang/srt/managers/schedule_batch.py +++ b/python/sglang/srt/managers/schedule_batch.py @@ -486,6 +486,8 @@ class Req: # Check finish self.tokenizer = None self.finished_reason = None + # finished position (in output_ids), used when checking stop conditions with speculative decoding + self.finished_len = None # Whether this request has finished output self.finished_output = None # If we want to abort the request in the middle of the event loop, set this to true @@ -651,6 +653,13 @@ class Req: spec_alg = get_global_server_args().speculative_algorithm return self.sampling_params.max_new_tokens == 0 and spec_alg is None + @property + def output_ids_through_stop(self) -> List[int]: + """Get the output ids through the stop condition. Stop position is included.""" + if self.finished_len is not None: + return self.output_ids[: self.finished_len] + return self.output_ids + def add_latency(self, stage: RequestStage): if self.metrics_collector is None: return @@ -702,18 +711,20 @@ class Req: def init_incremental_detokenize(self): first_iter = self.surr_offset is None or self.read_offset is None + output_ids = self.output_ids_through_stop + if first_iter: self.read_offset = len(self.origin_input_ids_unpadded) self.surr_offset = max( self.read_offset - INIT_INCREMENTAL_DETOKENIZATION_OFFSET, 0 ) self.surr_and_decode_ids = ( - self.origin_input_ids_unpadded[self.surr_offset :] + self.output_ids + self.origin_input_ids_unpadded[self.surr_offset :] + output_ids ) - self.cur_decode_ids_len = len(self.output_ids) + self.cur_decode_ids_len = len(output_ids) else: - self.surr_and_decode_ids.extend(self.output_ids[self.cur_decode_ids_len :]) - self.cur_decode_ids_len = len(self.output_ids) + self.surr_and_decode_ids.extend(output_ids[self.cur_decode_ids_len :]) + self.cur_decode_ids_len = len(output_ids) return self.surr_and_decode_ids, self.read_offset - self.surr_offset @@ -760,7 +771,72 @@ class Req: return False - def check_finished(self): + def _check_token_based_finish(self, new_accepted_tokens: List[int]) -> bool: + if self.sampling_params.ignore_eos: + return False + + # Check stop token ids + matched_eos = False + + for i, token_id in enumerate(new_accepted_tokens): + if self.sampling_params.stop_token_ids: + matched_eos |= token_id in self.sampling_params.stop_token_ids + if self.eos_token_ids: + matched_eos |= token_id in self.eos_token_ids + if self.tokenizer is not None: + matched_eos |= token_id == self.tokenizer.eos_token_id + if self.tokenizer.additional_stop_token_ids: + matched_eos |= token_id in self.tokenizer.additional_stop_token_ids + if matched_eos: + self.finished_reason = FINISH_MATCHED_TOKEN(matched=token_id) + matched_pos = len(self.output_ids) - len(new_accepted_tokens) + i + self.finished_len = matched_pos + 1 + return True + + return False + + def _check_str_based_finish(self): + if ( + len(self.sampling_params.stop_strs) > 0 + or len(self.sampling_params.stop_regex_strs) > 0 + ): + tail_str = self.tail_str() + + # Check stop strings + if len(self.sampling_params.stop_strs) > 0: + for stop_str in self.sampling_params.stop_strs: + if stop_str in tail_str or stop_str in self.decoded_text: + self.finished_reason = FINISH_MATCHED_STR(matched=stop_str) + return True + + # Check stop regex + if len(self.sampling_params.stop_regex_strs) > 0: + for stop_regex_str in self.sampling_params.stop_regex_strs: + if re.search(stop_regex_str, tail_str): + self.finished_reason = FINISHED_MATCHED_REGEX( + matched=stop_regex_str + ) + return True + + return False + + def _check_vocab_boundary_finish(self, new_accepted_tokens: List[int] = None): + for i, token_id in enumerate(new_accepted_tokens): + if token_id > self.vocab_size or token_id < 0: + offset = len(self.output_ids) - len(new_accepted_tokens) + i + if self.sampling_params.stop_token_ids: + self.output_ids[offset] = next( + iter(self.sampling_params.stop_token_ids) + ) + if self.eos_token_ids: + self.output_ids[offset] = next(iter(self.eos_token_ids)) + self.finished_reason = FINISH_MATCHED_STR(matched="NaN happened") + self.finished_len = offset + 1 + return True + + return False + + def check_finished(self, new_accepted_len: int = 1): if self.finished(): return @@ -774,6 +850,7 @@ class Req: self.finished_reason = FINISH_LENGTH( length=self.sampling_params.max_new_tokens ) + self.finished_len = self.sampling_params.max_new_tokens return if self.grammar is not None: @@ -781,55 +858,16 @@ class Req: self.finished_reason = FINISH_MATCHED_TOKEN(matched=self.output_ids[-1]) return - last_token_id = self.output_ids[-1] + new_accepted_tokens = self.output_ids[-new_accepted_len:] - if not self.sampling_params.ignore_eos: - matched_eos = False - - # Check stop token ids - if self.sampling_params.stop_token_ids: - matched_eos = last_token_id in self.sampling_params.stop_token_ids - if self.eos_token_ids: - matched_eos |= last_token_id in self.eos_token_ids - if self.tokenizer is not None: - matched_eos |= last_token_id == self.tokenizer.eos_token_id - if self.tokenizer.additional_stop_token_ids: - matched_eos |= ( - last_token_id in self.tokenizer.additional_stop_token_ids - ) - if matched_eos: - self.finished_reason = FINISH_MATCHED_TOKEN(matched=last_token_id) - return - - if last_token_id > self.vocab_size or last_token_id < 0: - if self.sampling_params.stop_token_ids: - self.output_ids[-1] = next(iter(self.sampling_params.stop_token_ids)) - if self.eos_token_ids: - self.output_ids[-1] = next(iter(self.eos_token_ids)) - self.finished_reason = FINISH_MATCHED_STR(matched="NaN happened") + if self._check_token_based_finish(new_accepted_tokens): return - if ( - len(self.sampling_params.stop_strs) > 0 - or len(self.sampling_params.stop_regex_strs) > 0 - ): - tail_str = self.tail_str() + if self._check_vocab_boundary_finish(new_accepted_tokens): + return - # Check stop strings - if len(self.sampling_params.stop_strs) > 0: - for stop_str in self.sampling_params.stop_strs: - if stop_str in tail_str or stop_str in self.decoded_text: - self.finished_reason = FINISH_MATCHED_STR(matched=stop_str) - return - - # Check stop regex - if len(self.sampling_params.stop_regex_strs) > 0: - for stop_regex_str in self.sampling_params.stop_regex_strs: - if re.search(stop_regex_str, tail_str): - self.finished_reason = FINISHED_MATCHED_REGEX( - matched=stop_regex_str - ) - return + if self._check_str_based_finish(): + return def reset_for_retract(self): self.prefix_indices = torch.empty((0,), dtype=torch.int64) diff --git a/python/sglang/srt/managers/scheduler_output_processor_mixin.py b/python/sglang/srt/managers/scheduler_output_processor_mixin.py index 55ce5ebd5..d1d78efb6 100644 --- a/python/sglang/srt/managers/scheduler_output_processor_mixin.py +++ b/python/sglang/srt/managers/scheduler_output_processor_mixin.py @@ -286,13 +286,16 @@ class SchedulerOutputProcessorMixin: self.token_to_kv_pool_allocator.free(indices_to_free) continue + new_accepted_len = 1 if batch.spec_algorithm.is_none(): req.output_ids.append(next_token_id) elif batch.is_v2_eagle: # Only v2 eagle's output_ids are updated here. req.output_ids.extend(next_token_id) + new_accepted_len = len(next_token_id) + + req.check_finished(new_accepted_len) - req.check_finished() if req.finished(): if batch.is_v2_eagle and self.cur_batch.forward_mode.is_extend(): # FIXME(lsyin): fix the messy logic here @@ -734,6 +737,8 @@ class SchedulerOutputProcessorMixin: # because of the one additional delayed token. This "continue" prevented the dummy output. continue req.finished_output = True + if req.finished_len is None: + req.finished_len = len(req.output_ids) should_output = True else: if req.stream: @@ -776,17 +781,20 @@ class SchedulerOutputProcessorMixin: else: decode_ids_list.append(decode_ids[req.send_decode_id_offset :]) + # Exclude the tokens after stop condition + output_ids_ = req.output_ids_through_stop + req.send_decode_id_offset = len(decode_ids) read_offsets.append(read_offset) - output_ids.append(req.output_ids[send_token_offset:]) - req.send_token_offset = len(req.output_ids) + output_ids.append(output_ids_[send_token_offset:]) + req.send_token_offset = len(output_ids_) skip_special_tokens.append(req.sampling_params.skip_special_tokens) spaces_between_special_tokens.append( req.sampling_params.spaces_between_special_tokens ) no_stop_trim.append(req.sampling_params.no_stop_trim) prompt_tokens.append(len(req.origin_input_ids)) - completion_tokens.append(len(req.output_ids)) + completion_tokens.append(len(output_ids_)) cached_tokens.append(req.cached_tokens) if not self.spec_algorithm.is_none():