Support penalty in overlap mode; return logprob with chunked prefill; improve benchmark scripts (#3988)

Co-authored-by: SangBin Cho <rkooo567@gmail.com>
Co-authored-by: dhou-xai <dhou@x.ai>
Co-authored-by: Hanming Lu <hanming_lu@berkeley.edu>
This commit is contained in:
Lianmin Zheng
2025-03-03 00:12:04 -08:00
parent 0194948fd9
commit ac2387279e
86 changed files with 4116 additions and 2015 deletions

View File

@@ -15,7 +15,7 @@
import multiprocessing as mp
import os
from dataclasses import dataclass
from typing import List, Union
from typing import List, Optional, Tuple, Union
import torch
import torch.nn.functional as F
@@ -56,6 +56,13 @@ def get_top_logprobs(logits, k):
return logprobs
def get_token_ids_logprobs(logits, token_ids):
logprobs = F.log_softmax(logits, dim=-1, dtype=torch.float32)
del logits
logprobs = logprobs[..., token_ids]
return logprobs
def _get_sentence_transformer_embedding_model(model_path, torch_dtype):
from sentence_transformers import SentenceTransformer
from sentence_transformers.util import is_sentence_transformer_model
@@ -84,8 +91,13 @@ class ModelOutput:
output_ids: List[int] = None
top_input_logprobs: List[torch.Tensor] = None
top_output_logprobs: List[torch.Tensor] = None
top_output_logprob_idx: List[List[int]] = None
embed_logits: List[torch.Tensor] = None
scores: List[float] = None
input_token_logprobs_lst: List[List[Tuple[float, int, None]]] = None
output_token_logprobs_lst: List[List[Tuple[float, int, None]]] = None
token_ids_input_logprobs: List[torch.Tensor] = None
token_ids_output_logprobs: List[torch.Tensor] = None
class HFRunner:
@@ -157,7 +169,7 @@ class HFRunner:
# Run forward
while True:
prompts, max_new_tokens, lora_paths = in_queue.get()
prompts, max_new_tokens, lora_paths, token_ids_logprob = in_queue.get()
if lora_paths is not None:
assert len(prompts) == len(lora_paths)
@@ -165,16 +177,16 @@ class HFRunner:
if self.model_type == "generation":
out_queue.put(
self.forward_generation_raw(
base_model=self.base_model,
prompts=prompts,
max_new_tokens=max_new_tokens,
base_model=self.base_model,
tokenizer=self.tokenizer,
lora_paths=lora_paths,
torch_dtype=torch_dtype,
output_str_only=self.output_str_only,
token_ids_logprob=token_ids_logprob,
)
)
elif self.model_type == "embedding":
assert not self.output_str_only
logits = self.model.encode(prompts).tolist()
@@ -199,10 +211,11 @@ class HFRunner:
def forward(
self,
prompts: Union[List[str], List[torch.Tensor]] = DEFAULT_PROMPTS,
max_new_tokens=8,
lora_paths=None,
max_new_tokens: int = 8,
lora_paths: Optional[List[str]] = None,
token_ids_logprob: Optional[int] = None,
):
self.in_queue.put((prompts, max_new_tokens, lora_paths))
self.in_queue.put((prompts, max_new_tokens, lora_paths, token_ids_logprob))
return self.out_queue.get()
def terminate(self):
@@ -218,17 +231,24 @@ class HFRunner:
@staticmethod
def forward_generation_raw(
prompts: Union[List[str], List[torch.Tensor]],
max_new_tokens,
base_model,
prompts: Union[List[str], List[torch.Tensor]],
max_new_tokens: int,
tokenizer,
lora_paths,
torch_dtype: torch.dtype,
output_str_only: bool,
lora_paths: Optional[List[str]] = None,
output_str_only: bool = False,
token_ids_logprob: Optional[int] = None,
) -> ModelOutput:
output_strs = []
top_input_logprobs = []
top_output_logprobs = []
if token_ids_logprob is not None:
token_ids_input_logprobs = []
token_ids_output_logprobs = []
else:
token_ids_input_logprobs = token_ids_output_logprobs = None
for i, p in enumerate(prompts):
if isinstance(p, str):
input_ids = tokenizer.encode(p, return_tensors="pt").cuda()
@@ -275,18 +295,33 @@ class HFRunner:
for logits in outputs.scores
]
)
if token_ids_logprob is not None:
token_ids_output_logprobs.append(
[
get_token_ids_logprobs(
logits[0], token_ids_logprob
).tolist()
for logits in outputs.scores
]
)
del outputs
input_logits = model.forward(input_ids).logits[0]
top_input_logprobs.append(
get_top_logprobs(input_logits, NUM_TOP_LOGPROBS).tolist()
)
if token_ids_logprob is not None:
token_ids_input_logprobs.append(
get_token_ids_logprobs(input_logits, token_ids_logprob).tolist()
)
del input_logits
return ModelOutput(
output_strs=output_strs,
top_input_logprobs=top_input_logprobs,
top_output_logprobs=top_output_logprobs,
token_ids_input_logprobs=token_ids_input_logprobs,
token_ids_output_logprobs=token_ids_output_logprobs,
)
@@ -303,11 +338,31 @@ class SRTRunner:
lora_backend: str = "triton",
disable_cuda_graph: bool = False,
disable_radix_cache: bool = False,
chunked_prefill_size: Optional[int] = None,
dp_size: int = 1,
tokenizer_path: Optional[str] = None,
enable_ep_moe: bool = False,
mem_fraction_static: float = 0.65,
trust_remote_code: bool = False,
speculative_draft_model_path: Optional[str] = None,
speculative_algorithm: Optional[str] = None,
speculative_num_steps: Optional[int] = None,
speculative_eagle_topk: Optional[int] = None,
speculative_num_draft_tokens: Optional[int] = None,
disable_overlap_schedule: bool = False,
):
self.model_type = model_type
self.is_generation = model_type == "generation"
enable_dp_attention = dp_size > 1
spec_kwargs = {}
if speculative_draft_model_path:
spec_kwargs["speculative_draft_model_path"] = speculative_draft_model_path
spec_kwargs["speculative_algorithm"] = speculative_algorithm
spec_kwargs["speculative_num_steps"] = speculative_num_steps
spec_kwargs["speculative_eagle_topk"] = speculative_eagle_topk
spec_kwargs["speculative_num_draft_tokens"] = speculative_num_draft_tokens
self.engine = Engine(
model_path=model_path,
tp_size=tp_size,
@@ -321,21 +376,41 @@ class SRTRunner:
lora_backend=lora_backend,
disable_cuda_graph=disable_cuda_graph,
disable_radix_cache=disable_radix_cache,
chunked_prefill_size=chunked_prefill_size,
enable_dp_attention=enable_dp_attention,
dp_size=dp_size,
tokenizer_path=tokenizer_path,
enable_ep_moe=enable_ep_moe,
disable_overlap_schedule=disable_overlap_schedule,
cuda_graph_max_bs=4,
**spec_kwargs,
)
self.tokenizer = get_tokenizer(model_path, trust_remote_code=trust_remote_code)
if tokenizer_path is None:
self.tokenizer = get_tokenizer(
model_path, trust_remote_code=trust_remote_code
)
else:
self.tokenizer = None
def forward(
self,
prompts: Union[List[str], List[torch.Tensor]] = DEFAULT_PROMPTS,
max_new_tokens=8,
lora_paths=None,
max_new_tokens: int = 8,
lora_paths: Optional[List[str]] = None,
logprob_start_len: int = 0,
top_k: Optional[int] = None,
token_ids_logprob: Optional[List[int]] = None,
):
if self.is_generation:
return self.forward_generation_raw(
engine=self.engine,
prompts=prompts,
max_new_tokens=max_new_tokens,
lora_paths=lora_paths,
engine=self.engine,
logprob_start_len=logprob_start_len,
top_k=top_k,
token_ids_logprob=token_ids_logprob,
)
else:
response = self.engine.encode(prompts)
@@ -358,10 +433,10 @@ class SRTRunner:
"""
if self.is_generation:
return self.batch_forward_generation_raw(
engine=self.engine,
prompts=prompts,
max_new_tokens=max_new_tokens,
lora_paths=lora_paths,
engine=self.engine,
)
else:
response = self.engine.encode(prompts)
@@ -381,24 +456,43 @@ class SRTRunner:
@staticmethod
def forward_generation_raw(
engine: Engine,
prompts: Union[List[str], List[torch.Tensor]],
max_new_tokens,
lora_paths,
engine,
max_new_tokens: int = 8,
lora_paths: Optional[List[str]] = None,
logprob_start_len: int = 0,
top_k: Optional[int] = None,
token_ids_logprob: Optional[List[int]] = None,
):
# the return value contains logprobs from prefill
output_strs = []
output_ids = []
# Input logprobs. Note that the last item in input logprob is equivalent to
# the first item in the output logprob.
top_input_logprobs = []
input_token_logprobs_lst = []
top_output_logprobs = []
output_token_logprobs_lst = []
top_output_logprob_idx = []
if token_ids_logprob is not None:
token_ids_input_logprobs = []
token_ids_output_logprobs = []
else:
token_ids_input_logprobs = token_ids_output_logprobs = None
sampling_params = {"max_new_tokens": max_new_tokens, "temperature": 0}
if top_k:
sampling_params["top_k"] = top_k
for i, prompt in enumerate(prompts):
response = engine.generate(
prompt,
lora_path=lora_paths[i] if lora_paths else None,
sampling_params=sampling_params,
return_logprob=True,
logprob_start_len=0,
logprob_start_len=logprob_start_len,
top_logprobs_num=NUM_TOP_LOGPROBS,
token_ids_logprob=token_ids_logprob,
)
text = response["text"]
@@ -408,12 +502,36 @@ class SRTRunner:
"Received an empty text response. Please verify your input or model configuration."
)
output_strs.append(text)
# output_ids.append(response["output_ids"])
input_token_logprobs = response["meta_info"]["input_token_logprobs"]
output_token_logprobs = response["meta_info"]["output_token_logprobs"]
# print(i, input_token_logprobs)
# print(i, output_token_logprobs)
logprobs = response["meta_info"]["input_top_logprobs"]
if token_ids_logprob is not None:
input_token_ids_logprobs = response["meta_info"][
"input_token_ids_logprobs"
][1:]
else:
input_token_ids_logprobs = None
num_prompt_tokens = response["meta_info"]["prompt_tokens"]
assert len(input_token_logprobs) == num_prompt_tokens - logprob_start_len
assert len(logprobs) == num_prompt_tokens - logprob_start_len
# The first token logprob has no meaning in sglang.
input_token_logprobs = input_token_logprobs[1:]
logprobs = logprobs[1:]
assert len(input_token_logprobs) == len(logprobs)
input_token_logprobs_lst.append(
input_token_logprobs + [output_token_logprobs[0]]
)
output_token_logprobs_lst.append(output_token_logprobs)
top_input_logprobs.append(
[
[tup[0] for tup in x[:NUM_TOP_LOGPROBS]]
for x in response["meta_info"]["input_top_logprobs"][1:]
]
[[tup[0] for tup in x[:NUM_TOP_LOGPROBS]] for x in logprobs]
+ [
[
tup[0]
@@ -429,11 +547,41 @@ class SRTRunner:
for x in response["meta_info"]["output_top_logprobs"]
]
)
top_output_logprob_idx.append(
[
[tup[1] for tup in x[:NUM_TOP_LOGPROBS]]
for x in response["meta_info"]["output_top_logprobs"]
]
)
if token_ids_logprob is not None:
token_ids_input_logprobs.append(
[[tup[0] for tup in x] for x in input_token_ids_logprobs]
+ [
[
tup[0]
for tup in response["meta_info"][
"output_token_ids_logprobs"
][0]
]
]
)
token_ids_output_logprobs.append(
[
[tup[0] for tup in x]
for x in response["meta_info"]["output_token_ids_logprobs"]
]
)
return ModelOutput(
output_strs=output_strs,
output_ids=output_ids,
top_input_logprobs=top_input_logprobs,
top_output_logprobs=top_output_logprobs,
input_token_logprobs_lst=input_token_logprobs_lst,
output_token_logprobs_lst=output_token_logprobs_lst,
top_output_logprob_idx=top_output_logprob_idx,
token_ids_input_logprobs=token_ids_input_logprobs,
token_ids_output_logprobs=token_ids_output_logprobs,
)
@staticmethod

View File

@@ -0,0 +1,88 @@
"""
Run one test prompt.
Usage:
python3 -m sglang.test.send_one
"""
import argparse
import json
import requests
def send_one_prompt(args):
if args.image:
args.prompt = (
"Human: Describe this image in a very short sentence.\n\nAssistant:"
)
image_data = "https://raw.githubusercontent.com/sgl-project/sglang/main/test/lang/example_image.png"
else:
image_data = None
response = requests.post(
"http://localhost:30000/generate",
json={
"text": args.prompt,
"image_data": image_data,
"sampling_params": {
"temperature": args.temperature,
"max_new_tokens": args.max_new_tokens,
"frequency_penalty": args.frequency_penalty,
"presence_penalty": args.presence_penalty,
},
"return_logprob": args.return_logprob,
"stream": args.stream,
},
stream=args.stream,
)
if args.stream:
for chunk in response.iter_lines(decode_unicode=False):
chunk = chunk.decode("utf-8")
if chunk and chunk.startswith("data:"):
if chunk == "data: [DONE]":
break
ret = json.loads(chunk[5:].strip("\n"))
else:
ret = response.json()
latency = ret["meta_info"]["e2e_latency"]
if "spec_verify_ct" in ret["meta_info"]:
acc_length = (
ret["meta_info"]["completion_tokens"] / ret["meta_info"]["spec_verify_ct"]
)
else:
acc_length = 1.0
speed = ret["meta_info"]["completion_tokens"] / latency
print(ret["text"])
print()
print(f"{acc_length=:.2f}")
print(f"{speed=:.2f} token/s")
return acc_length, speed
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--temperature", type=float, default=0.0)
parser.add_argument("--max-new-tokens", type=int, default=512)
parser.add_argument("--frequency-penalty", type=float, default=0.0)
parser.add_argument("--presence-penalty", type=float, default=0.0)
parser.add_argument("--return-logprob", action="store_true")
parser.add_argument(
"--prompt",
type=str,
default="Human: Give me a fully functional FastAPI server. Show the python code.\n\nAssistant:",
)
parser.add_argument(
"--image",
action="store_true",
)
parser.add_argument("--stream", action="store_true")
args = parser.parse_args()
send_one_prompt(args)

View File

@@ -8,10 +8,11 @@ import random
import subprocess
import threading
import time
import unittest
from concurrent.futures import ThreadPoolExecutor
from functools import partial
from types import SimpleNamespace
from typing import Callable, List, Optional
from typing import Callable, List, Optional, Tuple
import numpy as np
import requests
@@ -408,26 +409,49 @@ def popen_launch_server(
other_args: list[str] = (),
env: Optional[dict] = None,
return_stdout_stderr: Optional[tuple] = None,
pd_seperated: bool = False,
):
_, host, port = base_url.split(":")
host = host[2:]
if pd_seperated:
command = "sglang.launch_pd_server"
else:
command = "sglang.launch_server"
command = [
"python3",
"-m",
"sglang.launch_server",
command,
"--model-path",
model,
"--host",
host,
"--port",
port,
*other_args,
*[str(x) for x in other_args],
]
if pd_seperated:
command.extend(
[
"--lb-host",
host,
"--lb-port",
port,
]
)
else:
command.extend(
[
"--host",
host,
"--port",
port,
]
)
if api_key:
command += ["--api-key", api_key]
print(f"command={' '.join(command)}")
if return_stdout_stderr:
process = subprocess.Popen(
command,
@@ -456,6 +480,8 @@ def popen_launch_server(
except requests.RequestException:
pass
time.sleep(10)
kill_process_tree(process.pid)
raise TimeoutError("Server failed to start within the timeout period.")
@@ -488,9 +514,11 @@ def run_unittest_files(files: List[str], timeout_per_file: float):
success = True
for filename in files:
global process
process = None
def run_one_file(filename):
nonlocal process
filename = os.path.join(os.getcwd(), filename)
print(f"\n\nRun:\npython3 {filename}\n\n", flush=True)
process = subprocess.Popen(
@@ -534,11 +562,14 @@ def get_benchmark_args(
dataset_path="",
tokenizer="",
num_prompts=500,
sharegpt_output_len=None,
random_input_len=4096,
random_output_len=2048,
sharegpt_context_len=None,
request_rate=float("inf"),
disable_stream=False,
disable_ignore_eos=False,
pd_seperated: bool = False,
):
return SimpleNamespace(
backend="sglang",
@@ -550,8 +581,8 @@ def get_benchmark_args(
model=None,
tokenizer=tokenizer,
num_prompts=num_prompts,
sharegpt_output_len=None,
sharegpt_context_len=None,
sharegpt_output_len=sharegpt_output_len,
sharegpt_context_len=sharegpt_context_len,
random_input_len=random_input_len,
random_output_len=random_output_len,
random_range_ratio=0.0,
@@ -567,6 +598,8 @@ def get_benchmark_args(
apply_chat_template=False,
profile=None,
lora_name=None,
prompt_suffix="",
pd_seperated=pd_seperated,
)
@@ -580,6 +613,7 @@ def run_bench_serving(
tokenizer=None,
random_input_len=4096,
random_output_len=2048,
sharegpt_context_len=None,
disable_stream=False,
disable_ignore_eos=False,
need_warmup=False,
@@ -602,6 +636,7 @@ def run_bench_serving(
num_prompts=num_prompts,
random_input_len=random_input_len,
random_output_len=random_output_len,
sharegpt_context_len=sharegpt_context_len,
request_rate=request_rate,
disable_stream=disable_stream,
disable_ignore_eos=disable_ignore_eos,
@@ -626,6 +661,7 @@ def run_bench_serving_multi(
other_server_args,
benchmark_args,
need_warmup=False,
pd_seperated=False,
):
# Launch the server
process = popen_launch_server(
@@ -633,6 +669,7 @@ def run_bench_serving_multi(
base_url,
timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
other_args=other_server_args,
pd_seperated=pd_seperated,
)
# run benchmark for all
@@ -665,7 +702,7 @@ def run_bench_one_batch(model, other_args):
"128",
"--output",
"8",
*other_args,
*[str(x) for x in other_args],
]
process = subprocess.Popen(command, stdout=subprocess.PIPE, stderr=subprocess.PIPE)
@@ -816,7 +853,7 @@ def run_command_and_capture_output(command, env: Optional[dict] = None):
stdout = open(STDOUT_FILENAME, "w")
stderr = open(STDERR_FILENAME, "w")
process = subprocess.Popen(
command, stdout=stdout, stderr=stderr, env=env, text=True
command, stdout=stdout, stderr=stdout, env=env, text=True
)
# Launch a thread to stream the output
@@ -914,3 +951,78 @@ def run_mulit_request_test(
def write_github_step_summary(content):
with open(os.environ["GITHUB_STEP_SUMMARY"], "a") as f:
f.write(content)
def run_logprob_check(self: unittest.TestCase, arg: Tuple):
(
input_len,
output_len,
temperature,
logprob_start_len,
return_logprob,
top_logprobs_num,
) = arg
input_ids = list(range(input_len))
response = requests.post(
self.base_url + "/generate",
json={
"input_ids": input_ids,
"sampling_params": {
"temperature": temperature,
"max_new_tokens": output_len,
"ignore_eos": True,
},
"return_logprob": return_logprob,
"logprob_start_len": logprob_start_len,
"top_logprobs_num": top_logprobs_num,
},
)
response_json = response.json()
res = response_json
self.assertEqual(res["meta_info"]["prompt_tokens"], input_len)
self.assertEqual(res["meta_info"]["completion_tokens"], output_len)
# Test the number of tokens are correct
if return_logprob:
self.assertEqual(
len(res["meta_info"]["input_token_logprobs"]) + logprob_start_len,
res["meta_info"]["prompt_tokens"],
)
self.assertEqual(len(res["meta_info"]["output_token_logprobs"]), output_len)
if top_logprobs_num:
self.assertEqual(
len(res["meta_info"]["input_top_logprobs"]) + logprob_start_len,
res["meta_info"]["prompt_tokens"],
)
self.assertEqual(len(res["meta_info"]["output_top_logprobs"]), output_len)
for i in range(output_len):
self.assertEqual(
len(res["meta_info"]["output_top_logprobs"][i]),
top_logprobs_num,
)
# Test the top-1 tokens are the same as output tokens if temperature == 0
if temperature == 0:
rank = 0
while rank < len(res["meta_info"]["output_top_logprobs"][i]):
try:
self.assertListEqual(
res["meta_info"]["output_token_logprobs"][i],
res["meta_info"]["output_top_logprobs"][i][rank],
)
break
except AssertionError:
# There's a tie. Allow the second item in this case.
if (
res["meta_info"]["output_top_logprobs"][i][rank][0]
== res["meta_info"]["output_top_logprobs"][i][rank + 1][
0
]
):
rank += 1
else:
raise