Support min_tokens in sgl.gen (#1573)
This commit is contained in:
@@ -668,6 +668,7 @@ class StreamExecutor:
|
||||
|
||||
for item in [
|
||||
"max_new_tokens",
|
||||
"min_new_tokens",
|
||||
"stop",
|
||||
"stop_token_ids",
|
||||
"temperature",
|
||||
|
||||
@@ -17,6 +17,7 @@ REGEX_STR = r"\"[\w\d\s]*\"" # bugs with regex r"\".*\"" in interegular pkg
|
||||
@dataclasses.dataclass
|
||||
class SglSamplingParams:
|
||||
max_new_tokens: int = 128
|
||||
min_new_tokens: int = 0
|
||||
stop: Union[str, List[str]] = ()
|
||||
stop_token_ids: Optional[List[int]] = ()
|
||||
temperature: float = 1.0
|
||||
@@ -39,6 +40,7 @@ class SglSamplingParams:
|
||||
def clone(self):
|
||||
return SglSamplingParams(
|
||||
self.max_new_tokens,
|
||||
self.min_new_tokens,
|
||||
self.stop,
|
||||
self.stop_token_ids,
|
||||
self.temperature,
|
||||
@@ -113,6 +115,7 @@ class SglSamplingParams:
|
||||
def to_srt_kwargs(self):
|
||||
return {
|
||||
"max_new_tokens": self.max_new_tokens,
|
||||
"min_new_tokens": self.min_new_tokens,
|
||||
"stop": self.stop,
|
||||
"stop_token_ids": self.stop_token_ids,
|
||||
"temperature": self.temperature,
|
||||
@@ -424,6 +427,7 @@ class SglGen(SglExpr):
|
||||
self,
|
||||
name: Optional[str] = None,
|
||||
max_new_tokens: Optional[int] = None,
|
||||
min_new_tokens: Optional[int] = None,
|
||||
stop: Optional[Union[str, List[str]]] = None,
|
||||
stop_token_ids: Optional[List[int]] = None,
|
||||
temperature: Optional[float] = None,
|
||||
@@ -446,6 +450,7 @@ class SglGen(SglExpr):
|
||||
self.name = name
|
||||
self.sampling_params = SglSamplingParams(
|
||||
max_new_tokens=max_new_tokens,
|
||||
min_new_tokens=min_new_tokens,
|
||||
stop=stop,
|
||||
stop_token_ids=stop_token_ids,
|
||||
temperature=temperature,
|
||||
|
||||
Reference in New Issue
Block a user