Mixed style of chunked prefill (#1013)

This commit is contained in:
Liangsheng Yin
2024-08-16 02:13:00 -07:00
committed by GitHub
parent 5a261bd055
commit 3694f8f996
14 changed files with 195 additions and 59 deletions

View File

@@ -1,13 +1,12 @@
# Adapted from https://github.com/openai/simple-evals/
import base64
import os
import resource
import time
from collections import defaultdict
from dataclasses import dataclass, field
from multiprocessing.pool import ThreadPool
from typing import Any, Dict, List, Tuple
from typing import Any, Dict, List, Optional, Tuple
import httpx
import jinja2
@@ -44,8 +43,8 @@ class EvalResult:
Result of running an evaluation (usually consisting of many samples)
"""
score: float | None # top-line metric
metrics: Dict[str, float] | None # other metrics
score: Optional[float] # top-line metric
metrics: Optional[Dict[str, float]] # other metrics
htmls: List[str] # strings of valid HTML
convos: List[MessageList] # sampled conversations
@@ -56,10 +55,10 @@ class SingleEvalResult:
Result of evaluating a single sample
"""
score: float | None
score: Optional[float]
metrics: Dict[str, float] = field(default_factory=dict)
html: str | None = None
convo: MessageList | None = None # sampled conversation
html: Optional[str] = None
convo: Optional[MessageList] = None # sampled conversation
class Eval:
@@ -89,8 +88,8 @@ class ChatCompletionSampler(SamplerBase):
def __init__(
self,
base_url: str = None,
model: str | None = None,
system_message: str | None = None,
model: Optional[str] = None,
system_message: Optional[str] = None,
temperature: float = 0.0,
max_tokens: int = 2048,
):
@@ -272,7 +271,7 @@ def _compute_stat(values: list, stat: str):
def aggregate_results(
single_eval_results: List[SingleEvalResult],
default_stats: Tuple[str] = ("mean", "std"),
name2stats: Dict[str, Tuple[str]] | None = None,
name2stats: Optional[Dict[str, Tuple[str]]] = None,
) -> EvalResult:
"""
Aggregate results from multiple evaluations into a single EvalResult.

View File

@@ -8,6 +8,7 @@ https://arxiv.org/abs/2311.12022
import random
import re
from typing import Optional
import pandas
@@ -28,7 +29,7 @@ class GPQAEval(Eval):
def __init__(
self,
filename: str,
num_examples: int | None,
num_examples: Optional[int],
num_threads: int,
n_repeats: int = 1,
):

View File

@@ -9,7 +9,7 @@ https://arxiv.org/abs/2107.03374 https://github.com/openai/human-eval/
import random
import re
from concurrent.futures import ThreadPoolExecutor, as_completed
from typing import Dict, List
from typing import Dict, List, Optional
import tqdm
@@ -61,7 +61,7 @@ def evaluate_functional_correctness(
class HumanEval(Eval):
def __init__(
self,
num_examples: int | None,
num_examples: Optional[int],
num_threads: int,
num_samples_per_task: int = 5,
ks_passes: List[int] = [1, 2, 5],

View File

@@ -8,6 +8,7 @@ https://arxiv.org/abs/2103.03874
import random
import re
from typing import Optional
import pandas
@@ -36,7 +37,7 @@ class MathEval(Eval):
self,
filename: str,
equality_checker: SamplerBase,
num_examples: int | None,
num_examples: Optional[int],
num_threads: int,
):
df = pandas.read_csv(filename)

View File

@@ -8,6 +8,7 @@ https://arxiv.org/abs/2009.03300
import random
import re
from typing import Optional
import pandas
@@ -84,7 +85,7 @@ subject2category = {
class MMLUEval(Eval):
def __init__(self, filename: str, num_examples: int | None, num_threads: int):
def __init__(self, filename: str, num_examples: Optional[int], num_threads: int):
df = pandas.read_csv(filename)
examples = [row.to_dict() for _, row in df.iterrows()]
if num_examples: