Fix CI && python3.8 compatible (#920)

This commit is contained in:
Liangsheng Yin
2024-08-04 16:02:05 -07:00
committed by GitHub
parent 975adb802b
commit bb66cc4c52
11 changed files with 31 additions and 32 deletions

View File

@@ -7,7 +7,7 @@ import time
from collections import defaultdict
from dataclasses import dataclass, field
from multiprocessing.pool import ThreadPool
from typing import Any
from typing import Any, Dict, List, Tuple
import httpx
import jinja2
@@ -24,8 +24,8 @@ OPENAI_SYSTEM_MESSAGE_CHATGPT = (
)
Message = dict[str, Any] # keys role, content
MessageList = list[Message]
Message = Dict[str, Any] # keys role, content
MessageList = List[Message]
class SamplerBase:
@@ -45,9 +45,9 @@ class EvalResult:
"""
score: float | None # top-line metric
metrics: dict[str, float] | None # other metrics
htmls: list[str] # strings of valid HTML
convos: list[MessageList] # sampled conversations
metrics: Dict[str, float] | None # other metrics
htmls: List[str] # strings of valid HTML
convos: List[MessageList] # sampled conversations
@dataclass
@@ -57,7 +57,7 @@ class SingleEvalResult:
"""
score: float | None
metrics: dict[str, float] = field(default_factory=dict)
metrics: Dict[str, float] = field(default_factory=dict)
html: str | None = None
convo: MessageList | None = None # sampled conversation
@@ -270,9 +270,9 @@ def _compute_stat(values: list, stat: str):
def aggregate_results(
single_eval_results: list[SingleEvalResult],
default_stats: tuple[str] = ("mean", "std"),
name2stats: dict[str, tuple[str]] | None = None,
single_eval_results: List[SingleEvalResult],
default_stats: Tuple[str] = ("mean", "std"),
name2stats: Dict[str, Tuple[str]] | None = None,
) -> EvalResult:
"""
Aggregate results from multiple evaluations into a single EvalResult.
@@ -302,7 +302,7 @@ def aggregate_results(
)
def map_with_progress(f: callable, xs: list[Any], num_threads: int):
def map_with_progress(f: callable, xs: List[Any], num_threads: int):
"""
Apply f to each element of xs, using a ThreadPool, and show progress.
"""
@@ -422,7 +422,7 @@ def make_report(eval_result: EvalResult) -> str:
)
def make_report_from_example_htmls(htmls: list[str]):
def make_report_from_example_htmls(htmls: List[str]):
"""
Create a standalone HTML report from a list of example htmls
"""

View File

@@ -14,7 +14,7 @@ import re
from collections import Counter, defaultdict
from concurrent.futures import ThreadPoolExecutor, as_completed
from io import BytesIO
from typing import Any, Tuple
from typing import Any, Dict, List, Tuple
import blobfile as bf
import tqdm
@@ -38,8 +38,8 @@ from sglang.test.simple_eval_common import (
def evaluate_functional_correctness(
sample: dict[str, str],
completions: list[str],
sample: Dict[str, str],
completions: List[str],
n_workers: int = 4,
timeout: float = 3.0,
):
@@ -70,7 +70,7 @@ class HumanEval(Eval):
num_examples: int | None,
num_threads: int,
num_samples_per_task: int = 5,
ks_passes: list[int] = [1, 2, 5],
ks_passes: List[int] = [1, 2, 5],
timeout: int = 120,
):
self.seed = 0
@@ -97,7 +97,7 @@ class HumanEval(Eval):
] # remove signature
return extracted_answer
def fn(sample: dict[str, str]):
def fn(sample: Dict[str, str]):
prompt_messages = [
sampler._pack_message(
role="user", content=instruction + sample["prompt"]

View File

@@ -8,7 +8,7 @@ import threading
import time
import unittest
from functools import partial
from typing import Callable, Optional
from typing import Callable, List, Optional
import numpy as np
import requests
@@ -457,7 +457,7 @@ def run_with_timeout(
return ret_value[0]
def run_unittest_files(files: list[str], timeout_per_file: float):
def run_unittest_files(files: List[str], timeout_per_file: float):
tic = time.time()
success = True