[Feature] support regex strings as a stopping condition (#10635)

This commit is contained in:
Glen Liu
2025-10-11 22:53:15 -04:00
committed by GitHub
parent 9fcf73069f
commit 47c606d3dc
9 changed files with 219 additions and 8 deletions

View File

@@ -79,6 +79,7 @@ def gen(
n: Optional[int] = None,
stop: Optional[Union[str, List[str]]] = None,
stop_token_ids: Optional[List[int]] = None,
stop_regex: Optional[Union[str, List[str]]] = None,
temperature: Optional[float] = None,
top_p: Optional[float] = None,
top_k: Optional[int] = None,
@@ -120,6 +121,7 @@ def gen(
n,
stop,
stop_token_ids,
stop_regex,
temperature,
top_p,
top_k,
@@ -143,6 +145,7 @@ def gen_int(
n: Optional[int] = None,
stop: Optional[Union[str, List[str]]] = None,
stop_token_ids: Optional[List[int]] = None,
stop_regex: Optional[Union[str, List[str]]] = None,
temperature: Optional[float] = None,
top_p: Optional[float] = None,
top_k: Optional[int] = None,
@@ -162,6 +165,7 @@ def gen_int(
n,
stop,
stop_token_ids,
stop_regex,
temperature,
top_p,
top_k,
@@ -184,6 +188,7 @@ def gen_string(
n: Optional[int] = None,
stop: Optional[Union[str, List[str]]] = None,
stop_token_ids: Optional[List[int]] = None,
stop_regex: Optional[Union[str, List[str]]] = None,
temperature: Optional[float] = None,
top_p: Optional[float] = None,
top_k: Optional[int] = None,
@@ -203,6 +208,7 @@ def gen_string(
n,
stop,
stop_token_ids,
stop_regex,
temperature,
top_p,
top_k,

View File

@@ -792,6 +792,7 @@ class StreamExecutor:
"n",
"stop",
"stop_token_ids",
"stop_regex",
"temperature",
"top_p",
"top_k",

View File

@@ -21,6 +21,7 @@ class SglSamplingParams:
n: int = 1
stop: Union[str, List[str]] = ()
stop_token_ids: Optional[List[int]] = ()
stop_regex: Optional[Union[str, List[str]]] = ()
temperature: float = 1.0
top_p: float = 1.0
top_k: int = -1 # -1 means disable
@@ -45,6 +46,7 @@ class SglSamplingParams:
self.n,
self.stop,
self.stop_token_ids,
self.stop_regex,
self.temperature,
self.top_p,
self.top_k,
@@ -123,6 +125,7 @@ class SglSamplingParams:
"n": self.n,
"stop": self.stop,
"stop_token_ids": self.stop_token_ids,
"stop_regex": self.stop_regex,
"temperature": self.temperature,
"top_p": self.top_p,
"top_k": self.top_k,
@@ -161,6 +164,7 @@ class SglFunction:
n: int = 1,
stop: Optional[Union[str, List[str]]] = None,
stop_token_ids: Optional[List[int]] = None,
stop_regex: Optional[Union[str, List[str]]] = None,
temperature: float = 1.0,
top_p: float = 1.0,
top_k: int = -1,
@@ -184,12 +188,15 @@ class SglFunction:
stop = []
if stop_token_ids is None:
stop_token_ids = []
if stop_regex is None:
stop_regex = []
default_sampling_para = SglSamplingParams(
max_new_tokens=max_new_tokens,
n=n,
stop=stop,
stop_token_ids=stop_token_ids,
stop_regex=stop_regex,
temperature=temperature,
top_p=top_p,
top_k=top_k,
@@ -221,6 +228,7 @@ class SglFunction:
n: int = 1,
stop: Optional[Union[str, List[str]]] = None,
stop_token_ids: Optional[List[int]] = None,
stop_regex: Optional[Union[str, List[str]]] = None,
temperature: float = 1.0,
top_p: float = 1.0,
top_k: int = -1,
@@ -243,6 +251,8 @@ class SglFunction:
stop = []
if stop_token_ids is None:
stop_token_ids = []
if stop_regex is None:
stop_regex = []
assert isinstance(batch_kwargs, (list, tuple))
if len(batch_kwargs) == 0:
@@ -267,6 +277,7 @@ class SglFunction:
n=n,
stop=stop,
stop_token_ids=stop_token_ids,
stop_regex=stop_regex,
temperature=temperature,
top_p=top_p,
top_k=top_k,
@@ -451,6 +462,7 @@ class SglGen(SglExpr):
n: Optional[int] = None,
stop: Optional[Union[str, List[str]]] = None,
stop_token_ids: Optional[List[int]] = None,
stop_regex: Optional[Union[str, List[str]]] = None,
temperature: Optional[float] = None,
top_p: Optional[float] = None,
top_k: Optional[int] = None,
@@ -474,6 +486,7 @@ class SglGen(SglExpr):
min_new_tokens=min_new_tokens,
n=n,
stop=stop,
stop_regex=stop_regex,
stop_token_ids=stop_token_ids,
temperature=temperature,
top_p=top_p,

View File

@@ -221,6 +221,7 @@ class CompletionRequest(BaseModel):
ebnf: Optional[str] = None
repetition_penalty: float = 1.0
stop_token_ids: Optional[List[int]] = None
stop_regex: Optional[Union[str, List[str]]] = None
no_stop_trim: bool = False
ignore_eos: bool = False
skip_special_tokens: bool = True
@@ -474,6 +475,7 @@ class ChatCompletionRequest(BaseModel):
ebnf: Optional[str] = None
repetition_penalty: Optional[float] = None
stop_token_ids: Optional[List[int]] = None
stop_regex: Optional[Union[str, List[str]]] = None
no_stop_trim: bool = False
ignore_eos: bool = False
continue_final_message: bool = False
@@ -602,6 +604,7 @@ class ChatCompletionRequest(BaseModel):
"min_new_tokens": self.min_tokens,
"stop": stop,
"stop_token_ids": self.stop_token_ids,
"stop_regex": self.stop_regex,
"top_p": get_param("top_p"),
"top_k": get_param("top_k"),
"min_p": get_param("min_p"),

View File

@@ -123,6 +123,7 @@ class OpenAIServingCompletion(OpenAIServingBase):
"min_new_tokens": request.min_tokens,
"stop": request.stop,
"stop_token_ids": request.stop_token_ids,
"stop_regex": request.stop_regex,
"top_p": request.top_p,
"top_k": request.top_k,
"min_p": request.min_p,

View File

@@ -36,6 +36,7 @@ TODO(lmzheng): ModelWorkerBatch seems a bit redundant and we consider removing i
import copy
import dataclasses
import logging
import re
import threading
import time
from enum import Enum, auto
@@ -154,6 +155,18 @@ class FINISH_MATCHED_STR(BaseFinishReason):
}
class FINISHED_MATCHED_REGEX(BaseFinishReason):
def __init__(self, matched: str):
super().__init__()
self.matched = matched
def to_json(self):
return {
"type": "stop", # to match OpenAI API's return value
"matched": self.matched,
}
class FINISH_LENGTH(BaseFinishReason):
def __init__(self, length: int):
super().__init__()
@@ -735,8 +748,17 @@ class Req:
return self.surr_and_decode_ids, self.read_offset - self.surr_offset
def tail_str(self) -> str:
tail_len = self.sampling_params.stop_str_max_len + 1
tail_len = min(tail_len, len(self.output_ids))
# Check stop strings and stop regex patterns together
if (
len(self.sampling_params.stop_strs) > 0
or len(self.sampling_params.stop_regex_strs) > 0
):
max_len_tail_str = max(
self.sampling_params.stop_str_max_len + 1,
self.sampling_params.stop_regex_max_len + 1,
)
tail_len = min((max_len_tail_str + 1), len(self.output_ids))
return self.tokenizer.decode(self.output_ids[-tail_len:])
def check_match_stop_str_prefix(self) -> bool:
@@ -817,14 +839,27 @@ class Req:
self.finished_reason = FINISH_MATCHED_STR(matched="NaN happened")
return
# Check stop strings
if len(self.sampling_params.stop_strs) > 0:
if (
len(self.sampling_params.stop_strs) > 0
or len(self.sampling_params.stop_regex_strs) > 0
):
tail_str = self.tail_str()
for stop_str in self.sampling_params.stop_strs:
if stop_str in tail_str or stop_str in self.decoded_text:
self.finished_reason = FINISH_MATCHED_STR(matched=stop_str)
return
# Check stop strings
if len(self.sampling_params.stop_strs) > 0:
for stop_str in self.sampling_params.stop_strs:
if stop_str in tail_str or stop_str in self.decoded_text:
self.finished_reason = FINISH_MATCHED_STR(matched=stop_str)
return
# Check stop regex
if len(self.sampling_params.stop_regex_strs) > 0:
for stop_regex_str in self.sampling_params.stop_regex_strs:
if re.search(stop_regex_str, tail_str):
self.finished_reason = FINISHED_MATCHED_REGEX(
matched=stop_regex_str
)
return
def reset_for_retract(self):
self.prefix_indices = torch.empty((0,), dtype=torch.int64)

View File

@@ -13,6 +13,8 @@
# ==============================================================================
"""Sampling parameters for text generation."""
import logging
import sre_parse
from typing import Any, Dict, List, Optional, Union
from sglang.srt.utils import get_bool_env_var
@@ -20,6 +22,8 @@ from sglang.srt.utils import get_bool_env_var
_SAMPLING_EPS = 1e-6
TOP_K_ALL = 1 << 30
logger = logging.getLogger(__name__)
class SamplingParams:
"""
@@ -35,6 +39,7 @@ class SamplingParams:
max_new_tokens: int = 128,
stop: Optional[Union[str, List[str]]] = None,
stop_token_ids: Optional[List[int]] = None,
stop_regex: Optional[Union[str, List[str]]] = None,
temperature: float = 1.0,
top_p: float = 1.0,
top_k: int = -1,
@@ -63,6 +68,7 @@ class SamplingParams:
self.stop_token_ids = set(stop_token_ids)
else:
self.stop_token_ids = None
self.stop_regex_strs = stop_regex
self.temperature = temperature
self.top_p = top_p
self.top_k = top_k
@@ -170,3 +176,67 @@ class SamplingParams:
else:
stop_str_max_len = max(stop_str_max_len, len(stop_str))
self.stop_str_max_len = stop_str_max_len
# Process stop regex strings
if self.stop_regex_strs is None:
self.stop_regex_strs = []
self.stop_regex_max_len = 0
else:
if isinstance(self.stop_regex_strs, str):
self.stop_regex_strs = [self.stop_regex_strs]
stop_regex_max_len = 0
for stop_regex in self.stop_regex_strs:
stop_regex_max_len = max(
stop_regex_max_len, get_max_seq_length(stop_regex)
)
self.stop_regex_max_len = stop_regex_max_len
# This function gets a strict upperbound on the maximum number of tokens that would need
# to be buffered to match the input regex string
# NOTE: in the worst case, one character that needs to be buffered corresponds to one
# token
def get_max_seq_length(regex_str: str):
return _max_length_from_subpattern(sre_parse.parse(regex_str))
MAX_LEN = 2**30
def _max_length_from_subpattern(subpattern: sre_parse.SubPattern):
total = 0
for token, value in subpattern:
if token in {
sre_parse.LITERAL, # `value` is any one character
sre_parse.IN, # Any character within `value`
sre_parse.ANY, # "."
}:
total += 1
elif token == sre_parse.SUBPATTERN:
# EG: (a\d+) ->
# [(SUBPATTERN,
# (1, 0, 0, [(LITERAL, 97),
# (MAX_REPEAT, (1, MAXREPEAT, [(IN, [(CATEGORY, CATEGORY_DIGIT)])]))]))]
_, _, _, inner_subpattern = value
total += _max_length_from_subpattern(inner_subpattern)
elif token == sre_parse.BRANCH:
_, branches = value
total += max(_max_length_from_subpattern(branch) for branch in branches)
elif token in {sre_parse.MAX_REPEAT, sre_parse.MIN_REPEAT}:
_, max_num_repeat, inner_subpattern = value
if max_num_repeat == sre_parse.MAXREPEAT:
total += MAX_LEN
else:
total += max_num_repeat * _max_length_from_subpattern(inner_subpattern)
elif token == sre_parse.AT:
# These are zero-width assertions like ^, $, and \b that don't add to the max
# length
total += 0
else:
logger.warning(f"Got unhandled regex token: {token}")
total += MAX_LEN
return total