Support stop_token_ids in sglang API (#1092)
This commit is contained in:
@@ -20,7 +20,6 @@ from sglang.lang.ir import (
|
||||
SglConstantText,
|
||||
SglExpr,
|
||||
SglExprList,
|
||||
SglFunction,
|
||||
SglGen,
|
||||
SglImage,
|
||||
SglRoleBegin,
|
||||
@@ -181,8 +180,10 @@ class StreamExecutor:
|
||||
num_api_spec_tokens=None,
|
||||
use_thread=True,
|
||||
):
|
||||
from sglang.lang.backend.base_backend import BaseBackend
|
||||
|
||||
self.sid = uuid.uuid4().hex
|
||||
self.backend = backend
|
||||
self.backend: BaseBackend = backend
|
||||
self.arguments: Dict[str, Any] = arguments
|
||||
self.default_sampling_para = default_sampling_para
|
||||
self.stream = stream
|
||||
@@ -658,6 +659,7 @@ class StreamExecutor:
|
||||
for item in [
|
||||
"max_new_tokens",
|
||||
"stop",
|
||||
"stop_token_ids",
|
||||
"temperature",
|
||||
"top_p",
|
||||
"top_k",
|
||||
|
||||
@@ -18,6 +18,7 @@ REGEX_STR = r"\"[\w\d\s]*\"" # bugs with regex r"\".*\"" in interegular pkg
|
||||
class SglSamplingParams:
|
||||
max_new_tokens: int = 128
|
||||
stop: Union[str, List[str]] = ()
|
||||
stop_token_ids: Optional[List[int]] = ()
|
||||
temperature: float = 1.0
|
||||
top_p: float = 1.0
|
||||
top_k: int = -1 # -1 means disable
|
||||
@@ -37,6 +38,7 @@ class SglSamplingParams:
|
||||
return SglSamplingParams(
|
||||
self.max_new_tokens,
|
||||
self.stop,
|
||||
self.stop_token_ids,
|
||||
self.temperature,
|
||||
self.top_p,
|
||||
self.top_k,
|
||||
@@ -108,6 +110,7 @@ class SglSamplingParams:
|
||||
return {
|
||||
"max_new_tokens": self.max_new_tokens,
|
||||
"stop": self.stop,
|
||||
"stop_token_ids": self.stop_token_ids,
|
||||
"temperature": self.temperature,
|
||||
"top_p": self.top_p,
|
||||
"top_k": self.top_k,
|
||||
@@ -141,7 +144,8 @@ class SglFunction:
|
||||
self,
|
||||
*args,
|
||||
max_new_tokens: int = 128,
|
||||
stop: Union[str, List[str]] = (),
|
||||
stop: Union[str, List[str]] = [],
|
||||
stop_token_ids: Optional[List[int]] = [],
|
||||
temperature: float = 1.0,
|
||||
top_p: float = 1.0,
|
||||
top_k: int = -1,
|
||||
@@ -161,6 +165,7 @@ class SglFunction:
|
||||
default_sampling_para = SglSamplingParams(
|
||||
max_new_tokens=max_new_tokens,
|
||||
stop=stop,
|
||||
stop_token_ids=stop_token_ids,
|
||||
temperature=temperature,
|
||||
top_p=top_p,
|
||||
top_k=top_k,
|
||||
@@ -181,6 +186,7 @@ class SglFunction:
|
||||
*,
|
||||
max_new_tokens: int = 128,
|
||||
stop: Union[str, List[str]] = (),
|
||||
stop_token_ids: Optional[List[int]] = [],
|
||||
temperature: float = 1.0,
|
||||
top_p: float = 1.0,
|
||||
top_k: int = -1,
|
||||
@@ -218,6 +224,7 @@ class SglFunction:
|
||||
default_sampling_para = SglSamplingParams(
|
||||
max_new_tokens=max_new_tokens,
|
||||
stop=stop,
|
||||
stop_token_ids=stop_token_ids,
|
||||
temperature=temperature,
|
||||
top_p=top_p,
|
||||
top_k=top_k,
|
||||
@@ -397,6 +404,7 @@ class SglGen(SglExpr):
|
||||
name: Optional[str] = None,
|
||||
max_new_tokens: Optional[int] = None,
|
||||
stop: Optional[Union[str, List[str]]] = None,
|
||||
stop_token_ids: Optional[List[int]] = None,
|
||||
temperature: Optional[float] = None,
|
||||
top_p: Optional[float] = None,
|
||||
top_k: Optional[int] = None,
|
||||
@@ -416,6 +424,7 @@ class SglGen(SglExpr):
|
||||
self.sampling_params = SglSamplingParams(
|
||||
max_new_tokens=max_new_tokens,
|
||||
stop=stop,
|
||||
stop_token_ids=stop_token_ids,
|
||||
temperature=temperature,
|
||||
top_p=top_p,
|
||||
top_k=top_k,
|
||||
|
||||
Reference in New Issue
Block a user