Clean up metrics code (#1972)
This commit is contained in:
@@ -391,8 +391,12 @@ class TokenizerManager:
|
||||
|
||||
async with self.model_update_lock:
|
||||
# wait for the previous generation requests to finish
|
||||
while len(self.rid_to_state) > 0:
|
||||
await asyncio.sleep(0.001)
|
||||
for i in range(3):
|
||||
while len(self.rid_to_state) > 0:
|
||||
await asyncio.sleep(0.001)
|
||||
# FIXME: We add some sleep here to avoid some race conditions.
|
||||
# We can use a read-write lock as a better fix.
|
||||
await asyncio.sleep(0.01)
|
||||
self.send_to_scheduler.send_pyobj(obj)
|
||||
self.model_update_result = asyncio.Future()
|
||||
|
||||
|
||||
@@ -25,20 +25,16 @@ import json
|
||||
import logging
|
||||
import multiprocessing as mp
|
||||
import os
|
||||
import re
|
||||
import tempfile
|
||||
import threading
|
||||
import time
|
||||
from http import HTTPStatus
|
||||
from typing import AsyncIterator, Dict, List, Optional, Union
|
||||
|
||||
import orjson
|
||||
from starlette.routing import Mount
|
||||
|
||||
# Fix a bug of Python threading
|
||||
setattr(threading, "_register_atexit", lambda *args, **kwargs: None)
|
||||
|
||||
import aiohttp
|
||||
import orjson
|
||||
import requests
|
||||
import uvicorn
|
||||
import uvloop
|
||||
@@ -77,6 +73,7 @@ from sglang.srt.openai_api.protocol import ModelCard, ModelList
|
||||
from sglang.srt.server_args import PortArgs, ServerArgs
|
||||
from sglang.srt.utils import (
|
||||
add_api_key_middleware,
|
||||
add_prometheus_middleware,
|
||||
assert_pkg_version,
|
||||
configure_logger,
|
||||
delete_directory,
|
||||
@@ -84,16 +81,13 @@ from sglang.srt.utils import (
|
||||
kill_child_process,
|
||||
maybe_set_triton_cache_manager,
|
||||
prepare_model_and_tokenizer,
|
||||
set_prometheus_multiproc_dir,
|
||||
set_ulimit,
|
||||
)
|
||||
from sglang.utils import get_exception_traceback
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Temporary directory for prometheus multiprocess mode
|
||||
# Cleaned up automatically when this object is garbage collected
|
||||
prometheus_multiproc_dir: tempfile.TemporaryDirectory
|
||||
|
||||
asyncio.set_event_loop_policy(uvloop.EventLoopPolicy())
|
||||
|
||||
|
||||
@@ -445,10 +439,6 @@ def launch_server(
|
||||
1. The HTTP server and Tokenizer Manager both run in the main process.
|
||||
2. Inter-process communication is done through ICP (each process uses a different port) via the ZMQ library.
|
||||
"""
|
||||
|
||||
if server_args.enable_metrics:
|
||||
_set_prometheus_env()
|
||||
|
||||
launch_engine(server_args=server_args)
|
||||
|
||||
# Add api key authorization
|
||||
@@ -487,36 +477,6 @@ def launch_server(
|
||||
t.join()
|
||||
|
||||
|
||||
def add_prometheus_middleware(app: FastAPI):
|
||||
# Adapted from https://github.com/vllm-project/vllm/blob/v0.6.1/vllm/entrypoints/openai/api_server.py#L216
|
||||
from prometheus_client import CollectorRegistry, make_asgi_app, multiprocess
|
||||
|
||||
registry = CollectorRegistry()
|
||||
multiprocess.MultiProcessCollector(registry)
|
||||
metrics_route = Mount("/metrics", make_asgi_app(registry=registry))
|
||||
|
||||
# Workaround for 307 Redirect for /metrics
|
||||
metrics_route.path_regex = re.compile("^/metrics(?P<path>.*)$")
|
||||
app.routes.append(metrics_route)
|
||||
|
||||
|
||||
def _set_prometheus_env():
|
||||
# Set prometheus multiprocess directory
|
||||
# sglang uses prometheus multiprocess mode
|
||||
# we need to set this before importing prometheus_client
|
||||
# https://prometheus.github.io/client_python/multiprocess/
|
||||
global prometheus_multiproc_dir
|
||||
if "PROMETHEUS_MULTIPROC_DIR" in os.environ:
|
||||
logger.debug(f"User set PROMETHEUS_MULTIPROC_DIR detected.")
|
||||
prometheus_multiproc_dir = tempfile.TemporaryDirectory(
|
||||
dir=os.environ["PROMETHEUS_MULTIPROC_DIR"]
|
||||
)
|
||||
else:
|
||||
prometheus_multiproc_dir = tempfile.TemporaryDirectory()
|
||||
os.environ["PROMETHEUS_MULTIPROC_DIR"] = prometheus_multiproc_dir.name
|
||||
logger.debug(f"PROMETHEUS_MULTIPROC_DIR: {os.environ['PROMETHEUS_MULTIPROC_DIR']}")
|
||||
|
||||
|
||||
def _set_envs_and_config(server_args: ServerArgs):
|
||||
# Set global environments
|
||||
os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3"
|
||||
@@ -543,6 +503,10 @@ def _set_envs_and_config(server_args: ServerArgs):
|
||||
"at https://docs.flashinfer.ai/installation.html.",
|
||||
)
|
||||
|
||||
# Set prometheus env vars
|
||||
if server_args.enable_metrics:
|
||||
set_prometheus_multiproc_dir()
|
||||
|
||||
mp.set_start_method("spawn", force=True)
|
||||
|
||||
|
||||
|
||||
@@ -22,10 +22,12 @@ import logging
|
||||
import os
|
||||
import pickle
|
||||
import random
|
||||
import re
|
||||
import resource
|
||||
import shutil
|
||||
import signal
|
||||
import socket
|
||||
import tempfile
|
||||
import time
|
||||
import warnings
|
||||
from importlib.metadata import PackageNotFoundError, version
|
||||
@@ -41,6 +43,7 @@ import triton
|
||||
import zmq
|
||||
from fastapi.responses import ORJSONResponse
|
||||
from packaging import version as pkg_version
|
||||
from starlette.routing import Mount
|
||||
from torch import nn
|
||||
from torch.profiler import ProfilerActivity, profile, record_function
|
||||
from triton.runtime.cache import (
|
||||
@@ -752,3 +755,38 @@ def delete_directory(dirpath):
|
||||
shutil.rmtree(dirpath)
|
||||
except OSError as e:
|
||||
print(f"Warning: {dirpath} : {e.strerror}")
|
||||
|
||||
|
||||
# Temporary directory for prometheus multiprocess mode
|
||||
# Cleaned up automatically when this object is garbage collected
|
||||
prometheus_multiproc_dir: tempfile.TemporaryDirectory
|
||||
|
||||
|
||||
def set_prometheus_multiproc_dir():
|
||||
# Set prometheus multiprocess directory
|
||||
# sglang uses prometheus multiprocess mode
|
||||
# we need to set this before importing prometheus_client
|
||||
# https://prometheus.github.io/client_python/multiprocess/
|
||||
global prometheus_multiproc_dir
|
||||
|
||||
if "PROMETHEUS_MULTIPROC_DIR" in os.environ:
|
||||
logger.debug("User set PROMETHEUS_MULTIPROC_DIR detected.")
|
||||
prometheus_multiproc_dir = tempfile.TemporaryDirectory(
|
||||
dir=os.environ["PROMETHEUS_MULTIPROC_DIR"]
|
||||
)
|
||||
else:
|
||||
prometheus_multiproc_dir = tempfile.TemporaryDirectory()
|
||||
os.environ["PROMETHEUS_MULTIPROC_DIR"] = prometheus_multiproc_dir.name
|
||||
logger.debug(f"PROMETHEUS_MULTIPROC_DIR: {os.environ['PROMETHEUS_MULTIPROC_DIR']}")
|
||||
|
||||
|
||||
def add_prometheus_middleware(app):
|
||||
from prometheus_client import CollectorRegistry, make_asgi_app, multiprocess
|
||||
|
||||
registry = CollectorRegistry()
|
||||
multiprocess.MultiProcessCollector(registry)
|
||||
metrics_route = Mount("/metrics", make_asgi_app(registry=registry))
|
||||
|
||||
# Workaround for 307 Redirect for /metrics
|
||||
metrics_route.path_regex = re.compile("^/metrics(?P<path>.*)$")
|
||||
app.routes.append(metrics_route)
|
||||
|
||||
@@ -27,6 +27,7 @@ from sglang.utils import get_exception_traceback
|
||||
|
||||
DEFAULT_FP8_MODEL_NAME_FOR_TEST = "neuralmagic/Meta-Llama-3.1-8B-FP8"
|
||||
DEFAULT_MODEL_NAME_FOR_TEST = "meta-llama/Llama-3.1-8B-Instruct"
|
||||
DEFAULT_SMALL_MODEL_NAME_FOR_TEST = "meta-llama/Llama-3.2-1B-Instruct"
|
||||
DEFAULT_MOE_MODEL_NAME_FOR_TEST = "mistralai/Mixtral-8x7B-Instruct-v0.1"
|
||||
DEFAULT_MLA_MODEL_NAME_FOR_TEST = "deepseek-ai/DeepSeek-Coder-V2-Lite-Instruct"
|
||||
DEFAULT_MLA_FP8_MODEL_NAME_FOR_TEST = "neuralmagic/DeepSeek-Coder-V2-Lite-Instruct-FP8"
|
||||
@@ -404,7 +405,6 @@ def popen_launch_server(
|
||||
other_args: tuple = (),
|
||||
env: Optional[dict] = None,
|
||||
return_stdout_stderr: Optional[tuple] = None,
|
||||
enable_metrics: bool = False,
|
||||
):
|
||||
_, host, port = base_url.split(":")
|
||||
host = host[2:]
|
||||
@@ -423,8 +423,6 @@ def popen_launch_server(
|
||||
]
|
||||
if api_key:
|
||||
command += ["--api-key", api_key]
|
||||
if enable_metrics:
|
||||
command += ["--enable-metrics"]
|
||||
|
||||
if return_stdout_stderr:
|
||||
process = subprocess.Popen(
|
||||
|
||||
@@ -4,5 +4,5 @@ Install the dependency in CI.
|
||||
|
||||
pip install --upgrade pip
|
||||
pip install -e "python[all]"
|
||||
pip install transformers==4.45.2 sentence_transformers
|
||||
pip install transformers==4.45.2 sentence_transformers accelerate peft
|
||||
pip install flashinfer -i https://flashinfer.ai/whl/cu121/torch2.4/ --force-reinstall
|
||||
|
||||
@@ -16,6 +16,7 @@ suites = {
|
||||
"test_eval_accuracy_mini.py",
|
||||
"test_json_constrained.py",
|
||||
"test_large_max_new_tokens.py",
|
||||
"test_metrics.py",
|
||||
"test_openai_server.py",
|
||||
"test_overlap_schedule.py",
|
||||
"test_pytorch_sampling_backend.py",
|
||||
|
||||
@@ -1,7 +1,5 @@
|
||||
import subprocess
|
||||
import unittest
|
||||
|
||||
from sglang.srt.utils import kill_child_process
|
||||
from sglang.test.test_utils import (
|
||||
DEFAULT_MODEL_NAME_FOR_TEST,
|
||||
DEFAULT_MOE_MODEL_NAME_FOR_TEST,
|
||||
|
||||
@@ -6,7 +6,7 @@ import requests
|
||||
|
||||
from sglang.srt.utils import kill_child_process
|
||||
from sglang.test.test_utils import (
|
||||
DEFAULT_MODEL_NAME_FOR_TEST,
|
||||
DEFAULT_SMALL_MODEL_NAME_FOR_TEST,
|
||||
DEFAULT_URL_FOR_TEST,
|
||||
popen_launch_server,
|
||||
)
|
||||
@@ -15,7 +15,7 @@ from sglang.test.test_utils import (
|
||||
class TestCacheReport(unittest.TestCase):
|
||||
@classmethod
|
||||
def setUpClass(cls):
|
||||
cls.model = DEFAULT_MODEL_NAME_FOR_TEST
|
||||
cls.model = DEFAULT_SMALL_MODEL_NAME_FOR_TEST
|
||||
cls.base_url = DEFAULT_URL_FOR_TEST
|
||||
cls.min_cached = 5
|
||||
cls.process = popen_launch_server(
|
||||
|
||||
@@ -3,6 +3,7 @@ python3 -m unittest test_large_max_new_tokens.TestLargeMaxNewTokens.test_chat_co
|
||||
"""
|
||||
|
||||
import os
|
||||
import time
|
||||
import unittest
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
|
||||
@@ -11,7 +12,7 @@ import openai
|
||||
from sglang.srt.hf_transformers_utils import get_tokenizer
|
||||
from sglang.srt.utils import kill_child_process
|
||||
from sglang.test.test_utils import (
|
||||
DEFAULT_MODEL_NAME_FOR_TEST,
|
||||
DEFAULT_SMALL_MODEL_NAME_FOR_TEST,
|
||||
DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
|
||||
DEFAULT_URL_FOR_TEST,
|
||||
popen_launch_server,
|
||||
@@ -21,7 +22,7 @@ from sglang.test.test_utils import (
|
||||
class TestLargeMaxNewTokens(unittest.TestCase):
|
||||
@classmethod
|
||||
def setUpClass(cls):
|
||||
cls.model = DEFAULT_MODEL_NAME_FOR_TEST
|
||||
cls.model = DEFAULT_SMALL_MODEL_NAME_FOR_TEST
|
||||
cls.base_url = DEFAULT_URL_FOR_TEST
|
||||
cls.api_key = "sk-123456"
|
||||
|
||||
@@ -33,12 +34,19 @@ class TestLargeMaxNewTokens(unittest.TestCase):
|
||||
cls.base_url,
|
||||
timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
|
||||
api_key=cls.api_key,
|
||||
other_args=("--max-total-token", "1024", "--context-len", "8192"),
|
||||
other_args=(
|
||||
"--max-total-token",
|
||||
"1024",
|
||||
"--context-len",
|
||||
"8192",
|
||||
"--decode-log-interval",
|
||||
"2",
|
||||
),
|
||||
env={"SGLANG_CLIP_MAX_NEW_TOKENS_ESTIMATION": "256", **os.environ},
|
||||
return_stdout_stderr=(cls.stdout, cls.stderr),
|
||||
)
|
||||
cls.base_url += "/v1"
|
||||
cls.tokenizer = get_tokenizer(DEFAULT_MODEL_NAME_FOR_TEST)
|
||||
cls.tokenizer = get_tokenizer(DEFAULT_SMALL_MODEL_NAME_FOR_TEST)
|
||||
|
||||
@classmethod
|
||||
def tearDownClass(cls):
|
||||
@@ -75,6 +83,7 @@ class TestLargeMaxNewTokens(unittest.TestCase):
|
||||
# Ensure that they are running concurrently
|
||||
pt = 0
|
||||
while pt >= 0:
|
||||
time.sleep(5)
|
||||
lines = open("stderr.txt").readlines()
|
||||
for line in lines[pt:]:
|
||||
print(line, end="", flush=True)
|
||||
|
||||
@@ -1,31 +1,24 @@
|
||||
import unittest
|
||||
from types import SimpleNamespace
|
||||
|
||||
import requests
|
||||
|
||||
from sglang.srt.utils import kill_child_process
|
||||
from sglang.test.run_eval import run_eval
|
||||
from sglang.test.test_utils import (
|
||||
DEFAULT_MODEL_NAME_FOR_TEST,
|
||||
DEFAULT_SMALL_MODEL_NAME_FOR_TEST,
|
||||
DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
|
||||
DEFAULT_URL_FOR_TEST,
|
||||
popen_launch_server,
|
||||
)
|
||||
|
||||
TEST_MODEL = (
|
||||
DEFAULT_MODEL_NAME_FOR_TEST # I used "google/gemma-2-2b-it" for testing locally
|
||||
)
|
||||
|
||||
|
||||
class TestEnableMetrics(unittest.TestCase):
|
||||
def test_metrics_enabled(self):
|
||||
"""Test that metrics endpoint returns data when enabled"""
|
||||
# Launch server with metrics enabled
|
||||
process = popen_launch_server(
|
||||
model=TEST_MODEL,
|
||||
base_url=DEFAULT_URL_FOR_TEST,
|
||||
DEFAULT_SMALL_MODEL_NAME_FOR_TEST,
|
||||
DEFAULT_URL_FOR_TEST,
|
||||
timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
|
||||
enable_metrics=True,
|
||||
other_args=["--enable-metrics"],
|
||||
)
|
||||
|
||||
try:
|
||||
@@ -38,6 +31,8 @@ class TestEnableMetrics(unittest.TestCase):
|
||||
self.assertEqual(metrics_response.status_code, 200)
|
||||
metrics_content = metrics_response.text
|
||||
|
||||
print(f"{metrics_content=}")
|
||||
|
||||
# Verify essential metrics are present
|
||||
essential_metrics = [
|
||||
"sglang:prompt_tokens_total",
|
||||
@@ -53,7 +48,7 @@ class TestEnableMetrics(unittest.TestCase):
|
||||
self.assertIn(metric, metrics_content, f"Missing metric: {metric}")
|
||||
|
||||
# Verify model name label is present and correct
|
||||
expected_model_name = TEST_MODEL
|
||||
expected_model_name = DEFAULT_SMALL_MODEL_NAME_FOR_TEST
|
||||
self.assertIn(f'model_name="{expected_model_name}"', metrics_content)
|
||||
# Verify metrics have values (not empty)
|
||||
self.assertIn("_sum{", metrics_content)
|
||||
@@ -63,22 +58,6 @@ class TestEnableMetrics(unittest.TestCase):
|
||||
finally:
|
||||
kill_child_process(process.pid, include_self=True)
|
||||
|
||||
def test_metrics_disabled(self):
|
||||
"""Test that metrics endpoint returns 404 when disabled"""
|
||||
# Launch server with metrics disabled
|
||||
process = popen_launch_server(
|
||||
model=TEST_MODEL,
|
||||
base_url=DEFAULT_URL_FOR_TEST,
|
||||
timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
|
||||
enable_metrics=False,
|
||||
)
|
||||
|
||||
try:
|
||||
response = requests.get(f"{DEFAULT_URL_FOR_TEST}/health_generate")
|
||||
self.assertEqual(response.status_code, 200)
|
||||
# Verify metrics endpoint is not available
|
||||
metrics_response = requests.get(f"{DEFAULT_URL_FOR_TEST}/metrics")
|
||||
self.assertEqual(metrics_response.status_code, 404)
|
||||
|
||||
finally:
|
||||
kill_child_process(process.pid, include_self=True)
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
@@ -13,7 +13,7 @@ import openai
|
||||
from sglang.srt.hf_transformers_utils import get_tokenizer
|
||||
from sglang.srt.utils import kill_child_process
|
||||
from sglang.test.test_utils import (
|
||||
DEFAULT_MODEL_NAME_FOR_TEST,
|
||||
DEFAULT_SMALL_MODEL_NAME_FOR_TEST,
|
||||
DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
|
||||
DEFAULT_URL_FOR_TEST,
|
||||
popen_launch_server,
|
||||
@@ -23,7 +23,7 @@ from sglang.test.test_utils import (
|
||||
class TestOpenAIServer(unittest.TestCase):
|
||||
@classmethod
|
||||
def setUpClass(cls):
|
||||
cls.model = DEFAULT_MODEL_NAME_FOR_TEST
|
||||
cls.model = DEFAULT_SMALL_MODEL_NAME_FOR_TEST
|
||||
cls.base_url = DEFAULT_URL_FOR_TEST
|
||||
cls.api_key = "sk-123456"
|
||||
cls.process = popen_launch_server(
|
||||
@@ -33,7 +33,7 @@ class TestOpenAIServer(unittest.TestCase):
|
||||
api_key=cls.api_key,
|
||||
)
|
||||
cls.base_url += "/v1"
|
||||
cls.tokenizer = get_tokenizer(DEFAULT_MODEL_NAME_FOR_TEST)
|
||||
cls.tokenizer = get_tokenizer(DEFAULT_SMALL_MODEL_NAME_FOR_TEST)
|
||||
|
||||
@classmethod
|
||||
def tearDownClass(cls):
|
||||
|
||||
@@ -5,7 +5,7 @@ import unittest
|
||||
import requests
|
||||
|
||||
from sglang.test.test_utils import (
|
||||
DEFAULT_MODEL_NAME_FOR_TEST,
|
||||
DEFAULT_SMALL_MODEL_NAME_FOR_TEST,
|
||||
DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
|
||||
DEFAULT_URL_FOR_TEST,
|
||||
kill_child_process,
|
||||
@@ -62,7 +62,7 @@ def run_test(base_url, nodes):
|
||||
class TestRadixCacheFCFS(unittest.TestCase):
|
||||
@classmethod
|
||||
def setUpClass(cls):
|
||||
cls.model = DEFAULT_MODEL_NAME_FOR_TEST
|
||||
cls.model = DEFAULT_SMALL_MODEL_NAME_FOR_TEST
|
||||
cls.base_url = DEFAULT_URL_FOR_TEST
|
||||
cls.process = popen_launch_server(
|
||||
cls.model,
|
||||
@@ -90,7 +90,7 @@ class TestRadixCacheFCFS(unittest.TestCase):
|
||||
class TestRadixCacheLPM(TestRadixCacheFCFS):
|
||||
@classmethod
|
||||
def setUpClass(cls):
|
||||
cls.model = DEFAULT_MODEL_NAME_FOR_TEST
|
||||
cls.model = DEFAULT_SMALL_MODEL_NAME_FOR_TEST
|
||||
cls.base_url = DEFAULT_URL_FOR_TEST
|
||||
cls.process = popen_launch_server(
|
||||
cls.model,
|
||||
@@ -110,7 +110,7 @@ class TestRadixCacheLPM(TestRadixCacheFCFS):
|
||||
class TestRadixCacheOverlapLPM(TestRadixCacheFCFS):
|
||||
@classmethod
|
||||
def setUpClass(cls):
|
||||
cls.model = DEFAULT_MODEL_NAME_FOR_TEST
|
||||
cls.model = DEFAULT_SMALL_MODEL_NAME_FOR_TEST
|
||||
cls.base_url = DEFAULT_URL_FOR_TEST
|
||||
cls.process = popen_launch_server(
|
||||
cls.model,
|
||||
|
||||
@@ -9,7 +9,7 @@ import requests
|
||||
|
||||
from sglang.srt.utils import kill_child_process
|
||||
from sglang.test.test_utils import (
|
||||
DEFAULT_MODEL_NAME_FOR_TEST,
|
||||
DEFAULT_SMALL_MODEL_NAME_FOR_TEST,
|
||||
DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
|
||||
DEFAULT_URL_FOR_TEST,
|
||||
popen_launch_server,
|
||||
@@ -19,7 +19,7 @@ from sglang.test.test_utils import (
|
||||
class TestSkipTokenizerInit(unittest.TestCase):
|
||||
@classmethod
|
||||
def setUpClass(cls):
|
||||
cls.model = DEFAULT_MODEL_NAME_FOR_TEST
|
||||
cls.model = DEFAULT_SMALL_MODEL_NAME_FOR_TEST
|
||||
cls.base_url = DEFAULT_URL_FOR_TEST
|
||||
cls.process = popen_launch_server(
|
||||
cls.model,
|
||||
|
||||
@@ -10,7 +10,7 @@ import requests
|
||||
|
||||
from sglang.srt.utils import kill_child_process
|
||||
from sglang.test.test_utils import (
|
||||
DEFAULT_MODEL_NAME_FOR_TEST,
|
||||
DEFAULT_SMALL_MODEL_NAME_FOR_TEST,
|
||||
DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
|
||||
DEFAULT_URL_FOR_TEST,
|
||||
popen_launch_server,
|
||||
@@ -20,7 +20,7 @@ from sglang.test.test_utils import (
|
||||
class TestSRTEndpoint(unittest.TestCase):
|
||||
@classmethod
|
||||
def setUpClass(cls):
|
||||
cls.model = DEFAULT_MODEL_NAME_FOR_TEST
|
||||
cls.model = DEFAULT_SMALL_MODEL_NAME_FOR_TEST
|
||||
cls.base_url = DEFAULT_URL_FOR_TEST
|
||||
cls.process = popen_launch_server(
|
||||
cls.model, cls.base_url, timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH
|
||||
|
||||
@@ -11,14 +11,17 @@ from types import SimpleNamespace
|
||||
import sglang as sgl
|
||||
from sglang.srt.hf_transformers_utils import get_tokenizer
|
||||
from sglang.test.few_shot_gsm8k_engine import run_eval
|
||||
from sglang.test.test_utils import DEFAULT_MODEL_NAME_FOR_TEST
|
||||
from sglang.test.test_utils import (
|
||||
DEFAULT_MODEL_NAME_FOR_TEST,
|
||||
DEFAULT_SMALL_MODEL_NAME_FOR_TEST,
|
||||
)
|
||||
|
||||
|
||||
class TestSRTEngine(unittest.TestCase):
|
||||
|
||||
def test_1_engine_runtime_consistency(self):
|
||||
prompt = "Today is a sunny day and I like"
|
||||
model_path = DEFAULT_MODEL_NAME_FOR_TEST
|
||||
model_path = DEFAULT_SMALL_MODEL_NAME_FOR_TEST
|
||||
|
||||
sampling_params = {"temperature": 0, "max_new_tokens": 8}
|
||||
|
||||
@@ -40,7 +43,7 @@ class TestSRTEngine(unittest.TestCase):
|
||||
def test_2_engine_multiple_generate(self):
|
||||
# just to ensure there is no issue running multiple generate calls
|
||||
prompt = "Today is a sunny day and I like"
|
||||
model_path = DEFAULT_MODEL_NAME_FOR_TEST
|
||||
model_path = DEFAULT_SMALL_MODEL_NAME_FOR_TEST
|
||||
|
||||
sampling_params = {"temperature": 0, "max_new_tokens": 8}
|
||||
|
||||
@@ -66,7 +69,7 @@ class TestSRTEngine(unittest.TestCase):
|
||||
|
||||
# Create an LLM.
|
||||
llm = sgl.Engine(
|
||||
model_path=DEFAULT_MODEL_NAME_FOR_TEST,
|
||||
model_path=DEFAULT_SMALL_MODEL_NAME_FOR_TEST,
|
||||
log_level="error",
|
||||
)
|
||||
|
||||
@@ -110,7 +113,7 @@ class TestSRTEngine(unittest.TestCase):
|
||||
def test_5_prompt_input_ids_consistency(self):
|
||||
prompt = "The capital of UK is"
|
||||
|
||||
model_path = DEFAULT_MODEL_NAME_FOR_TEST
|
||||
model_path = DEFAULT_SMALL_MODEL_NAME_FOR_TEST
|
||||
engine = sgl.Engine(model_path=model_path, random_seed=42, log_level="error")
|
||||
sampling_params = {"temperature": 0, "max_new_tokens": 8}
|
||||
out1 = engine.generate(prompt, sampling_params)["text"]
|
||||
|
||||
@@ -5,7 +5,7 @@ import requests
|
||||
|
||||
from sglang.srt.utils import kill_child_process
|
||||
from sglang.test.test_utils import (
|
||||
DEFAULT_MODEL_NAME_FOR_TEST,
|
||||
DEFAULT_SMALL_MODEL_NAME_FOR_TEST,
|
||||
DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
|
||||
DEFAULT_URL_FOR_TEST,
|
||||
popen_launch_server,
|
||||
@@ -15,7 +15,7 @@ from sglang.test.test_utils import (
|
||||
class TestUpdateWeights(unittest.TestCase):
|
||||
@classmethod
|
||||
def setUpClass(cls):
|
||||
cls.model = DEFAULT_MODEL_NAME_FOR_TEST
|
||||
cls.model = DEFAULT_SMALL_MODEL_NAME_FOR_TEST
|
||||
cls.base_url = DEFAULT_URL_FOR_TEST
|
||||
cls.process = popen_launch_server(
|
||||
cls.model, cls.base_url, timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH
|
||||
@@ -64,7 +64,7 @@ class TestUpdateWeights(unittest.TestCase):
|
||||
origin_response = self.run_decode()
|
||||
|
||||
# update weights
|
||||
new_model_path = "meta-llama/Meta-Llama-3.1-8B"
|
||||
new_model_path = DEFAULT_SMALL_MODEL_NAME_FOR_TEST.replace("-Instruct", "")
|
||||
ret = self.run_update_weights(new_model_path)
|
||||
assert ret["success"]
|
||||
|
||||
@@ -92,7 +92,7 @@ class TestUpdateWeights(unittest.TestCase):
|
||||
origin_response = self.run_decode()
|
||||
|
||||
# update weights
|
||||
new_model_path = "meta-llama/Meta-Llama-3.1-8B-1"
|
||||
new_model_path = DEFAULT_SMALL_MODEL_NAME_FOR_TEST.replace("-Instruct", "wrong")
|
||||
ret = self.run_update_weights(new_model_path)
|
||||
assert not ret["success"]
|
||||
|
||||
|
||||
Reference in New Issue
Block a user