Files
sglang/test/srt/test_srt_endpoint.py

525 lines
19 KiB
Python

"""
python3 -m unittest test_srt_endpoint.TestSRTEndpoint.test_simple_decode
python3 -m unittest test_srt_endpoint.TestSRTEndpoint.test_logprob_with_chunked_prefill
"""
import json
import random
import time
import unittest
from concurrent.futures import ThreadPoolExecutor
from functools import partial
from typing import Optional
import numpy as np
import requests
from sglang.srt.sampling.custom_logit_processor import CustomLogitProcessor
from sglang.srt.utils import kill_process_tree
from sglang.test.test_utils import (
DEFAULT_SMALL_MODEL_NAME_FOR_TEST,
DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
DEFAULT_URL_FOR_TEST,
CustomTestCase,
popen_launch_server,
run_logprob_check,
)
class TestSRTEndpoint(CustomTestCase):
@classmethod
def setUpClass(cls):
cls.model = DEFAULT_SMALL_MODEL_NAME_FOR_TEST
cls.base_url = DEFAULT_URL_FOR_TEST
cls.process = popen_launch_server(
cls.model,
cls.base_url,
timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
other_args=(
"--enable-custom-logit-processor",
"--mem-fraction-static",
"0.7",
"--cuda-graph-max-bs",
"8",
),
)
@classmethod
def tearDownClass(cls):
kill_process_tree(cls.process.pid)
def run_decode(
self,
return_logprob=False,
top_logprobs_num=0,
return_text=False,
n=1,
stream=False,
batch=False,
):
if batch:
text = ["The capital of France is"]
else:
text = "The capital of France is"
response = requests.post(
self.base_url + "/generate",
json={
"text": text,
"sampling_params": {
"temperature": 0 if n == 1 else 0.5,
"max_new_tokens": 16,
"n": n,
},
"stream": stream,
"return_logprob": return_logprob,
"top_logprobs_num": top_logprobs_num,
"return_text_in_logprobs": return_text,
"logprob_start_len": 0,
},
)
if not stream:
response_json = response.json()
else:
response_json = []
for line in response.iter_lines():
if line.startswith(b"data: ") and line[6:] != b"[DONE]":
response_json.append(json.loads(line[6:]))
print(json.dumps(response_json, indent=2))
print("=" * 100)
def test_simple_decode(self):
self.run_decode()
def test_simple_decode_batch(self):
self.run_decode(batch=True)
def test_parallel_sample(self):
self.run_decode(n=3)
def test_parallel_sample_stream(self):
self.run_decode(n=3, stream=True)
def test_logprob(self):
self.run_decode(
return_logprob=True,
top_logprobs_num=5,
return_text=True,
)
def test_logprob_start_len(self):
logprob_start_len = 4
new_tokens = 4
prompts = [
"I have a very good idea on",
"Today is a sunndy day and",
]
response = requests.post(
self.base_url + "/generate",
json={
"text": prompts,
"sampling_params": {
"temperature": 0,
"max_new_tokens": new_tokens,
},
"return_logprob": True,
"top_logprobs_num": 5,
"return_text_in_logprobs": True,
"logprob_start_len": logprob_start_len,
},
)
response_json = response.json()
print(json.dumps(response_json, indent=2))
for i, res in enumerate(response_json):
self.assertEqual(
res["meta_info"]["prompt_tokens"],
logprob_start_len + len(res["meta_info"]["input_token_logprobs"]),
)
assert prompts[i].endswith(
"".join([x[-1] for x in res["meta_info"]["input_token_logprobs"]])
)
self.assertEqual(res["meta_info"]["completion_tokens"], new_tokens)
self.assertEqual(len(res["meta_info"]["output_token_logprobs"]), new_tokens)
self.assertEqual(
res["text"],
"".join([x[-1] for x in res["meta_info"]["output_token_logprobs"]]),
)
def test_logprob_with_chunked_prefill(self):
"""Test a long prompt that requests output logprobs will not hit OOM."""
new_tokens = 4
prompts = "I have a very good idea on this. " * 8000
response = requests.post(
self.base_url + "/generate",
json={
"text": prompts,
"sampling_params": {
"temperature": 0,
"max_new_tokens": new_tokens,
},
"return_logprob": True,
"logprob_start_len": -1,
"top_logprobs_num": 5,
},
)
response_json = response.json()
# print(json.dumps(response_json, indent=2))
res = response_json
self.assertEqual(res["meta_info"]["completion_tokens"], new_tokens)
# Test the number of tokens are correct
self.assertEqual(len(res["meta_info"]["output_token_logprobs"]), new_tokens)
self.assertEqual(len(res["meta_info"]["output_top_logprobs"]), new_tokens)
# Test the top-1 tokens are the same as output tokens (because temp = 0.0)
for i in range(new_tokens):
self.assertListEqual(
res["meta_info"]["output_token_logprobs"][i],
res["meta_info"]["output_top_logprobs"][i][0],
)
self.assertEqual(len(res["meta_info"]["output_top_logprobs"][i]), 5)
def test_logprob_match(self):
"""Test the output logprobs are close to the input logprobs if we run a prefill again."""
def run_generate(
prompt, return_logprob=False, max_new_tokens=512, logprob_start_len=-1
):
if isinstance(prompt, str):
prompt_kwargs = {"text": prompt}
else:
prompt_kwargs = {"input_ids": prompt}
response = requests.post(
self.base_url + "/generate",
json={
**prompt_kwargs,
"sampling_params": {
"temperature": 1.0,
"max_new_tokens": max_new_tokens,
"ignore_eos": True,
},
"return_logprob": return_logprob,
"return_text_in_logprobs": True,
"logprob_start_len": logprob_start_len,
},
)
return response.json()
prompt = "I have a very good idea on how to"
gen = run_generate(prompt, return_logprob=True, logprob_start_len=0)
output_logprobs = np.array(
[x[0] for x in gen["meta_info"]["output_token_logprobs"]]
)
num_prompts_tokens = gen["meta_info"]["prompt_tokens"]
input_tokens = [x[1] for x in gen["meta_info"]["input_token_logprobs"]]
output_tokens = [x[1] for x in gen["meta_info"]["output_token_logprobs"]]
new_prompt = input_tokens + output_tokens
score = run_generate(
new_prompt, return_logprob=True, logprob_start_len=0, max_new_tokens=0
)
output_logprobs_score = np.array(
[
x[0]
for x in score["meta_info"]["input_token_logprobs"][num_prompts_tokens:]
]
)
print(f"{output_logprobs[-10:]=}")
print(f"{output_logprobs_score[-10:]=}")
diff = np.abs(output_logprobs - output_logprobs_score)
max_diff = np.max(diff)
self.assertLess(max_diff, 0.35)
def test_logprob_mixed(self):
args = []
temperature = 0
# input_len, output_len, temperature, logprob_start_len, return_logprob, top_logprobs_num
for input_len in [1000, 5000, 10000, 50000]:
for output_len in [4, 8]:
for logprob_start_len in [0, 500, 2500, 5000, 25000]:
for return_logprob in [True, False]:
for top_logprobs_num in [0, 5]:
if logprob_start_len >= input_len:
continue
args.append(
(
input_len,
output_len,
temperature,
logprob_start_len,
return_logprob,
top_logprobs_num,
)
)
random.shuffle(args)
func = partial(run_logprob_check, self)
with ThreadPoolExecutor(8) as executor:
list(executor.map(func, args))
def test_logprob_grammar(self):
prompts = "Question: Is Paris the Capital of France? Answer:"
allowed_tokens = [" Yes", " No"]
response = requests.post(
self.base_url + "/generate",
json={
"text": prompts,
"sampling_params": {
"temperature": 1.0,
"max_new_tokens": 1,
"regex": "( Yes| No)",
},
"return_logprob": True,
"top_logprobs_num": 5, # The grammar constraint allows all prefix tokens so we need to use a larger top_k.
"return_text_in_logprobs": True,
},
)
response_json = response.json()
output_top_logprobs = response_json["meta_info"]["output_top_logprobs"][0]
print(f"{output_top_logprobs=}")
# Parse results
# This is because the grammar constraint allows all prefix tokens
logprobs = [None] * 2
for i in range(len(output_top_logprobs)):
try:
idx = allowed_tokens.index(output_top_logprobs[i][2])
except ValueError:
# Not found
continue
logprobs[idx] = output_top_logprobs[i][0]
self.assertTrue(all(x is not None for x in logprobs))
def run_custom_logit_processor(self, target_token_id: Optional[int] = None):
"""Test custom logit processor with custom params.
If target_token_id is None, the custom logit processor won't be passed in.
"""
custom_params = {"token_id": target_token_id}
class DeterministicLogitProcessor(CustomLogitProcessor):
"""A dummy logit processor that changes the logits to always
sample the given token id.
"""
def __call__(self, logits, custom_param_list):
assert logits.shape[0] == len(custom_param_list)
key = "token_id"
for i, param_dict in enumerate(custom_param_list):
# Mask all other tokens
logits[i, :] = -float("inf")
# Assign highest probability to the specified token
logits[i, param_dict[key]] = 0.0
return logits
prompts = "Question: Is Paris the Capital of France? Answer:"
# Base case json data to be posted to the server.
base_json = {
"text": prompts,
"sampling_params": {"temperature": 0.0},
"return_logprob": True,
}
# Custom json data with custom logit processor and params.
custom_json = base_json.copy()
# Only set the custom logit processor if target_token_id is not None.
if target_token_id is not None:
custom_json["custom_logit_processor"] = DeterministicLogitProcessor.to_str()
custom_json["sampling_params"]["custom_params"] = custom_params
custom_response = requests.post(
self.base_url + "/generate",
json=custom_json,
).json()
output_token_logprobs = custom_response["meta_info"]["output_token_logprobs"]
sampled_tokens = [x[1] for x in output_token_logprobs]
# The logit processor should always sample the given token as the logits is deterministic.
if target_token_id is not None:
self.assertTrue(
all(x == custom_params["token_id"] for x in sampled_tokens),
# Print the detailed test case info if the test fails.
f"{target_token_id=}\n{sampled_tokens=}\n{custom_response=}",
)
def run_stateful_custom_logit_processor(
self, first_token_id: int | None, delay: int = 2
):
"""Test custom logit processor with custom params and state.
Should sample the first `delay` tokens normally, then output first_token_id and consecutive tokens after that.
If first_token_id is None, the custom logit processor won't be passed in.
"""
custom_params = {"token_id": first_token_id, "delay": 2}
class DeterministicStatefulLogitProcessor(CustomLogitProcessor):
"""A dummy logit processor that changes the logits to always
sample the given token id.
"""
def __call__(self, logits, custom_param_list):
assert logits.shape[0] == len(custom_param_list)
for i, param_dict in enumerate(custom_param_list):
if param_dict["delay"] > 0:
param_dict["delay"] -= 1
continue
if param_dict["delay"] == 0:
param_dict["delay"] -= 1
force_token = param_dict["token_id"]
else:
output_ids = param_dict["__req__"].output_ids
force_token = output_ids[-1] + 1
# Mask all other tokens
logits[i, :] = -float("inf")
# Assign highest probability to the specified token
logits[i, force_token] = 0.0
return logits
prompts = "Question: Is Paris the Capital of France? Answer:"
# Base case json data to be posted to the server.
base_json = {
"text": prompts,
"sampling_params": {"temperature": 0.0},
"return_logprob": True,
}
# Custom json data with custom logit processor and params.
custom_json = base_json.copy()
# Only set the custom logit processor if target_token_id is not None.
if first_token_id is not None:
custom_json["custom_logit_processor"] = (
DeterministicStatefulLogitProcessor().to_str()
)
custom_json["sampling_params"]["custom_params"] = custom_params
custom_response = requests.post(
self.base_url + "/generate",
json=custom_json,
).json()
output_token_logprobs = custom_response["meta_info"]["output_token_logprobs"]
sampled_tokens = [x[1] for x in output_token_logprobs]
# The logit processor should always sample the given token as the logits is deterministic.
if first_token_id is not None:
self.assertTrue(
all(
x == custom_params["token_id"] + k
for k, x in enumerate(sampled_tokens[custom_params["delay"] :])
),
# Print the detailed test case info if the test fails.
f"{first_token_id=}\n{sampled_tokens=}\n{custom_response=}",
)
def test_custom_logit_processor(self):
"""Test custom logit processor with a single request."""
self.run_custom_logit_processor(target_token_id=5)
def test_custom_logit_processor_batch_mixed(self):
"""Test a batch of requests mixed of requests with and without custom logit processor."""
target_token_ids = list(range(32)) + [None] * 16
random.shuffle(target_token_ids)
with ThreadPoolExecutor(len(target_token_ids)) as executor:
list(executor.map(self.run_custom_logit_processor, target_token_ids))
@unittest.skip("Skip this test because this feature has a bug. See comments below.")
def test_stateful_custom_logit_processor(self):
"""Test custom logit processor with a single request."""
"""
NOTE: This feature has a race condition bug.
This line https://github.com/sgl-project/sglang/blob/ef8ec07b2ce4c70c2a33ec5acda4ce529bc3cda4/test/srt/test_srt_endpoint.py#L395-L396 can be accessed by two concurrent threads at the same time. The access order is not guaranteed.
In sglang, we use two python threads to overlap the GPU computation and CPU scheduling.
Thread 1 (the CPU scheduling thread) will update the `param_dict["__req__"].output_ids`.
Thread 2 (the GPU computation thread) will call `DeterministicStatefulLogitProcessor` because sampling is considered as GPU computation.
We can fix this by moving the call of DeterministicStatefulLogitProcessor to the CPU scheduling thread.
"""
self.run_stateful_custom_logit_processor(first_token_id=5)
@unittest.skip("Skip this test because this feature has a bug. See comments above.")
def test_stateful_custom_logit_processor_batch_mixed(self):
"""Test a batch of requests mixed of requests with and without custom logit processor."""
target_token_ids = list(range(32)) + [None] * 16
random.shuffle(target_token_ids)
with ThreadPoolExecutor(len(target_token_ids)) as executor:
list(
executor.map(self.run_stateful_custom_logit_processor, target_token_ids)
)
def test_cache_tokens(self):
for _ in range(2):
time.sleep(1)
response = requests.post(self.base_url + "/flush_cache")
assert response.status_code == 200
def send_and_check_cached_tokens(input_ids):
response = requests.post(
self.base_url + "/generate",
json={
"input_ids": list(input_ids),
"sampling_params": {
"max_new_tokens": 1,
},
},
)
response_json = response.json()
return response_json["meta_info"]["cached_tokens"]
self.assertEqual(send_and_check_cached_tokens(range(0, 100)), 0)
self.assertEqual(send_and_check_cached_tokens(range(0, 10000)), 100)
self.assertEqual(send_and_check_cached_tokens(range(0, 10000)), 9999)
self.assertEqual(send_and_check_cached_tokens(range(0, 1000)), 999)
self.assertEqual(send_and_check_cached_tokens(range(0, 11000)), 10000)
def test_get_server_info(self):
response = requests.get(self.base_url + "/get_server_info")
response_json = response.json()
max_total_num_tokens = response_json["max_total_num_tokens"]
self.assertIsInstance(max_total_num_tokens, int)
version = response_json["version"]
self.assertIsInstance(version, str)
def test_get_server_info_concurrent(self):
"""Make sure the concurrent get_server_info doesn't crash the server."""
tp = ThreadPoolExecutor(max_workers=30)
def s():
server_info = requests.get(self.base_url + "/get_server_info")
server_info.json()
futures = []
for _ in range(4):
futures.append(tp.submit(s))
for f in futures:
f.result()
if __name__ == "__main__":
unittest.main()