From 73cf6834f2a6ee0d566a1ca70db5e2c05c76486b Mon Sep 17 00:00:00 2001 From: Liangsheng Yin Date: Wed, 14 Aug 2024 17:31:39 -0700 Subject: [PATCH] Support `stop_token_ids` in sglang API (#1092) --- python/sglang/api.py | 6 ++++++ python/sglang/lang/interpreter.py | 6 ++++-- python/sglang/lang/ir.py | 11 ++++++++++- python/sglang/srt/managers/schedule_batch.py | 10 ++++++---- python/sglang/test/test_programs.py | 11 +++++++---- test/srt/test_moe_serving_throughput.py | 2 +- 6 files changed, 34 insertions(+), 12 deletions(-) diff --git a/python/sglang/api.py b/python/sglang/api.py index 2242b4a4c..887ffce76 100644 --- a/python/sglang/api.py +++ b/python/sglang/api.py @@ -62,6 +62,7 @@ def gen( name: Optional[str] = None, max_tokens: Optional[int] = None, stop: Optional[Union[str, List[str]]] = None, + stop_token_ids: Optional[List[int]] = None, temperature: Optional[float] = None, top_p: Optional[float] = None, top_k: Optional[int] = None, @@ -98,6 +99,7 @@ def gen( name, max_tokens, stop, + stop_token_ids, temperature, top_p, top_k, @@ -117,6 +119,7 @@ def gen_int( name: Optional[str] = None, max_tokens: Optional[int] = None, stop: Optional[Union[str, List[str]]] = None, + stop_token_ids: Optional[List[int]] = None, temperature: Optional[float] = None, top_p: Optional[float] = None, top_k: Optional[int] = None, @@ -132,6 +135,7 @@ def gen_int( name, max_tokens, stop, + stop_token_ids, temperature, top_p, top_k, @@ -151,6 +155,7 @@ def gen_string( name: Optional[str] = None, max_tokens: Optional[int] = None, stop: Optional[Union[str, List[str]]] = None, + stop_token_ids: Optional[List[int]] = None, temperature: Optional[float] = None, top_p: Optional[float] = None, top_k: Optional[int] = None, @@ -166,6 +171,7 @@ def gen_string( name, max_tokens, stop, + stop_token_ids, temperature, top_p, top_k, diff --git a/python/sglang/lang/interpreter.py b/python/sglang/lang/interpreter.py index cf53fac30..844c9d062 100644 --- a/python/sglang/lang/interpreter.py +++ b/python/sglang/lang/interpreter.py @@ -20,7 +20,6 @@ from sglang.lang.ir import ( SglConstantText, SglExpr, SglExprList, - SglFunction, SglGen, SglImage, SglRoleBegin, @@ -181,8 +180,10 @@ class StreamExecutor: num_api_spec_tokens=None, use_thread=True, ): + from sglang.lang.backend.base_backend import BaseBackend + self.sid = uuid.uuid4().hex - self.backend = backend + self.backend: BaseBackend = backend self.arguments: Dict[str, Any] = arguments self.default_sampling_para = default_sampling_para self.stream = stream @@ -658,6 +659,7 @@ class StreamExecutor: for item in [ "max_new_tokens", "stop", + "stop_token_ids", "temperature", "top_p", "top_k", diff --git a/python/sglang/lang/ir.py b/python/sglang/lang/ir.py index 0166b8687..9db5f2719 100644 --- a/python/sglang/lang/ir.py +++ b/python/sglang/lang/ir.py @@ -18,6 +18,7 @@ REGEX_STR = r"\"[\w\d\s]*\"" # bugs with regex r"\".*\"" in interegular pkg class SglSamplingParams: max_new_tokens: int = 128 stop: Union[str, List[str]] = () + stop_token_ids: Optional[List[int]] = () temperature: float = 1.0 top_p: float = 1.0 top_k: int = -1 # -1 means disable @@ -37,6 +38,7 @@ class SglSamplingParams: return SglSamplingParams( self.max_new_tokens, self.stop, + self.stop_token_ids, self.temperature, self.top_p, self.top_k, @@ -108,6 +110,7 @@ class SglSamplingParams: return { "max_new_tokens": self.max_new_tokens, "stop": self.stop, + "stop_token_ids": self.stop_token_ids, "temperature": self.temperature, "top_p": self.top_p, "top_k": self.top_k, @@ -141,7 +144,8 @@ class SglFunction: self, *args, max_new_tokens: int = 128, - stop: Union[str, List[str]] = (), + stop: Union[str, List[str]] = [], + stop_token_ids: Optional[List[int]] = [], temperature: float = 1.0, top_p: float = 1.0, top_k: int = -1, @@ -161,6 +165,7 @@ class SglFunction: default_sampling_para = SglSamplingParams( max_new_tokens=max_new_tokens, stop=stop, + stop_token_ids=stop_token_ids, temperature=temperature, top_p=top_p, top_k=top_k, @@ -181,6 +186,7 @@ class SglFunction: *, max_new_tokens: int = 128, stop: Union[str, List[str]] = (), + stop_token_ids: Optional[List[int]] = [], temperature: float = 1.0, top_p: float = 1.0, top_k: int = -1, @@ -218,6 +224,7 @@ class SglFunction: default_sampling_para = SglSamplingParams( max_new_tokens=max_new_tokens, stop=stop, + stop_token_ids=stop_token_ids, temperature=temperature, top_p=top_p, top_k=top_k, @@ -397,6 +404,7 @@ class SglGen(SglExpr): name: Optional[str] = None, max_new_tokens: Optional[int] = None, stop: Optional[Union[str, List[str]]] = None, + stop_token_ids: Optional[List[int]] = None, temperature: Optional[float] = None, top_p: Optional[float] = None, top_k: Optional[int] = None, @@ -416,6 +424,7 @@ class SglGen(SglExpr): self.sampling_params = SglSamplingParams( max_new_tokens=max_new_tokens, stop=stop, + stop_token_ids=stop_token_ids, temperature=temperature, top_p=top_p, top_k=top_k, diff --git a/python/sglang/srt/managers/schedule_batch.py b/python/sglang/srt/managers/schedule_batch.py index 9037f5a6e..9e86c9b18 100644 --- a/python/sglang/srt/managers/schedule_batch.py +++ b/python/sglang/srt/managers/schedule_batch.py @@ -235,10 +235,12 @@ class Req: return last_token_id = self.output_ids[-1] - if self.tokenizer is None: - matched_eos = last_token_id in self.sampling_params.stop_token_ids - else: - matched_eos = last_token_id == self.tokenizer.eos_token_id + + matched_eos = last_token_id in self.sampling_params.stop_token_ids + + if self.tokenizer is not None: + matched_eos |= last_token_id == self.tokenizer.eos_token_id + if matched_eos and not self.sampling_params.ignore_eos: self.finished_reason = FINISH_MATCHED_TOKEN(matched=last_token_id) return diff --git a/python/sglang/test/test_programs.py b/python/sglang/test/test_programs.py index 6e39f0aa9..ce4025585 100644 --- a/python/sglang/test/test_programs.py +++ b/python/sglang/test/test_programs.py @@ -106,13 +106,16 @@ def test_decode_json_regex(): from sglang.lang.ir import REGEX_FLOAT, REGEX_INT, REGEX_STR s += "Generate a JSON object to describe the basic city information of Paris.\n" + s += "Here are the JSON object:\n" + + # NOTE: we recommend using dtype gen or whole regex string to control the output with s.var_scope("json_output"): s += "{\n" - s += ' "name": ' + sgl.gen(regex=REGEX_STR + ",") + "\n" - s += ' "population": ' + sgl.gen(regex=REGEX_INT + ",") + "\n" - s += ' "area": ' + sgl.gen(regex=REGEX_INT + ",") + "\n" - s += ' "latitude": ' + sgl.gen(regex=REGEX_FLOAT) + "\n" + s += ' "name": ' + sgl.gen(regex=REGEX_STR) + ",\n" + s += ' "population": ' + sgl.gen(regex=REGEX_INT, stop=[" ", "\n"]) + ",\n" + s += ' "area": ' + sgl.gen(regex=REGEX_INT, stop=[" ", "\n"]) + ",\n" + s += ' "latitude": ' + sgl.gen(regex=REGEX_FLOAT, stop=[" ", "\n"]) + "\n" s += "}" ret = decode_json.run(temperature=0.0) diff --git a/test/srt/test_moe_serving_throughput.py b/test/srt/test_moe_serving_throughput.py index 713eba7ab..80b445f49 100644 --- a/test/srt/test_moe_serving_throughput.py +++ b/test/srt/test_moe_serving_throughput.py @@ -84,7 +84,7 @@ class TestServingThroughput(unittest.TestCase): if os.getenv("SGLANG_IS_IN_CI", "false") == "true": # A100 (PCIE) performance - assert res["output_throughput"] > 940 + assert res["output_throughput"] > 930 def test_default_with_chunked_prefill(self): res = self.run_test(