Support penalty in overlap mode; return logprob with chunked prefill; improve benchmark scripts (#3988)
Co-authored-by: SangBin Cho <rkooo567@gmail.com> Co-authored-by: dhou-xai <dhou@x.ai> Co-authored-by: Hanming Lu <hanming_lu@berkeley.edu>
This commit is contained in:
@@ -70,7 +70,10 @@ class TestSessionControl(unittest.TestCase):
|
||||
|
||||
first_rid = None
|
||||
outputs_from_session = []
|
||||
logprobs_from_session = []
|
||||
cur_logprob_start_len = 0
|
||||
for i, chunk_ids in enumerate(chunks_ids):
|
||||
max_new_tokens = gen_len if i > 0 else 1 # prefill only for the first chunk
|
||||
response = requests.post(
|
||||
self.base_url + "/generate",
|
||||
json={
|
||||
@@ -83,12 +86,12 @@ class TestSessionControl(unittest.TestCase):
|
||||
},
|
||||
"sampling_params": {
|
||||
"temperature": 0,
|
||||
"max_new_tokens": (
|
||||
gen_len if i > 0 else 1
|
||||
), # prefill only for the first chunk
|
||||
"max_new_tokens": max_new_tokens,
|
||||
"no_stop_trim": True,
|
||||
"skip_special_tokens": False,
|
||||
},
|
||||
"return_logprob": True,
|
||||
"logprob_start_len": cur_logprob_start_len - 1,
|
||||
},
|
||||
).json()
|
||||
rid = response["meta_info"]["id"]
|
||||
@@ -96,8 +99,39 @@ class TestSessionControl(unittest.TestCase):
|
||||
first_rid = rid
|
||||
if i > 0:
|
||||
outputs_from_session.append(response["text"])
|
||||
logprobs_from_session.extend(
|
||||
[
|
||||
round(sublist[0], 2)
|
||||
for sublist in response["meta_info"]["output_token_logprobs"]
|
||||
]
|
||||
)
|
||||
cur_logprob_start_len += len(chunk_ids) + max_new_tokens
|
||||
|
||||
# query with a logprob_start_len longer than the request, should see error
|
||||
response = requests.post(
|
||||
self.base_url + "/generate",
|
||||
json={
|
||||
"input_ids": chunk_ids,
|
||||
"session_params": {
|
||||
"id": session_id,
|
||||
"rid": rid,
|
||||
"offset": -1,
|
||||
"replace": True,
|
||||
},
|
||||
"sampling_params": {
|
||||
"temperature": 0,
|
||||
"max_new_tokens": max_new_tokens,
|
||||
"no_stop_trim": True,
|
||||
"skip_special_tokens": False,
|
||||
},
|
||||
"return_logprob": True,
|
||||
"logprob_start_len": cur_logprob_start_len + len(chunk_ids),
|
||||
},
|
||||
).json()
|
||||
assert "Request with a lower logprob_start_len" in response["error"]["message"]
|
||||
|
||||
# backtrack to the first request and regenerate
|
||||
cur_logprob_start_len = 0
|
||||
response = requests.post(
|
||||
self.base_url + "/generate",
|
||||
json={
|
||||
@@ -114,9 +148,17 @@ class TestSessionControl(unittest.TestCase):
|
||||
"no_stop_trim": True,
|
||||
"skip_special_tokens": False,
|
||||
},
|
||||
"return_logprob": True,
|
||||
"logprob_start_len": cur_logprob_start_len,
|
||||
},
|
||||
).json()
|
||||
outputs_from_session.append(response["text"])
|
||||
logprobs_from_session.extend(
|
||||
[
|
||||
round(sublist[0], 2)
|
||||
for sublist in response["meta_info"]["output_token_logprobs"]
|
||||
]
|
||||
)
|
||||
|
||||
# query with a non-existing rid (the last one should be disappeared becuase of backtrack), should see abort
|
||||
response = requests.post(
|
||||
@@ -135,6 +177,7 @@ class TestSessionControl(unittest.TestCase):
|
||||
"no_stop_trim": True,
|
||||
"skip_special_tokens": False,
|
||||
},
|
||||
"return_logprob": True,
|
||||
},
|
||||
).json()
|
||||
assert response["meta_info"]["finish_reason"]["type"] == "abort"
|
||||
@@ -162,6 +205,7 @@ class TestSessionControl(unittest.TestCase):
|
||||
"no_stop_trim": True,
|
||||
"skip_special_tokens": False,
|
||||
},
|
||||
"return_logprob": True,
|
||||
},
|
||||
).json()
|
||||
assert response["meta_info"]["finish_reason"]["type"] == "abort"
|
||||
@@ -172,6 +216,7 @@ class TestSessionControl(unittest.TestCase):
|
||||
input_ids_first_req = None
|
||||
input_ids = []
|
||||
outputs_normal = []
|
||||
logprobs_normal = []
|
||||
for i, chunk_ids in enumerate(chunks_ids):
|
||||
input_ids += chunk_ids
|
||||
response = requests.post(
|
||||
@@ -186,6 +231,7 @@ class TestSessionControl(unittest.TestCase):
|
||||
"no_stop_trim": True,
|
||||
"skip_special_tokens": False,
|
||||
},
|
||||
"return_logprob": True,
|
||||
},
|
||||
).json()
|
||||
if i > 0:
|
||||
@@ -194,6 +240,12 @@ class TestSessionControl(unittest.TestCase):
|
||||
output_ids = output_ids[1:]
|
||||
input_ids += output_ids[:-1]
|
||||
outputs_normal.append(response["text"])
|
||||
logprobs_normal.extend(
|
||||
[
|
||||
round(sublist[0], 2)
|
||||
for sublist in response["meta_info"]["output_token_logprobs"]
|
||||
]
|
||||
)
|
||||
if i == 0:
|
||||
input_ids_first_req = input_ids.copy()
|
||||
|
||||
@@ -208,17 +260,31 @@ class TestSessionControl(unittest.TestCase):
|
||||
"no_stop_trim": True,
|
||||
"skip_special_tokens": False,
|
||||
},
|
||||
"return_logprob": True,
|
||||
},
|
||||
).json()
|
||||
outputs_normal.append(response["text"])
|
||||
logprobs_normal.extend(
|
||||
[
|
||||
round(sublist[0], 2)
|
||||
for sublist in response["meta_info"]["output_token_logprobs"]
|
||||
]
|
||||
)
|
||||
|
||||
print("outputs from chunked queries with session control:")
|
||||
print(outputs_from_session)
|
||||
print("outputs from normal queries:")
|
||||
print(outputs_normal)
|
||||
assert (
|
||||
outputs_from_session == outputs_normal
|
||||
), f"outputs_from_session: {outputs_from_session}, outputs_normal: {outputs_normal}"
|
||||
assert outputs_from_session == outputs_normal
|
||||
print("logprobs from chunked queries with session control:")
|
||||
print(logprobs_from_session)
|
||||
print("logprobs from normal queries:")
|
||||
print(logprobs_normal)
|
||||
assert len(logprobs_from_session) == len(
|
||||
logprobs_normal
|
||||
), "logprobs must have equal length"
|
||||
for a, b in zip(logprobs_from_session, logprobs_normal):
|
||||
assert abs(a - b) <= 0.1, f"logprobs {a} and {b} differ by more than 0.1"
|
||||
|
||||
async def async_generate(self, payload):
|
||||
url = self.base_url + "/generate"
|
||||
|
||||
Reference in New Issue
Block a user