[Feature] support regex strings as a stopping condition (#10635)
This commit is contained in:
@@ -49,6 +49,7 @@ python -m sglang.launch_server --model-path <MODEL> --sampling-defaults openai
|
|||||||
| max_new_tokens | `int = 128` | The maximum output length measured in tokens. |
|
| max_new_tokens | `int = 128` | The maximum output length measured in tokens. |
|
||||||
| stop | `Optional[Union[str, List[str]]] = None` | One or multiple [stop words](https://platform.openai.com/docs/api-reference/chat/create#chat-create-stop). Generation will stop if one of these words is sampled. |
|
| stop | `Optional[Union[str, List[str]]] = None` | One or multiple [stop words](https://platform.openai.com/docs/api-reference/chat/create#chat-create-stop). Generation will stop if one of these words is sampled. |
|
||||||
| stop_token_ids | `Optional[List[int]] = None` | Provide stop words in the form of token IDs. Generation will stop if one of these token IDs is sampled. |
|
| stop_token_ids | `Optional[List[int]] = None` | Provide stop words in the form of token IDs. Generation will stop if one of these token IDs is sampled. |
|
||||||
|
| stop_regex | `Optional[Union[str, List[str]]] = None` | Stop when hitting any of the regex patterns in this list |
|
||||||
| temperature | `float (model default; fallback 1.0)` | [Temperature](https://platform.openai.com/docs/api-reference/chat/create#chat-create-temperature) when sampling the next token. `temperature = 0` corresponds to greedy sampling, a higher temperature leads to more diversity. |
|
| temperature | `float (model default; fallback 1.0)` | [Temperature](https://platform.openai.com/docs/api-reference/chat/create#chat-create-temperature) when sampling the next token. `temperature = 0` corresponds to greedy sampling, a higher temperature leads to more diversity. |
|
||||||
| top_p | `float (model default; fallback 1.0)` | [Top-p](https://platform.openai.com/docs/api-reference/chat/create#chat-create-top_p) selects tokens from the smallest sorted set whose cumulative probability exceeds `top_p`. When `top_p = 1`, this reduces to unrestricted sampling from all tokens. |
|
| top_p | `float (model default; fallback 1.0)` | [Top-p](https://platform.openai.com/docs/api-reference/chat/create#chat-create-top_p) selects tokens from the smallest sorted set whose cumulative probability exceeds `top_p`. When `top_p = 1`, this reduces to unrestricted sampling from all tokens. |
|
||||||
| top_k | `int (model default; fallback -1)` | [Top-k](https://developer.nvidia.com/blog/how-to-get-better-outputs-from-your-large-language-model/#predictability_vs_creativity) randomly selects from the `k` highest-probability tokens. |
|
| top_k | `int (model default; fallback -1)` | [Top-k](https://developer.nvidia.com/blog/how-to-get-better-outputs-from-your-large-language-model/#predictability_vs_creativity) randomly selects from the `k` highest-probability tokens. |
|
||||||
|
|||||||
@@ -79,6 +79,7 @@ def gen(
|
|||||||
n: Optional[int] = None,
|
n: Optional[int] = None,
|
||||||
stop: Optional[Union[str, List[str]]] = None,
|
stop: Optional[Union[str, List[str]]] = None,
|
||||||
stop_token_ids: Optional[List[int]] = None,
|
stop_token_ids: Optional[List[int]] = None,
|
||||||
|
stop_regex: Optional[Union[str, List[str]]] = None,
|
||||||
temperature: Optional[float] = None,
|
temperature: Optional[float] = None,
|
||||||
top_p: Optional[float] = None,
|
top_p: Optional[float] = None,
|
||||||
top_k: Optional[int] = None,
|
top_k: Optional[int] = None,
|
||||||
@@ -120,6 +121,7 @@ def gen(
|
|||||||
n,
|
n,
|
||||||
stop,
|
stop,
|
||||||
stop_token_ids,
|
stop_token_ids,
|
||||||
|
stop_regex,
|
||||||
temperature,
|
temperature,
|
||||||
top_p,
|
top_p,
|
||||||
top_k,
|
top_k,
|
||||||
@@ -143,6 +145,7 @@ def gen_int(
|
|||||||
n: Optional[int] = None,
|
n: Optional[int] = None,
|
||||||
stop: Optional[Union[str, List[str]]] = None,
|
stop: Optional[Union[str, List[str]]] = None,
|
||||||
stop_token_ids: Optional[List[int]] = None,
|
stop_token_ids: Optional[List[int]] = None,
|
||||||
|
stop_regex: Optional[Union[str, List[str]]] = None,
|
||||||
temperature: Optional[float] = None,
|
temperature: Optional[float] = None,
|
||||||
top_p: Optional[float] = None,
|
top_p: Optional[float] = None,
|
||||||
top_k: Optional[int] = None,
|
top_k: Optional[int] = None,
|
||||||
@@ -162,6 +165,7 @@ def gen_int(
|
|||||||
n,
|
n,
|
||||||
stop,
|
stop,
|
||||||
stop_token_ids,
|
stop_token_ids,
|
||||||
|
stop_regex,
|
||||||
temperature,
|
temperature,
|
||||||
top_p,
|
top_p,
|
||||||
top_k,
|
top_k,
|
||||||
@@ -184,6 +188,7 @@ def gen_string(
|
|||||||
n: Optional[int] = None,
|
n: Optional[int] = None,
|
||||||
stop: Optional[Union[str, List[str]]] = None,
|
stop: Optional[Union[str, List[str]]] = None,
|
||||||
stop_token_ids: Optional[List[int]] = None,
|
stop_token_ids: Optional[List[int]] = None,
|
||||||
|
stop_regex: Optional[Union[str, List[str]]] = None,
|
||||||
temperature: Optional[float] = None,
|
temperature: Optional[float] = None,
|
||||||
top_p: Optional[float] = None,
|
top_p: Optional[float] = None,
|
||||||
top_k: Optional[int] = None,
|
top_k: Optional[int] = None,
|
||||||
@@ -203,6 +208,7 @@ def gen_string(
|
|||||||
n,
|
n,
|
||||||
stop,
|
stop,
|
||||||
stop_token_ids,
|
stop_token_ids,
|
||||||
|
stop_regex,
|
||||||
temperature,
|
temperature,
|
||||||
top_p,
|
top_p,
|
||||||
top_k,
|
top_k,
|
||||||
|
|||||||
@@ -792,6 +792,7 @@ class StreamExecutor:
|
|||||||
"n",
|
"n",
|
||||||
"stop",
|
"stop",
|
||||||
"stop_token_ids",
|
"stop_token_ids",
|
||||||
|
"stop_regex",
|
||||||
"temperature",
|
"temperature",
|
||||||
"top_p",
|
"top_p",
|
||||||
"top_k",
|
"top_k",
|
||||||
|
|||||||
@@ -21,6 +21,7 @@ class SglSamplingParams:
|
|||||||
n: int = 1
|
n: int = 1
|
||||||
stop: Union[str, List[str]] = ()
|
stop: Union[str, List[str]] = ()
|
||||||
stop_token_ids: Optional[List[int]] = ()
|
stop_token_ids: Optional[List[int]] = ()
|
||||||
|
stop_regex: Optional[Union[str, List[str]]] = ()
|
||||||
temperature: float = 1.0
|
temperature: float = 1.0
|
||||||
top_p: float = 1.0
|
top_p: float = 1.0
|
||||||
top_k: int = -1 # -1 means disable
|
top_k: int = -1 # -1 means disable
|
||||||
@@ -45,6 +46,7 @@ class SglSamplingParams:
|
|||||||
self.n,
|
self.n,
|
||||||
self.stop,
|
self.stop,
|
||||||
self.stop_token_ids,
|
self.stop_token_ids,
|
||||||
|
self.stop_regex,
|
||||||
self.temperature,
|
self.temperature,
|
||||||
self.top_p,
|
self.top_p,
|
||||||
self.top_k,
|
self.top_k,
|
||||||
@@ -123,6 +125,7 @@ class SglSamplingParams:
|
|||||||
"n": self.n,
|
"n": self.n,
|
||||||
"stop": self.stop,
|
"stop": self.stop,
|
||||||
"stop_token_ids": self.stop_token_ids,
|
"stop_token_ids": self.stop_token_ids,
|
||||||
|
"stop_regex": self.stop_regex,
|
||||||
"temperature": self.temperature,
|
"temperature": self.temperature,
|
||||||
"top_p": self.top_p,
|
"top_p": self.top_p,
|
||||||
"top_k": self.top_k,
|
"top_k": self.top_k,
|
||||||
@@ -161,6 +164,7 @@ class SglFunction:
|
|||||||
n: int = 1,
|
n: int = 1,
|
||||||
stop: Optional[Union[str, List[str]]] = None,
|
stop: Optional[Union[str, List[str]]] = None,
|
||||||
stop_token_ids: Optional[List[int]] = None,
|
stop_token_ids: Optional[List[int]] = None,
|
||||||
|
stop_regex: Optional[Union[str, List[str]]] = None,
|
||||||
temperature: float = 1.0,
|
temperature: float = 1.0,
|
||||||
top_p: float = 1.0,
|
top_p: float = 1.0,
|
||||||
top_k: int = -1,
|
top_k: int = -1,
|
||||||
@@ -184,12 +188,15 @@ class SglFunction:
|
|||||||
stop = []
|
stop = []
|
||||||
if stop_token_ids is None:
|
if stop_token_ids is None:
|
||||||
stop_token_ids = []
|
stop_token_ids = []
|
||||||
|
if stop_regex is None:
|
||||||
|
stop_regex = []
|
||||||
|
|
||||||
default_sampling_para = SglSamplingParams(
|
default_sampling_para = SglSamplingParams(
|
||||||
max_new_tokens=max_new_tokens,
|
max_new_tokens=max_new_tokens,
|
||||||
n=n,
|
n=n,
|
||||||
stop=stop,
|
stop=stop,
|
||||||
stop_token_ids=stop_token_ids,
|
stop_token_ids=stop_token_ids,
|
||||||
|
stop_regex=stop_regex,
|
||||||
temperature=temperature,
|
temperature=temperature,
|
||||||
top_p=top_p,
|
top_p=top_p,
|
||||||
top_k=top_k,
|
top_k=top_k,
|
||||||
@@ -221,6 +228,7 @@ class SglFunction:
|
|||||||
n: int = 1,
|
n: int = 1,
|
||||||
stop: Optional[Union[str, List[str]]] = None,
|
stop: Optional[Union[str, List[str]]] = None,
|
||||||
stop_token_ids: Optional[List[int]] = None,
|
stop_token_ids: Optional[List[int]] = None,
|
||||||
|
stop_regex: Optional[Union[str, List[str]]] = None,
|
||||||
temperature: float = 1.0,
|
temperature: float = 1.0,
|
||||||
top_p: float = 1.0,
|
top_p: float = 1.0,
|
||||||
top_k: int = -1,
|
top_k: int = -1,
|
||||||
@@ -243,6 +251,8 @@ class SglFunction:
|
|||||||
stop = []
|
stop = []
|
||||||
if stop_token_ids is None:
|
if stop_token_ids is None:
|
||||||
stop_token_ids = []
|
stop_token_ids = []
|
||||||
|
if stop_regex is None:
|
||||||
|
stop_regex = []
|
||||||
|
|
||||||
assert isinstance(batch_kwargs, (list, tuple))
|
assert isinstance(batch_kwargs, (list, tuple))
|
||||||
if len(batch_kwargs) == 0:
|
if len(batch_kwargs) == 0:
|
||||||
@@ -267,6 +277,7 @@ class SglFunction:
|
|||||||
n=n,
|
n=n,
|
||||||
stop=stop,
|
stop=stop,
|
||||||
stop_token_ids=stop_token_ids,
|
stop_token_ids=stop_token_ids,
|
||||||
|
stop_regex=stop_regex,
|
||||||
temperature=temperature,
|
temperature=temperature,
|
||||||
top_p=top_p,
|
top_p=top_p,
|
||||||
top_k=top_k,
|
top_k=top_k,
|
||||||
@@ -451,6 +462,7 @@ class SglGen(SglExpr):
|
|||||||
n: Optional[int] = None,
|
n: Optional[int] = None,
|
||||||
stop: Optional[Union[str, List[str]]] = None,
|
stop: Optional[Union[str, List[str]]] = None,
|
||||||
stop_token_ids: Optional[List[int]] = None,
|
stop_token_ids: Optional[List[int]] = None,
|
||||||
|
stop_regex: Optional[Union[str, List[str]]] = None,
|
||||||
temperature: Optional[float] = None,
|
temperature: Optional[float] = None,
|
||||||
top_p: Optional[float] = None,
|
top_p: Optional[float] = None,
|
||||||
top_k: Optional[int] = None,
|
top_k: Optional[int] = None,
|
||||||
@@ -474,6 +486,7 @@ class SglGen(SglExpr):
|
|||||||
min_new_tokens=min_new_tokens,
|
min_new_tokens=min_new_tokens,
|
||||||
n=n,
|
n=n,
|
||||||
stop=stop,
|
stop=stop,
|
||||||
|
stop_regex=stop_regex,
|
||||||
stop_token_ids=stop_token_ids,
|
stop_token_ids=stop_token_ids,
|
||||||
temperature=temperature,
|
temperature=temperature,
|
||||||
top_p=top_p,
|
top_p=top_p,
|
||||||
|
|||||||
@@ -221,6 +221,7 @@ class CompletionRequest(BaseModel):
|
|||||||
ebnf: Optional[str] = None
|
ebnf: Optional[str] = None
|
||||||
repetition_penalty: float = 1.0
|
repetition_penalty: float = 1.0
|
||||||
stop_token_ids: Optional[List[int]] = None
|
stop_token_ids: Optional[List[int]] = None
|
||||||
|
stop_regex: Optional[Union[str, List[str]]] = None
|
||||||
no_stop_trim: bool = False
|
no_stop_trim: bool = False
|
||||||
ignore_eos: bool = False
|
ignore_eos: bool = False
|
||||||
skip_special_tokens: bool = True
|
skip_special_tokens: bool = True
|
||||||
@@ -474,6 +475,7 @@ class ChatCompletionRequest(BaseModel):
|
|||||||
ebnf: Optional[str] = None
|
ebnf: Optional[str] = None
|
||||||
repetition_penalty: Optional[float] = None
|
repetition_penalty: Optional[float] = None
|
||||||
stop_token_ids: Optional[List[int]] = None
|
stop_token_ids: Optional[List[int]] = None
|
||||||
|
stop_regex: Optional[Union[str, List[str]]] = None
|
||||||
no_stop_trim: bool = False
|
no_stop_trim: bool = False
|
||||||
ignore_eos: bool = False
|
ignore_eos: bool = False
|
||||||
continue_final_message: bool = False
|
continue_final_message: bool = False
|
||||||
@@ -602,6 +604,7 @@ class ChatCompletionRequest(BaseModel):
|
|||||||
"min_new_tokens": self.min_tokens,
|
"min_new_tokens": self.min_tokens,
|
||||||
"stop": stop,
|
"stop": stop,
|
||||||
"stop_token_ids": self.stop_token_ids,
|
"stop_token_ids": self.stop_token_ids,
|
||||||
|
"stop_regex": self.stop_regex,
|
||||||
"top_p": get_param("top_p"),
|
"top_p": get_param("top_p"),
|
||||||
"top_k": get_param("top_k"),
|
"top_k": get_param("top_k"),
|
||||||
"min_p": get_param("min_p"),
|
"min_p": get_param("min_p"),
|
||||||
|
|||||||
@@ -123,6 +123,7 @@ class OpenAIServingCompletion(OpenAIServingBase):
|
|||||||
"min_new_tokens": request.min_tokens,
|
"min_new_tokens": request.min_tokens,
|
||||||
"stop": request.stop,
|
"stop": request.stop,
|
||||||
"stop_token_ids": request.stop_token_ids,
|
"stop_token_ids": request.stop_token_ids,
|
||||||
|
"stop_regex": request.stop_regex,
|
||||||
"top_p": request.top_p,
|
"top_p": request.top_p,
|
||||||
"top_k": request.top_k,
|
"top_k": request.top_k,
|
||||||
"min_p": request.min_p,
|
"min_p": request.min_p,
|
||||||
|
|||||||
@@ -36,6 +36,7 @@ TODO(lmzheng): ModelWorkerBatch seems a bit redundant and we consider removing i
|
|||||||
import copy
|
import copy
|
||||||
import dataclasses
|
import dataclasses
|
||||||
import logging
|
import logging
|
||||||
|
import re
|
||||||
import threading
|
import threading
|
||||||
import time
|
import time
|
||||||
from enum import Enum, auto
|
from enum import Enum, auto
|
||||||
@@ -154,6 +155,18 @@ class FINISH_MATCHED_STR(BaseFinishReason):
|
|||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
class FINISHED_MATCHED_REGEX(BaseFinishReason):
|
||||||
|
def __init__(self, matched: str):
|
||||||
|
super().__init__()
|
||||||
|
self.matched = matched
|
||||||
|
|
||||||
|
def to_json(self):
|
||||||
|
return {
|
||||||
|
"type": "stop", # to match OpenAI API's return value
|
||||||
|
"matched": self.matched,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
class FINISH_LENGTH(BaseFinishReason):
|
class FINISH_LENGTH(BaseFinishReason):
|
||||||
def __init__(self, length: int):
|
def __init__(self, length: int):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
@@ -735,8 +748,17 @@ class Req:
|
|||||||
return self.surr_and_decode_ids, self.read_offset - self.surr_offset
|
return self.surr_and_decode_ids, self.read_offset - self.surr_offset
|
||||||
|
|
||||||
def tail_str(self) -> str:
|
def tail_str(self) -> str:
|
||||||
tail_len = self.sampling_params.stop_str_max_len + 1
|
# Check stop strings and stop regex patterns together
|
||||||
tail_len = min(tail_len, len(self.output_ids))
|
if (
|
||||||
|
len(self.sampling_params.stop_strs) > 0
|
||||||
|
or len(self.sampling_params.stop_regex_strs) > 0
|
||||||
|
):
|
||||||
|
max_len_tail_str = max(
|
||||||
|
self.sampling_params.stop_str_max_len + 1,
|
||||||
|
self.sampling_params.stop_regex_max_len + 1,
|
||||||
|
)
|
||||||
|
|
||||||
|
tail_len = min((max_len_tail_str + 1), len(self.output_ids))
|
||||||
return self.tokenizer.decode(self.output_ids[-tail_len:])
|
return self.tokenizer.decode(self.output_ids[-tail_len:])
|
||||||
|
|
||||||
def check_match_stop_str_prefix(self) -> bool:
|
def check_match_stop_str_prefix(self) -> bool:
|
||||||
@@ -817,14 +839,27 @@ class Req:
|
|||||||
self.finished_reason = FINISH_MATCHED_STR(matched="NaN happened")
|
self.finished_reason = FINISH_MATCHED_STR(matched="NaN happened")
|
||||||
return
|
return
|
||||||
|
|
||||||
# Check stop strings
|
if (
|
||||||
if len(self.sampling_params.stop_strs) > 0:
|
len(self.sampling_params.stop_strs) > 0
|
||||||
|
or len(self.sampling_params.stop_regex_strs) > 0
|
||||||
|
):
|
||||||
tail_str = self.tail_str()
|
tail_str = self.tail_str()
|
||||||
|
|
||||||
for stop_str in self.sampling_params.stop_strs:
|
# Check stop strings
|
||||||
if stop_str in tail_str or stop_str in self.decoded_text:
|
if len(self.sampling_params.stop_strs) > 0:
|
||||||
self.finished_reason = FINISH_MATCHED_STR(matched=stop_str)
|
for stop_str in self.sampling_params.stop_strs:
|
||||||
return
|
if stop_str in tail_str or stop_str in self.decoded_text:
|
||||||
|
self.finished_reason = FINISH_MATCHED_STR(matched=stop_str)
|
||||||
|
return
|
||||||
|
|
||||||
|
# Check stop regex
|
||||||
|
if len(self.sampling_params.stop_regex_strs) > 0:
|
||||||
|
for stop_regex_str in self.sampling_params.stop_regex_strs:
|
||||||
|
if re.search(stop_regex_str, tail_str):
|
||||||
|
self.finished_reason = FINISHED_MATCHED_REGEX(
|
||||||
|
matched=stop_regex_str
|
||||||
|
)
|
||||||
|
return
|
||||||
|
|
||||||
def reset_for_retract(self):
|
def reset_for_retract(self):
|
||||||
self.prefix_indices = torch.empty((0,), dtype=torch.int64)
|
self.prefix_indices = torch.empty((0,), dtype=torch.int64)
|
||||||
|
|||||||
@@ -13,6 +13,8 @@
|
|||||||
# ==============================================================================
|
# ==============================================================================
|
||||||
"""Sampling parameters for text generation."""
|
"""Sampling parameters for text generation."""
|
||||||
|
|
||||||
|
import logging
|
||||||
|
import sre_parse
|
||||||
from typing import Any, Dict, List, Optional, Union
|
from typing import Any, Dict, List, Optional, Union
|
||||||
|
|
||||||
from sglang.srt.utils import get_bool_env_var
|
from sglang.srt.utils import get_bool_env_var
|
||||||
@@ -20,6 +22,8 @@ from sglang.srt.utils import get_bool_env_var
|
|||||||
_SAMPLING_EPS = 1e-6
|
_SAMPLING_EPS = 1e-6
|
||||||
TOP_K_ALL = 1 << 30
|
TOP_K_ALL = 1 << 30
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
class SamplingParams:
|
class SamplingParams:
|
||||||
"""
|
"""
|
||||||
@@ -35,6 +39,7 @@ class SamplingParams:
|
|||||||
max_new_tokens: int = 128,
|
max_new_tokens: int = 128,
|
||||||
stop: Optional[Union[str, List[str]]] = None,
|
stop: Optional[Union[str, List[str]]] = None,
|
||||||
stop_token_ids: Optional[List[int]] = None,
|
stop_token_ids: Optional[List[int]] = None,
|
||||||
|
stop_regex: Optional[Union[str, List[str]]] = None,
|
||||||
temperature: float = 1.0,
|
temperature: float = 1.0,
|
||||||
top_p: float = 1.0,
|
top_p: float = 1.0,
|
||||||
top_k: int = -1,
|
top_k: int = -1,
|
||||||
@@ -63,6 +68,7 @@ class SamplingParams:
|
|||||||
self.stop_token_ids = set(stop_token_ids)
|
self.stop_token_ids = set(stop_token_ids)
|
||||||
else:
|
else:
|
||||||
self.stop_token_ids = None
|
self.stop_token_ids = None
|
||||||
|
self.stop_regex_strs = stop_regex
|
||||||
self.temperature = temperature
|
self.temperature = temperature
|
||||||
self.top_p = top_p
|
self.top_p = top_p
|
||||||
self.top_k = top_k
|
self.top_k = top_k
|
||||||
@@ -170,3 +176,67 @@ class SamplingParams:
|
|||||||
else:
|
else:
|
||||||
stop_str_max_len = max(stop_str_max_len, len(stop_str))
|
stop_str_max_len = max(stop_str_max_len, len(stop_str))
|
||||||
self.stop_str_max_len = stop_str_max_len
|
self.stop_str_max_len = stop_str_max_len
|
||||||
|
|
||||||
|
# Process stop regex strings
|
||||||
|
if self.stop_regex_strs is None:
|
||||||
|
self.stop_regex_strs = []
|
||||||
|
self.stop_regex_max_len = 0
|
||||||
|
else:
|
||||||
|
if isinstance(self.stop_regex_strs, str):
|
||||||
|
self.stop_regex_strs = [self.stop_regex_strs]
|
||||||
|
|
||||||
|
stop_regex_max_len = 0
|
||||||
|
for stop_regex in self.stop_regex_strs:
|
||||||
|
stop_regex_max_len = max(
|
||||||
|
stop_regex_max_len, get_max_seq_length(stop_regex)
|
||||||
|
)
|
||||||
|
|
||||||
|
self.stop_regex_max_len = stop_regex_max_len
|
||||||
|
|
||||||
|
|
||||||
|
# This function gets a strict upperbound on the maximum number of tokens that would need
|
||||||
|
# to be buffered to match the input regex string
|
||||||
|
# NOTE: in the worst case, one character that needs to be buffered corresponds to one
|
||||||
|
# token
|
||||||
|
def get_max_seq_length(regex_str: str):
|
||||||
|
return _max_length_from_subpattern(sre_parse.parse(regex_str))
|
||||||
|
|
||||||
|
|
||||||
|
MAX_LEN = 2**30
|
||||||
|
|
||||||
|
|
||||||
|
def _max_length_from_subpattern(subpattern: sre_parse.SubPattern):
|
||||||
|
total = 0
|
||||||
|
for token, value in subpattern:
|
||||||
|
if token in {
|
||||||
|
sre_parse.LITERAL, # `value` is any one character
|
||||||
|
sre_parse.IN, # Any character within `value`
|
||||||
|
sre_parse.ANY, # "."
|
||||||
|
}:
|
||||||
|
total += 1
|
||||||
|
elif token == sre_parse.SUBPATTERN:
|
||||||
|
# EG: (a\d+) ->
|
||||||
|
# [(SUBPATTERN,
|
||||||
|
# (1, 0, 0, [(LITERAL, 97),
|
||||||
|
# (MAX_REPEAT, (1, MAXREPEAT, [(IN, [(CATEGORY, CATEGORY_DIGIT)])]))]))]
|
||||||
|
_, _, _, inner_subpattern = value
|
||||||
|
total += _max_length_from_subpattern(inner_subpattern)
|
||||||
|
elif token == sre_parse.BRANCH:
|
||||||
|
_, branches = value
|
||||||
|
total += max(_max_length_from_subpattern(branch) for branch in branches)
|
||||||
|
elif token in {sre_parse.MAX_REPEAT, sre_parse.MIN_REPEAT}:
|
||||||
|
_, max_num_repeat, inner_subpattern = value
|
||||||
|
if max_num_repeat == sre_parse.MAXREPEAT:
|
||||||
|
total += MAX_LEN
|
||||||
|
else:
|
||||||
|
total += max_num_repeat * _max_length_from_subpattern(inner_subpattern)
|
||||||
|
elif token == sre_parse.AT:
|
||||||
|
# These are zero-width assertions like ^, $, and \b that don't add to the max
|
||||||
|
# length
|
||||||
|
total += 0
|
||||||
|
else:
|
||||||
|
logger.warning(f"Got unhandled regex token: {token}")
|
||||||
|
|
||||||
|
total += MAX_LEN
|
||||||
|
|
||||||
|
return total
|
||||||
|
|||||||
@@ -3,6 +3,7 @@ import unittest
|
|||||||
|
|
||||||
import requests
|
import requests
|
||||||
|
|
||||||
|
from sglang.srt.sampling.sampling_params import MAX_LEN, get_max_seq_length
|
||||||
from sglang.srt.utils import kill_process_tree
|
from sglang.srt.utils import kill_process_tree
|
||||||
from sglang.test.test_utils import (
|
from sglang.test.test_utils import (
|
||||||
DEFAULT_MODEL_NAME_FOR_TEST,
|
DEFAULT_MODEL_NAME_FOR_TEST,
|
||||||
@@ -40,6 +41,7 @@ class TestMatchedStop(CustomTestCase):
|
|||||||
prompt=MANY_NEW_TOKENS_PROMPT,
|
prompt=MANY_NEW_TOKENS_PROMPT,
|
||||||
max_tokens=1,
|
max_tokens=1,
|
||||||
stop=None,
|
stop=None,
|
||||||
|
stop_regex=None,
|
||||||
finish_reason=None,
|
finish_reason=None,
|
||||||
matched_stop=None,
|
matched_stop=None,
|
||||||
):
|
):
|
||||||
@@ -54,6 +56,9 @@ class TestMatchedStop(CustomTestCase):
|
|||||||
if stop is not None:
|
if stop is not None:
|
||||||
payload["stop"] = stop
|
payload["stop"] = stop
|
||||||
|
|
||||||
|
if stop_regex is not None:
|
||||||
|
payload["stop_regex"] = stop_regex
|
||||||
|
|
||||||
response_completions = requests.post(
|
response_completions = requests.post(
|
||||||
self.base_url + "/v1/completions",
|
self.base_url + "/v1/completions",
|
||||||
json=payload,
|
json=payload,
|
||||||
@@ -71,6 +76,7 @@ class TestMatchedStop(CustomTestCase):
|
|||||||
prompt=MANY_NEW_TOKENS_PROMPT,
|
prompt=MANY_NEW_TOKENS_PROMPT,
|
||||||
max_tokens=1,
|
max_tokens=1,
|
||||||
stop=None,
|
stop=None,
|
||||||
|
stop_regex=None,
|
||||||
finish_reason=None,
|
finish_reason=None,
|
||||||
matched_stop=None,
|
matched_stop=None,
|
||||||
):
|
):
|
||||||
@@ -88,6 +94,9 @@ class TestMatchedStop(CustomTestCase):
|
|||||||
if stop is not None:
|
if stop is not None:
|
||||||
chat_payload["stop"] = stop
|
chat_payload["stop"] = stop
|
||||||
|
|
||||||
|
if stop_regex is not None:
|
||||||
|
chat_payload["stop_regex"] = stop_regex
|
||||||
|
|
||||||
response_chat = requests.post(
|
response_chat = requests.post(
|
||||||
self.base_url + "/v1/chat/completions",
|
self.base_url + "/v1/chat/completions",
|
||||||
json=chat_payload,
|
json=chat_payload,
|
||||||
@@ -106,6 +115,30 @@ class TestMatchedStop(CustomTestCase):
|
|||||||
max_tokens=1000, stop="\n", finish_reason="stop", matched_stop="\n"
|
max_tokens=1000, stop="\n", finish_reason="stop", matched_stop="\n"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def test_finish_stop_regex_str(self):
|
||||||
|
STOP_REGEX_STR = r"and|or"
|
||||||
|
self.run_completions_generation(
|
||||||
|
max_tokens=1000,
|
||||||
|
stop_regex=STOP_REGEX_STR,
|
||||||
|
finish_reason="stop",
|
||||||
|
matched_stop=STOP_REGEX_STR,
|
||||||
|
)
|
||||||
|
self.run_chat_completions_generation(
|
||||||
|
max_tokens=1000,
|
||||||
|
stop_regex=STOP_REGEX_STR,
|
||||||
|
finish_reason="stop",
|
||||||
|
matched_stop=STOP_REGEX_STR,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Match a complete sentence
|
||||||
|
STOP_REGEX_STR_SENTENCE = r"[.!?]\s*$"
|
||||||
|
self.run_chat_completions_generation(
|
||||||
|
max_tokens=1000,
|
||||||
|
stop_regex=STOP_REGEX_STR_SENTENCE,
|
||||||
|
finish_reason="stop",
|
||||||
|
matched_stop=STOP_REGEX_STR_SENTENCE,
|
||||||
|
)
|
||||||
|
|
||||||
def test_finish_stop_eos(self):
|
def test_finish_stop_eos(self):
|
||||||
llama_format_prompt = """
|
llama_format_prompt = """
|
||||||
<|begin_of_text|><|start_header_id|>system<|end_header_id|>
|
<|begin_of_text|><|start_header_id|>system<|end_header_id|>
|
||||||
@@ -136,5 +169,53 @@ class TestMatchedStop(CustomTestCase):
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class TestRegexPatternMaxLength(unittest.TestCase):
|
||||||
|
@classmethod
|
||||||
|
def setUpClass(cls):
|
||||||
|
cls.regex_str_to_max_len = {
|
||||||
|
"((ab|cd(e|f){2}){3,5}g|hij)*k": MAX_LEN,
|
||||||
|
# - '*' → infinite tokens need to be stored
|
||||||
|
"abc*?k": MAX_LEN,
|
||||||
|
# - '*?' → infinite tokens still need to be stored even if lazy matching used
|
||||||
|
"^spec(foo|at)$": 7,
|
||||||
|
# - '^' and '$' don't add any characters to the max length
|
||||||
|
# "spec" → 4
|
||||||
|
# "(foo|at)" → max(3, 2) = 3
|
||||||
|
# Whole regex = 7
|
||||||
|
"(a(bca|de(fg|hi){2,3})j){2}kl": 22,
|
||||||
|
# - Innermost alt: "fg" vs "hi" → 2
|
||||||
|
# - Repeat {2,3}: max = 3 * 2 = 6
|
||||||
|
# - Inner group "de(...)": 2 (for "de") + 6 = 8.
|
||||||
|
# - "bca" or "de(...)" → max(3, 8) = 8
|
||||||
|
# - Whole group: "a" (1) + group (8) + "j"(1) = 10
|
||||||
|
# - Repeat {2} → 20
|
||||||
|
# - Add "kl"(2) → 22
|
||||||
|
"(foo(bar|baz(qux){1,2}))|(x(yz){5,10})": 21,
|
||||||
|
# Branch 1:
|
||||||
|
# "foo"(3) + max("bar"(3), "baz"(3)+"qux"{2} = 3 + 6 = 9) = 3 + 9 = 12
|
||||||
|
# Branch 2:
|
||||||
|
# "x"(1) + "yz"{10} = 1 + 20 =21
|
||||||
|
# Whole regex = max(12, 21) = 21
|
||||||
|
"(((a|bc){1,3}(d(e|f){2}|gh){2,4})|(ijk|lmp(no|p){3})){5}": 90,
|
||||||
|
# Branch A:
|
||||||
|
# (a|bc){1,3} → max = 3 * 2 = 6
|
||||||
|
# Inside: d(e|f){2} = 1 + 2 * 1 = 3 vs gh = 2 → max = 3
|
||||||
|
# Repeat {2,4} → 4 * 3 = 12
|
||||||
|
# Branch A total = 18
|
||||||
|
# Branch B:
|
||||||
|
# "ijk"(3) vs "lmp(no|p){3}" = 3 + 3 * max(2, 1) = 3 + 6 = 9 → max = 9
|
||||||
|
# Branch B total = 9
|
||||||
|
# Whole outer alt = max(18, 9) = 18
|
||||||
|
# Repeat {5} → 90
|
||||||
|
}
|
||||||
|
|
||||||
|
def test_get_max_length(self):
|
||||||
|
for regex_str, max_len in self.regex_str_to_max_len.items():
|
||||||
|
if max_len == MAX_LEN:
|
||||||
|
self.assertGreaterEqual(get_max_seq_length(regex_str), MAX_LEN)
|
||||||
|
else:
|
||||||
|
self.assertEqual(get_max_seq_length(regex_str), max_len)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
unittest.main()
|
unittest.main()
|
||||||
|
|||||||
Reference in New Issue
Block a user