Support penalty in overlap mode; return logprob with chunked prefill; improve benchmark scripts (#3988)
Co-authored-by: SangBin Cho <rkooo567@gmail.com> Co-authored-by: dhou-xai <dhou@x.ai> Co-authored-by: Hanming Lu <hanming_lu@berkeley.edu>
This commit is contained in:
@@ -8,6 +8,7 @@ import random
|
||||
import time
|
||||
import unittest
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
from functools import partial
|
||||
from typing import Optional
|
||||
|
||||
import numpy as np
|
||||
@@ -20,6 +21,7 @@ from sglang.test.test_utils import (
|
||||
DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
|
||||
DEFAULT_URL_FOR_TEST,
|
||||
popen_launch_server,
|
||||
run_logprob_check,
|
||||
)
|
||||
|
||||
|
||||
@@ -35,7 +37,9 @@ class TestSRTEndpoint(unittest.TestCase):
|
||||
other_args=(
|
||||
"--enable-custom-logit-processor",
|
||||
"--mem-fraction-static",
|
||||
"0.8",
|
||||
"0.7",
|
||||
"--cuda-graph-max-bs",
|
||||
"8",
|
||||
),
|
||||
)
|
||||
|
||||
@@ -131,7 +135,7 @@ class TestSRTEndpoint(unittest.TestCase):
|
||||
for i, res in enumerate(response_json):
|
||||
self.assertEqual(
|
||||
res["meta_info"]["prompt_tokens"],
|
||||
logprob_start_len + 1 + len(res["meta_info"]["input_token_logprobs"]),
|
||||
logprob_start_len + len(res["meta_info"]["input_token_logprobs"]),
|
||||
)
|
||||
assert prompts[i].endswith(
|
||||
"".join([x[-1] for x in res["meta_info"]["input_token_logprobs"]])
|
||||
@@ -235,83 +239,15 @@ class TestSRTEndpoint(unittest.TestCase):
|
||||
|
||||
diff = np.abs(output_logprobs - output_logprobs_score)
|
||||
max_diff = np.max(diff)
|
||||
self.assertLess(max_diff, 0.25)
|
||||
|
||||
def run_logprob_check(self, arg):
|
||||
(
|
||||
input_len,
|
||||
output_len,
|
||||
temperature,
|
||||
logprob_start_len,
|
||||
return_logprob,
|
||||
top_logprobs_num,
|
||||
) = arg
|
||||
input_ids = list(range(input_len))
|
||||
|
||||
response = requests.post(
|
||||
self.base_url + "/generate",
|
||||
json={
|
||||
"input_ids": input_ids,
|
||||
"sampling_params": {
|
||||
"temperature": temperature,
|
||||
"max_new_tokens": output_len,
|
||||
},
|
||||
"return_logprob": return_logprob,
|
||||
"logprob_start_len": logprob_start_len,
|
||||
"top_logprobs_num": top_logprobs_num,
|
||||
},
|
||||
)
|
||||
response_json = response.json()
|
||||
|
||||
res = response_json
|
||||
self.assertEqual(res["meta_info"]["prompt_tokens"], input_len)
|
||||
self.assertEqual(res["meta_info"]["completion_tokens"], output_len)
|
||||
|
||||
# Test the number of tokens are correct
|
||||
if return_logprob:
|
||||
# This is because if logprob_start_len == 0, we added a padding for the first token.
|
||||
# In other cases, we do not add the padding
|
||||
delta = 0 if logprob_start_len == 0 else 1
|
||||
|
||||
self.assertEqual(
|
||||
len(res["meta_info"]["input_token_logprobs"])
|
||||
+ logprob_start_len
|
||||
+ delta,
|
||||
res["meta_info"]["prompt_tokens"],
|
||||
)
|
||||
self.assertEqual(len(res["meta_info"]["output_token_logprobs"]), output_len)
|
||||
|
||||
if top_logprobs_num:
|
||||
self.assertEqual(
|
||||
len(res["meta_info"]["input_top_logprobs"])
|
||||
+ logprob_start_len
|
||||
+ delta,
|
||||
res["meta_info"]["prompt_tokens"],
|
||||
)
|
||||
self.assertEqual(
|
||||
len(res["meta_info"]["output_top_logprobs"]), output_len
|
||||
)
|
||||
|
||||
for i in range(output_len):
|
||||
self.assertEqual(
|
||||
len(res["meta_info"]["output_top_logprobs"][i]),
|
||||
top_logprobs_num,
|
||||
)
|
||||
|
||||
# Test the top-1 tokens are the same as output tokens if temperature == 0
|
||||
if temperature == 0:
|
||||
self.assertListEqual(
|
||||
res["meta_info"]["output_token_logprobs"][i],
|
||||
res["meta_info"]["output_top_logprobs"][i][0],
|
||||
)
|
||||
self.assertLess(max_diff, 0.35)
|
||||
|
||||
def test_logprob_mixed(self):
|
||||
args = []
|
||||
temperature = 0
|
||||
# input_len, output_len, temperature, logprob_start_len, return_logprob, top_logprobs_num
|
||||
for input_len in [1000, 2000]:
|
||||
for input_len in [1000, 5000, 10000, 50000]:
|
||||
for output_len in [4, 8]:
|
||||
for logprob_start_len in [0, 500, 1000]:
|
||||
for logprob_start_len in [0, 500, 2500, 5000, 25000]:
|
||||
for return_logprob in [True, False]:
|
||||
for top_logprobs_num in [0, 5]:
|
||||
|
||||
@@ -331,8 +267,9 @@ class TestSRTEndpoint(unittest.TestCase):
|
||||
|
||||
random.shuffle(args)
|
||||
|
||||
func = partial(run_logprob_check, self)
|
||||
with ThreadPoolExecutor(8) as executor:
|
||||
list(executor.map(self.run_logprob_check, args))
|
||||
list(executor.map(func, args))
|
||||
|
||||
def test_logprob_grammar(self):
|
||||
prompts = "Question: Is Paris the Capital of France? Answer:"
|
||||
@@ -427,6 +364,77 @@ class TestSRTEndpoint(unittest.TestCase):
|
||||
f"{target_token_id=}\n{sampled_tokens=}\n{custom_response=}",
|
||||
)
|
||||
|
||||
def run_stateful_custom_logit_processor(
|
||||
self, first_token_id: int | None, delay: int = 2
|
||||
):
|
||||
"""Test custom logit processor with custom params and state.
|
||||
|
||||
Should sample the first `delay` tokens normally, then output first_token_id and consecutive tokens after that.
|
||||
If first_token_id is None, the custom logit processor won't be passed in.
|
||||
"""
|
||||
|
||||
custom_params = {"token_id": first_token_id, "delay": 2}
|
||||
|
||||
class DeterministicStatefulLogitProcessor(CustomLogitProcessor):
|
||||
"""A dummy logit processor that changes the logits to always
|
||||
sample the given token id.
|
||||
"""
|
||||
|
||||
def __call__(self, logits, custom_param_list):
|
||||
assert logits.shape[0] == len(custom_param_list)
|
||||
|
||||
for i, param_dict in enumerate(custom_param_list):
|
||||
if param_dict["delay"] > 0:
|
||||
param_dict["delay"] -= 1
|
||||
continue
|
||||
if param_dict["delay"] == 0:
|
||||
param_dict["delay"] -= 1
|
||||
force_token = param_dict["token_id"]
|
||||
else:
|
||||
output_ids = param_dict["__req__"].output_ids
|
||||
force_token = output_ids[-1] + 1
|
||||
# Mask all other tokens
|
||||
logits[i, :] = -float("inf")
|
||||
# Assign highest probability to the specified token
|
||||
logits[i, force_token] = 0.0
|
||||
return logits
|
||||
|
||||
prompts = "Question: Is Paris the Capital of France? Answer:"
|
||||
|
||||
# Base case json data to be posted to the server.
|
||||
base_json = {
|
||||
"text": prompts,
|
||||
"sampling_params": {"temperature": 0.0},
|
||||
"return_logprob": True,
|
||||
}
|
||||
|
||||
# Custom json data with custom logit processor and params.
|
||||
custom_json = base_json.copy()
|
||||
# Only set the custom logit processor if target_token_id is not None.
|
||||
if first_token_id is not None:
|
||||
custom_json["custom_logit_processor"] = (
|
||||
DeterministicStatefulLogitProcessor().to_str()
|
||||
)
|
||||
custom_json["sampling_params"]["custom_params"] = custom_params
|
||||
|
||||
custom_response = requests.post(
|
||||
self.base_url + "/generate",
|
||||
json=custom_json,
|
||||
).json()
|
||||
|
||||
output_token_logprobs = custom_response["meta_info"]["output_token_logprobs"]
|
||||
sampled_tokens = [x[1] for x in output_token_logprobs]
|
||||
# The logit processor should always sample the given token as the logits is deterministic.
|
||||
if first_token_id is not None:
|
||||
self.assertTrue(
|
||||
all(
|
||||
x == custom_params["token_id"] + k
|
||||
for k, x in enumerate(sampled_tokens[custom_params["delay"] :])
|
||||
),
|
||||
# Print the detailed test case info if the test fails.
|
||||
f"{first_token_id=}\n{sampled_tokens=}\n{custom_response=}",
|
||||
)
|
||||
|
||||
def test_custom_logit_processor(self):
|
||||
"""Test custom logit processor with a single request."""
|
||||
self.run_custom_logit_processor(target_token_id=5)
|
||||
@@ -438,6 +446,19 @@ class TestSRTEndpoint(unittest.TestCase):
|
||||
with ThreadPoolExecutor(len(target_token_ids)) as executor:
|
||||
list(executor.map(self.run_custom_logit_processor, target_token_ids))
|
||||
|
||||
def test_stateful_custom_logit_processor(self):
|
||||
"""Test custom logit processor with a single request."""
|
||||
self.run_stateful_custom_logit_processor(first_token_id=5)
|
||||
|
||||
def test_stateful_custom_logit_processor_batch_mixed(self):
|
||||
"""Test a batch of requests mixed of requests with and without custom logit processor."""
|
||||
target_token_ids = list(range(32)) + [None] * 16
|
||||
random.shuffle(target_token_ids)
|
||||
with ThreadPoolExecutor(len(target_token_ids)) as executor:
|
||||
list(
|
||||
executor.map(self.run_stateful_custom_logit_processor, target_token_ids)
|
||||
)
|
||||
|
||||
def test_cache_tokens(self):
|
||||
for _ in range(2):
|
||||
time.sleep(1)
|
||||
@@ -476,6 +497,21 @@ class TestSRTEndpoint(unittest.TestCase):
|
||||
version = response_json["version"]
|
||||
self.assertIsInstance(version, str)
|
||||
|
||||
def test_get_server_info_concurrent(self):
|
||||
"""Make sure the concurrent get_server_info doesn't crash the server."""
|
||||
tp = ThreadPoolExecutor(max_workers=30)
|
||||
|
||||
def s():
|
||||
server_info = requests.get(self.base_url + "/get_server_info")
|
||||
server_info.json()
|
||||
|
||||
futures = []
|
||||
for _ in range(4):
|
||||
futures.append(tp.submit(s))
|
||||
|
||||
for f in futures:
|
||||
f.result()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
|
||||
Reference in New Issue
Block a user