[CI] test chunked prefill more (#5798)
This commit is contained in:
@@ -975,7 +975,7 @@ class ModelRunner:
|
||||
after_mem = get_available_gpu_memory(self.device, self.gpu_id)
|
||||
logger.info(
|
||||
f"Capture cuda graph end. Time elapsed: {time.time() - tic:.2f} s. "
|
||||
f"avail mem={after_mem:.2f} GB. mem usage={(before_mem - after_mem):.2f} GB."
|
||||
f"mem usage={(before_mem - after_mem):.2f} GB. avail mem={after_mem:.2f} GB."
|
||||
)
|
||||
|
||||
def apply_torch_tp(self):
|
||||
|
||||
@@ -426,7 +426,7 @@ class ServerArgs:
|
||||
parser.add_argument(
|
||||
"--skip-tokenizer-init",
|
||||
action="store_true",
|
||||
help="If set, skip init tokenizer and pass input_ids in generate request",
|
||||
help="If set, skip init tokenizer and pass input_ids in generate request.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--enable-tokenizer-batch-encode",
|
||||
@@ -565,6 +565,7 @@ class ServerArgs:
|
||||
"name, a tag name, or a commit id. If unspecified, will use "
|
||||
"the default version.",
|
||||
)
|
||||
|
||||
# Memory and scheduling
|
||||
parser.add_argument(
|
||||
"--mem-fraction-static",
|
||||
|
||||
@@ -6,11 +6,56 @@ python3 -m sglang.test.send_one
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import dataclasses
|
||||
import json
|
||||
|
||||
import requests
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class BenchArgs:
|
||||
host: str = "localhost"
|
||||
port: int = 30000
|
||||
batch_size: int = 1
|
||||
temperature: float = 0.0
|
||||
max_new_tokens: int = 512
|
||||
frequency_penalty: float = 0.0
|
||||
presence_penalty: float = 0.0
|
||||
json: bool = False
|
||||
return_logprob: bool = False
|
||||
prompt: str = (
|
||||
"Human: Give me a fully functional FastAPI server. Show the python code.\n\nAssistant:"
|
||||
)
|
||||
image: bool = False
|
||||
stream: bool = False
|
||||
|
||||
@staticmethod
|
||||
def add_cli_args(parser: argparse.ArgumentParser):
|
||||
parser.add_argument("--host", type=str, default=BenchArgs.host)
|
||||
parser.add_argument("--port", type=int, default=BenchArgs.port)
|
||||
parser.add_argument("--batch-size", type=int, default=BenchArgs.batch_size)
|
||||
parser.add_argument("--temperature", type=float, default=BenchArgs.temperature)
|
||||
parser.add_argument(
|
||||
"--max-new-tokens", type=int, default=BenchArgs.max_new_tokens
|
||||
)
|
||||
parser.add_argument(
|
||||
"--frequency-penalty", type=float, default=BenchArgs.frequency_penalty
|
||||
)
|
||||
parser.add_argument(
|
||||
"--presence-penalty", type=float, default=BenchArgs.presence_penalty
|
||||
)
|
||||
parser.add_argument("--json", action="store_true")
|
||||
parser.add_argument("--return-logprob", action="store_true")
|
||||
parser.add_argument("--prompt", type=str, default=BenchArgs.prompt)
|
||||
parser.add_argument("--image", action="store_true")
|
||||
parser.add_argument("--stream", action="store_true")
|
||||
|
||||
@classmethod
|
||||
def from_cli_args(cls, args: argparse.Namespace):
|
||||
attrs = [attr.name for attr in dataclasses.fields(cls)]
|
||||
return cls(**{attr: getattr(args, attr) for attr in attrs})
|
||||
|
||||
|
||||
def send_one_prompt(args):
|
||||
if args.image:
|
||||
args.prompt = (
|
||||
@@ -20,20 +65,42 @@ def send_one_prompt(args):
|
||||
else:
|
||||
image_data = None
|
||||
|
||||
response = requests.post(
|
||||
"http://localhost:30000/generate",
|
||||
json={
|
||||
"text": args.prompt,
|
||||
"image_data": image_data,
|
||||
"sampling_params": {
|
||||
"temperature": args.temperature,
|
||||
"max_new_tokens": args.max_new_tokens,
|
||||
"frequency_penalty": args.frequency_penalty,
|
||||
"presence_penalty": args.presence_penalty,
|
||||
},
|
||||
"return_logprob": args.return_logprob,
|
||||
"stream": args.stream,
|
||||
prompt = args.prompt
|
||||
|
||||
if args.json:
|
||||
prompt = (
|
||||
"Human: What is the capital of France and how is that city like. "
|
||||
"Give me 3 trivial information about that city. "
|
||||
"Write in a format of json.\nAssistant:"
|
||||
)
|
||||
json_schema = "$$ANY$$"
|
||||
json_schema = (
|
||||
'{"type": "object", "properties": {"population": {"type": "integer"}}}'
|
||||
)
|
||||
else:
|
||||
json_schema = None
|
||||
|
||||
if args.batch_size > 1:
|
||||
prompt = [prompt] * args.batch_size
|
||||
|
||||
json_data = {
|
||||
"text": prompt,
|
||||
"image_data": image_data,
|
||||
"sampling_params": {
|
||||
"temperature": args.temperature,
|
||||
"max_new_tokens": args.max_new_tokens,
|
||||
"frequency_penalty": args.frequency_penalty,
|
||||
"presence_penalty": args.presence_penalty,
|
||||
"json_schema": json_schema,
|
||||
"stop": ["Question", "Assistant:", "<|separator|>", "<|eos|>"],
|
||||
},
|
||||
"return_logprob": args.return_logprob,
|
||||
"stream": args.stream,
|
||||
}
|
||||
|
||||
response = requests.post(
|
||||
f"http://{args.host}:{args.port}/generate",
|
||||
json=json_data,
|
||||
stream=args.stream,
|
||||
)
|
||||
|
||||
@@ -47,6 +114,9 @@ def send_one_prompt(args):
|
||||
else:
|
||||
ret = response.json()
|
||||
|
||||
if args.batch_size > 1:
|
||||
ret = ret[0]
|
||||
|
||||
latency = ret["meta_info"]["e2e_latency"]
|
||||
|
||||
if "spec_verify_ct" in ret["meta_info"]:
|
||||
@@ -68,21 +138,7 @@ def send_one_prompt(args):
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--temperature", type=float, default=0.0)
|
||||
parser.add_argument("--max-new-tokens", type=int, default=512)
|
||||
parser.add_argument("--frequency-penalty", type=float, default=0.0)
|
||||
parser.add_argument("--presence-penalty", type=float, default=0.0)
|
||||
parser.add_argument("--return-logprob", action="store_true")
|
||||
parser.add_argument(
|
||||
"--prompt",
|
||||
type=str,
|
||||
default="Human: Give me a fully functional FastAPI server. Show the python code.\n\nAssistant:",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--image",
|
||||
action="store_true",
|
||||
)
|
||||
parser.add_argument("--stream", action="store_true")
|
||||
BenchArgs.add_cli_args(parser)
|
||||
args = parser.parse_args()
|
||||
|
||||
send_one_prompt(args)
|
||||
|
||||
@@ -732,6 +732,44 @@ def run_bench_one_batch(model, other_args):
|
||||
return output_throughput
|
||||
|
||||
|
||||
def run_bench_offline_throughput(model, other_args):
|
||||
command = [
|
||||
"python3",
|
||||
"-m",
|
||||
"sglang.bench_offline_throughput",
|
||||
"--num-prompts",
|
||||
"1",
|
||||
"--dataset-name",
|
||||
"random",
|
||||
"--random-input-len",
|
||||
"256",
|
||||
"--random-output-len",
|
||||
"256",
|
||||
"--model-path",
|
||||
model,
|
||||
*[str(x) for x in other_args],
|
||||
]
|
||||
|
||||
print(f"{command=}")
|
||||
process = subprocess.Popen(command, stdout=subprocess.PIPE, stderr=subprocess.PIPE)
|
||||
|
||||
try:
|
||||
stdout, stderr = process.communicate()
|
||||
output = stdout.decode()
|
||||
error = stderr.decode()
|
||||
print(f"Output: {output}", flush=True)
|
||||
print(f"Error: {error}", flush=True)
|
||||
|
||||
output_throughput = -1
|
||||
for line in output.split("\n"):
|
||||
if "Last generation throughput (tok/s):" in line:
|
||||
output_throughput = float(line.split(":")[-1])
|
||||
finally:
|
||||
kill_process_tree(process.pid)
|
||||
|
||||
return output_throughput
|
||||
|
||||
|
||||
def lcs(X, Y):
|
||||
m = len(X)
|
||||
n = len(Y)
|
||||
|
||||
Reference in New Issue
Block a user