Enhance the test case for chunked prefill (#1785)
This commit is contained in:
@@ -3,6 +3,7 @@
|
||||
import argparse
|
||||
import asyncio
|
||||
import os
|
||||
import random
|
||||
import subprocess
|
||||
import threading
|
||||
import time
|
||||
@@ -20,6 +21,7 @@ from sglang.global_config import global_config
|
||||
from sglang.lang.backend.openai import OpenAI
|
||||
from sglang.lang.backend.runtime_endpoint import RuntimeEndpoint
|
||||
from sglang.srt.utils import kill_child_process
|
||||
from sglang.test.run_eval import run_eval
|
||||
from sglang.utils import get_exception_traceback
|
||||
|
||||
DEFAULT_FP8_MODEL_NAME_FOR_TEST = "neuralmagic/Meta-Llama-3.1-8B-FP8"
|
||||
@@ -400,7 +402,7 @@ def popen_launch_server(
|
||||
api_key: Optional[str] = None,
|
||||
other_args: tuple = (),
|
||||
env: Optional[dict] = None,
|
||||
return_stdout_stderr: bool = False,
|
||||
return_stdout_stderr: Optional[tuple] = None,
|
||||
):
|
||||
_, host, port = base_url.split(":")
|
||||
host = host[2:]
|
||||
@@ -423,8 +425,8 @@ def popen_launch_server(
|
||||
if return_stdout_stderr:
|
||||
process = subprocess.Popen(
|
||||
command,
|
||||
stdout=subprocess.PIPE,
|
||||
stderr=subprocess.PIPE,
|
||||
stdout=return_stdout_stderr[0],
|
||||
stderr=return_stdout_stderr[1],
|
||||
env=env,
|
||||
text=True,
|
||||
)
|
||||
@@ -631,3 +633,91 @@ def calculate_rouge_l(output_strs_list1, output_strs_list2):
|
||||
rouge_l_scores.append(fmeasure)
|
||||
|
||||
return rouge_l_scores
|
||||
|
||||
|
||||
STDOUT_FILENAME = "stdout.txt"
|
||||
STDERR_FILENAME = "stderr.txt"
|
||||
|
||||
|
||||
def read_output(output_lines):
|
||||
pt = 0
|
||||
while pt >= 0:
|
||||
if pt > 0 and os.path.exists(STDERR_FILENAME):
|
||||
break
|
||||
lines = open(STDERR_FILENAME).readlines()
|
||||
output_lines[:] = lines
|
||||
for line in lines[pt:]:
|
||||
print(line, end="", flush=True)
|
||||
pt += 1
|
||||
|
||||
|
||||
def run_mmlu_test(
|
||||
disable_radix_cache,
|
||||
enable_mixed_chunk=False,
|
||||
enable_overlap=False,
|
||||
chunked_prefill_size=32,
|
||||
):
|
||||
other_args = ["--chunked-prefill-size", str(chunked_prefill_size)]
|
||||
if disable_radix_cache:
|
||||
other_args += ["--disable-radix-cache"]
|
||||
if enable_mixed_chunk:
|
||||
other_args += ["--enable-mixed-chunk"]
|
||||
if enable_overlap:
|
||||
other_args += ["--enable-overlap-scheduler"]
|
||||
|
||||
model = DEFAULT_MODEL_NAME_FOR_TEST
|
||||
port = random.randint(4000, 5000)
|
||||
base_url = f"http://127.0.0.1:{port}"
|
||||
|
||||
# Create files and launch the server
|
||||
stdout = open(STDOUT_FILENAME, "w")
|
||||
stderr = open(STDERR_FILENAME, "w")
|
||||
process = popen_launch_server(
|
||||
model,
|
||||
base_url,
|
||||
timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
|
||||
other_args=other_args,
|
||||
return_stdout_stderr=(stdout, stderr),
|
||||
)
|
||||
|
||||
# Launch a thread to stream the output
|
||||
output_lines = []
|
||||
t = threading.Thread(target=read_output, args=(output_lines,))
|
||||
t.start()
|
||||
|
||||
# Run the eval
|
||||
args = SimpleNamespace(
|
||||
base_url=base_url,
|
||||
model=model,
|
||||
eval_name="mmlu",
|
||||
num_examples=128,
|
||||
num_threads=128,
|
||||
)
|
||||
|
||||
try:
|
||||
metrics = run_eval(args)
|
||||
print(f"{metrics=}")
|
||||
assert metrics["score"] >= 0.65
|
||||
finally:
|
||||
pass
|
||||
|
||||
# Clean up everything
|
||||
kill_child_process(process.pid)
|
||||
kill_child_process(process.pid)
|
||||
stdout.close()
|
||||
stderr.close()
|
||||
os.remove(STDOUT_FILENAME)
|
||||
os.remove(STDERR_FILENAME)
|
||||
t.join()
|
||||
|
||||
# Assert success
|
||||
has_new_server = False
|
||||
has_leak = False
|
||||
for line in output_lines:
|
||||
if "The server is fired" in line:
|
||||
has_new_server = True
|
||||
if "leak" in line:
|
||||
has_leak = True
|
||||
|
||||
assert has_new_server
|
||||
# assert not has_leak
|
||||
|
||||
Reference in New Issue
Block a user