[Feature] support regex strings as a stopping condition (#10635)
This commit is contained in:
@@ -79,6 +79,7 @@ def gen(
|
||||
n: Optional[int] = None,
|
||||
stop: Optional[Union[str, List[str]]] = None,
|
||||
stop_token_ids: Optional[List[int]] = None,
|
||||
stop_regex: Optional[Union[str, List[str]]] = None,
|
||||
temperature: Optional[float] = None,
|
||||
top_p: Optional[float] = None,
|
||||
top_k: Optional[int] = None,
|
||||
@@ -120,6 +121,7 @@ def gen(
|
||||
n,
|
||||
stop,
|
||||
stop_token_ids,
|
||||
stop_regex,
|
||||
temperature,
|
||||
top_p,
|
||||
top_k,
|
||||
@@ -143,6 +145,7 @@ def gen_int(
|
||||
n: Optional[int] = None,
|
||||
stop: Optional[Union[str, List[str]]] = None,
|
||||
stop_token_ids: Optional[List[int]] = None,
|
||||
stop_regex: Optional[Union[str, List[str]]] = None,
|
||||
temperature: Optional[float] = None,
|
||||
top_p: Optional[float] = None,
|
||||
top_k: Optional[int] = None,
|
||||
@@ -162,6 +165,7 @@ def gen_int(
|
||||
n,
|
||||
stop,
|
||||
stop_token_ids,
|
||||
stop_regex,
|
||||
temperature,
|
||||
top_p,
|
||||
top_k,
|
||||
@@ -184,6 +188,7 @@ def gen_string(
|
||||
n: Optional[int] = None,
|
||||
stop: Optional[Union[str, List[str]]] = None,
|
||||
stop_token_ids: Optional[List[int]] = None,
|
||||
stop_regex: Optional[Union[str, List[str]]] = None,
|
||||
temperature: Optional[float] = None,
|
||||
top_p: Optional[float] = None,
|
||||
top_k: Optional[int] = None,
|
||||
@@ -203,6 +208,7 @@ def gen_string(
|
||||
n,
|
||||
stop,
|
||||
stop_token_ids,
|
||||
stop_regex,
|
||||
temperature,
|
||||
top_p,
|
||||
top_k,
|
||||
|
||||
@@ -792,6 +792,7 @@ class StreamExecutor:
|
||||
"n",
|
||||
"stop",
|
||||
"stop_token_ids",
|
||||
"stop_regex",
|
||||
"temperature",
|
||||
"top_p",
|
||||
"top_k",
|
||||
|
||||
@@ -21,6 +21,7 @@ class SglSamplingParams:
|
||||
n: int = 1
|
||||
stop: Union[str, List[str]] = ()
|
||||
stop_token_ids: Optional[List[int]] = ()
|
||||
stop_regex: Optional[Union[str, List[str]]] = ()
|
||||
temperature: float = 1.0
|
||||
top_p: float = 1.0
|
||||
top_k: int = -1 # -1 means disable
|
||||
@@ -45,6 +46,7 @@ class SglSamplingParams:
|
||||
self.n,
|
||||
self.stop,
|
||||
self.stop_token_ids,
|
||||
self.stop_regex,
|
||||
self.temperature,
|
||||
self.top_p,
|
||||
self.top_k,
|
||||
@@ -123,6 +125,7 @@ class SglSamplingParams:
|
||||
"n": self.n,
|
||||
"stop": self.stop,
|
||||
"stop_token_ids": self.stop_token_ids,
|
||||
"stop_regex": self.stop_regex,
|
||||
"temperature": self.temperature,
|
||||
"top_p": self.top_p,
|
||||
"top_k": self.top_k,
|
||||
@@ -161,6 +164,7 @@ class SglFunction:
|
||||
n: int = 1,
|
||||
stop: Optional[Union[str, List[str]]] = None,
|
||||
stop_token_ids: Optional[List[int]] = None,
|
||||
stop_regex: Optional[Union[str, List[str]]] = None,
|
||||
temperature: float = 1.0,
|
||||
top_p: float = 1.0,
|
||||
top_k: int = -1,
|
||||
@@ -184,12 +188,15 @@ class SglFunction:
|
||||
stop = []
|
||||
if stop_token_ids is None:
|
||||
stop_token_ids = []
|
||||
if stop_regex is None:
|
||||
stop_regex = []
|
||||
|
||||
default_sampling_para = SglSamplingParams(
|
||||
max_new_tokens=max_new_tokens,
|
||||
n=n,
|
||||
stop=stop,
|
||||
stop_token_ids=stop_token_ids,
|
||||
stop_regex=stop_regex,
|
||||
temperature=temperature,
|
||||
top_p=top_p,
|
||||
top_k=top_k,
|
||||
@@ -221,6 +228,7 @@ class SglFunction:
|
||||
n: int = 1,
|
||||
stop: Optional[Union[str, List[str]]] = None,
|
||||
stop_token_ids: Optional[List[int]] = None,
|
||||
stop_regex: Optional[Union[str, List[str]]] = None,
|
||||
temperature: float = 1.0,
|
||||
top_p: float = 1.0,
|
||||
top_k: int = -1,
|
||||
@@ -243,6 +251,8 @@ class SglFunction:
|
||||
stop = []
|
||||
if stop_token_ids is None:
|
||||
stop_token_ids = []
|
||||
if stop_regex is None:
|
||||
stop_regex = []
|
||||
|
||||
assert isinstance(batch_kwargs, (list, tuple))
|
||||
if len(batch_kwargs) == 0:
|
||||
@@ -267,6 +277,7 @@ class SglFunction:
|
||||
n=n,
|
||||
stop=stop,
|
||||
stop_token_ids=stop_token_ids,
|
||||
stop_regex=stop_regex,
|
||||
temperature=temperature,
|
||||
top_p=top_p,
|
||||
top_k=top_k,
|
||||
@@ -451,6 +462,7 @@ class SglGen(SglExpr):
|
||||
n: Optional[int] = None,
|
||||
stop: Optional[Union[str, List[str]]] = None,
|
||||
stop_token_ids: Optional[List[int]] = None,
|
||||
stop_regex: Optional[Union[str, List[str]]] = None,
|
||||
temperature: Optional[float] = None,
|
||||
top_p: Optional[float] = None,
|
||||
top_k: Optional[int] = None,
|
||||
@@ -474,6 +486,7 @@ class SglGen(SglExpr):
|
||||
min_new_tokens=min_new_tokens,
|
||||
n=n,
|
||||
stop=stop,
|
||||
stop_regex=stop_regex,
|
||||
stop_token_ids=stop_token_ids,
|
||||
temperature=temperature,
|
||||
top_p=top_p,
|
||||
|
||||
Reference in New Issue
Block a user