Use dtype to control generate (#1082)

Co-authored-by: zhyncs <me@zhyncs.com>
This commit is contained in:
Liangsheng Yin
2024-08-14 08:58:07 -07:00
committed by GitHub
parent 67c0d832a6
commit a34dd86a7d
12 changed files with 110 additions and 88 deletions

View File

@@ -6,11 +6,11 @@ from functools import partial
from tqdm import tqdm
from sglang.lang.ir import REGEX_FLOAT, REGEX_INT, REGEX_STRING
from sglang.lang.ir import REGEX_FLOAT, REGEX_INT, REGEX_STR
from sglang.test.test_utils import add_common_other_args_and_parse, get_call_generate
from sglang.utils import dump_state_text, read_jsonl
REGEX_LIST = r"\[(" + REGEX_STRING + ", )*" + REGEX_STRING + r"\]"
REGEX_LIST = r"\[(" + REGEX_STR + ", )*" + REGEX_STR + r"\]"
# fmt: off
@@ -20,9 +20,9 @@ def json_decode(document, generate):
s += "Here is the name, country, and symbol of the city in JSON format.\n"
s += "{\n"
s += ' "name": '
s += generate(s, max_tokens=8, regex=REGEX_STRING + ",") + "\n"
s += generate(s, max_tokens=8, regex=REGEX_STR + ",") + "\n"
s += ' "country": '
s += generate(s, max_tokens=8, regex=REGEX_STRING + ",") + "\n"
s += generate(s, max_tokens=8, regex=REGEX_STR + ",") + "\n"
s += ' "latitude": '
s += generate(s, max_tokens=8, regex=REGEX_FLOAT + ",") + "\n"
s += ' "population": '