Sync from v0.13
This commit is contained in:
262
tests/evals/gsm8k/gsm8k_eval.py
Normal file
262
tests/evals/gsm8k/gsm8k_eval.py
Normal file
@@ -0,0 +1,262 @@
|
||||
#!/usr/bin/env python3
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
"""
|
||||
Isolated GSM8K evaluation script for vLLM serve endpoint.
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import ast
|
||||
import asyncio
|
||||
import json
|
||||
import os
|
||||
import time
|
||||
from collections.abc import Generator
|
||||
|
||||
import aiohttp
|
||||
import numpy as np
|
||||
import regex as re
|
||||
import requests
|
||||
from tqdm.asyncio import tqdm
|
||||
|
||||
INVALID = -9999999
|
||||
|
||||
|
||||
def download_and_cache_file(url: str, filename: str | None = None) -> str:
|
||||
"""Download and cache a file from a URL."""
|
||||
if filename is None:
|
||||
filename = os.path.join("/tmp", url.split("/")[-1])
|
||||
|
||||
if os.path.exists(filename):
|
||||
return filename
|
||||
|
||||
print(f"Downloading from {url} to {filename}")
|
||||
response = requests.get(url, stream=True)
|
||||
response.raise_for_status()
|
||||
|
||||
with open(filename, "wb") as f:
|
||||
for chunk in response.iter_content(chunk_size=1024):
|
||||
f.write(chunk)
|
||||
|
||||
return filename
|
||||
|
||||
|
||||
def load_gsm8k_data() -> tuple[list[dict], list[dict]]:
|
||||
"""Load GSM8K train and test data"""
|
||||
train_url = "https://raw.githubusercontent.com/openai/grade-school-math/master/grade_school_math/data/train.jsonl"
|
||||
test_url = "https://raw.githubusercontent.com/openai/grade-school-math/master/grade_school_math/data/test.jsonl"
|
||||
|
||||
train_file = download_and_cache_file(train_url)
|
||||
test_file = download_and_cache_file(test_url)
|
||||
|
||||
train_data = list(read_jsonl(train_file))
|
||||
test_data = list(read_jsonl(test_file))
|
||||
|
||||
return train_data, test_data
|
||||
|
||||
|
||||
def read_jsonl(filename: str) -> Generator[dict, None, None]:
|
||||
"""Read a JSONL file."""
|
||||
with open(filename) as fin:
|
||||
for line in fin:
|
||||
if not line.startswith("#"):
|
||||
yield json.loads(line)
|
||||
|
||||
|
||||
def get_answer_value(answer_str: str) -> int:
|
||||
"""Extract the numerical answer from the response."""
|
||||
answer_str = answer_str.replace(",", "")
|
||||
numbers = re.findall(r"\d+", answer_str)
|
||||
if len(numbers) < 1:
|
||||
return INVALID
|
||||
try:
|
||||
return ast.literal_eval(numbers[-1])
|
||||
except SyntaxError:
|
||||
return INVALID
|
||||
|
||||
|
||||
async def call_vllm_api(
|
||||
session: aiohttp.ClientSession,
|
||||
prompt: str,
|
||||
temperature: float,
|
||||
max_tokens: int,
|
||||
stop: list[str] | None = None,
|
||||
url: str | None = None,
|
||||
seed: int | None = None,
|
||||
) -> tuple[str, int]:
|
||||
"""Call vLLM's OpenAI-compatible completions endpoint.
|
||||
|
||||
Returns:
|
||||
Tuple of (response_text, completion_tokens)
|
||||
"""
|
||||
data = {
|
||||
"prompt": prompt,
|
||||
"temperature": temperature,
|
||||
"max_tokens": max_tokens,
|
||||
"stop": stop,
|
||||
}
|
||||
if seed is not None:
|
||||
data["seed"] = seed
|
||||
|
||||
try:
|
||||
async with session.post(f"{url}/v1/completions", json=data) as response:
|
||||
response.raise_for_status()
|
||||
result = await response.json()
|
||||
text = result["choices"][0]["text"]
|
||||
completion_tokens = result.get("usage", {}).get("completion_tokens", 0)
|
||||
return text, completion_tokens
|
||||
except Exception as e:
|
||||
print(f"Error calling vLLM API: {e}")
|
||||
return "", 0
|
||||
|
||||
|
||||
def evaluate_gsm8k(
|
||||
num_questions: int = 1319,
|
||||
num_shots: int = 5,
|
||||
max_tokens: int = 256,
|
||||
host: str = "http://127.0.0.1",
|
||||
port: int = 8000,
|
||||
temperature: float = 0.0,
|
||||
seed: int | None = 42,
|
||||
) -> dict[str, float | int]:
|
||||
"""
|
||||
Evaluate GSM8K accuracy using vLLM serve endpoint.
|
||||
|
||||
Returns dict with accuracy, invalid_rate, latency, etc.
|
||||
"""
|
||||
base_url = f"{host}:{port}"
|
||||
|
||||
# Load GSM8K train and test data
|
||||
train_data, test_data = load_gsm8k_data()
|
||||
|
||||
# Limit to available test questions
|
||||
num_questions = min(num_questions, len(test_data))
|
||||
|
||||
# Build few-shot examples from train split (like lm-eval does)
|
||||
few_shot_examples = ""
|
||||
for i in range(num_shots):
|
||||
few_shot_examples += (
|
||||
f"Question: {train_data[i]['question']}\n"
|
||||
f"Answer: {train_data[i]['answer']}\n\n"
|
||||
)
|
||||
|
||||
# Prepare test questions and labels from test split
|
||||
questions = []
|
||||
labels = []
|
||||
for i in range(num_questions):
|
||||
questions.append(f"Question: {test_data[i]['question']}\nAnswer:")
|
||||
labels.append(get_answer_value(test_data[i]["answer"]))
|
||||
|
||||
assert all(label != INVALID for label in labels), "Some labels are invalid"
|
||||
|
||||
# Run evaluation
|
||||
async def run_async_evaluation():
|
||||
states: list[str] = [""] * num_questions
|
||||
output_tokens: list[int] = [0] * num_questions
|
||||
|
||||
async def get_answer(session: aiohttp.ClientSession, i: int) -> tuple[str, int]:
|
||||
prompt = few_shot_examples + questions[i]
|
||||
answer, tokens = await call_vllm_api(
|
||||
session=session,
|
||||
prompt=prompt,
|
||||
temperature=temperature,
|
||||
max_tokens=max_tokens,
|
||||
stop=["Question", "Assistant:", "<|separator|>"],
|
||||
url=base_url,
|
||||
seed=seed,
|
||||
)
|
||||
states[i] = answer
|
||||
output_tokens[i] = tokens
|
||||
return answer, tokens
|
||||
|
||||
async with aiohttp.ClientSession(
|
||||
timeout=aiohttp.ClientTimeout(total=600)
|
||||
) as session:
|
||||
tasks = [get_answer(session, i) for i in range(num_questions)]
|
||||
await tqdm.gather(*tasks, desc="Evaluating")
|
||||
|
||||
return states, output_tokens
|
||||
|
||||
print(f"Running GSM8K evaluation: {num_questions} questions, {num_shots}-shot")
|
||||
|
||||
tic = time.perf_counter()
|
||||
states, output_tokens = asyncio.run(run_async_evaluation())
|
||||
latency = time.perf_counter() - tic
|
||||
|
||||
# Compute metrics
|
||||
preds = [get_answer_value(state) for state in states]
|
||||
accuracy = np.mean(np.array(preds) == np.array(labels))
|
||||
invalid_rate = np.mean(np.array(preds) == INVALID)
|
||||
total_output_tokens = sum(output_tokens)
|
||||
tokens_per_second = total_output_tokens / latency if latency > 0 else 0.0
|
||||
|
||||
result = {
|
||||
"accuracy": accuracy,
|
||||
"invalid_rate": invalid_rate,
|
||||
"latency": latency,
|
||||
"questions_per_second": num_questions / latency,
|
||||
"total_output_tokens": total_output_tokens,
|
||||
"tokens_per_second": tokens_per_second,
|
||||
"num_questions": num_questions,
|
||||
"num_shots": num_shots,
|
||||
"max_tokens": max_tokens,
|
||||
"timestamp": time.time(),
|
||||
}
|
||||
|
||||
return result
|
||||
|
||||
|
||||
def main() -> None:
|
||||
parser = argparse.ArgumentParser(description="GSM8K evaluation for vLLM serve")
|
||||
parser.add_argument(
|
||||
"--num-shots", type=int, default=5, help="Number of few-shot examples"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--num-questions",
|
||||
type=int,
|
||||
default=1319,
|
||||
help="Number of questions to evaluate",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--max-tokens", type=int, default=256, help="Max tokens for generation"
|
||||
)
|
||||
parser.add_argument("--host", type=str, default="http://127.0.0.1", help="Host URL")
|
||||
parser.add_argument("--port", type=int, default=8000, help="Port number")
|
||||
parser.add_argument(
|
||||
"--temperature", type=float, default=0.0, help="Temperature for generation"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--seed", type=int, default=42, help="Random seed for reproducibility"
|
||||
)
|
||||
parser.add_argument("--save-results", type=str, help="Save results to JSON file")
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
result = evaluate_gsm8k(
|
||||
num_questions=args.num_questions,
|
||||
num_shots=args.num_shots,
|
||||
max_tokens=args.max_tokens,
|
||||
host=args.host,
|
||||
port=args.port,
|
||||
temperature=args.temperature,
|
||||
seed=args.seed,
|
||||
)
|
||||
|
||||
# Print results to terminal
|
||||
print("\nResults:")
|
||||
print(f"Accuracy: {result['accuracy']:.3f}")
|
||||
print(f"Invalid responses: {result['invalid_rate']:.3f}")
|
||||
print(f"Total latency: {result['latency']:.3f} s")
|
||||
print(f"Questions per second: {result['questions_per_second']:.3f}")
|
||||
print(f"Total output tokens: {result['total_output_tokens']}")
|
||||
print(f"Output tokens per second: {result['tokens_per_second']:.3f}")
|
||||
|
||||
# Optional file saving
|
||||
if args.save_results:
|
||||
with open(args.save_results, "w") as f:
|
||||
json.dump(result, f, indent=2)
|
||||
print(f"Results saved to {args.save_results}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
Reference in New Issue
Block a user