Log if cuda graph is used & extend cuda graph capture to cuda-graph-max-bs (#6201)
Co-authored-by: SangBin Cho <rkooo567@gmail.com>
This commit is contained in:
@@ -25,6 +25,7 @@ import requests
|
||||
from sglang.srt.entrypoints.http_server import launch_server
|
||||
from sglang.srt.server_args import ServerArgs
|
||||
from sglang.srt.utils import kill_process_tree
|
||||
from sglang.test.test_utils import is_in_ci, write_github_step_summary
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
@@ -33,9 +34,13 @@ class BenchArgs:
|
||||
batch_size: Tuple[int] = (1,)
|
||||
input_len: Tuple[int] = (1024,)
|
||||
output_len: Tuple[int] = (16,)
|
||||
temperature: float = 0.0
|
||||
return_logprob: bool = False
|
||||
input_len_step_percentage: float = 0.0
|
||||
result_filename: str = "result.jsonl"
|
||||
base_url: str = ""
|
||||
skip_warmup: bool = False
|
||||
show_report: bool = False
|
||||
|
||||
@staticmethod
|
||||
def add_cli_args(parser: argparse.ArgumentParser):
|
||||
@@ -49,11 +54,19 @@ class BenchArgs:
|
||||
parser.add_argument(
|
||||
"--output-len", type=int, nargs="+", default=BenchArgs.output_len
|
||||
)
|
||||
parser.add_argument("--temperature", type=float, default=BenchArgs.temperature)
|
||||
parser.add_argument("--return-logprob", action="store_true")
|
||||
parser.add_argument(
|
||||
"--input-len-step-percentage",
|
||||
type=float,
|
||||
default=BenchArgs.input_len_step_percentage,
|
||||
)
|
||||
parser.add_argument(
|
||||
"--result-filename", type=str, default=BenchArgs.result_filename
|
||||
)
|
||||
parser.add_argument("--base-url", type=str, default=BenchArgs.base_url)
|
||||
parser.add_argument("--skip-warmup", action="store_true")
|
||||
parser.add_argument("--show-report", action="store_true")
|
||||
|
||||
@classmethod
|
||||
def from_cli_args(cls, args: argparse.Namespace):
|
||||
@@ -99,36 +112,89 @@ def run_one_case(
|
||||
batch_size: int,
|
||||
input_len: int,
|
||||
output_len: int,
|
||||
temperature: float,
|
||||
return_logprob: bool,
|
||||
input_len_step_percentage: float,
|
||||
run_name: str,
|
||||
result_filename: str,
|
||||
):
|
||||
input_ids = [
|
||||
[int(x) for x in np.random.randint(0, high=16384, size=(input_len,))]
|
||||
for _ in range(batch_size)
|
||||
requests.post(url + "/flush_cache")
|
||||
input_lens = [
|
||||
int(input_len * (1 + (i - (batch_size - 1) / 2) * input_len_step_percentage))
|
||||
for i in range(batch_size)
|
||||
]
|
||||
input_ids = [
|
||||
[int(x) for x in np.random.randint(0, high=16384, size=(input_lens[i],))]
|
||||
for i in range(batch_size)
|
||||
]
|
||||
|
||||
use_structured_outputs = False
|
||||
if use_structured_outputs:
|
||||
texts = []
|
||||
for _ in range(batch_size):
|
||||
texts.append(
|
||||
"Human: What is the capital city of france? can you give as many trivial information as possible about that city? answer in json.\n"
|
||||
* 50
|
||||
+ "Assistant:"
|
||||
)
|
||||
json_schema = "$$ANY$$"
|
||||
else:
|
||||
json_schema = None
|
||||
|
||||
tic = time.time()
|
||||
response = requests.post(
|
||||
url + "/generate",
|
||||
json={
|
||||
# "text": texts,
|
||||
"input_ids": input_ids,
|
||||
"sampling_params": {
|
||||
"temperature": 0,
|
||||
"temperature": temperature,
|
||||
"max_new_tokens": output_len,
|
||||
"ignore_eos": True,
|
||||
"json_schema": json_schema,
|
||||
},
|
||||
"return_logprob": return_logprob,
|
||||
"stream": True,
|
||||
},
|
||||
stream=True,
|
||||
)
|
||||
latency = time.time() - tic
|
||||
|
||||
_ = response.json()
|
||||
output_throughput = batch_size * output_len / latency
|
||||
# The TTFT of the last request in the batch
|
||||
ttft = 0.0
|
||||
for chunk in response.iter_lines(decode_unicode=False):
|
||||
chunk = chunk.decode("utf-8")
|
||||
if chunk and chunk.startswith("data:"):
|
||||
if chunk == "data: [DONE]":
|
||||
break
|
||||
data = json.loads(chunk[5:].strip("\n"))
|
||||
if "error" in data:
|
||||
raise RuntimeError(f"Request has failed. {data}.")
|
||||
|
||||
assert (
|
||||
data["meta_info"]["finish_reason"] is None
|
||||
or data["meta_info"]["finish_reason"]["type"] == "length"
|
||||
)
|
||||
if data["meta_info"]["completion_tokens"] == 1:
|
||||
ttft = time.time() - tic
|
||||
|
||||
latency = time.time() - tic
|
||||
input_throughput = batch_size * input_len / ttft
|
||||
output_throughput = batch_size * output_len / (latency - ttft)
|
||||
overall_throughput = batch_size * (input_len + output_len) / latency
|
||||
|
||||
server_info = requests.get(url + "/get_server_info").json()
|
||||
acc_length = server_info["internal_states"][0].get("avg_spec_accept_length", None)
|
||||
last_gen_throughput = server_info["internal_states"][0]["last_gen_throughput"]
|
||||
|
||||
print(f"batch size: {batch_size}")
|
||||
print(f"input_len: {input_len}")
|
||||
print(f"output_len: {output_len}")
|
||||
print(f"latency: {latency:.2f} s")
|
||||
print(f"output throughput: {output_throughput:.2f} token/s")
|
||||
print(f"(input + output) throughput: {overall_throughput:.2f} token/s")
|
||||
print(f"ttft: {ttft:.2f} s")
|
||||
print(f"Last generation throughput: {last_gen_throughput:.2f} tok/s")
|
||||
print(f"Input throughput: {input_throughput:.2f} tok/s")
|
||||
if output_len != 1:
|
||||
print(f"output throughput: {output_throughput:.2f} tok/s")
|
||||
|
||||
if result_filename:
|
||||
with open(result_filename, "a") as fout:
|
||||
@@ -140,9 +206,21 @@ def run_one_case(
|
||||
"latency": round(latency, 4),
|
||||
"output_throughput": round(output_throughput, 2),
|
||||
"overall_throughput": round(overall_throughput, 2),
|
||||
"last_gen_throughput": round(last_gen_throughput, 2),
|
||||
}
|
||||
fout.write(json.dumps(res) + "\n")
|
||||
|
||||
return (
|
||||
batch_size,
|
||||
latency,
|
||||
ttft,
|
||||
input_throughput,
|
||||
output_throughput,
|
||||
overall_throughput,
|
||||
last_gen_throughput,
|
||||
acc_length,
|
||||
)
|
||||
|
||||
|
||||
def run_benchmark(server_args: ServerArgs, bench_args: BenchArgs):
|
||||
if bench_args.base_url:
|
||||
@@ -152,27 +230,38 @@ def run_benchmark(server_args: ServerArgs, bench_args: BenchArgs):
|
||||
|
||||
# warmup
|
||||
if not bench_args.skip_warmup:
|
||||
print("=" * 8 + " Warmup Begin " + "=" * 8)
|
||||
run_one_case(
|
||||
base_url,
|
||||
batch_size=16,
|
||||
input_len=1024,
|
||||
output_len=16,
|
||||
temperature=bench_args.temperature,
|
||||
return_logprob=bench_args.return_logprob,
|
||||
input_len_step_percentage=bench_args.input_len_step_percentage,
|
||||
run_name="",
|
||||
result_filename="",
|
||||
)
|
||||
print("=" * 8 + " Warmup End " + "=" * 8 + "\n")
|
||||
|
||||
# benchmark
|
||||
result = []
|
||||
try:
|
||||
for bs, il, ol in itertools.product(
|
||||
bench_args.batch_size, bench_args.input_len, bench_args.output_len
|
||||
):
|
||||
run_one_case(
|
||||
base_url,
|
||||
bs,
|
||||
il,
|
||||
ol,
|
||||
bench_args.run_name,
|
||||
bench_args.result_filename,
|
||||
result.append(
|
||||
run_one_case(
|
||||
base_url,
|
||||
bs,
|
||||
il,
|
||||
ol,
|
||||
temperature=bench_args.temperature,
|
||||
return_logprob=bench_args.return_logprob,
|
||||
input_len_step_percentage=bench_args.input_len_step_percentage,
|
||||
run_name=bench_args.run_name,
|
||||
result_filename=bench_args.result_filename,
|
||||
)
|
||||
)
|
||||
finally:
|
||||
if proc:
|
||||
@@ -180,6 +269,45 @@ def run_benchmark(server_args: ServerArgs, bench_args: BenchArgs):
|
||||
|
||||
print(f"\nResults are saved to {bench_args.result_filename}")
|
||||
|
||||
if not bench_args.show_report:
|
||||
return
|
||||
|
||||
summary = " | batch size | latency (s) | input throughput (tok/s) | output throughput (tok/s) | acc length | ITL (ms) | input price ($/1M) | output price ($/1M) |\n"
|
||||
summary += "| ---------- | ----------- | ------------------------- | ------------------------- | ---------- | -------- | ------------------ | ------------------- |\n"
|
||||
|
||||
for (
|
||||
batch_size,
|
||||
latency,
|
||||
ttft,
|
||||
input_throughput,
|
||||
output_throughput,
|
||||
overall_throughput,
|
||||
last_gen_throughput,
|
||||
acc_length,
|
||||
) in result:
|
||||
hourly_cost = 2 * server_args.tp_size # $2/hour for one H100
|
||||
input_util = 0.7
|
||||
accept_length = round(acc_length, 2) if acc_length is not None else "n/a"
|
||||
line = (
|
||||
f"| {batch_size} | "
|
||||
f"{latency:.2f} | "
|
||||
f"{input_throughput:.2f} | "
|
||||
f"{output_throughput:.2f} | "
|
||||
f"{accept_length} | "
|
||||
f"{1 / (output_throughput/batch_size) * 1000:.2f} | "
|
||||
f"{1e6 / (input_throughput * input_util) / 3600 * hourly_cost:.2f} | "
|
||||
f"{1e6 / output_throughput / 3600 * hourly_cost:.2f} |\n"
|
||||
)
|
||||
summary += line
|
||||
|
||||
# print metrics table
|
||||
print(summary)
|
||||
|
||||
if is_in_ci():
|
||||
write_github_step_summary(
|
||||
f"### Test Nightly Benchmark (bench_one_batch) \n{summary}"
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
|
||||
Reference in New Issue
Block a user