Support stop_token_ids in sglang API (#1092)
This commit is contained in:
@@ -62,6 +62,7 @@ def gen(
|
|||||||
name: Optional[str] = None,
|
name: Optional[str] = None,
|
||||||
max_tokens: Optional[int] = None,
|
max_tokens: Optional[int] = None,
|
||||||
stop: Optional[Union[str, List[str]]] = None,
|
stop: Optional[Union[str, List[str]]] = None,
|
||||||
|
stop_token_ids: Optional[List[int]] = None,
|
||||||
temperature: Optional[float] = None,
|
temperature: Optional[float] = None,
|
||||||
top_p: Optional[float] = None,
|
top_p: Optional[float] = None,
|
||||||
top_k: Optional[int] = None,
|
top_k: Optional[int] = None,
|
||||||
@@ -98,6 +99,7 @@ def gen(
|
|||||||
name,
|
name,
|
||||||
max_tokens,
|
max_tokens,
|
||||||
stop,
|
stop,
|
||||||
|
stop_token_ids,
|
||||||
temperature,
|
temperature,
|
||||||
top_p,
|
top_p,
|
||||||
top_k,
|
top_k,
|
||||||
@@ -117,6 +119,7 @@ def gen_int(
|
|||||||
name: Optional[str] = None,
|
name: Optional[str] = None,
|
||||||
max_tokens: Optional[int] = None,
|
max_tokens: Optional[int] = None,
|
||||||
stop: Optional[Union[str, List[str]]] = None,
|
stop: Optional[Union[str, List[str]]] = None,
|
||||||
|
stop_token_ids: Optional[List[int]] = None,
|
||||||
temperature: Optional[float] = None,
|
temperature: Optional[float] = None,
|
||||||
top_p: Optional[float] = None,
|
top_p: Optional[float] = None,
|
||||||
top_k: Optional[int] = None,
|
top_k: Optional[int] = None,
|
||||||
@@ -132,6 +135,7 @@ def gen_int(
|
|||||||
name,
|
name,
|
||||||
max_tokens,
|
max_tokens,
|
||||||
stop,
|
stop,
|
||||||
|
stop_token_ids,
|
||||||
temperature,
|
temperature,
|
||||||
top_p,
|
top_p,
|
||||||
top_k,
|
top_k,
|
||||||
@@ -151,6 +155,7 @@ def gen_string(
|
|||||||
name: Optional[str] = None,
|
name: Optional[str] = None,
|
||||||
max_tokens: Optional[int] = None,
|
max_tokens: Optional[int] = None,
|
||||||
stop: Optional[Union[str, List[str]]] = None,
|
stop: Optional[Union[str, List[str]]] = None,
|
||||||
|
stop_token_ids: Optional[List[int]] = None,
|
||||||
temperature: Optional[float] = None,
|
temperature: Optional[float] = None,
|
||||||
top_p: Optional[float] = None,
|
top_p: Optional[float] = None,
|
||||||
top_k: Optional[int] = None,
|
top_k: Optional[int] = None,
|
||||||
@@ -166,6 +171,7 @@ def gen_string(
|
|||||||
name,
|
name,
|
||||||
max_tokens,
|
max_tokens,
|
||||||
stop,
|
stop,
|
||||||
|
stop_token_ids,
|
||||||
temperature,
|
temperature,
|
||||||
top_p,
|
top_p,
|
||||||
top_k,
|
top_k,
|
||||||
|
|||||||
@@ -20,7 +20,6 @@ from sglang.lang.ir import (
|
|||||||
SglConstantText,
|
SglConstantText,
|
||||||
SglExpr,
|
SglExpr,
|
||||||
SglExprList,
|
SglExprList,
|
||||||
SglFunction,
|
|
||||||
SglGen,
|
SglGen,
|
||||||
SglImage,
|
SglImage,
|
||||||
SglRoleBegin,
|
SglRoleBegin,
|
||||||
@@ -181,8 +180,10 @@ class StreamExecutor:
|
|||||||
num_api_spec_tokens=None,
|
num_api_spec_tokens=None,
|
||||||
use_thread=True,
|
use_thread=True,
|
||||||
):
|
):
|
||||||
|
from sglang.lang.backend.base_backend import BaseBackend
|
||||||
|
|
||||||
self.sid = uuid.uuid4().hex
|
self.sid = uuid.uuid4().hex
|
||||||
self.backend = backend
|
self.backend: BaseBackend = backend
|
||||||
self.arguments: Dict[str, Any] = arguments
|
self.arguments: Dict[str, Any] = arguments
|
||||||
self.default_sampling_para = default_sampling_para
|
self.default_sampling_para = default_sampling_para
|
||||||
self.stream = stream
|
self.stream = stream
|
||||||
@@ -658,6 +659,7 @@ class StreamExecutor:
|
|||||||
for item in [
|
for item in [
|
||||||
"max_new_tokens",
|
"max_new_tokens",
|
||||||
"stop",
|
"stop",
|
||||||
|
"stop_token_ids",
|
||||||
"temperature",
|
"temperature",
|
||||||
"top_p",
|
"top_p",
|
||||||
"top_k",
|
"top_k",
|
||||||
|
|||||||
@@ -18,6 +18,7 @@ REGEX_STR = r"\"[\w\d\s]*\"" # bugs with regex r"\".*\"" in interegular pkg
|
|||||||
class SglSamplingParams:
|
class SglSamplingParams:
|
||||||
max_new_tokens: int = 128
|
max_new_tokens: int = 128
|
||||||
stop: Union[str, List[str]] = ()
|
stop: Union[str, List[str]] = ()
|
||||||
|
stop_token_ids: Optional[List[int]] = ()
|
||||||
temperature: float = 1.0
|
temperature: float = 1.0
|
||||||
top_p: float = 1.0
|
top_p: float = 1.0
|
||||||
top_k: int = -1 # -1 means disable
|
top_k: int = -1 # -1 means disable
|
||||||
@@ -37,6 +38,7 @@ class SglSamplingParams:
|
|||||||
return SglSamplingParams(
|
return SglSamplingParams(
|
||||||
self.max_new_tokens,
|
self.max_new_tokens,
|
||||||
self.stop,
|
self.stop,
|
||||||
|
self.stop_token_ids,
|
||||||
self.temperature,
|
self.temperature,
|
||||||
self.top_p,
|
self.top_p,
|
||||||
self.top_k,
|
self.top_k,
|
||||||
@@ -108,6 +110,7 @@ class SglSamplingParams:
|
|||||||
return {
|
return {
|
||||||
"max_new_tokens": self.max_new_tokens,
|
"max_new_tokens": self.max_new_tokens,
|
||||||
"stop": self.stop,
|
"stop": self.stop,
|
||||||
|
"stop_token_ids": self.stop_token_ids,
|
||||||
"temperature": self.temperature,
|
"temperature": self.temperature,
|
||||||
"top_p": self.top_p,
|
"top_p": self.top_p,
|
||||||
"top_k": self.top_k,
|
"top_k": self.top_k,
|
||||||
@@ -141,7 +144,8 @@ class SglFunction:
|
|||||||
self,
|
self,
|
||||||
*args,
|
*args,
|
||||||
max_new_tokens: int = 128,
|
max_new_tokens: int = 128,
|
||||||
stop: Union[str, List[str]] = (),
|
stop: Union[str, List[str]] = [],
|
||||||
|
stop_token_ids: Optional[List[int]] = [],
|
||||||
temperature: float = 1.0,
|
temperature: float = 1.0,
|
||||||
top_p: float = 1.0,
|
top_p: float = 1.0,
|
||||||
top_k: int = -1,
|
top_k: int = -1,
|
||||||
@@ -161,6 +165,7 @@ class SglFunction:
|
|||||||
default_sampling_para = SglSamplingParams(
|
default_sampling_para = SglSamplingParams(
|
||||||
max_new_tokens=max_new_tokens,
|
max_new_tokens=max_new_tokens,
|
||||||
stop=stop,
|
stop=stop,
|
||||||
|
stop_token_ids=stop_token_ids,
|
||||||
temperature=temperature,
|
temperature=temperature,
|
||||||
top_p=top_p,
|
top_p=top_p,
|
||||||
top_k=top_k,
|
top_k=top_k,
|
||||||
@@ -181,6 +186,7 @@ class SglFunction:
|
|||||||
*,
|
*,
|
||||||
max_new_tokens: int = 128,
|
max_new_tokens: int = 128,
|
||||||
stop: Union[str, List[str]] = (),
|
stop: Union[str, List[str]] = (),
|
||||||
|
stop_token_ids: Optional[List[int]] = [],
|
||||||
temperature: float = 1.0,
|
temperature: float = 1.0,
|
||||||
top_p: float = 1.0,
|
top_p: float = 1.0,
|
||||||
top_k: int = -1,
|
top_k: int = -1,
|
||||||
@@ -218,6 +224,7 @@ class SglFunction:
|
|||||||
default_sampling_para = SglSamplingParams(
|
default_sampling_para = SglSamplingParams(
|
||||||
max_new_tokens=max_new_tokens,
|
max_new_tokens=max_new_tokens,
|
||||||
stop=stop,
|
stop=stop,
|
||||||
|
stop_token_ids=stop_token_ids,
|
||||||
temperature=temperature,
|
temperature=temperature,
|
||||||
top_p=top_p,
|
top_p=top_p,
|
||||||
top_k=top_k,
|
top_k=top_k,
|
||||||
@@ -397,6 +404,7 @@ class SglGen(SglExpr):
|
|||||||
name: Optional[str] = None,
|
name: Optional[str] = None,
|
||||||
max_new_tokens: Optional[int] = None,
|
max_new_tokens: Optional[int] = None,
|
||||||
stop: Optional[Union[str, List[str]]] = None,
|
stop: Optional[Union[str, List[str]]] = None,
|
||||||
|
stop_token_ids: Optional[List[int]] = None,
|
||||||
temperature: Optional[float] = None,
|
temperature: Optional[float] = None,
|
||||||
top_p: Optional[float] = None,
|
top_p: Optional[float] = None,
|
||||||
top_k: Optional[int] = None,
|
top_k: Optional[int] = None,
|
||||||
@@ -416,6 +424,7 @@ class SglGen(SglExpr):
|
|||||||
self.sampling_params = SglSamplingParams(
|
self.sampling_params = SglSamplingParams(
|
||||||
max_new_tokens=max_new_tokens,
|
max_new_tokens=max_new_tokens,
|
||||||
stop=stop,
|
stop=stop,
|
||||||
|
stop_token_ids=stop_token_ids,
|
||||||
temperature=temperature,
|
temperature=temperature,
|
||||||
top_p=top_p,
|
top_p=top_p,
|
||||||
top_k=top_k,
|
top_k=top_k,
|
||||||
|
|||||||
@@ -235,10 +235,12 @@ class Req:
|
|||||||
return
|
return
|
||||||
|
|
||||||
last_token_id = self.output_ids[-1]
|
last_token_id = self.output_ids[-1]
|
||||||
if self.tokenizer is None:
|
|
||||||
matched_eos = last_token_id in self.sampling_params.stop_token_ids
|
matched_eos = last_token_id in self.sampling_params.stop_token_ids
|
||||||
else:
|
|
||||||
matched_eos = last_token_id == self.tokenizer.eos_token_id
|
if self.tokenizer is not None:
|
||||||
|
matched_eos |= last_token_id == self.tokenizer.eos_token_id
|
||||||
|
|
||||||
if matched_eos and not self.sampling_params.ignore_eos:
|
if matched_eos and not self.sampling_params.ignore_eos:
|
||||||
self.finished_reason = FINISH_MATCHED_TOKEN(matched=last_token_id)
|
self.finished_reason = FINISH_MATCHED_TOKEN(matched=last_token_id)
|
||||||
return
|
return
|
||||||
|
|||||||
@@ -106,13 +106,16 @@ def test_decode_json_regex():
|
|||||||
from sglang.lang.ir import REGEX_FLOAT, REGEX_INT, REGEX_STR
|
from sglang.lang.ir import REGEX_FLOAT, REGEX_INT, REGEX_STR
|
||||||
|
|
||||||
s += "Generate a JSON object to describe the basic city information of Paris.\n"
|
s += "Generate a JSON object to describe the basic city information of Paris.\n"
|
||||||
|
s += "Here are the JSON object:\n"
|
||||||
|
|
||||||
|
# NOTE: we recommend using dtype gen or whole regex string to control the output
|
||||||
|
|
||||||
with s.var_scope("json_output"):
|
with s.var_scope("json_output"):
|
||||||
s += "{\n"
|
s += "{\n"
|
||||||
s += ' "name": ' + sgl.gen(regex=REGEX_STR + ",") + "\n"
|
s += ' "name": ' + sgl.gen(regex=REGEX_STR) + ",\n"
|
||||||
s += ' "population": ' + sgl.gen(regex=REGEX_INT + ",") + "\n"
|
s += ' "population": ' + sgl.gen(regex=REGEX_INT, stop=[" ", "\n"]) + ",\n"
|
||||||
s += ' "area": ' + sgl.gen(regex=REGEX_INT + ",") + "\n"
|
s += ' "area": ' + sgl.gen(regex=REGEX_INT, stop=[" ", "\n"]) + ",\n"
|
||||||
s += ' "latitude": ' + sgl.gen(regex=REGEX_FLOAT) + "\n"
|
s += ' "latitude": ' + sgl.gen(regex=REGEX_FLOAT, stop=[" ", "\n"]) + "\n"
|
||||||
s += "}"
|
s += "}"
|
||||||
|
|
||||||
ret = decode_json.run(temperature=0.0)
|
ret = decode_json.run(temperature=0.0)
|
||||||
|
|||||||
@@ -84,7 +84,7 @@ class TestServingThroughput(unittest.TestCase):
|
|||||||
|
|
||||||
if os.getenv("SGLANG_IS_IN_CI", "false") == "true":
|
if os.getenv("SGLANG_IS_IN_CI", "false") == "true":
|
||||||
# A100 (PCIE) performance
|
# A100 (PCIE) performance
|
||||||
assert res["output_throughput"] > 940
|
assert res["output_throughput"] > 930
|
||||||
|
|
||||||
def test_default_with_chunked_prefill(self):
|
def test_default_with_chunked_prefill(self):
|
||||||
res = self.run_test(
|
res = self.run_test(
|
||||||
|
|||||||
Reference in New Issue
Block a user