[CI] test chunked prefill more (#5798)

This commit is contained in:
Lianmin Zheng
2025-04-28 10:57:17 -07:00
committed by GitHub
parent d73ddeb196
commit 849c83a0c0
15 changed files with 212 additions and 97 deletions

View File

@@ -975,7 +975,7 @@ class ModelRunner:
after_mem = get_available_gpu_memory(self.device, self.gpu_id)
logger.info(
f"Capture cuda graph end. Time elapsed: {time.time() - tic:.2f} s. "
f"avail mem={after_mem:.2f} GB. mem usage={(before_mem - after_mem):.2f} GB."
f"mem usage={(before_mem - after_mem):.2f} GB. avail mem={after_mem:.2f} GB."
)
def apply_torch_tp(self):

View File

@@ -426,7 +426,7 @@ class ServerArgs:
parser.add_argument(
"--skip-tokenizer-init",
action="store_true",
help="If set, skip init tokenizer and pass input_ids in generate request",
help="If set, skip init tokenizer and pass input_ids in generate request.",
)
parser.add_argument(
"--enable-tokenizer-batch-encode",
@@ -565,6 +565,7 @@ class ServerArgs:
"name, a tag name, or a commit id. If unspecified, will use "
"the default version.",
)
# Memory and scheduling
parser.add_argument(
"--mem-fraction-static",

View File

@@ -6,11 +6,56 @@ python3 -m sglang.test.send_one
"""
import argparse
import dataclasses
import json
import requests
@dataclasses.dataclass
class BenchArgs:
host: str = "localhost"
port: int = 30000
batch_size: int = 1
temperature: float = 0.0
max_new_tokens: int = 512
frequency_penalty: float = 0.0
presence_penalty: float = 0.0
json: bool = False
return_logprob: bool = False
prompt: str = (
"Human: Give me a fully functional FastAPI server. Show the python code.\n\nAssistant:"
)
image: bool = False
stream: bool = False
@staticmethod
def add_cli_args(parser: argparse.ArgumentParser):
parser.add_argument("--host", type=str, default=BenchArgs.host)
parser.add_argument("--port", type=int, default=BenchArgs.port)
parser.add_argument("--batch-size", type=int, default=BenchArgs.batch_size)
parser.add_argument("--temperature", type=float, default=BenchArgs.temperature)
parser.add_argument(
"--max-new-tokens", type=int, default=BenchArgs.max_new_tokens
)
parser.add_argument(
"--frequency-penalty", type=float, default=BenchArgs.frequency_penalty
)
parser.add_argument(
"--presence-penalty", type=float, default=BenchArgs.presence_penalty
)
parser.add_argument("--json", action="store_true")
parser.add_argument("--return-logprob", action="store_true")
parser.add_argument("--prompt", type=str, default=BenchArgs.prompt)
parser.add_argument("--image", action="store_true")
parser.add_argument("--stream", action="store_true")
@classmethod
def from_cli_args(cls, args: argparse.Namespace):
attrs = [attr.name for attr in dataclasses.fields(cls)]
return cls(**{attr: getattr(args, attr) for attr in attrs})
def send_one_prompt(args):
if args.image:
args.prompt = (
@@ -20,20 +65,42 @@ def send_one_prompt(args):
else:
image_data = None
response = requests.post(
"http://localhost:30000/generate",
json={
"text": args.prompt,
"image_data": image_data,
"sampling_params": {
"temperature": args.temperature,
"max_new_tokens": args.max_new_tokens,
"frequency_penalty": args.frequency_penalty,
"presence_penalty": args.presence_penalty,
},
"return_logprob": args.return_logprob,
"stream": args.stream,
prompt = args.prompt
if args.json:
prompt = (
"Human: What is the capital of France and how is that city like. "
"Give me 3 trivial information about that city. "
"Write in a format of json.\nAssistant:"
)
json_schema = "$$ANY$$"
json_schema = (
'{"type": "object", "properties": {"population": {"type": "integer"}}}'
)
else:
json_schema = None
if args.batch_size > 1:
prompt = [prompt] * args.batch_size
json_data = {
"text": prompt,
"image_data": image_data,
"sampling_params": {
"temperature": args.temperature,
"max_new_tokens": args.max_new_tokens,
"frequency_penalty": args.frequency_penalty,
"presence_penalty": args.presence_penalty,
"json_schema": json_schema,
"stop": ["Question", "Assistant:", "<|separator|>", "<|eos|>"],
},
"return_logprob": args.return_logprob,
"stream": args.stream,
}
response = requests.post(
f"http://{args.host}:{args.port}/generate",
json=json_data,
stream=args.stream,
)
@@ -47,6 +114,9 @@ def send_one_prompt(args):
else:
ret = response.json()
if args.batch_size > 1:
ret = ret[0]
latency = ret["meta_info"]["e2e_latency"]
if "spec_verify_ct" in ret["meta_info"]:
@@ -68,21 +138,7 @@ def send_one_prompt(args):
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--temperature", type=float, default=0.0)
parser.add_argument("--max-new-tokens", type=int, default=512)
parser.add_argument("--frequency-penalty", type=float, default=0.0)
parser.add_argument("--presence-penalty", type=float, default=0.0)
parser.add_argument("--return-logprob", action="store_true")
parser.add_argument(
"--prompt",
type=str,
default="Human: Give me a fully functional FastAPI server. Show the python code.\n\nAssistant:",
)
parser.add_argument(
"--image",
action="store_true",
)
parser.add_argument("--stream", action="store_true")
BenchArgs.add_cli_args(parser)
args = parser.parse_args()
send_one_prompt(args)

View File

@@ -732,6 +732,44 @@ def run_bench_one_batch(model, other_args):
return output_throughput
def run_bench_offline_throughput(model, other_args):
command = [
"python3",
"-m",
"sglang.bench_offline_throughput",
"--num-prompts",
"1",
"--dataset-name",
"random",
"--random-input-len",
"256",
"--random-output-len",
"256",
"--model-path",
model,
*[str(x) for x in other_args],
]
print(f"{command=}")
process = subprocess.Popen(command, stdout=subprocess.PIPE, stderr=subprocess.PIPE)
try:
stdout, stderr = process.communicate()
output = stdout.decode()
error = stderr.decode()
print(f"Output: {output}", flush=True)
print(f"Error: {error}", flush=True)
output_throughput = -1
for line in output.split("\n"):
if "Last generation throughput (tok/s):" in line:
output_throughput = float(line.split(":")[-1])
finally:
kill_process_tree(process.pid)
return output_throughput
def lcs(X, Y):
m = len(X)
n = len(Y)