diff --git a/python/pyproject.toml b/python/pyproject.toml index 06ec94cab..84079db35 100644 --- a/python/pyproject.toml +++ b/python/pyproject.toml @@ -19,7 +19,7 @@ dependencies = [ [project.optional-dependencies] srt = ["fastapi", "psutil", "rpyc", "torch", "uvloop", "uvicorn", "zmq", "vllm>=0.2.5", - "interegular", "lark"] + "interegular", "lark", "numba"] openai = ["openai>=1.0"] anthropic = ["anthropic"] all = ["sglang[srt]", "sglang[openai]", "sglang[anthropic]"] diff --git a/python/sglang/lang/ir.py b/python/sglang/lang/ir.py index ddce5cf54..09cf9ad2a 100644 --- a/python/sglang/lang/ir.py +++ b/python/sglang/lang/ir.py @@ -6,6 +6,11 @@ from typing import List, Optional, Union from sglang.global_config import global_config +REGEX_INT = r"[-+]?[0-9]+" +REGEX_FLOAT = r"[-+]?[0-9]*\.?[0-9]+" +REGEX_BOOL = r"(True|False)" +REGEX_STRING = r"\"[\w\d\s]*\"" # bugs with regex r"\".*\"" in interegular pkg + @dataclasses.dataclass class SamplingParams: diff --git a/python/sglang/test/test_programs.py b/python/sglang/test/test_programs.py index 805e50c1e..38e102902 100644 --- a/python/sglang/test/test_programs.py +++ b/python/sglang/test/test_programs.py @@ -102,6 +102,29 @@ def test_decode_int(): assert int(ret["days"]) == 365, ret.text +def test_decode_json_regex(): + @sgl.function + def decode_json(s): + from sglang.lang.ir import REGEX_FLOAT, REGEX_INT, REGEX_STRING + + s += "Generate a JSON object to describe the basic information of a city.\n" + + with s.var_scope("json_output"): + s += "{\n" + s += ' "name": ' + sgl.gen(regex=REGEX_STRING + ",") + "\n" + s += ' "population": ' + sgl.gen(regex=REGEX_INT + ",") + "\n" + s += ' "area": ' + sgl.gen(regex=REGEX_INT + ",") + "\n" + s += ' "latitude": ' + sgl.gen(regex=REGEX_FLOAT + ",") + "\n" + s += ' "country": ' + sgl.gen(regex=REGEX_STRING + ",") + "\n" + s += ' "timezone": ' + sgl.gen(regex=REGEX_STRING) + "\n" + s += "}" + + ret = decode_json.run() + js_obj = json.loads(ret["json_output"]) + assert isinstance(js_obj["name"], str) + assert isinstance(js_obj["population"], int) + + def test_decode_json(): @sgl.function def decode_json(s): diff --git a/test/lang/test_srt_backend.py b/test/lang/test_srt_backend.py index 374c53db9..102f57f24 100644 --- a/test/lang/test_srt_backend.py +++ b/test/lang/test_srt_backend.py @@ -6,7 +6,7 @@ import unittest from sglang.test.test_programs import ( test_decode_int, - test_decode_json, + test_decode_json_regex, test_expert_answer, test_few_shot_qa, test_mt_bench, @@ -44,6 +44,9 @@ class TestSRTBackend(unittest.TestCase): def test_decode_int(self): test_decode_int() + def test_decode_json_regex(self): + test_decode_json_regex() + def test_expert_answer(self): test_expert_answer()