Support penalty in overlap mode; return logprob with chunked prefill; improve benchmark scripts (#3988)

Co-authored-by: SangBin Cho <rkooo567@gmail.com>
Co-authored-by: dhou-xai <dhou@x.ai>
Co-authored-by: Hanming Lu <hanming_lu@berkeley.edu>
This commit is contained in:
Lianmin Zheng
2025-03-03 00:12:04 -08:00
parent 0194948fd9
commit ac2387279e
86 changed files with 4116 additions and 2015 deletions

View File

@@ -30,7 +30,9 @@ class TestSRTBackend(unittest.TestCase):
@classmethod
def setUpClass(cls):
cls.backend = sgl.Runtime(model_path=DEFAULT_MODEL_NAME_FOR_TEST)
cls.backend = sgl.Runtime(
model_path=DEFAULT_MODEL_NAME_FOR_TEST, cuda_graph_max_bs=4
)
sgl.set_default_backend(cls.backend)
@classmethod

View File

@@ -12,7 +12,6 @@ suites = {
"models/test_generation_models.py",
"models/test_qwen_models.py",
"models/test_reward_models.py",
"sampling/penaltylib",
"test_abort.py",
"test_chunked_prefill.py",
"test_custom_allreduce.py",
@@ -31,6 +30,7 @@ suites = {
"test_no_chunked_prefill.py",
"test_no_overlap_scheduler.py",
"test_openai_server.py",
"test_penalty.py",
"test_pytorch_sampling_backend.py",
"test_radix_attention.py",
"test_regex_constrained.py",
@@ -38,7 +38,8 @@ suites = {
"test_request_length_validation.py",
"test_retract_decode.py",
"test_server_args.py",
"test_session_control.py",
# Disabled temporarily
# "test_session_control.py",
"test_skip_tokenizer_init.py",
"test_srt_engine.py",
"test_srt_endpoint.py",
@@ -64,9 +65,6 @@ suites = {
# Disable temporarily
# "test_nightly_math_eval.py",
],
"sampling/penaltylib": glob.glob(
"sampling/penaltylib/**/test_*.py", recursive=True
),
}
# Expand suite
@@ -83,7 +81,7 @@ if __name__ == "__main__":
arg_parser.add_argument(
"--timeout-per-file",
type=int,
default=2000,
default=1800,
help="The time limit for running one file in seconds.",
)
arg_parser.add_argument(

View File

@@ -1,97 +0,0 @@
import unittest
from typing import List
import torch
from sglang.srt.sampling.penaltylib.penalizers.frequency_penalty import (
BatchedFrequencyPenalizer,
)
from sglang.test.srt.sampling.penaltylib.utils import (
BaseBatchedPenalizerTest,
MockSamplingParams,
Step,
StepType,
Subject,
)
class BaseBatchedFrequencyPenalizerTest(BaseBatchedPenalizerTest):
Penalizer = BatchedFrequencyPenalizer
frequency_penalty: float
def setUp(self):
if self.__class__ == BaseBatchedFrequencyPenalizerTest:
self.skipTest("Base class for frequency_penalty tests")
super().setUp()
def _create_subject(self, frequency_penalty: float) -> Subject:
return Subject(
sampling_params=MockSamplingParams(
frequency_penalty=frequency_penalty,
),
steps=[
Step(
type=StepType.INPUT,
token_ids=[0, 1, 2],
expected_tensors={
"frequency_penalties": self.tensor(
[[frequency_penalty] * self.vocab_size], dtype=torch.float32
),
"cumulated_frequency_penalties": self.tensor(
[[0.0] * self.vocab_size], dtype=torch.float32
),
},
expected_logits=self.tensor(
[[1] * self.vocab_size], dtype=torch.float32
),
),
Step(
type=StepType.OUTPUT,
token_ids=[
1,
2,
2,
], # This is the output ids of one request in three steps.
expected_tensors={
"frequency_penalties": self.tensor(
[[frequency_penalty] * self.vocab_size], dtype=torch.float32
),
"cumulated_frequency_penalties": self.tensor(
[
[
frequency_penalty * i if i in {1, 2} else 0.0
for i in range(self.vocab_size)
],
],
dtype=torch.float32,
),
},
expected_logits=self.tensor(
[
[
1.0 - frequency_penalty * i if i in {1, 2} else 1.0
for i in range(self.vocab_size)
],
],
dtype=torch.float32,
),
),
],
)
def create_test_subjects(self) -> List[Subject]:
self.enabled = self._create_subject(frequency_penalty=self.frequency_penalty)
self.disabled = self._create_subject(frequency_penalty=0.0)
class TestBatchedFrequencyPenalizerPositiveValue(BaseBatchedFrequencyPenalizerTest):
frequency_penalty = 0.12
class TestBatchedFrequencyPenalizerNegativeValue(BaseBatchedFrequencyPenalizerTest):
frequency_penalty = -0.12
if __name__ == "__main__":
unittest.main()

View File

@@ -1,152 +0,0 @@
import unittest
from typing import List
import torch
from sglang.srt.sampling.penaltylib.penalizers.min_new_tokens import (
BatchedMinNewTokensPenalizer,
)
from sglang.test.srt.sampling.penaltylib.utils import (
BaseBatchedPenalizerTest,
MockSamplingParams,
Step,
StepType,
Subject,
)
MIN_NEW_TOKENS = 2
EOS_TOKEN_ID = 4
STOP_TOKEN_ID = 3
ALL_STOP_TOKEN_IDS = {STOP_TOKEN_ID, EOS_TOKEN_ID}
class TestBatchedMinNewTokensPenalizer(BaseBatchedPenalizerTest):
Penalizer = BatchedMinNewTokensPenalizer
def _create_subject(self, min_new_tokens: int) -> Subject:
return Subject(
eos_token_id=EOS_TOKEN_ID,
sampling_params=MockSamplingParams(
min_new_tokens=min_new_tokens,
stop_token_ids={STOP_TOKEN_ID},
),
steps=[
Step(
type=StepType.INPUT,
token_ids=[0, 1, 2],
expected_tensors={
"min_new_tokens": self.tensor(
[[min_new_tokens]], dtype=torch.int32
),
"stop_token_penalties": self.tensor(
[
[
float("-inf") if i in ALL_STOP_TOKEN_IDS else 0
for i in range(self.vocab_size)
]
],
dtype=torch.float32,
),
"len_output_tokens": self.tensor([[0]], dtype=torch.int32),
},
expected_logits=(
self.tensor(
[
[
float("-inf") if i in ALL_STOP_TOKEN_IDS else 1
for i in range(self.vocab_size)
]
],
dtype=torch.float32,
)
if min_new_tokens > 0
else torch.ones(
(1, self.vocab_size),
dtype=torch.float32,
device=self.device,
)
),
),
Step(
type=StepType.OUTPUT,
token_ids=[0],
expected_tensors={
"min_new_tokens": self.tensor(
[[min_new_tokens]], dtype=torch.int32
),
"stop_token_penalties": self.tensor(
[
[
float("-inf") if i in ALL_STOP_TOKEN_IDS else 0
for i in range(self.vocab_size)
]
],
dtype=torch.float32,
),
"len_output_tokens": self.tensor([[1]], dtype=torch.int32),
},
expected_logits=(
self.tensor(
[
[
float("-inf") if i in ALL_STOP_TOKEN_IDS else 1
for i in range(self.vocab_size)
]
],
dtype=torch.float32,
)
if min_new_tokens > 1
else torch.ones(
(1, self.vocab_size),
dtype=torch.float32,
device=self.device,
)
),
),
Step(
type=StepType.OUTPUT,
token_ids=[0],
expected_tensors={
"min_new_tokens": self.tensor(
[[min_new_tokens]], dtype=torch.int32
),
"stop_token_penalties": self.tensor(
[
[
float("-inf") if i in ALL_STOP_TOKEN_IDS else 0
for i in range(self.vocab_size)
]
],
dtype=torch.float32,
),
"len_output_tokens": self.tensor([[2]], dtype=torch.int32),
},
expected_logits=(
self.tensor(
[
[
float("-inf") if i in ALL_STOP_TOKEN_IDS else 1
for i in range(self.vocab_size)
]
],
dtype=torch.float32,
)
if min_new_tokens > 2
else torch.ones(
(1, self.vocab_size),
dtype=torch.float32,
device=self.device,
)
),
),
],
)
def create_test_subjects(self) -> List[Subject]:
self.enabled = self._create_subject(min_new_tokens=MIN_NEW_TOKENS)
self.disabled = self._create_subject(min_new_tokens=0.0)
if __name__ == "__main__":
unittest.main()

View File

@@ -1,93 +0,0 @@
import unittest
from typing import List
import torch
from sglang.srt.sampling.penaltylib.penalizers.presence_penalty import (
BatchedPresencePenalizer,
)
from sglang.test.srt.sampling.penaltylib.utils import (
BaseBatchedPenalizerTest,
MockSamplingParams,
Step,
StepType,
Subject,
)
class BaseBatchedPresencePenalizerTest(BaseBatchedPenalizerTest):
Penalizer = BatchedPresencePenalizer
presence_penalty: float
def setUp(self):
if self.__class__ == BaseBatchedPresencePenalizerTest:
self.skipTest("Base class for presence_penalty tests")
super().setUp()
def _create_subject(self, presence_penalty: float) -> Subject:
return Subject(
sampling_params=MockSamplingParams(
presence_penalty=presence_penalty,
),
steps=[
Step(
type=StepType.INPUT,
token_ids=[0, 1, 2],
expected_tensors={
"presence_penalties": self.tensor(
[[presence_penalty] * self.vocab_size], dtype=torch.float32
),
"cumulated_presence_penalties": self.tensor(
[[0.0] * self.vocab_size], dtype=torch.float32
),
},
expected_logits=self.tensor(
[[1] * self.vocab_size], dtype=torch.float32
),
),
Step(
type=StepType.OUTPUT,
token_ids=[1, 2, 2],
expected_tensors={
"presence_penalties": self.tensor(
[[presence_penalty] * self.vocab_size], dtype=torch.float32
),
"cumulated_presence_penalties": self.tensor(
[
[
presence_penalty if i in {1, 2} else 0.0
for i in range(self.vocab_size)
],
],
dtype=torch.float32,
),
},
expected_logits=self.tensor(
[
[
1.0 - presence_penalty if i in {1, 2} else 1.0
for i in range(self.vocab_size)
],
],
dtype=torch.float32,
),
),
],
)
def create_test_subjects(self) -> List[Subject]:
self.enabled = self._create_subject(presence_penalty=self.presence_penalty)
self.disabled = self._create_subject(presence_penalty=0.0)
class TestBatchedPresencePenalizerPositiveValue(BaseBatchedPresencePenalizerTest):
presence_penalty = 0.12
class TestBatchedPresencePenalizerNegativeValue(BaseBatchedPresencePenalizerTest):
presence_penalty = -0.12
if __name__ == "__main__":
unittest.main()

View File

@@ -1,87 +0,0 @@
import unittest
from typing import List
import torch
from sglang.srt.sampling.penaltylib.penalizers.repetition_penalty import (
BatchedRepetitionPenalizer,
)
from sglang.test.srt.sampling.penaltylib.utils import (
BaseBatchedPenalizerTest,
MockSamplingParams,
Step,
StepType,
Subject,
)
REPETITION_PENALTY = 2.0
class TestBatchedRepetitionPenalizer(BaseBatchedPenalizerTest):
Penalizer = BatchedRepetitionPenalizer
def _create_subject(self, repetition_penalty: float) -> Subject:
l = 1.0 / repetition_penalty
return Subject(
sampling_params=MockSamplingParams(
repetition_penalty=repetition_penalty,
),
steps=[
Step(
type=StepType.INPUT,
token_ids=[0, 1, 2],
expected_tensors={
"repetition_penalties": self.tensor(
[[repetition_penalty] * self.vocab_size],
dtype=torch.float32,
),
"cumulated_repetition_penalties": (
self.tensor(
[[2.0, 2.0, 2.0, 1.0, 1.0]], dtype=torch.float32
)
if repetition_penalty != 1.0
else self.tensor(
[[1.0] * self.vocab_size], dtype=torch.float32
)
),
},
expected_logits=(
self.tensor([[l, l, l, 1.0, 1.0]], dtype=torch.float32)
if repetition_penalty != 1.0
else self.tensor([[1.0] * self.vocab_size], dtype=torch.float32)
),
),
Step(
type=StepType.OUTPUT,
token_ids=[0, 1, 3],
expected_tensors={
"repetition_penalties": self.tensor(
[[repetition_penalty] * self.vocab_size],
dtype=torch.float32,
),
"cumulated_repetition_penalties": (
self.tensor(
[[2.0, 2.0, 2.0, 2.0, 1.0]], dtype=torch.float32
)
if repetition_penalty != 1.0
else self.tensor(
[[1.0] * self.vocab_size], dtype=torch.float32
)
),
},
expected_logits=(
self.tensor([[l, l, l, l, 1.0]], dtype=torch.float32)
if repetition_penalty != 1.0
else self.tensor([[1.0] * self.vocab_size], dtype=torch.float32)
),
),
],
)
def create_test_subjects(self) -> List[Subject]:
self.enabled = self._create_subject(repetition_penalty=REPETITION_PENALTY)
self.disabled = self._create_subject(repetition_penalty=1.0)
if __name__ == "__main__":
unittest.main()

View File

@@ -1,114 +0,0 @@
import json
import unittest
from multiprocessing import Process
import requests
from sglang.srt.utils import kill_process_tree
from sglang.test.test_utils import (
DEFAULT_SMALL_MODEL_NAME_FOR_TEST,
DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
DEFAULT_URL_FOR_TEST,
popen_launch_server,
)
class TestBatchPenalizerE2E(unittest.TestCase):
@classmethod
def setUpClass(cls):
cls.model = DEFAULT_SMALL_MODEL_NAME_FOR_TEST
cls.base_url = DEFAULT_URL_FOR_TEST
cls.process = popen_launch_server(
cls.model,
cls.base_url,
timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
other_args=(
"--random-seed",
"0",
),
)
@classmethod
def tearDownClass(cls):
kill_process_tree(cls.process.pid)
def run_decode(
self,
return_logprob=True,
top_logprobs_num=5,
return_text=True,
n=1,
**sampling_params,
):
response = requests.post(
self.base_url + "/generate",
json={
# prompt that is supposed to generate < 32 tokens
"text": "<|start_header_id|>user<|end_header_id|>\n\nWhat is the answer for 1 + 1 = ?<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n",
"sampling_params": {
"max_new_tokens": 32,
"n": n,
**sampling_params,
},
"stream": False,
"return_logprob": return_logprob,
"top_logprobs_num": top_logprobs_num,
"return_text_in_logprobs": return_text,
"logprob_start_len": 0,
},
)
assert response.status_code == 200, "Request failed: " + response.text
def test_default_values(self):
self.run_decode()
def test_mixed(self):
"""
Sends two requests with one with penalizers disabled, and the other with penalizers enabled.
This will cause two different {ScheduleBatch} to be initialized and eventually gets merged.
Merging batch with penalizers enabled with enabled, or disabled is trivial. However disabled + enabled is not.
This is because the penalizer will not be prepared if it is not required, then it will be prepared during the merge.
This test triggers the merge of disabled + enabled.
"""
processes = []
p = Process(
target=self.run_decode,
)
processes.append(p)
p.start()
p = Process(
target=self.run_decode,
kwargs={
"frequency_penalty": 2,
"min_new_tokens": 16,
"presence_penalty": 2,
"repetition_penalty": 2,
},
)
processes.append(p)
p.start()
for p in processes:
p.join()
def test_frequency_penalty(self):
self.run_decode(frequency_penalty=2)
def test_min_new_tokens(self):
self.run_decode(min_new_tokens=16)
def test_presence_penalty(self):
self.run_decode(presence_penalty=2)
def test_repetition_penalty(self):
self.run_decode(repetition_penalty=2)
if __name__ == "__main__":
unittest.main(verbosity=3)

View File

@@ -138,6 +138,7 @@ class TestBenchServing(unittest.TestCase):
model=DEFAULT_EAGLE_TARGET_MODEL_FOR_TEST,
num_prompts=50,
request_rate=1,
sharegpt_context_len=3072,
disable_ignore_eos=True,
dataset_name="sharegpt",
other_server_args=[
@@ -148,22 +149,23 @@ class TestBenchServing(unittest.TestCase):
"--speculative-num-steps",
"5",
"--speculative-eagle-topk",
"8",
"4",
"--speculative-num-draft-tokens",
"64",
"16",
"--mem-fraction-static",
"0.7",
"--cuda-graph-max-bs",
"32",
],
need_warmup=True,
)
if is_in_ci():
write_github_step_summary(
f"### test_online_latency_eagle\n"
f'median_e2e_latency_ms : {res["median_e2e_latency_ms"]:.2f} ms\n'
f'accept_length : {res["accept_length"]:.2f} \n'
)
self.assertLess(res["median_e2e_latency_ms"], 450)
self.assertLess(res["median_e2e_latency_ms"], 700)
self.assertGreater(res["accept_length"], 2.50)
def test_moe_offline_throughput_default(self):
res = run_bench_serving(

View File

@@ -12,7 +12,9 @@ from sglang.test.test_utils import (
DEFAULT_MODEL_NAME_FOR_TEST,
DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
DEFAULT_URL_FOR_TEST,
is_in_ci,
popen_launch_server,
write_github_step_summary,
)
@@ -44,6 +46,9 @@ class TestEvalAccuracyLarge(unittest.TestCase):
metrics = run_eval(args)
self.assertGreater(metrics["score"], 0.71)
if is_in_ci():
write_github_step_summary(f"### test_mmlu\n" f'{metrics["score"]=:.4f}\n')
def test_human_eval(self):
args = SimpleNamespace(
base_url=self.base_url,
@@ -56,6 +61,11 @@ class TestEvalAccuracyLarge(unittest.TestCase):
metrics = run_eval(args)
self.assertGreater(metrics["score"], 0.64)
if is_in_ci():
write_github_step_summary(
f"### test_human_eval\n" f'{metrics["score"]=:.4f}\n'
)
def test_mgsm_en(self):
args = SimpleNamespace(
base_url=self.base_url,
@@ -68,6 +78,11 @@ class TestEvalAccuracyLarge(unittest.TestCase):
metrics = run_eval(args)
self.assertGreater(metrics["score"], 0.835)
if is_in_ci():
write_github_step_summary(
f"### test_mgsm_en\n" f'{metrics["score"]=:.4f}\n'
)
if __name__ == "__main__":
unittest.main()

View File

@@ -0,0 +1,27 @@
import unittest
from sglang.test.test_utils import (
DEFAULT_SMALL_MODEL_NAME_FOR_TEST,
DEFAULT_URL_FOR_TEST,
popen_launch_server,
)
class TestHealthCheck(unittest.TestCase):
def test_health_check(self):
"""Test that metrics endpoint returns data when enabled"""
with self.assertRaises(TimeoutError):
popen_launch_server(
DEFAULT_SMALL_MODEL_NAME_FOR_TEST,
DEFAULT_URL_FOR_TEST,
timeout=60,
other_args=[
"--disable-cuda-graph",
"--json-model-override-args",
'{"architectures": ["LlamaForCausalLMForHealthTest"]}',
],
)
if __name__ == "__main__":
unittest.main()

View File

@@ -49,7 +49,7 @@ class TestHiddenState(unittest.TestCase):
with torch.inference_mode():
hf_out = model(
torch.tensor(
[input_id + output["token_ids"][:-1]], device=model.device
[input_id + output["output_ids"][:-1]], device=model.device
),
output_hidden_states=True,
)

View File

@@ -56,11 +56,13 @@ class TestEnableMetrics(unittest.TestCase):
"sglang:gen_throughput",
"sglang:num_queue_reqs",
"sglang:cache_hit_rate",
"sglang:spec_accept_length",
"sglang:prompt_tokens_total",
"sglang:generation_tokens_total",
"sglang:num_requests_total",
"sglang:time_to_first_token_seconds",
"sglang:time_per_output_token_seconds",
"sglang:inter_token_latency_seconds",
"sglang:e2e_request_latency_seconds",
]

View File

@@ -141,7 +141,7 @@ class TestDeepseekV3MTP(unittest.TestCase):
metrics = run_eval_few_shot_gsm8k(args)
print(metrics)
self.assertGreater(metrics["accuracy"], 0.62)
self.assertGreater(metrics["accuracy"], 0.60)
if __name__ == "__main__":

91
test/srt/test_penalty.py Normal file
View File

@@ -0,0 +1,91 @@
import json
import random
import unittest
from concurrent.futures import ThreadPoolExecutor
import requests
from sglang.srt.utils import kill_process_tree
from sglang.test.test_utils import (
DEFAULT_SMALL_MODEL_NAME_FOR_TEST,
DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
DEFAULT_URL_FOR_TEST,
popen_launch_server,
)
class TestPenalty(unittest.TestCase):
@classmethod
def setUpClass(cls):
cls.model = DEFAULT_SMALL_MODEL_NAME_FOR_TEST
cls.base_url = DEFAULT_URL_FOR_TEST
cls.process = popen_launch_server(
cls.model,
cls.base_url,
timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
)
@classmethod
def tearDownClass(cls):
kill_process_tree(cls.process.pid)
def run_decode(self, sampling_params):
return_logprob = True
top_logprobs_num = 5
return_text = True
n = 1
response = requests.post(
self.base_url + "/generate",
json={
# prompt that is supposed to generate < 32 tokens
"text": "<|start_header_id|>user<|end_header_id|>\n\nWhat is the answer for 1 + 1 = ?<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n",
"sampling_params": {
"max_new_tokens": 32,
"n": n,
**sampling_params,
},
"return_logprob": return_logprob,
"top_logprobs_num": top_logprobs_num,
"return_text_in_logprobs": return_text,
"logprob_start_len": 0,
},
)
self.assertEqual(response.status_code, 200)
print(json.dumps(response.json()))
print("=" * 100)
def test_default_values(self):
self.run_decode({})
def test_frequency_penalty(self):
self.run_decode({"frequency_penalty": 2})
def test_min_new_tokens(self):
self.run_decode({"min_new_tokens": 16})
def test_presence_penalty(self):
self.run_decode({"presence_penalty": 2})
def test_mixed(self):
args = [
{},
{},
{},
{"frequency_penalty": 2},
{"min_new_tokens": 16},
{"presence_penalty": 1},
{"frequency_penalty": 0.2},
{"min_new_tokens": 8},
{"presence_penalty": 0.4},
{"presence_penalty": 0.4, "frequency_penalty": 2},
{"min_new_tokens": 12, "frequency_penalty": 2},
]
random.shuffle(args * 5)
with ThreadPoolExecutor(8) as executor:
list(executor.map(self.run_decode, args))
if __name__ == "__main__":
unittest.main(verbosity=3)

View File

@@ -70,7 +70,10 @@ class TestSessionControl(unittest.TestCase):
first_rid = None
outputs_from_session = []
logprobs_from_session = []
cur_logprob_start_len = 0
for i, chunk_ids in enumerate(chunks_ids):
max_new_tokens = gen_len if i > 0 else 1 # prefill only for the first chunk
response = requests.post(
self.base_url + "/generate",
json={
@@ -83,12 +86,12 @@ class TestSessionControl(unittest.TestCase):
},
"sampling_params": {
"temperature": 0,
"max_new_tokens": (
gen_len if i > 0 else 1
), # prefill only for the first chunk
"max_new_tokens": max_new_tokens,
"no_stop_trim": True,
"skip_special_tokens": False,
},
"return_logprob": True,
"logprob_start_len": cur_logprob_start_len - 1,
},
).json()
rid = response["meta_info"]["id"]
@@ -96,8 +99,39 @@ class TestSessionControl(unittest.TestCase):
first_rid = rid
if i > 0:
outputs_from_session.append(response["text"])
logprobs_from_session.extend(
[
round(sublist[0], 2)
for sublist in response["meta_info"]["output_token_logprobs"]
]
)
cur_logprob_start_len += len(chunk_ids) + max_new_tokens
# query with a logprob_start_len longer than the request, should see error
response = requests.post(
self.base_url + "/generate",
json={
"input_ids": chunk_ids,
"session_params": {
"id": session_id,
"rid": rid,
"offset": -1,
"replace": True,
},
"sampling_params": {
"temperature": 0,
"max_new_tokens": max_new_tokens,
"no_stop_trim": True,
"skip_special_tokens": False,
},
"return_logprob": True,
"logprob_start_len": cur_logprob_start_len + len(chunk_ids),
},
).json()
assert "Request with a lower logprob_start_len" in response["error"]["message"]
# backtrack to the first request and regenerate
cur_logprob_start_len = 0
response = requests.post(
self.base_url + "/generate",
json={
@@ -114,9 +148,17 @@ class TestSessionControl(unittest.TestCase):
"no_stop_trim": True,
"skip_special_tokens": False,
},
"return_logprob": True,
"logprob_start_len": cur_logprob_start_len,
},
).json()
outputs_from_session.append(response["text"])
logprobs_from_session.extend(
[
round(sublist[0], 2)
for sublist in response["meta_info"]["output_token_logprobs"]
]
)
# query with a non-existing rid (the last one should be disappeared becuase of backtrack), should see abort
response = requests.post(
@@ -135,6 +177,7 @@ class TestSessionControl(unittest.TestCase):
"no_stop_trim": True,
"skip_special_tokens": False,
},
"return_logprob": True,
},
).json()
assert response["meta_info"]["finish_reason"]["type"] == "abort"
@@ -162,6 +205,7 @@ class TestSessionControl(unittest.TestCase):
"no_stop_trim": True,
"skip_special_tokens": False,
},
"return_logprob": True,
},
).json()
assert response["meta_info"]["finish_reason"]["type"] == "abort"
@@ -172,6 +216,7 @@ class TestSessionControl(unittest.TestCase):
input_ids_first_req = None
input_ids = []
outputs_normal = []
logprobs_normal = []
for i, chunk_ids in enumerate(chunks_ids):
input_ids += chunk_ids
response = requests.post(
@@ -186,6 +231,7 @@ class TestSessionControl(unittest.TestCase):
"no_stop_trim": True,
"skip_special_tokens": False,
},
"return_logprob": True,
},
).json()
if i > 0:
@@ -194,6 +240,12 @@ class TestSessionControl(unittest.TestCase):
output_ids = output_ids[1:]
input_ids += output_ids[:-1]
outputs_normal.append(response["text"])
logprobs_normal.extend(
[
round(sublist[0], 2)
for sublist in response["meta_info"]["output_token_logprobs"]
]
)
if i == 0:
input_ids_first_req = input_ids.copy()
@@ -208,17 +260,31 @@ class TestSessionControl(unittest.TestCase):
"no_stop_trim": True,
"skip_special_tokens": False,
},
"return_logprob": True,
},
).json()
outputs_normal.append(response["text"])
logprobs_normal.extend(
[
round(sublist[0], 2)
for sublist in response["meta_info"]["output_token_logprobs"]
]
)
print("outputs from chunked queries with session control:")
print(outputs_from_session)
print("outputs from normal queries:")
print(outputs_normal)
assert (
outputs_from_session == outputs_normal
), f"outputs_from_session: {outputs_from_session}, outputs_normal: {outputs_normal}"
assert outputs_from_session == outputs_normal
print("logprobs from chunked queries with session control:")
print(logprobs_from_session)
print("logprobs from normal queries:")
print(logprobs_normal)
assert len(logprobs_from_session) == len(
logprobs_normal
), "logprobs must have equal length"
for a, b in zip(logprobs_from_session, logprobs_normal):
assert abs(a - b) <= 0.1, f"logprobs {a} and {b} differ by more than 0.1"
async def async_generate(self, payload):
url = self.base_url + "/generate"

