[Fix] Improve longbench prompt and other logics (#11474)
This commit is contained in:
@@ -103,6 +103,7 @@ def run_eval(args):
|
||||
categories = args.categories.split(",") if args.categories else None
|
||||
|
||||
eval_obj = LongBenchV2Eval(
|
||||
model=args.model,
|
||||
data_source=data_source,
|
||||
num_examples=args.num_examples,
|
||||
num_threads=args.num_threads,
|
||||
|
||||
@@ -290,6 +290,9 @@ def aggregate_results(
|
||||
htmls = []
|
||||
convos = []
|
||||
for single_eval_result in single_eval_results:
|
||||
# Skip None results
|
||||
if single_eval_result is None:
|
||||
continue
|
||||
for name, value in single_eval_result.metrics.items():
|
||||
name2values[name].append(value)
|
||||
if single_eval_result.score is not None:
|
||||
|
||||
@@ -12,6 +12,8 @@ import os
|
||||
import re
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
from transformers import AutoTokenizer
|
||||
|
||||
from sglang.test import simple_eval_common as common
|
||||
from sglang.test.simple_eval_common import (
|
||||
ANSWER_PATTERN_MULTICHOICE,
|
||||
@@ -55,7 +57,11 @@ def format_longbench_v2_question(row: dict) -> str:
|
||||
choice_D = row.get("D", row.get("choice_D", ""))
|
||||
|
||||
# Official LongBench-v2 template
|
||||
prompt = f"""{context.strip()}
|
||||
prompt = f"""
|
||||
Please read the following text and answer the question below.
|
||||
<text>
|
||||
{context.strip()}
|
||||
</text>
|
||||
|
||||
What is the correct answer to this question: {question.strip()}
|
||||
Choices:
|
||||
@@ -64,7 +70,7 @@ Choices:
|
||||
(C) {choice_C.strip()}
|
||||
(D) {choice_D.strip()}
|
||||
|
||||
The correct answer is"""
|
||||
Format your response as follows: "The correct answer is (insert answer here)"."""
|
||||
|
||||
return prompt
|
||||
|
||||
@@ -106,6 +112,7 @@ class LongBenchV2Eval(Eval):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model: str = None,
|
||||
data_source: str = DEFAULT_DATASET,
|
||||
num_examples: Optional[int] = None,
|
||||
num_threads: int = 1,
|
||||
@@ -126,6 +133,9 @@ class LongBenchV2Eval(Eval):
|
||||
max_context_length: Maximum context length in characters
|
||||
min_context_length: Minimum context length in characters
|
||||
"""
|
||||
self.tokenizer = AutoTokenizer.from_pretrained(model, trust_remote_code=True)
|
||||
self.min_context_length = min_context_length
|
||||
self.max_context_length = max_context_length
|
||||
# Load dataset based on data source type
|
||||
examples = self._load_dataset(data_source)
|
||||
|
||||
@@ -133,11 +143,6 @@ class LongBenchV2Eval(Eval):
|
||||
if categories:
|
||||
examples = [ex for ex in examples if ex.get("category") in categories]
|
||||
|
||||
if min_context_length or max_context_length:
|
||||
examples = self._filter_by_context_length(
|
||||
examples, min_context_length, max_context_length
|
||||
)
|
||||
|
||||
# Sample examples if specified
|
||||
if num_examples:
|
||||
assert n_repeats == 1, "n_repeats only supported when not sampling examples"
|
||||
@@ -246,26 +251,23 @@ class LongBenchV2Eval(Eval):
|
||||
|
||||
return normalized
|
||||
|
||||
def _filter_by_context_length(
|
||||
def _check_context_length(
|
||||
self,
|
||||
examples: List[Dict[str, Any]],
|
||||
formatted_question: str,
|
||||
tokenizer: AutoTokenizer,
|
||||
min_length: Optional[int],
|
||||
max_length: Optional[int],
|
||||
) -> List[Dict[str, Any]]:
|
||||
) -> bool:
|
||||
"""Filter examples by context length measured in characters."""
|
||||
filtered = []
|
||||
for example in examples:
|
||||
context = example.get("context", "")
|
||||
context_length = len(context)
|
||||
input_ids = tokenizer.encode(formatted_question)
|
||||
context_length = len(input_ids)
|
||||
|
||||
if min_length is not None and context_length < min_length:
|
||||
continue
|
||||
if max_length is not None and context_length > max_length:
|
||||
continue
|
||||
if min_length is not None and context_length < min_length:
|
||||
return False
|
||||
if max_length is not None and context_length > max_length:
|
||||
return False
|
||||
|
||||
filtered.append(example)
|
||||
|
||||
return filtered
|
||||
return True
|
||||
|
||||
def __call__(self, sampler: SamplerBase) -> EvalResult:
|
||||
"""Run the evaluation."""
|
||||
@@ -274,6 +276,16 @@ class LongBenchV2Eval(Eval):
|
||||
# Format the question using official template
|
||||
formatted_question = format_longbench_v2_question(row)
|
||||
|
||||
if self.min_context_length or self.max_context_length:
|
||||
if not self._check_context_length(
|
||||
formatted_question,
|
||||
self.tokenizer,
|
||||
self.min_context_length,
|
||||
self.max_context_length,
|
||||
):
|
||||
# Skip this example
|
||||
return None
|
||||
|
||||
prompt_messages = [
|
||||
sampler._pack_message(content=formatted_question, role="user")
|
||||
]
|
||||
|
||||
Reference in New Issue
Block a user