Add constrained_json_whitespace_pattern to ServerArgs (#1438)
This commit is contained in:
@@ -29,6 +29,7 @@ class FSMCache(BaseToolCache):
|
|||||||
tokenizer_args_dict,
|
tokenizer_args_dict,
|
||||||
enable=True,
|
enable=True,
|
||||||
skip_tokenizer_init=False,
|
skip_tokenizer_init=False,
|
||||||
|
constrained_json_whitespace_pattern=None,
|
||||||
):
|
):
|
||||||
super().__init__(enable=enable)
|
super().__init__(enable=enable)
|
||||||
|
|
||||||
@@ -63,11 +64,14 @@ class FSMCache(BaseToolCache):
|
|||||||
self.outlines_tokenizer.vocabulary = (
|
self.outlines_tokenizer.vocabulary = (
|
||||||
self.outlines_tokenizer.tokenizer.get_vocab()
|
self.outlines_tokenizer.tokenizer.get_vocab()
|
||||||
)
|
)
|
||||||
|
self.constrained_json_whitespace_pattern = constrained_json_whitespace_pattern
|
||||||
|
|
||||||
def init_value(self, key):
|
def init_value(self, key):
|
||||||
key_type, key_string = key
|
key_type, key_string = key
|
||||||
if key_type == "json":
|
if key_type == "json":
|
||||||
regex = build_regex_from_schema(key_string, whitespace_pattern=r"[\n\t ]*")
|
regex = build_regex_from_schema(
|
||||||
|
key_string, whitespace_pattern=self.constrained_json_whitespace_pattern
|
||||||
|
)
|
||||||
elif key_type == "regex":
|
elif key_type == "regex":
|
||||||
regex = key_string
|
regex = key_string
|
||||||
else:
|
else:
|
||||||
|
|||||||
@@ -198,6 +198,7 @@ class ModelTpServer:
|
|||||||
"trust_remote_code": server_args.trust_remote_code,
|
"trust_remote_code": server_args.trust_remote_code,
|
||||||
},
|
},
|
||||||
skip_tokenizer_init=server_args.skip_tokenizer_init,
|
skip_tokenizer_init=server_args.skip_tokenizer_init,
|
||||||
|
constrained_json_whitespace_pattern=server_args.constrained_json_whitespace_pattern,
|
||||||
)
|
)
|
||||||
self.jump_forward_cache = JumpForwardCache()
|
self.jump_forward_cache = JumpForwardCache()
|
||||||
|
|
||||||
@@ -807,12 +808,10 @@ class ModelTpServer:
|
|||||||
unfinished_indices.append(i)
|
unfinished_indices.append(i)
|
||||||
|
|
||||||
if req.finished() or (
|
if req.finished() or (
|
||||||
(
|
req.stream
|
||||||
req.stream
|
and (
|
||||||
and (
|
self.decode_forward_ct % self.stream_interval == 0
|
||||||
self.decode_forward_ct % self.stream_interval == 0
|
or len(req.output_ids) == 1
|
||||||
or len(req.output_ids) == 1
|
|
||||||
)
|
|
||||||
)
|
)
|
||||||
):
|
):
|
||||||
output_rids.append(req.rid)
|
output_rids.append(req.rid)
|
||||||
|
|||||||
@@ -70,6 +70,7 @@ class ServerArgs:
|
|||||||
tp_size: int = 1
|
tp_size: int = 1
|
||||||
stream_interval: int = 1
|
stream_interval: int = 1
|
||||||
random_seed: Optional[int] = None
|
random_seed: Optional[int] = None
|
||||||
|
constrained_json_whitespace_pattern: Optional[str] = None
|
||||||
|
|
||||||
# Logging
|
# Logging
|
||||||
log_level: str = "info"
|
log_level: str = "info"
|
||||||
@@ -370,6 +371,12 @@ class ServerArgs:
|
|||||||
default=ServerArgs.random_seed,
|
default=ServerArgs.random_seed,
|
||||||
help="The random seed.",
|
help="The random seed.",
|
||||||
)
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--constrained-json-whitespace-pattern",
|
||||||
|
type=str,
|
||||||
|
default=ServerArgs.constrained_json_whitespace_pattern,
|
||||||
|
help=r"Regex pattern for syntactic whitespaces allowed in JSON constrained output. For example, to allow the model generate consecutive whitespaces, set the pattern to [\n\t ]*",
|
||||||
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--log-level",
|
"--log-level",
|
||||||
type=str,
|
type=str,
|
||||||
|
|||||||
Reference in New Issue
Block a user