[FEAT] JSON constrained support (#1125)
Co-authored-by: Yineng Zhang <me@zhyncs.com>
This commit is contained in:
@@ -15,6 +15,8 @@ limitations under the License.
|
||||
|
||||
"""Cache for the compressed finite state machine."""
|
||||
|
||||
from outlines.fsm.json_schema import build_regex_from_schema
|
||||
|
||||
from sglang.srt.constrained import RegexGuide, TransformerTokenizer
|
||||
from sglang.srt.constrained.base_tool_cache import BaseToolCache
|
||||
|
||||
@@ -26,9 +28,12 @@ class FSMCache(BaseToolCache):
|
||||
tokenizer_args_dict,
|
||||
enable=True,
|
||||
skip_tokenizer_init=False,
|
||||
json_schema_mode=False,
|
||||
):
|
||||
super().__init__(enable=enable)
|
||||
|
||||
self.json_schema_mode = json_schema_mode
|
||||
|
||||
if (
|
||||
skip_tokenizer_init
|
||||
or tokenizer_path.endswith(".json")
|
||||
@@ -72,5 +77,9 @@ class FSMCache(BaseToolCache):
|
||||
tokenizer_path, **tokenizer_args_dict
|
||||
)
|
||||
|
||||
def init_value(self, regex):
|
||||
return RegexGuide(regex, self.outlines_tokenizer)
|
||||
def init_value(self, value):
|
||||
if self.json_schema_mode:
|
||||
regex = build_regex_from_schema(value)
|
||||
return RegexGuide(regex, self.outlines_tokenizer), regex
|
||||
else:
|
||||
return RegexGuide(value, self.outlines_tokenizer)
|
||||
|
||||
@@ -23,6 +23,7 @@ from collections import defaultdict
|
||||
|
||||
import interegular
|
||||
import outlines.caching
|
||||
from outlines.fsm.json_schema import build_regex_from_schema
|
||||
|
||||
from sglang.srt.constrained import (
|
||||
FSMInfo,
|
||||
|
||||
@@ -268,7 +268,14 @@ class Req:
|
||||
|
||||
all_text = self.origin_input_text + self.decoded_text + jump_forward_str
|
||||
all_ids = self.tokenizer.encode(all_text)
|
||||
if not all_ids:
|
||||
warnings.warn("Encoded all_text resulted in empty all_ids")
|
||||
return False
|
||||
|
||||
prompt_tokens = len(self.origin_input_ids_unpadded)
|
||||
if prompt_tokens > len(all_ids):
|
||||
warnings.warn("prompt_tokens is larger than encoded all_ids")
|
||||
return False
|
||||
|
||||
if all_ids[prompt_tokens - 1] != self.origin_input_ids_unpadded[-1]:
|
||||
# TODO(lsyin): fix token fusion
|
||||
|
||||
@@ -197,6 +197,16 @@ class ModelTpServer:
|
||||
"trust_remote_code": server_args.trust_remote_code,
|
||||
},
|
||||
skip_tokenizer_init=server_args.skip_tokenizer_init,
|
||||
json_schema_mode=False,
|
||||
)
|
||||
self.json_fsm_cache = FSMCache(
|
||||
server_args.tokenizer_path,
|
||||
{
|
||||
"tokenizer_mode": server_args.tokenizer_mode,
|
||||
"trust_remote_code": server_args.trust_remote_code,
|
||||
},
|
||||
skip_tokenizer_init=server_args.skip_tokenizer_init,
|
||||
json_schema_mode=True,
|
||||
)
|
||||
self.jump_forward_cache = JumpForwardCache()
|
||||
|
||||
@@ -349,8 +359,17 @@ class ModelTpServer:
|
||||
req.top_logprobs_num = recv_req.top_logprobs_num
|
||||
req.stream = recv_req.stream
|
||||
|
||||
# Init regex fsm fron json
|
||||
if req.sampling_params.json_schema is not None:
|
||||
req.regex_fsm, computed_regex_string = self.json_fsm_cache.query(
|
||||
req.sampling_params.json_schema
|
||||
)
|
||||
if not self.disable_regex_jump_forward:
|
||||
req.jump_forward_map = self.jump_forward_cache.query(
|
||||
computed_regex_string
|
||||
)
|
||||
# Init regex fsm
|
||||
if req.sampling_params.regex is not None:
|
||||
elif req.sampling_params.regex is not None:
|
||||
req.regex_fsm = self.regex_fsm_cache.query(req.sampling_params.regex)
|
||||
if not self.disable_regex_jump_forward:
|
||||
req.jump_forward_map = self.jump_forward_cache.query(
|
||||
|
||||
@@ -434,6 +434,7 @@ def v1_generate_request(all_requests: List[CompletionRequest]):
|
||||
"frequency_penalty": request.frequency_penalty,
|
||||
"repetition_penalty": request.repetition_penalty,
|
||||
"regex": request.regex,
|
||||
"json_schema": request.json_schema,
|
||||
"n": request.n,
|
||||
"ignore_eos": request.ignore_eos,
|
||||
}
|
||||
@@ -802,6 +803,7 @@ def v1_chat_generate_request(
|
||||
"frequency_penalty": request.frequency_penalty,
|
||||
"repetition_penalty": request.repetition_penalty,
|
||||
"regex": request.regex,
|
||||
"json_schema": request.json_schema,
|
||||
"n": request.n,
|
||||
}
|
||||
)
|
||||
|
||||
@@ -161,6 +161,7 @@ class CompletionRequest(BaseModel):
|
||||
|
||||
# Extra parameters for SRT backend only and will be ignored by OpenAI models.
|
||||
regex: Optional[str] = None
|
||||
json_schema: Optional[str] = None
|
||||
ignore_eos: Optional[bool] = False
|
||||
min_tokens: Optional[int] = 0
|
||||
repetition_penalty: Optional[float] = 1.0
|
||||
@@ -262,6 +263,7 @@ class ChatCompletionRequest(BaseModel):
|
||||
|
||||
# Extra parameters for SRT backend only and will be ignored by OpenAI models.
|
||||
regex: Optional[str] = None
|
||||
json_schema: Optional[str] = None
|
||||
min_tokens: Optional[int] = 0
|
||||
repetition_penalty: Optional[float] = 1.0
|
||||
stop_token_ids: Optional[List[int]] = Field(default_factory=list)
|
||||
|
||||
@@ -39,6 +39,7 @@ class SamplingParams:
|
||||
spaces_between_special_tokens: bool = True,
|
||||
regex: Optional[str] = None,
|
||||
n: int = 1,
|
||||
json_schema: Optional[str] = None,
|
||||
) -> None:
|
||||
self.temperature = temperature
|
||||
self.top_p = top_p
|
||||
@@ -56,6 +57,7 @@ class SamplingParams:
|
||||
self.spaces_between_special_tokens = spaces_between_special_tokens
|
||||
self.regex = regex
|
||||
self.n = n
|
||||
self.json_schema = json_schema
|
||||
|
||||
# Process some special cases
|
||||
if self.temperature < _SAMPLING_EPS:
|
||||
@@ -106,6 +108,8 @@ class SamplingParams:
|
||||
f"min_new_tokens must be in (0, max_new_tokens({self.max_new_tokens})], got "
|
||||
f"{self.min_new_tokens}."
|
||||
)
|
||||
if self.regex is not None and self.json_schema is not None:
|
||||
raise ValueError("regex and json_schema cannot be both set.")
|
||||
|
||||
def normalize(self, tokenizer):
|
||||
# Process stop strings
|
||||
|
||||
Reference in New Issue
Block a user