[Feature] support regex strings as a stopping condition (#10635)

This commit is contained in:
Glen Liu
2025-10-11 22:53:15 -04:00
committed by GitHub
parent 9fcf73069f
commit 47c606d3dc
9 changed files with 219 additions and 8 deletions

View File

@@ -36,6 +36,7 @@ TODO(lmzheng): ModelWorkerBatch seems a bit redundant and we consider removing i
import copy
import dataclasses
import logging
import re
import threading
import time
from enum import Enum, auto
@@ -154,6 +155,18 @@ class FINISH_MATCHED_STR(BaseFinishReason):
}
class FINISHED_MATCHED_REGEX(BaseFinishReason):
def __init__(self, matched: str):
super().__init__()
self.matched = matched
def to_json(self):
return {
"type": "stop", # to match OpenAI API's return value
"matched": self.matched,
}
class FINISH_LENGTH(BaseFinishReason):
def __init__(self, length: int):
super().__init__()
@@ -735,8 +748,17 @@ class Req:
return self.surr_and_decode_ids, self.read_offset - self.surr_offset
def tail_str(self) -> str:
tail_len = self.sampling_params.stop_str_max_len + 1
tail_len = min(tail_len, len(self.output_ids))
# Check stop strings and stop regex patterns together
if (
len(self.sampling_params.stop_strs) > 0
or len(self.sampling_params.stop_regex_strs) > 0
):
max_len_tail_str = max(
self.sampling_params.stop_str_max_len + 1,
self.sampling_params.stop_regex_max_len + 1,
)
tail_len = min((max_len_tail_str + 1), len(self.output_ids))
return self.tokenizer.decode(self.output_ids[-tail_len:])
def check_match_stop_str_prefix(self) -> bool:
@@ -817,14 +839,27 @@ class Req:
self.finished_reason = FINISH_MATCHED_STR(matched="NaN happened")
return
# Check stop strings
if len(self.sampling_params.stop_strs) > 0:
if (
len(self.sampling_params.stop_strs) > 0
or len(self.sampling_params.stop_regex_strs) > 0
):
tail_str = self.tail_str()
for stop_str in self.sampling_params.stop_strs:
if stop_str in tail_str or stop_str in self.decoded_text:
self.finished_reason = FINISH_MATCHED_STR(matched=stop_str)
return
# Check stop strings
if len(self.sampling_params.stop_strs) > 0:
for stop_str in self.sampling_params.stop_strs:
if stop_str in tail_str or stop_str in self.decoded_text:
self.finished_reason = FINISH_MATCHED_STR(matched=stop_str)
return
# Check stop regex
if len(self.sampling_params.stop_regex_strs) > 0:
for stop_regex_str in self.sampling_params.stop_regex_strs:
if re.search(stop_regex_str, tail_str):
self.finished_reason = FINISHED_MATCHED_REGEX(
matched=stop_regex_str
)
return
def reset_for_retract(self):
self.prefix_indices = torch.empty((0,), dtype=torch.int64)