Use dtype to control generate (#1082)
Co-authored-by: zhyncs <me@zhyncs.com>
This commit is contained in:
@@ -6,11 +6,11 @@ from functools import partial
|
||||
|
||||
from tqdm import tqdm
|
||||
|
||||
from sglang.lang.ir import REGEX_FLOAT, REGEX_INT, REGEX_STRING
|
||||
from sglang.lang.ir import REGEX_FLOAT, REGEX_INT, REGEX_STR
|
||||
from sglang.test.test_utils import add_common_other_args_and_parse, get_call_generate
|
||||
from sglang.utils import dump_state_text, read_jsonl
|
||||
|
||||
REGEX_LIST = r"\[(" + REGEX_STRING + ", )*" + REGEX_STRING + r"\]"
|
||||
REGEX_LIST = r"\[(" + REGEX_STR + ", )*" + REGEX_STR + r"\]"
|
||||
|
||||
|
||||
# fmt: off
|
||||
@@ -20,9 +20,9 @@ def json_decode(document, generate):
|
||||
s += "Here is the name, country, and symbol of the city in JSON format.\n"
|
||||
s += "{\n"
|
||||
s += ' "name": '
|
||||
s += generate(s, max_tokens=8, regex=REGEX_STRING + ",") + "\n"
|
||||
s += generate(s, max_tokens=8, regex=REGEX_STR + ",") + "\n"
|
||||
s += ' "country": '
|
||||
s += generate(s, max_tokens=8, regex=REGEX_STRING + ",") + "\n"
|
||||
s += generate(s, max_tokens=8, regex=REGEX_STR + ",") + "\n"
|
||||
s += ' "latitude": '
|
||||
s += generate(s, max_tokens=8, regex=REGEX_FLOAT + ",") + "\n"
|
||||
s += ' "population": '
|
||||
|
||||
Reference in New Issue
Block a user