Support min_tokens in sgl.gen (#1573)
This commit is contained in:
@@ -517,3 +517,36 @@ def test_hellaswag_select():
|
||||
accuracy = np.mean(np.array(preds) == np.array(labels))
|
||||
|
||||
return accuracy, latency
|
||||
|
||||
|
||||
def test_gen_min_new_tokens():
|
||||
"""
|
||||
Validate sgl.gen(min_tokens) functionality.
|
||||
|
||||
The test asks a question where, without a min_tokens constraint, the generated answer is expected to be short.
|
||||
By enforcing the min_tokens parameter, we ensure the generated answer has at least the specified number of tokens.
|
||||
We verify that the number of tokens in the answer is >= the min_tokens threshold.
|
||||
"""
|
||||
import sglang as sgl
|
||||
from sglang.srt.hf_transformers_utils import get_tokenizer
|
||||
|
||||
model_path = sgl.global_config.default_backend.endpoint.get_model_name()
|
||||
MIN_TOKENS, MAX_TOKENS = 64, 128
|
||||
|
||||
@sgl.function
|
||||
def convo_1(s):
|
||||
s += sgl.user("What is the capital of the United States?")
|
||||
s += sgl.assistant(
|
||||
sgl.gen("answer", min_tokens=MIN_TOKENS, max_tokens=MAX_TOKENS)
|
||||
)
|
||||
|
||||
def assert_min_tokens(tokenizer, text):
|
||||
token_ids = tokenizer.encode(text)
|
||||
assert (
|
||||
len(token_ids) >= MIN_TOKENS
|
||||
), f"Generated {len(token_ids)} tokens, min required: {MIN_TOKENS}. Text: {text}"
|
||||
|
||||
tokenizer = get_tokenizer(model_path)
|
||||
|
||||
state = convo_1.run()
|
||||
assert_min_tokens(tokenizer, state["answer"])
|
||||
|
||||
Reference in New Issue
Block a user