Add SRT json decode example (#2)
This commit is contained in:
@@ -19,7 +19,7 @@ dependencies = [
|
|||||||
|
|
||||||
[project.optional-dependencies]
|
[project.optional-dependencies]
|
||||||
srt = ["fastapi", "psutil", "rpyc", "torch", "uvloop", "uvicorn", "zmq", "vllm>=0.2.5",
|
srt = ["fastapi", "psutil", "rpyc", "torch", "uvloop", "uvicorn", "zmq", "vllm>=0.2.5",
|
||||||
"interegular", "lark"]
|
"interegular", "lark", "numba"]
|
||||||
openai = ["openai>=1.0"]
|
openai = ["openai>=1.0"]
|
||||||
anthropic = ["anthropic"]
|
anthropic = ["anthropic"]
|
||||||
all = ["sglang[srt]", "sglang[openai]", "sglang[anthropic]"]
|
all = ["sglang[srt]", "sglang[openai]", "sglang[anthropic]"]
|
||||||
|
|||||||
@@ -6,6 +6,11 @@ from typing import List, Optional, Union
|
|||||||
|
|
||||||
from sglang.global_config import global_config
|
from sglang.global_config import global_config
|
||||||
|
|
||||||
|
REGEX_INT = r"[-+]?[0-9]+"
|
||||||
|
REGEX_FLOAT = r"[-+]?[0-9]*\.?[0-9]+"
|
||||||
|
REGEX_BOOL = r"(True|False)"
|
||||||
|
REGEX_STRING = r"\"[\w\d\s]*\"" # bugs with regex r"\".*\"" in interegular pkg
|
||||||
|
|
||||||
|
|
||||||
@dataclasses.dataclass
|
@dataclasses.dataclass
|
||||||
class SamplingParams:
|
class SamplingParams:
|
||||||
|
|||||||
@@ -102,6 +102,29 @@ def test_decode_int():
|
|||||||
assert int(ret["days"]) == 365, ret.text
|
assert int(ret["days"]) == 365, ret.text
|
||||||
|
|
||||||
|
|
||||||
|
def test_decode_json_regex():
|
||||||
|
@sgl.function
|
||||||
|
def decode_json(s):
|
||||||
|
from sglang.lang.ir import REGEX_FLOAT, REGEX_INT, REGEX_STRING
|
||||||
|
|
||||||
|
s += "Generate a JSON object to describe the basic information of a city.\n"
|
||||||
|
|
||||||
|
with s.var_scope("json_output"):
|
||||||
|
s += "{\n"
|
||||||
|
s += ' "name": ' + sgl.gen(regex=REGEX_STRING + ",") + "\n"
|
||||||
|
s += ' "population": ' + sgl.gen(regex=REGEX_INT + ",") + "\n"
|
||||||
|
s += ' "area": ' + sgl.gen(regex=REGEX_INT + ",") + "\n"
|
||||||
|
s += ' "latitude": ' + sgl.gen(regex=REGEX_FLOAT + ",") + "\n"
|
||||||
|
s += ' "country": ' + sgl.gen(regex=REGEX_STRING + ",") + "\n"
|
||||||
|
s += ' "timezone": ' + sgl.gen(regex=REGEX_STRING) + "\n"
|
||||||
|
s += "}"
|
||||||
|
|
||||||
|
ret = decode_json.run()
|
||||||
|
js_obj = json.loads(ret["json_output"])
|
||||||
|
assert isinstance(js_obj["name"], str)
|
||||||
|
assert isinstance(js_obj["population"], int)
|
||||||
|
|
||||||
|
|
||||||
def test_decode_json():
|
def test_decode_json():
|
||||||
@sgl.function
|
@sgl.function
|
||||||
def decode_json(s):
|
def decode_json(s):
|
||||||
|
|||||||
@@ -6,7 +6,7 @@ import unittest
|
|||||||
|
|
||||||
from sglang.test.test_programs import (
|
from sglang.test.test_programs import (
|
||||||
test_decode_int,
|
test_decode_int,
|
||||||
test_decode_json,
|
test_decode_json_regex,
|
||||||
test_expert_answer,
|
test_expert_answer,
|
||||||
test_few_shot_qa,
|
test_few_shot_qa,
|
||||||
test_mt_bench,
|
test_mt_bench,
|
||||||
@@ -44,6 +44,9 @@ class TestSRTBackend(unittest.TestCase):
|
|||||||
def test_decode_int(self):
|
def test_decode_int(self):
|
||||||
test_decode_int()
|
test_decode_int()
|
||||||
|
|
||||||
|
def test_decode_json_regex(self):
|
||||||
|
test_decode_json_regex()
|
||||||
|
|
||||||
def test_expert_answer(self):
|
def test_expert_answer(self):
|
||||||
test_expert_answer()
|
test_expert_answer()
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user