Use dtype to control generate (#1082)

Co-authored-by: zhyncs <me@zhyncs.com>
This commit is contained in:
Liangsheng Yin
2024-08-14 08:58:07 -07:00
committed by GitHub
parent 67c0d832a6
commit a34dd86a7d
12 changed files with 110 additions and 88 deletions

View File

@@ -3,14 +3,14 @@ import json
import time
import sglang as sgl
from sglang.lang.ir import REGEX_FLOAT, REGEX_INT, REGEX_STRING
from sglang.lang.ir import REGEX_FLOAT, REGEX_INT, REGEX_STR
from sglang.test.test_utils import (
add_common_sglang_args_and_parse,
select_sglang_backend,
)
from sglang.utils import dump_state_text, read_jsonl
REGEX_LIST = r"\[(" + REGEX_STRING + ", )*" + REGEX_STRING + r"\]"
REGEX_LIST = r"\[(" + REGEX_STR + ", )*" + REGEX_STR + r"\]"
# fmt: off
@sgl.function
@@ -18,8 +18,8 @@ def json_warm_up(s):
s += "The information about Hogwarts is in the following JSON format.\n"
with s.var_scope("json_output"):
s += "{\n"
s += ' "name": ' + sgl.gen("name", max_tokens=8, regex=REGEX_STRING + ",") + "\n"
s += ' "country": ' + sgl.gen("country", max_tokens=8, regex=REGEX_STRING + ",") + "\n"
s += ' "name": ' + sgl.gen("name", max_tokens=8, regex=REGEX_STR + ",") + "\n"
s += ' "country": ' + sgl.gen("country", max_tokens=8, regex=REGEX_STR + ",") + "\n"
s += ' "latitude": ' + sgl.gen("latitude", max_tokens=8, regex=REGEX_FLOAT + ",") + "\n"
s += ' "population": ' + sgl.gen("population", max_tokens=8, regex=REGEX_INT + ",") + "\n"
s += ' "top 3 landmarks": ' + sgl.gen( "landmarks", max_tokens=24, regex=REGEX_LIST) + "\n"
@@ -35,8 +35,8 @@ def json_decode(s, document):
s += "Here is the name, country, and symbol of the city in JSON format.\n"
with s.var_scope("json_output"):
s += "{\n"
s += ' "name": ' + sgl.gen("name", max_tokens=8, regex=REGEX_STRING + ",") + "\n"
s += ' "country": ' + sgl.gen("country", max_tokens=8, regex=REGEX_STRING + ",") + "\n"
s += ' "name": ' + sgl.gen("name", max_tokens=8, regex=REGEX_STR + ",") + "\n"
s += ' "country": ' + sgl.gen("country", max_tokens=8, regex=REGEX_STR + ",") + "\n"
s += ' "latitude": ' + sgl.gen("latitude", max_tokens=8, regex=REGEX_FLOAT + ",") + "\n"
s += ' "population": ' + sgl.gen("population", max_tokens=8, regex=REGEX_INT + ",") + "\n"
s += ' "top 3 landmarks": ' + sgl.gen( "landmarks", max_tokens=24, regex=REGEX_LIST) + "\n"