fix stop when stream (#11462)
Signed-off-by: ybyang <ybyang7@iflytek.com> Co-authored-by: Liangsheng Yin <lsyincs@gmail.com> Co-authored-by: Liangsheng Yin <hnyls2002@gmail.com>
This commit is contained in:
@@ -734,6 +734,40 @@ class Req:
|
|||||||
|
|
||||||
return self.surr_and_decode_ids, self.read_offset - self.surr_offset
|
return self.surr_and_decode_ids, self.read_offset - self.surr_offset
|
||||||
|
|
||||||
|
def tail_str(self) -> str:
|
||||||
|
tail_len = self.sampling_params.stop_str_max_len + 1
|
||||||
|
tail_len = min(tail_len, len(self.output_ids))
|
||||||
|
return self.tokenizer.decode(self.output_ids[-tail_len:])
|
||||||
|
|
||||||
|
def check_match_stop_str_prefix(self) -> bool:
|
||||||
|
"""
|
||||||
|
Check if the suffix of tail_str overlaps with any stop_str prefix
|
||||||
|
"""
|
||||||
|
if not self.sampling_params.stop_strs:
|
||||||
|
return False
|
||||||
|
|
||||||
|
tail_str = self.tail_str()
|
||||||
|
|
||||||
|
# Early return if tail_str is empty
|
||||||
|
if not tail_str:
|
||||||
|
return False
|
||||||
|
|
||||||
|
for stop_str in self.sampling_params.stop_strs:
|
||||||
|
if not stop_str:
|
||||||
|
continue
|
||||||
|
# Check if stop_str is contained in tail_str (fastest check first)
|
||||||
|
if stop_str in tail_str:
|
||||||
|
return True
|
||||||
|
|
||||||
|
# Check if tail_str suffix matches stop_str prefix
|
||||||
|
# Only check if stop_str is not empty, it's for stream output
|
||||||
|
min_len = min(len(tail_str), len(stop_str))
|
||||||
|
for i in range(1, min_len + 1):
|
||||||
|
if tail_str[-i:] == stop_str[:i]:
|
||||||
|
return True
|
||||||
|
|
||||||
|
return False
|
||||||
|
|
||||||
def check_finished(self):
|
def check_finished(self):
|
||||||
if self.finished():
|
if self.finished():
|
||||||
return
|
return
|
||||||
@@ -785,9 +819,7 @@ class Req:
|
|||||||
|
|
||||||
# Check stop strings
|
# Check stop strings
|
||||||
if len(self.sampling_params.stop_strs) > 0:
|
if len(self.sampling_params.stop_strs) > 0:
|
||||||
tail_str = self.tokenizer.decode(
|
tail_str = self.tail_str()
|
||||||
self.output_ids[-(self.sampling_params.stop_str_max_len + 1) :]
|
|
||||||
)
|
|
||||||
|
|
||||||
for stop_str in self.sampling_params.stop_strs:
|
for stop_str in self.sampling_params.stop_strs:
|
||||||
if stop_str in tail_str or stop_str in self.decoded_text:
|
if stop_str in tail_str or stop_str in self.decoded_text:
|
||||||
|
|||||||
@@ -680,12 +680,18 @@ class SchedulerOutputProcessorMixin:
|
|||||||
stream_interval = (
|
stream_interval = (
|
||||||
req.sampling_params.stream_interval or self.stream_interval
|
req.sampling_params.stream_interval or self.stream_interval
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# origin stream_interval logic
|
||||||
should_output = (
|
should_output = (
|
||||||
len(req.output_ids) % stream_interval == 1
|
len(req.output_ids) % stream_interval == 1
|
||||||
if not self.model_config.is_multimodal_gen
|
if not self.model_config.is_multimodal_gen
|
||||||
and stream_interval > 1
|
and stream_interval > 1
|
||||||
else len(req.output_ids) % stream_interval == 0
|
else len(req.output_ids) % stream_interval == 0
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if should_output:
|
||||||
|
# check_match_stop_str_prefix if tail_str's suffix match stop_str prefix
|
||||||
|
should_output &= not req.check_match_stop_str_prefix()
|
||||||
else:
|
else:
|
||||||
should_output = (
|
should_output = (
|
||||||
len(req.output_ids) % DEFAULT_FORCE_STREAM_INTERVAL == 0
|
len(req.output_ids) % DEFAULT_FORCE_STREAM_INTERVAL == 0
|
||||||
|
|||||||
Reference in New Issue
Block a user