Support penalty in overlap mode; return logprob with chunked prefill; improve benchmark scripts (#3988)

Co-authored-by: SangBin Cho <rkooo567@gmail.com>
Co-authored-by: dhou-xai <dhou@x.ai>
Co-authored-by: Hanming Lu <hanming_lu@berkeley.edu>
This commit is contained in:
Lianmin Zheng
2025-03-03 00:12:04 -08:00
parent 0194948fd9
commit ac2387279e
86 changed files with 4116 additions and 2015 deletions

View File

@@ -70,7 +70,10 @@ class TestSessionControl(unittest.TestCase):
first_rid = None
outputs_from_session = []
logprobs_from_session = []
cur_logprob_start_len = 0
for i, chunk_ids in enumerate(chunks_ids):
max_new_tokens = gen_len if i > 0 else 1 # prefill only for the first chunk
response = requests.post(
self.base_url + "/generate",
json={
@@ -83,12 +86,12 @@ class TestSessionControl(unittest.TestCase):
},
"sampling_params": {
"temperature": 0,
"max_new_tokens": (
gen_len if i > 0 else 1
), # prefill only for the first chunk
"max_new_tokens": max_new_tokens,
"no_stop_trim": True,
"skip_special_tokens": False,
},
"return_logprob": True,
"logprob_start_len": cur_logprob_start_len - 1,
},
).json()
rid = response["meta_info"]["id"]
@@ -96,8 +99,39 @@ class TestSessionControl(unittest.TestCase):
first_rid = rid
if i > 0:
outputs_from_session.append(response["text"])
logprobs_from_session.extend(
[
round(sublist[0], 2)
for sublist in response["meta_info"]["output_token_logprobs"]
]
)
cur_logprob_start_len += len(chunk_ids) + max_new_tokens
# query with a logprob_start_len longer than the request, should see error
response = requests.post(
self.base_url + "/generate",
json={
"input_ids": chunk_ids,
"session_params": {
"id": session_id,
"rid": rid,
"offset": -1,
"replace": True,
},
"sampling_params": {
"temperature": 0,
"max_new_tokens": max_new_tokens,
"no_stop_trim": True,
"skip_special_tokens": False,
},
"return_logprob": True,
"logprob_start_len": cur_logprob_start_len + len(chunk_ids),
},
).json()
assert "Request with a lower logprob_start_len" in response["error"]["message"]
# backtrack to the first request and regenerate
cur_logprob_start_len = 0
response = requests.post(
self.base_url + "/generate",
json={
@@ -114,9 +148,17 @@ class TestSessionControl(unittest.TestCase):
"no_stop_trim": True,
"skip_special_tokens": False,
},
"return_logprob": True,
"logprob_start_len": cur_logprob_start_len,
},
).json()
outputs_from_session.append(response["text"])
logprobs_from_session.extend(
[
round(sublist[0], 2)
for sublist in response["meta_info"]["output_token_logprobs"]
]
)
# query with a non-existing rid (the last one should be disappeared becuase of backtrack), should see abort
response = requests.post(
@@ -135,6 +177,7 @@ class TestSessionControl(unittest.TestCase):
"no_stop_trim": True,
"skip_special_tokens": False,
},
"return_logprob": True,
},
).json()
assert response["meta_info"]["finish_reason"]["type"] == "abort"
@@ -162,6 +205,7 @@ class TestSessionControl(unittest.TestCase):
"no_stop_trim": True,
"skip_special_tokens": False,
},
"return_logprob": True,
},
).json()
assert response["meta_info"]["finish_reason"]["type"] == "abort"
@@ -172,6 +216,7 @@ class TestSessionControl(unittest.TestCase):
input_ids_first_req = None
input_ids = []
outputs_normal = []
logprobs_normal = []
for i, chunk_ids in enumerate(chunks_ids):
input_ids += chunk_ids
response = requests.post(
@@ -186,6 +231,7 @@ class TestSessionControl(unittest.TestCase):
"no_stop_trim": True,
"skip_special_tokens": False,
},
"return_logprob": True,
},
).json()
if i > 0:
@@ -194,6 +240,12 @@ class TestSessionControl(unittest.TestCase):
output_ids = output_ids[1:]
input_ids += output_ids[:-1]
outputs_normal.append(response["text"])
logprobs_normal.extend(
[
round(sublist[0], 2)
for sublist in response["meta_info"]["output_token_logprobs"]
]
)
if i == 0:
input_ids_first_req = input_ids.copy()
@@ -208,17 +260,31 @@ class TestSessionControl(unittest.TestCase):
"no_stop_trim": True,
"skip_special_tokens": False,
},
"return_logprob": True,
},
).json()
outputs_normal.append(response["text"])
logprobs_normal.extend(
[
round(sublist[0], 2)
for sublist in response["meta_info"]["output_token_logprobs"]
]
)
print("outputs from chunked queries with session control:")
print(outputs_from_session)
print("outputs from normal queries:")
print(outputs_normal)
assert (
outputs_from_session == outputs_normal
), f"outputs_from_session: {outputs_from_session}, outputs_normal: {outputs_normal}"
assert outputs_from_session == outputs_normal
print("logprobs from chunked queries with session control:")
print(logprobs_from_session)
print("logprobs from normal queries:")
print(logprobs_normal)
assert len(logprobs_from_session) == len(
logprobs_normal
), "logprobs must have equal length"
for a, b in zip(logprobs_from_session, logprobs_normal):
assert abs(a - b) <= 0.1, f"logprobs {a} and {b} differ by more than 0.1"
async def async_generate(self, payload):
url = self.base_url + "/generate"