make json_schema usable from gen (#1254)

This commit is contained in:
Enrique Shockwave
2024-08-29 02:57:10 +01:00
committed by GitHub
parent 13ac95b894
commit 6c34d6339c
3 changed files with 8 additions and 0 deletions

View File

@@ -78,6 +78,7 @@ def gen(
choices: Optional[List[str]] = None,
choices_method: Optional[ChoicesSamplingMethod] = None,
regex: Optional[str] = None,
json_schema: Optional[str] = None,
):
"""Call the model to generate. See the meaning of the arguments in docs/en/sampling_params.md"""
@@ -114,6 +115,7 @@ def gen(
return_text_in_logprobs,
dtype,
regex,
json_schema,
)

View File

@@ -673,6 +673,7 @@ class StreamExecutor:
"return_text_in_logprobs",
"dtype",
"regex",
"json_schema",
]:
value = getattr(sampling_params, item, None)
if value is not None:

View File

@@ -30,6 +30,7 @@ class SglSamplingParams:
logprob_start_len: Optional[int] = (None,)
top_logprobs_num: Optional[int] = (None,)
return_text_in_logprobs: Optional[bool] = (None,)
json_schema: Optional[str] = None
# for constrained generation, not included in to_xxx_kwargs
dtype: Optional[str] = None
@@ -51,6 +52,7 @@ class SglSamplingParams:
self.logprob_start_len,
self.top_logprobs_num,
self.return_text_in_logprobs,
self.json_schema,
)
def to_openai_kwargs(self):
@@ -121,6 +123,7 @@ class SglSamplingParams:
"presence_penalty": self.presence_penalty,
"ignore_eos": self.ignore_eos,
"regex": self.regex,
"json_schema": self.json_schema,
}
@@ -425,6 +428,7 @@ class SglGen(SglExpr):
return_text_in_logprobs: Optional[bool] = None,
dtype: Optional[type] = None,
regex: Optional[str] = None,
json_schema: Optional[str] = None,
):
"""Call the model to generate. See the meaning of the arguments in docs/en/sampling_params.md"""
super().__init__()
@@ -446,6 +450,7 @@ class SglGen(SglExpr):
return_text_in_logprobs=return_text_in_logprobs,
dtype=dtype,
regex=regex,
json_schema=json_schema,
)
def __repr__(self):