Enhance the test case for chunked prefill (#1785)

This commit is contained in:
Lianmin Zheng
2024-10-24 21:23:09 -07:00
committed by GitHub
parent 384d85ba35
commit 1701b0db31
6 changed files with 162 additions and 107 deletions

View File

@@ -3,6 +3,7 @@
import argparse
import asyncio
import os
import random
import subprocess
import threading
import time
@@ -20,6 +21,7 @@ from sglang.global_config import global_config
from sglang.lang.backend.openai import OpenAI
from sglang.lang.backend.runtime_endpoint import RuntimeEndpoint
from sglang.srt.utils import kill_child_process
from sglang.test.run_eval import run_eval
from sglang.utils import get_exception_traceback
DEFAULT_FP8_MODEL_NAME_FOR_TEST = "neuralmagic/Meta-Llama-3.1-8B-FP8"
@@ -400,7 +402,7 @@ def popen_launch_server(
api_key: Optional[str] = None,
other_args: tuple = (),
env: Optional[dict] = None,
return_stdout_stderr: bool = False,
return_stdout_stderr: Optional[tuple] = None,
):
_, host, port = base_url.split(":")
host = host[2:]
@@ -423,8 +425,8 @@ def popen_launch_server(
if return_stdout_stderr:
process = subprocess.Popen(
command,
stdout=subprocess.PIPE,
stderr=subprocess.PIPE,
stdout=return_stdout_stderr[0],
stderr=return_stdout_stderr[1],
env=env,
text=True,
)
@@ -631,3 +633,91 @@ def calculate_rouge_l(output_strs_list1, output_strs_list2):
rouge_l_scores.append(fmeasure)
return rouge_l_scores
STDOUT_FILENAME = "stdout.txt"
STDERR_FILENAME = "stderr.txt"
def read_output(output_lines):
pt = 0
while pt >= 0:
if pt > 0 and os.path.exists(STDERR_FILENAME):
break
lines = open(STDERR_FILENAME).readlines()
output_lines[:] = lines
for line in lines[pt:]:
print(line, end="", flush=True)
pt += 1
def run_mmlu_test(
disable_radix_cache,
enable_mixed_chunk=False,
enable_overlap=False,
chunked_prefill_size=32,
):
other_args = ["--chunked-prefill-size", str(chunked_prefill_size)]
if disable_radix_cache:
other_args += ["--disable-radix-cache"]
if enable_mixed_chunk:
other_args += ["--enable-mixed-chunk"]
if enable_overlap:
other_args += ["--enable-overlap-scheduler"]
model = DEFAULT_MODEL_NAME_FOR_TEST
port = random.randint(4000, 5000)
base_url = f"http://127.0.0.1:{port}"
# Create files and launch the server
stdout = open(STDOUT_FILENAME, "w")
stderr = open(STDERR_FILENAME, "w")
process = popen_launch_server(
model,
base_url,
timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
other_args=other_args,
return_stdout_stderr=(stdout, stderr),
)
# Launch a thread to stream the output
output_lines = []
t = threading.Thread(target=read_output, args=(output_lines,))
t.start()
# Run the eval
args = SimpleNamespace(
base_url=base_url,
model=model,
eval_name="mmlu",
num_examples=128,
num_threads=128,
)
try:
metrics = run_eval(args)
print(f"{metrics=}")
assert metrics["score"] >= 0.65
finally:
pass
# Clean up everything
kill_child_process(process.pid)
kill_child_process(process.pid)
stdout.close()
stderr.close()
os.remove(STDOUT_FILENAME)
os.remove(STDERR_FILENAME)
t.join()
# Assert success
has_new_server = False
has_leak = False
for line in output_lines:
if "The server is fired" in line:
has_new_server = True
if "leak" in line:
has_leak = True
assert has_new_server
# assert not has_leak