Fix CI && python3.8 compatible (#920)
This commit is contained in:
@@ -19,7 +19,7 @@ import functools
|
||||
import json
|
||||
import os
|
||||
import warnings
|
||||
from typing import AbstractSet, Collection, Dict, Literal, Optional, Type, Union
|
||||
from typing import AbstractSet, Collection, Dict, List, Literal, Optional, Type, Union
|
||||
|
||||
from huggingface_hub import snapshot_download
|
||||
from transformers import (
|
||||
@@ -259,7 +259,7 @@ class TiktokenTokenizer:
|
||||
Literal["all"], AbstractSet[str]
|
||||
] = set(), # noqa: B006
|
||||
disallowed_special: Union[Literal["all"], Collection[str]] = "all",
|
||||
) -> list[int]:
|
||||
) -> List[int]:
|
||||
if isinstance(allowed_special, set):
|
||||
allowed_special |= self._default_allowed_special
|
||||
return tiktoken.Encoding.encode(
|
||||
|
||||
@@ -7,7 +7,7 @@ import time
|
||||
from collections import defaultdict
|
||||
from dataclasses import dataclass, field
|
||||
from multiprocessing.pool import ThreadPool
|
||||
from typing import Any
|
||||
from typing import Any, Dict, List, Tuple
|
||||
|
||||
import httpx
|
||||
import jinja2
|
||||
@@ -24,8 +24,8 @@ OPENAI_SYSTEM_MESSAGE_CHATGPT = (
|
||||
)
|
||||
|
||||
|
||||
Message = dict[str, Any] # keys role, content
|
||||
MessageList = list[Message]
|
||||
Message = Dict[str, Any] # keys role, content
|
||||
MessageList = List[Message]
|
||||
|
||||
|
||||
class SamplerBase:
|
||||
@@ -45,9 +45,9 @@ class EvalResult:
|
||||
"""
|
||||
|
||||
score: float | None # top-line metric
|
||||
metrics: dict[str, float] | None # other metrics
|
||||
htmls: list[str] # strings of valid HTML
|
||||
convos: list[MessageList] # sampled conversations
|
||||
metrics: Dict[str, float] | None # other metrics
|
||||
htmls: List[str] # strings of valid HTML
|
||||
convos: List[MessageList] # sampled conversations
|
||||
|
||||
|
||||
@dataclass
|
||||
@@ -57,7 +57,7 @@ class SingleEvalResult:
|
||||
"""
|
||||
|
||||
score: float | None
|
||||
metrics: dict[str, float] = field(default_factory=dict)
|
||||
metrics: Dict[str, float] = field(default_factory=dict)
|
||||
html: str | None = None
|
||||
convo: MessageList | None = None # sampled conversation
|
||||
|
||||
@@ -270,9 +270,9 @@ def _compute_stat(values: list, stat: str):
|
||||
|
||||
|
||||
def aggregate_results(
|
||||
single_eval_results: list[SingleEvalResult],
|
||||
default_stats: tuple[str] = ("mean", "std"),
|
||||
name2stats: dict[str, tuple[str]] | None = None,
|
||||
single_eval_results: List[SingleEvalResult],
|
||||
default_stats: Tuple[str] = ("mean", "std"),
|
||||
name2stats: Dict[str, Tuple[str]] | None = None,
|
||||
) -> EvalResult:
|
||||
"""
|
||||
Aggregate results from multiple evaluations into a single EvalResult.
|
||||
@@ -302,7 +302,7 @@ def aggregate_results(
|
||||
)
|
||||
|
||||
|
||||
def map_with_progress(f: callable, xs: list[Any], num_threads: int):
|
||||
def map_with_progress(f: callable, xs: List[Any], num_threads: int):
|
||||
"""
|
||||
Apply f to each element of xs, using a ThreadPool, and show progress.
|
||||
"""
|
||||
@@ -422,7 +422,7 @@ def make_report(eval_result: EvalResult) -> str:
|
||||
)
|
||||
|
||||
|
||||
def make_report_from_example_htmls(htmls: list[str]):
|
||||
def make_report_from_example_htmls(htmls: List[str]):
|
||||
"""
|
||||
Create a standalone HTML report from a list of example htmls
|
||||
"""
|
||||
|
||||
@@ -14,7 +14,7 @@ import re
|
||||
from collections import Counter, defaultdict
|
||||
from concurrent.futures import ThreadPoolExecutor, as_completed
|
||||
from io import BytesIO
|
||||
from typing import Any, Tuple
|
||||
from typing import Any, Dict, List, Tuple
|
||||
|
||||
import blobfile as bf
|
||||
import tqdm
|
||||
@@ -38,8 +38,8 @@ from sglang.test.simple_eval_common import (
|
||||
|
||||
|
||||
def evaluate_functional_correctness(
|
||||
sample: dict[str, str],
|
||||
completions: list[str],
|
||||
sample: Dict[str, str],
|
||||
completions: List[str],
|
||||
n_workers: int = 4,
|
||||
timeout: float = 3.0,
|
||||
):
|
||||
@@ -70,7 +70,7 @@ class HumanEval(Eval):
|
||||
num_examples: int | None,
|
||||
num_threads: int,
|
||||
num_samples_per_task: int = 5,
|
||||
ks_passes: list[int] = [1, 2, 5],
|
||||
ks_passes: List[int] = [1, 2, 5],
|
||||
timeout: int = 120,
|
||||
):
|
||||
self.seed = 0
|
||||
@@ -97,7 +97,7 @@ class HumanEval(Eval):
|
||||
] # remove signature
|
||||
return extracted_answer
|
||||
|
||||
def fn(sample: dict[str, str]):
|
||||
def fn(sample: Dict[str, str]):
|
||||
prompt_messages = [
|
||||
sampler._pack_message(
|
||||
role="user", content=instruction + sample["prompt"]
|
||||
|
||||
@@ -8,7 +8,7 @@ import threading
|
||||
import time
|
||||
import unittest
|
||||
from functools import partial
|
||||
from typing import Callable, Optional
|
||||
from typing import Callable, List, Optional
|
||||
|
||||
import numpy as np
|
||||
import requests
|
||||
@@ -457,7 +457,7 @@ def run_with_timeout(
|
||||
return ret_value[0]
|
||||
|
||||
|
||||
def run_unittest_files(files: list[str], timeout_per_file: float):
|
||||
def run_unittest_files(files: List[str], timeout_per_file: float):
|
||||
tic = time.time()
|
||||
success = True
|
||||
|
||||
|
||||
Reference in New Issue
Block a user