Support penalty in overlap mode; return logprob with chunked prefill; improve benchmark scripts (#3988)
Co-authored-by: SangBin Cho <rkooo567@gmail.com> Co-authored-by: dhou-xai <dhou@x.ai> Co-authored-by: Hanming Lu <hanming_lu@berkeley.edu>
This commit is contained in:
@@ -8,7 +8,6 @@ Usage:
|
||||
python3 -m sglang.bench_serving --backend sglang --num-prompt 10
|
||||
|
||||
python3 -m sglang.bench_serving --backend sglang --dataset-name random --num-prompts 3000 --random-input 1024 --random-output 1024 --random-range-ratio 0.5
|
||||
python3 -m sglang.bench_serving --backend sglang --dataset-name random --request-rate-range 1,2,4,8,16,32 --random-input 4096 --random-output 1024 --random-range-ratio 0.125 --multi
|
||||
"""
|
||||
|
||||
import argparse
|
||||
@@ -71,6 +70,10 @@ def remove_prefix(text: str, prefix: str) -> str:
|
||||
return text[len(prefix) :] if text.startswith(prefix) else text
|
||||
|
||||
|
||||
def remove_suffix(text: str, suffix: str) -> str:
|
||||
return text[: -len(suffix)] if text.endswith(suffix) else text
|
||||
|
||||
|
||||
def get_auth_headers() -> Dict[str, str]:
|
||||
api_key = os.environ.get("OPENAI_API_KEY")
|
||||
if api_key:
|
||||
@@ -79,7 +82,7 @@ def get_auth_headers() -> Dict[str, str]:
|
||||
return {}
|
||||
|
||||
|
||||
# trt llm not support ignore_eos
|
||||
# trt llm does not support ignore_eos
|
||||
# https://github.com/triton-inference-server/tensorrtllm_backend/issues/505
|
||||
async def async_request_trt_llm(
|
||||
request_func_input: RequestFuncInput,
|
||||
@@ -179,6 +182,7 @@ async def async_request_openai_completions(
|
||||
output.prompt_len = request_func_input.prompt_len
|
||||
|
||||
generated_text = ""
|
||||
output_len = request_func_input.output_len
|
||||
ttft = 0.0
|
||||
st = time.perf_counter()
|
||||
most_recent_timestamp = st
|
||||
@@ -215,11 +219,14 @@ async def async_request_openai_completions(
|
||||
|
||||
most_recent_timestamp = timestamp
|
||||
generated_text += data["choices"][0]["text"]
|
||||
output_len = data.get("usage", {}).get(
|
||||
"completion_tokens", output_len
|
||||
)
|
||||
|
||||
output.generated_text = generated_text
|
||||
output.success = True
|
||||
output.latency = latency
|
||||
output.output_len = request_func_input.output_len
|
||||
output.output_len = output_len
|
||||
else:
|
||||
output.error = response.reason or ""
|
||||
output.success = False
|
||||
@@ -339,9 +346,11 @@ async def async_request_sglang_generate(
|
||||
output.prompt_len = request_func_input.prompt_len
|
||||
|
||||
generated_text = ""
|
||||
output_len = request_func_input.output_len
|
||||
ttft = 0.0
|
||||
st = time.perf_counter()
|
||||
most_recent_timestamp = st
|
||||
last_output_len = 0
|
||||
try:
|
||||
async with session.post(
|
||||
url=api_url, json=payload, headers=headers
|
||||
@@ -365,6 +374,9 @@ async def async_request_sglang_generate(
|
||||
# want to check a token was generated
|
||||
if data["text"]:
|
||||
timestamp = time.perf_counter()
|
||||
generated_text = data["text"]
|
||||
output_len = data["meta_info"]["completion_tokens"]
|
||||
|
||||
# First token
|
||||
if ttft == 0.0:
|
||||
ttft = time.perf_counter() - st
|
||||
@@ -372,7 +384,13 @@ async def async_request_sglang_generate(
|
||||
|
||||
# Decoding phase
|
||||
else:
|
||||
output.itl.append(timestamp - most_recent_timestamp)
|
||||
num_new_tokens = output_len - last_output_len
|
||||
if num_new_tokens == 0:
|
||||
continue
|
||||
adjust_itl = (
|
||||
timestamp - most_recent_timestamp
|
||||
) / num_new_tokens
|
||||
output.itl.extend([adjust_itl] * num_new_tokens)
|
||||
|
||||
most_recent_timestamp = timestamp
|
||||
generated_text = data["text"]
|
||||
@@ -380,7 +398,7 @@ async def async_request_sglang_generate(
|
||||
output.generated_text = generated_text
|
||||
output.success = True
|
||||
output.latency = latency
|
||||
output.output_len = request_func_input.output_len
|
||||
output.output_len = output_len
|
||||
else:
|
||||
output.error = response.reason or ""
|
||||
output.success = False
|
||||
@@ -388,6 +406,7 @@ async def async_request_sglang_generate(
|
||||
output.success = False
|
||||
exc_info = sys.exc_info()
|
||||
output.error = "".join(traceback.format_exception(*exc_info))
|
||||
print(f"{output.error=}")
|
||||
|
||||
if pbar:
|
||||
pbar.update(1)
|
||||
@@ -461,6 +480,7 @@ def get_dataset(args, tokenizer):
|
||||
tokenizer=tokenizer,
|
||||
fixed_output_len=args.sharegpt_output_len,
|
||||
context_len=args.sharegpt_context_len,
|
||||
prompt_suffix=args.prompt_suffix,
|
||||
apply_chat_template=args.apply_chat_template,
|
||||
)
|
||||
elif args.dataset_name == "random":
|
||||
@@ -521,7 +541,9 @@ class BenchmarkMetrics:
|
||||
mean_itl_ms: float
|
||||
median_itl_ms: float
|
||||
std_itl_ms: float
|
||||
p95_itl_ms: float
|
||||
p99_itl_ms: float
|
||||
max_itl_ms: float
|
||||
mean_e2e_latency_ms: float
|
||||
median_e2e_latency_ms: float
|
||||
std_e2e_latency_ms: float
|
||||
@@ -572,6 +594,7 @@ def sample_sharegpt_requests(
|
||||
tokenizer: PreTrainedTokenizerBase,
|
||||
fixed_output_len: Optional[int] = None,
|
||||
context_len: Optional[int] = None,
|
||||
prompt_suffix: Optional[str] = "",
|
||||
apply_chat_template=False,
|
||||
) -> List[Tuple[str, int, int]]:
|
||||
if fixed_output_len is not None and fixed_output_len < 4:
|
||||
@@ -584,11 +607,19 @@ def sample_sharegpt_requests(
|
||||
# Load the dataset.
|
||||
with open(dataset_path) as f:
|
||||
dataset = json.load(f)
|
||||
|
||||
# Filter out the conversations with less than 2 turns.
|
||||
dataset = [data for data in dataset if len(data["conversations"]) >= 2]
|
||||
dataset = [
|
||||
data
|
||||
for data in dataset
|
||||
if len(data.get("conversations", data.get("conversation", []))) >= 2
|
||||
]
|
||||
# Only keep the first two turns of each conversation.
|
||||
dataset = [
|
||||
(data["conversations"][0]["value"], data["conversations"][1]["value"])
|
||||
(
|
||||
data.get("conversations", data.get("conversation", []))[0]["value"],
|
||||
data.get("conversations", data.get("conversation", []))[1]["value"],
|
||||
)
|
||||
for data in dataset
|
||||
]
|
||||
|
||||
@@ -603,6 +634,8 @@ def sample_sharegpt_requests(
|
||||
|
||||
# Tokenize the prompts and completions.
|
||||
prompt = dataset[i][0]
|
||||
if prompt_suffix:
|
||||
prompt = prompt
|
||||
|
||||
if apply_chat_template:
|
||||
prompt = tokenizer.apply_chat_template(
|
||||
@@ -666,10 +699,17 @@ def sample_random_requests(
|
||||
with open(dataset_path) as f:
|
||||
dataset = json.load(f)
|
||||
# Filter out the conversations with less than 2 turns.
|
||||
dataset = [data for data in dataset if len(data["conversations"]) >= 2]
|
||||
dataset = [
|
||||
data
|
||||
for data in dataset
|
||||
if len(data.get("conversations", data.get("conversation", []))) >= 2
|
||||
]
|
||||
# Only keep the first two turns of each conversation.
|
||||
dataset = [
|
||||
(data["conversations"][0]["value"], data["conversations"][1]["value"])
|
||||
(
|
||||
data.get("conversations", data.get("conversation", []))[0]["value"],
|
||||
data.get("conversations", data.get("conversation", []))[1]["value"],
|
||||
)
|
||||
for data in dataset
|
||||
]
|
||||
# Shuffle the dataset.
|
||||
@@ -895,7 +935,9 @@ def calculate_metrics(
|
||||
mean_itl_ms=np.mean(itls or 0) * 1000,
|
||||
median_itl_ms=np.median(itls or 0) * 1000,
|
||||
std_itl_ms=np.std(itls or 0) * 1000,
|
||||
p95_itl_ms=np.percentile(itls or 0, 95) * 1000,
|
||||
p99_itl_ms=np.percentile(itls or 0, 99) * 1000,
|
||||
max_itl_ms=np.max(itls or 0) * 1000,
|
||||
mean_e2e_latency_ms=np.mean(e2e_latencies) * 1000,
|
||||
median_e2e_latency_ms=np.median(e2e_latencies) * 1000,
|
||||
std_e2e_latency_ms=np.std(e2e_latencies) * 1000,
|
||||
@@ -919,6 +961,7 @@ async def benchmark(
|
||||
lora_name: str,
|
||||
extra_request_body: Dict[str, Any],
|
||||
profile: bool,
|
||||
pd_seperated: bool = False,
|
||||
):
|
||||
if backend in ASYNC_REQUEST_FUNCS:
|
||||
request_func = ASYNC_REQUEST_FUNCS[backend]
|
||||
@@ -1004,6 +1047,17 @@ async def benchmark(
|
||||
if pbar is not None:
|
||||
pbar.close()
|
||||
|
||||
if "sglang" in backend:
|
||||
server_info = requests.get(base_url + "/get_server_info")
|
||||
if pd_seperated:
|
||||
accept_length = server_info.json()["decode"][0].get(
|
||||
"avg_spec_accept_length", None
|
||||
)
|
||||
else:
|
||||
accept_length = server_info.json().get("avg_spec_accept_length", None)
|
||||
else:
|
||||
accept_length = None
|
||||
|
||||
# Compute metrics and print results
|
||||
benchmark_duration = time.perf_counter() - benchmark_start_time
|
||||
metrics, output_lens = calculate_metrics(
|
||||
@@ -1053,6 +1107,8 @@ async def benchmark(
|
||||
)
|
||||
)
|
||||
print("{:<40} {:<10.2f}".format("Concurrency:", metrics.concurrency))
|
||||
if accept_length:
|
||||
print("{:<40} {:<10.2f}".format("Accept length:", accept_length))
|
||||
print("{s:{c}^{n}}".format(s="End-to-End Latency", n=50, c="-"))
|
||||
print(
|
||||
"{:<40} {:<10.2f}".format("Mean E2E Latency (ms):", metrics.mean_e2e_latency_ms)
|
||||
@@ -1066,16 +1122,12 @@ async def benchmark(
|
||||
print("{:<40} {:<10.2f}".format("Mean TTFT (ms):", metrics.mean_ttft_ms))
|
||||
print("{:<40} {:<10.2f}".format("Median TTFT (ms):", metrics.median_ttft_ms))
|
||||
print("{:<40} {:<10.2f}".format("P99 TTFT (ms):", metrics.p99_ttft_ms))
|
||||
print(
|
||||
"{s:{c}^{n}}".format(s="Time per Output Token (excl. 1st token)", n=50, c="-")
|
||||
)
|
||||
print("{:<40} {:<10.2f}".format("Mean TPOT (ms):", metrics.mean_tpot_ms))
|
||||
print("{:<40} {:<10.2f}".format("Median TPOT (ms):", metrics.median_tpot_ms))
|
||||
print("{:<40} {:<10.2f}".format("P99 TPOT (ms):", metrics.p99_tpot_ms))
|
||||
print("{s:{c}^{n}}".format(s="Inter-token Latency", n=50, c="-"))
|
||||
print("{s:{c}^{n}}".format(s="Inter-Token Latency", n=50, c="-"))
|
||||
print("{:<40} {:<10.2f}".format("Mean ITL (ms):", metrics.mean_itl_ms))
|
||||
print("{:<40} {:<10.2f}".format("Median ITL (ms):", metrics.median_itl_ms))
|
||||
print("{:<40} {:<10.2f}".format("P95 ITL (ms):", metrics.p95_itl_ms))
|
||||
print("{:<40} {:<10.2f}".format("P99 ITL (ms):", metrics.p99_itl_ms))
|
||||
print("{:<40} {:<10.2f}".format("Max ITL (ms):", metrics.max_itl_ms))
|
||||
print("=" * 50)
|
||||
|
||||
if (
|
||||
@@ -1117,8 +1169,10 @@ async def benchmark(
|
||||
"mean_itl_ms": metrics.mean_itl_ms,
|
||||
"median_itl_ms": metrics.median_itl_ms,
|
||||
"std_itl_ms": metrics.std_itl_ms,
|
||||
"p95_itl_ms": metrics.p95_itl_ms,
|
||||
"p99_itl_ms": metrics.p99_itl_ms,
|
||||
"concurrency": metrics.concurrency,
|
||||
"accept_length": accept_length,
|
||||
}
|
||||
else:
|
||||
print(f"Error running benchmark for request rate: {request_rate}")
|
||||
@@ -1151,14 +1205,6 @@ async def benchmark(
|
||||
return result
|
||||
|
||||
|
||||
def parse_request_rate_range(request_rate_range):
|
||||
if len(request_rate_range.split(",")) == 3:
|
||||
start, stop, step = map(int, request_rate_range.split(","))
|
||||
return list(range(start, stop, step))
|
||||
else:
|
||||
return list(map(int, request_rate_range.split(",")))
|
||||
|
||||
|
||||
def check_chat_template(model_path):
|
||||
try:
|
||||
tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
|
||||
@@ -1168,6 +1214,12 @@ def check_chat_template(model_path):
|
||||
return False
|
||||
|
||||
|
||||
def set_global_args(args_: argparse.Namespace):
|
||||
"""Set the global args."""
|
||||
global args
|
||||
args = args_
|
||||
|
||||
|
||||
def run_benchmark(args_: argparse.Namespace):
|
||||
global args
|
||||
args = args_
|
||||
@@ -1176,6 +1228,8 @@ def run_benchmark(args_: argparse.Namespace):
|
||||
if not hasattr(args, "max_concurrency"):
|
||||
args.max_concurrency = None
|
||||
|
||||
print(f"benchmark_args={args}")
|
||||
|
||||
# Set global environments
|
||||
set_ulimit()
|
||||
random.seed(args.seed)
|
||||
@@ -1272,49 +1326,26 @@ def run_benchmark(args_: argparse.Namespace):
|
||||
backend = args.backend
|
||||
model_id = args.model
|
||||
tokenizer_id = args.tokenizer if args.tokenizer is not None else args.model
|
||||
|
||||
tokenizer = get_tokenizer(tokenizer_id)
|
||||
|
||||
input_requests = get_dataset(args, tokenizer)
|
||||
|
||||
if not args.multi:
|
||||
return asyncio.run(
|
||||
benchmark(
|
||||
backend=backend,
|
||||
api_url=api_url,
|
||||
base_url=base_url,
|
||||
model_id=model_id,
|
||||
tokenizer=tokenizer,
|
||||
input_requests=input_requests,
|
||||
request_rate=args.request_rate,
|
||||
max_concurrency=args.max_concurrency,
|
||||
disable_tqdm=args.disable_tqdm,
|
||||
lora_name=args.lora_name,
|
||||
extra_request_body=extra_request_body,
|
||||
profile=args.profile,
|
||||
)
|
||||
return asyncio.run(
|
||||
benchmark(
|
||||
backend=backend,
|
||||
api_url=api_url,
|
||||
base_url=base_url,
|
||||
model_id=model_id,
|
||||
tokenizer=tokenizer,
|
||||
input_requests=input_requests,
|
||||
request_rate=args.request_rate,
|
||||
max_concurrency=args.max_concurrency,
|
||||
disable_tqdm=args.disable_tqdm,
|
||||
lora_name=args.lora_name,
|
||||
extra_request_body=extra_request_body,
|
||||
profile=args.profile,
|
||||
pd_seperated=args.pd_seperated,
|
||||
)
|
||||
else:
|
||||
# Benchmark multiple rps. TODO: use a fixed duration to compute num_prompts
|
||||
request_rates = parse_request_rate_range(args.request_rate_range)
|
||||
|
||||
for rate in request_rates:
|
||||
asyncio.run(
|
||||
benchmark(
|
||||
backend=backend,
|
||||
api_url=api_url,
|
||||
base_url=base_url,
|
||||
model_id=model_id,
|
||||
tokenizer=tokenizer,
|
||||
input_requests=input_requests,
|
||||
request_rate=rate,
|
||||
max_concurrency=args.max_concurrency,
|
||||
disable_tqdm=args.disable_tqdm,
|
||||
lora_name=args.lora_name,
|
||||
extra_request_body=extra_request_body,
|
||||
profile=args.profile,
|
||||
)
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
def set_ulimit(target_soft_limit=65535):
|
||||
@@ -1428,17 +1459,6 @@ if __name__ == "__main__":
|
||||
"actual request rate may be lower than specified with --request-rate, "
|
||||
"if the server is not processing requests fast enough to keep up.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--multi",
|
||||
action="store_true",
|
||||
help="Use request rate range rather than single value.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--request-rate-range",
|
||||
type=str,
|
||||
default="2,34,2",
|
||||
help="Range of request rates in the format start,stop,step. Default is 2,34,2. It also supports a list of request rates, requiring the parameters to not equal three.",
|
||||
)
|
||||
parser.add_argument("--output-file", type=str, help="Output JSONL file name.")
|
||||
parser.add_argument(
|
||||
"--disable-tqdm",
|
||||
@@ -1485,6 +1505,17 @@ if __name__ == "__main__":
|
||||
default=None,
|
||||
help="The name of LoRA adapter",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--prompt-suffix",
|
||||
type=str,
|
||||
default="",
|
||||
help="Suffix applied to the end of all user prompts, followed by assistant prompt suffix.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--pd-seperated",
|
||||
action="store_true",
|
||||
help="Benchmark PD disaggregation server",
|
||||
)
|
||||
|
||||
group = parser.add_argument_group("generated-shared-prefix dataset arguments")
|
||||
group.add_argument(
|
||||
|
||||
Reference in New Issue
Block a user