Refactor logprob computation to return the real logprob used in sampling (#2664)
This commit is contained in:
@@ -213,6 +213,41 @@ class TestSRTEndpoint(unittest.TestCase):
|
||||
max_diff = np.max(diff)
|
||||
self.assertLess(max_diff, 0.25)
|
||||
|
||||
def test_logprob_grammar(self):
|
||||
prompts = "Question: Is Paris the Capital of France? Answer:"
|
||||
allowed_tokens = [" Yes", " No"]
|
||||
|
||||
response = requests.post(
|
||||
self.base_url + "/generate",
|
||||
json={
|
||||
"text": prompts,
|
||||
"sampling_params": {
|
||||
"temperature": 1.0,
|
||||
"max_new_tokens": 1,
|
||||
"regex": "( Yes| No)",
|
||||
},
|
||||
"return_logprob": True,
|
||||
"top_logprobs_num": 5,
|
||||
"return_text_in_logprobs": True,
|
||||
},
|
||||
)
|
||||
response_json = response.json()
|
||||
output_top_logprobs = response_json["meta_info"]["output_top_logprobs"][0]
|
||||
print(f"{output_top_logprobs=}")
|
||||
|
||||
# Parse results
|
||||
# This is becaues the grammar constraint allows all prefix tokens
|
||||
logprobs = [None] * 2
|
||||
for i in range(len(output_top_logprobs)):
|
||||
try:
|
||||
idx = allowed_tokens.index(output_top_logprobs[i][2])
|
||||
except ValueError:
|
||||
# Not found
|
||||
continue
|
||||
logprobs[idx] = output_top_logprobs[i][0]
|
||||
|
||||
self.assertTrue(all(x is not None for x in logprobs))
|
||||
|
||||
def test_get_server_info(self):
|
||||
response = requests.get(self.base_url + "/get_server_info")
|
||||
response_json = response.json()
|
||||
|
||||
Reference in New Issue
Block a user