fix stop when stream (#11462)
Signed-off-by: ybyang <ybyang7@iflytek.com> Co-authored-by: Liangsheng Yin <lsyincs@gmail.com> Co-authored-by: Liangsheng Yin <hnyls2002@gmail.com>
This commit is contained in:
@@ -734,6 +734,40 @@ class Req:
|
||||
|
||||
return self.surr_and_decode_ids, self.read_offset - self.surr_offset
|
||||
|
||||
def tail_str(self) -> str:
|
||||
tail_len = self.sampling_params.stop_str_max_len + 1
|
||||
tail_len = min(tail_len, len(self.output_ids))
|
||||
return self.tokenizer.decode(self.output_ids[-tail_len:])
|
||||
|
||||
def check_match_stop_str_prefix(self) -> bool:
|
||||
"""
|
||||
Check if the suffix of tail_str overlaps with any stop_str prefix
|
||||
"""
|
||||
if not self.sampling_params.stop_strs:
|
||||
return False
|
||||
|
||||
tail_str = self.tail_str()
|
||||
|
||||
# Early return if tail_str is empty
|
||||
if not tail_str:
|
||||
return False
|
||||
|
||||
for stop_str in self.sampling_params.stop_strs:
|
||||
if not stop_str:
|
||||
continue
|
||||
# Check if stop_str is contained in tail_str (fastest check first)
|
||||
if stop_str in tail_str:
|
||||
return True
|
||||
|
||||
# Check if tail_str suffix matches stop_str prefix
|
||||
# Only check if stop_str is not empty, it's for stream output
|
||||
min_len = min(len(tail_str), len(stop_str))
|
||||
for i in range(1, min_len + 1):
|
||||
if tail_str[-i:] == stop_str[:i]:
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
def check_finished(self):
|
||||
if self.finished():
|
||||
return
|
||||
@@ -785,9 +819,7 @@ class Req:
|
||||
|
||||
# Check stop strings
|
||||
if len(self.sampling_params.stop_strs) > 0:
|
||||
tail_str = self.tokenizer.decode(
|
||||
self.output_ids[-(self.sampling_params.stop_str_max_len + 1) :]
|
||||
)
|
||||
tail_str = self.tail_str()
|
||||
|
||||
for stop_str in self.sampling_params.stop_strs:
|
||||
if stop_str in tail_str or stop_str in self.decoded_text:
|
||||
|
||||
@@ -680,12 +680,18 @@ class SchedulerOutputProcessorMixin:
|
||||
stream_interval = (
|
||||
req.sampling_params.stream_interval or self.stream_interval
|
||||
)
|
||||
|
||||
# origin stream_interval logic
|
||||
should_output = (
|
||||
len(req.output_ids) % stream_interval == 1
|
||||
if not self.model_config.is_multimodal_gen
|
||||
and stream_interval > 1
|
||||
else len(req.output_ids) % stream_interval == 0
|
||||
)
|
||||
|
||||
if should_output:
|
||||
# check_match_stop_str_prefix if tail_str's suffix match stop_str prefix
|
||||
should_output &= not req.check_match_stop_str_prefix()
|
||||
else:
|
||||
should_output = (
|
||||
len(req.output_ids) % DEFAULT_FORCE_STREAM_INTERVAL == 0
|
||||
|
||||
Reference in New Issue
Block a user