diff --git a/examples/frontend_language/usage/sgl_gen_min_tokens.py b/examples/frontend_language/usage/sgl_gen_min_tokens.py new file mode 100644 index 000000000..a5088199b --- /dev/null +++ b/examples/frontend_language/usage/sgl_gen_min_tokens.py @@ -0,0 +1,35 @@ +""" +This example demonstrates how to use `min_tokens` to enforce sgl.gen to generate a longer sequence + +Usage: +python3 sgl_gen_min_tokens.py +""" + +import sglang as sgl + + +@sgl.function +def long_answer(s): + s += sgl.user("What is the capital of the United States?") + s += sgl.assistant(sgl.gen("answer", min_tokens=64, max_tokens=128)) + + +@sgl.function +def short_answer(s): + s += sgl.user("What is the capital of the United States?") + s += sgl.assistant(sgl.gen("answer")) + + +if __name__ == "__main__": + runtime = sgl.Runtime(model_path="meta-llama/Meta-Llama-3.1-8B-Instruct") + sgl.set_default_backend(runtime) + + state = long_answer.run() + print("=" * 20) + print("Longer Answer", state["answer"]) + + state = short_answer.run() + print("=" * 20) + print("Short Answer", state["answer"]) + + runtime.shutdown() diff --git a/python/sglang/api.py b/python/sglang/api.py index 5adda6022..4082deae1 100644 --- a/python/sglang/api.py +++ b/python/sglang/api.py @@ -69,6 +69,7 @@ def get_server_args(backend: Optional[BaseBackend] = None): def gen( name: Optional[str] = None, max_tokens: Optional[int] = None, + min_tokens: Optional[int] = None, stop: Optional[Union[str, List[str]]] = None, stop_token_ids: Optional[List[int]] = None, temperature: Optional[float] = None, @@ -108,6 +109,7 @@ def gen( return SglGen( name, max_tokens, + min_tokens, stop, stop_token_ids, temperature, @@ -147,6 +149,7 @@ def gen_int( return SglGen( name, max_tokens, + None, stop, stop_token_ids, temperature, @@ -185,6 +188,7 @@ def gen_string( return SglGen( name, max_tokens, + None, stop, stop_token_ids, temperature, diff --git a/python/sglang/lang/interpreter.py b/python/sglang/lang/interpreter.py index 31c39d76a..44ea17f66 100644 --- a/python/sglang/lang/interpreter.py +++ b/python/sglang/lang/interpreter.py @@ -668,6 +668,7 @@ class StreamExecutor: for item in [ "max_new_tokens", + "min_new_tokens", "stop", "stop_token_ids", "temperature", diff --git a/python/sglang/lang/ir.py b/python/sglang/lang/ir.py index 75f4d0bb6..5c03db068 100644 --- a/python/sglang/lang/ir.py +++ b/python/sglang/lang/ir.py @@ -17,6 +17,7 @@ REGEX_STR = r"\"[\w\d\s]*\"" # bugs with regex r"\".*\"" in interegular pkg @dataclasses.dataclass class SglSamplingParams: max_new_tokens: int = 128 + min_new_tokens: int = 0 stop: Union[str, List[str]] = () stop_token_ids: Optional[List[int]] = () temperature: float = 1.0 @@ -39,6 +40,7 @@ class SglSamplingParams: def clone(self): return SglSamplingParams( self.max_new_tokens, + self.min_new_tokens, self.stop, self.stop_token_ids, self.temperature, @@ -113,6 +115,7 @@ class SglSamplingParams: def to_srt_kwargs(self): return { "max_new_tokens": self.max_new_tokens, + "min_new_tokens": self.min_new_tokens, "stop": self.stop, "stop_token_ids": self.stop_token_ids, "temperature": self.temperature, @@ -424,6 +427,7 @@ class SglGen(SglExpr): self, name: Optional[str] = None, max_new_tokens: Optional[int] = None, + min_new_tokens: Optional[int] = None, stop: Optional[Union[str, List[str]]] = None, stop_token_ids: Optional[List[int]] = None, temperature: Optional[float] = None, @@ -446,6 +450,7 @@ class SglGen(SglExpr): self.name = name self.sampling_params = SglSamplingParams( max_new_tokens=max_new_tokens, + min_new_tokens=min_new_tokens, stop=stop, stop_token_ids=stop_token_ids, temperature=temperature, diff --git a/python/sglang/test/test_programs.py b/python/sglang/test/test_programs.py index 092c5369d..a251e0aca 100644 --- a/python/sglang/test/test_programs.py +++ b/python/sglang/test/test_programs.py @@ -517,3 +517,36 @@ def test_hellaswag_select(): accuracy = np.mean(np.array(preds) == np.array(labels)) return accuracy, latency + + +def test_gen_min_new_tokens(): + """ + Validate sgl.gen(min_tokens) functionality. + + The test asks a question where, without a min_tokens constraint, the generated answer is expected to be short. + By enforcing the min_tokens parameter, we ensure the generated answer has at least the specified number of tokens. + We verify that the number of tokens in the answer is >= the min_tokens threshold. + """ + import sglang as sgl + from sglang.srt.hf_transformers_utils import get_tokenizer + + model_path = sgl.global_config.default_backend.endpoint.get_model_name() + MIN_TOKENS, MAX_TOKENS = 64, 128 + + @sgl.function + def convo_1(s): + s += sgl.user("What is the capital of the United States?") + s += sgl.assistant( + sgl.gen("answer", min_tokens=MIN_TOKENS, max_tokens=MAX_TOKENS) + ) + + def assert_min_tokens(tokenizer, text): + token_ids = tokenizer.encode(text) + assert ( + len(token_ids) >= MIN_TOKENS + ), f"Generated {len(token_ids)} tokens, min required: {MIN_TOKENS}. Text: {text}" + + tokenizer = get_tokenizer(model_path) + + state = convo_1.run() + assert_min_tokens(tokenizer, state["answer"]) diff --git a/test/lang/test_srt_backend.py b/test/lang/test_srt_backend.py index 5bc565c18..106196a6a 100644 --- a/test/lang/test_srt_backend.py +++ b/test/lang/test_srt_backend.py @@ -7,6 +7,7 @@ from sglang.test.test_programs import ( test_dtype_gen, test_expert_answer, test_few_shot_qa, + test_gen_min_new_tokens, test_hellaswag_select, test_mt_bench, test_parallel_decoding, @@ -69,6 +70,9 @@ class TestSRTBackend(unittest.TestCase): accuracy, latency = test_hellaswag_select() assert accuracy > 0.71, f"{accuracy=}" + def test_gen_min_new_tokens(self): + test_gen_min_new_tokens() + if __name__ == "__main__": unittest.main()