Refactor logprob computation to return the real logprob used in sampling (#2664)
This commit is contained in:
@@ -6,7 +6,7 @@ import requests
|
||||
|
||||
from sglang.srt.utils import kill_process_tree
|
||||
from sglang.test.test_utils import (
|
||||
DEFAULT_MODEL_NAME_FOR_TEST,
|
||||
DEFAULT_SMALL_MODEL_NAME_FOR_TEST,
|
||||
DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
|
||||
DEFAULT_URL_FOR_TEST,
|
||||
popen_launch_server,
|
||||
@@ -17,7 +17,7 @@ class TestBatchPenalizerE2E(unittest.TestCase):
|
||||
|
||||
@classmethod
|
||||
def setUpClass(cls):
|
||||
cls.model = DEFAULT_MODEL_NAME_FOR_TEST
|
||||
cls.model = DEFAULT_SMALL_MODEL_NAME_FOR_TEST
|
||||
cls.base_url = DEFAULT_URL_FOR_TEST
|
||||
cls.process = popen_launch_server(
|
||||
cls.model,
|
||||
|
||||
@@ -213,6 +213,41 @@ class TestSRTEndpoint(unittest.TestCase):
|
||||
max_diff = np.max(diff)
|
||||
self.assertLess(max_diff, 0.25)
|
||||
|
||||
def test_logprob_grammar(self):
|
||||
prompts = "Question: Is Paris the Capital of France? Answer:"
|
||||
allowed_tokens = [" Yes", " No"]
|
||||
|
||||
response = requests.post(
|
||||
self.base_url + "/generate",
|
||||
json={
|
||||
"text": prompts,
|
||||
"sampling_params": {
|
||||
"temperature": 1.0,
|
||||
"max_new_tokens": 1,
|
||||
"regex": "( Yes| No)",
|
||||
},
|
||||
"return_logprob": True,
|
||||
"top_logprobs_num": 5,
|
||||
"return_text_in_logprobs": True,
|
||||
},
|
||||
)
|
||||
response_json = response.json()
|
||||
output_top_logprobs = response_json["meta_info"]["output_top_logprobs"][0]
|
||||
print(f"{output_top_logprobs=}")
|
||||
|
||||
# Parse results
|
||||
# This is becaues the grammar constraint allows all prefix tokens
|
||||
logprobs = [None] * 2
|
||||
for i in range(len(output_top_logprobs)):
|
||||
try:
|
||||
idx = allowed_tokens.index(output_top_logprobs[i][2])
|
||||
except ValueError:
|
||||
# Not found
|
||||
continue
|
||||
logprobs[idx] = output_top_logprobs[i][0]
|
||||
|
||||
self.assertTrue(all(x is not None for x in logprobs))
|
||||
|
||||
def test_get_server_info(self):
|
||||
response = requests.get(self.base_url + "/get_server_info")
|
||||
response_json = response.json()
|
||||
|
||||
Reference in New Issue
Block a user