Use dtype to control generate (#1082)
Co-authored-by: zhyncs <me@zhyncs.com>
This commit is contained in:
@@ -3,14 +3,14 @@ import json
|
||||
import time
|
||||
|
||||
import sglang as sgl
|
||||
from sglang.lang.ir import REGEX_FLOAT, REGEX_INT, REGEX_STRING
|
||||
from sglang.lang.ir import REGEX_FLOAT, REGEX_INT, REGEX_STR
|
||||
from sglang.test.test_utils import (
|
||||
add_common_sglang_args_and_parse,
|
||||
select_sglang_backend,
|
||||
)
|
||||
from sglang.utils import dump_state_text, read_jsonl
|
||||
|
||||
REGEX_LIST = r"\[(" + REGEX_STRING + ", )*" + REGEX_STRING + r"\]"
|
||||
REGEX_LIST = r"\[(" + REGEX_STR + ", )*" + REGEX_STR + r"\]"
|
||||
|
||||
# fmt: off
|
||||
@sgl.function
|
||||
@@ -18,8 +18,8 @@ def json_warm_up(s):
|
||||
s += "The information about Hogwarts is in the following JSON format.\n"
|
||||
with s.var_scope("json_output"):
|
||||
s += "{\n"
|
||||
s += ' "name": ' + sgl.gen("name", max_tokens=8, regex=REGEX_STRING + ",") + "\n"
|
||||
s += ' "country": ' + sgl.gen("country", max_tokens=8, regex=REGEX_STRING + ",") + "\n"
|
||||
s += ' "name": ' + sgl.gen("name", max_tokens=8, regex=REGEX_STR + ",") + "\n"
|
||||
s += ' "country": ' + sgl.gen("country", max_tokens=8, regex=REGEX_STR + ",") + "\n"
|
||||
s += ' "latitude": ' + sgl.gen("latitude", max_tokens=8, regex=REGEX_FLOAT + ",") + "\n"
|
||||
s += ' "population": ' + sgl.gen("population", max_tokens=8, regex=REGEX_INT + ",") + "\n"
|
||||
s += ' "top 3 landmarks": ' + sgl.gen( "landmarks", max_tokens=24, regex=REGEX_LIST) + "\n"
|
||||
@@ -35,8 +35,8 @@ def json_decode(s, document):
|
||||
s += "Here is the name, country, and symbol of the city in JSON format.\n"
|
||||
with s.var_scope("json_output"):
|
||||
s += "{\n"
|
||||
s += ' "name": ' + sgl.gen("name", max_tokens=8, regex=REGEX_STRING + ",") + "\n"
|
||||
s += ' "country": ' + sgl.gen("country", max_tokens=8, regex=REGEX_STRING + ",") + "\n"
|
||||
s += ' "name": ' + sgl.gen("name", max_tokens=8, regex=REGEX_STR + ",") + "\n"
|
||||
s += ' "country": ' + sgl.gen("country", max_tokens=8, regex=REGEX_STR + ",") + "\n"
|
||||
s += ' "latitude": ' + sgl.gen("latitude", max_tokens=8, regex=REGEX_FLOAT + ",") + "\n"
|
||||
s += ' "population": ' + sgl.gen("population", max_tokens=8, regex=REGEX_INT + ",") + "\n"
|
||||
s += ' "top 3 landmarks": ' + sgl.gen( "landmarks", max_tokens=24, regex=REGEX_LIST) + "\n"
|
||||
|
||||
Reference in New Issue
Block a user