[Feature] support regex strings as a stopping condition (#10635)
This commit is contained in:
@@ -36,6 +36,7 @@ TODO(lmzheng): ModelWorkerBatch seems a bit redundant and we consider removing i
|
||||
import copy
|
||||
import dataclasses
|
||||
import logging
|
||||
import re
|
||||
import threading
|
||||
import time
|
||||
from enum import Enum, auto
|
||||
@@ -154,6 +155,18 @@ class FINISH_MATCHED_STR(BaseFinishReason):
|
||||
}
|
||||
|
||||
|
||||
class FINISHED_MATCHED_REGEX(BaseFinishReason):
|
||||
def __init__(self, matched: str):
|
||||
super().__init__()
|
||||
self.matched = matched
|
||||
|
||||
def to_json(self):
|
||||
return {
|
||||
"type": "stop", # to match OpenAI API's return value
|
||||
"matched": self.matched,
|
||||
}
|
||||
|
||||
|
||||
class FINISH_LENGTH(BaseFinishReason):
|
||||
def __init__(self, length: int):
|
||||
super().__init__()
|
||||
@@ -735,8 +748,17 @@ class Req:
|
||||
return self.surr_and_decode_ids, self.read_offset - self.surr_offset
|
||||
|
||||
def tail_str(self) -> str:
|
||||
tail_len = self.sampling_params.stop_str_max_len + 1
|
||||
tail_len = min(tail_len, len(self.output_ids))
|
||||
# Check stop strings and stop regex patterns together
|
||||
if (
|
||||
len(self.sampling_params.stop_strs) > 0
|
||||
or len(self.sampling_params.stop_regex_strs) > 0
|
||||
):
|
||||
max_len_tail_str = max(
|
||||
self.sampling_params.stop_str_max_len + 1,
|
||||
self.sampling_params.stop_regex_max_len + 1,
|
||||
)
|
||||
|
||||
tail_len = min((max_len_tail_str + 1), len(self.output_ids))
|
||||
return self.tokenizer.decode(self.output_ids[-tail_len:])
|
||||
|
||||
def check_match_stop_str_prefix(self) -> bool:
|
||||
@@ -817,14 +839,27 @@ class Req:
|
||||
self.finished_reason = FINISH_MATCHED_STR(matched="NaN happened")
|
||||
return
|
||||
|
||||
# Check stop strings
|
||||
if len(self.sampling_params.stop_strs) > 0:
|
||||
if (
|
||||
len(self.sampling_params.stop_strs) > 0
|
||||
or len(self.sampling_params.stop_regex_strs) > 0
|
||||
):
|
||||
tail_str = self.tail_str()
|
||||
|
||||
for stop_str in self.sampling_params.stop_strs:
|
||||
if stop_str in tail_str or stop_str in self.decoded_text:
|
||||
self.finished_reason = FINISH_MATCHED_STR(matched=stop_str)
|
||||
return
|
||||
# Check stop strings
|
||||
if len(self.sampling_params.stop_strs) > 0:
|
||||
for stop_str in self.sampling_params.stop_strs:
|
||||
if stop_str in tail_str or stop_str in self.decoded_text:
|
||||
self.finished_reason = FINISH_MATCHED_STR(matched=stop_str)
|
||||
return
|
||||
|
||||
# Check stop regex
|
||||
if len(self.sampling_params.stop_regex_strs) > 0:
|
||||
for stop_regex_str in self.sampling_params.stop_regex_strs:
|
||||
if re.search(stop_regex_str, tail_str):
|
||||
self.finished_reason = FINISHED_MATCHED_REGEX(
|
||||
matched=stop_regex_str
|
||||
)
|
||||
return
|
||||
|
||||
def reset_for_retract(self):
|
||||
self.prefix_indices = torch.empty((0,), dtype=torch.int64)
|
||||
|
||||
Reference in New Issue
Block a user