Support penalty in overlap mode; return logprob with chunked prefill; improve benchmark scripts (#3988)

Co-authored-by: SangBin Cho <rkooo567@gmail.com>
Co-authored-by: dhou-xai <dhou@x.ai>
Co-authored-by: Hanming Lu <hanming_lu@berkeley.edu>
This commit is contained in:
Lianmin Zheng
2025-03-03 00:12:04 -08:00
parent 0194948fd9
commit ac2387279e
86 changed files with 4116 additions and 2015 deletions

View File

@@ -8,10 +8,11 @@ import random
import subprocess
import threading
import time
import unittest
from concurrent.futures import ThreadPoolExecutor
from functools import partial
from types import SimpleNamespace
from typing import Callable, List, Optional
from typing import Callable, List, Optional, Tuple
import numpy as np
import requests
@@ -408,26 +409,49 @@ def popen_launch_server(
other_args: list[str] = (),
env: Optional[dict] = None,
return_stdout_stderr: Optional[tuple] = None,
pd_seperated: bool = False,
):
_, host, port = base_url.split(":")
host = host[2:]
if pd_seperated:
command = "sglang.launch_pd_server"
else:
command = "sglang.launch_server"
command = [
"python3",
"-m",
"sglang.launch_server",
command,
"--model-path",
model,
"--host",
host,
"--port",
port,
*other_args,
*[str(x) for x in other_args],
]
if pd_seperated:
command.extend(
[
"--lb-host",
host,
"--lb-port",
port,
]
)
else:
command.extend(
[
"--host",
host,
"--port",
port,
]
)
if api_key:
command += ["--api-key", api_key]
print(f"command={' '.join(command)}")
if return_stdout_stderr:
process = subprocess.Popen(
command,
@@ -456,6 +480,8 @@ def popen_launch_server(
except requests.RequestException:
pass
time.sleep(10)
kill_process_tree(process.pid)
raise TimeoutError("Server failed to start within the timeout period.")
@@ -488,9 +514,11 @@ def run_unittest_files(files: List[str], timeout_per_file: float):
success = True
for filename in files:
global process
process = None
def run_one_file(filename):
nonlocal process
filename = os.path.join(os.getcwd(), filename)
print(f"\n\nRun:\npython3 {filename}\n\n", flush=True)
process = subprocess.Popen(
@@ -534,11 +562,14 @@ def get_benchmark_args(
dataset_path="",
tokenizer="",
num_prompts=500,
sharegpt_output_len=None,
random_input_len=4096,
random_output_len=2048,
sharegpt_context_len=None,
request_rate=float("inf"),
disable_stream=False,
disable_ignore_eos=False,
pd_seperated: bool = False,
):
return SimpleNamespace(
backend="sglang",
@@ -550,8 +581,8 @@ def get_benchmark_args(
model=None,
tokenizer=tokenizer,
num_prompts=num_prompts,
sharegpt_output_len=None,
sharegpt_context_len=None,
sharegpt_output_len=sharegpt_output_len,
sharegpt_context_len=sharegpt_context_len,
random_input_len=random_input_len,
random_output_len=random_output_len,
random_range_ratio=0.0,
@@ -567,6 +598,8 @@ def get_benchmark_args(
apply_chat_template=False,
profile=None,
lora_name=None,
prompt_suffix="",
pd_seperated=pd_seperated,
)
@@ -580,6 +613,7 @@ def run_bench_serving(
tokenizer=None,
random_input_len=4096,
random_output_len=2048,
sharegpt_context_len=None,
disable_stream=False,
disable_ignore_eos=False,
need_warmup=False,
@@ -602,6 +636,7 @@ def run_bench_serving(
num_prompts=num_prompts,
random_input_len=random_input_len,
random_output_len=random_output_len,
sharegpt_context_len=sharegpt_context_len,
request_rate=request_rate,
disable_stream=disable_stream,
disable_ignore_eos=disable_ignore_eos,
@@ -626,6 +661,7 @@ def run_bench_serving_multi(
other_server_args,
benchmark_args,
need_warmup=False,
pd_seperated=False,
):
# Launch the server
process = popen_launch_server(
@@ -633,6 +669,7 @@ def run_bench_serving_multi(
base_url,
timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
other_args=other_server_args,
pd_seperated=pd_seperated,
)
# run benchmark for all
@@ -665,7 +702,7 @@ def run_bench_one_batch(model, other_args):
"128",
"--output",
"8",
*other_args,
*[str(x) for x in other_args],
]
process = subprocess.Popen(command, stdout=subprocess.PIPE, stderr=subprocess.PIPE)
@@ -816,7 +853,7 @@ def run_command_and_capture_output(command, env: Optional[dict] = None):
stdout = open(STDOUT_FILENAME, "w")
stderr = open(STDERR_FILENAME, "w")
process = subprocess.Popen(
command, stdout=stdout, stderr=stderr, env=env, text=True
command, stdout=stdout, stderr=stdout, env=env, text=True
)
# Launch a thread to stream the output
@@ -914,3 +951,78 @@ def run_mulit_request_test(
def write_github_step_summary(content):
with open(os.environ["GITHUB_STEP_SUMMARY"], "a") as f:
f.write(content)
def run_logprob_check(self: unittest.TestCase, arg: Tuple):
(
input_len,
output_len,
temperature,
logprob_start_len,
return_logprob,
top_logprobs_num,
) = arg
input_ids = list(range(input_len))
response = requests.post(
self.base_url + "/generate",
json={
"input_ids": input_ids,
"sampling_params": {
"temperature": temperature,
"max_new_tokens": output_len,
"ignore_eos": True,
},
"return_logprob": return_logprob,
"logprob_start_len": logprob_start_len,
"top_logprobs_num": top_logprobs_num,
},
)
response_json = response.json()
res = response_json
self.assertEqual(res["meta_info"]["prompt_tokens"], input_len)
self.assertEqual(res["meta_info"]["completion_tokens"], output_len)
# Test the number of tokens are correct
if return_logprob:
self.assertEqual(
len(res["meta_info"]["input_token_logprobs"]) + logprob_start_len,
res["meta_info"]["prompt_tokens"],
)
self.assertEqual(len(res["meta_info"]["output_token_logprobs"]), output_len)
if top_logprobs_num:
self.assertEqual(
len(res["meta_info"]["input_top_logprobs"]) + logprob_start_len,
res["meta_info"]["prompt_tokens"],
)
self.assertEqual(len(res["meta_info"]["output_top_logprobs"]), output_len)
for i in range(output_len):
self.assertEqual(
len(res["meta_info"]["output_top_logprobs"][i]),
top_logprobs_num,
)
# Test the top-1 tokens are the same as output tokens if temperature == 0
if temperature == 0:
rank = 0
while rank < len(res["meta_info"]["output_top_logprobs"][i]):
try:
self.assertListEqual(
res["meta_info"]["output_token_logprobs"][i],
res["meta_info"]["output_top_logprobs"][i][rank],
)
break
except AssertionError:
# There's a tie. Allow the second item in this case.
if (
res["meta_info"]["output_top_logprobs"][i][rank][0]
== res["meta_info"]["output_top_logprobs"][i][rank + 1][
0
]
):
rank += 1
else:
raise