Support penalty in overlap mode; return logprob with chunked prefill; improve benchmark scripts (#3988)

Co-authored-by: SangBin Cho <rkooo567@gmail.com>
Co-authored-by: dhou-xai <dhou@x.ai>
Co-authored-by: Hanming Lu <hanming_lu@berkeley.edu>
This commit is contained in:
Lianmin Zheng
2025-03-03 00:12:04 -08:00
parent 0194948fd9
commit ac2387279e
86 changed files with 4116 additions and 2015 deletions

View File

@@ -8,6 +8,7 @@ import random
import time
import unittest
from concurrent.futures import ThreadPoolExecutor
from functools import partial
from typing import Optional
import numpy as np
@@ -20,6 +21,7 @@ from sglang.test.test_utils import (
DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
DEFAULT_URL_FOR_TEST,
popen_launch_server,
run_logprob_check,
)
@@ -35,7 +37,9 @@ class TestSRTEndpoint(unittest.TestCase):
other_args=(
"--enable-custom-logit-processor",
"--mem-fraction-static",
"0.8",
"0.7",
"--cuda-graph-max-bs",
"8",
),
)
@@ -131,7 +135,7 @@ class TestSRTEndpoint(unittest.TestCase):
for i, res in enumerate(response_json):
self.assertEqual(
res["meta_info"]["prompt_tokens"],
logprob_start_len + 1 + len(res["meta_info"]["input_token_logprobs"]),
logprob_start_len + len(res["meta_info"]["input_token_logprobs"]),
)
assert prompts[i].endswith(
"".join([x[-1] for x in res["meta_info"]["input_token_logprobs"]])
@@ -235,83 +239,15 @@ class TestSRTEndpoint(unittest.TestCase):
diff = np.abs(output_logprobs - output_logprobs_score)
max_diff = np.max(diff)
self.assertLess(max_diff, 0.25)
def run_logprob_check(self, arg):
(
input_len,
output_len,
temperature,
logprob_start_len,
return_logprob,
top_logprobs_num,
) = arg
input_ids = list(range(input_len))
response = requests.post(
self.base_url + "/generate",
json={
"input_ids": input_ids,
"sampling_params": {
"temperature": temperature,
"max_new_tokens": output_len,
},
"return_logprob": return_logprob,
"logprob_start_len": logprob_start_len,
"top_logprobs_num": top_logprobs_num,
},
)
response_json = response.json()
res = response_json
self.assertEqual(res["meta_info"]["prompt_tokens"], input_len)
self.assertEqual(res["meta_info"]["completion_tokens"], output_len)
# Test the number of tokens are correct
if return_logprob:
# This is because if logprob_start_len == 0, we added a padding for the first token.
# In other cases, we do not add the padding
delta = 0 if logprob_start_len == 0 else 1
self.assertEqual(
len(res["meta_info"]["input_token_logprobs"])
+ logprob_start_len
+ delta,
res["meta_info"]["prompt_tokens"],
)
self.assertEqual(len(res["meta_info"]["output_token_logprobs"]), output_len)
if top_logprobs_num:
self.assertEqual(
len(res["meta_info"]["input_top_logprobs"])
+ logprob_start_len
+ delta,
res["meta_info"]["prompt_tokens"],
)
self.assertEqual(
len(res["meta_info"]["output_top_logprobs"]), output_len
)
for i in range(output_len):
self.assertEqual(
len(res["meta_info"]["output_top_logprobs"][i]),
top_logprobs_num,
)
# Test the top-1 tokens are the same as output tokens if temperature == 0
if temperature == 0:
self.assertListEqual(
res["meta_info"]["output_token_logprobs"][i],
res["meta_info"]["output_top_logprobs"][i][0],
)
self.assertLess(max_diff, 0.35)
def test_logprob_mixed(self):
args = []
temperature = 0
# input_len, output_len, temperature, logprob_start_len, return_logprob, top_logprobs_num
for input_len in [1000, 2000]:
for input_len in [1000, 5000, 10000, 50000]:
for output_len in [4, 8]:
for logprob_start_len in [0, 500, 1000]:
for logprob_start_len in [0, 500, 2500, 5000, 25000]:
for return_logprob in [True, False]:
for top_logprobs_num in [0, 5]:
@@ -331,8 +267,9 @@ class TestSRTEndpoint(unittest.TestCase):
random.shuffle(args)
func = partial(run_logprob_check, self)
with ThreadPoolExecutor(8) as executor:
list(executor.map(self.run_logprob_check, args))
list(executor.map(func, args))
def test_logprob_grammar(self):
prompts = "Question: Is Paris the Capital of France? Answer:"
@@ -427,6 +364,77 @@ class TestSRTEndpoint(unittest.TestCase):
f"{target_token_id=}\n{sampled_tokens=}\n{custom_response=}",
)
def run_stateful_custom_logit_processor(
self, first_token_id: int | None, delay: int = 2
):
"""Test custom logit processor with custom params and state.
Should sample the first `delay` tokens normally, then output first_token_id and consecutive tokens after that.
If first_token_id is None, the custom logit processor won't be passed in.
"""
custom_params = {"token_id": first_token_id, "delay": 2}
class DeterministicStatefulLogitProcessor(CustomLogitProcessor):
"""A dummy logit processor that changes the logits to always
sample the given token id.
"""
def __call__(self, logits, custom_param_list):
assert logits.shape[0] == len(custom_param_list)
for i, param_dict in enumerate(custom_param_list):
if param_dict["delay"] > 0:
param_dict["delay"] -= 1
continue
if param_dict["delay"] == 0:
param_dict["delay"] -= 1
force_token = param_dict["token_id"]
else:
output_ids = param_dict["__req__"].output_ids
force_token = output_ids[-1] + 1
# Mask all other tokens
logits[i, :] = -float("inf")
# Assign highest probability to the specified token
logits[i, force_token] = 0.0
return logits
prompts = "Question: Is Paris the Capital of France? Answer:"
# Base case json data to be posted to the server.
base_json = {
"text": prompts,
"sampling_params": {"temperature": 0.0},
"return_logprob": True,
}
# Custom json data with custom logit processor and params.
custom_json = base_json.copy()
# Only set the custom logit processor if target_token_id is not None.
if first_token_id is not None:
custom_json["custom_logit_processor"] = (
DeterministicStatefulLogitProcessor().to_str()
)
custom_json["sampling_params"]["custom_params"] = custom_params
custom_response = requests.post(
self.base_url + "/generate",
json=custom_json,
).json()
output_token_logprobs = custom_response["meta_info"]["output_token_logprobs"]
sampled_tokens = [x[1] for x in output_token_logprobs]
# The logit processor should always sample the given token as the logits is deterministic.
if first_token_id is not None:
self.assertTrue(
all(
x == custom_params["token_id"] + k
for k, x in enumerate(sampled_tokens[custom_params["delay"] :])
),
# Print the detailed test case info if the test fails.
f"{first_token_id=}\n{sampled_tokens=}\n{custom_response=}",
)
def test_custom_logit_processor(self):
"""Test custom logit processor with a single request."""
self.run_custom_logit_processor(target_token_id=5)
@@ -438,6 +446,19 @@ class TestSRTEndpoint(unittest.TestCase):
with ThreadPoolExecutor(len(target_token_ids)) as executor:
list(executor.map(self.run_custom_logit_processor, target_token_ids))
def test_stateful_custom_logit_processor(self):
"""Test custom logit processor with a single request."""
self.run_stateful_custom_logit_processor(first_token_id=5)
def test_stateful_custom_logit_processor_batch_mixed(self):
"""Test a batch of requests mixed of requests with and without custom logit processor."""
target_token_ids = list(range(32)) + [None] * 16
random.shuffle(target_token_ids)
with ThreadPoolExecutor(len(target_token_ids)) as executor:
list(
executor.map(self.run_stateful_custom_logit_processor, target_token_ids)
)
def test_cache_tokens(self):
for _ in range(2):
time.sleep(1)
@@ -476,6 +497,21 @@ class TestSRTEndpoint(unittest.TestCase):
version = response_json["version"]
self.assertIsInstance(version, str)
def test_get_server_info_concurrent(self):
"""Make sure the concurrent get_server_info doesn't crash the server."""
tp = ThreadPoolExecutor(max_workers=30)
def s():
server_info = requests.get(self.base_url + "/get_server_info")
server_info.json()
futures = []
for _ in range(4):
futures.append(tp.submit(s))
for f in futures:
f.result()
if __name__ == "__main__":
unittest.main()