diff --git a/python/sglang/srt/managers/schedule_batch.py b/python/sglang/srt/managers/schedule_batch.py index 3b18b9452..6eab0a088 100644 --- a/python/sglang/srt/managers/schedule_batch.py +++ b/python/sglang/srt/managers/schedule_batch.py @@ -734,6 +734,40 @@ class Req: return self.surr_and_decode_ids, self.read_offset - self.surr_offset + def tail_str(self) -> str: + tail_len = self.sampling_params.stop_str_max_len + 1 + tail_len = min(tail_len, len(self.output_ids)) + return self.tokenizer.decode(self.output_ids[-tail_len:]) + + def check_match_stop_str_prefix(self) -> bool: + """ + Check if the suffix of tail_str overlaps with any stop_str prefix + """ + if not self.sampling_params.stop_strs: + return False + + tail_str = self.tail_str() + + # Early return if tail_str is empty + if not tail_str: + return False + + for stop_str in self.sampling_params.stop_strs: + if not stop_str: + continue + # Check if stop_str is contained in tail_str (fastest check first) + if stop_str in tail_str: + return True + + # Check if tail_str suffix matches stop_str prefix + # Only check if stop_str is not empty, it's for stream output + min_len = min(len(tail_str), len(stop_str)) + for i in range(1, min_len + 1): + if tail_str[-i:] == stop_str[:i]: + return True + + return False + def check_finished(self): if self.finished(): return @@ -785,9 +819,7 @@ class Req: # Check stop strings if len(self.sampling_params.stop_strs) > 0: - tail_str = self.tokenizer.decode( - self.output_ids[-(self.sampling_params.stop_str_max_len + 1) :] - ) + tail_str = self.tail_str() for stop_str in self.sampling_params.stop_strs: if stop_str in tail_str or stop_str in self.decoded_text: diff --git a/python/sglang/srt/managers/scheduler_output_processor_mixin.py b/python/sglang/srt/managers/scheduler_output_processor_mixin.py index 2072f9f68..a224bdc34 100644 --- a/python/sglang/srt/managers/scheduler_output_processor_mixin.py +++ b/python/sglang/srt/managers/scheduler_output_processor_mixin.py @@ -680,12 +680,18 @@ class SchedulerOutputProcessorMixin: stream_interval = ( req.sampling_params.stream_interval or self.stream_interval ) + + # origin stream_interval logic should_output = ( len(req.output_ids) % stream_interval == 1 if not self.model_config.is_multimodal_gen and stream_interval > 1 else len(req.output_ids) % stream_interval == 0 ) + + if should_output: + # check_match_stop_str_prefix if tail_str's suffix match stop_str prefix + should_output &= not req.check_match_stop_str_prefix() else: should_output = ( len(req.output_ids) % DEFAULT_FORCE_STREAM_INTERVAL == 0