From bb66cc4c52b1440a8e85247b706b2b3d645e902d Mon Sep 17 00:00:00 2001 From: Liangsheng Yin Date: Sun, 4 Aug 2024 16:02:05 -0700 Subject: [PATCH] Fix CI && python3.8 compatible (#920) --- .github/workflows/e2e-test.yml | 2 +- .../trace_and_evaluate_rag_using_parea.ipynb | 5 ++-- python/sglang/srt/hf_transformers_utils.py | 4 ++-- python/sglang/test/simple_eval_common.py | 24 +++++++++---------- python/sglang/test/simple_eval_humaneval.py | 10 ++++---- python/sglang/test/test_utils.py | 4 ++-- test/srt/test_chunked_prefill.py | 2 +- test/srt/test_eval_accuracy.py | 2 +- test/srt/test_openai_server.py | 2 +- test/srt/test_srt_endpoint.py | 6 ++--- test/srt/test_torch_compile.py | 2 +- 11 files changed, 31 insertions(+), 32 deletions(-) diff --git a/.github/workflows/e2e-test.yml b/.github/workflows/e2e-test.yml index 1c5852436..38651d45b 100644 --- a/.github/workflows/e2e-test.yml +++ b/.github/workflows/e2e-test.yml @@ -57,4 +57,4 @@ jobs: cd $HOME && python3 -m sglang.bench_serving --backend sglang --port 8413 --dataset-name random --num-prompts 3000 --random-input 256 --random-output 512 echo "Stopping server..." - kill -9 $(ps aux | grep sglang | grep Meta-Llama-3.1-8B-Instruct | grep -v grep | awk '{print $2}') + kill -9 $(ps aux | grep sglang | grep Meta-Llama-3.1-8B-Instruct | grep -- "--port 8413" | grep -v grep | awk '{print $2}') diff --git a/examples/usage/rag_using_parea/trace_and_evaluate_rag_using_parea.ipynb b/examples/usage/rag_using_parea/trace_and_evaluate_rag_using_parea.ipynb index ce90e2186..25b91b7d1 100644 --- a/examples/usage/rag_using_parea/trace_and_evaluate_rag_using_parea.ipynb +++ b/examples/usage/rag_using_parea/trace_and_evaluate_rag_using_parea.ipynb @@ -71,6 +71,7 @@ "source": [ "import json\n", "import os\n", + "from typing import List\n", "\n", "import chromadb\n", "\n", @@ -148,7 +149,7 @@ "outputs": [], "source": [ "@trace\n", - "def retrieval(question: str) -> list[str]:\n", + "def retrieval(question: str) -> List[str]:\n", " return collection.query(\n", " query_texts=[question],\n", " n_results=1\n", @@ -278,7 +279,7 @@ "\n", "\n", "@trace(eval_funcs=[context_relevancy_eval, percent_target_supported_by_context])\n", - "def retrieval(question: str) -> list[str]:\n", + "def retrieval(question: str) -> List[str]:\n", " return collection.query(\n", " query_texts=[question],\n", " n_results=1\n", diff --git a/python/sglang/srt/hf_transformers_utils.py b/python/sglang/srt/hf_transformers_utils.py index 9f681fe88..508843a39 100644 --- a/python/sglang/srt/hf_transformers_utils.py +++ b/python/sglang/srt/hf_transformers_utils.py @@ -19,7 +19,7 @@ import functools import json import os import warnings -from typing import AbstractSet, Collection, Dict, Literal, Optional, Type, Union +from typing import AbstractSet, Collection, Dict, List, Literal, Optional, Type, Union from huggingface_hub import snapshot_download from transformers import ( @@ -259,7 +259,7 @@ class TiktokenTokenizer: Literal["all"], AbstractSet[str] ] = set(), # noqa: B006 disallowed_special: Union[Literal["all"], Collection[str]] = "all", - ) -> list[int]: + ) -> List[int]: if isinstance(allowed_special, set): allowed_special |= self._default_allowed_special return tiktoken.Encoding.encode( diff --git a/python/sglang/test/simple_eval_common.py b/python/sglang/test/simple_eval_common.py index af1671694..4cfd3515f 100644 --- a/python/sglang/test/simple_eval_common.py +++ b/python/sglang/test/simple_eval_common.py @@ -7,7 +7,7 @@ import time from collections import defaultdict from dataclasses import dataclass, field from multiprocessing.pool import ThreadPool -from typing import Any +from typing import Any, Dict, List, Tuple import httpx import jinja2 @@ -24,8 +24,8 @@ OPENAI_SYSTEM_MESSAGE_CHATGPT = ( ) -Message = dict[str, Any] # keys role, content -MessageList = list[Message] +Message = Dict[str, Any] # keys role, content +MessageList = List[Message] class SamplerBase: @@ -45,9 +45,9 @@ class EvalResult: """ score: float | None # top-line metric - metrics: dict[str, float] | None # other metrics - htmls: list[str] # strings of valid HTML - convos: list[MessageList] # sampled conversations + metrics: Dict[str, float] | None # other metrics + htmls: List[str] # strings of valid HTML + convos: List[MessageList] # sampled conversations @dataclass @@ -57,7 +57,7 @@ class SingleEvalResult: """ score: float | None - metrics: dict[str, float] = field(default_factory=dict) + metrics: Dict[str, float] = field(default_factory=dict) html: str | None = None convo: MessageList | None = None # sampled conversation @@ -270,9 +270,9 @@ def _compute_stat(values: list, stat: str): def aggregate_results( - single_eval_results: list[SingleEvalResult], - default_stats: tuple[str] = ("mean", "std"), - name2stats: dict[str, tuple[str]] | None = None, + single_eval_results: List[SingleEvalResult], + default_stats: Tuple[str] = ("mean", "std"), + name2stats: Dict[str, Tuple[str]] | None = None, ) -> EvalResult: """ Aggregate results from multiple evaluations into a single EvalResult. @@ -302,7 +302,7 @@ def aggregate_results( ) -def map_with_progress(f: callable, xs: list[Any], num_threads: int): +def map_with_progress(f: callable, xs: List[Any], num_threads: int): """ Apply f to each element of xs, using a ThreadPool, and show progress. """ @@ -422,7 +422,7 @@ def make_report(eval_result: EvalResult) -> str: ) -def make_report_from_example_htmls(htmls: list[str]): +def make_report_from_example_htmls(htmls: List[str]): """ Create a standalone HTML report from a list of example htmls """ diff --git a/python/sglang/test/simple_eval_humaneval.py b/python/sglang/test/simple_eval_humaneval.py index f693cb7f8..7a0f90c46 100644 --- a/python/sglang/test/simple_eval_humaneval.py +++ b/python/sglang/test/simple_eval_humaneval.py @@ -14,7 +14,7 @@ import re from collections import Counter, defaultdict from concurrent.futures import ThreadPoolExecutor, as_completed from io import BytesIO -from typing import Any, Tuple +from typing import Any, Dict, List, Tuple import blobfile as bf import tqdm @@ -38,8 +38,8 @@ from sglang.test.simple_eval_common import ( def evaluate_functional_correctness( - sample: dict[str, str], - completions: list[str], + sample: Dict[str, str], + completions: List[str], n_workers: int = 4, timeout: float = 3.0, ): @@ -70,7 +70,7 @@ class HumanEval(Eval): num_examples: int | None, num_threads: int, num_samples_per_task: int = 5, - ks_passes: list[int] = [1, 2, 5], + ks_passes: List[int] = [1, 2, 5], timeout: int = 120, ): self.seed = 0 @@ -97,7 +97,7 @@ class HumanEval(Eval): ] # remove signature return extracted_answer - def fn(sample: dict[str, str]): + def fn(sample: Dict[str, str]): prompt_messages = [ sampler._pack_message( role="user", content=instruction + sample["prompt"] diff --git a/python/sglang/test/test_utils.py b/python/sglang/test/test_utils.py index 1fe237a2f..be1bdb966 100644 --- a/python/sglang/test/test_utils.py +++ b/python/sglang/test/test_utils.py @@ -8,7 +8,7 @@ import threading import time import unittest from functools import partial -from typing import Callable, Optional +from typing import Callable, List, Optional import numpy as np import requests @@ -457,7 +457,7 @@ def run_with_timeout( return ret_value[0] -def run_unittest_files(files: list[str], timeout_per_file: float): +def run_unittest_files(files: List[str], timeout_per_file: float): tic = time.time() success = True diff --git a/test/srt/test_chunked_prefill.py b/test/srt/test_chunked_prefill.py index 3380f6aa8..e98c713e8 100644 --- a/test/srt/test_chunked_prefill.py +++ b/test/srt/test_chunked_prefill.py @@ -11,7 +11,7 @@ class TestAccuracy(unittest.TestCase): @classmethod def setUpClass(cls): cls.model = MODEL_NAME_FOR_TEST - cls.base_url = f"http://localhost:30000" + cls.base_url = f"http://localhost:8157" cls.process = popen_launch_server( cls.model, cls.base_url, diff --git a/test/srt/test_eval_accuracy.py b/test/srt/test_eval_accuracy.py index dc3f8266b..a6911785e 100644 --- a/test/srt/test_eval_accuracy.py +++ b/test/srt/test_eval_accuracy.py @@ -11,7 +11,7 @@ class TestAccuracy(unittest.TestCase): @classmethod def setUpClass(cls): cls.model = MODEL_NAME_FOR_TEST - cls.base_url = f"http://localhost:30000" + cls.base_url = f"http://localhost:8157" cls.process = popen_launch_server(cls.model, cls.base_url, timeout=300) @classmethod diff --git a/test/srt/test_openai_server.py b/test/srt/test_openai_server.py index 269664e14..5e37b1b4d 100644 --- a/test/srt/test_openai_server.py +++ b/test/srt/test_openai_server.py @@ -12,7 +12,7 @@ class TestOpenAIServer(unittest.TestCase): @classmethod def setUpClass(cls): cls.model = MODEL_NAME_FOR_TEST - cls.base_url = f"http://localhost:30000" + cls.base_url = f"http://localhost:8157" cls.api_key = "sk-123456" cls.process = popen_launch_server( cls.model, cls.base_url, timeout=300, api_key=cls.api_key diff --git a/test/srt/test_srt_endpoint.py b/test/srt/test_srt_endpoint.py index 345467858..76637b2f6 100644 --- a/test/srt/test_srt_endpoint.py +++ b/test/srt/test_srt_endpoint.py @@ -12,11 +12,9 @@ class TestSRTEndpoint(unittest.TestCase): @classmethod def setUpClass(cls): - port = 30000 - cls.model = MODEL_NAME_FOR_TEST - cls.base_url = f"http://localhost:{port}" - cls.process = popen_launch_server(cls.model, port, timeout=300) + cls.base_url = f"http://localhost:{8157}" + cls.process = popen_launch_server(cls.model, cls.base_url, timeout=300) @classmethod def tearDownClass(cls): diff --git a/test/srt/test_torch_compile.py b/test/srt/test_torch_compile.py index efd9c4698..126ee91ef 100644 --- a/test/srt/test_torch_compile.py +++ b/test/srt/test_torch_compile.py @@ -11,7 +11,7 @@ class TestAccuracy(unittest.TestCase): @classmethod def setUpClass(cls): cls.model = MODEL_NAME_FOR_TEST - cls.base_url = f"http://localhost:30000" + cls.base_url = f"http://localhost:8157" cls.process = popen_launch_server( cls.model, cls.base_url, timeout=300, other_args=["--enable-torch-compile"] )