make json_schema usable from gen (#1254)
This commit is contained in:
committed by
GitHub
parent
13ac95b894
commit
6c34d6339c
@@ -78,6 +78,7 @@ def gen(
|
||||
choices: Optional[List[str]] = None,
|
||||
choices_method: Optional[ChoicesSamplingMethod] = None,
|
||||
regex: Optional[str] = None,
|
||||
json_schema: Optional[str] = None,
|
||||
):
|
||||
"""Call the model to generate. See the meaning of the arguments in docs/en/sampling_params.md"""
|
||||
|
||||
@@ -114,6 +115,7 @@ def gen(
|
||||
return_text_in_logprobs,
|
||||
dtype,
|
||||
regex,
|
||||
json_schema,
|
||||
)
|
||||
|
||||
|
||||
|
||||
@@ -673,6 +673,7 @@ class StreamExecutor:
|
||||
"return_text_in_logprobs",
|
||||
"dtype",
|
||||
"regex",
|
||||
"json_schema",
|
||||
]:
|
||||
value = getattr(sampling_params, item, None)
|
||||
if value is not None:
|
||||
|
||||
@@ -30,6 +30,7 @@ class SglSamplingParams:
|
||||
logprob_start_len: Optional[int] = (None,)
|
||||
top_logprobs_num: Optional[int] = (None,)
|
||||
return_text_in_logprobs: Optional[bool] = (None,)
|
||||
json_schema: Optional[str] = None
|
||||
|
||||
# for constrained generation, not included in to_xxx_kwargs
|
||||
dtype: Optional[str] = None
|
||||
@@ -51,6 +52,7 @@ class SglSamplingParams:
|
||||
self.logprob_start_len,
|
||||
self.top_logprobs_num,
|
||||
self.return_text_in_logprobs,
|
||||
self.json_schema,
|
||||
)
|
||||
|
||||
def to_openai_kwargs(self):
|
||||
@@ -121,6 +123,7 @@ class SglSamplingParams:
|
||||
"presence_penalty": self.presence_penalty,
|
||||
"ignore_eos": self.ignore_eos,
|
||||
"regex": self.regex,
|
||||
"json_schema": self.json_schema,
|
||||
}
|
||||
|
||||
|
||||
@@ -425,6 +428,7 @@ class SglGen(SglExpr):
|
||||
return_text_in_logprobs: Optional[bool] = None,
|
||||
dtype: Optional[type] = None,
|
||||
regex: Optional[str] = None,
|
||||
json_schema: Optional[str] = None,
|
||||
):
|
||||
"""Call the model to generate. See the meaning of the arguments in docs/en/sampling_params.md"""
|
||||
super().__init__()
|
||||
@@ -446,6 +450,7 @@ class SglGen(SglExpr):
|
||||
return_text_in_logprobs=return_text_in_logprobs,
|
||||
dtype=dtype,
|
||||
regex=regex,
|
||||
json_schema=json_schema,
|
||||
)
|
||||
|
||||
def __repr__(self):
|
||||
|
||||
Reference in New Issue
Block a user