Support penalty in overlap mode; return logprob with chunked prefill; improve benchmark scripts (#3988)

Co-authored-by: SangBin Cho <rkooo567@gmail.com>
Co-authored-by: dhou-xai <dhou@x.ai>
Co-authored-by: Hanming Lu <hanming_lu@berkeley.edu>
This commit is contained in:
Lianmin Zheng
2025-03-03 00:12:04 -08:00
parent 0194948fd9
commit ac2387279e
86 changed files with 4116 additions and 2015 deletions

View File

@@ -8,7 +8,6 @@ Usage:
python3 -m sglang.bench_serving --backend sglang --num-prompt 10
python3 -m sglang.bench_serving --backend sglang --dataset-name random --num-prompts 3000 --random-input 1024 --random-output 1024 --random-range-ratio 0.5
python3 -m sglang.bench_serving --backend sglang --dataset-name random --request-rate-range 1,2,4,8,16,32 --random-input 4096 --random-output 1024 --random-range-ratio 0.125 --multi
"""
import argparse
@@ -71,6 +70,10 @@ def remove_prefix(text: str, prefix: str) -> str:
return text[len(prefix) :] if text.startswith(prefix) else text
def remove_suffix(text: str, suffix: str) -> str:
return text[: -len(suffix)] if text.endswith(suffix) else text
def get_auth_headers() -> Dict[str, str]:
api_key = os.environ.get("OPENAI_API_KEY")
if api_key:
@@ -79,7 +82,7 @@ def get_auth_headers() -> Dict[str, str]:
return {}
# trt llm not support ignore_eos
# trt llm does not support ignore_eos
# https://github.com/triton-inference-server/tensorrtllm_backend/issues/505
async def async_request_trt_llm(
request_func_input: RequestFuncInput,
@@ -179,6 +182,7 @@ async def async_request_openai_completions(
output.prompt_len = request_func_input.prompt_len
generated_text = ""
output_len = request_func_input.output_len
ttft = 0.0
st = time.perf_counter()
most_recent_timestamp = st
@@ -215,11 +219,14 @@ async def async_request_openai_completions(
most_recent_timestamp = timestamp
generated_text += data["choices"][0]["text"]
output_len = data.get("usage", {}).get(
"completion_tokens", output_len
)
output.generated_text = generated_text
output.success = True
output.latency = latency
output.output_len = request_func_input.output_len
output.output_len = output_len
else:
output.error = response.reason or ""
output.success = False
@@ -339,9 +346,11 @@ async def async_request_sglang_generate(
output.prompt_len = request_func_input.prompt_len
generated_text = ""
output_len = request_func_input.output_len
ttft = 0.0
st = time.perf_counter()
most_recent_timestamp = st
last_output_len = 0
try:
async with session.post(
url=api_url, json=payload, headers=headers
@@ -365,6 +374,9 @@ async def async_request_sglang_generate(
# want to check a token was generated
if data["text"]:
timestamp = time.perf_counter()
generated_text = data["text"]
output_len = data["meta_info"]["completion_tokens"]
# First token
if ttft == 0.0:
ttft = time.perf_counter() - st
@@ -372,7 +384,13 @@ async def async_request_sglang_generate(
# Decoding phase
else:
output.itl.append(timestamp - most_recent_timestamp)
num_new_tokens = output_len - last_output_len
if num_new_tokens == 0:
continue
adjust_itl = (
timestamp - most_recent_timestamp
) / num_new_tokens
output.itl.extend([adjust_itl] * num_new_tokens)
most_recent_timestamp = timestamp
generated_text = data["text"]
@@ -380,7 +398,7 @@ async def async_request_sglang_generate(
output.generated_text = generated_text
output.success = True
output.latency = latency
output.output_len = request_func_input.output_len
output.output_len = output_len
else:
output.error = response.reason or ""
output.success = False
@@ -388,6 +406,7 @@ async def async_request_sglang_generate(
output.success = False
exc_info = sys.exc_info()
output.error = "".join(traceback.format_exception(*exc_info))
print(f"{output.error=}")
if pbar:
pbar.update(1)
@@ -461,6 +480,7 @@ def get_dataset(args, tokenizer):
tokenizer=tokenizer,
fixed_output_len=args.sharegpt_output_len,
context_len=args.sharegpt_context_len,
prompt_suffix=args.prompt_suffix,
apply_chat_template=args.apply_chat_template,
)
elif args.dataset_name == "random":
@@ -521,7 +541,9 @@ class BenchmarkMetrics:
mean_itl_ms: float
median_itl_ms: float
std_itl_ms: float
p95_itl_ms: float
p99_itl_ms: float
max_itl_ms: float
mean_e2e_latency_ms: float
median_e2e_latency_ms: float
std_e2e_latency_ms: float
@@ -572,6 +594,7 @@ def sample_sharegpt_requests(
tokenizer: PreTrainedTokenizerBase,
fixed_output_len: Optional[int] = None,
context_len: Optional[int] = None,
prompt_suffix: Optional[str] = "",
apply_chat_template=False,
) -> List[Tuple[str, int, int]]:
if fixed_output_len is not None and fixed_output_len < 4:
@@ -584,11 +607,19 @@ def sample_sharegpt_requests(
# Load the dataset.
with open(dataset_path) as f:
dataset = json.load(f)
# Filter out the conversations with less than 2 turns.
dataset = [data for data in dataset if len(data["conversations"]) >= 2]
dataset = [
data
for data in dataset
if len(data.get("conversations", data.get("conversation", []))) >= 2
]
# Only keep the first two turns of each conversation.
dataset = [
(data["conversations"][0]["value"], data["conversations"][1]["value"])
(
data.get("conversations", data.get("conversation", []))[0]["value"],
data.get("conversations", data.get("conversation", []))[1]["value"],
)
for data in dataset
]
@@ -603,6 +634,8 @@ def sample_sharegpt_requests(
# Tokenize the prompts and completions.
prompt = dataset[i][0]
if prompt_suffix:
prompt = prompt
if apply_chat_template:
prompt = tokenizer.apply_chat_template(
@@ -666,10 +699,17 @@ def sample_random_requests(
with open(dataset_path) as f:
dataset = json.load(f)
# Filter out the conversations with less than 2 turns.
dataset = [data for data in dataset if len(data["conversations"]) >= 2]
dataset = [
data
for data in dataset
if len(data.get("conversations", data.get("conversation", []))) >= 2
]
# Only keep the first two turns of each conversation.
dataset = [
(data["conversations"][0]["value"], data["conversations"][1]["value"])
(
data.get("conversations", data.get("conversation", []))[0]["value"],
data.get("conversations", data.get("conversation", []))[1]["value"],
)
for data in dataset
]
# Shuffle the dataset.
@@ -895,7 +935,9 @@ def calculate_metrics(
mean_itl_ms=np.mean(itls or 0) * 1000,
median_itl_ms=np.median(itls or 0) * 1000,
std_itl_ms=np.std(itls or 0) * 1000,
p95_itl_ms=np.percentile(itls or 0, 95) * 1000,
p99_itl_ms=np.percentile(itls or 0, 99) * 1000,
max_itl_ms=np.max(itls or 0) * 1000,
mean_e2e_latency_ms=np.mean(e2e_latencies) * 1000,
median_e2e_latency_ms=np.median(e2e_latencies) * 1000,
std_e2e_latency_ms=np.std(e2e_latencies) * 1000,
@@ -919,6 +961,7 @@ async def benchmark(
lora_name: str,
extra_request_body: Dict[str, Any],
profile: bool,
pd_seperated: bool = False,
):
if backend in ASYNC_REQUEST_FUNCS:
request_func = ASYNC_REQUEST_FUNCS[backend]
@@ -1004,6 +1047,17 @@ async def benchmark(
if pbar is not None:
pbar.close()
if "sglang" in backend:
server_info = requests.get(base_url + "/get_server_info")
if pd_seperated:
accept_length = server_info.json()["decode"][0].get(
"avg_spec_accept_length", None
)
else:
accept_length = server_info.json().get("avg_spec_accept_length", None)
else:
accept_length = None
# Compute metrics and print results
benchmark_duration = time.perf_counter() - benchmark_start_time
metrics, output_lens = calculate_metrics(
@@ -1053,6 +1107,8 @@ async def benchmark(
)
)
print("{:<40} {:<10.2f}".format("Concurrency:", metrics.concurrency))
if accept_length:
print("{:<40} {:<10.2f}".format("Accept length:", accept_length))
print("{s:{c}^{n}}".format(s="End-to-End Latency", n=50, c="-"))
print(
"{:<40} {:<10.2f}".format("Mean E2E Latency (ms):", metrics.mean_e2e_latency_ms)
@@ -1066,16 +1122,12 @@ async def benchmark(
print("{:<40} {:<10.2f}".format("Mean TTFT (ms):", metrics.mean_ttft_ms))
print("{:<40} {:<10.2f}".format("Median TTFT (ms):", metrics.median_ttft_ms))
print("{:<40} {:<10.2f}".format("P99 TTFT (ms):", metrics.p99_ttft_ms))
print(
"{s:{c}^{n}}".format(s="Time per Output Token (excl. 1st token)", n=50, c="-")
)
print("{:<40} {:<10.2f}".format("Mean TPOT (ms):", metrics.mean_tpot_ms))
print("{:<40} {:<10.2f}".format("Median TPOT (ms):", metrics.median_tpot_ms))
print("{:<40} {:<10.2f}".format("P99 TPOT (ms):", metrics.p99_tpot_ms))
print("{s:{c}^{n}}".format(s="Inter-token Latency", n=50, c="-"))
print("{s:{c}^{n}}".format(s="Inter-Token Latency", n=50, c="-"))
print("{:<40} {:<10.2f}".format("Mean ITL (ms):", metrics.mean_itl_ms))
print("{:<40} {:<10.2f}".format("Median ITL (ms):", metrics.median_itl_ms))
print("{:<40} {:<10.2f}".format("P95 ITL (ms):", metrics.p95_itl_ms))
print("{:<40} {:<10.2f}".format("P99 ITL (ms):", metrics.p99_itl_ms))
print("{:<40} {:<10.2f}".format("Max ITL (ms):", metrics.max_itl_ms))
print("=" * 50)
if (
@@ -1117,8 +1169,10 @@ async def benchmark(
"mean_itl_ms": metrics.mean_itl_ms,
"median_itl_ms": metrics.median_itl_ms,
"std_itl_ms": metrics.std_itl_ms,
"p95_itl_ms": metrics.p95_itl_ms,
"p99_itl_ms": metrics.p99_itl_ms,
"concurrency": metrics.concurrency,
"accept_length": accept_length,
}
else:
print(f"Error running benchmark for request rate: {request_rate}")
@@ -1151,14 +1205,6 @@ async def benchmark(
return result
def parse_request_rate_range(request_rate_range):
if len(request_rate_range.split(",")) == 3:
start, stop, step = map(int, request_rate_range.split(","))
return list(range(start, stop, step))
else:
return list(map(int, request_rate_range.split(",")))
def check_chat_template(model_path):
try:
tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
@@ -1168,6 +1214,12 @@ def check_chat_template(model_path):
return False
def set_global_args(args_: argparse.Namespace):
"""Set the global args."""
global args
args = args_
def run_benchmark(args_: argparse.Namespace):
global args
args = args_
@@ -1176,6 +1228,8 @@ def run_benchmark(args_: argparse.Namespace):
if not hasattr(args, "max_concurrency"):
args.max_concurrency = None
print(f"benchmark_args={args}")
# Set global environments
set_ulimit()
random.seed(args.seed)
@@ -1272,49 +1326,26 @@ def run_benchmark(args_: argparse.Namespace):
backend = args.backend
model_id = args.model
tokenizer_id = args.tokenizer if args.tokenizer is not None else args.model
tokenizer = get_tokenizer(tokenizer_id)
input_requests = get_dataset(args, tokenizer)
if not args.multi:
return asyncio.run(
benchmark(
backend=backend,
api_url=api_url,
base_url=base_url,
model_id=model_id,
tokenizer=tokenizer,
input_requests=input_requests,
request_rate=args.request_rate,
max_concurrency=args.max_concurrency,
disable_tqdm=args.disable_tqdm,
lora_name=args.lora_name,
extra_request_body=extra_request_body,
profile=args.profile,
)
return asyncio.run(
benchmark(
backend=backend,
api_url=api_url,
base_url=base_url,
model_id=model_id,
tokenizer=tokenizer,
input_requests=input_requests,
request_rate=args.request_rate,
max_concurrency=args.max_concurrency,
disable_tqdm=args.disable_tqdm,
lora_name=args.lora_name,
extra_request_body=extra_request_body,
profile=args.profile,
pd_seperated=args.pd_seperated,
)
else:
# Benchmark multiple rps. TODO: use a fixed duration to compute num_prompts
request_rates = parse_request_rate_range(args.request_rate_range)
for rate in request_rates:
asyncio.run(
benchmark(
backend=backend,
api_url=api_url,
base_url=base_url,
model_id=model_id,
tokenizer=tokenizer,
input_requests=input_requests,
request_rate=rate,
max_concurrency=args.max_concurrency,
disable_tqdm=args.disable_tqdm,
lora_name=args.lora_name,
extra_request_body=extra_request_body,
profile=args.profile,
)
)
)
def set_ulimit(target_soft_limit=65535):
@@ -1428,17 +1459,6 @@ if __name__ == "__main__":
"actual request rate may be lower than specified with --request-rate, "
"if the server is not processing requests fast enough to keep up.",
)
parser.add_argument(
"--multi",
action="store_true",
help="Use request rate range rather than single value.",
)
parser.add_argument(
"--request-rate-range",
type=str,
default="2,34,2",
help="Range of request rates in the format start,stop,step. Default is 2,34,2. It also supports a list of request rates, requiring the parameters to not equal three.",
)
parser.add_argument("--output-file", type=str, help="Output JSONL file name.")
parser.add_argument(
"--disable-tqdm",
@@ -1485,6 +1505,17 @@ if __name__ == "__main__":
default=None,
help="The name of LoRA adapter",
)
parser.add_argument(
"--prompt-suffix",
type=str,
default="",
help="Suffix applied to the end of all user prompts, followed by assistant prompt suffix.",
)
parser.add_argument(
"--pd-seperated",
action="store_true",
help="Benchmark PD disaggregation server",
)
group = parser.add_argument_group("generated-shared-prefix dataset arguments")
group.add_argument(