View File

@@ -1,3 +1,8 @@
"""
python3 -m unittest test_skip_tokenizer_init.TestSkipTokenizerInit.test_parallel_sample
python3 -m unittest test_skip_tokenizer_init.TestSkipTokenizerInit.run_decode_stream
"""
import json
import unittest
@@ -12,42 +17,26 @@ from sglang.test.test_utils import (
popen_launch_server,
)
_server_process = None
_base_url = None
_tokenizer = None
def setUpModule():
"""
Launch the server once before all tests and initialize the tokenizer.
"""
global _server_process, _base_url, _tokenizer
_server_process = popen_launch_server(
DEFAULT_SMALL_MODEL_NAME_FOR_TEST,
DEFAULT_URL_FOR_TEST,
timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
other_args=["--skip-tokenizer-init"],
)
_base_url = DEFAULT_URL_FOR_TEST
_tokenizer = AutoTokenizer.from_pretrained(
DEFAULT_SMALL_MODEL_NAME_FOR_TEST, use_fast=False
)
print(">>> setUpModule: Server launched, tokenizer ready")
def tearDownModule():
"""
Terminate the server once after all tests have completed.
"""
global _server_process
if _server_process is not None:
kill_process_tree(_server_process.pid)
_server_process = None
print(">>> tearDownModule: Server terminated")
class TestSkipTokenizerInit(unittest.TestCase):
@classmethod
def setUpClass(cls):
cls.model = DEFAULT_SMALL_MODEL_NAME_FOR_TEST
cls.base_url = DEFAULT_URL_FOR_TEST
cls.process = popen_launch_server(
cls.model,
cls.base_url,
timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
other_args=["--skip-tokenizer-init", "--stream-output"],
)
cls.tokenizer = AutoTokenizer.from_pretrained(
DEFAULT_SMALL_MODEL_NAME_FOR_TEST, use_fast=False
)
@classmethod
def tearDownClass(cls):
kill_process_tree(cls.process.pid)
def run_decode(
self,
prompt_text="The capital of France is",
@@ -56,19 +45,19 @@ class TestSkipTokenizerInit(unittest.TestCase):
top_logprobs_num=0,
n=1,
):
input_ids = _tokenizer(prompt_text, return_tensors="pt")["input_ids"][
input_ids = self.tokenizer(prompt_text, return_tensors="pt")["input_ids"][
0
].tolist()
response = requests.post(
_base_url + "/generate",
self.base_url + "/generate",
json={
"input_ids": input_ids,
"sampling_params": {
"temperature": 0 if n == 1 else 0.5,
"max_new_tokens": max_new_tokens,
"n": n,
"stop_token_ids": [_tokenizer.eos_token_id],
"stop_token_ids": [self.tokenizer.eos_token_id],
},
"stream": False,
"return_logprob": return_logprob,
@@ -83,13 +72,13 @@ class TestSkipTokenizerInit(unittest.TestCase):
if item["meta_info"]["finish_reason"]["type"] == "stop":
self.assertEqual(
item["meta_info"]["finish_reason"]["matched"],
_tokenizer.eos_token_id,
self.tokenizer.eos_token_id,
)
elif item["meta_info"]["finish_reason"]["type"] == "length":
self.assertEqual(
len(item["token_ids"]), item["meta_info"]["completion_tokens"]
len(item["output_ids"]), item["meta_info"]["completion_tokens"]
)
self.assertEqual(len(item["token_ids"]), max_new_tokens)
self.assertEqual(len(item["output_ids"]), max_new_tokens)
self.assertEqual(item["meta_info"]["prompt_tokens"], len(input_ids))
if return_logprob:
@@ -113,6 +102,63 @@ class TestSkipTokenizerInit(unittest.TestCase):
print("=" * 100)
def run_decode_stream(self, return_logprob=False, top_logprobs_num=0, n=1):
max_new_tokens = 32
input_ids = [128000, 791, 6864, 315, 9822, 374] # The capital of France is
requests.post(self.base_url + "/flush_cache")
response = requests.post(
self.base_url + "/generate",
json={
"input_ids": input_ids,
"sampling_params": {
"temperature": 0 if n == 1 else 0.5,
"max_new_tokens": max_new_tokens,
"n": n,
"stop_token_ids": [119690],
},
"stream": False,
"return_logprob": return_logprob,
"top_logprobs_num": top_logprobs_num,
"logprob_start_len": 0,
},
)
ret = response.json()
print(json.dumps(ret))
output_ids = ret["output_ids"]
requests.post(self.base_url + "/flush_cache")
response_stream = requests.post(
self.base_url + "/generate",
json={
"input_ids": input_ids,
"sampling_params": {
"temperature": 0 if n == 1 else 0.5,
"max_new_tokens": max_new_tokens,
"n": n,
"stop_token_ids": [119690],
},
"stream": True,
"return_logprob": return_logprob,
"top_logprobs_num": top_logprobs_num,
"logprob_start_len": 0,
},
)
ret = response.json()
output_ids = ret["output_ids"]
print("output from non-streaming request:")
print(output_ids)
response_stream_json = []
for line in response_stream.iter_lines():
if line.startswith(b"data: ") and line[6:] != b"[DONE]":
response_stream_json.append(json.loads(line[6:]))
out_stream_ids = []
for x in response_stream_json:
out_stream_ids += x["output_ids"]
print("output from streaming request:")
print(out_stream_ids)
assert output_ids == out_stream_ids
def test_simple_decode(self):
self.run_decode()
@@ -126,6 +172,9 @@ class TestSkipTokenizerInit(unittest.TestCase):
def test_eos_behavior(self):
self.run_decode(max_new_tokens=256)
def test_simple_decode_stream(self):
self.run_decode_stream()
if __name__ == "__main__":
unittest.main()

View File

@@ -8,6 +8,7 @@ import random
import time
import unittest
from concurrent.futures import ThreadPoolExecutor
from functools import partial
from typing import Optional
import numpy as np
@@ -20,6 +21,7 @@ from sglang.test.test_utils import (
DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
DEFAULT_URL_FOR_TEST,
popen_launch_server,
run_logprob_check,
)
@@ -35,7 +37,9 @@ class TestSRTEndpoint(unittest.TestCase):
other_args=(
"--enable-custom-logit-processor",
"--mem-fraction-static",
"0.8",
"0.7",
"--cuda-graph-max-bs",
"8",
),
)
@@ -131,7 +135,7 @@ class TestSRTEndpoint(unittest.TestCase):
for i, res in enumerate(response_json):
self.assertEqual(
res["meta_info"]["prompt_tokens"],
logprob_start_len + 1 + len(res["meta_info"]["input_token_logprobs"]),
logprob_start_len + len(res["meta_info"]["input_token_logprobs"]),
)
assert prompts[i].endswith(
"".join([x[-1] for x in res["meta_info"]["input_token_logprobs"]])
@@ -235,83 +239,15 @@ class TestSRTEndpoint(unittest.TestCase):
diff = np.abs(output_logprobs - output_logprobs_score)
max_diff = np.max(diff)
self.assertLess(max_diff, 0.25)
def run_logprob_check(self, arg):
(
input_len,
output_len,
temperature,
logprob_start_len,
return_logprob,
top_logprobs_num,
) = arg
input_ids = list(range(input_len))
response = requests.post(
self.base_url + "/generate",
json={
"input_ids": input_ids,
"sampling_params": {
"temperature": temperature,
"max_new_tokens": output_len,
},
"return_logprob": return_logprob,
"logprob_start_len": logprob_start_len,
"top_logprobs_num": top_logprobs_num,
},
)
response_json = response.json()
res = response_json
self.assertEqual(res["meta_info"]["prompt_tokens"], input_len)
self.assertEqual(res["meta_info"]["completion_tokens"], output_len)
# Test the number of tokens are correct
if return_logprob:
# This is because if logprob_start_len == 0, we added a padding for the first token.
# In other cases, we do not add the padding
delta = 0 if logprob_start_len == 0 else 1
self.assertEqual(
len(res["meta_info"]["input_token_logprobs"])
+ logprob_start_len
+ delta,
res["meta_info"]["prompt_tokens"],
)
self.assertEqual(len(res["meta_info"]["output_token_logprobs"]), output_len)
if top_logprobs_num:
self.assertEqual(
len(res["meta_info"]["input_top_logprobs"])
+ logprob_start_len
+ delta,
res["meta_info"]["prompt_tokens"],
)
self.assertEqual(
len(res["meta_info"]["output_top_logprobs"]), output_len
)
for i in range(output_len):
self.assertEqual(
len(res["meta_info"]["output_top_logprobs"][i]),
top_logprobs_num,
)
# Test the top-1 tokens are the same as output tokens if temperature == 0
if temperature == 0:
self.assertListEqual(
res["meta_info"]["output_token_logprobs"][i],
res["meta_info"]["output_top_logprobs"][i][0],
)
self.assertLess(max_diff, 0.35)
def test_logprob_mixed(self):
args = []
temperature = 0
# input_len, output_len, temperature, logprob_start_len, return_logprob, top_logprobs_num
for input_len in [1000, 2000]:
for input_len in [1000, 5000, 10000, 50000]:
for output_len in [4, 8]:
for logprob_start_len in [0, 500, 1000]:
for logprob_start_len in [0, 500, 2500, 5000, 25000]:
for return_logprob in [True, False]:
for top_logprobs_num in [0, 5]:
@@ -331,8 +267,9 @@ class TestSRTEndpoint(unittest.TestCase):
random.shuffle(args)
func = partial(run_logprob_check, self)
with ThreadPoolExecutor(8) as executor:
list(executor.map(self.run_logprob_check, args))
list(executor.map(func, args))
def test_logprob_grammar(self):
prompts = "Question: Is Paris the Capital of France? Answer:"
@@ -427,6 +364,77 @@ class TestSRTEndpoint(unittest.TestCase):
f"{target_token_id=}\n{sampled_tokens=}\n{custom_response=}",
)
def run_stateful_custom_logit_processor(
self, first_token_id: int | None, delay: int = 2
):
"""Test custom logit processor with custom params and state.
Should sample the first `delay` tokens normally, then output first_token_id and consecutive tokens after that.
If first_token_id is None, the custom logit processor won't be passed in.
"""
custom_params = {"token_id": first_token_id, "delay": 2}
class DeterministicStatefulLogitProcessor(CustomLogitProcessor):
"""A dummy logit processor that changes the logits to always
sample the given token id.
"""
def __call__(self, logits, custom_param_list):
assert logits.shape[0] == len(custom_param_list)
for i, param_dict in enumerate(custom_param_list):
if param_dict["delay"] > 0:
param_dict["delay"] -= 1
continue
if param_dict["delay"] == 0:
param_dict["delay"] -= 1
force_token = param_dict["token_id"]
else:
output_ids = param_dict["__req__"].output_ids
force_token = output_ids[-1] + 1
# Mask all other tokens
logits[i, :] = -float("inf")
# Assign highest probability to the specified token
logits[i, force_token] = 0.0
return logits
prompts = "Question: Is Paris the Capital of France? Answer:"
# Base case json data to be posted to the server.
base_json = {
"text": prompts,
"sampling_params": {"temperature": 0.0},
"return_logprob": True,
}
# Custom json data with custom logit processor and params.
custom_json = base_json.copy()
# Only set the custom logit processor if target_token_id is not None.
if first_token_id is not None:
custom_json["custom_logit_processor"] = (
DeterministicStatefulLogitProcessor().to_str()
)
custom_json["sampling_params"]["custom_params"] = custom_params
custom_response = requests.post(
self.base_url + "/generate",
json=custom_json,
).json()
output_token_logprobs = custom_response["meta_info"]["output_token_logprobs"]
sampled_tokens = [x[1] for x in output_token_logprobs]
# The logit processor should always sample the given token as the logits is deterministic.
if first_token_id is not None:
self.assertTrue(
all(
x == custom_params["token_id"] + k
for k, x in enumerate(sampled_tokens[custom_params["delay"] :])
),
# Print the detailed test case info if the test fails.
f"{first_token_id=}\n{sampled_tokens=}\n{custom_response=}",
)
def test_custom_logit_processor(self):
"""Test custom logit processor with a single request."""
self.run_custom_logit_processor(target_token_id=5)
@@ -438,6 +446,19 @@ class TestSRTEndpoint(unittest.TestCase):
with ThreadPoolExecutor(len(target_token_ids)) as executor:
list(executor.map(self.run_custom_logit_processor, target_token_ids))
def test_stateful_custom_logit_processor(self):
"""Test custom logit processor with a single request."""
self.run_stateful_custom_logit_processor(first_token_id=5)
def test_stateful_custom_logit_processor_batch_mixed(self):
"""Test a batch of requests mixed of requests with and without custom logit processor."""
target_token_ids = list(range(32)) + [None] * 16
random.shuffle(target_token_ids)
with ThreadPoolExecutor(len(target_token_ids)) as executor:
list(
executor.map(self.run_stateful_custom_logit_processor, target_token_ids)
)
def test_cache_tokens(self):
for _ in range(2):
time.sleep(1)
@@ -476,6 +497,21 @@ class TestSRTEndpoint(unittest.TestCase):
version = response_json["version"]
self.assertIsInstance(version, str)
def test_get_server_info_concurrent(self):
"""Make sure the concurrent get_server_info doesn't crash the server."""
tp = ThreadPoolExecutor(max_workers=30)
def s():
server_info = requests.get(self.base_url + "/get_server_info")
server_info.json()
futures = []
for _ in range(4):
futures.append(tp.submit(s))
for f in futures:
f.result()
if __name__ == "__main__":
unittest.main()

View File

@@ -168,9 +168,9 @@ def _run_subprocess(
hf_tokenizer = get_tokenizer(model_path, trust_remote_code=True)
hf_outputs = HFRunner.forward_generation_raw(
base_model=hf_model,
prompts=_PROMPTS,
max_new_tokens=_MAX_NEW_TOKENS,
base_model=hf_model,
tokenizer=hf_tokenizer,
lora_paths=None,
torch_dtype=_TORCH_DTYPE,