165 lines
4.5 KiB
Python
165 lines
4.5 KiB
Python
import argparse
|
||
import asyncio
|
||
import os
|
||
import pickle
|
||
from pathlib import Path
|
||
from typing import List
|
||
|
||
import openai
|
||
import torch
|
||
from bert_score import BERTScorer
|
||
from datasets import load_dataset
|
||
from tqdm import tqdm
|
||
|
||
|
||
def get_client(api_url: str) -> openai.AsyncOpenAI:
|
||
if os.getenv("OPENAI_API_KEY") is None:
|
||
os.environ["OPENAI_API_KEY"] = "EMPTY"
|
||
return openai.AsyncOpenAI(base_url=api_url)
|
||
|
||
|
||
def get_dataset():
|
||
return load_dataset("bigai-nlco/LooGLE", "longdep_qa", split="test")
|
||
|
||
|
||
async def fetch_response(
|
||
client: openai.AsyncOpenAI,
|
||
context: str,
|
||
question: str,
|
||
semaphore: asyncio.Semaphore,
|
||
index: int,
|
||
model: str,
|
||
output_dir: Path,
|
||
):
|
||
output_file = output_dir / f"response_{index}.pkl"
|
||
if output_file.exists():
|
||
return
|
||
|
||
prompt = (
|
||
"Please answer the question based on the long texts below.\n"
|
||
f"{context}\n"
|
||
f"Question: {question}\n"
|
||
"Answer:"
|
||
)
|
||
messages = [
|
||
{"role": "system", "content": "You are a helpful assistant."},
|
||
{"role": "user", "content": prompt},
|
||
]
|
||
|
||
async with semaphore:
|
||
try:
|
||
response = await client.chat.completions.create(
|
||
model=model,
|
||
messages=messages,
|
||
temperature=0.0,
|
||
max_tokens=512,
|
||
)
|
||
except openai.BadRequestError as e:
|
||
with open(output_file, "wb") as f:
|
||
pickle.dump({"error": str(e)}, f)
|
||
return
|
||
|
||
with open(output_file, "wb") as f:
|
||
pickle.dump(response, f)
|
||
|
||
|
||
async def benchmark(args):
|
||
dataset = get_dataset()
|
||
output_dir = Path(args.output_dir)
|
||
output_dir.mkdir(parents=True, exist_ok=True)
|
||
|
||
client = get_client(args.api_url)
|
||
semaphore = asyncio.Semaphore(args.max_concurrency)
|
||
|
||
tasks: List[asyncio.Task] = []
|
||
for idx, ex in enumerate(dataset):
|
||
if idx >= args.num_prompts:
|
||
break
|
||
tasks.append(
|
||
asyncio.create_task(
|
||
fetch_response(
|
||
client,
|
||
ex["context"],
|
||
ex["question"],
|
||
semaphore,
|
||
idx,
|
||
args.model,
|
||
output_dir,
|
||
)
|
||
)
|
||
)
|
||
|
||
for _ in tqdm(
|
||
asyncio.as_completed(tasks), total=len(tasks), desc="Running benchmark"
|
||
):
|
||
await _
|
||
|
||
|
||
def analyse(args):
|
||
dataset = get_dataset()
|
||
output_dir = Path(args.output_dir)
|
||
|
||
device = "cuda" if torch.cuda.is_available() else "cpu"
|
||
scorer = BERTScorer(lang="en", device=device)
|
||
|
||
hyps: List[str] = []
|
||
refs: List[str] = []
|
||
for idx, ex in enumerate(tqdm(dataset, desc="Loading responses")):
|
||
if idx >= args.num_prompts:
|
||
break
|
||
pkl_file = output_dir / f"response_{idx}.pkl"
|
||
if not pkl_file.exists():
|
||
raise FileNotFoundError(pkl_file)
|
||
|
||
response = pickle.load(open(pkl_file, "rb"))
|
||
if isinstance(response, dict) and "error" in response:
|
||
continue
|
||
|
||
hyps.append(response.choices[0].message.content.strip())
|
||
refs.append(ex["answer"])
|
||
|
||
if not hyps:
|
||
print("No valid responses to score!")
|
||
return
|
||
|
||
batch_size = 64
|
||
all_f1: List[float] = []
|
||
for i in tqdm(range(0, len(hyps), batch_size), desc="Scoring batches"):
|
||
h_batch = hyps[i : i + batch_size]
|
||
r_batch = refs[i : i + batch_size]
|
||
_, _, f1_scores = scorer.score(h_batch, r_batch, verbose=False)
|
||
all_f1.extend([float(x) for x in f1_scores])
|
||
|
||
avg = sum(all_f1) / len(all_f1)
|
||
print(f"Average BERTScore (F1): {avg:.2%}")
|
||
|
||
|
||
if __name__ == "__main__":
|
||
parser = argparse.ArgumentParser(
|
||
description="Run benchmark and evaluation in one go."
|
||
)
|
||
parser.add_argument(
|
||
"--api-url",
|
||
default="http://127.0.0.1:30000/v1",
|
||
help="OpenAI‑compatible API base URL",
|
||
)
|
||
parser.add_argument(
|
||
"--model",
|
||
default="meta-llama/Llama-4-Maverick-17B-128E-Instruct",
|
||
help="Model name or ID, only used for model name",
|
||
)
|
||
parser.add_argument(
|
||
"--max-concurrency", type=int, default=144, help="Maximum concurrent requests"
|
||
)
|
||
parser.add_argument(
|
||
"--output-dir", default="tmp-output-dir", help="Directory for cached responses"
|
||
)
|
||
parser.add_argument(
|
||
"--num-prompts", type=int, default=10000, help="Number of prompts to run"
|
||
)
|
||
args = parser.parse_args()
|
||
|
||
asyncio.run(benchmark(args))
|
||
|
||
analyse(args)
|