Add constrained_json_whitespace_pattern to ServerArgs (#1438)
This commit is contained in:
@@ -29,6 +29,7 @@ class FSMCache(BaseToolCache):
|
||||
tokenizer_args_dict,
|
||||
enable=True,
|
||||
skip_tokenizer_init=False,
|
||||
constrained_json_whitespace_pattern=None,
|
||||
):
|
||||
super().__init__(enable=enable)
|
||||
|
||||
@@ -63,11 +64,14 @@ class FSMCache(BaseToolCache):
|
||||
self.outlines_tokenizer.vocabulary = (
|
||||
self.outlines_tokenizer.tokenizer.get_vocab()
|
||||
)
|
||||
self.constrained_json_whitespace_pattern = constrained_json_whitespace_pattern
|
||||
|
||||
def init_value(self, key):
|
||||
key_type, key_string = key
|
||||
if key_type == "json":
|
||||
regex = build_regex_from_schema(key_string, whitespace_pattern=r"[\n\t ]*")
|
||||
regex = build_regex_from_schema(
|
||||
key_string, whitespace_pattern=self.constrained_json_whitespace_pattern
|
||||
)
|
||||
elif key_type == "regex":
|
||||
regex = key_string
|
||||
else:
|
||||
|
||||
@@ -198,6 +198,7 @@ class ModelTpServer:
|
||||
"trust_remote_code": server_args.trust_remote_code,
|
||||
},
|
||||
skip_tokenizer_init=server_args.skip_tokenizer_init,
|
||||
constrained_json_whitespace_pattern=server_args.constrained_json_whitespace_pattern,
|
||||
)
|
||||
self.jump_forward_cache = JumpForwardCache()
|
||||
|
||||
@@ -807,12 +808,10 @@ class ModelTpServer:
|
||||
unfinished_indices.append(i)
|
||||
|
||||
if req.finished() or (
|
||||
(
|
||||
req.stream
|
||||
and (
|
||||
self.decode_forward_ct % self.stream_interval == 0
|
||||
or len(req.output_ids) == 1
|
||||
)
|
||||
req.stream
|
||||
and (
|
||||
self.decode_forward_ct % self.stream_interval == 0
|
||||
or len(req.output_ids) == 1
|
||||
)
|
||||
):
|
||||
output_rids.append(req.rid)
|
||||
|
||||
@@ -70,6 +70,7 @@ class ServerArgs:
|
||||
tp_size: int = 1
|
||||
stream_interval: int = 1
|
||||
random_seed: Optional[int] = None
|
||||
constrained_json_whitespace_pattern: Optional[str] = None
|
||||
|
||||
# Logging
|
||||
log_level: str = "info"
|
||||
@@ -370,6 +371,12 @@ class ServerArgs:
|
||||
default=ServerArgs.random_seed,
|
||||
help="The random seed.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--constrained-json-whitespace-pattern",
|
||||
type=str,
|
||||
default=ServerArgs.constrained_json_whitespace_pattern,
|
||||
help=r"Regex pattern for syntactic whitespaces allowed in JSON constrained output. For example, to allow the model generate consecutive whitespaces, set the pattern to [\n\t ]*",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--log-level",
|
||||
type=str,
|
||||
|
||||
Reference in New Issue
Block a user