Refactor logprob computation to return the real logprob used in sampling (#2664)

This commit is contained in:
Lianmin Zheng
2024-12-30 04:51:38 -08:00
committed by GitHub
parent b02da24a5b
commit 9c6ba2484f
9 changed files with 305 additions and 312 deletions

View File

@@ -6,7 +6,7 @@ import requests
from sglang.srt.utils import kill_process_tree
from sglang.test.test_utils import (
DEFAULT_MODEL_NAME_FOR_TEST,
DEFAULT_SMALL_MODEL_NAME_FOR_TEST,
DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
DEFAULT_URL_FOR_TEST,
popen_launch_server,
@@ -17,7 +17,7 @@ class TestBatchPenalizerE2E(unittest.TestCase):
@classmethod
def setUpClass(cls):
cls.model = DEFAULT_MODEL_NAME_FOR_TEST
cls.model = DEFAULT_SMALL_MODEL_NAME_FOR_TEST
cls.base_url = DEFAULT_URL_FOR_TEST
cls.process = popen_launch_server(
cls.model,

View File

@@ -213,6 +213,41 @@ class TestSRTEndpoint(unittest.TestCase):
max_diff = np.max(diff)
self.assertLess(max_diff, 0.25)
def test_logprob_grammar(self):
prompts = "Question: Is Paris the Capital of France? Answer:"
allowed_tokens = [" Yes", " No"]
response = requests.post(
self.base_url + "/generate",
json={
"text": prompts,
"sampling_params": {
"temperature": 1.0,
"max_new_tokens": 1,
"regex": "( Yes| No)",
},
"return_logprob": True,
"top_logprobs_num": 5,
"return_text_in_logprobs": True,
},
)
response_json = response.json()
output_top_logprobs = response_json["meta_info"]["output_top_logprobs"][0]
print(f"{output_top_logprobs=}")
# Parse results
# This is becaues the grammar constraint allows all prefix tokens
logprobs = [None] * 2
for i in range(len(output_top_logprobs)):
try:
idx = allowed_tokens.index(output_top_logprobs[i][2])
except ValueError:
# Not found
continue
logprobs[idx] = output_top_logprobs[i][0]
self.assertTrue(all(x is not None for x in logprobs))
def test_get_server_info(self):
response = requests.get(self.base_url + "/get_server_info")
response_json = response.json()