make json_schema usable from gen (#1254)
This commit is contained in:
committed by
GitHub
parent
13ac95b894
commit
6c34d6339c
@@ -78,6 +78,7 @@ def gen(
|
|||||||
choices: Optional[List[str]] = None,
|
choices: Optional[List[str]] = None,
|
||||||
choices_method: Optional[ChoicesSamplingMethod] = None,
|
choices_method: Optional[ChoicesSamplingMethod] = None,
|
||||||
regex: Optional[str] = None,
|
regex: Optional[str] = None,
|
||||||
|
json_schema: Optional[str] = None,
|
||||||
):
|
):
|
||||||
"""Call the model to generate. See the meaning of the arguments in docs/en/sampling_params.md"""
|
"""Call the model to generate. See the meaning of the arguments in docs/en/sampling_params.md"""
|
||||||
|
|
||||||
@@ -114,6 +115,7 @@ def gen(
|
|||||||
return_text_in_logprobs,
|
return_text_in_logprobs,
|
||||||
dtype,
|
dtype,
|
||||||
regex,
|
regex,
|
||||||
|
json_schema,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -673,6 +673,7 @@ class StreamExecutor:
|
|||||||
"return_text_in_logprobs",
|
"return_text_in_logprobs",
|
||||||
"dtype",
|
"dtype",
|
||||||
"regex",
|
"regex",
|
||||||
|
"json_schema",
|
||||||
]:
|
]:
|
||||||
value = getattr(sampling_params, item, None)
|
value = getattr(sampling_params, item, None)
|
||||||
if value is not None:
|
if value is not None:
|
||||||
|
|||||||
@@ -30,6 +30,7 @@ class SglSamplingParams:
|
|||||||
logprob_start_len: Optional[int] = (None,)
|
logprob_start_len: Optional[int] = (None,)
|
||||||
top_logprobs_num: Optional[int] = (None,)
|
top_logprobs_num: Optional[int] = (None,)
|
||||||
return_text_in_logprobs: Optional[bool] = (None,)
|
return_text_in_logprobs: Optional[bool] = (None,)
|
||||||
|
json_schema: Optional[str] = None
|
||||||
|
|
||||||
# for constrained generation, not included in to_xxx_kwargs
|
# for constrained generation, not included in to_xxx_kwargs
|
||||||
dtype: Optional[str] = None
|
dtype: Optional[str] = None
|
||||||
@@ -51,6 +52,7 @@ class SglSamplingParams:
|
|||||||
self.logprob_start_len,
|
self.logprob_start_len,
|
||||||
self.top_logprobs_num,
|
self.top_logprobs_num,
|
||||||
self.return_text_in_logprobs,
|
self.return_text_in_logprobs,
|
||||||
|
self.json_schema,
|
||||||
)
|
)
|
||||||
|
|
||||||
def to_openai_kwargs(self):
|
def to_openai_kwargs(self):
|
||||||
@@ -121,6 +123,7 @@ class SglSamplingParams:
|
|||||||
"presence_penalty": self.presence_penalty,
|
"presence_penalty": self.presence_penalty,
|
||||||
"ignore_eos": self.ignore_eos,
|
"ignore_eos": self.ignore_eos,
|
||||||
"regex": self.regex,
|
"regex": self.regex,
|
||||||
|
"json_schema": self.json_schema,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
@@ -425,6 +428,7 @@ class SglGen(SglExpr):
|
|||||||
return_text_in_logprobs: Optional[bool] = None,
|
return_text_in_logprobs: Optional[bool] = None,
|
||||||
dtype: Optional[type] = None,
|
dtype: Optional[type] = None,
|
||||||
regex: Optional[str] = None,
|
regex: Optional[str] = None,
|
||||||
|
json_schema: Optional[str] = None,
|
||||||
):
|
):
|
||||||
"""Call the model to generate. See the meaning of the arguments in docs/en/sampling_params.md"""
|
"""Call the model to generate. See the meaning of the arguments in docs/en/sampling_params.md"""
|
||||||
super().__init__()
|
super().__init__()
|
||||||
@@ -446,6 +450,7 @@ class SglGen(SglExpr):
|
|||||||
return_text_in_logprobs=return_text_in_logprobs,
|
return_text_in_logprobs=return_text_in_logprobs,
|
||||||
dtype=dtype,
|
dtype=dtype,
|
||||||
regex=regex,
|
regex=regex,
|
||||||
|
json_schema=json_schema,
|
||||||
)
|
)
|
||||||
|
|
||||||
def __repr__(self):
|
def __repr__(self):
|
||||||
|
|||||||
Reference in New Issue
Block a user