[Feature] support regex strings as a stopping condition (#10635)
This commit is contained in:
@@ -49,6 +49,7 @@ python -m sglang.launch_server --model-path <MODEL> --sampling-defaults openai
|
||||
| max_new_tokens | `int = 128` | The maximum output length measured in tokens. |
|
||||
| stop | `Optional[Union[str, List[str]]] = None` | One or multiple [stop words](https://platform.openai.com/docs/api-reference/chat/create#chat-create-stop). Generation will stop if one of these words is sampled. |
|
||||
| stop_token_ids | `Optional[List[int]] = None` | Provide stop words in the form of token IDs. Generation will stop if one of these token IDs is sampled. |
|
||||
| stop_regex | `Optional[Union[str, List[str]]] = None` | Stop when hitting any of the regex patterns in this list |
|
||||
| temperature | `float (model default; fallback 1.0)` | [Temperature](https://platform.openai.com/docs/api-reference/chat/create#chat-create-temperature) when sampling the next token. `temperature = 0` corresponds to greedy sampling, a higher temperature leads to more diversity. |
|
||||
| top_p | `float (model default; fallback 1.0)` | [Top-p](https://platform.openai.com/docs/api-reference/chat/create#chat-create-top_p) selects tokens from the smallest sorted set whose cumulative probability exceeds `top_p`. When `top_p = 1`, this reduces to unrestricted sampling from all tokens. |
|
||||
| top_k | `int (model default; fallback -1)` | [Top-k](https://developer.nvidia.com/blog/how-to-get-better-outputs-from-your-large-language-model/#predictability_vs_creativity) randomly selects from the `k` highest-probability tokens. |
|
||||
|
||||
@@ -79,6 +79,7 @@ def gen(
|
||||
n: Optional[int] = None,
|
||||
stop: Optional[Union[str, List[str]]] = None,
|
||||
stop_token_ids: Optional[List[int]] = None,
|
||||
stop_regex: Optional[Union[str, List[str]]] = None,
|
||||
temperature: Optional[float] = None,
|
||||
top_p: Optional[float] = None,
|
||||
top_k: Optional[int] = None,
|
||||
@@ -120,6 +121,7 @@ def gen(
|
||||
n,
|
||||
stop,
|
||||
stop_token_ids,
|
||||
stop_regex,
|
||||
temperature,
|
||||
top_p,
|
||||
top_k,
|
||||
@@ -143,6 +145,7 @@ def gen_int(
|
||||
n: Optional[int] = None,
|
||||
stop: Optional[Union[str, List[str]]] = None,
|
||||
stop_token_ids: Optional[List[int]] = None,
|
||||
stop_regex: Optional[Union[str, List[str]]] = None,
|
||||
temperature: Optional[float] = None,
|
||||
top_p: Optional[float] = None,
|
||||
top_k: Optional[int] = None,
|
||||
@@ -162,6 +165,7 @@ def gen_int(
|
||||
n,
|
||||
stop,
|
||||
stop_token_ids,
|
||||
stop_regex,
|
||||
temperature,
|
||||
top_p,
|
||||
top_k,
|
||||
@@ -184,6 +188,7 @@ def gen_string(
|
||||
n: Optional[int] = None,
|
||||
stop: Optional[Union[str, List[str]]] = None,
|
||||
stop_token_ids: Optional[List[int]] = None,
|
||||
stop_regex: Optional[Union[str, List[str]]] = None,
|
||||
temperature: Optional[float] = None,
|
||||
top_p: Optional[float] = None,
|
||||
top_k: Optional[int] = None,
|
||||
@@ -203,6 +208,7 @@ def gen_string(
|
||||
n,
|
||||
stop,
|
||||
stop_token_ids,
|
||||
stop_regex,
|
||||
temperature,
|
||||
top_p,
|
||||
top_k,
|
||||
|
||||
@@ -792,6 +792,7 @@ class StreamExecutor:
|
||||
"n",
|
||||
"stop",
|
||||
"stop_token_ids",
|
||||
"stop_regex",
|
||||
"temperature",
|
||||
"top_p",
|
||||
"top_k",
|
||||
|
||||
@@ -21,6 +21,7 @@ class SglSamplingParams:
|
||||
n: int = 1
|
||||
stop: Union[str, List[str]] = ()
|
||||
stop_token_ids: Optional[List[int]] = ()
|
||||
stop_regex: Optional[Union[str, List[str]]] = ()
|
||||
temperature: float = 1.0
|
||||
top_p: float = 1.0
|
||||
top_k: int = -1 # -1 means disable
|
||||
@@ -45,6 +46,7 @@ class SglSamplingParams:
|
||||
self.n,
|
||||
self.stop,
|
||||
self.stop_token_ids,
|
||||
self.stop_regex,
|
||||
self.temperature,
|
||||
self.top_p,
|
||||
self.top_k,
|
||||
@@ -123,6 +125,7 @@ class SglSamplingParams:
|
||||
"n": self.n,
|
||||
"stop": self.stop,
|
||||
"stop_token_ids": self.stop_token_ids,
|
||||
"stop_regex": self.stop_regex,
|
||||
"temperature": self.temperature,
|
||||
"top_p": self.top_p,
|
||||
"top_k": self.top_k,
|
||||
@@ -161,6 +164,7 @@ class SglFunction:
|
||||
n: int = 1,
|
||||
stop: Optional[Union[str, List[str]]] = None,
|
||||
stop_token_ids: Optional[List[int]] = None,
|
||||
stop_regex: Optional[Union[str, List[str]]] = None,
|
||||
temperature: float = 1.0,
|
||||
top_p: float = 1.0,
|
||||
top_k: int = -1,
|
||||
@@ -184,12 +188,15 @@ class SglFunction:
|
||||
stop = []
|
||||
if stop_token_ids is None:
|
||||
stop_token_ids = []
|
||||
if stop_regex is None:
|
||||
stop_regex = []
|
||||
|
||||
default_sampling_para = SglSamplingParams(
|
||||
max_new_tokens=max_new_tokens,
|
||||
n=n,
|
||||
stop=stop,
|
||||
stop_token_ids=stop_token_ids,
|
||||
stop_regex=stop_regex,
|
||||
temperature=temperature,
|
||||
top_p=top_p,
|
||||
top_k=top_k,
|
||||
@@ -221,6 +228,7 @@ class SglFunction:
|
||||
n: int = 1,
|
||||
stop: Optional[Union[str, List[str]]] = None,
|
||||
stop_token_ids: Optional[List[int]] = None,
|
||||
stop_regex: Optional[Union[str, List[str]]] = None,
|
||||
temperature: float = 1.0,
|
||||
top_p: float = 1.0,
|
||||
top_k: int = -1,
|
||||
@@ -243,6 +251,8 @@ class SglFunction:
|
||||
stop = []
|
||||
if stop_token_ids is None:
|
||||
stop_token_ids = []
|
||||
if stop_regex is None:
|
||||
stop_regex = []
|
||||
|
||||
assert isinstance(batch_kwargs, (list, tuple))
|
||||
if len(batch_kwargs) == 0:
|
||||
@@ -267,6 +277,7 @@ class SglFunction:
|
||||
n=n,
|
||||
stop=stop,
|
||||
stop_token_ids=stop_token_ids,
|
||||
stop_regex=stop_regex,
|
||||
temperature=temperature,
|
||||
top_p=top_p,
|
||||
top_k=top_k,
|
||||
@@ -451,6 +462,7 @@ class SglGen(SglExpr):
|
||||
n: Optional[int] = None,
|
||||
stop: Optional[Union[str, List[str]]] = None,
|
||||
stop_token_ids: Optional[List[int]] = None,
|
||||
stop_regex: Optional[Union[str, List[str]]] = None,
|
||||
temperature: Optional[float] = None,
|
||||
top_p: Optional[float] = None,
|
||||
top_k: Optional[int] = None,
|
||||
@@ -474,6 +486,7 @@ class SglGen(SglExpr):
|
||||
min_new_tokens=min_new_tokens,
|
||||
n=n,
|
||||
stop=stop,
|
||||
stop_regex=stop_regex,
|
||||
stop_token_ids=stop_token_ids,
|
||||
temperature=temperature,
|
||||
top_p=top_p,
|
||||
|
||||
@@ -221,6 +221,7 @@ class CompletionRequest(BaseModel):
|
||||
ebnf: Optional[str] = None
|
||||
repetition_penalty: float = 1.0
|
||||
stop_token_ids: Optional[List[int]] = None
|
||||
stop_regex: Optional[Union[str, List[str]]] = None
|
||||
no_stop_trim: bool = False
|
||||
ignore_eos: bool = False
|
||||
skip_special_tokens: bool = True
|
||||
@@ -474,6 +475,7 @@ class ChatCompletionRequest(BaseModel):
|
||||
ebnf: Optional[str] = None
|
||||
repetition_penalty: Optional[float] = None
|
||||
stop_token_ids: Optional[List[int]] = None
|
||||
stop_regex: Optional[Union[str, List[str]]] = None
|
||||
no_stop_trim: bool = False
|
||||
ignore_eos: bool = False
|
||||
continue_final_message: bool = False
|
||||
@@ -602,6 +604,7 @@ class ChatCompletionRequest(BaseModel):
|
||||
"min_new_tokens": self.min_tokens,
|
||||
"stop": stop,
|
||||
"stop_token_ids": self.stop_token_ids,
|
||||
"stop_regex": self.stop_regex,
|
||||
"top_p": get_param("top_p"),
|
||||
"top_k": get_param("top_k"),
|
||||
"min_p": get_param("min_p"),
|
||||
|
||||
@@ -123,6 +123,7 @@ class OpenAIServingCompletion(OpenAIServingBase):
|
||||
"min_new_tokens": request.min_tokens,
|
||||
"stop": request.stop,
|
||||
"stop_token_ids": request.stop_token_ids,
|
||||
"stop_regex": request.stop_regex,
|
||||
"top_p": request.top_p,
|
||||
"top_k": request.top_k,
|
||||
"min_p": request.min_p,
|
||||
|
||||
@@ -36,6 +36,7 @@ TODO(lmzheng): ModelWorkerBatch seems a bit redundant and we consider removing i
|
||||
import copy
|
||||
import dataclasses
|
||||
import logging
|
||||
import re
|
||||
import threading
|
||||
import time
|
||||
from enum import Enum, auto
|
||||
@@ -154,6 +155,18 @@ class FINISH_MATCHED_STR(BaseFinishReason):
|
||||
}
|
||||
|
||||
|
||||
class FINISHED_MATCHED_REGEX(BaseFinishReason):
|
||||
def __init__(self, matched: str):
|
||||
super().__init__()
|
||||
self.matched = matched
|
||||
|
||||
def to_json(self):
|
||||
return {
|
||||
"type": "stop", # to match OpenAI API's return value
|
||||
"matched": self.matched,
|
||||
}
|
||||
|
||||
|
||||
class FINISH_LENGTH(BaseFinishReason):
|
||||
def __init__(self, length: int):
|
||||
super().__init__()
|
||||
@@ -735,8 +748,17 @@ class Req:
|
||||
return self.surr_and_decode_ids, self.read_offset - self.surr_offset
|
||||
|
||||
def tail_str(self) -> str:
|
||||
tail_len = self.sampling_params.stop_str_max_len + 1
|
||||
tail_len = min(tail_len, len(self.output_ids))
|
||||
# Check stop strings and stop regex patterns together
|
||||
if (
|
||||
len(self.sampling_params.stop_strs) > 0
|
||||
or len(self.sampling_params.stop_regex_strs) > 0
|
||||
):
|
||||
max_len_tail_str = max(
|
||||
self.sampling_params.stop_str_max_len + 1,
|
||||
self.sampling_params.stop_regex_max_len + 1,
|
||||
)
|
||||
|
||||
tail_len = min((max_len_tail_str + 1), len(self.output_ids))
|
||||
return self.tokenizer.decode(self.output_ids[-tail_len:])
|
||||
|
||||
def check_match_stop_str_prefix(self) -> bool:
|
||||
@@ -817,14 +839,27 @@ class Req:
|
||||
self.finished_reason = FINISH_MATCHED_STR(matched="NaN happened")
|
||||
return
|
||||
|
||||
# Check stop strings
|
||||
if len(self.sampling_params.stop_strs) > 0:
|
||||
if (
|
||||
len(self.sampling_params.stop_strs) > 0
|
||||
or len(self.sampling_params.stop_regex_strs) > 0
|
||||
):
|
||||
tail_str = self.tail_str()
|
||||
|
||||
for stop_str in self.sampling_params.stop_strs:
|
||||
if stop_str in tail_str or stop_str in self.decoded_text:
|
||||
self.finished_reason = FINISH_MATCHED_STR(matched=stop_str)
|
||||
return
|
||||
# Check stop strings
|
||||
if len(self.sampling_params.stop_strs) > 0:
|
||||
for stop_str in self.sampling_params.stop_strs:
|
||||
if stop_str in tail_str or stop_str in self.decoded_text:
|
||||
self.finished_reason = FINISH_MATCHED_STR(matched=stop_str)
|
||||
return
|
||||
|
||||
# Check stop regex
|
||||
if len(self.sampling_params.stop_regex_strs) > 0:
|
||||
for stop_regex_str in self.sampling_params.stop_regex_strs:
|
||||
if re.search(stop_regex_str, tail_str):
|
||||
self.finished_reason = FINISHED_MATCHED_REGEX(
|
||||
matched=stop_regex_str
|
||||
)
|
||||
return
|
||||
|
||||
def reset_for_retract(self):
|
||||
self.prefix_indices = torch.empty((0,), dtype=torch.int64)
|
||||
|
||||
@@ -13,6 +13,8 @@
|
||||
# ==============================================================================
|
||||
"""Sampling parameters for text generation."""
|
||||
|
||||
import logging
|
||||
import sre_parse
|
||||
from typing import Any, Dict, List, Optional, Union
|
||||
|
||||
from sglang.srt.utils import get_bool_env_var
|
||||
@@ -20,6 +22,8 @@ from sglang.srt.utils import get_bool_env_var
|
||||
_SAMPLING_EPS = 1e-6
|
||||
TOP_K_ALL = 1 << 30
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class SamplingParams:
|
||||
"""
|
||||
@@ -35,6 +39,7 @@ class SamplingParams:
|
||||
max_new_tokens: int = 128,
|
||||
stop: Optional[Union[str, List[str]]] = None,
|
||||
stop_token_ids: Optional[List[int]] = None,
|
||||
stop_regex: Optional[Union[str, List[str]]] = None,
|
||||
temperature: float = 1.0,
|
||||
top_p: float = 1.0,
|
||||
top_k: int = -1,
|
||||
@@ -63,6 +68,7 @@ class SamplingParams:
|
||||
self.stop_token_ids = set(stop_token_ids)
|
||||
else:
|
||||
self.stop_token_ids = None
|
||||
self.stop_regex_strs = stop_regex
|
||||
self.temperature = temperature
|
||||
self.top_p = top_p
|
||||
self.top_k = top_k
|
||||
@@ -170,3 +176,67 @@ class SamplingParams:
|
||||
else:
|
||||
stop_str_max_len = max(stop_str_max_len, len(stop_str))
|
||||
self.stop_str_max_len = stop_str_max_len
|
||||
|
||||
# Process stop regex strings
|
||||
if self.stop_regex_strs is None:
|
||||
self.stop_regex_strs = []
|
||||
self.stop_regex_max_len = 0
|
||||
else:
|
||||
if isinstance(self.stop_regex_strs, str):
|
||||
self.stop_regex_strs = [self.stop_regex_strs]
|
||||
|
||||
stop_regex_max_len = 0
|
||||
for stop_regex in self.stop_regex_strs:
|
||||
stop_regex_max_len = max(
|
||||
stop_regex_max_len, get_max_seq_length(stop_regex)
|
||||
)
|
||||
|
||||
self.stop_regex_max_len = stop_regex_max_len
|
||||
|
||||
|
||||
# This function gets a strict upperbound on the maximum number of tokens that would need
|
||||
# to be buffered to match the input regex string
|
||||
# NOTE: in the worst case, one character that needs to be buffered corresponds to one
|
||||
# token
|
||||
def get_max_seq_length(regex_str: str):
|
||||
return _max_length_from_subpattern(sre_parse.parse(regex_str))
|
||||
|
||||
|
||||
MAX_LEN = 2**30
|
||||
|
||||
|
||||
def _max_length_from_subpattern(subpattern: sre_parse.SubPattern):
|
||||
total = 0
|
||||
for token, value in subpattern:
|
||||
if token in {
|
||||
sre_parse.LITERAL, # `value` is any one character
|
||||
sre_parse.IN, # Any character within `value`
|
||||
sre_parse.ANY, # "."
|
||||
}:
|
||||
total += 1
|
||||
elif token == sre_parse.SUBPATTERN:
|
||||
# EG: (a\d+) ->
|
||||
# [(SUBPATTERN,
|
||||
# (1, 0, 0, [(LITERAL, 97),
|
||||
# (MAX_REPEAT, (1, MAXREPEAT, [(IN, [(CATEGORY, CATEGORY_DIGIT)])]))]))]
|
||||
_, _, _, inner_subpattern = value
|
||||
total += _max_length_from_subpattern(inner_subpattern)
|
||||
elif token == sre_parse.BRANCH:
|
||||
_, branches = value
|
||||
total += max(_max_length_from_subpattern(branch) for branch in branches)
|
||||
elif token in {sre_parse.MAX_REPEAT, sre_parse.MIN_REPEAT}:
|
||||
_, max_num_repeat, inner_subpattern = value
|
||||
if max_num_repeat == sre_parse.MAXREPEAT:
|
||||
total += MAX_LEN
|
||||
else:
|
||||
total += max_num_repeat * _max_length_from_subpattern(inner_subpattern)
|
||||
elif token == sre_parse.AT:
|
||||
# These are zero-width assertions like ^, $, and \b that don't add to the max
|
||||
# length
|
||||
total += 0
|
||||
else:
|
||||
logger.warning(f"Got unhandled regex token: {token}")
|
||||
|
||||
total += MAX_LEN
|
||||
|
||||
return total
|
||||
|
||||
@@ -3,6 +3,7 @@ import unittest
|
||||
|
||||
import requests
|
||||
|
||||
from sglang.srt.sampling.sampling_params import MAX_LEN, get_max_seq_length
|
||||
from sglang.srt.utils import kill_process_tree
|
||||
from sglang.test.test_utils import (
|
||||
DEFAULT_MODEL_NAME_FOR_TEST,
|
||||
@@ -40,6 +41,7 @@ class TestMatchedStop(CustomTestCase):
|
||||
prompt=MANY_NEW_TOKENS_PROMPT,
|
||||
max_tokens=1,
|
||||
stop=None,
|
||||
stop_regex=None,
|
||||
finish_reason=None,
|
||||
matched_stop=None,
|
||||
):
|
||||
@@ -54,6 +56,9 @@ class TestMatchedStop(CustomTestCase):
|
||||
if stop is not None:
|
||||
payload["stop"] = stop
|
||||
|
||||
if stop_regex is not None:
|
||||
payload["stop_regex"] = stop_regex
|
||||
|
||||
response_completions = requests.post(
|
||||
self.base_url + "/v1/completions",
|
||||
json=payload,
|
||||
@@ -71,6 +76,7 @@ class TestMatchedStop(CustomTestCase):
|
||||
prompt=MANY_NEW_TOKENS_PROMPT,
|
||||
max_tokens=1,
|
||||
stop=None,
|
||||
stop_regex=None,
|
||||
finish_reason=None,
|
||||
matched_stop=None,
|
||||
):
|
||||
@@ -88,6 +94,9 @@ class TestMatchedStop(CustomTestCase):
|
||||
if stop is not None:
|
||||
chat_payload["stop"] = stop
|
||||
|
||||
if stop_regex is not None:
|
||||
chat_payload["stop_regex"] = stop_regex
|
||||
|
||||
response_chat = requests.post(
|
||||
self.base_url + "/v1/chat/completions",
|
||||
json=chat_payload,
|
||||
@@ -106,6 +115,30 @@ class TestMatchedStop(CustomTestCase):
|
||||
max_tokens=1000, stop="\n", finish_reason="stop", matched_stop="\n"
|
||||
)
|
||||
|
||||
def test_finish_stop_regex_str(self):
|
||||
STOP_REGEX_STR = r"and|or"
|
||||
self.run_completions_generation(
|
||||
max_tokens=1000,
|
||||
stop_regex=STOP_REGEX_STR,
|
||||
finish_reason="stop",
|
||||
matched_stop=STOP_REGEX_STR,
|
||||
)
|
||||
self.run_chat_completions_generation(
|
||||
max_tokens=1000,
|
||||
stop_regex=STOP_REGEX_STR,
|
||||
finish_reason="stop",
|
||||
matched_stop=STOP_REGEX_STR,
|
||||
)
|
||||
|
||||
# Match a complete sentence
|
||||
STOP_REGEX_STR_SENTENCE = r"[.!?]\s*$"
|
||||
self.run_chat_completions_generation(
|
||||
max_tokens=1000,
|
||||
stop_regex=STOP_REGEX_STR_SENTENCE,
|
||||
finish_reason="stop",
|
||||
matched_stop=STOP_REGEX_STR_SENTENCE,
|
||||
)
|
||||
|
||||
def test_finish_stop_eos(self):
|
||||
llama_format_prompt = """
|
||||
<|begin_of_text|><|start_header_id|>system<|end_header_id|>
|
||||
@@ -136,5 +169,53 @@ class TestMatchedStop(CustomTestCase):
|
||||
)
|
||||
|
||||
|
||||
class TestRegexPatternMaxLength(unittest.TestCase):
|
||||
@classmethod
|
||||
def setUpClass(cls):
|
||||
cls.regex_str_to_max_len = {
|
||||
"((ab|cd(e|f){2}){3,5}g|hij)*k": MAX_LEN,
|
||||
# - '*' → infinite tokens need to be stored
|
||||
"abc*?k": MAX_LEN,
|
||||
# - '*?' → infinite tokens still need to be stored even if lazy matching used
|
||||
"^spec(foo|at)$": 7,
|
||||
# - '^' and '$' don't add any characters to the max length
|
||||
# "spec" → 4
|
||||
# "(foo|at)" → max(3, 2) = 3
|
||||
# Whole regex = 7
|
||||
"(a(bca|de(fg|hi){2,3})j){2}kl": 22,
|
||||
# - Innermost alt: "fg" vs "hi" → 2
|
||||
# - Repeat {2,3}: max = 3 * 2 = 6
|
||||
# - Inner group "de(...)": 2 (for "de") + 6 = 8.
|
||||
# - "bca" or "de(...)" → max(3, 8) = 8
|
||||
# - Whole group: "a" (1) + group (8) + "j"(1) = 10
|
||||
# - Repeat {2} → 20
|
||||
# - Add "kl"(2) → 22
|
||||
"(foo(bar|baz(qux){1,2}))|(x(yz){5,10})": 21,
|
||||
# Branch 1:
|
||||
# "foo"(3) + max("bar"(3), "baz"(3)+"qux"{2} = 3 + 6 = 9) = 3 + 9 = 12
|
||||
# Branch 2:
|
||||
# "x"(1) + "yz"{10} = 1 + 20 =21
|
||||
# Whole regex = max(12, 21) = 21
|
||||
"(((a|bc){1,3}(d(e|f){2}|gh){2,4})|(ijk|lmp(no|p){3})){5}": 90,
|
||||
# Branch A:
|
||||
# (a|bc){1,3} → max = 3 * 2 = 6
|
||||
# Inside: d(e|f){2} = 1 + 2 * 1 = 3 vs gh = 2 → max = 3
|
||||
# Repeat {2,4} → 4 * 3 = 12
|
||||
# Branch A total = 18
|
||||
# Branch B:
|
||||
# "ijk"(3) vs "lmp(no|p){3}" = 3 + 3 * max(2, 1) = 3 + 6 = 9 → max = 9
|
||||
# Branch B total = 9
|
||||
# Whole outer alt = max(18, 9) = 18
|
||||
# Repeat {5} → 90
|
||||
}
|
||||
|
||||
def test_get_max_length(self):
|
||||
for regex_str, max_len in self.regex_str_to_max_len.items():
|
||||
if max_len == MAX_LEN:
|
||||
self.assertGreaterEqual(get_max_seq_length(regex_str), MAX_LEN)
|
||||
else:
|
||||
self.assertEqual(get_max_seq_length(regex_str), max_len)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
|
||||
Reference in New Issue
Block a user