feat: support custom task runner (#2407)
This commit is contained in:
26
.github/workflows/experiment-runner.yml
vendored
Normal file
26
.github/workflows/experiment-runner.yml
vendored
Normal file
@@ -0,0 +1,26 @@
|
||||
name: Experiment Runner
|
||||
|
||||
on:
|
||||
workflow_dispatch:
|
||||
|
||||
concurrency:
|
||||
group: experiment-runner-${{ github.ref }}
|
||||
cancel-in-progress: true
|
||||
|
||||
jobs:
|
||||
experiment-runner-1-gpu:
|
||||
if: github.repository == 'sgl-project/sglang' || github.event_name == 'pull_request'
|
||||
runs-on: 1-gpu-runner
|
||||
steps:
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@v3
|
||||
|
||||
- name: Install dependencies
|
||||
run: |
|
||||
bash scripts/ci_install_dependency.sh
|
||||
|
||||
- name: Test experiment runner
|
||||
timeout-minutes: 10
|
||||
run: |
|
||||
cd test/srt
|
||||
python3 experiment_runner.py --config configs/sharegpt_config.yaml
|
||||
7
test/srt/configs/sharegpt_config.yaml
Normal file
7
test/srt/configs/sharegpt_config.yaml
Normal file
@@ -0,0 +1,7 @@
|
||||
tasks:
|
||||
- name: sglang-benchmark
|
||||
server_cmd: python3 -m sglang.launch_server --model-path meta-llama/Llama-3.1-8B-Instruct --disable-radix-cache
|
||||
client_cmd: python3 -m sglang.bench_serving --backend sglang --dataset-name sharegpt --request-rate 16
|
||||
- name: vllm-benchmark
|
||||
server_cmd: python3 -m vllm.entrypoints.openai.api_server --model meta-llama/Llama-3.1-8B-Instruct --disable-log-requests
|
||||
client_cmd: python3 -m sglang.bench_serving --backend vllm --dataset-name sharegpt --request-rate 16
|
||||
359
test/srt/experiment_runner.py
Normal file
359
test/srt/experiment_runner.py
Normal file
@@ -0,0 +1,359 @@
|
||||
import argparse
|
||||
import logging
|
||||
import os
|
||||
import queue
|
||||
import re
|
||||
import subprocess
|
||||
import threading
|
||||
import time
|
||||
from dataclasses import dataclass
|
||||
from datetime import datetime
|
||||
from pathlib import Path
|
||||
from typing import Dict, List, Optional, Tuple
|
||||
|
||||
import psutil
|
||||
import requests
|
||||
import yaml
|
||||
|
||||
|
||||
@dataclass
|
||||
class ServerConfig:
|
||||
command: str
|
||||
process_names: List[str]
|
||||
default_port: int
|
||||
|
||||
|
||||
@dataclass
|
||||
class TaskConfig:
|
||||
server_cmd: str
|
||||
client_cmd: str
|
||||
name: Optional[str] = None
|
||||
server_type: Optional[str] = None
|
||||
|
||||
|
||||
@dataclass
|
||||
class TaskResult:
|
||||
name: str
|
||||
success: bool
|
||||
output: str
|
||||
runtime: float
|
||||
timestamp: str
|
||||
|
||||
|
||||
SERVER_DEFAULTS = {
|
||||
"sglang": ServerConfig(
|
||||
command="sglang.launch_server",
|
||||
process_names=["sglang.launch_server"],
|
||||
default_port=30000,
|
||||
),
|
||||
"vllm": ServerConfig(
|
||||
command="vllm.entrypoints.openai.api_server",
|
||||
process_names=["vllm.entrypoints.openai.api_server"],
|
||||
default_port=8000,
|
||||
),
|
||||
}
|
||||
|
||||
|
||||
def parse_key_info(output: str) -> str:
|
||||
"""Extract and format key information from the output"""
|
||||
key_info = []
|
||||
|
||||
# Extract Args namespace
|
||||
args_match = re.search(r"Namespace\(.*?\)", output, re.DOTALL)
|
||||
if args_match:
|
||||
key_info.append(args_match.group(0))
|
||||
|
||||
# Extract input/output token counts
|
||||
token_matches = re.findall(r"#(Input|Output) tokens: \d+", output)
|
||||
key_info.extend(token_matches)
|
||||
|
||||
# Extract benchmark result section
|
||||
result_match = re.search(
|
||||
r"============ Serving Benchmark Result ============.*?={50,}",
|
||||
output,
|
||||
re.DOTALL,
|
||||
)
|
||||
if result_match:
|
||||
key_info.append(result_match.group(0))
|
||||
|
||||
return "\n\n".join(key_info)
|
||||
|
||||
|
||||
def extract_port_from_command(cmd: str, server_type: str) -> int:
|
||||
port_match = re.search(r"--port[= ](\d+)", cmd)
|
||||
if port_match:
|
||||
return int(port_match.group(1))
|
||||
return SERVER_DEFAULTS.get(server_type, ServerConfig("", [], 8000)).default_port
|
||||
|
||||
|
||||
def detect_server_type(cmd: str) -> str:
|
||||
for server_type, config in SERVER_DEFAULTS.items():
|
||||
if config.command in cmd:
|
||||
return server_type
|
||||
return "unknown"
|
||||
|
||||
|
||||
def stream_output(
|
||||
process: subprocess.Popen, prefix: str, logger: logging.Logger
|
||||
) -> queue.Queue:
|
||||
output_queue = queue.Queue()
|
||||
|
||||
def stream_pipe(pipe, prefix):
|
||||
for line in iter(pipe.readline, ""):
|
||||
if prefix == "CLIENT":
|
||||
output_queue.put(line.rstrip())
|
||||
logger.debug(f"{prefix} | {line.rstrip()}")
|
||||
|
||||
stdout_thread = threading.Thread(
|
||||
target=stream_pipe, args=(process.stdout, prefix), daemon=True
|
||||
)
|
||||
stderr_thread = threading.Thread(
|
||||
target=stream_pipe, args=(process.stderr, prefix), daemon=True
|
||||
)
|
||||
|
||||
stdout_thread.start()
|
||||
stderr_thread.start()
|
||||
return output_queue, (stdout_thread, stderr_thread)
|
||||
|
||||
|
||||
class ProcessManager:
|
||||
def __init__(self):
|
||||
self.server_process: Optional[subprocess.Popen] = None
|
||||
self.client_process: Optional[subprocess.Popen] = None
|
||||
self.logger = logging.getLogger(__name__)
|
||||
|
||||
def start_process(
|
||||
self, command: str, prefix: str
|
||||
) -> Tuple[subprocess.Popen, queue.Queue]:
|
||||
process = subprocess.Popen(
|
||||
command,
|
||||
shell=True,
|
||||
stdout=subprocess.PIPE,
|
||||
stderr=subprocess.PIPE,
|
||||
text=True,
|
||||
bufsize=1,
|
||||
)
|
||||
|
||||
output_queue, threads = stream_output(process, prefix, self.logger)
|
||||
return process, output_queue, threads
|
||||
|
||||
def kill_process_tree(self, process: subprocess.Popen):
|
||||
try:
|
||||
parent = psutil.Process(process.pid)
|
||||
children = parent.children(recursive=True)
|
||||
|
||||
for child in children:
|
||||
try:
|
||||
child.kill()
|
||||
except psutil.NoSuchProcess:
|
||||
pass
|
||||
|
||||
parent.kill()
|
||||
gone, alive = psutil.wait_procs(children + [parent], timeout=3)
|
||||
|
||||
for p in alive:
|
||||
try:
|
||||
p.kill()
|
||||
except psutil.NoSuchProcess:
|
||||
pass
|
||||
|
||||
except psutil.NoSuchProcess:
|
||||
pass
|
||||
|
||||
def cleanup(self, process_names: List[str]):
|
||||
if self.client_process:
|
||||
self.kill_process_tree(self.client_process)
|
||||
self.client_process = None
|
||||
|
||||
if self.server_process:
|
||||
self.kill_process_tree(self.server_process)
|
||||
self.server_process = None
|
||||
|
||||
for proc in psutil.process_iter(["pid", "name", "cmdline"]):
|
||||
try:
|
||||
cmdline = " ".join(proc.cmdline())
|
||||
if any(name in cmdline for name in process_names):
|
||||
proc.kill()
|
||||
except (psutil.NoSuchProcess, psutil.AccessDenied):
|
||||
continue
|
||||
|
||||
|
||||
class ExperimentRunner:
|
||||
def __init__(self):
|
||||
self.process_manager = ProcessManager()
|
||||
self.logger = logging.getLogger(__name__)
|
||||
|
||||
def wait_for_server(self, port: int, timeout: int = 300) -> bool:
|
||||
start_time = time.time()
|
||||
|
||||
while time.time() - start_time < timeout:
|
||||
try:
|
||||
response = requests.get(f"http://localhost:{port}/health")
|
||||
if response.status_code == 200:
|
||||
self.logger.debug(f"Server ready on port {port}")
|
||||
return True
|
||||
except requests.RequestException:
|
||||
time.sleep(2)
|
||||
return False
|
||||
|
||||
def run_task(self, config: TaskConfig) -> TaskResult:
|
||||
start_time = time.time()
|
||||
client_output = []
|
||||
|
||||
try:
|
||||
if not config.server_type:
|
||||
config.server_type = detect_server_type(config.server_cmd)
|
||||
|
||||
server_config = SERVER_DEFAULTS.get(config.server_type)
|
||||
if not server_config:
|
||||
raise ValueError(f"Unknown server type: {config.server_type}")
|
||||
|
||||
port = extract_port_from_command(config.server_cmd, config.server_type)
|
||||
|
||||
self.process_manager.cleanup(server_config.process_names)
|
||||
|
||||
self.logger.debug(f"Starting server: {config.name}")
|
||||
self.process_manager.server_process, _, server_threads = (
|
||||
self.process_manager.start_process(config.server_cmd, "SERVER")
|
||||
)
|
||||
|
||||
if not self.wait_for_server(port):
|
||||
raise TimeoutError("Server startup timeout")
|
||||
|
||||
time.sleep(10)
|
||||
|
||||
self.logger.debug("Starting client")
|
||||
self.process_manager.client_process, output_queue, client_threads = (
|
||||
self.process_manager.start_process(config.client_cmd, "CLIENT")
|
||||
)
|
||||
|
||||
returncode = self.process_manager.client_process.wait()
|
||||
|
||||
while True:
|
||||
try:
|
||||
line = output_queue.get_nowait()
|
||||
client_output.append(line)
|
||||
except queue.Empty:
|
||||
break
|
||||
|
||||
if returncode != 0:
|
||||
raise RuntimeError(f"Client failed with code {returncode}")
|
||||
|
||||
# Parse and format the output
|
||||
full_output = "\n".join(client_output)
|
||||
formatted_output = parse_key_info(full_output)
|
||||
|
||||
return TaskResult(
|
||||
name=config.name,
|
||||
success=True,
|
||||
output=formatted_output,
|
||||
runtime=time.time() - start_time,
|
||||
timestamp=datetime.now().isoformat(),
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
return TaskResult(
|
||||
name=config.name,
|
||||
success=False,
|
||||
output=str(e),
|
||||
runtime=time.time() - start_time,
|
||||
timestamp=datetime.now().isoformat(),
|
||||
)
|
||||
|
||||
finally:
|
||||
if config.server_type in SERVER_DEFAULTS:
|
||||
self.process_manager.cleanup(
|
||||
SERVER_DEFAULTS[config.server_type].process_names
|
||||
)
|
||||
time.sleep(10)
|
||||
|
||||
|
||||
def load_config(config_path: str) -> List[TaskConfig]:
|
||||
with open(config_path, "r") as f:
|
||||
config_data = yaml.safe_load(f)
|
||||
|
||||
configs = []
|
||||
for idx, entry in enumerate(config_data.get("tasks", [])):
|
||||
if not isinstance(entry, dict):
|
||||
raise ValueError(f"Invalid entry at index {idx}")
|
||||
|
||||
config = TaskConfig(
|
||||
server_cmd=entry.get("server_cmd"),
|
||||
client_cmd=entry.get("client_cmd"),
|
||||
name=entry.get("name", f"task-{idx+1}"),
|
||||
server_type=entry.get("server_type"),
|
||||
)
|
||||
|
||||
if not config.server_cmd or not config.client_cmd:
|
||||
raise ValueError(f"Missing commands in {config.name}")
|
||||
|
||||
configs.append(config)
|
||||
|
||||
return configs
|
||||
|
||||
|
||||
def setup_logging(debug: bool = False):
|
||||
level = logging.DEBUG if debug else logging.INFO
|
||||
logging.basicConfig(
|
||||
level=level,
|
||||
format="%(asctime)s - %(levelname)s - %(message)s",
|
||||
handlers=[logging.StreamHandler(), logging.FileHandler("experiment.log")],
|
||||
)
|
||||
|
||||
|
||||
def format_results(results: List[TaskResult]) -> str:
|
||||
"""Format experiment results in Markdown for GitHub step summary."""
|
||||
output = ["# Experiment Results\n"]
|
||||
|
||||
for result in results:
|
||||
output.append(f"## {result.name}")
|
||||
output.append(f"**Status**: {'✅ Success' if result.success else '❌ Failed'}")
|
||||
output.append(f"**Runtime**: {result.runtime:.2f} seconds")
|
||||
output.append(f"**Timestamp**: {result.timestamp}")
|
||||
output.append("\n**Output**:\n```")
|
||||
output.append(result.output)
|
||||
output.append("```\n")
|
||||
|
||||
return "\n".join(output)
|
||||
|
||||
|
||||
def write_in_github_step_summary(results: List[TaskResult]):
|
||||
"""Write formatted results to GitHub step summary."""
|
||||
if not os.environ.get("GITHUB_STEP_SUMMARY"):
|
||||
logging.warning("GITHUB_STEP_SUMMARY environment variable not set")
|
||||
return
|
||||
|
||||
formatted_content = format_results(results)
|
||||
with open(os.environ["GITHUB_STEP_SUMMARY"], "a") as f:
|
||||
f.write(formatted_content)
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(description="Experiment Runner")
|
||||
parser.add_argument(
|
||||
"--config", type=str, required=True, help="Path to YAML config file"
|
||||
)
|
||||
parser.add_argument("--debug", action="store_true", help="Enable debug output")
|
||||
args = parser.parse_args()
|
||||
|
||||
setup_logging(args.debug)
|
||||
logger = logging.getLogger(__name__)
|
||||
results = []
|
||||
|
||||
try:
|
||||
configs = load_config(args.config)
|
||||
runner = ExperimentRunner()
|
||||
|
||||
for config in configs:
|
||||
logger.info(f"Running {config.name}")
|
||||
result = runner.run_task(config)
|
||||
results.append(result)
|
||||
|
||||
write_in_github_step_summary(results)
|
||||
except Exception as e:
|
||||
logger.error(f"Error: {e}")
|
||||
raise
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
Reference in New Issue
Block a user