Support penalty in overlap mode; return logprob with chunked prefill; improve benchmark scripts (#3988)
Co-authored-by: SangBin Cho <rkooo567@gmail.com> Co-authored-by: dhou-xai <dhou@x.ai> Co-authored-by: Hanming Lu <hanming_lu@berkeley.edu>
This commit is contained in:
@@ -8,10 +8,11 @@ import random
|
||||
import subprocess
|
||||
import threading
|
||||
import time
|
||||
import unittest
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
from functools import partial
|
||||
from types import SimpleNamespace
|
||||
from typing import Callable, List, Optional
|
||||
from typing import Callable, List, Optional, Tuple
|
||||
|
||||
import numpy as np
|
||||
import requests
|
||||
@@ -408,26 +409,49 @@ def popen_launch_server(
|
||||
other_args: list[str] = (),
|
||||
env: Optional[dict] = None,
|
||||
return_stdout_stderr: Optional[tuple] = None,
|
||||
pd_seperated: bool = False,
|
||||
):
|
||||
_, host, port = base_url.split(":")
|
||||
host = host[2:]
|
||||
|
||||
if pd_seperated:
|
||||
command = "sglang.launch_pd_server"
|
||||
else:
|
||||
command = "sglang.launch_server"
|
||||
|
||||
command = [
|
||||
"python3",
|
||||
"-m",
|
||||
"sglang.launch_server",
|
||||
command,
|
||||
"--model-path",
|
||||
model,
|
||||
"--host",
|
||||
host,
|
||||
"--port",
|
||||
port,
|
||||
*other_args,
|
||||
*[str(x) for x in other_args],
|
||||
]
|
||||
|
||||
if pd_seperated:
|
||||
command.extend(
|
||||
[
|
||||
"--lb-host",
|
||||
host,
|
||||
"--lb-port",
|
||||
port,
|
||||
]
|
||||
)
|
||||
else:
|
||||
command.extend(
|
||||
[
|
||||
"--host",
|
||||
host,
|
||||
"--port",
|
||||
port,
|
||||
]
|
||||
)
|
||||
|
||||
if api_key:
|
||||
command += ["--api-key", api_key]
|
||||
|
||||
print(f"command={' '.join(command)}")
|
||||
|
||||
if return_stdout_stderr:
|
||||
process = subprocess.Popen(
|
||||
command,
|
||||
@@ -456,6 +480,8 @@ def popen_launch_server(
|
||||
except requests.RequestException:
|
||||
pass
|
||||
time.sleep(10)
|
||||
|
||||
kill_process_tree(process.pid)
|
||||
raise TimeoutError("Server failed to start within the timeout period.")
|
||||
|
||||
|
||||
@@ -488,9 +514,11 @@ def run_unittest_files(files: List[str], timeout_per_file: float):
|
||||
success = True
|
||||
|
||||
for filename in files:
|
||||
global process
|
||||
process = None
|
||||
|
||||
def run_one_file(filename):
|
||||
nonlocal process
|
||||
|
||||
filename = os.path.join(os.getcwd(), filename)
|
||||
print(f"\n\nRun:\npython3 {filename}\n\n", flush=True)
|
||||
process = subprocess.Popen(
|
||||
@@ -534,11 +562,14 @@ def get_benchmark_args(
|
||||
dataset_path="",
|
||||
tokenizer="",
|
||||
num_prompts=500,
|
||||
sharegpt_output_len=None,
|
||||
random_input_len=4096,
|
||||
random_output_len=2048,
|
||||
sharegpt_context_len=None,
|
||||
request_rate=float("inf"),
|
||||
disable_stream=False,
|
||||
disable_ignore_eos=False,
|
||||
pd_seperated: bool = False,
|
||||
):
|
||||
return SimpleNamespace(
|
||||
backend="sglang",
|
||||
@@ -550,8 +581,8 @@ def get_benchmark_args(
|
||||
model=None,
|
||||
tokenizer=tokenizer,
|
||||
num_prompts=num_prompts,
|
||||
sharegpt_output_len=None,
|
||||
sharegpt_context_len=None,
|
||||
sharegpt_output_len=sharegpt_output_len,
|
||||
sharegpt_context_len=sharegpt_context_len,
|
||||
random_input_len=random_input_len,
|
||||
random_output_len=random_output_len,
|
||||
random_range_ratio=0.0,
|
||||
@@ -567,6 +598,8 @@ def get_benchmark_args(
|
||||
apply_chat_template=False,
|
||||
profile=None,
|
||||
lora_name=None,
|
||||
prompt_suffix="",
|
||||
pd_seperated=pd_seperated,
|
||||
)
|
||||
|
||||
|
||||
@@ -580,6 +613,7 @@ def run_bench_serving(
|
||||
tokenizer=None,
|
||||
random_input_len=4096,
|
||||
random_output_len=2048,
|
||||
sharegpt_context_len=None,
|
||||
disable_stream=False,
|
||||
disable_ignore_eos=False,
|
||||
need_warmup=False,
|
||||
@@ -602,6 +636,7 @@ def run_bench_serving(
|
||||
num_prompts=num_prompts,
|
||||
random_input_len=random_input_len,
|
||||
random_output_len=random_output_len,
|
||||
sharegpt_context_len=sharegpt_context_len,
|
||||
request_rate=request_rate,
|
||||
disable_stream=disable_stream,
|
||||
disable_ignore_eos=disable_ignore_eos,
|
||||
@@ -626,6 +661,7 @@ def run_bench_serving_multi(
|
||||
other_server_args,
|
||||
benchmark_args,
|
||||
need_warmup=False,
|
||||
pd_seperated=False,
|
||||
):
|
||||
# Launch the server
|
||||
process = popen_launch_server(
|
||||
@@ -633,6 +669,7 @@ def run_bench_serving_multi(
|
||||
base_url,
|
||||
timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
|
||||
other_args=other_server_args,
|
||||
pd_seperated=pd_seperated,
|
||||
)
|
||||
|
||||
# run benchmark for all
|
||||
@@ -665,7 +702,7 @@ def run_bench_one_batch(model, other_args):
|
||||
"128",
|
||||
"--output",
|
||||
"8",
|
||||
*other_args,
|
||||
*[str(x) for x in other_args],
|
||||
]
|
||||
process = subprocess.Popen(command, stdout=subprocess.PIPE, stderr=subprocess.PIPE)
|
||||
|
||||
@@ -816,7 +853,7 @@ def run_command_and_capture_output(command, env: Optional[dict] = None):
|
||||
stdout = open(STDOUT_FILENAME, "w")
|
||||
stderr = open(STDERR_FILENAME, "w")
|
||||
process = subprocess.Popen(
|
||||
command, stdout=stdout, stderr=stderr, env=env, text=True
|
||||
command, stdout=stdout, stderr=stdout, env=env, text=True
|
||||
)
|
||||
|
||||
# Launch a thread to stream the output
|
||||
@@ -914,3 +951,78 @@ def run_mulit_request_test(
|
||||
def write_github_step_summary(content):
|
||||
with open(os.environ["GITHUB_STEP_SUMMARY"], "a") as f:
|
||||
f.write(content)
|
||||
|
||||
|
||||
def run_logprob_check(self: unittest.TestCase, arg: Tuple):
|
||||
(
|
||||
input_len,
|
||||
output_len,
|
||||
temperature,
|
||||
logprob_start_len,
|
||||
return_logprob,
|
||||
top_logprobs_num,
|
||||
) = arg
|
||||
input_ids = list(range(input_len))
|
||||
|
||||
response = requests.post(
|
||||
self.base_url + "/generate",
|
||||
json={
|
||||
"input_ids": input_ids,
|
||||
"sampling_params": {
|
||||
"temperature": temperature,
|
||||
"max_new_tokens": output_len,
|
||||
"ignore_eos": True,
|
||||
},
|
||||
"return_logprob": return_logprob,
|
||||
"logprob_start_len": logprob_start_len,
|
||||
"top_logprobs_num": top_logprobs_num,
|
||||
},
|
||||
)
|
||||
response_json = response.json()
|
||||
|
||||
res = response_json
|
||||
self.assertEqual(res["meta_info"]["prompt_tokens"], input_len)
|
||||
self.assertEqual(res["meta_info"]["completion_tokens"], output_len)
|
||||
|
||||
# Test the number of tokens are correct
|
||||
if return_logprob:
|
||||
self.assertEqual(
|
||||
len(res["meta_info"]["input_token_logprobs"]) + logprob_start_len,
|
||||
res["meta_info"]["prompt_tokens"],
|
||||
)
|
||||
self.assertEqual(len(res["meta_info"]["output_token_logprobs"]), output_len)
|
||||
|
||||
if top_logprobs_num:
|
||||
self.assertEqual(
|
||||
len(res["meta_info"]["input_top_logprobs"]) + logprob_start_len,
|
||||
res["meta_info"]["prompt_tokens"],
|
||||
)
|
||||
self.assertEqual(len(res["meta_info"]["output_top_logprobs"]), output_len)
|
||||
|
||||
for i in range(output_len):
|
||||
self.assertEqual(
|
||||
len(res["meta_info"]["output_top_logprobs"][i]),
|
||||
top_logprobs_num,
|
||||
)
|
||||
|
||||
# Test the top-1 tokens are the same as output tokens if temperature == 0
|
||||
if temperature == 0:
|
||||
rank = 0
|
||||
while rank < len(res["meta_info"]["output_top_logprobs"][i]):
|
||||
try:
|
||||
self.assertListEqual(
|
||||
res["meta_info"]["output_token_logprobs"][i],
|
||||
res["meta_info"]["output_top_logprobs"][i][rank],
|
||||
)
|
||||
break
|
||||
except AssertionError:
|
||||
# There's a tie. Allow the second item in this case.
|
||||
if (
|
||||
res["meta_info"]["output_top_logprobs"][i][rank][0]
|
||||
== res["meta_info"]["output_top_logprobs"][i][rank + 1][
|
||||
0
|
||||
]
|
||||
):
|
||||
rank += 1
|
||||
else:
|
||||
raise
|
||||
|
||||
Reference in New Issue
Block a user