Support penalty in overlap mode; return logprob with chunked prefill; improve benchmark scripts (#3988)
Co-authored-by: SangBin Cho <rkooo567@gmail.com> Co-authored-by: dhou-xai <dhou@x.ai> Co-authored-by: Hanming Lu <hanming_lu@berkeley.edu>
This commit is contained in:
@@ -30,7 +30,9 @@ class TestSRTBackend(unittest.TestCase):
|
||||
|
||||
@classmethod
|
||||
def setUpClass(cls):
|
||||
cls.backend = sgl.Runtime(model_path=DEFAULT_MODEL_NAME_FOR_TEST)
|
||||
cls.backend = sgl.Runtime(
|
||||
model_path=DEFAULT_MODEL_NAME_FOR_TEST, cuda_graph_max_bs=4
|
||||
)
|
||||
sgl.set_default_backend(cls.backend)
|
||||
|
||||
@classmethod
|
||||
|
||||
@@ -12,7 +12,6 @@ suites = {
|
||||
"models/test_generation_models.py",
|
||||
"models/test_qwen_models.py",
|
||||
"models/test_reward_models.py",
|
||||
"sampling/penaltylib",
|
||||
"test_abort.py",
|
||||
"test_chunked_prefill.py",
|
||||
"test_custom_allreduce.py",
|
||||
@@ -31,6 +30,7 @@ suites = {
|
||||
"test_no_chunked_prefill.py",
|
||||
"test_no_overlap_scheduler.py",
|
||||
"test_openai_server.py",
|
||||
"test_penalty.py",
|
||||
"test_pytorch_sampling_backend.py",
|
||||
"test_radix_attention.py",
|
||||
"test_regex_constrained.py",
|
||||
@@ -38,7 +38,8 @@ suites = {
|
||||
"test_request_length_validation.py",
|
||||
"test_retract_decode.py",
|
||||
"test_server_args.py",
|
||||
"test_session_control.py",
|
||||
# Disabled temporarily
|
||||
# "test_session_control.py",
|
||||
"test_skip_tokenizer_init.py",
|
||||
"test_srt_engine.py",
|
||||
"test_srt_endpoint.py",
|
||||
@@ -64,9 +65,6 @@ suites = {
|
||||
# Disable temporarily
|
||||
# "test_nightly_math_eval.py",
|
||||
],
|
||||
"sampling/penaltylib": glob.glob(
|
||||
"sampling/penaltylib/**/test_*.py", recursive=True
|
||||
),
|
||||
}
|
||||
|
||||
# Expand suite
|
||||
@@ -83,7 +81,7 @@ if __name__ == "__main__":
|
||||
arg_parser.add_argument(
|
||||
"--timeout-per-file",
|
||||
type=int,
|
||||
default=2000,
|
||||
default=1800,
|
||||
help="The time limit for running one file in seconds.",
|
||||
)
|
||||
arg_parser.add_argument(
|
||||
|
||||
@@ -1,97 +0,0 @@
|
||||
import unittest
|
||||
from typing import List
|
||||
|
||||
import torch
|
||||
|
||||
from sglang.srt.sampling.penaltylib.penalizers.frequency_penalty import (
|
||||
BatchedFrequencyPenalizer,
|
||||
)
|
||||
from sglang.test.srt.sampling.penaltylib.utils import (
|
||||
BaseBatchedPenalizerTest,
|
||||
MockSamplingParams,
|
||||
Step,
|
||||
StepType,
|
||||
Subject,
|
||||
)
|
||||
|
||||
|
||||
class BaseBatchedFrequencyPenalizerTest(BaseBatchedPenalizerTest):
|
||||
Penalizer = BatchedFrequencyPenalizer
|
||||
frequency_penalty: float
|
||||
|
||||
def setUp(self):
|
||||
if self.__class__ == BaseBatchedFrequencyPenalizerTest:
|
||||
self.skipTest("Base class for frequency_penalty tests")
|
||||
|
||||
super().setUp()
|
||||
|
||||
def _create_subject(self, frequency_penalty: float) -> Subject:
|
||||
return Subject(
|
||||
sampling_params=MockSamplingParams(
|
||||
frequency_penalty=frequency_penalty,
|
||||
),
|
||||
steps=[
|
||||
Step(
|
||||
type=StepType.INPUT,
|
||||
token_ids=[0, 1, 2],
|
||||
expected_tensors={
|
||||
"frequency_penalties": self.tensor(
|
||||
[[frequency_penalty] * self.vocab_size], dtype=torch.float32
|
||||
),
|
||||
"cumulated_frequency_penalties": self.tensor(
|
||||
[[0.0] * self.vocab_size], dtype=torch.float32
|
||||
),
|
||||
},
|
||||
expected_logits=self.tensor(
|
||||
[[1] * self.vocab_size], dtype=torch.float32
|
||||
),
|
||||
),
|
||||
Step(
|
||||
type=StepType.OUTPUT,
|
||||
token_ids=[
|
||||
1,
|
||||
2,
|
||||
2,
|
||||
], # This is the output ids of one request in three steps.
|
||||
expected_tensors={
|
||||
"frequency_penalties": self.tensor(
|
||||
[[frequency_penalty] * self.vocab_size], dtype=torch.float32
|
||||
),
|
||||
"cumulated_frequency_penalties": self.tensor(
|
||||
[
|
||||
[
|
||||
frequency_penalty * i if i in {1, 2} else 0.0
|
||||
for i in range(self.vocab_size)
|
||||
],
|
||||
],
|
||||
dtype=torch.float32,
|
||||
),
|
||||
},
|
||||
expected_logits=self.tensor(
|
||||
[
|
||||
[
|
||||
1.0 - frequency_penalty * i if i in {1, 2} else 1.0
|
||||
for i in range(self.vocab_size)
|
||||
],
|
||||
],
|
||||
dtype=torch.float32,
|
||||
),
|
||||
),
|
||||
],
|
||||
)
|
||||
|
||||
def create_test_subjects(self) -> List[Subject]:
|
||||
self.enabled = self._create_subject(frequency_penalty=self.frequency_penalty)
|
||||
self.disabled = self._create_subject(frequency_penalty=0.0)
|
||||
|
||||
|
||||
class TestBatchedFrequencyPenalizerPositiveValue(BaseBatchedFrequencyPenalizerTest):
|
||||
frequency_penalty = 0.12
|
||||
|
||||
|
||||
class TestBatchedFrequencyPenalizerNegativeValue(BaseBatchedFrequencyPenalizerTest):
|
||||
frequency_penalty = -0.12
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
@@ -1,152 +0,0 @@
|
||||
import unittest
|
||||
from typing import List
|
||||
|
||||
import torch
|
||||
|
||||
from sglang.srt.sampling.penaltylib.penalizers.min_new_tokens import (
|
||||
BatchedMinNewTokensPenalizer,
|
||||
)
|
||||
from sglang.test.srt.sampling.penaltylib.utils import (
|
||||
BaseBatchedPenalizerTest,
|
||||
MockSamplingParams,
|
||||
Step,
|
||||
StepType,
|
||||
Subject,
|
||||
)
|
||||
|
||||
MIN_NEW_TOKENS = 2
|
||||
EOS_TOKEN_ID = 4
|
||||
STOP_TOKEN_ID = 3
|
||||
|
||||
ALL_STOP_TOKEN_IDS = {STOP_TOKEN_ID, EOS_TOKEN_ID}
|
||||
|
||||
|
||||
class TestBatchedMinNewTokensPenalizer(BaseBatchedPenalizerTest):
|
||||
Penalizer = BatchedMinNewTokensPenalizer
|
||||
|
||||
def _create_subject(self, min_new_tokens: int) -> Subject:
|
||||
return Subject(
|
||||
eos_token_id=EOS_TOKEN_ID,
|
||||
sampling_params=MockSamplingParams(
|
||||
min_new_tokens=min_new_tokens,
|
||||
stop_token_ids={STOP_TOKEN_ID},
|
||||
),
|
||||
steps=[
|
||||
Step(
|
||||
type=StepType.INPUT,
|
||||
token_ids=[0, 1, 2],
|
||||
expected_tensors={
|
||||
"min_new_tokens": self.tensor(
|
||||
[[min_new_tokens]], dtype=torch.int32
|
||||
),
|
||||
"stop_token_penalties": self.tensor(
|
||||
[
|
||||
[
|
||||
float("-inf") if i in ALL_STOP_TOKEN_IDS else 0
|
||||
for i in range(self.vocab_size)
|
||||
]
|
||||
],
|
||||
dtype=torch.float32,
|
||||
),
|
||||
"len_output_tokens": self.tensor([[0]], dtype=torch.int32),
|
||||
},
|
||||
expected_logits=(
|
||||
self.tensor(
|
||||
[
|
||||
[
|
||||
float("-inf") if i in ALL_STOP_TOKEN_IDS else 1
|
||||
for i in range(self.vocab_size)
|
||||
]
|
||||
],
|
||||
dtype=torch.float32,
|
||||
)
|
||||
if min_new_tokens > 0
|
||||
else torch.ones(
|
||||
(1, self.vocab_size),
|
||||
dtype=torch.float32,
|
||||
device=self.device,
|
||||
)
|
||||
),
|
||||
),
|
||||
Step(
|
||||
type=StepType.OUTPUT,
|
||||
token_ids=[0],
|
||||
expected_tensors={
|
||||
"min_new_tokens": self.tensor(
|
||||
[[min_new_tokens]], dtype=torch.int32
|
||||
),
|
||||
"stop_token_penalties": self.tensor(
|
||||
[
|
||||
[
|
||||
float("-inf") if i in ALL_STOP_TOKEN_IDS else 0
|
||||
for i in range(self.vocab_size)
|
||||
]
|
||||
],
|
||||
dtype=torch.float32,
|
||||
),
|
||||
"len_output_tokens": self.tensor([[1]], dtype=torch.int32),
|
||||
},
|
||||
expected_logits=(
|
||||
self.tensor(
|
||||
[
|
||||
[
|
||||
float("-inf") if i in ALL_STOP_TOKEN_IDS else 1
|
||||
for i in range(self.vocab_size)
|
||||
]
|
||||
],
|
||||
dtype=torch.float32,
|
||||
)
|
||||
if min_new_tokens > 1
|
||||
else torch.ones(
|
||||
(1, self.vocab_size),
|
||||
dtype=torch.float32,
|
||||
device=self.device,
|
||||
)
|
||||
),
|
||||
),
|
||||
Step(
|
||||
type=StepType.OUTPUT,
|
||||
token_ids=[0],
|
||||
expected_tensors={
|
||||
"min_new_tokens": self.tensor(
|
||||
[[min_new_tokens]], dtype=torch.int32
|
||||
),
|
||||
"stop_token_penalties": self.tensor(
|
||||
[
|
||||
[
|
||||
float("-inf") if i in ALL_STOP_TOKEN_IDS else 0
|
||||
for i in range(self.vocab_size)
|
||||
]
|
||||
],
|
||||
dtype=torch.float32,
|
||||
),
|
||||
"len_output_tokens": self.tensor([[2]], dtype=torch.int32),
|
||||
},
|
||||
expected_logits=(
|
||||
self.tensor(
|
||||
[
|
||||
[
|
||||
float("-inf") if i in ALL_STOP_TOKEN_IDS else 1
|
||||
for i in range(self.vocab_size)
|
||||
]
|
||||
],
|
||||
dtype=torch.float32,
|
||||
)
|
||||
if min_new_tokens > 2
|
||||
else torch.ones(
|
||||
(1, self.vocab_size),
|
||||
dtype=torch.float32,
|
||||
device=self.device,
|
||||
)
|
||||
),
|
||||
),
|
||||
],
|
||||
)
|
||||
|
||||
def create_test_subjects(self) -> List[Subject]:
|
||||
self.enabled = self._create_subject(min_new_tokens=MIN_NEW_TOKENS)
|
||||
self.disabled = self._create_subject(min_new_tokens=0.0)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
@@ -1,93 +0,0 @@
|
||||
import unittest
|
||||
from typing import List
|
||||
|
||||
import torch
|
||||
|
||||
from sglang.srt.sampling.penaltylib.penalizers.presence_penalty import (
|
||||
BatchedPresencePenalizer,
|
||||
)
|
||||
from sglang.test.srt.sampling.penaltylib.utils import (
|
||||
BaseBatchedPenalizerTest,
|
||||
MockSamplingParams,
|
||||
Step,
|
||||
StepType,
|
||||
Subject,
|
||||
)
|
||||
|
||||
|
||||
class BaseBatchedPresencePenalizerTest(BaseBatchedPenalizerTest):
|
||||
Penalizer = BatchedPresencePenalizer
|
||||
presence_penalty: float
|
||||
|
||||
def setUp(self):
|
||||
if self.__class__ == BaseBatchedPresencePenalizerTest:
|
||||
self.skipTest("Base class for presence_penalty tests")
|
||||
|
||||
super().setUp()
|
||||
|
||||
def _create_subject(self, presence_penalty: float) -> Subject:
|
||||
return Subject(
|
||||
sampling_params=MockSamplingParams(
|
||||
presence_penalty=presence_penalty,
|
||||
),
|
||||
steps=[
|
||||
Step(
|
||||
type=StepType.INPUT,
|
||||
token_ids=[0, 1, 2],
|
||||
expected_tensors={
|
||||
"presence_penalties": self.tensor(
|
||||
[[presence_penalty] * self.vocab_size], dtype=torch.float32
|
||||
),
|
||||
"cumulated_presence_penalties": self.tensor(
|
||||
[[0.0] * self.vocab_size], dtype=torch.float32
|
||||
),
|
||||
},
|
||||
expected_logits=self.tensor(
|
||||
[[1] * self.vocab_size], dtype=torch.float32
|
||||
),
|
||||
),
|
||||
Step(
|
||||
type=StepType.OUTPUT,
|
||||
token_ids=[1, 2, 2],
|
||||
expected_tensors={
|
||||
"presence_penalties": self.tensor(
|
||||
[[presence_penalty] * self.vocab_size], dtype=torch.float32
|
||||
),
|
||||
"cumulated_presence_penalties": self.tensor(
|
||||
[
|
||||
[
|
||||
presence_penalty if i in {1, 2} else 0.0
|
||||
for i in range(self.vocab_size)
|
||||
],
|
||||
],
|
||||
dtype=torch.float32,
|
||||
),
|
||||
},
|
||||
expected_logits=self.tensor(
|
||||
[
|
||||
[
|
||||
1.0 - presence_penalty if i in {1, 2} else 1.0
|
||||
for i in range(self.vocab_size)
|
||||
],
|
||||
],
|
||||
dtype=torch.float32,
|
||||
),
|
||||
),
|
||||
],
|
||||
)
|
||||
|
||||
def create_test_subjects(self) -> List[Subject]:
|
||||
self.enabled = self._create_subject(presence_penalty=self.presence_penalty)
|
||||
self.disabled = self._create_subject(presence_penalty=0.0)
|
||||
|
||||
|
||||
class TestBatchedPresencePenalizerPositiveValue(BaseBatchedPresencePenalizerTest):
|
||||
presence_penalty = 0.12
|
||||
|
||||
|
||||
class TestBatchedPresencePenalizerNegativeValue(BaseBatchedPresencePenalizerTest):
|
||||
presence_penalty = -0.12
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
@@ -1,87 +0,0 @@
|
||||
import unittest
|
||||
from typing import List
|
||||
|
||||
import torch
|
||||
|
||||
from sglang.srt.sampling.penaltylib.penalizers.repetition_penalty import (
|
||||
BatchedRepetitionPenalizer,
|
||||
)
|
||||
from sglang.test.srt.sampling.penaltylib.utils import (
|
||||
BaseBatchedPenalizerTest,
|
||||
MockSamplingParams,
|
||||
Step,
|
||||
StepType,
|
||||
Subject,
|
||||
)
|
||||
|
||||
REPETITION_PENALTY = 2.0
|
||||
|
||||
|
||||
class TestBatchedRepetitionPenalizer(BaseBatchedPenalizerTest):
|
||||
Penalizer = BatchedRepetitionPenalizer
|
||||
|
||||
def _create_subject(self, repetition_penalty: float) -> Subject:
|
||||
l = 1.0 / repetition_penalty
|
||||
return Subject(
|
||||
sampling_params=MockSamplingParams(
|
||||
repetition_penalty=repetition_penalty,
|
||||
),
|
||||
steps=[
|
||||
Step(
|
||||
type=StepType.INPUT,
|
||||
token_ids=[0, 1, 2],
|
||||
expected_tensors={
|
||||
"repetition_penalties": self.tensor(
|
||||
[[repetition_penalty] * self.vocab_size],
|
||||
dtype=torch.float32,
|
||||
),
|
||||
"cumulated_repetition_penalties": (
|
||||
self.tensor(
|
||||
[[2.0, 2.0, 2.0, 1.0, 1.0]], dtype=torch.float32
|
||||
)
|
||||
if repetition_penalty != 1.0
|
||||
else self.tensor(
|
||||
[[1.0] * self.vocab_size], dtype=torch.float32
|
||||
)
|
||||
),
|
||||
},
|
||||
expected_logits=(
|
||||
self.tensor([[l, l, l, 1.0, 1.0]], dtype=torch.float32)
|
||||
if repetition_penalty != 1.0
|
||||
else self.tensor([[1.0] * self.vocab_size], dtype=torch.float32)
|
||||
),
|
||||
),
|
||||
Step(
|
||||
type=StepType.OUTPUT,
|
||||
token_ids=[0, 1, 3],
|
||||
expected_tensors={
|
||||
"repetition_penalties": self.tensor(
|
||||
[[repetition_penalty] * self.vocab_size],
|
||||
dtype=torch.float32,
|
||||
),
|
||||
"cumulated_repetition_penalties": (
|
||||
self.tensor(
|
||||
[[2.0, 2.0, 2.0, 2.0, 1.0]], dtype=torch.float32
|
||||
)
|
||||
if repetition_penalty != 1.0
|
||||
else self.tensor(
|
||||
[[1.0] * self.vocab_size], dtype=torch.float32
|
||||
)
|
||||
),
|
||||
},
|
||||
expected_logits=(
|
||||
self.tensor([[l, l, l, l, 1.0]], dtype=torch.float32)
|
||||
if repetition_penalty != 1.0
|
||||
else self.tensor([[1.0] * self.vocab_size], dtype=torch.float32)
|
||||
),
|
||||
),
|
||||
],
|
||||
)
|
||||
|
||||
def create_test_subjects(self) -> List[Subject]:
|
||||
self.enabled = self._create_subject(repetition_penalty=REPETITION_PENALTY)
|
||||
self.disabled = self._create_subject(repetition_penalty=1.0)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
@@ -1,114 +0,0 @@
|
||||
import json
|
||||
import unittest
|
||||
from multiprocessing import Process
|
||||
|
||||
import requests
|
||||
|
||||
from sglang.srt.utils import kill_process_tree
|
||||
from sglang.test.test_utils import (
|
||||
DEFAULT_SMALL_MODEL_NAME_FOR_TEST,
|
||||
DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
|
||||
DEFAULT_URL_FOR_TEST,
|
||||
popen_launch_server,
|
||||
)
|
||||
|
||||
|
||||
class TestBatchPenalizerE2E(unittest.TestCase):
|
||||
|
||||
@classmethod
|
||||
def setUpClass(cls):
|
||||
cls.model = DEFAULT_SMALL_MODEL_NAME_FOR_TEST
|
||||
cls.base_url = DEFAULT_URL_FOR_TEST
|
||||
cls.process = popen_launch_server(
|
||||
cls.model,
|
||||
cls.base_url,
|
||||
timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
|
||||
other_args=(
|
||||
"--random-seed",
|
||||
"0",
|
||||
),
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def tearDownClass(cls):
|
||||
kill_process_tree(cls.process.pid)
|
||||
|
||||
def run_decode(
|
||||
self,
|
||||
return_logprob=True,
|
||||
top_logprobs_num=5,
|
||||
return_text=True,
|
||||
n=1,
|
||||
**sampling_params,
|
||||
):
|
||||
response = requests.post(
|
||||
self.base_url + "/generate",
|
||||
json={
|
||||
# prompt that is supposed to generate < 32 tokens
|
||||
"text": "<|start_header_id|>user<|end_header_id|>\n\nWhat is the answer for 1 + 1 = ?<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n",
|
||||
"sampling_params": {
|
||||
"max_new_tokens": 32,
|
||||
"n": n,
|
||||
**sampling_params,
|
||||
},
|
||||
"stream": False,
|
||||
"return_logprob": return_logprob,
|
||||
"top_logprobs_num": top_logprobs_num,
|
||||
"return_text_in_logprobs": return_text,
|
||||
"logprob_start_len": 0,
|
||||
},
|
||||
)
|
||||
assert response.status_code == 200, "Request failed: " + response.text
|
||||
|
||||
def test_default_values(self):
|
||||
self.run_decode()
|
||||
|
||||
def test_mixed(self):
|
||||
"""
|
||||
Sends two requests with one with penalizers disabled, and the other with penalizers enabled.
|
||||
This will cause two different {ScheduleBatch} to be initialized and eventually gets merged.
|
||||
|
||||
Merging batch with penalizers enabled with enabled, or disabled is trivial. However disabled + enabled is not.
|
||||
This is because the penalizer will not be prepared if it is not required, then it will be prepared during the merge.
|
||||
|
||||
This test triggers the merge of disabled + enabled.
|
||||
"""
|
||||
|
||||
processes = []
|
||||
|
||||
p = Process(
|
||||
target=self.run_decode,
|
||||
)
|
||||
processes.append(p)
|
||||
p.start()
|
||||
|
||||
p = Process(
|
||||
target=self.run_decode,
|
||||
kwargs={
|
||||
"frequency_penalty": 2,
|
||||
"min_new_tokens": 16,
|
||||
"presence_penalty": 2,
|
||||
"repetition_penalty": 2,
|
||||
},
|
||||
)
|
||||
processes.append(p)
|
||||
p.start()
|
||||
|
||||
for p in processes:
|
||||
p.join()
|
||||
|
||||
def test_frequency_penalty(self):
|
||||
self.run_decode(frequency_penalty=2)
|
||||
|
||||
def test_min_new_tokens(self):
|
||||
self.run_decode(min_new_tokens=16)
|
||||
|
||||
def test_presence_penalty(self):
|
||||
self.run_decode(presence_penalty=2)
|
||||
|
||||
def test_repetition_penalty(self):
|
||||
self.run_decode(repetition_penalty=2)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main(verbosity=3)
|
||||
@@ -138,6 +138,7 @@ class TestBenchServing(unittest.TestCase):
|
||||
model=DEFAULT_EAGLE_TARGET_MODEL_FOR_TEST,
|
||||
num_prompts=50,
|
||||
request_rate=1,
|
||||
sharegpt_context_len=3072,
|
||||
disable_ignore_eos=True,
|
||||
dataset_name="sharegpt",
|
||||
other_server_args=[
|
||||
@@ -148,22 +149,23 @@ class TestBenchServing(unittest.TestCase):
|
||||
"--speculative-num-steps",
|
||||
"5",
|
||||
"--speculative-eagle-topk",
|
||||
"8",
|
||||
"4",
|
||||
"--speculative-num-draft-tokens",
|
||||
"64",
|
||||
"16",
|
||||
"--mem-fraction-static",
|
||||
"0.7",
|
||||
"--cuda-graph-max-bs",
|
||||
"32",
|
||||
],
|
||||
need_warmup=True,
|
||||
)
|
||||
|
||||
if is_in_ci():
|
||||
write_github_step_summary(
|
||||
f"### test_online_latency_eagle\n"
|
||||
f'median_e2e_latency_ms : {res["median_e2e_latency_ms"]:.2f} ms\n'
|
||||
f'accept_length : {res["accept_length"]:.2f} \n'
|
||||
)
|
||||
self.assertLess(res["median_e2e_latency_ms"], 450)
|
||||
self.assertLess(res["median_e2e_latency_ms"], 700)
|
||||
self.assertGreater(res["accept_length"], 2.50)
|
||||
|
||||
def test_moe_offline_throughput_default(self):
|
||||
res = run_bench_serving(
|
||||
|
||||
@@ -12,7 +12,9 @@ from sglang.test.test_utils import (
|
||||
DEFAULT_MODEL_NAME_FOR_TEST,
|
||||
DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
|
||||
DEFAULT_URL_FOR_TEST,
|
||||
is_in_ci,
|
||||
popen_launch_server,
|
||||
write_github_step_summary,
|
||||
)
|
||||
|
||||
|
||||
@@ -44,6 +46,9 @@ class TestEvalAccuracyLarge(unittest.TestCase):
|
||||
metrics = run_eval(args)
|
||||
self.assertGreater(metrics["score"], 0.71)
|
||||
|
||||
if is_in_ci():
|
||||
write_github_step_summary(f"### test_mmlu\n" f'{metrics["score"]=:.4f}\n')
|
||||
|
||||
def test_human_eval(self):
|
||||
args = SimpleNamespace(
|
||||
base_url=self.base_url,
|
||||
@@ -56,6 +61,11 @@ class TestEvalAccuracyLarge(unittest.TestCase):
|
||||
metrics = run_eval(args)
|
||||
self.assertGreater(metrics["score"], 0.64)
|
||||
|
||||
if is_in_ci():
|
||||
write_github_step_summary(
|
||||
f"### test_human_eval\n" f'{metrics["score"]=:.4f}\n'
|
||||
)
|
||||
|
||||
def test_mgsm_en(self):
|
||||
args = SimpleNamespace(
|
||||
base_url=self.base_url,
|
||||
@@ -68,6 +78,11 @@ class TestEvalAccuracyLarge(unittest.TestCase):
|
||||
metrics = run_eval(args)
|
||||
self.assertGreater(metrics["score"], 0.835)
|
||||
|
||||
if is_in_ci():
|
||||
write_github_step_summary(
|
||||
f"### test_mgsm_en\n" f'{metrics["score"]=:.4f}\n'
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
|
||||
27
test/srt/test_health_check.py
Normal file
27
test/srt/test_health_check.py
Normal file
@@ -0,0 +1,27 @@
|
||||
import unittest
|
||||
|
||||
from sglang.test.test_utils import (
|
||||
DEFAULT_SMALL_MODEL_NAME_FOR_TEST,
|
||||
DEFAULT_URL_FOR_TEST,
|
||||
popen_launch_server,
|
||||
)
|
||||
|
||||
|
||||
class TestHealthCheck(unittest.TestCase):
|
||||
def test_health_check(self):
|
||||
"""Test that metrics endpoint returns data when enabled"""
|
||||
with self.assertRaises(TimeoutError):
|
||||
popen_launch_server(
|
||||
DEFAULT_SMALL_MODEL_NAME_FOR_TEST,
|
||||
DEFAULT_URL_FOR_TEST,
|
||||
timeout=60,
|
||||
other_args=[
|
||||
"--disable-cuda-graph",
|
||||
"--json-model-override-args",
|
||||
'{"architectures": ["LlamaForCausalLMForHealthTest"]}',
|
||||
],
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
@@ -49,7 +49,7 @@ class TestHiddenState(unittest.TestCase):
|
||||
with torch.inference_mode():
|
||||
hf_out = model(
|
||||
torch.tensor(
|
||||
[input_id + output["token_ids"][:-1]], device=model.device
|
||||
[input_id + output["output_ids"][:-1]], device=model.device
|
||||
),
|
||||
output_hidden_states=True,
|
||||
)
|
||||
|
||||
@@ -56,11 +56,13 @@ class TestEnableMetrics(unittest.TestCase):
|
||||
"sglang:gen_throughput",
|
||||
"sglang:num_queue_reqs",
|
||||
"sglang:cache_hit_rate",
|
||||
"sglang:spec_accept_length",
|
||||
"sglang:prompt_tokens_total",
|
||||
"sglang:generation_tokens_total",
|
||||
"sglang:num_requests_total",
|
||||
"sglang:time_to_first_token_seconds",
|
||||
"sglang:time_per_output_token_seconds",
|
||||
"sglang:inter_token_latency_seconds",
|
||||
"sglang:e2e_request_latency_seconds",
|
||||
]
|
||||
|
||||
|
||||
@@ -141,7 +141,7 @@ class TestDeepseekV3MTP(unittest.TestCase):
|
||||
metrics = run_eval_few_shot_gsm8k(args)
|
||||
print(metrics)
|
||||
|
||||
self.assertGreater(metrics["accuracy"], 0.62)
|
||||
self.assertGreater(metrics["accuracy"], 0.60)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
91
test/srt/test_penalty.py
Normal file
91
test/srt/test_penalty.py
Normal file
@@ -0,0 +1,91 @@
|
||||
import json
|
||||
import random
|
||||
import unittest
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
|
||||
import requests
|
||||
|
||||
from sglang.srt.utils import kill_process_tree
|
||||
from sglang.test.test_utils import (
|
||||
DEFAULT_SMALL_MODEL_NAME_FOR_TEST,
|
||||
DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
|
||||
DEFAULT_URL_FOR_TEST,
|
||||
popen_launch_server,
|
||||
)
|
||||
|
||||
|
||||
class TestPenalty(unittest.TestCase):
|
||||
|
||||
@classmethod
|
||||
def setUpClass(cls):
|
||||
cls.model = DEFAULT_SMALL_MODEL_NAME_FOR_TEST
|
||||
cls.base_url = DEFAULT_URL_FOR_TEST
|
||||
cls.process = popen_launch_server(
|
||||
cls.model,
|
||||
cls.base_url,
|
||||
timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def tearDownClass(cls):
|
||||
kill_process_tree(cls.process.pid)
|
||||
|
||||
def run_decode(self, sampling_params):
|
||||
return_logprob = True
|
||||
top_logprobs_num = 5
|
||||
return_text = True
|
||||
n = 1
|
||||
|
||||
response = requests.post(
|
||||
self.base_url + "/generate",
|
||||
json={
|
||||
# prompt that is supposed to generate < 32 tokens
|
||||
"text": "<|start_header_id|>user<|end_header_id|>\n\nWhat is the answer for 1 + 1 = ?<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n",
|
||||
"sampling_params": {
|
||||
"max_new_tokens": 32,
|
||||
"n": n,
|
||||
**sampling_params,
|
||||
},
|
||||
"return_logprob": return_logprob,
|
||||
"top_logprobs_num": top_logprobs_num,
|
||||
"return_text_in_logprobs": return_text,
|
||||
"logprob_start_len": 0,
|
||||
},
|
||||
)
|
||||
self.assertEqual(response.status_code, 200)
|
||||
print(json.dumps(response.json()))
|
||||
print("=" * 100)
|
||||
|
||||
def test_default_values(self):
|
||||
self.run_decode({})
|
||||
|
||||
def test_frequency_penalty(self):
|
||||
self.run_decode({"frequency_penalty": 2})
|
||||
|
||||
def test_min_new_tokens(self):
|
||||
self.run_decode({"min_new_tokens": 16})
|
||||
|
||||
def test_presence_penalty(self):
|
||||
self.run_decode({"presence_penalty": 2})
|
||||
|
||||
def test_mixed(self):
|
||||
args = [
|
||||
{},
|
||||
{},
|
||||
{},
|
||||
{"frequency_penalty": 2},
|
||||
{"min_new_tokens": 16},
|
||||
{"presence_penalty": 1},
|
||||
{"frequency_penalty": 0.2},
|
||||
{"min_new_tokens": 8},
|
||||
{"presence_penalty": 0.4},
|
||||
{"presence_penalty": 0.4, "frequency_penalty": 2},
|
||||
{"min_new_tokens": 12, "frequency_penalty": 2},
|
||||
]
|
||||
random.shuffle(args * 5)
|
||||
with ThreadPoolExecutor(8) as executor:
|
||||
list(executor.map(self.run_decode, args))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main(verbosity=3)
|
||||
@@ -70,7 +70,10 @@ class TestSessionControl(unittest.TestCase):
|
||||
|
||||
first_rid = None
|
||||
outputs_from_session = []
|
||||
logprobs_from_session = []
|
||||
cur_logprob_start_len = 0
|
||||
for i, chunk_ids in enumerate(chunks_ids):
|
||||
max_new_tokens = gen_len if i > 0 else 1 # prefill only for the first chunk
|
||||
response = requests.post(
|
||||
self.base_url + "/generate",
|
||||
json={
|
||||
@@ -83,12 +86,12 @@ class TestSessionControl(unittest.TestCase):
|
||||
},
|
||||
"sampling_params": {
|
||||
"temperature": 0,
|
||||
"max_new_tokens": (
|
||||
gen_len if i > 0 else 1
|
||||
), # prefill only for the first chunk
|
||||
"max_new_tokens": max_new_tokens,
|
||||
"no_stop_trim": True,
|
||||
"skip_special_tokens": False,
|
||||
},
|
||||
"return_logprob": True,
|
||||
"logprob_start_len": cur_logprob_start_len - 1,
|
||||
},
|
||||
).json()
|
||||
rid = response["meta_info"]["id"]
|
||||
@@ -96,8 +99,39 @@ class TestSessionControl(unittest.TestCase):
|
||||
first_rid = rid
|
||||
if i > 0:
|
||||
outputs_from_session.append(response["text"])
|
||||
logprobs_from_session.extend(
|
||||
[
|
||||
round(sublist[0], 2)
|
||||
for sublist in response["meta_info"]["output_token_logprobs"]
|
||||
]
|
||||
)
|
||||
cur_logprob_start_len += len(chunk_ids) + max_new_tokens
|
||||
|
||||
# query with a logprob_start_len longer than the request, should see error
|
||||
response = requests.post(
|
||||
self.base_url + "/generate",
|
||||
json={
|
||||
"input_ids": chunk_ids,
|
||||
"session_params": {
|
||||
"id": session_id,
|
||||
"rid": rid,
|
||||
"offset": -1,
|
||||
"replace": True,
|
||||
},
|
||||
"sampling_params": {
|
||||
"temperature": 0,
|
||||
"max_new_tokens": max_new_tokens,
|
||||
"no_stop_trim": True,
|
||||
"skip_special_tokens": False,
|
||||
},
|
||||
"return_logprob": True,
|
||||
"logprob_start_len": cur_logprob_start_len + len(chunk_ids),
|
||||
},
|
||||
).json()
|
||||
assert "Request with a lower logprob_start_len" in response["error"]["message"]
|
||||
|
||||
# backtrack to the first request and regenerate
|
||||
cur_logprob_start_len = 0
|
||||
response = requests.post(
|
||||
self.base_url + "/generate",
|
||||
json={
|
||||
@@ -114,9 +148,17 @@ class TestSessionControl(unittest.TestCase):
|
||||
"no_stop_trim": True,
|
||||
"skip_special_tokens": False,
|
||||
},
|
||||
"return_logprob": True,
|
||||
"logprob_start_len": cur_logprob_start_len,
|
||||
},
|
||||
).json()
|
||||
outputs_from_session.append(response["text"])
|
||||
logprobs_from_session.extend(
|
||||
[
|
||||
round(sublist[0], 2)
|
||||
for sublist in response["meta_info"]["output_token_logprobs"]
|
||||
]
|
||||
)
|
||||
|
||||
# query with a non-existing rid (the last one should be disappeared becuase of backtrack), should see abort
|
||||
response = requests.post(
|
||||
@@ -135,6 +177,7 @@ class TestSessionControl(unittest.TestCase):
|
||||
"no_stop_trim": True,
|
||||
"skip_special_tokens": False,
|
||||
},
|
||||
"return_logprob": True,
|
||||
},
|
||||
).json()
|
||||
assert response["meta_info"]["finish_reason"]["type"] == "abort"
|
||||
@@ -162,6 +205,7 @@ class TestSessionControl(unittest.TestCase):
|
||||
"no_stop_trim": True,
|
||||
"skip_special_tokens": False,
|
||||
},
|
||||
"return_logprob": True,
|
||||
},
|
||||
).json()
|
||||
assert response["meta_info"]["finish_reason"]["type"] == "abort"
|
||||
@@ -172,6 +216,7 @@ class TestSessionControl(unittest.TestCase):
|
||||
input_ids_first_req = None
|
||||
input_ids = []
|
||||
outputs_normal = []
|
||||
logprobs_normal = []
|
||||
for i, chunk_ids in enumerate(chunks_ids):
|
||||
input_ids += chunk_ids
|
||||
response = requests.post(
|
||||
@@ -186,6 +231,7 @@ class TestSessionControl(unittest.TestCase):
|
||||
"no_stop_trim": True,
|
||||
"skip_special_tokens": False,
|
||||
},
|
||||
"return_logprob": True,
|
||||
},
|
||||
).json()
|
||||
if i > 0:
|
||||
@@ -194,6 +240,12 @@ class TestSessionControl(unittest.TestCase):
|
||||
output_ids = output_ids[1:]
|
||||
input_ids += output_ids[:-1]
|
||||
outputs_normal.append(response["text"])
|
||||
logprobs_normal.extend(
|
||||
[
|
||||
round(sublist[0], 2)
|
||||
for sublist in response["meta_info"]["output_token_logprobs"]
|
||||
]
|
||||
)
|
||||
if i == 0:
|
||||
input_ids_first_req = input_ids.copy()
|
||||
|
||||
@@ -208,17 +260,31 @@ class TestSessionControl(unittest.TestCase):
|
||||
"no_stop_trim": True,
|
||||
"skip_special_tokens": False,
|
||||
},
|
||||
"return_logprob": True,
|
||||
},
|
||||
).json()
|
||||
outputs_normal.append(response["text"])
|
||||
logprobs_normal.extend(
|
||||
[
|
||||
round(sublist[0], 2)
|
||||
for sublist in response["meta_info"]["output_token_logprobs"]
|
||||
]
|
||||
)
|
||||
|
||||
print("outputs from chunked queries with session control:")
|
||||
print(outputs_from_session)
|
||||
print("outputs from normal queries:")
|
||||
print(outputs_normal)
|
||||
assert (
|
||||
outputs_from_session == outputs_normal
|
||||
), f"outputs_from_session: {outputs_from_session}, outputs_normal: {outputs_normal}"
|
||||
assert outputs_from_session == outputs_normal
|
||||
print("logprobs from chunked queries with session control:")
|
||||
print(logprobs_from_session)
|
||||
print("logprobs from normal queries:")
|
||||
print(logprobs_normal)
|
||||
assert len(logprobs_from_session) == len(
|
||||
logprobs_normal
|
||||
), "logprobs must have equal length"
|
||||
for a, b in zip(logprobs_from_session, logprobs_normal):
|
||||
assert abs(a - b) <= 0.1, f"logprobs {a} and {b} differ by more than 0.1"
|
||||
|
||||
async def async_generate(self, payload):
|
||||
url = self.base_url + "/generate"
|
||||
|
||||
@@ -1,3 +1,8 @@
|
||||
"""
|
||||
python3 -m unittest test_skip_tokenizer_init.TestSkipTokenizerInit.test_parallel_sample
|
||||
python3 -m unittest test_skip_tokenizer_init.TestSkipTokenizerInit.run_decode_stream
|
||||
"""
|
||||
|
||||
import json
|
||||
import unittest
|
||||
|
||||
@@ -12,42 +17,26 @@ from sglang.test.test_utils import (
|
||||
popen_launch_server,
|
||||
)
|
||||
|
||||
_server_process = None
|
||||
_base_url = None
|
||||
_tokenizer = None
|
||||
|
||||
|
||||
def setUpModule():
|
||||
"""
|
||||
Launch the server once before all tests and initialize the tokenizer.
|
||||
"""
|
||||
global _server_process, _base_url, _tokenizer
|
||||
_server_process = popen_launch_server(
|
||||
DEFAULT_SMALL_MODEL_NAME_FOR_TEST,
|
||||
DEFAULT_URL_FOR_TEST,
|
||||
timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
|
||||
other_args=["--skip-tokenizer-init"],
|
||||
)
|
||||
_base_url = DEFAULT_URL_FOR_TEST
|
||||
|
||||
_tokenizer = AutoTokenizer.from_pretrained(
|
||||
DEFAULT_SMALL_MODEL_NAME_FOR_TEST, use_fast=False
|
||||
)
|
||||
print(">>> setUpModule: Server launched, tokenizer ready")
|
||||
|
||||
|
||||
def tearDownModule():
|
||||
"""
|
||||
Terminate the server once after all tests have completed.
|
||||
"""
|
||||
global _server_process
|
||||
if _server_process is not None:
|
||||
kill_process_tree(_server_process.pid)
|
||||
_server_process = None
|
||||
print(">>> tearDownModule: Server terminated")
|
||||
|
||||
|
||||
class TestSkipTokenizerInit(unittest.TestCase):
|
||||
@classmethod
|
||||
def setUpClass(cls):
|
||||
cls.model = DEFAULT_SMALL_MODEL_NAME_FOR_TEST
|
||||
cls.base_url = DEFAULT_URL_FOR_TEST
|
||||
cls.process = popen_launch_server(
|
||||
cls.model,
|
||||
cls.base_url,
|
||||
timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
|
||||
other_args=["--skip-tokenizer-init", "--stream-output"],
|
||||
)
|
||||
cls.tokenizer = AutoTokenizer.from_pretrained(
|
||||
DEFAULT_SMALL_MODEL_NAME_FOR_TEST, use_fast=False
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def tearDownClass(cls):
|
||||
kill_process_tree(cls.process.pid)
|
||||
|
||||
def run_decode(
|
||||
self,
|
||||
prompt_text="The capital of France is",
|
||||
@@ -56,19 +45,19 @@ class TestSkipTokenizerInit(unittest.TestCase):
|
||||
top_logprobs_num=0,
|
||||
n=1,
|
||||
):
|
||||
input_ids = _tokenizer(prompt_text, return_tensors="pt")["input_ids"][
|
||||
input_ids = self.tokenizer(prompt_text, return_tensors="pt")["input_ids"][
|
||||
0
|
||||
].tolist()
|
||||
|
||||
response = requests.post(
|
||||
_base_url + "/generate",
|
||||
self.base_url + "/generate",
|
||||
json={
|
||||
"input_ids": input_ids,
|
||||
"sampling_params": {
|
||||
"temperature": 0 if n == 1 else 0.5,
|
||||
"max_new_tokens": max_new_tokens,
|
||||
"n": n,
|
||||
"stop_token_ids": [_tokenizer.eos_token_id],
|
||||
"stop_token_ids": [self.tokenizer.eos_token_id],
|
||||
},
|
||||
"stream": False,
|
||||
"return_logprob": return_logprob,
|
||||
@@ -83,13 +72,13 @@ class TestSkipTokenizerInit(unittest.TestCase):
|
||||
if item["meta_info"]["finish_reason"]["type"] == "stop":
|
||||
self.assertEqual(
|
||||
item["meta_info"]["finish_reason"]["matched"],
|
||||
_tokenizer.eos_token_id,
|
||||
self.tokenizer.eos_token_id,
|
||||
)
|
||||
elif item["meta_info"]["finish_reason"]["type"] == "length":
|
||||
self.assertEqual(
|
||||
len(item["token_ids"]), item["meta_info"]["completion_tokens"]
|
||||
len(item["output_ids"]), item["meta_info"]["completion_tokens"]
|
||||
)
|
||||
self.assertEqual(len(item["token_ids"]), max_new_tokens)
|
||||
self.assertEqual(len(item["output_ids"]), max_new_tokens)
|
||||
self.assertEqual(item["meta_info"]["prompt_tokens"], len(input_ids))
|
||||
|
||||
if return_logprob:
|
||||
@@ -113,6 +102,63 @@ class TestSkipTokenizerInit(unittest.TestCase):
|
||||
|
||||
print("=" * 100)
|
||||
|
||||
def run_decode_stream(self, return_logprob=False, top_logprobs_num=0, n=1):
|
||||
max_new_tokens = 32
|
||||
input_ids = [128000, 791, 6864, 315, 9822, 374] # The capital of France is
|
||||
requests.post(self.base_url + "/flush_cache")
|
||||
response = requests.post(
|
||||
self.base_url + "/generate",
|
||||
json={
|
||||
"input_ids": input_ids,
|
||||
"sampling_params": {
|
||||
"temperature": 0 if n == 1 else 0.5,
|
||||
"max_new_tokens": max_new_tokens,
|
||||
"n": n,
|
||||
"stop_token_ids": [119690],
|
||||
},
|
||||
"stream": False,
|
||||
"return_logprob": return_logprob,
|
||||
"top_logprobs_num": top_logprobs_num,
|
||||
"logprob_start_len": 0,
|
||||
},
|
||||
)
|
||||
ret = response.json()
|
||||
print(json.dumps(ret))
|
||||
output_ids = ret["output_ids"]
|
||||
|
||||
requests.post(self.base_url + "/flush_cache")
|
||||
response_stream = requests.post(
|
||||
self.base_url + "/generate",
|
||||
json={
|
||||
"input_ids": input_ids,
|
||||
"sampling_params": {
|
||||
"temperature": 0 if n == 1 else 0.5,
|
||||
"max_new_tokens": max_new_tokens,
|
||||
"n": n,
|
||||
"stop_token_ids": [119690],
|
||||
},
|
||||
"stream": True,
|
||||
"return_logprob": return_logprob,
|
||||
"top_logprobs_num": top_logprobs_num,
|
||||
"logprob_start_len": 0,
|
||||
},
|
||||
)
|
||||
ret = response.json()
|
||||
output_ids = ret["output_ids"]
|
||||
print("output from non-streaming request:")
|
||||
print(output_ids)
|
||||
|
||||
response_stream_json = []
|
||||
for line in response_stream.iter_lines():
|
||||
if line.startswith(b"data: ") and line[6:] != b"[DONE]":
|
||||
response_stream_json.append(json.loads(line[6:]))
|
||||
out_stream_ids = []
|
||||
for x in response_stream_json:
|
||||
out_stream_ids += x["output_ids"]
|
||||
print("output from streaming request:")
|
||||
print(out_stream_ids)
|
||||
assert output_ids == out_stream_ids
|
||||
|
||||
def test_simple_decode(self):
|
||||
self.run_decode()
|
||||
|
||||
@@ -126,6 +172,9 @@ class TestSkipTokenizerInit(unittest.TestCase):
|
||||
def test_eos_behavior(self):
|
||||
self.run_decode(max_new_tokens=256)
|
||||
|
||||
def test_simple_decode_stream(self):
|
||||
self.run_decode_stream()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
|
||||
@@ -8,6 +8,7 @@ import random
|
||||
import time
|
||||
import unittest
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
from functools import partial
|
||||
from typing import Optional
|
||||
|
||||
import numpy as np
|
||||
@@ -20,6 +21,7 @@ from sglang.test.test_utils import (
|
||||
DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
|
||||
DEFAULT_URL_FOR_TEST,
|
||||
popen_launch_server,
|
||||
run_logprob_check,
|
||||
)
|
||||
|
||||
|
||||
@@ -35,7 +37,9 @@ class TestSRTEndpoint(unittest.TestCase):
|
||||
other_args=(
|
||||
"--enable-custom-logit-processor",
|
||||
"--mem-fraction-static",
|
||||
"0.8",
|
||||
"0.7",
|
||||
"--cuda-graph-max-bs",
|
||||
"8",
|
||||
),
|
||||
)
|
||||
|
||||
@@ -131,7 +135,7 @@ class TestSRTEndpoint(unittest.TestCase):
|
||||
for i, res in enumerate(response_json):
|
||||
self.assertEqual(
|
||||
res["meta_info"]["prompt_tokens"],
|
||||
logprob_start_len + 1 + len(res["meta_info"]["input_token_logprobs"]),
|
||||
logprob_start_len + len(res["meta_info"]["input_token_logprobs"]),
|
||||
)
|
||||
assert prompts[i].endswith(
|
||||
"".join([x[-1] for x in res["meta_info"]["input_token_logprobs"]])
|
||||
@@ -235,83 +239,15 @@ class TestSRTEndpoint(unittest.TestCase):
|
||||
|
||||
diff = np.abs(output_logprobs - output_logprobs_score)
|
||||
max_diff = np.max(diff)
|
||||
self.assertLess(max_diff, 0.25)
|
||||
|
||||
def run_logprob_check(self, arg):
|
||||
(
|
||||
input_len,
|
||||
output_len,
|
||||
temperature,
|
||||
logprob_start_len,
|
||||
return_logprob,
|
||||
top_logprobs_num,
|
||||
) = arg
|
||||
input_ids = list(range(input_len))
|
||||
|
||||
response = requests.post(
|
||||
self.base_url + "/generate",
|
||||
json={
|
||||
"input_ids": input_ids,
|
||||
"sampling_params": {
|
||||
"temperature": temperature,
|
||||
"max_new_tokens": output_len,
|
||||
},
|
||||
"return_logprob": return_logprob,
|
||||
"logprob_start_len": logprob_start_len,
|
||||
"top_logprobs_num": top_logprobs_num,
|
||||
},
|
||||
)
|
||||
response_json = response.json()
|
||||
|
||||
res = response_json
|
||||
self.assertEqual(res["meta_info"]["prompt_tokens"], input_len)
|
||||
self.assertEqual(res["meta_info"]["completion_tokens"], output_len)
|
||||
|
||||
# Test the number of tokens are correct
|
||||
if return_logprob:
|
||||
# This is because if logprob_start_len == 0, we added a padding for the first token.
|
||||
# In other cases, we do not add the padding
|
||||
delta = 0 if logprob_start_len == 0 else 1
|
||||
|
||||
self.assertEqual(
|
||||
len(res["meta_info"]["input_token_logprobs"])
|
||||
+ logprob_start_len
|
||||
+ delta,
|
||||
res["meta_info"]["prompt_tokens"],
|
||||
)
|
||||
self.assertEqual(len(res["meta_info"]["output_token_logprobs"]), output_len)
|
||||
|
||||
if top_logprobs_num:
|
||||
self.assertEqual(
|
||||
len(res["meta_info"]["input_top_logprobs"])
|
||||
+ logprob_start_len
|
||||
+ delta,
|
||||
res["meta_info"]["prompt_tokens"],
|
||||
)
|
||||
self.assertEqual(
|
||||
len(res["meta_info"]["output_top_logprobs"]), output_len
|
||||
)
|
||||
|
||||
for i in range(output_len):
|
||||
self.assertEqual(
|
||||
len(res["meta_info"]["output_top_logprobs"][i]),
|
||||
top_logprobs_num,
|
||||
)
|
||||
|
||||
# Test the top-1 tokens are the same as output tokens if temperature == 0
|
||||
if temperature == 0:
|
||||
self.assertListEqual(
|
||||
res["meta_info"]["output_token_logprobs"][i],
|
||||
res["meta_info"]["output_top_logprobs"][i][0],
|
||||
)
|
||||
self.assertLess(max_diff, 0.35)
|
||||
|
||||
def test_logprob_mixed(self):
|
||||
args = []
|
||||
temperature = 0
|
||||
# input_len, output_len, temperature, logprob_start_len, return_logprob, top_logprobs_num
|
||||
for input_len in [1000, 2000]:
|
||||
for input_len in [1000, 5000, 10000, 50000]:
|
||||
for output_len in [4, 8]:
|
||||
for logprob_start_len in [0, 500, 1000]:
|
||||
for logprob_start_len in [0, 500, 2500, 5000, 25000]:
|
||||
for return_logprob in [True, False]:
|
||||
for top_logprobs_num in [0, 5]:
|
||||
|
||||
@@ -331,8 +267,9 @@ class TestSRTEndpoint(unittest.TestCase):
|
||||
|
||||
random.shuffle(args)
|
||||
|
||||
func = partial(run_logprob_check, self)
|
||||
with ThreadPoolExecutor(8) as executor:
|
||||
list(executor.map(self.run_logprob_check, args))
|
||||
list(executor.map(func, args))
|
||||
|
||||
def test_logprob_grammar(self):
|
||||
prompts = "Question: Is Paris the Capital of France? Answer:"
|
||||
@@ -427,6 +364,77 @@ class TestSRTEndpoint(unittest.TestCase):
|
||||
f"{target_token_id=}\n{sampled_tokens=}\n{custom_response=}",
|
||||
)
|
||||
|
||||
def run_stateful_custom_logit_processor(
|
||||
self, first_token_id: int | None, delay: int = 2
|
||||
):
|
||||
"""Test custom logit processor with custom params and state.
|
||||
|
||||
Should sample the first `delay` tokens normally, then output first_token_id and consecutive tokens after that.
|
||||
If first_token_id is None, the custom logit processor won't be passed in.
|
||||
"""
|
||||
|
||||
custom_params = {"token_id": first_token_id, "delay": 2}
|
||||
|
||||
class DeterministicStatefulLogitProcessor(CustomLogitProcessor):
|
||||
"""A dummy logit processor that changes the logits to always
|
||||
sample the given token id.
|
||||
"""
|
||||
|
||||
def __call__(self, logits, custom_param_list):
|
||||
assert logits.shape[0] == len(custom_param_list)
|
||||
|
||||
for i, param_dict in enumerate(custom_param_list):
|
||||
if param_dict["delay"] > 0:
|
||||
param_dict["delay"] -= 1
|
||||
continue
|
||||
if param_dict["delay"] == 0:
|
||||
param_dict["delay"] -= 1
|
||||
force_token = param_dict["token_id"]
|
||||
else:
|
||||
output_ids = param_dict["__req__"].output_ids
|
||||
force_token = output_ids[-1] + 1
|
||||
# Mask all other tokens
|
||||
logits[i, :] = -float("inf")
|
||||
# Assign highest probability to the specified token
|
||||
logits[i, force_token] = 0.0
|
||||
return logits
|
||||
|
||||
prompts = "Question: Is Paris the Capital of France? Answer:"
|
||||
|
||||
# Base case json data to be posted to the server.
|
||||
base_json = {
|
||||
"text": prompts,
|
||||
"sampling_params": {"temperature": 0.0},
|
||||
"return_logprob": True,
|
||||
}
|
||||
|
||||
# Custom json data with custom logit processor and params.
|
||||
custom_json = base_json.copy()
|
||||
# Only set the custom logit processor if target_token_id is not None.
|
||||
if first_token_id is not None:
|
||||
custom_json["custom_logit_processor"] = (
|
||||
DeterministicStatefulLogitProcessor().to_str()
|
||||
)
|
||||
custom_json["sampling_params"]["custom_params"] = custom_params
|
||||
|
||||
custom_response = requests.post(
|
||||
self.base_url + "/generate",
|
||||
json=custom_json,
|
||||
).json()
|
||||
|
||||
output_token_logprobs = custom_response["meta_info"]["output_token_logprobs"]
|
||||
sampled_tokens = [x[1] for x in output_token_logprobs]
|
||||
# The logit processor should always sample the given token as the logits is deterministic.
|
||||
if first_token_id is not None:
|
||||
self.assertTrue(
|
||||
all(
|
||||
x == custom_params["token_id"] + k
|
||||
for k, x in enumerate(sampled_tokens[custom_params["delay"] :])
|
||||
),
|
||||
# Print the detailed test case info if the test fails.
|
||||
f"{first_token_id=}\n{sampled_tokens=}\n{custom_response=}",
|
||||
)
|
||||
|
||||
def test_custom_logit_processor(self):
|
||||
"""Test custom logit processor with a single request."""
|
||||
self.run_custom_logit_processor(target_token_id=5)
|
||||
@@ -438,6 +446,19 @@ class TestSRTEndpoint(unittest.TestCase):
|
||||
with ThreadPoolExecutor(len(target_token_ids)) as executor:
|
||||
list(executor.map(self.run_custom_logit_processor, target_token_ids))
|
||||
|
||||
def test_stateful_custom_logit_processor(self):
|
||||
"""Test custom logit processor with a single request."""
|
||||
self.run_stateful_custom_logit_processor(first_token_id=5)
|
||||
|
||||
def test_stateful_custom_logit_processor_batch_mixed(self):
|
||||
"""Test a batch of requests mixed of requests with and without custom logit processor."""
|
||||
target_token_ids = list(range(32)) + [None] * 16
|
||||
random.shuffle(target_token_ids)
|
||||
with ThreadPoolExecutor(len(target_token_ids)) as executor:
|
||||
list(
|
||||
executor.map(self.run_stateful_custom_logit_processor, target_token_ids)
|
||||
)
|
||||
|
||||
def test_cache_tokens(self):
|
||||
for _ in range(2):
|
||||
time.sleep(1)
|
||||
@@ -476,6 +497,21 @@ class TestSRTEndpoint(unittest.TestCase):
|
||||
version = response_json["version"]
|
||||
self.assertIsInstance(version, str)
|
||||
|
||||
def test_get_server_info_concurrent(self):
|
||||
"""Make sure the concurrent get_server_info doesn't crash the server."""
|
||||
tp = ThreadPoolExecutor(max_workers=30)
|
||||
|
||||
def s():
|
||||
server_info = requests.get(self.base_url + "/get_server_info")
|
||||
server_info.json()
|
||||
|
||||
futures = []
|
||||
for _ in range(4):
|
||||
futures.append(tp.submit(s))
|
||||
|
||||
for f in futures:
|
||||
f.result()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
|
||||
@@ -168,9 +168,9 @@ def _run_subprocess(
|
||||
hf_tokenizer = get_tokenizer(model_path, trust_remote_code=True)
|
||||
|
||||
hf_outputs = HFRunner.forward_generation_raw(
|
||||
base_model=hf_model,
|
||||
prompts=_PROMPTS,
|
||||
max_new_tokens=_MAX_NEW_TOKENS,
|
||||
base_model=hf_model,
|
||||
tokenizer=hf_tokenizer,
|
||||
lora_paths=None,
|
||||
torch_dtype=_TORCH_DTYPE,
|
||||
|
||||
Reference in New Issue
Block a user