Support min_tokens in sgl.gen (#1573)

This commit is contained in:
Byron Hsu
2024-10-05 21:51:12 -07:00
committed by GitHub
parent 521f862d90
commit 2422de5193
6 changed files with 82 additions and 0 deletions

View File

@@ -69,6 +69,7 @@ def get_server_args(backend: Optional[BaseBackend] = None):
def gen(
name: Optional[str] = None,
max_tokens: Optional[int] = None,
min_tokens: Optional[int] = None,
stop: Optional[Union[str, List[str]]] = None,
stop_token_ids: Optional[List[int]] = None,
temperature: Optional[float] = None,
@@ -108,6 +109,7 @@ def gen(
return SglGen(
name,
max_tokens,
min_tokens,
stop,
stop_token_ids,
temperature,
@@ -147,6 +149,7 @@ def gen_int(
return SglGen(
name,
max_tokens,
None,
stop,
stop_token_ids,
temperature,
@@ -185,6 +188,7 @@ def gen_string(
return SglGen(
name,
max_tokens,
None,
stop,
stop_token_ids,
temperature,

View File

@@ -668,6 +668,7 @@ class StreamExecutor:
for item in [
"max_new_tokens",
"min_new_tokens",
"stop",
"stop_token_ids",
"temperature",

View File

@@ -17,6 +17,7 @@ REGEX_STR = r"\"[\w\d\s]*\"" # bugs with regex r"\".*\"" in interegular pkg
@dataclasses.dataclass
class SglSamplingParams:
max_new_tokens: int = 128
min_new_tokens: int = 0
stop: Union[str, List[str]] = ()
stop_token_ids: Optional[List[int]] = ()
temperature: float = 1.0
@@ -39,6 +40,7 @@ class SglSamplingParams:
def clone(self):
return SglSamplingParams(
self.max_new_tokens,
self.min_new_tokens,
self.stop,
self.stop_token_ids,
self.temperature,
@@ -113,6 +115,7 @@ class SglSamplingParams:
def to_srt_kwargs(self):
return {
"max_new_tokens": self.max_new_tokens,
"min_new_tokens": self.min_new_tokens,
"stop": self.stop,
"stop_token_ids": self.stop_token_ids,
"temperature": self.temperature,
@@ -424,6 +427,7 @@ class SglGen(SglExpr):
self,
name: Optional[str] = None,
max_new_tokens: Optional[int] = None,
min_new_tokens: Optional[int] = None,
stop: Optional[Union[str, List[str]]] = None,
stop_token_ids: Optional[List[int]] = None,
temperature: Optional[float] = None,
@@ -446,6 +450,7 @@ class SglGen(SglExpr):
self.name = name
self.sampling_params = SglSamplingParams(
max_new_tokens=max_new_tokens,
min_new_tokens=min_new_tokens,
stop=stop,
stop_token_ids=stop_token_ids,
temperature=temperature,

View File

@@ -517,3 +517,36 @@ def test_hellaswag_select():
accuracy = np.mean(np.array(preds) == np.array(labels))
return accuracy, latency
def test_gen_min_new_tokens():
"""
Validate sgl.gen(min_tokens) functionality.
The test asks a question where, without a min_tokens constraint, the generated answer is expected to be short.
By enforcing the min_tokens parameter, we ensure the generated answer has at least the specified number of tokens.
We verify that the number of tokens in the answer is >= the min_tokens threshold.
"""
import sglang as sgl
from sglang.srt.hf_transformers_utils import get_tokenizer
model_path = sgl.global_config.default_backend.endpoint.get_model_name()
MIN_TOKENS, MAX_TOKENS = 64, 128
@sgl.function
def convo_1(s):
s += sgl.user("What is the capital of the United States?")
s += sgl.assistant(
sgl.gen("answer", min_tokens=MIN_TOKENS, max_tokens=MAX_TOKENS)
)
def assert_min_tokens(tokenizer, text):
token_ids = tokenizer.encode(text)
assert (
len(token_ids) >= MIN_TOKENS
), f"Generated {len(token_ids)} tokens, min required: {MIN_TOKENS}. Text: {text}"
tokenizer = get_tokenizer(model_path)
state = convo_1.run()
assert_min_tokens(tokenizer, state["answer"])