Fix CI && python3.8 compatible (#920)
This commit is contained in:
2
.github/workflows/e2e-test.yml
vendored
2
.github/workflows/e2e-test.yml
vendored
@@ -57,4 +57,4 @@ jobs:
|
|||||||
cd $HOME && python3 -m sglang.bench_serving --backend sglang --port 8413 --dataset-name random --num-prompts 3000 --random-input 256 --random-output 512
|
cd $HOME && python3 -m sglang.bench_serving --backend sglang --port 8413 --dataset-name random --num-prompts 3000 --random-input 256 --random-output 512
|
||||||
|
|
||||||
echo "Stopping server..."
|
echo "Stopping server..."
|
||||||
kill -9 $(ps aux | grep sglang | grep Meta-Llama-3.1-8B-Instruct | grep -v grep | awk '{print $2}')
|
kill -9 $(ps aux | grep sglang | grep Meta-Llama-3.1-8B-Instruct | grep -- "--port 8413" | grep -v grep | awk '{print $2}')
|
||||||
|
|||||||
@@ -71,6 +71,7 @@
|
|||||||
"source": [
|
"source": [
|
||||||
"import json\n",
|
"import json\n",
|
||||||
"import os\n",
|
"import os\n",
|
||||||
|
"from typing import List\n",
|
||||||
"\n",
|
"\n",
|
||||||
"import chromadb\n",
|
"import chromadb\n",
|
||||||
"\n",
|
"\n",
|
||||||
@@ -148,7 +149,7 @@
|
|||||||
"outputs": [],
|
"outputs": [],
|
||||||
"source": [
|
"source": [
|
||||||
"@trace\n",
|
"@trace\n",
|
||||||
"def retrieval(question: str) -> list[str]:\n",
|
"def retrieval(question: str) -> List[str]:\n",
|
||||||
" return collection.query(\n",
|
" return collection.query(\n",
|
||||||
" query_texts=[question],\n",
|
" query_texts=[question],\n",
|
||||||
" n_results=1\n",
|
" n_results=1\n",
|
||||||
@@ -278,7 +279,7 @@
|
|||||||
"\n",
|
"\n",
|
||||||
"\n",
|
"\n",
|
||||||
"@trace(eval_funcs=[context_relevancy_eval, percent_target_supported_by_context])\n",
|
"@trace(eval_funcs=[context_relevancy_eval, percent_target_supported_by_context])\n",
|
||||||
"def retrieval(question: str) -> list[str]:\n",
|
"def retrieval(question: str) -> List[str]:\n",
|
||||||
" return collection.query(\n",
|
" return collection.query(\n",
|
||||||
" query_texts=[question],\n",
|
" query_texts=[question],\n",
|
||||||
" n_results=1\n",
|
" n_results=1\n",
|
||||||
|
|||||||
@@ -19,7 +19,7 @@ import functools
|
|||||||
import json
|
import json
|
||||||
import os
|
import os
|
||||||
import warnings
|
import warnings
|
||||||
from typing import AbstractSet, Collection, Dict, Literal, Optional, Type, Union
|
from typing import AbstractSet, Collection, Dict, List, Literal, Optional, Type, Union
|
||||||
|
|
||||||
from huggingface_hub import snapshot_download
|
from huggingface_hub import snapshot_download
|
||||||
from transformers import (
|
from transformers import (
|
||||||
@@ -259,7 +259,7 @@ class TiktokenTokenizer:
|
|||||||
Literal["all"], AbstractSet[str]
|
Literal["all"], AbstractSet[str]
|
||||||
] = set(), # noqa: B006
|
] = set(), # noqa: B006
|
||||||
disallowed_special: Union[Literal["all"], Collection[str]] = "all",
|
disallowed_special: Union[Literal["all"], Collection[str]] = "all",
|
||||||
) -> list[int]:
|
) -> List[int]:
|
||||||
if isinstance(allowed_special, set):
|
if isinstance(allowed_special, set):
|
||||||
allowed_special |= self._default_allowed_special
|
allowed_special |= self._default_allowed_special
|
||||||
return tiktoken.Encoding.encode(
|
return tiktoken.Encoding.encode(
|
||||||
|
|||||||
@@ -7,7 +7,7 @@ import time
|
|||||||
from collections import defaultdict
|
from collections import defaultdict
|
||||||
from dataclasses import dataclass, field
|
from dataclasses import dataclass, field
|
||||||
from multiprocessing.pool import ThreadPool
|
from multiprocessing.pool import ThreadPool
|
||||||
from typing import Any
|
from typing import Any, Dict, List, Tuple
|
||||||
|
|
||||||
import httpx
|
import httpx
|
||||||
import jinja2
|
import jinja2
|
||||||
@@ -24,8 +24,8 @@ OPENAI_SYSTEM_MESSAGE_CHATGPT = (
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
Message = dict[str, Any] # keys role, content
|
Message = Dict[str, Any] # keys role, content
|
||||||
MessageList = list[Message]
|
MessageList = List[Message]
|
||||||
|
|
||||||
|
|
||||||
class SamplerBase:
|
class SamplerBase:
|
||||||
@@ -45,9 +45,9 @@ class EvalResult:
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
score: float | None # top-line metric
|
score: float | None # top-line metric
|
||||||
metrics: dict[str, float] | None # other metrics
|
metrics: Dict[str, float] | None # other metrics
|
||||||
htmls: list[str] # strings of valid HTML
|
htmls: List[str] # strings of valid HTML
|
||||||
convos: list[MessageList] # sampled conversations
|
convos: List[MessageList] # sampled conversations
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
@@ -57,7 +57,7 @@ class SingleEvalResult:
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
score: float | None
|
score: float | None
|
||||||
metrics: dict[str, float] = field(default_factory=dict)
|
metrics: Dict[str, float] = field(default_factory=dict)
|
||||||
html: str | None = None
|
html: str | None = None
|
||||||
convo: MessageList | None = None # sampled conversation
|
convo: MessageList | None = None # sampled conversation
|
||||||
|
|
||||||
@@ -270,9 +270,9 @@ def _compute_stat(values: list, stat: str):
|
|||||||
|
|
||||||
|
|
||||||
def aggregate_results(
|
def aggregate_results(
|
||||||
single_eval_results: list[SingleEvalResult],
|
single_eval_results: List[SingleEvalResult],
|
||||||
default_stats: tuple[str] = ("mean", "std"),
|
default_stats: Tuple[str] = ("mean", "std"),
|
||||||
name2stats: dict[str, tuple[str]] | None = None,
|
name2stats: Dict[str, Tuple[str]] | None = None,
|
||||||
) -> EvalResult:
|
) -> EvalResult:
|
||||||
"""
|
"""
|
||||||
Aggregate results from multiple evaluations into a single EvalResult.
|
Aggregate results from multiple evaluations into a single EvalResult.
|
||||||
@@ -302,7 +302,7 @@ def aggregate_results(
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def map_with_progress(f: callable, xs: list[Any], num_threads: int):
|
def map_with_progress(f: callable, xs: List[Any], num_threads: int):
|
||||||
"""
|
"""
|
||||||
Apply f to each element of xs, using a ThreadPool, and show progress.
|
Apply f to each element of xs, using a ThreadPool, and show progress.
|
||||||
"""
|
"""
|
||||||
@@ -422,7 +422,7 @@ def make_report(eval_result: EvalResult) -> str:
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def make_report_from_example_htmls(htmls: list[str]):
|
def make_report_from_example_htmls(htmls: List[str]):
|
||||||
"""
|
"""
|
||||||
Create a standalone HTML report from a list of example htmls
|
Create a standalone HTML report from a list of example htmls
|
||||||
"""
|
"""
|
||||||
|
|||||||
@@ -14,7 +14,7 @@ import re
|
|||||||
from collections import Counter, defaultdict
|
from collections import Counter, defaultdict
|
||||||
from concurrent.futures import ThreadPoolExecutor, as_completed
|
from concurrent.futures import ThreadPoolExecutor, as_completed
|
||||||
from io import BytesIO
|
from io import BytesIO
|
||||||
from typing import Any, Tuple
|
from typing import Any, Dict, List, Tuple
|
||||||
|
|
||||||
import blobfile as bf
|
import blobfile as bf
|
||||||
import tqdm
|
import tqdm
|
||||||
@@ -38,8 +38,8 @@ from sglang.test.simple_eval_common import (
|
|||||||
|
|
||||||
|
|
||||||
def evaluate_functional_correctness(
|
def evaluate_functional_correctness(
|
||||||
sample: dict[str, str],
|
sample: Dict[str, str],
|
||||||
completions: list[str],
|
completions: List[str],
|
||||||
n_workers: int = 4,
|
n_workers: int = 4,
|
||||||
timeout: float = 3.0,
|
timeout: float = 3.0,
|
||||||
):
|
):
|
||||||
@@ -70,7 +70,7 @@ class HumanEval(Eval):
|
|||||||
num_examples: int | None,
|
num_examples: int | None,
|
||||||
num_threads: int,
|
num_threads: int,
|
||||||
num_samples_per_task: int = 5,
|
num_samples_per_task: int = 5,
|
||||||
ks_passes: list[int] = [1, 2, 5],
|
ks_passes: List[int] = [1, 2, 5],
|
||||||
timeout: int = 120,
|
timeout: int = 120,
|
||||||
):
|
):
|
||||||
self.seed = 0
|
self.seed = 0
|
||||||
@@ -97,7 +97,7 @@ class HumanEval(Eval):
|
|||||||
] # remove signature
|
] # remove signature
|
||||||
return extracted_answer
|
return extracted_answer
|
||||||
|
|
||||||
def fn(sample: dict[str, str]):
|
def fn(sample: Dict[str, str]):
|
||||||
prompt_messages = [
|
prompt_messages = [
|
||||||
sampler._pack_message(
|
sampler._pack_message(
|
||||||
role="user", content=instruction + sample["prompt"]
|
role="user", content=instruction + sample["prompt"]
|
||||||
|
|||||||
@@ -8,7 +8,7 @@ import threading
|
|||||||
import time
|
import time
|
||||||
import unittest
|
import unittest
|
||||||
from functools import partial
|
from functools import partial
|
||||||
from typing import Callable, Optional
|
from typing import Callable, List, Optional
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import requests
|
import requests
|
||||||
@@ -457,7 +457,7 @@ def run_with_timeout(
|
|||||||
return ret_value[0]
|
return ret_value[0]
|
||||||
|
|
||||||
|
|
||||||
def run_unittest_files(files: list[str], timeout_per_file: float):
|
def run_unittest_files(files: List[str], timeout_per_file: float):
|
||||||
tic = time.time()
|
tic = time.time()
|
||||||
success = True
|
success = True
|
||||||
|
|
||||||
|
|||||||
@@ -11,7 +11,7 @@ class TestAccuracy(unittest.TestCase):
|
|||||||
@classmethod
|
@classmethod
|
||||||
def setUpClass(cls):
|
def setUpClass(cls):
|
||||||
cls.model = MODEL_NAME_FOR_TEST
|
cls.model = MODEL_NAME_FOR_TEST
|
||||||
cls.base_url = f"http://localhost:30000"
|
cls.base_url = f"http://localhost:8157"
|
||||||
cls.process = popen_launch_server(
|
cls.process = popen_launch_server(
|
||||||
cls.model,
|
cls.model,
|
||||||
cls.base_url,
|
cls.base_url,
|
||||||
|
|||||||
@@ -11,7 +11,7 @@ class TestAccuracy(unittest.TestCase):
|
|||||||
@classmethod
|
@classmethod
|
||||||
def setUpClass(cls):
|
def setUpClass(cls):
|
||||||
cls.model = MODEL_NAME_FOR_TEST
|
cls.model = MODEL_NAME_FOR_TEST
|
||||||
cls.base_url = f"http://localhost:30000"
|
cls.base_url = f"http://localhost:8157"
|
||||||
cls.process = popen_launch_server(cls.model, cls.base_url, timeout=300)
|
cls.process = popen_launch_server(cls.model, cls.base_url, timeout=300)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
|
|||||||
@@ -12,7 +12,7 @@ class TestOpenAIServer(unittest.TestCase):
|
|||||||
@classmethod
|
@classmethod
|
||||||
def setUpClass(cls):
|
def setUpClass(cls):
|
||||||
cls.model = MODEL_NAME_FOR_TEST
|
cls.model = MODEL_NAME_FOR_TEST
|
||||||
cls.base_url = f"http://localhost:30000"
|
cls.base_url = f"http://localhost:8157"
|
||||||
cls.api_key = "sk-123456"
|
cls.api_key = "sk-123456"
|
||||||
cls.process = popen_launch_server(
|
cls.process = popen_launch_server(
|
||||||
cls.model, cls.base_url, timeout=300, api_key=cls.api_key
|
cls.model, cls.base_url, timeout=300, api_key=cls.api_key
|
||||||
|
|||||||
@@ -12,11 +12,9 @@ class TestSRTEndpoint(unittest.TestCase):
|
|||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def setUpClass(cls):
|
def setUpClass(cls):
|
||||||
port = 30000
|
|
||||||
|
|
||||||
cls.model = MODEL_NAME_FOR_TEST
|
cls.model = MODEL_NAME_FOR_TEST
|
||||||
cls.base_url = f"http://localhost:{port}"
|
cls.base_url = f"http://localhost:{8157}"
|
||||||
cls.process = popen_launch_server(cls.model, port, timeout=300)
|
cls.process = popen_launch_server(cls.model, cls.base_url, timeout=300)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def tearDownClass(cls):
|
def tearDownClass(cls):
|
||||||
|
|||||||
@@ -11,7 +11,7 @@ class TestAccuracy(unittest.TestCase):
|
|||||||
@classmethod
|
@classmethod
|
||||||
def setUpClass(cls):
|
def setUpClass(cls):
|
||||||
cls.model = MODEL_NAME_FOR_TEST
|
cls.model = MODEL_NAME_FOR_TEST
|
||||||
cls.base_url = f"http://localhost:30000"
|
cls.base_url = f"http://localhost:8157"
|
||||||
cls.process = popen_launch_server(
|
cls.process = popen_launch_server(
|
||||||
cls.model, cls.base_url, timeout=300, other_args=["--enable-torch-compile"]
|
cls.model, cls.base_url, timeout=300, other_args=["--enable-torch-compile"]
|
||||||
)
|
)
|
||||||
|
|||||||
Reference in New Issue
Block a user