Use dtype to control generate (#1082)

Co-authored-by: zhyncs <me@zhyncs.com>
This commit is contained in:
Liangsheng Yin
2024-08-14 08:58:07 -07:00
committed by GitHub
parent 67c0d832a6
commit a34dd86a7d
12 changed files with 110 additions and 88 deletions

View File

@@ -103,13 +103,13 @@ def test_decode_int():
def test_decode_json_regex():
@sgl.function
def decode_json(s):
from sglang.lang.ir import REGEX_FLOAT, REGEX_INT, REGEX_STRING
from sglang.lang.ir import REGEX_FLOAT, REGEX_INT, REGEX_STR
s += "Generate a JSON object to describe the basic city information of Paris.\n"
with s.var_scope("json_output"):
s += "{\n"
s += ' "name": ' + sgl.gen(regex=REGEX_STRING + ",") + "\n"
s += ' "name": ' + sgl.gen(regex=REGEX_STR + ",") + "\n"
s += ' "population": ' + sgl.gen(regex=REGEX_INT + ",") + "\n"
s += ' "area": ' + sgl.gen(regex=REGEX_INT + ",") + "\n"
s += ' "latitude": ' + sgl.gen(regex=REGEX_FLOAT) + "\n"
@@ -359,6 +359,30 @@ def test_regex():
assert re.match(regex, answer)
def test_dtype_gen():
@sgl.function
def dtype_gen(s):
s += "Q: What is the full name of DNS?\n"
s += "A: The full nams is " + sgl.gen("str_res", dtype=str, stop="\n") + "\n"
s += "Q: Which year was DNS invented?\n"
s += "A: " + sgl.gen("int_res", dtype=int) + "\n"
s += "Q: What is the value of pi?\n"
s += "A: " + sgl.gen("float_res", dtype=float) + "\n"
s += "Q: Is the sky blue?\n"
s += "A: " + sgl.gen("bool_res", dtype=bool) + "\n"
state = dtype_gen.run()
try:
state["int_res"] = int(state["int_res"])
state["float_res"] = float(state["float_res"])
state["bool_res"] = bool(state["bool_res"])
# assert state["str_res"].startswith('"') and state["str_res"].endswith('"')
except ValueError:
print(state)
raise
def test_completion_speculative():
@sgl.function(num_api_spec_tokens=64)
def gen_character_spec(s):