Fix CI && python3.8 compatible (#920)

This commit is contained in:
Liangsheng Yin
2024-08-04 16:02:05 -07:00
committed by GitHub
parent 975adb802b
commit bb66cc4c52
11 changed files with 31 additions and 32 deletions

View File

@@ -57,4 +57,4 @@ jobs:
cd $HOME && python3 -m sglang.bench_serving --backend sglang --port 8413 --dataset-name random --num-prompts 3000 --random-input 256 --random-output 512
echo "Stopping server..."
kill -9 $(ps aux | grep sglang | grep Meta-Llama-3.1-8B-Instruct | grep -v grep | awk '{print $2}')
kill -9 $(ps aux | grep sglang | grep Meta-Llama-3.1-8B-Instruct | grep -- "--port 8413" | grep -v grep | awk '{print $2}')

View File

@@ -71,6 +71,7 @@
"source": [
"import json\n",
"import os\n",
"from typing import List\n",
"\n",
"import chromadb\n",
"\n",
@@ -148,7 +149,7 @@
"outputs": [],
"source": [
"@trace\n",
"def retrieval(question: str) -> list[str]:\n",
"def retrieval(question: str) -> List[str]:\n",
" return collection.query(\n",
" query_texts=[question],\n",
" n_results=1\n",
@@ -278,7 +279,7 @@
"\n",
"\n",
"@trace(eval_funcs=[context_relevancy_eval, percent_target_supported_by_context])\n",
"def retrieval(question: str) -> list[str]:\n",
"def retrieval(question: str) -> List[str]:\n",
" return collection.query(\n",
" query_texts=[question],\n",
" n_results=1\n",

View File

@@ -19,7 +19,7 @@ import functools
import json
import os
import warnings
from typing import AbstractSet, Collection, Dict, Literal, Optional, Type, Union
from typing import AbstractSet, Collection, Dict, List, Literal, Optional, Type, Union
from huggingface_hub import snapshot_download
from transformers import (
@@ -259,7 +259,7 @@ class TiktokenTokenizer:
Literal["all"], AbstractSet[str]
] = set(), # noqa: B006
disallowed_special: Union[Literal["all"], Collection[str]] = "all",
) -> list[int]:
) -> List[int]:
if isinstance(allowed_special, set):
allowed_special |= self._default_allowed_special
return tiktoken.Encoding.encode(

View File

@@ -7,7 +7,7 @@ import time
from collections import defaultdict
from dataclasses import dataclass, field
from multiprocessing.pool import ThreadPool
from typing import Any
from typing import Any, Dict, List, Tuple
import httpx
import jinja2
@@ -24,8 +24,8 @@ OPENAI_SYSTEM_MESSAGE_CHATGPT = (
)
Message = dict[str, Any] # keys role, content
MessageList = list[Message]
Message = Dict[str, Any] # keys role, content
MessageList = List[Message]
class SamplerBase:
@@ -45,9 +45,9 @@ class EvalResult:
"""
score: float | None # top-line metric
metrics: dict[str, float] | None # other metrics
htmls: list[str] # strings of valid HTML
convos: list[MessageList] # sampled conversations
metrics: Dict[str, float] | None # other metrics
htmls: List[str] # strings of valid HTML
convos: List[MessageList] # sampled conversations
@dataclass
@@ -57,7 +57,7 @@ class SingleEvalResult:
"""
score: float | None
metrics: dict[str, float] = field(default_factory=dict)
metrics: Dict[str, float] = field(default_factory=dict)
html: str | None = None
convo: MessageList | None = None # sampled conversation
@@ -270,9 +270,9 @@ def _compute_stat(values: list, stat: str):
def aggregate_results(
single_eval_results: list[SingleEvalResult],
default_stats: tuple[str] = ("mean", "std"),
name2stats: dict[str, tuple[str]] | None = None,
single_eval_results: List[SingleEvalResult],
default_stats: Tuple[str] = ("mean", "std"),
name2stats: Dict[str, Tuple[str]] | None = None,
) -> EvalResult:
"""
Aggregate results from multiple evaluations into a single EvalResult.
@@ -302,7 +302,7 @@ def aggregate_results(
)
def map_with_progress(f: callable, xs: list[Any], num_threads: int):
def map_with_progress(f: callable, xs: List[Any], num_threads: int):
"""
Apply f to each element of xs, using a ThreadPool, and show progress.
"""
@@ -422,7 +422,7 @@ def make_report(eval_result: EvalResult) -> str:
)
def make_report_from_example_htmls(htmls: list[str]):
def make_report_from_example_htmls(htmls: List[str]):
"""
Create a standalone HTML report from a list of example htmls
"""

View File

@@ -14,7 +14,7 @@ import re
from collections import Counter, defaultdict
from concurrent.futures import ThreadPoolExecutor, as_completed
from io import BytesIO
from typing import Any, Tuple
from typing import Any, Dict, List, Tuple
import blobfile as bf
import tqdm
@@ -38,8 +38,8 @@ from sglang.test.simple_eval_common import (
def evaluate_functional_correctness(
sample: dict[str, str],
completions: list[str],
sample: Dict[str, str],
completions: List[str],
n_workers: int = 4,
timeout: float = 3.0,
):
@@ -70,7 +70,7 @@ class HumanEval(Eval):
num_examples: int | None,
num_threads: int,
num_samples_per_task: int = 5,
ks_passes: list[int] = [1, 2, 5],
ks_passes: List[int] = [1, 2, 5],
timeout: int = 120,
):
self.seed = 0
@@ -97,7 +97,7 @@ class HumanEval(Eval):
] # remove signature
return extracted_answer
def fn(sample: dict[str, str]):
def fn(sample: Dict[str, str]):
prompt_messages = [
sampler._pack_message(
role="user", content=instruction + sample["prompt"]

View File

@@ -8,7 +8,7 @@ import threading
import time
import unittest
from functools import partial
from typing import Callable, Optional
from typing import Callable, List, Optional
import numpy as np
import requests
@@ -457,7 +457,7 @@ def run_with_timeout(
return ret_value[0]
def run_unittest_files(files: list[str], timeout_per_file: float):
def run_unittest_files(files: List[str], timeout_per_file: float):
tic = time.time()
success = True

View File

@@ -11,7 +11,7 @@ class TestAccuracy(unittest.TestCase):
@classmethod
def setUpClass(cls):
cls.model = MODEL_NAME_FOR_TEST
cls.base_url = f"http://localhost:30000"
cls.base_url = f"http://localhost:8157"
cls.process = popen_launch_server(
cls.model,
cls.base_url,

View File

@@ -11,7 +11,7 @@ class TestAccuracy(unittest.TestCase):
@classmethod
def setUpClass(cls):
cls.model = MODEL_NAME_FOR_TEST
cls.base_url = f"http://localhost:30000"
cls.base_url = f"http://localhost:8157"
cls.process = popen_launch_server(cls.model, cls.base_url, timeout=300)
@classmethod

View File

@@ -12,7 +12,7 @@ class TestOpenAIServer(unittest.TestCase):
@classmethod
def setUpClass(cls):
cls.model = MODEL_NAME_FOR_TEST
cls.base_url = f"http://localhost:30000"
cls.base_url = f"http://localhost:8157"
cls.api_key = "sk-123456"
cls.process = popen_launch_server(
cls.model, cls.base_url, timeout=300, api_key=cls.api_key

View File

@@ -12,11 +12,9 @@ class TestSRTEndpoint(unittest.TestCase):
@classmethod
def setUpClass(cls):
port = 30000
cls.model = MODEL_NAME_FOR_TEST
cls.base_url = f"http://localhost:{port}"
cls.process = popen_launch_server(cls.model, port, timeout=300)
cls.base_url = f"http://localhost:{8157}"
cls.process = popen_launch_server(cls.model, cls.base_url, timeout=300)
@classmethod
def tearDownClass(cls):

View File

@@ -11,7 +11,7 @@ class TestAccuracy(unittest.TestCase):
@classmethod
def setUpClass(cls):
cls.model = MODEL_NAME_FOR_TEST
cls.base_url = f"http://localhost:30000"
cls.base_url = f"http://localhost:8157"
cls.process = popen_launch_server(
cls.model, cls.base_url, timeout=300, other_args=["--enable-torch-compile"]
)