[Feature] support regex strings as a stopping condition (#10635)

This commit is contained in:
Glen Liu
2025-10-11 22:53:15 -04:00
committed by GitHub
parent 9fcf73069f
commit 47c606d3dc
9 changed files with 219 additions and 8 deletions

View File

@@ -79,6 +79,7 @@ def gen(
n: Optional[int] = None,
stop: Optional[Union[str, List[str]]] = None,
stop_token_ids: Optional[List[int]] = None,
stop_regex: Optional[Union[str, List[str]]] = None,
temperature: Optional[float] = None,
top_p: Optional[float] = None,
top_k: Optional[int] = None,
@@ -120,6 +121,7 @@ def gen(
n,
stop,
stop_token_ids,
stop_regex,
temperature,
top_p,
top_k,
@@ -143,6 +145,7 @@ def gen_int(
n: Optional[int] = None,
stop: Optional[Union[str, List[str]]] = None,
stop_token_ids: Optional[List[int]] = None,
stop_regex: Optional[Union[str, List[str]]] = None,
temperature: Optional[float] = None,
top_p: Optional[float] = None,
top_k: Optional[int] = None,
@@ -162,6 +165,7 @@ def gen_int(
n,
stop,
stop_token_ids,
stop_regex,
temperature,
top_p,
top_k,
@@ -184,6 +188,7 @@ def gen_string(
n: Optional[int] = None,
stop: Optional[Union[str, List[str]]] = None,
stop_token_ids: Optional[List[int]] = None,
stop_regex: Optional[Union[str, List[str]]] = None,
temperature: Optional[float] = None,
top_p: Optional[float] = None,
top_k: Optional[int] = None,
@@ -203,6 +208,7 @@ def gen_string(
n,
stop,
stop_token_ids,
stop_regex,
temperature,
top_p,
top_k,

View File

@@ -792,6 +792,7 @@ class StreamExecutor:
"n",
"stop",
"stop_token_ids",
"stop_regex",
"temperature",
"top_p",
"top_k",

View File

@@ -21,6 +21,7 @@ class SglSamplingParams:
n: int = 1
stop: Union[str, List[str]] = ()
stop_token_ids: Optional[List[int]] = ()
stop_regex: Optional[Union[str, List[str]]] = ()
temperature: float = 1.0
top_p: float = 1.0
top_k: int = -1 # -1 means disable
@@ -45,6 +46,7 @@ class SglSamplingParams:
self.n,
self.stop,
self.stop_token_ids,
self.stop_regex,
self.temperature,
self.top_p,
self.top_k,
@@ -123,6 +125,7 @@ class SglSamplingParams:
"n": self.n,
"stop": self.stop,
"stop_token_ids": self.stop_token_ids,
"stop_regex": self.stop_regex,
"temperature": self.temperature,
"top_p": self.top_p,
"top_k": self.top_k,
@@ -161,6 +164,7 @@ class SglFunction:
n: int = 1,
stop: Optional[Union[str, List[str]]] = None,
stop_token_ids: Optional[List[int]] = None,
stop_regex: Optional[Union[str, List[str]]] = None,
temperature: float = 1.0,
top_p: float = 1.0,
top_k: int = -1,
@@ -184,12 +188,15 @@ class SglFunction:
stop = []
if stop_token_ids is None:
stop_token_ids = []
if stop_regex is None:
stop_regex = []
default_sampling_para = SglSamplingParams(
max_new_tokens=max_new_tokens,
n=n,
stop=stop,
stop_token_ids=stop_token_ids,
stop_regex=stop_regex,
temperature=temperature,
top_p=top_p,
top_k=top_k,
@@ -221,6 +228,7 @@ class SglFunction:
n: int = 1,
stop: Optional[Union[str, List[str]]] = None,
stop_token_ids: Optional[List[int]] = None,
stop_regex: Optional[Union[str, List[str]]] = None,
temperature: float = 1.0,
top_p: float = 1.0,
top_k: int = -1,
@@ -243,6 +251,8 @@ class SglFunction:
stop = []
if stop_token_ids is None:
stop_token_ids = []
if stop_regex is None:
stop_regex = []
assert isinstance(batch_kwargs, (list, tuple))
if len(batch_kwargs) == 0:
@@ -267,6 +277,7 @@ class SglFunction:
n=n,
stop=stop,
stop_token_ids=stop_token_ids,
stop_regex=stop_regex,
temperature=temperature,
top_p=top_p,
top_k=top_k,
@@ -451,6 +462,7 @@ class SglGen(SglExpr):
n: Optional[int] = None,
stop: Optional[Union[str, List[str]]] = None,
stop_token_ids: Optional[List[int]] = None,
stop_regex: Optional[Union[str, List[str]]] = None,
temperature: Optional[float] = None,
top_p: Optional[float] = None,
top_k: Optional[int] = None,
@@ -474,6 +486,7 @@ class SglGen(SglExpr):
min_new_tokens=min_new_tokens,
n=n,
stop=stop,
stop_regex=stop_regex,
stop_token_ids=stop_token_ids,
temperature=temperature,
top_p=top_p,