Add constrained_json_whitespace_pattern to ServerArgs (#1438)

This commit is contained in:
zifeitong
2024-09-16 13:29:18 -07:00
committed by GitHub
parent 2abe4f1cb6
commit 93dffd699b
3 changed files with 17 additions and 7 deletions

View File

@@ -29,6 +29,7 @@ class FSMCache(BaseToolCache):
tokenizer_args_dict,
enable=True,
skip_tokenizer_init=False,
constrained_json_whitespace_pattern=None,
):
super().__init__(enable=enable)
@@ -63,11 +64,14 @@ class FSMCache(BaseToolCache):
self.outlines_tokenizer.vocabulary = (
self.outlines_tokenizer.tokenizer.get_vocab()
)
self.constrained_json_whitespace_pattern = constrained_json_whitespace_pattern
def init_value(self, key):
key_type, key_string = key
if key_type == "json":
regex = build_regex_from_schema(key_string, whitespace_pattern=r"[\n\t ]*")
regex = build_regex_from_schema(
key_string, whitespace_pattern=self.constrained_json_whitespace_pattern
)
elif key_type == "regex":
regex = key_string
else:

View File

@@ -198,6 +198,7 @@ class ModelTpServer:
"trust_remote_code": server_args.trust_remote_code,
},
skip_tokenizer_init=server_args.skip_tokenizer_init,
constrained_json_whitespace_pattern=server_args.constrained_json_whitespace_pattern,
)
self.jump_forward_cache = JumpForwardCache()
@@ -807,12 +808,10 @@ class ModelTpServer:
unfinished_indices.append(i)
if req.finished() or (
(
req.stream
and (
self.decode_forward_ct % self.stream_interval == 0
or len(req.output_ids) == 1
)
req.stream
and (
self.decode_forward_ct % self.stream_interval == 0
or len(req.output_ids) == 1
)
):
output_rids.append(req.rid)

View File

@@ -70,6 +70,7 @@ class ServerArgs:
tp_size: int = 1
stream_interval: int = 1
random_seed: Optional[int] = None
constrained_json_whitespace_pattern: Optional[str] = None
# Logging
log_level: str = "info"
@@ -370,6 +371,12 @@ class ServerArgs:
default=ServerArgs.random_seed,
help="The random seed.",
)
parser.add_argument(
"--constrained-json-whitespace-pattern",
type=str,
default=ServerArgs.constrained_json_whitespace_pattern,
help=r"Regex pattern for syntactic whitespaces allowed in JSON constrained output. For example, to allow the model generate consecutive whitespaces, set the pattern to [\n\t ]*",
)
parser.add_argument(
"--log-level",
type=str,