diff --git a/python/sglang/api.py b/python/sglang/api.py index 3a2f747be..9405606b7 100644 --- a/python/sglang/api.py +++ b/python/sglang/api.py @@ -78,6 +78,7 @@ def gen( choices: Optional[List[str]] = None, choices_method: Optional[ChoicesSamplingMethod] = None, regex: Optional[str] = None, + json_schema: Optional[str] = None, ): """Call the model to generate. See the meaning of the arguments in docs/en/sampling_params.md""" @@ -114,6 +115,7 @@ def gen( return_text_in_logprobs, dtype, regex, + json_schema, ) diff --git a/python/sglang/lang/interpreter.py b/python/sglang/lang/interpreter.py index 306d280c7..91f48456a 100644 --- a/python/sglang/lang/interpreter.py +++ b/python/sglang/lang/interpreter.py @@ -673,6 +673,7 @@ class StreamExecutor: "return_text_in_logprobs", "dtype", "regex", + "json_schema", ]: value = getattr(sampling_params, item, None) if value is not None: diff --git a/python/sglang/lang/ir.py b/python/sglang/lang/ir.py index 199a7ac7a..99a3e8e68 100644 --- a/python/sglang/lang/ir.py +++ b/python/sglang/lang/ir.py @@ -30,6 +30,7 @@ class SglSamplingParams: logprob_start_len: Optional[int] = (None,) top_logprobs_num: Optional[int] = (None,) return_text_in_logprobs: Optional[bool] = (None,) + json_schema: Optional[str] = None # for constrained generation, not included in to_xxx_kwargs dtype: Optional[str] = None @@ -51,6 +52,7 @@ class SglSamplingParams: self.logprob_start_len, self.top_logprobs_num, self.return_text_in_logprobs, + self.json_schema, ) def to_openai_kwargs(self): @@ -121,6 +123,7 @@ class SglSamplingParams: "presence_penalty": self.presence_penalty, "ignore_eos": self.ignore_eos, "regex": self.regex, + "json_schema": self.json_schema, } @@ -425,6 +428,7 @@ class SglGen(SglExpr): return_text_in_logprobs: Optional[bool] = None, dtype: Optional[type] = None, regex: Optional[str] = None, + json_schema: Optional[str] = None, ): """Call the model to generate. See the meaning of the arguments in docs/en/sampling_params.md""" super().__init__() @@ -446,6 +450,7 @@ class SglGen(SglExpr): return_text_in_logprobs=return_text_in_logprobs, dtype=dtype, regex=regex, + json_schema=json_schema, ) def __repr__(self):