[Fix] Improve longbench prompt and other logics (#11474)
This commit is contained in:
@@ -103,6 +103,7 @@ def run_eval(args):
|
|||||||
categories = args.categories.split(",") if args.categories else None
|
categories = args.categories.split(",") if args.categories else None
|
||||||
|
|
||||||
eval_obj = LongBenchV2Eval(
|
eval_obj = LongBenchV2Eval(
|
||||||
|
model=args.model,
|
||||||
data_source=data_source,
|
data_source=data_source,
|
||||||
num_examples=args.num_examples,
|
num_examples=args.num_examples,
|
||||||
num_threads=args.num_threads,
|
num_threads=args.num_threads,
|
||||||
|
|||||||
@@ -290,6 +290,9 @@ def aggregate_results(
|
|||||||
htmls = []
|
htmls = []
|
||||||
convos = []
|
convos = []
|
||||||
for single_eval_result in single_eval_results:
|
for single_eval_result in single_eval_results:
|
||||||
|
# Skip None results
|
||||||
|
if single_eval_result is None:
|
||||||
|
continue
|
||||||
for name, value in single_eval_result.metrics.items():
|
for name, value in single_eval_result.metrics.items():
|
||||||
name2values[name].append(value)
|
name2values[name].append(value)
|
||||||
if single_eval_result.score is not None:
|
if single_eval_result.score is not None:
|
||||||
|
|||||||
@@ -12,6 +12,8 @@ import os
|
|||||||
import re
|
import re
|
||||||
from typing import Any, Dict, List, Optional
|
from typing import Any, Dict, List, Optional
|
||||||
|
|
||||||
|
from transformers import AutoTokenizer
|
||||||
|
|
||||||
from sglang.test import simple_eval_common as common
|
from sglang.test import simple_eval_common as common
|
||||||
from sglang.test.simple_eval_common import (
|
from sglang.test.simple_eval_common import (
|
||||||
ANSWER_PATTERN_MULTICHOICE,
|
ANSWER_PATTERN_MULTICHOICE,
|
||||||
@@ -55,7 +57,11 @@ def format_longbench_v2_question(row: dict) -> str:
|
|||||||
choice_D = row.get("D", row.get("choice_D", ""))
|
choice_D = row.get("D", row.get("choice_D", ""))
|
||||||
|
|
||||||
# Official LongBench-v2 template
|
# Official LongBench-v2 template
|
||||||
prompt = f"""{context.strip()}
|
prompt = f"""
|
||||||
|
Please read the following text and answer the question below.
|
||||||
|
<text>
|
||||||
|
{context.strip()}
|
||||||
|
</text>
|
||||||
|
|
||||||
What is the correct answer to this question: {question.strip()}
|
What is the correct answer to this question: {question.strip()}
|
||||||
Choices:
|
Choices:
|
||||||
@@ -64,7 +70,7 @@ Choices:
|
|||||||
(C) {choice_C.strip()}
|
(C) {choice_C.strip()}
|
||||||
(D) {choice_D.strip()}
|
(D) {choice_D.strip()}
|
||||||
|
|
||||||
The correct answer is"""
|
Format your response as follows: "The correct answer is (insert answer here)"."""
|
||||||
|
|
||||||
return prompt
|
return prompt
|
||||||
|
|
||||||
@@ -106,6 +112,7 @@ class LongBenchV2Eval(Eval):
|
|||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
|
model: str = None,
|
||||||
data_source: str = DEFAULT_DATASET,
|
data_source: str = DEFAULT_DATASET,
|
||||||
num_examples: Optional[int] = None,
|
num_examples: Optional[int] = None,
|
||||||
num_threads: int = 1,
|
num_threads: int = 1,
|
||||||
@@ -126,6 +133,9 @@ class LongBenchV2Eval(Eval):
|
|||||||
max_context_length: Maximum context length in characters
|
max_context_length: Maximum context length in characters
|
||||||
min_context_length: Minimum context length in characters
|
min_context_length: Minimum context length in characters
|
||||||
"""
|
"""
|
||||||
|
self.tokenizer = AutoTokenizer.from_pretrained(model, trust_remote_code=True)
|
||||||
|
self.min_context_length = min_context_length
|
||||||
|
self.max_context_length = max_context_length
|
||||||
# Load dataset based on data source type
|
# Load dataset based on data source type
|
||||||
examples = self._load_dataset(data_source)
|
examples = self._load_dataset(data_source)
|
||||||
|
|
||||||
@@ -133,11 +143,6 @@ class LongBenchV2Eval(Eval):
|
|||||||
if categories:
|
if categories:
|
||||||
examples = [ex for ex in examples if ex.get("category") in categories]
|
examples = [ex for ex in examples if ex.get("category") in categories]
|
||||||
|
|
||||||
if min_context_length or max_context_length:
|
|
||||||
examples = self._filter_by_context_length(
|
|
||||||
examples, min_context_length, max_context_length
|
|
||||||
)
|
|
||||||
|
|
||||||
# Sample examples if specified
|
# Sample examples if specified
|
||||||
if num_examples:
|
if num_examples:
|
||||||
assert n_repeats == 1, "n_repeats only supported when not sampling examples"
|
assert n_repeats == 1, "n_repeats only supported when not sampling examples"
|
||||||
@@ -246,26 +251,23 @@ class LongBenchV2Eval(Eval):
|
|||||||
|
|
||||||
return normalized
|
return normalized
|
||||||
|
|
||||||
def _filter_by_context_length(
|
def _check_context_length(
|
||||||
self,
|
self,
|
||||||
examples: List[Dict[str, Any]],
|
formatted_question: str,
|
||||||
|
tokenizer: AutoTokenizer,
|
||||||
min_length: Optional[int],
|
min_length: Optional[int],
|
||||||
max_length: Optional[int],
|
max_length: Optional[int],
|
||||||
) -> List[Dict[str, Any]]:
|
) -> bool:
|
||||||
"""Filter examples by context length measured in characters."""
|
"""Filter examples by context length measured in characters."""
|
||||||
filtered = []
|
input_ids = tokenizer.encode(formatted_question)
|
||||||
for example in examples:
|
context_length = len(input_ids)
|
||||||
context = example.get("context", "")
|
|
||||||
context_length = len(context)
|
|
||||||
|
|
||||||
if min_length is not None and context_length < min_length:
|
if min_length is not None and context_length < min_length:
|
||||||
continue
|
return False
|
||||||
if max_length is not None and context_length > max_length:
|
if max_length is not None and context_length > max_length:
|
||||||
continue
|
return False
|
||||||
|
|
||||||
filtered.append(example)
|
return True
|
||||||
|
|
||||||
return filtered
|
|
||||||
|
|
||||||
def __call__(self, sampler: SamplerBase) -> EvalResult:
|
def __call__(self, sampler: SamplerBase) -> EvalResult:
|
||||||
"""Run the evaluation."""
|
"""Run the evaluation."""
|
||||||
@@ -274,6 +276,16 @@ class LongBenchV2Eval(Eval):
|
|||||||
# Format the question using official template
|
# Format the question using official template
|
||||||
formatted_question = format_longbench_v2_question(row)
|
formatted_question = format_longbench_v2_question(row)
|
||||||
|
|
||||||
|
if self.min_context_length or self.max_context_length:
|
||||||
|
if not self._check_context_length(
|
||||||
|
formatted_question,
|
||||||
|
self.tokenizer,
|
||||||
|
self.min_context_length,
|
||||||
|
self.max_context_length,
|
||||||
|
):
|
||||||
|
# Skip this example
|
||||||
|
return None
|
||||||
|
|
||||||
prompt_messages = [
|
prompt_messages = [
|
||||||
sampler._pack_message(content=formatted_question, role="user")
|
sampler._pack_message(content=formatted_question, role="user")
|
||||||
]
|
]
|
||||||
|
|||||||
Reference in New Issue
Block a user