diff --git a/docs/basic_usage/sampling_params.md b/docs/basic_usage/sampling_params.md index f1c61314f..f6faf72d9 100644 --- a/docs/basic_usage/sampling_params.md +++ b/docs/basic_usage/sampling_params.md @@ -49,6 +49,7 @@ python -m sglang.launch_server --model-path --sampling-defaults openai | max_new_tokens | `int = 128` | The maximum output length measured in tokens. | | stop | `Optional[Union[str, List[str]]] = None` | One or multiple [stop words](https://platform.openai.com/docs/api-reference/chat/create#chat-create-stop). Generation will stop if one of these words is sampled. | | stop_token_ids | `Optional[List[int]] = None` | Provide stop words in the form of token IDs. Generation will stop if one of these token IDs is sampled. | +| stop_regex | `Optional[Union[str, List[str]]] = None` | Stop when hitting any of the regex patterns in this list | | temperature | `float (model default; fallback 1.0)` | [Temperature](https://platform.openai.com/docs/api-reference/chat/create#chat-create-temperature) when sampling the next token. `temperature = 0` corresponds to greedy sampling, a higher temperature leads to more diversity. | | top_p | `float (model default; fallback 1.0)` | [Top-p](https://platform.openai.com/docs/api-reference/chat/create#chat-create-top_p) selects tokens from the smallest sorted set whose cumulative probability exceeds `top_p`. When `top_p = 1`, this reduces to unrestricted sampling from all tokens. | | top_k | `int (model default; fallback -1)` | [Top-k](https://developer.nvidia.com/blog/how-to-get-better-outputs-from-your-large-language-model/#predictability_vs_creativity) randomly selects from the `k` highest-probability tokens. | diff --git a/python/sglang/lang/api.py b/python/sglang/lang/api.py index a8d2e43e6..745c656ee 100644 --- a/python/sglang/lang/api.py +++ b/python/sglang/lang/api.py @@ -79,6 +79,7 @@ def gen( n: Optional[int] = None, stop: Optional[Union[str, List[str]]] = None, stop_token_ids: Optional[List[int]] = None, + stop_regex: Optional[Union[str, List[str]]] = None, temperature: Optional[float] = None, top_p: Optional[float] = None, top_k: Optional[int] = None, @@ -120,6 +121,7 @@ def gen( n, stop, stop_token_ids, + stop_regex, temperature, top_p, top_k, @@ -143,6 +145,7 @@ def gen_int( n: Optional[int] = None, stop: Optional[Union[str, List[str]]] = None, stop_token_ids: Optional[List[int]] = None, + stop_regex: Optional[Union[str, List[str]]] = None, temperature: Optional[float] = None, top_p: Optional[float] = None, top_k: Optional[int] = None, @@ -162,6 +165,7 @@ def gen_int( n, stop, stop_token_ids, + stop_regex, temperature, top_p, top_k, @@ -184,6 +188,7 @@ def gen_string( n: Optional[int] = None, stop: Optional[Union[str, List[str]]] = None, stop_token_ids: Optional[List[int]] = None, + stop_regex: Optional[Union[str, List[str]]] = None, temperature: Optional[float] = None, top_p: Optional[float] = None, top_k: Optional[int] = None, @@ -203,6 +208,7 @@ def gen_string( n, stop, stop_token_ids, + stop_regex, temperature, top_p, top_k, diff --git a/python/sglang/lang/interpreter.py b/python/sglang/lang/interpreter.py index 8b8cdf9c5..0b59e91b5 100644 --- a/python/sglang/lang/interpreter.py +++ b/python/sglang/lang/interpreter.py @@ -792,6 +792,7 @@ class StreamExecutor: "n", "stop", "stop_token_ids", + "stop_regex", "temperature", "top_p", "top_k", diff --git a/python/sglang/lang/ir.py b/python/sglang/lang/ir.py index 531705ebe..ad690f0f3 100644 --- a/python/sglang/lang/ir.py +++ b/python/sglang/lang/ir.py @@ -21,6 +21,7 @@ class SglSamplingParams: n: int = 1 stop: Union[str, List[str]] = () stop_token_ids: Optional[List[int]] = () + stop_regex: Optional[Union[str, List[str]]] = () temperature: float = 1.0 top_p: float = 1.0 top_k: int = -1 # -1 means disable @@ -45,6 +46,7 @@ class SglSamplingParams: self.n, self.stop, self.stop_token_ids, + self.stop_regex, self.temperature, self.top_p, self.top_k, @@ -123,6 +125,7 @@ class SglSamplingParams: "n": self.n, "stop": self.stop, "stop_token_ids": self.stop_token_ids, + "stop_regex": self.stop_regex, "temperature": self.temperature, "top_p": self.top_p, "top_k": self.top_k, @@ -161,6 +164,7 @@ class SglFunction: n: int = 1, stop: Optional[Union[str, List[str]]] = None, stop_token_ids: Optional[List[int]] = None, + stop_regex: Optional[Union[str, List[str]]] = None, temperature: float = 1.0, top_p: float = 1.0, top_k: int = -1, @@ -184,12 +188,15 @@ class SglFunction: stop = [] if stop_token_ids is None: stop_token_ids = [] + if stop_regex is None: + stop_regex = [] default_sampling_para = SglSamplingParams( max_new_tokens=max_new_tokens, n=n, stop=stop, stop_token_ids=stop_token_ids, + stop_regex=stop_regex, temperature=temperature, top_p=top_p, top_k=top_k, @@ -221,6 +228,7 @@ class SglFunction: n: int = 1, stop: Optional[Union[str, List[str]]] = None, stop_token_ids: Optional[List[int]] = None, + stop_regex: Optional[Union[str, List[str]]] = None, temperature: float = 1.0, top_p: float = 1.0, top_k: int = -1, @@ -243,6 +251,8 @@ class SglFunction: stop = [] if stop_token_ids is None: stop_token_ids = [] + if stop_regex is None: + stop_regex = [] assert isinstance(batch_kwargs, (list, tuple)) if len(batch_kwargs) == 0: @@ -267,6 +277,7 @@ class SglFunction: n=n, stop=stop, stop_token_ids=stop_token_ids, + stop_regex=stop_regex, temperature=temperature, top_p=top_p, top_k=top_k, @@ -451,6 +462,7 @@ class SglGen(SglExpr): n: Optional[int] = None, stop: Optional[Union[str, List[str]]] = None, stop_token_ids: Optional[List[int]] = None, + stop_regex: Optional[Union[str, List[str]]] = None, temperature: Optional[float] = None, top_p: Optional[float] = None, top_k: Optional[int] = None, @@ -474,6 +486,7 @@ class SglGen(SglExpr): min_new_tokens=min_new_tokens, n=n, stop=stop, + stop_regex=stop_regex, stop_token_ids=stop_token_ids, temperature=temperature, top_p=top_p, diff --git a/python/sglang/srt/entrypoints/openai/protocol.py b/python/sglang/srt/entrypoints/openai/protocol.py index 735f6a998..871dcfd06 100644 --- a/python/sglang/srt/entrypoints/openai/protocol.py +++ b/python/sglang/srt/entrypoints/openai/protocol.py @@ -221,6 +221,7 @@ class CompletionRequest(BaseModel): ebnf: Optional[str] = None repetition_penalty: float = 1.0 stop_token_ids: Optional[List[int]] = None + stop_regex: Optional[Union[str, List[str]]] = None no_stop_trim: bool = False ignore_eos: bool = False skip_special_tokens: bool = True @@ -474,6 +475,7 @@ class ChatCompletionRequest(BaseModel): ebnf: Optional[str] = None repetition_penalty: Optional[float] = None stop_token_ids: Optional[List[int]] = None + stop_regex: Optional[Union[str, List[str]]] = None no_stop_trim: bool = False ignore_eos: bool = False continue_final_message: bool = False @@ -602,6 +604,7 @@ class ChatCompletionRequest(BaseModel): "min_new_tokens": self.min_tokens, "stop": stop, "stop_token_ids": self.stop_token_ids, + "stop_regex": self.stop_regex, "top_p": get_param("top_p"), "top_k": get_param("top_k"), "min_p": get_param("min_p"), diff --git a/python/sglang/srt/entrypoints/openai/serving_completions.py b/python/sglang/srt/entrypoints/openai/serving_completions.py index b065984aa..aaf3b097c 100644 --- a/python/sglang/srt/entrypoints/openai/serving_completions.py +++ b/python/sglang/srt/entrypoints/openai/serving_completions.py @@ -123,6 +123,7 @@ class OpenAIServingCompletion(OpenAIServingBase): "min_new_tokens": request.min_tokens, "stop": request.stop, "stop_token_ids": request.stop_token_ids, + "stop_regex": request.stop_regex, "top_p": request.top_p, "top_k": request.top_k, "min_p": request.min_p, diff --git a/python/sglang/srt/managers/schedule_batch.py b/python/sglang/srt/managers/schedule_batch.py index 6eab0a088..1d42cd8f7 100644 --- a/python/sglang/srt/managers/schedule_batch.py +++ b/python/sglang/srt/managers/schedule_batch.py @@ -36,6 +36,7 @@ TODO(lmzheng): ModelWorkerBatch seems a bit redundant and we consider removing i import copy import dataclasses import logging +import re import threading import time from enum import Enum, auto @@ -154,6 +155,18 @@ class FINISH_MATCHED_STR(BaseFinishReason): } +class FINISHED_MATCHED_REGEX(BaseFinishReason): + def __init__(self, matched: str): + super().__init__() + self.matched = matched + + def to_json(self): + return { + "type": "stop", # to match OpenAI API's return value + "matched": self.matched, + } + + class FINISH_LENGTH(BaseFinishReason): def __init__(self, length: int): super().__init__() @@ -735,8 +748,17 @@ class Req: return self.surr_and_decode_ids, self.read_offset - self.surr_offset def tail_str(self) -> str: - tail_len = self.sampling_params.stop_str_max_len + 1 - tail_len = min(tail_len, len(self.output_ids)) + # Check stop strings and stop regex patterns together + if ( + len(self.sampling_params.stop_strs) > 0 + or len(self.sampling_params.stop_regex_strs) > 0 + ): + max_len_tail_str = max( + self.sampling_params.stop_str_max_len + 1, + self.sampling_params.stop_regex_max_len + 1, + ) + + tail_len = min((max_len_tail_str + 1), len(self.output_ids)) return self.tokenizer.decode(self.output_ids[-tail_len:]) def check_match_stop_str_prefix(self) -> bool: @@ -817,14 +839,27 @@ class Req: self.finished_reason = FINISH_MATCHED_STR(matched="NaN happened") return - # Check stop strings - if len(self.sampling_params.stop_strs) > 0: + if ( + len(self.sampling_params.stop_strs) > 0 + or len(self.sampling_params.stop_regex_strs) > 0 + ): tail_str = self.tail_str() - for stop_str in self.sampling_params.stop_strs: - if stop_str in tail_str or stop_str in self.decoded_text: - self.finished_reason = FINISH_MATCHED_STR(matched=stop_str) - return + # Check stop strings + if len(self.sampling_params.stop_strs) > 0: + for stop_str in self.sampling_params.stop_strs: + if stop_str in tail_str or stop_str in self.decoded_text: + self.finished_reason = FINISH_MATCHED_STR(matched=stop_str) + return + + # Check stop regex + if len(self.sampling_params.stop_regex_strs) > 0: + for stop_regex_str in self.sampling_params.stop_regex_strs: + if re.search(stop_regex_str, tail_str): + self.finished_reason = FINISHED_MATCHED_REGEX( + matched=stop_regex_str + ) + return def reset_for_retract(self): self.prefix_indices = torch.empty((0,), dtype=torch.int64) diff --git a/python/sglang/srt/sampling/sampling_params.py b/python/sglang/srt/sampling/sampling_params.py index d978e9587..73be70026 100644 --- a/python/sglang/srt/sampling/sampling_params.py +++ b/python/sglang/srt/sampling/sampling_params.py @@ -13,6 +13,8 @@ # ============================================================================== """Sampling parameters for text generation.""" +import logging +import sre_parse from typing import Any, Dict, List, Optional, Union from sglang.srt.utils import get_bool_env_var @@ -20,6 +22,8 @@ from sglang.srt.utils import get_bool_env_var _SAMPLING_EPS = 1e-6 TOP_K_ALL = 1 << 30 +logger = logging.getLogger(__name__) + class SamplingParams: """ @@ -35,6 +39,7 @@ class SamplingParams: max_new_tokens: int = 128, stop: Optional[Union[str, List[str]]] = None, stop_token_ids: Optional[List[int]] = None, + stop_regex: Optional[Union[str, List[str]]] = None, temperature: float = 1.0, top_p: float = 1.0, top_k: int = -1, @@ -63,6 +68,7 @@ class SamplingParams: self.stop_token_ids = set(stop_token_ids) else: self.stop_token_ids = None + self.stop_regex_strs = stop_regex self.temperature = temperature self.top_p = top_p self.top_k = top_k @@ -170,3 +176,67 @@ class SamplingParams: else: stop_str_max_len = max(stop_str_max_len, len(stop_str)) self.stop_str_max_len = stop_str_max_len + + # Process stop regex strings + if self.stop_regex_strs is None: + self.stop_regex_strs = [] + self.stop_regex_max_len = 0 + else: + if isinstance(self.stop_regex_strs, str): + self.stop_regex_strs = [self.stop_regex_strs] + + stop_regex_max_len = 0 + for stop_regex in self.stop_regex_strs: + stop_regex_max_len = max( + stop_regex_max_len, get_max_seq_length(stop_regex) + ) + + self.stop_regex_max_len = stop_regex_max_len + + +# This function gets a strict upperbound on the maximum number of tokens that would need +# to be buffered to match the input regex string +# NOTE: in the worst case, one character that needs to be buffered corresponds to one +# token +def get_max_seq_length(regex_str: str): + return _max_length_from_subpattern(sre_parse.parse(regex_str)) + + +MAX_LEN = 2**30 + + +def _max_length_from_subpattern(subpattern: sre_parse.SubPattern): + total = 0 + for token, value in subpattern: + if token in { + sre_parse.LITERAL, # `value` is any one character + sre_parse.IN, # Any character within `value` + sre_parse.ANY, # "." + }: + total += 1 + elif token == sre_parse.SUBPATTERN: + # EG: (a\d+) -> + # [(SUBPATTERN, + # (1, 0, 0, [(LITERAL, 97), + # (MAX_REPEAT, (1, MAXREPEAT, [(IN, [(CATEGORY, CATEGORY_DIGIT)])]))]))] + _, _, _, inner_subpattern = value + total += _max_length_from_subpattern(inner_subpattern) + elif token == sre_parse.BRANCH: + _, branches = value + total += max(_max_length_from_subpattern(branch) for branch in branches) + elif token in {sre_parse.MAX_REPEAT, sre_parse.MIN_REPEAT}: + _, max_num_repeat, inner_subpattern = value + if max_num_repeat == sre_parse.MAXREPEAT: + total += MAX_LEN + else: + total += max_num_repeat * _max_length_from_subpattern(inner_subpattern) + elif token == sre_parse.AT: + # These are zero-width assertions like ^, $, and \b that don't add to the max + # length + total += 0 + else: + logger.warning(f"Got unhandled regex token: {token}") + + total += MAX_LEN + + return total diff --git a/test/srt/openai_server/validation/test_matched_stop.py b/test/srt/openai_server/validation/test_matched_stop.py index 357b07f31..5c264853a 100644 --- a/test/srt/openai_server/validation/test_matched_stop.py +++ b/test/srt/openai_server/validation/test_matched_stop.py @@ -3,6 +3,7 @@ import unittest import requests +from sglang.srt.sampling.sampling_params import MAX_LEN, get_max_seq_length from sglang.srt.utils import kill_process_tree from sglang.test.test_utils import ( DEFAULT_MODEL_NAME_FOR_TEST, @@ -40,6 +41,7 @@ class TestMatchedStop(CustomTestCase): prompt=MANY_NEW_TOKENS_PROMPT, max_tokens=1, stop=None, + stop_regex=None, finish_reason=None, matched_stop=None, ): @@ -54,6 +56,9 @@ class TestMatchedStop(CustomTestCase): if stop is not None: payload["stop"] = stop + if stop_regex is not None: + payload["stop_regex"] = stop_regex + response_completions = requests.post( self.base_url + "/v1/completions", json=payload, @@ -71,6 +76,7 @@ class TestMatchedStop(CustomTestCase): prompt=MANY_NEW_TOKENS_PROMPT, max_tokens=1, stop=None, + stop_regex=None, finish_reason=None, matched_stop=None, ): @@ -88,6 +94,9 @@ class TestMatchedStop(CustomTestCase): if stop is not None: chat_payload["stop"] = stop + if stop_regex is not None: + chat_payload["stop_regex"] = stop_regex + response_chat = requests.post( self.base_url + "/v1/chat/completions", json=chat_payload, @@ -106,6 +115,30 @@ class TestMatchedStop(CustomTestCase): max_tokens=1000, stop="\n", finish_reason="stop", matched_stop="\n" ) + def test_finish_stop_regex_str(self): + STOP_REGEX_STR = r"and|or" + self.run_completions_generation( + max_tokens=1000, + stop_regex=STOP_REGEX_STR, + finish_reason="stop", + matched_stop=STOP_REGEX_STR, + ) + self.run_chat_completions_generation( + max_tokens=1000, + stop_regex=STOP_REGEX_STR, + finish_reason="stop", + matched_stop=STOP_REGEX_STR, + ) + + # Match a complete sentence + STOP_REGEX_STR_SENTENCE = r"[.!?]\s*$" + self.run_chat_completions_generation( + max_tokens=1000, + stop_regex=STOP_REGEX_STR_SENTENCE, + finish_reason="stop", + matched_stop=STOP_REGEX_STR_SENTENCE, + ) + def test_finish_stop_eos(self): llama_format_prompt = """ <|begin_of_text|><|start_header_id|>system<|end_header_id|> @@ -136,5 +169,53 @@ class TestMatchedStop(CustomTestCase): ) +class TestRegexPatternMaxLength(unittest.TestCase): + @classmethod + def setUpClass(cls): + cls.regex_str_to_max_len = { + "((ab|cd(e|f){2}){3,5}g|hij)*k": MAX_LEN, + # - '*' → infinite tokens need to be stored + "abc*?k": MAX_LEN, + # - '*?' → infinite tokens still need to be stored even if lazy matching used + "^spec(foo|at)$": 7, + # - '^' and '$' don't add any characters to the max length + # "spec" → 4 + # "(foo|at)" → max(3, 2) = 3 + # Whole regex = 7 + "(a(bca|de(fg|hi){2,3})j){2}kl": 22, + # - Innermost alt: "fg" vs "hi" → 2 + # - Repeat {2,3}: max = 3 * 2 = 6 + # - Inner group "de(...)": 2 (for "de") + 6 = 8. + # - "bca" or "de(...)" → max(3, 8) = 8 + # - Whole group: "a" (1) + group (8) + "j"(1) = 10 + # - Repeat {2} → 20 + # - Add "kl"(2) → 22 + "(foo(bar|baz(qux){1,2}))|(x(yz){5,10})": 21, + # Branch 1: + # "foo"(3) + max("bar"(3), "baz"(3)+"qux"{2} = 3 + 6 = 9) = 3 + 9 = 12 + # Branch 2: + # "x"(1) + "yz"{10} = 1 + 20 =21 + # Whole regex = max(12, 21) = 21 + "(((a|bc){1,3}(d(e|f){2}|gh){2,4})|(ijk|lmp(no|p){3})){5}": 90, + # Branch A: + # (a|bc){1,3} → max = 3 * 2 = 6 + # Inside: d(e|f){2} = 1 + 2 * 1 = 3 vs gh = 2 → max = 3 + # Repeat {2,4} → 4 * 3 = 12 + # Branch A total = 18 + # Branch B: + # "ijk"(3) vs "lmp(no|p){3}" = 3 + 3 * max(2, 1) = 3 + 6 = 9 → max = 9 + # Branch B total = 9 + # Whole outer alt = max(18, 9) = 18 + # Repeat {5} → 90 + } + + def test_get_max_length(self): + for regex_str, max_len in self.regex_str_to_max_len.items(): + if max_len == MAX_LEN: + self.assertGreaterEqual(get_max_seq_length(regex_str), MAX_LEN) + else: + self.assertEqual(get_max_seq_length(regex_str), max_len) + + if __name__ == "__main__": unittest.main()