Files
sglang/test/srt/sampling/penaltylib/test_srt_endpoint_with_penalizers.py

76 lines
2.1 KiB
Python

import json
import unittest
import requests
from sglang.srt.utils import kill_child_process
from sglang.test.test_utils import DEFAULT_MODEL_NAME_FOR_TEST, popen_launch_server
class TestBatchPenalizerE2E(unittest.TestCase):
@classmethod
def setUpClass(cls):
cls.model = DEFAULT_MODEL_NAME_FOR_TEST
cls.base_url = f"http://127.0.0.1:{8157}"
cls.process = popen_launch_server(
cls.model,
cls.base_url,
timeout=300,
other_args=(
"--random-seed",
"0",
),
)
@classmethod
def tearDownClass(cls):
kill_child_process(cls.process.pid)
def run_decode(
self,
return_logprob=True,
top_logprobs_num=5,
return_text=True,
n=1,
**sampling_params,
):
response = requests.post(
self.base_url + "/generate",
json={
# prompt that is supposed to generate < 32 tokens
"text": "<|start_header_id|>user<|end_header_id|>\n\nWhat is the answer for 1 + 1 = ?<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n",
"sampling_params": {
"max_new_tokens": 32,
"n": n,
**sampling_params,
},
"stream": False,
"return_logprob": return_logprob,
"top_logprobs_num": top_logprobs_num,
"return_text_in_logprobs": return_text,
"logprob_start_len": 0,
},
)
print(json.dumps(response.json()))
print("=" * 100)
def test_default_values(self):
self.run_decode()
def test_frequency_penalty(self):
self.run_decode(frequency_penalty=2)
def test_min_new_tokens(self):
self.run_decode(min_new_tokens=16)
def test_presence_penalty(self):
self.run_decode(presence_penalty=2)
def test_repetition_penalty(self):
self.run_decode(repetition_penalty=2)
if __name__ == "__main__":
unittest.main(warnings="ignore")