Fix CI && python3.8 compatible (#920)
This commit is contained in:
2
.github/workflows/e2e-test.yml
vendored
2
.github/workflows/e2e-test.yml
vendored
@@ -57,4 +57,4 @@ jobs:
|
||||
cd $HOME && python3 -m sglang.bench_serving --backend sglang --port 8413 --dataset-name random --num-prompts 3000 --random-input 256 --random-output 512
|
||||
|
||||
echo "Stopping server..."
|
||||
kill -9 $(ps aux | grep sglang | grep Meta-Llama-3.1-8B-Instruct | grep -v grep | awk '{print $2}')
|
||||
kill -9 $(ps aux | grep sglang | grep Meta-Llama-3.1-8B-Instruct | grep -- "--port 8413" | grep -v grep | awk '{print $2}')
|
||||
|
||||
@@ -71,6 +71,7 @@
|
||||
"source": [
|
||||
"import json\n",
|
||||
"import os\n",
|
||||
"from typing import List\n",
|
||||
"\n",
|
||||
"import chromadb\n",
|
||||
"\n",
|
||||
@@ -148,7 +149,7 @@
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"@trace\n",
|
||||
"def retrieval(question: str) -> list[str]:\n",
|
||||
"def retrieval(question: str) -> List[str]:\n",
|
||||
" return collection.query(\n",
|
||||
" query_texts=[question],\n",
|
||||
" n_results=1\n",
|
||||
@@ -278,7 +279,7 @@
|
||||
"\n",
|
||||
"\n",
|
||||
"@trace(eval_funcs=[context_relevancy_eval, percent_target_supported_by_context])\n",
|
||||
"def retrieval(question: str) -> list[str]:\n",
|
||||
"def retrieval(question: str) -> List[str]:\n",
|
||||
" return collection.query(\n",
|
||||
" query_texts=[question],\n",
|
||||
" n_results=1\n",
|
||||
|
||||
@@ -19,7 +19,7 @@ import functools
|
||||
import json
|
||||
import os
|
||||
import warnings
|
||||
from typing import AbstractSet, Collection, Dict, Literal, Optional, Type, Union
|
||||
from typing import AbstractSet, Collection, Dict, List, Literal, Optional, Type, Union
|
||||
|
||||
from huggingface_hub import snapshot_download
|
||||
from transformers import (
|
||||
@@ -259,7 +259,7 @@ class TiktokenTokenizer:
|
||||
Literal["all"], AbstractSet[str]
|
||||
] = set(), # noqa: B006
|
||||
disallowed_special: Union[Literal["all"], Collection[str]] = "all",
|
||||
) -> list[int]:
|
||||
) -> List[int]:
|
||||
if isinstance(allowed_special, set):
|
||||
allowed_special |= self._default_allowed_special
|
||||
return tiktoken.Encoding.encode(
|
||||
|
||||
@@ -7,7 +7,7 @@ import time
|
||||
from collections import defaultdict
|
||||
from dataclasses import dataclass, field
|
||||
from multiprocessing.pool import ThreadPool
|
||||
from typing import Any
|
||||
from typing import Any, Dict, List, Tuple
|
||||
|
||||
import httpx
|
||||
import jinja2
|
||||
@@ -24,8 +24,8 @@ OPENAI_SYSTEM_MESSAGE_CHATGPT = (
|
||||
)
|
||||
|
||||
|
||||
Message = dict[str, Any] # keys role, content
|
||||
MessageList = list[Message]
|
||||
Message = Dict[str, Any] # keys role, content
|
||||
MessageList = List[Message]
|
||||
|
||||
|
||||
class SamplerBase:
|
||||
@@ -45,9 +45,9 @@ class EvalResult:
|
||||
"""
|
||||
|
||||
score: float | None # top-line metric
|
||||
metrics: dict[str, float] | None # other metrics
|
||||
htmls: list[str] # strings of valid HTML
|
||||
convos: list[MessageList] # sampled conversations
|
||||
metrics: Dict[str, float] | None # other metrics
|
||||
htmls: List[str] # strings of valid HTML
|
||||
convos: List[MessageList] # sampled conversations
|
||||
|
||||
|
||||
@dataclass
|
||||
@@ -57,7 +57,7 @@ class SingleEvalResult:
|
||||
"""
|
||||
|
||||
score: float | None
|
||||
metrics: dict[str, float] = field(default_factory=dict)
|
||||
metrics: Dict[str, float] = field(default_factory=dict)
|
||||
html: str | None = None
|
||||
convo: MessageList | None = None # sampled conversation
|
||||
|
||||
@@ -270,9 +270,9 @@ def _compute_stat(values: list, stat: str):
|
||||
|
||||
|
||||
def aggregate_results(
|
||||
single_eval_results: list[SingleEvalResult],
|
||||
default_stats: tuple[str] = ("mean", "std"),
|
||||
name2stats: dict[str, tuple[str]] | None = None,
|
||||
single_eval_results: List[SingleEvalResult],
|
||||
default_stats: Tuple[str] = ("mean", "std"),
|
||||
name2stats: Dict[str, Tuple[str]] | None = None,
|
||||
) -> EvalResult:
|
||||
"""
|
||||
Aggregate results from multiple evaluations into a single EvalResult.
|
||||
@@ -302,7 +302,7 @@ def aggregate_results(
|
||||
)
|
||||
|
||||
|
||||
def map_with_progress(f: callable, xs: list[Any], num_threads: int):
|
||||
def map_with_progress(f: callable, xs: List[Any], num_threads: int):
|
||||
"""
|
||||
Apply f to each element of xs, using a ThreadPool, and show progress.
|
||||
"""
|
||||
@@ -422,7 +422,7 @@ def make_report(eval_result: EvalResult) -> str:
|
||||
)
|
||||
|
||||
|
||||
def make_report_from_example_htmls(htmls: list[str]):
|
||||
def make_report_from_example_htmls(htmls: List[str]):
|
||||
"""
|
||||
Create a standalone HTML report from a list of example htmls
|
||||
"""
|
||||
|
||||
@@ -14,7 +14,7 @@ import re
|
||||
from collections import Counter, defaultdict
|
||||
from concurrent.futures import ThreadPoolExecutor, as_completed
|
||||
from io import BytesIO
|
||||
from typing import Any, Tuple
|
||||
from typing import Any, Dict, List, Tuple
|
||||
|
||||
import blobfile as bf
|
||||
import tqdm
|
||||
@@ -38,8 +38,8 @@ from sglang.test.simple_eval_common import (
|
||||
|
||||
|
||||
def evaluate_functional_correctness(
|
||||
sample: dict[str, str],
|
||||
completions: list[str],
|
||||
sample: Dict[str, str],
|
||||
completions: List[str],
|
||||
n_workers: int = 4,
|
||||
timeout: float = 3.0,
|
||||
):
|
||||
@@ -70,7 +70,7 @@ class HumanEval(Eval):
|
||||
num_examples: int | None,
|
||||
num_threads: int,
|
||||
num_samples_per_task: int = 5,
|
||||
ks_passes: list[int] = [1, 2, 5],
|
||||
ks_passes: List[int] = [1, 2, 5],
|
||||
timeout: int = 120,
|
||||
):
|
||||
self.seed = 0
|
||||
@@ -97,7 +97,7 @@ class HumanEval(Eval):
|
||||
] # remove signature
|
||||
return extracted_answer
|
||||
|
||||
def fn(sample: dict[str, str]):
|
||||
def fn(sample: Dict[str, str]):
|
||||
prompt_messages = [
|
||||
sampler._pack_message(
|
||||
role="user", content=instruction + sample["prompt"]
|
||||
|
||||
@@ -8,7 +8,7 @@ import threading
|
||||
import time
|
||||
import unittest
|
||||
from functools import partial
|
||||
from typing import Callable, Optional
|
||||
from typing import Callable, List, Optional
|
||||
|
||||
import numpy as np
|
||||
import requests
|
||||
@@ -457,7 +457,7 @@ def run_with_timeout(
|
||||
return ret_value[0]
|
||||
|
||||
|
||||
def run_unittest_files(files: list[str], timeout_per_file: float):
|
||||
def run_unittest_files(files: List[str], timeout_per_file: float):
|
||||
tic = time.time()
|
||||
success = True
|
||||
|
||||
|
||||
@@ -11,7 +11,7 @@ class TestAccuracy(unittest.TestCase):
|
||||
@classmethod
|
||||
def setUpClass(cls):
|
||||
cls.model = MODEL_NAME_FOR_TEST
|
||||
cls.base_url = f"http://localhost:30000"
|
||||
cls.base_url = f"http://localhost:8157"
|
||||
cls.process = popen_launch_server(
|
||||
cls.model,
|
||||
cls.base_url,
|
||||
|
||||
@@ -11,7 +11,7 @@ class TestAccuracy(unittest.TestCase):
|
||||
@classmethod
|
||||
def setUpClass(cls):
|
||||
cls.model = MODEL_NAME_FOR_TEST
|
||||
cls.base_url = f"http://localhost:30000"
|
||||
cls.base_url = f"http://localhost:8157"
|
||||
cls.process = popen_launch_server(cls.model, cls.base_url, timeout=300)
|
||||
|
||||
@classmethod
|
||||
|
||||
@@ -12,7 +12,7 @@ class TestOpenAIServer(unittest.TestCase):
|
||||
@classmethod
|
||||
def setUpClass(cls):
|
||||
cls.model = MODEL_NAME_FOR_TEST
|
||||
cls.base_url = f"http://localhost:30000"
|
||||
cls.base_url = f"http://localhost:8157"
|
||||
cls.api_key = "sk-123456"
|
||||
cls.process = popen_launch_server(
|
||||
cls.model, cls.base_url, timeout=300, api_key=cls.api_key
|
||||
|
||||
@@ -12,11 +12,9 @@ class TestSRTEndpoint(unittest.TestCase):
|
||||
|
||||
@classmethod
|
||||
def setUpClass(cls):
|
||||
port = 30000
|
||||
|
||||
cls.model = MODEL_NAME_FOR_TEST
|
||||
cls.base_url = f"http://localhost:{port}"
|
||||
cls.process = popen_launch_server(cls.model, port, timeout=300)
|
||||
cls.base_url = f"http://localhost:{8157}"
|
||||
cls.process = popen_launch_server(cls.model, cls.base_url, timeout=300)
|
||||
|
||||
@classmethod
|
||||
def tearDownClass(cls):
|
||||
|
||||
@@ -11,7 +11,7 @@ class TestAccuracy(unittest.TestCase):
|
||||
@classmethod
|
||||
def setUpClass(cls):
|
||||
cls.model = MODEL_NAME_FOR_TEST
|
||||
cls.base_url = f"http://localhost:30000"
|
||||
cls.base_url = f"http://localhost:8157"
|
||||
cls.process = popen_launch_server(
|
||||
cls.model, cls.base_url, timeout=300, other_args=["--enable-torch-compile"]
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user