Support min_tokens in sgl.gen (#1573)
This commit is contained in:
35
examples/frontend_language/usage/sgl_gen_min_tokens.py
Normal file
35
examples/frontend_language/usage/sgl_gen_min_tokens.py
Normal file
@@ -0,0 +1,35 @@
|
||||
"""
|
||||
This example demonstrates how to use `min_tokens` to enforce sgl.gen to generate a longer sequence
|
||||
|
||||
Usage:
|
||||
python3 sgl_gen_min_tokens.py
|
||||
"""
|
||||
|
||||
import sglang as sgl
|
||||
|
||||
|
||||
@sgl.function
|
||||
def long_answer(s):
|
||||
s += sgl.user("What is the capital of the United States?")
|
||||
s += sgl.assistant(sgl.gen("answer", min_tokens=64, max_tokens=128))
|
||||
|
||||
|
||||
@sgl.function
|
||||
def short_answer(s):
|
||||
s += sgl.user("What is the capital of the United States?")
|
||||
s += sgl.assistant(sgl.gen("answer"))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
runtime = sgl.Runtime(model_path="meta-llama/Meta-Llama-3.1-8B-Instruct")
|
||||
sgl.set_default_backend(runtime)
|
||||
|
||||
state = long_answer.run()
|
||||
print("=" * 20)
|
||||
print("Longer Answer", state["answer"])
|
||||
|
||||
state = short_answer.run()
|
||||
print("=" * 20)
|
||||
print("Short Answer", state["answer"])
|
||||
|
||||
runtime.shutdown()
|
||||
@@ -69,6 +69,7 @@ def get_server_args(backend: Optional[BaseBackend] = None):
|
||||
def gen(
|
||||
name: Optional[str] = None,
|
||||
max_tokens: Optional[int] = None,
|
||||
min_tokens: Optional[int] = None,
|
||||
stop: Optional[Union[str, List[str]]] = None,
|
||||
stop_token_ids: Optional[List[int]] = None,
|
||||
temperature: Optional[float] = None,
|
||||
@@ -108,6 +109,7 @@ def gen(
|
||||
return SglGen(
|
||||
name,
|
||||
max_tokens,
|
||||
min_tokens,
|
||||
stop,
|
||||
stop_token_ids,
|
||||
temperature,
|
||||
@@ -147,6 +149,7 @@ def gen_int(
|
||||
return SglGen(
|
||||
name,
|
||||
max_tokens,
|
||||
None,
|
||||
stop,
|
||||
stop_token_ids,
|
||||
temperature,
|
||||
@@ -185,6 +188,7 @@ def gen_string(
|
||||
return SglGen(
|
||||
name,
|
||||
max_tokens,
|
||||
None,
|
||||
stop,
|
||||
stop_token_ids,
|
||||
temperature,
|
||||
|
||||
@@ -668,6 +668,7 @@ class StreamExecutor:
|
||||
|
||||
for item in [
|
||||
"max_new_tokens",
|
||||
"min_new_tokens",
|
||||
"stop",
|
||||
"stop_token_ids",
|
||||
"temperature",
|
||||
|
||||
@@ -17,6 +17,7 @@ REGEX_STR = r"\"[\w\d\s]*\"" # bugs with regex r"\".*\"" in interegular pkg
|
||||
@dataclasses.dataclass
|
||||
class SglSamplingParams:
|
||||
max_new_tokens: int = 128
|
||||
min_new_tokens: int = 0
|
||||
stop: Union[str, List[str]] = ()
|
||||
stop_token_ids: Optional[List[int]] = ()
|
||||
temperature: float = 1.0
|
||||
@@ -39,6 +40,7 @@ class SglSamplingParams:
|
||||
def clone(self):
|
||||
return SglSamplingParams(
|
||||
self.max_new_tokens,
|
||||
self.min_new_tokens,
|
||||
self.stop,
|
||||
self.stop_token_ids,
|
||||
self.temperature,
|
||||
@@ -113,6 +115,7 @@ class SglSamplingParams:
|
||||
def to_srt_kwargs(self):
|
||||
return {
|
||||
"max_new_tokens": self.max_new_tokens,
|
||||
"min_new_tokens": self.min_new_tokens,
|
||||
"stop": self.stop,
|
||||
"stop_token_ids": self.stop_token_ids,
|
||||
"temperature": self.temperature,
|
||||
@@ -424,6 +427,7 @@ class SglGen(SglExpr):
|
||||
self,
|
||||
name: Optional[str] = None,
|
||||
max_new_tokens: Optional[int] = None,
|
||||
min_new_tokens: Optional[int] = None,
|
||||
stop: Optional[Union[str, List[str]]] = None,
|
||||
stop_token_ids: Optional[List[int]] = None,
|
||||
temperature: Optional[float] = None,
|
||||
@@ -446,6 +450,7 @@ class SglGen(SglExpr):
|
||||
self.name = name
|
||||
self.sampling_params = SglSamplingParams(
|
||||
max_new_tokens=max_new_tokens,
|
||||
min_new_tokens=min_new_tokens,
|
||||
stop=stop,
|
||||
stop_token_ids=stop_token_ids,
|
||||
temperature=temperature,
|
||||
|
||||
@@ -517,3 +517,36 @@ def test_hellaswag_select():
|
||||
accuracy = np.mean(np.array(preds) == np.array(labels))
|
||||
|
||||
return accuracy, latency
|
||||
|
||||
|
||||
def test_gen_min_new_tokens():
|
||||
"""
|
||||
Validate sgl.gen(min_tokens) functionality.
|
||||
|
||||
The test asks a question where, without a min_tokens constraint, the generated answer is expected to be short.
|
||||
By enforcing the min_tokens parameter, we ensure the generated answer has at least the specified number of tokens.
|
||||
We verify that the number of tokens in the answer is >= the min_tokens threshold.
|
||||
"""
|
||||
import sglang as sgl
|
||||
from sglang.srt.hf_transformers_utils import get_tokenizer
|
||||
|
||||
model_path = sgl.global_config.default_backend.endpoint.get_model_name()
|
||||
MIN_TOKENS, MAX_TOKENS = 64, 128
|
||||
|
||||
@sgl.function
|
||||
def convo_1(s):
|
||||
s += sgl.user("What is the capital of the United States?")
|
||||
s += sgl.assistant(
|
||||
sgl.gen("answer", min_tokens=MIN_TOKENS, max_tokens=MAX_TOKENS)
|
||||
)
|
||||
|
||||
def assert_min_tokens(tokenizer, text):
|
||||
token_ids = tokenizer.encode(text)
|
||||
assert (
|
||||
len(token_ids) >= MIN_TOKENS
|
||||
), f"Generated {len(token_ids)} tokens, min required: {MIN_TOKENS}. Text: {text}"
|
||||
|
||||
tokenizer = get_tokenizer(model_path)
|
||||
|
||||
state = convo_1.run()
|
||||
assert_min_tokens(tokenizer, state["answer"])
|
||||
|
||||
@@ -7,6 +7,7 @@ from sglang.test.test_programs import (
|
||||
test_dtype_gen,
|
||||
test_expert_answer,
|
||||
test_few_shot_qa,
|
||||
test_gen_min_new_tokens,
|
||||
test_hellaswag_select,
|
||||
test_mt_bench,
|
||||
test_parallel_decoding,
|
||||
@@ -69,6 +70,9 @@ class TestSRTBackend(unittest.TestCase):
|
||||
accuracy, latency = test_hellaswag_select()
|
||||
assert accuracy > 0.71, f"{accuracy=}"
|
||||
|
||||
def test_gen_min_new_tokens(self):
|
||||
test_gen_min_new_tokens()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
|
||||
Reference in New Issue
Block a user