Fix CI && python3.8 compatible (#920)

This commit is contained in:
Liangsheng Yin
2024-08-04 16:02:05 -07:00
committed by GitHub
parent 975adb802b
commit bb66cc4c52
11 changed files with 31 additions and 32 deletions

View File

@@ -19,7 +19,7 @@ import functools
import json
import os
import warnings
from typing import AbstractSet, Collection, Dict, Literal, Optional, Type, Union
from typing import AbstractSet, Collection, Dict, List, Literal, Optional, Type, Union
from huggingface_hub import snapshot_download
from transformers import (
@@ -259,7 +259,7 @@ class TiktokenTokenizer:
Literal["all"], AbstractSet[str]
] = set(), # noqa: B006
disallowed_special: Union[Literal["all"], Collection[str]] = "all",
) -> list[int]:
) -> List[int]:
if isinstance(allowed_special, set):
allowed_special |= self._default_allowed_special
return tiktoken.Encoding.encode(

View File

@@ -7,7 +7,7 @@ import time
from collections import defaultdict
from dataclasses import dataclass, field
from multiprocessing.pool import ThreadPool
from typing import Any
from typing import Any, Dict, List, Tuple
import httpx
import jinja2
@@ -24,8 +24,8 @@ OPENAI_SYSTEM_MESSAGE_CHATGPT = (
)
Message = dict[str, Any] # keys role, content
MessageList = list[Message]
Message = Dict[str, Any] # keys role, content
MessageList = List[Message]
class SamplerBase:
@@ -45,9 +45,9 @@ class EvalResult:
"""
score: float | None # top-line metric
metrics: dict[str, float] | None # other metrics
htmls: list[str] # strings of valid HTML
convos: list[MessageList] # sampled conversations
metrics: Dict[str, float] | None # other metrics
htmls: List[str] # strings of valid HTML
convos: List[MessageList] # sampled conversations
@dataclass
@@ -57,7 +57,7 @@ class SingleEvalResult:
"""
score: float | None
metrics: dict[str, float] = field(default_factory=dict)
metrics: Dict[str, float] = field(default_factory=dict)
html: str | None = None
convo: MessageList | None = None # sampled conversation
@@ -270,9 +270,9 @@ def _compute_stat(values: list, stat: str):
def aggregate_results(
single_eval_results: list[SingleEvalResult],
default_stats: tuple[str] = ("mean", "std"),
name2stats: dict[str, tuple[str]] | None = None,
single_eval_results: List[SingleEvalResult],
default_stats: Tuple[str] = ("mean", "std"),
name2stats: Dict[str, Tuple[str]] | None = None,
) -> EvalResult:
"""
Aggregate results from multiple evaluations into a single EvalResult.
@@ -302,7 +302,7 @@ def aggregate_results(
)
def map_with_progress(f: callable, xs: list[Any], num_threads: int):
def map_with_progress(f: callable, xs: List[Any], num_threads: int):
"""
Apply f to each element of xs, using a ThreadPool, and show progress.
"""
@@ -422,7 +422,7 @@ def make_report(eval_result: EvalResult) -> str:
)
def make_report_from_example_htmls(htmls: list[str]):
def make_report_from_example_htmls(htmls: List[str]):
"""
Create a standalone HTML report from a list of example htmls
"""

View File

@@ -14,7 +14,7 @@ import re
from collections import Counter, defaultdict
from concurrent.futures import ThreadPoolExecutor, as_completed
from io import BytesIO
from typing import Any, Tuple
from typing import Any, Dict, List, Tuple
import blobfile as bf
import tqdm
@@ -38,8 +38,8 @@ from sglang.test.simple_eval_common import (
def evaluate_functional_correctness(
sample: dict[str, str],
completions: list[str],
sample: Dict[str, str],
completions: List[str],
n_workers: int = 4,
timeout: float = 3.0,
):
@@ -70,7 +70,7 @@ class HumanEval(Eval):
num_examples: int | None,
num_threads: int,
num_samples_per_task: int = 5,
ks_passes: list[int] = [1, 2, 5],
ks_passes: List[int] = [1, 2, 5],
timeout: int = 120,
):
self.seed = 0
@@ -97,7 +97,7 @@ class HumanEval(Eval):
] # remove signature
return extracted_answer
def fn(sample: dict[str, str]):
def fn(sample: Dict[str, str]):
prompt_messages = [
sampler._pack_message(
role="user", content=instruction + sample["prompt"]

View File

@@ -8,7 +8,7 @@ import threading
import time
import unittest
from functools import partial
from typing import Callable, Optional
from typing import Callable, List, Optional
import numpy as np
import requests
@@ -457,7 +457,7 @@ def run_with_timeout(
return ret_value[0]
def run_unittest_files(files: list[str], timeout_per_file: float):
def run_unittest_files(files: List[str], timeout_per_file: float):
tic = time.time()
success = True