[Feature] support regex strings as a stopping condition (#10635)
This commit is contained in:
@@ -79,6 +79,7 @@ def gen(
|
||||
n: Optional[int] = None,
|
||||
stop: Optional[Union[str, List[str]]] = None,
|
||||
stop_token_ids: Optional[List[int]] = None,
|
||||
stop_regex: Optional[Union[str, List[str]]] = None,
|
||||
temperature: Optional[float] = None,
|
||||
top_p: Optional[float] = None,
|
||||
top_k: Optional[int] = None,
|
||||
@@ -120,6 +121,7 @@ def gen(
|
||||
n,
|
||||
stop,
|
||||
stop_token_ids,
|
||||
stop_regex,
|
||||
temperature,
|
||||
top_p,
|
||||
top_k,
|
||||
@@ -143,6 +145,7 @@ def gen_int(
|
||||
n: Optional[int] = None,
|
||||
stop: Optional[Union[str, List[str]]] = None,
|
||||
stop_token_ids: Optional[List[int]] = None,
|
||||
stop_regex: Optional[Union[str, List[str]]] = None,
|
||||
temperature: Optional[float] = None,
|
||||
top_p: Optional[float] = None,
|
||||
top_k: Optional[int] = None,
|
||||
@@ -162,6 +165,7 @@ def gen_int(
|
||||
n,
|
||||
stop,
|
||||
stop_token_ids,
|
||||
stop_regex,
|
||||
temperature,
|
||||
top_p,
|
||||
top_k,
|
||||
@@ -184,6 +188,7 @@ def gen_string(
|
||||
n: Optional[int] = None,
|
||||
stop: Optional[Union[str, List[str]]] = None,
|
||||
stop_token_ids: Optional[List[int]] = None,
|
||||
stop_regex: Optional[Union[str, List[str]]] = None,
|
||||
temperature: Optional[float] = None,
|
||||
top_p: Optional[float] = None,
|
||||
top_k: Optional[int] = None,
|
||||
@@ -203,6 +208,7 @@ def gen_string(
|
||||
n,
|
||||
stop,
|
||||
stop_token_ids,
|
||||
stop_regex,
|
||||
temperature,
|
||||
top_p,
|
||||
top_k,
|
||||
|
||||
@@ -792,6 +792,7 @@ class StreamExecutor:
|
||||
"n",
|
||||
"stop",
|
||||
"stop_token_ids",
|
||||
"stop_regex",
|
||||
"temperature",
|
||||
"top_p",
|
||||
"top_k",
|
||||
|
||||
@@ -21,6 +21,7 @@ class SglSamplingParams:
|
||||
n: int = 1
|
||||
stop: Union[str, List[str]] = ()
|
||||
stop_token_ids: Optional[List[int]] = ()
|
||||
stop_regex: Optional[Union[str, List[str]]] = ()
|
||||
temperature: float = 1.0
|
||||
top_p: float = 1.0
|
||||
top_k: int = -1 # -1 means disable
|
||||
@@ -45,6 +46,7 @@ class SglSamplingParams:
|
||||
self.n,
|
||||
self.stop,
|
||||
self.stop_token_ids,
|
||||
self.stop_regex,
|
||||
self.temperature,
|
||||
self.top_p,
|
||||
self.top_k,
|
||||
@@ -123,6 +125,7 @@ class SglSamplingParams:
|
||||
"n": self.n,
|
||||
"stop": self.stop,
|
||||
"stop_token_ids": self.stop_token_ids,
|
||||
"stop_regex": self.stop_regex,
|
||||
"temperature": self.temperature,
|
||||
"top_p": self.top_p,
|
||||
"top_k": self.top_k,
|
||||
@@ -161,6 +164,7 @@ class SglFunction:
|
||||
n: int = 1,
|
||||
stop: Optional[Union[str, List[str]]] = None,
|
||||
stop_token_ids: Optional[List[int]] = None,
|
||||
stop_regex: Optional[Union[str, List[str]]] = None,
|
||||
temperature: float = 1.0,
|
||||
top_p: float = 1.0,
|
||||
top_k: int = -1,
|
||||
@@ -184,12 +188,15 @@ class SglFunction:
|
||||
stop = []
|
||||
if stop_token_ids is None:
|
||||
stop_token_ids = []
|
||||
if stop_regex is None:
|
||||
stop_regex = []
|
||||
|
||||
default_sampling_para = SglSamplingParams(
|
||||
max_new_tokens=max_new_tokens,
|
||||
n=n,
|
||||
stop=stop,
|
||||
stop_token_ids=stop_token_ids,
|
||||
stop_regex=stop_regex,
|
||||
temperature=temperature,
|
||||
top_p=top_p,
|
||||
top_k=top_k,
|
||||
@@ -221,6 +228,7 @@ class SglFunction:
|
||||
n: int = 1,
|
||||
stop: Optional[Union[str, List[str]]] = None,
|
||||
stop_token_ids: Optional[List[int]] = None,
|
||||
stop_regex: Optional[Union[str, List[str]]] = None,
|
||||
temperature: float = 1.0,
|
||||
top_p: float = 1.0,
|
||||
top_k: int = -1,
|
||||
@@ -243,6 +251,8 @@ class SglFunction:
|
||||
stop = []
|
||||
if stop_token_ids is None:
|
||||
stop_token_ids = []
|
||||
if stop_regex is None:
|
||||
stop_regex = []
|
||||
|
||||
assert isinstance(batch_kwargs, (list, tuple))
|
||||
if len(batch_kwargs) == 0:
|
||||
@@ -267,6 +277,7 @@ class SglFunction:
|
||||
n=n,
|
||||
stop=stop,
|
||||
stop_token_ids=stop_token_ids,
|
||||
stop_regex=stop_regex,
|
||||
temperature=temperature,
|
||||
top_p=top_p,
|
||||
top_k=top_k,
|
||||
@@ -451,6 +462,7 @@ class SglGen(SglExpr):
|
||||
n: Optional[int] = None,
|
||||
stop: Optional[Union[str, List[str]]] = None,
|
||||
stop_token_ids: Optional[List[int]] = None,
|
||||
stop_regex: Optional[Union[str, List[str]]] = None,
|
||||
temperature: Optional[float] = None,
|
||||
top_p: Optional[float] = None,
|
||||
top_k: Optional[int] = None,
|
||||
@@ -474,6 +486,7 @@ class SglGen(SglExpr):
|
||||
min_new_tokens=min_new_tokens,
|
||||
n=n,
|
||||
stop=stop,
|
||||
stop_regex=stop_regex,
|
||||
stop_token_ids=stop_token_ids,
|
||||
temperature=temperature,
|
||||
top_p=top_p,
|
||||
|
||||
@@ -221,6 +221,7 @@ class CompletionRequest(BaseModel):
|
||||
ebnf: Optional[str] = None
|
||||
repetition_penalty: float = 1.0
|
||||
stop_token_ids: Optional[List[int]] = None
|
||||
stop_regex: Optional[Union[str, List[str]]] = None
|
||||
no_stop_trim: bool = False
|
||||
ignore_eos: bool = False
|
||||
skip_special_tokens: bool = True
|
||||
@@ -474,6 +475,7 @@ class ChatCompletionRequest(BaseModel):
|
||||
ebnf: Optional[str] = None
|
||||
repetition_penalty: Optional[float] = None
|
||||
stop_token_ids: Optional[List[int]] = None
|
||||
stop_regex: Optional[Union[str, List[str]]] = None
|
||||
no_stop_trim: bool = False
|
||||
ignore_eos: bool = False
|
||||
continue_final_message: bool = False
|
||||
@@ -602,6 +604,7 @@ class ChatCompletionRequest(BaseModel):
|
||||
"min_new_tokens": self.min_tokens,
|
||||
"stop": stop,
|
||||
"stop_token_ids": self.stop_token_ids,
|
||||
"stop_regex": self.stop_regex,
|
||||
"top_p": get_param("top_p"),
|
||||
"top_k": get_param("top_k"),
|
||||
"min_p": get_param("min_p"),
|
||||
|
||||
@@ -123,6 +123,7 @@ class OpenAIServingCompletion(OpenAIServingBase):
|
||||
"min_new_tokens": request.min_tokens,
|
||||
"stop": request.stop,
|
||||
"stop_token_ids": request.stop_token_ids,
|
||||
"stop_regex": request.stop_regex,
|
||||
"top_p": request.top_p,
|
||||
"top_k": request.top_k,
|
||||
"min_p": request.min_p,
|
||||
|
||||
@@ -36,6 +36,7 @@ TODO(lmzheng): ModelWorkerBatch seems a bit redundant and we consider removing i
|
||||
import copy
|
||||
import dataclasses
|
||||
import logging
|
||||
import re
|
||||
import threading
|
||||
import time
|
||||
from enum import Enum, auto
|
||||
@@ -154,6 +155,18 @@ class FINISH_MATCHED_STR(BaseFinishReason):
|
||||
}
|
||||
|
||||
|
||||
class FINISHED_MATCHED_REGEX(BaseFinishReason):
|
||||
def __init__(self, matched: str):
|
||||
super().__init__()
|
||||
self.matched = matched
|
||||
|
||||
def to_json(self):
|
||||
return {
|
||||
"type": "stop", # to match OpenAI API's return value
|
||||
"matched": self.matched,
|
||||
}
|
||||
|
||||
|
||||
class FINISH_LENGTH(BaseFinishReason):
|
||||
def __init__(self, length: int):
|
||||
super().__init__()
|
||||
@@ -735,8 +748,17 @@ class Req:
|
||||
return self.surr_and_decode_ids, self.read_offset - self.surr_offset
|
||||
|
||||
def tail_str(self) -> str:
|
||||
tail_len = self.sampling_params.stop_str_max_len + 1
|
||||
tail_len = min(tail_len, len(self.output_ids))
|
||||
# Check stop strings and stop regex patterns together
|
||||
if (
|
||||
len(self.sampling_params.stop_strs) > 0
|
||||
or len(self.sampling_params.stop_regex_strs) > 0
|
||||
):
|
||||
max_len_tail_str = max(
|
||||
self.sampling_params.stop_str_max_len + 1,
|
||||
self.sampling_params.stop_regex_max_len + 1,
|
||||
)
|
||||
|
||||
tail_len = min((max_len_tail_str + 1), len(self.output_ids))
|
||||
return self.tokenizer.decode(self.output_ids[-tail_len:])
|
||||
|
||||
def check_match_stop_str_prefix(self) -> bool:
|
||||
@@ -817,14 +839,27 @@ class Req:
|
||||
self.finished_reason = FINISH_MATCHED_STR(matched="NaN happened")
|
||||
return
|
||||
|
||||
# Check stop strings
|
||||
if len(self.sampling_params.stop_strs) > 0:
|
||||
if (
|
||||
len(self.sampling_params.stop_strs) > 0
|
||||
or len(self.sampling_params.stop_regex_strs) > 0
|
||||
):
|
||||
tail_str = self.tail_str()
|
||||
|
||||
for stop_str in self.sampling_params.stop_strs:
|
||||
if stop_str in tail_str or stop_str in self.decoded_text:
|
||||
self.finished_reason = FINISH_MATCHED_STR(matched=stop_str)
|
||||
return
|
||||
# Check stop strings
|
||||
if len(self.sampling_params.stop_strs) > 0:
|
||||
for stop_str in self.sampling_params.stop_strs:
|
||||
if stop_str in tail_str or stop_str in self.decoded_text:
|
||||
self.finished_reason = FINISH_MATCHED_STR(matched=stop_str)
|
||||
return
|
||||
|
||||
# Check stop regex
|
||||
if len(self.sampling_params.stop_regex_strs) > 0:
|
||||
for stop_regex_str in self.sampling_params.stop_regex_strs:
|
||||
if re.search(stop_regex_str, tail_str):
|
||||
self.finished_reason = FINISHED_MATCHED_REGEX(
|
||||
matched=stop_regex_str
|
||||
)
|
||||
return
|
||||
|
||||
def reset_for_retract(self):
|
||||
self.prefix_indices = torch.empty((0,), dtype=torch.int64)
|
||||
|
||||
@@ -13,6 +13,8 @@
|
||||
# ==============================================================================
|
||||
"""Sampling parameters for text generation."""
|
||||
|
||||
import logging
|
||||
import sre_parse
|
||||
from typing import Any, Dict, List, Optional, Union
|
||||
|
||||
from sglang.srt.utils import get_bool_env_var
|
||||
@@ -20,6 +22,8 @@ from sglang.srt.utils import get_bool_env_var
|
||||
_SAMPLING_EPS = 1e-6
|
||||
TOP_K_ALL = 1 << 30
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class SamplingParams:
|
||||
"""
|
||||
@@ -35,6 +39,7 @@ class SamplingParams:
|
||||
max_new_tokens: int = 128,
|
||||
stop: Optional[Union[str, List[str]]] = None,
|
||||
stop_token_ids: Optional[List[int]] = None,
|
||||
stop_regex: Optional[Union[str, List[str]]] = None,
|
||||
temperature: float = 1.0,
|
||||
top_p: float = 1.0,
|
||||
top_k: int = -1,
|
||||
@@ -63,6 +68,7 @@ class SamplingParams:
|
||||
self.stop_token_ids = set(stop_token_ids)
|
||||
else:
|
||||
self.stop_token_ids = None
|
||||
self.stop_regex_strs = stop_regex
|
||||
self.temperature = temperature
|
||||
self.top_p = top_p
|
||||
self.top_k = top_k
|
||||
@@ -170,3 +176,67 @@ class SamplingParams:
|
||||
else:
|
||||
stop_str_max_len = max(stop_str_max_len, len(stop_str))
|
||||
self.stop_str_max_len = stop_str_max_len
|
||||
|
||||
# Process stop regex strings
|
||||
if self.stop_regex_strs is None:
|
||||
self.stop_regex_strs = []
|
||||
self.stop_regex_max_len = 0
|
||||
else:
|
||||
if isinstance(self.stop_regex_strs, str):
|
||||
self.stop_regex_strs = [self.stop_regex_strs]
|
||||
|
||||
stop_regex_max_len = 0
|
||||
for stop_regex in self.stop_regex_strs:
|
||||
stop_regex_max_len = max(
|
||||
stop_regex_max_len, get_max_seq_length(stop_regex)
|
||||
)
|
||||
|
||||
self.stop_regex_max_len = stop_regex_max_len
|
||||
|
||||
|
||||
# This function gets a strict upperbound on the maximum number of tokens that would need
|
||||
# to be buffered to match the input regex string
|
||||
# NOTE: in the worst case, one character that needs to be buffered corresponds to one
|
||||
# token
|
||||
def get_max_seq_length(regex_str: str):
|
||||
return _max_length_from_subpattern(sre_parse.parse(regex_str))
|
||||
|
||||
|
||||
MAX_LEN = 2**30
|
||||
|
||||
|
||||
def _max_length_from_subpattern(subpattern: sre_parse.SubPattern):
|
||||
total = 0
|
||||
for token, value in subpattern:
|
||||
if token in {
|
||||
sre_parse.LITERAL, # `value` is any one character
|
||||
sre_parse.IN, # Any character within `value`
|
||||
sre_parse.ANY, # "."
|
||||
}:
|
||||
total += 1
|
||||
elif token == sre_parse.SUBPATTERN:
|
||||
# EG: (a\d+) ->
|
||||
# [(SUBPATTERN,
|
||||
# (1, 0, 0, [(LITERAL, 97),
|
||||
# (MAX_REPEAT, (1, MAXREPEAT, [(IN, [(CATEGORY, CATEGORY_DIGIT)])]))]))]
|
||||
_, _, _, inner_subpattern = value
|
||||
total += _max_length_from_subpattern(inner_subpattern)
|
||||
elif token == sre_parse.BRANCH:
|
||||
_, branches = value
|
||||
total += max(_max_length_from_subpattern(branch) for branch in branches)
|
||||
elif token in {sre_parse.MAX_REPEAT, sre_parse.MIN_REPEAT}:
|
||||
_, max_num_repeat, inner_subpattern = value
|
||||
if max_num_repeat == sre_parse.MAXREPEAT:
|
||||
total += MAX_LEN
|
||||
else:
|
||||
total += max_num_repeat * _max_length_from_subpattern(inner_subpattern)
|
||||
elif token == sre_parse.AT:
|
||||
# These are zero-width assertions like ^, $, and \b that don't add to the max
|
||||
# length
|
||||
total += 0
|
||||
else:
|
||||
logger.warning(f"Got unhandled regex token: {token}")
|
||||
|
||||
total += MAX_LEN
|
||||
|
||||
return total
|
||||
|
||||
Reference in New Issue
Block a user