Sync from v0.13
This commit is contained in:
2
tests/evals/gpt_oss/__init__.py
Normal file
2
tests/evals/gpt_oss/__init__.py
Normal file
@@ -0,0 +1,2 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
16
tests/evals/gpt_oss/conftest.py
Normal file
16
tests/evals/gpt_oss/conftest.py
Normal file
@@ -0,0 +1,16 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
"""
|
||||
Pytest configuration for GPT-OSS evaluation tests.
|
||||
"""
|
||||
|
||||
|
||||
def pytest_addoption(parser):
|
||||
"""Add command line options for pytest."""
|
||||
parser.addoption("--model", action="store", help="Model name to evaluate")
|
||||
parser.addoption(
|
||||
"--metric", action="store", type=float, help="Expected metric threshold"
|
||||
)
|
||||
parser.addoption(
|
||||
"--server-args", action="store", default="", help="Additional server arguments"
|
||||
)
|
||||
118
tests/evals/gpt_oss/test_gpqa_correctness.py
Normal file
118
tests/evals/gpt_oss/test_gpqa_correctness.py
Normal file
@@ -0,0 +1,118 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
"""
|
||||
GPQA evaluation using vLLM server and GPT-OSS evaluation package.
|
||||
|
||||
Usage:
|
||||
pytest -s -v tests/evals/gpt_oss/test_gpqa_correctness.py \
|
||||
--model openai/gpt-oss-20b \
|
||||
--metric 0.58 \
|
||||
--server-args "--tensor-parallel-size 2"
|
||||
"""
|
||||
|
||||
import subprocess
|
||||
import sys
|
||||
|
||||
import regex as re
|
||||
|
||||
from tests.utils import RemoteOpenAIServer
|
||||
|
||||
TOL = 0.05 # Absolute tolerance for accuracy comparison
|
||||
|
||||
|
||||
def run_gpqa_eval(model_name: str, base_url: str) -> float:
|
||||
"""Run GPQA evaluation using the gpt-oss evaluation package."""
|
||||
|
||||
# Build the command to run the evaluation
|
||||
cmd = [
|
||||
sys.executable,
|
||||
"-m",
|
||||
"gpt_oss.evals",
|
||||
"--eval",
|
||||
"gpqa",
|
||||
"--model",
|
||||
model_name,
|
||||
"--reasoning-effort",
|
||||
"low",
|
||||
"--base-url",
|
||||
base_url,
|
||||
"--n-threads",
|
||||
"200",
|
||||
]
|
||||
|
||||
try:
|
||||
# Run the evaluation
|
||||
result = subprocess.run(
|
||||
cmd,
|
||||
text=True,
|
||||
capture_output=True,
|
||||
timeout=1800, # 30 minute timeout
|
||||
env={"OPENAI_API_KEY": "dummy"},
|
||||
)
|
||||
|
||||
print("Evaluation process output:\n", result.stdout)
|
||||
|
||||
# Parse the output to extract the score
|
||||
match = re.search(r"'metric':\s*([\d.]+)", result.stdout)
|
||||
if match:
|
||||
return float(match.group(1))
|
||||
|
||||
# If we still can't find it, raise an error
|
||||
raise ValueError(
|
||||
f"Could not parse score from evaluation output:\n{result.stdout}"
|
||||
)
|
||||
|
||||
except subprocess.TimeoutExpired as e:
|
||||
raise RuntimeError("Evaluation timed out") from e
|
||||
except subprocess.CalledProcessError as e:
|
||||
raise RuntimeError(
|
||||
f"Evaluation failed with exit code {e.returncode}:\n"
|
||||
f"stdout: {e.stdout}\nstderr: {e.stderr}"
|
||||
) from e
|
||||
|
||||
|
||||
def test_gpqa_correctness(request):
|
||||
"""Test GPQA correctness for GPT-OSS model."""
|
||||
|
||||
# Get command line arguments
|
||||
model_name = request.config.getoption("--model")
|
||||
expected_metric = request.config.getoption("--metric")
|
||||
server_args_str = request.config.getoption("--server-args")
|
||||
|
||||
# Parse server arguments
|
||||
server_args = []
|
||||
if server_args_str:
|
||||
server_args = server_args_str.split()
|
||||
|
||||
# Add standard server arguments
|
||||
server_args.extend(
|
||||
[
|
||||
"--trust-remote-code",
|
||||
]
|
||||
)
|
||||
|
||||
print(f"Starting GPQA evaluation for model: {model_name}")
|
||||
print(f"Expected metric threshold: {expected_metric}")
|
||||
print(f"Server args: {' '.join(server_args)}")
|
||||
|
||||
# Launch server and run evaluation
|
||||
with RemoteOpenAIServer(
|
||||
model_name, server_args, max_wait_seconds=1800
|
||||
) as remote_server:
|
||||
base_url = remote_server.url_for("v1")
|
||||
print(f"Server started at: {base_url}")
|
||||
|
||||
measured_metric = run_gpqa_eval(model_name, base_url)
|
||||
|
||||
print(f"GPQA Results for {model_name}:")
|
||||
print(f" Measured metric: {measured_metric:.4f}")
|
||||
print(f" Expected metric: {expected_metric:.4f}")
|
||||
print(f" Tolerance: {TOL:.4f}")
|
||||
|
||||
# Verify metric is within tolerance
|
||||
assert measured_metric >= expected_metric - TOL, (
|
||||
f"GPQA metric too low: {measured_metric:.4f} < "
|
||||
f"{expected_metric:.4f} - {TOL:.4f} = {expected_metric - TOL:.4f}"
|
||||
)
|
||||
|
||||
print(f"✅ GPQA test passed for {model_name}")
|
||||
35
tests/evals/gsm8k/README.md
Normal file
35
tests/evals/gsm8k/README.md
Normal file
@@ -0,0 +1,35 @@
|
||||
# GSM8K Accuracy Evaluation
|
||||
|
||||
This directory contains a replacement for the lm-eval-harness GSM8K evaluation, using an isolated GSM8K script and vLLM server for better performance and control.
|
||||
|
||||
## Usage
|
||||
|
||||
### Run tests with pytest (like buildkite)
|
||||
|
||||
```bash
|
||||
pytest -s -v tests/gsm8k/test_gsm8k_correctness.py \
|
||||
--config-list-file=configs/models-small.txt \
|
||||
--tp-size=1
|
||||
```
|
||||
|
||||
### Run standalone evaluation script
|
||||
|
||||
```bash
|
||||
# Start vLLM server first
|
||||
vllm serve Qwen/Qwen2.5-1.5B-Instruct --port 8000
|
||||
|
||||
# Run evaluation
|
||||
python tests/evals/gsm8k/gsm8k_eval.py --port 8000
|
||||
```
|
||||
|
||||
## Configuration Format
|
||||
|
||||
Model configs in `configs/` directory use this YAML format:
|
||||
|
||||
```yaml
|
||||
model_name: "Qwen/Qwen2.5-1.5B-Instruct"
|
||||
accuracy_threshold: 0.54 # Minimum expected accuracy
|
||||
num_questions: 1319 # Number of questions (default: full test set)
|
||||
num_fewshot: 5 # Few-shot examples from train set
|
||||
max_model_len: 4096 # Model context length
|
||||
```
|
||||
2
tests/evals/gsm8k/__init__.py
Normal file
2
tests/evals/gsm8k/__init__.py
Normal file
@@ -0,0 +1,2 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
@@ -0,0 +1,6 @@
|
||||
model_name: "RedHatAI/DeepSeek-Coder-V2-Lite-Instruct-FP8"
|
||||
accuracy_threshold: 0.72
|
||||
num_questions: 1319
|
||||
num_fewshot: 5
|
||||
max_model_len: 4096
|
||||
|
||||
@@ -0,0 +1,5 @@
|
||||
model_name: "nm-testing/Meta-Llama-3-8B-Instruct-nonuniform-test"
|
||||
accuracy_threshold: 0.74
|
||||
num_questions: 1319
|
||||
num_fewshot: 5
|
||||
max_model_len: 4096
|
||||
@@ -0,0 +1,5 @@
|
||||
model_name: "RedHatAI/Llama-3.2-1B-Instruct-quantized.w8a8"
|
||||
accuracy_threshold: 0.31
|
||||
num_questions: 1319
|
||||
num_fewshot: 5
|
||||
max_model_len: 4096
|
||||
5
tests/evals/gsm8k/configs/Qwen1.5-MoE-W4A16-CT.yaml
Normal file
5
tests/evals/gsm8k/configs/Qwen1.5-MoE-W4A16-CT.yaml
Normal file
@@ -0,0 +1,5 @@
|
||||
model_name: "nm-testing/Qwen1.5-MoE-A2.7B-Chat-quantized.w4a16"
|
||||
accuracy_threshold: 0.45
|
||||
num_questions: 1319
|
||||
num_fewshot: 5
|
||||
max_model_len: 4096
|
||||
@@ -0,0 +1,5 @@
|
||||
model_name: "RedHatAI/Qwen2.5-VL-3B-Instruct-FP8-Dynamic"
|
||||
accuracy_threshold: 0.60
|
||||
num_questions: 1319
|
||||
num_fewshot: 5
|
||||
max_model_len: 4096
|
||||
5
tests/evals/gsm8k/configs/Qwen3-0.6B-FP8.yaml
Normal file
5
tests/evals/gsm8k/configs/Qwen3-0.6B-FP8.yaml
Normal file
@@ -0,0 +1,5 @@
|
||||
model_name: "Qwen/Qwen3-0.6B-FP8"
|
||||
accuracy_threshold: 0.375
|
||||
num_questions: 1319
|
||||
num_fewshot: 5
|
||||
max_model_len: 4096
|
||||
6
tests/evals/gsm8k/configs/Qwen3-30B-A3B-NVFP4.yaml
Normal file
6
tests/evals/gsm8k/configs/Qwen3-30B-A3B-NVFP4.yaml
Normal file
@@ -0,0 +1,6 @@
|
||||
model_name: "nvidia/Qwen3-30B-A3B-FP4"
|
||||
accuracy_threshold: 0.89
|
||||
num_questions: 1319
|
||||
num_fewshot: 5
|
||||
max_model_len: 4096
|
||||
|
||||
5
tests/evals/gsm8k/configs/models-blackwell.txt
Normal file
5
tests/evals/gsm8k/configs/models-blackwell.txt
Normal file
@@ -0,0 +1,5 @@
|
||||
Qwen3-0.6B-FP8.yaml
|
||||
Qwen2.5-VL-3B-Instruct-FP8-dynamic.yaml
|
||||
Qwen1.5-MoE-W4A16-CT.yaml
|
||||
DeepSeek-V2-Lite-Instruct-FP8.yaml
|
||||
Qwen3-30B-A3B-NVFP4.yaml
|
||||
6
tests/evals/gsm8k/configs/models-small.txt
Normal file
6
tests/evals/gsm8k/configs/models-small.txt
Normal file
@@ -0,0 +1,6 @@
|
||||
Qwen3-0.6B-FP8.yaml
|
||||
Llama-3.2-1B-Instruct-INT8-CT.yaml
|
||||
Llama-3-8B-Instruct-nonuniform-CT.yaml
|
||||
Qwen2.5-VL-3B-Instruct-FP8-dynamic.yaml
|
||||
Qwen1.5-MoE-W4A16-CT.yaml
|
||||
DeepSeek-V2-Lite-Instruct-FP8.yaml
|
||||
63
tests/evals/gsm8k/conftest.py
Normal file
63
tests/evals/gsm8k/conftest.py
Normal file
@@ -0,0 +1,63 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from pathlib import Path
|
||||
|
||||
|
||||
def pytest_addoption(parser):
|
||||
"""Add custom command line options."""
|
||||
parser.addoption(
|
||||
"--config-list-file",
|
||||
default="configs/models-small.txt",
|
||||
help="File containing list of config files to test",
|
||||
)
|
||||
parser.addoption("--tp-size", default=1, type=int, help="Tensor parallel size")
|
||||
|
||||
|
||||
def pytest_generate_tests(metafunc):
|
||||
"""Generate test parameters from config files."""
|
||||
if "config_filename" in metafunc.fixturenames:
|
||||
config_list_file = metafunc.config.getoption("--config-list-file")
|
||||
tp_size = metafunc.config.getoption("--tp-size")
|
||||
|
||||
# Handle both relative and absolute paths
|
||||
config_list_path = Path(config_list_file)
|
||||
if not config_list_path.is_absolute():
|
||||
# If relative, try relative to test directory first
|
||||
test_dir_path = Path(__file__).parent / config_list_file
|
||||
if test_dir_path.exists():
|
||||
config_list_path = test_dir_path
|
||||
else:
|
||||
# Try relative to current working directory
|
||||
config_list_path = Path.cwd() / config_list_file
|
||||
|
||||
print(f"Looking for config list at: {config_list_path}")
|
||||
|
||||
config_files = []
|
||||
if config_list_path.exists():
|
||||
# Determine config directory (same directory as the list file)
|
||||
config_dir = config_list_path.parent
|
||||
|
||||
with open(config_list_path) as f:
|
||||
for line in f:
|
||||
line = line.strip()
|
||||
if line and not line.startswith("#"):
|
||||
config_path = config_dir / line
|
||||
print(f"Checking config file: {config_path}")
|
||||
if config_path.exists():
|
||||
config_files.append(config_path)
|
||||
print(f" ✓ Found: {config_path}")
|
||||
else:
|
||||
print(f" ✗ Missing: {config_path}")
|
||||
else:
|
||||
print(f"Config list file not found: {config_list_path}")
|
||||
|
||||
# Generate test parameters
|
||||
if config_files:
|
||||
metafunc.parametrize(
|
||||
["config_filename", "tp_size"],
|
||||
[(config_file, int(tp_size)) for config_file in config_files],
|
||||
ids=[f"{config_file.stem}-tp{tp_size}" for config_file in config_files],
|
||||
)
|
||||
else:
|
||||
print("No config files found, test will be skipped")
|
||||
262
tests/evals/gsm8k/gsm8k_eval.py
Normal file
262
tests/evals/gsm8k/gsm8k_eval.py
Normal file
@@ -0,0 +1,262 @@
|
||||
#!/usr/bin/env python3
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
"""
|
||||
Isolated GSM8K evaluation script for vLLM serve endpoint.
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import ast
|
||||
import asyncio
|
||||
import json
|
||||
import os
|
||||
import time
|
||||
from collections.abc import Generator
|
||||
|
||||
import aiohttp
|
||||
import numpy as np
|
||||
import regex as re
|
||||
import requests
|
||||
from tqdm.asyncio import tqdm
|
||||
|
||||
INVALID = -9999999
|
||||
|
||||
|
||||
def download_and_cache_file(url: str, filename: str | None = None) -> str:
|
||||
"""Download and cache a file from a URL."""
|
||||
if filename is None:
|
||||
filename = os.path.join("/tmp", url.split("/")[-1])
|
||||
|
||||
if os.path.exists(filename):
|
||||
return filename
|
||||
|
||||
print(f"Downloading from {url} to {filename}")
|
||||
response = requests.get(url, stream=True)
|
||||
response.raise_for_status()
|
||||
|
||||
with open(filename, "wb") as f:
|
||||
for chunk in response.iter_content(chunk_size=1024):
|
||||
f.write(chunk)
|
||||
|
||||
return filename
|
||||
|
||||
|
||||
def load_gsm8k_data() -> tuple[list[dict], list[dict]]:
|
||||
"""Load GSM8K train and test data"""
|
||||
train_url = "https://raw.githubusercontent.com/openai/grade-school-math/master/grade_school_math/data/train.jsonl"
|
||||
test_url = "https://raw.githubusercontent.com/openai/grade-school-math/master/grade_school_math/data/test.jsonl"
|
||||
|
||||
train_file = download_and_cache_file(train_url)
|
||||
test_file = download_and_cache_file(test_url)
|
||||
|
||||
train_data = list(read_jsonl(train_file))
|
||||
test_data = list(read_jsonl(test_file))
|
||||
|
||||
return train_data, test_data
|
||||
|
||||
|
||||
def read_jsonl(filename: str) -> Generator[dict, None, None]:
|
||||
"""Read a JSONL file."""
|
||||
with open(filename) as fin:
|
||||
for line in fin:
|
||||
if not line.startswith("#"):
|
||||
yield json.loads(line)
|
||||
|
||||
|
||||
def get_answer_value(answer_str: str) -> int:
|
||||
"""Extract the numerical answer from the response."""
|
||||
answer_str = answer_str.replace(",", "")
|
||||
numbers = re.findall(r"\d+", answer_str)
|
||||
if len(numbers) < 1:
|
||||
return INVALID
|
||||
try:
|
||||
return ast.literal_eval(numbers[-1])
|
||||
except SyntaxError:
|
||||
return INVALID
|
||||
|
||||
|
||||
async def call_vllm_api(
|
||||
session: aiohttp.ClientSession,
|
||||
prompt: str,
|
||||
temperature: float,
|
||||
max_tokens: int,
|
||||
stop: list[str] | None = None,
|
||||
url: str | None = None,
|
||||
seed: int | None = None,
|
||||
) -> tuple[str, int]:
|
||||
"""Call vLLM's OpenAI-compatible completions endpoint.
|
||||
|
||||
Returns:
|
||||
Tuple of (response_text, completion_tokens)
|
||||
"""
|
||||
data = {
|
||||
"prompt": prompt,
|
||||
"temperature": temperature,
|
||||
"max_tokens": max_tokens,
|
||||
"stop": stop,
|
||||
}
|
||||
if seed is not None:
|
||||
data["seed"] = seed
|
||||
|
||||
try:
|
||||
async with session.post(f"{url}/v1/completions", json=data) as response:
|
||||
response.raise_for_status()
|
||||
result = await response.json()
|
||||
text = result["choices"][0]["text"]
|
||||
completion_tokens = result.get("usage", {}).get("completion_tokens", 0)
|
||||
return text, completion_tokens
|
||||
except Exception as e:
|
||||
print(f"Error calling vLLM API: {e}")
|
||||
return "", 0
|
||||
|
||||
|
||||
def evaluate_gsm8k(
|
||||
num_questions: int = 1319,
|
||||
num_shots: int = 5,
|
||||
max_tokens: int = 256,
|
||||
host: str = "http://127.0.0.1",
|
||||
port: int = 8000,
|
||||
temperature: float = 0.0,
|
||||
seed: int | None = 42,
|
||||
) -> dict[str, float | int]:
|
||||
"""
|
||||
Evaluate GSM8K accuracy using vLLM serve endpoint.
|
||||
|
||||
Returns dict with accuracy, invalid_rate, latency, etc.
|
||||
"""
|
||||
base_url = f"{host}:{port}"
|
||||
|
||||
# Load GSM8K train and test data
|
||||
train_data, test_data = load_gsm8k_data()
|
||||
|
||||
# Limit to available test questions
|
||||
num_questions = min(num_questions, len(test_data))
|
||||
|
||||
# Build few-shot examples from train split (like lm-eval does)
|
||||
few_shot_examples = ""
|
||||
for i in range(num_shots):
|
||||
few_shot_examples += (
|
||||
f"Question: {train_data[i]['question']}\n"
|
||||
f"Answer: {train_data[i]['answer']}\n\n"
|
||||
)
|
||||
|
||||
# Prepare test questions and labels from test split
|
||||
questions = []
|
||||
labels = []
|
||||
for i in range(num_questions):
|
||||
questions.append(f"Question: {test_data[i]['question']}\nAnswer:")
|
||||
labels.append(get_answer_value(test_data[i]["answer"]))
|
||||
|
||||
assert all(label != INVALID for label in labels), "Some labels are invalid"
|
||||
|
||||
# Run evaluation
|
||||
async def run_async_evaluation():
|
||||
states: list[str] = [""] * num_questions
|
||||
output_tokens: list[int] = [0] * num_questions
|
||||
|
||||
async def get_answer(session: aiohttp.ClientSession, i: int) -> tuple[str, int]:
|
||||
prompt = few_shot_examples + questions[i]
|
||||
answer, tokens = await call_vllm_api(
|
||||
session=session,
|
||||
prompt=prompt,
|
||||
temperature=temperature,
|
||||
max_tokens=max_tokens,
|
||||
stop=["Question", "Assistant:", "<|separator|>"],
|
||||
url=base_url,
|
||||
seed=seed,
|
||||
)
|
||||
states[i] = answer
|
||||
output_tokens[i] = tokens
|
||||
return answer, tokens
|
||||
|
||||
async with aiohttp.ClientSession(
|
||||
timeout=aiohttp.ClientTimeout(total=600)
|
||||
) as session:
|
||||
tasks = [get_answer(session, i) for i in range(num_questions)]
|
||||
await tqdm.gather(*tasks, desc="Evaluating")
|
||||
|
||||
return states, output_tokens
|
||||
|
||||
print(f"Running GSM8K evaluation: {num_questions} questions, {num_shots}-shot")
|
||||
|
||||
tic = time.perf_counter()
|
||||
states, output_tokens = asyncio.run(run_async_evaluation())
|
||||
latency = time.perf_counter() - tic
|
||||
|
||||
# Compute metrics
|
||||
preds = [get_answer_value(state) for state in states]
|
||||
accuracy = np.mean(np.array(preds) == np.array(labels))
|
||||
invalid_rate = np.mean(np.array(preds) == INVALID)
|
||||
total_output_tokens = sum(output_tokens)
|
||||
tokens_per_second = total_output_tokens / latency if latency > 0 else 0.0
|
||||
|
||||
result = {
|
||||
"accuracy": accuracy,
|
||||
"invalid_rate": invalid_rate,
|
||||
"latency": latency,
|
||||
"questions_per_second": num_questions / latency,
|
||||
"total_output_tokens": total_output_tokens,
|
||||
"tokens_per_second": tokens_per_second,
|
||||
"num_questions": num_questions,
|
||||
"num_shots": num_shots,
|
||||
"max_tokens": max_tokens,
|
||||
"timestamp": time.time(),
|
||||
}
|
||||
|
||||
return result
|
||||
|
||||
|
||||
def main() -> None:
|
||||
parser = argparse.ArgumentParser(description="GSM8K evaluation for vLLM serve")
|
||||
parser.add_argument(
|
||||
"--num-shots", type=int, default=5, help="Number of few-shot examples"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--num-questions",
|
||||
type=int,
|
||||
default=1319,
|
||||
help="Number of questions to evaluate",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--max-tokens", type=int, default=256, help="Max tokens for generation"
|
||||
)
|
||||
parser.add_argument("--host", type=str, default="http://127.0.0.1", help="Host URL")
|
||||
parser.add_argument("--port", type=int, default=8000, help="Port number")
|
||||
parser.add_argument(
|
||||
"--temperature", type=float, default=0.0, help="Temperature for generation"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--seed", type=int, default=42, help="Random seed for reproducibility"
|
||||
)
|
||||
parser.add_argument("--save-results", type=str, help="Save results to JSON file")
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
result = evaluate_gsm8k(
|
||||
num_questions=args.num_questions,
|
||||
num_shots=args.num_shots,
|
||||
max_tokens=args.max_tokens,
|
||||
host=args.host,
|
||||
port=args.port,
|
||||
temperature=args.temperature,
|
||||
seed=args.seed,
|
||||
)
|
||||
|
||||
# Print results to terminal
|
||||
print("\nResults:")
|
||||
print(f"Accuracy: {result['accuracy']:.3f}")
|
||||
print(f"Invalid responses: {result['invalid_rate']:.3f}")
|
||||
print(f"Total latency: {result['latency']:.3f} s")
|
||||
print(f"Questions per second: {result['questions_per_second']:.3f}")
|
||||
print(f"Total output tokens: {result['total_output_tokens']}")
|
||||
print(f"Output tokens per second: {result['tokens_per_second']:.3f}")
|
||||
|
||||
# Optional file saving
|
||||
if args.save_results:
|
||||
with open(args.save_results, "w") as f:
|
||||
json.dump(result, f, indent=2)
|
||||
print(f"Results saved to {args.save_results}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
93
tests/evals/gsm8k/test_gsm8k_correctness.py
Normal file
93
tests/evals/gsm8k/test_gsm8k_correctness.py
Normal file
@@ -0,0 +1,93 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
"""
|
||||
GSM8K evaluation using vLLM server and isolated GSM8K script.
|
||||
Replacement for lm-eval-harness with better performance and control.
|
||||
|
||||
Usage:
|
||||
pytest -s -v test_gsm8k_correctness.py \
|
||||
--config-list-file=configs/models-small.txt \
|
||||
--tp-size=1
|
||||
"""
|
||||
|
||||
import yaml
|
||||
|
||||
from tests.utils import RemoteOpenAIServer
|
||||
|
||||
from .gsm8k_eval import evaluate_gsm8k
|
||||
|
||||
RTOL = 0.08 # Relative tolerance for accuracy comparison
|
||||
|
||||
|
||||
def launch_gsm8k_eval(eval_config, server_url, tp_size):
|
||||
"""Launch GSM8K evaluation using our isolated script."""
|
||||
# Extract host and port from server URL
|
||||
if "://" in server_url:
|
||||
server_url = server_url.split("://")[1]
|
||||
|
||||
host_port = server_url.split("/")[0] # Remove path if present
|
||||
if ":" in host_port:
|
||||
host, port = host_port.split(":")
|
||||
port = int(port)
|
||||
else:
|
||||
host = host_port
|
||||
port = 8000
|
||||
|
||||
# Add http:// prefix if not present
|
||||
if not host.startswith("http"):
|
||||
host = f"http://{host}"
|
||||
|
||||
# Run GSM8K evaluation
|
||||
results = evaluate_gsm8k(
|
||||
num_questions=eval_config["num_questions"],
|
||||
num_shots=eval_config["num_fewshot"],
|
||||
host=host,
|
||||
port=port,
|
||||
)
|
||||
|
||||
return results
|
||||
|
||||
|
||||
def test_gsm8k_correctness_param(config_filename, tp_size):
|
||||
"""Test GSM8K correctness for a given model configuration."""
|
||||
eval_config = yaml.safe_load(config_filename.read_text(encoding="utf-8"))
|
||||
|
||||
# Server arguments
|
||||
server_args = [
|
||||
"--max-model-len",
|
||||
str(eval_config.get("max_model_len", 4096)),
|
||||
"--enforce-eager",
|
||||
"--trust-remote-code",
|
||||
"--tensor-parallel-size",
|
||||
str(tp_size),
|
||||
]
|
||||
|
||||
env_dict = eval_config.get("env", None)
|
||||
|
||||
# Launch server and run evaluation
|
||||
with RemoteOpenAIServer(
|
||||
eval_config["model_name"], server_args, env_dict=env_dict, max_wait_seconds=480
|
||||
) as remote_server:
|
||||
server_url = remote_server.url_for("v1")
|
||||
|
||||
results = launch_gsm8k_eval(eval_config, server_url, tp_size)
|
||||
|
||||
# Check accuracy against threshold
|
||||
measured_accuracy = results["accuracy"]
|
||||
expected_accuracy = eval_config["accuracy_threshold"]
|
||||
|
||||
print(f"GSM8K Results for {eval_config['model_name']}:")
|
||||
print(f" Accuracy: {measured_accuracy:.3f}")
|
||||
print(f" Expected: {expected_accuracy:.3f}")
|
||||
print(f" Questions: {results['num_questions']}")
|
||||
print(f" Invalid rate: {results['invalid_rate']:.3f}")
|
||||
print(f" Latency: {results['latency']:.1f}s")
|
||||
print(f" QPS: {results['questions_per_second']:.1f}")
|
||||
|
||||
# Verify accuracy is within tolerance
|
||||
assert measured_accuracy >= expected_accuracy - RTOL, (
|
||||
f"Accuracy too low: {measured_accuracy:.3f} < "
|
||||
f"{expected_accuracy:.3f} - {RTOL:.3f}"
|
||||
)
|
||||
|
||||
print(f"✅ GSM8K test passed for {eval_config['model_name']}")
|
||||
Reference in New Issue
Block a user