diff --git a/python/sglang/srt/managers/tokenizer_manager.py b/python/sglang/srt/managers/tokenizer_manager.py index cf412b8fa..11e14e5b6 100644 --- a/python/sglang/srt/managers/tokenizer_manager.py +++ b/python/sglang/srt/managers/tokenizer_manager.py @@ -391,8 +391,12 @@ class TokenizerManager: async with self.model_update_lock: # wait for the previous generation requests to finish - while len(self.rid_to_state) > 0: - await asyncio.sleep(0.001) + for i in range(3): + while len(self.rid_to_state) > 0: + await asyncio.sleep(0.001) + # FIXME: We add some sleep here to avoid some race conditions. + # We can use a read-write lock as a better fix. + await asyncio.sleep(0.01) self.send_to_scheduler.send_pyobj(obj) self.model_update_result = asyncio.Future() diff --git a/python/sglang/srt/server.py b/python/sglang/srt/server.py index 75ffed325..ecde19f5b 100644 --- a/python/sglang/srt/server.py +++ b/python/sglang/srt/server.py @@ -25,20 +25,16 @@ import json import logging import multiprocessing as mp import os -import re -import tempfile import threading import time from http import HTTPStatus from typing import AsyncIterator, Dict, List, Optional, Union -import orjson -from starlette.routing import Mount - # Fix a bug of Python threading setattr(threading, "_register_atexit", lambda *args, **kwargs: None) import aiohttp +import orjson import requests import uvicorn import uvloop @@ -77,6 +73,7 @@ from sglang.srt.openai_api.protocol import ModelCard, ModelList from sglang.srt.server_args import PortArgs, ServerArgs from sglang.srt.utils import ( add_api_key_middleware, + add_prometheus_middleware, assert_pkg_version, configure_logger, delete_directory, @@ -84,16 +81,13 @@ from sglang.srt.utils import ( kill_child_process, maybe_set_triton_cache_manager, prepare_model_and_tokenizer, + set_prometheus_multiproc_dir, set_ulimit, ) from sglang.utils import get_exception_traceback logger = logging.getLogger(__name__) -# Temporary directory for prometheus multiprocess mode -# Cleaned up automatically when this object is garbage collected -prometheus_multiproc_dir: tempfile.TemporaryDirectory - asyncio.set_event_loop_policy(uvloop.EventLoopPolicy()) @@ -445,10 +439,6 @@ def launch_server( 1. The HTTP server and Tokenizer Manager both run in the main process. 2. Inter-process communication is done through ICP (each process uses a different port) via the ZMQ library. """ - - if server_args.enable_metrics: - _set_prometheus_env() - launch_engine(server_args=server_args) # Add api key authorization @@ -487,36 +477,6 @@ def launch_server( t.join() -def add_prometheus_middleware(app: FastAPI): - # Adapted from https://github.com/vllm-project/vllm/blob/v0.6.1/vllm/entrypoints/openai/api_server.py#L216 - from prometheus_client import CollectorRegistry, make_asgi_app, multiprocess - - registry = CollectorRegistry() - multiprocess.MultiProcessCollector(registry) - metrics_route = Mount("/metrics", make_asgi_app(registry=registry)) - - # Workaround for 307 Redirect for /metrics - metrics_route.path_regex = re.compile("^/metrics(?P.*)$") - app.routes.append(metrics_route) - - -def _set_prometheus_env(): - # Set prometheus multiprocess directory - # sglang uses prometheus multiprocess mode - # we need to set this before importing prometheus_client - # https://prometheus.github.io/client_python/multiprocess/ - global prometheus_multiproc_dir - if "PROMETHEUS_MULTIPROC_DIR" in os.environ: - logger.debug(f"User set PROMETHEUS_MULTIPROC_DIR detected.") - prometheus_multiproc_dir = tempfile.TemporaryDirectory( - dir=os.environ["PROMETHEUS_MULTIPROC_DIR"] - ) - else: - prometheus_multiproc_dir = tempfile.TemporaryDirectory() - os.environ["PROMETHEUS_MULTIPROC_DIR"] = prometheus_multiproc_dir.name - logger.debug(f"PROMETHEUS_MULTIPROC_DIR: {os.environ['PROMETHEUS_MULTIPROC_DIR']}") - - def _set_envs_and_config(server_args: ServerArgs): # Set global environments os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3" @@ -543,6 +503,10 @@ def _set_envs_and_config(server_args: ServerArgs): "at https://docs.flashinfer.ai/installation.html.", ) + # Set prometheus env vars + if server_args.enable_metrics: + set_prometheus_multiproc_dir() + mp.set_start_method("spawn", force=True) diff --git a/python/sglang/srt/utils.py b/python/sglang/srt/utils.py index 5ee0fe59d..d8184b018 100644 --- a/python/sglang/srt/utils.py +++ b/python/sglang/srt/utils.py @@ -22,10 +22,12 @@ import logging import os import pickle import random +import re import resource import shutil import signal import socket +import tempfile import time import warnings from importlib.metadata import PackageNotFoundError, version @@ -41,6 +43,7 @@ import triton import zmq from fastapi.responses import ORJSONResponse from packaging import version as pkg_version +from starlette.routing import Mount from torch import nn from torch.profiler import ProfilerActivity, profile, record_function from triton.runtime.cache import ( @@ -752,3 +755,38 @@ def delete_directory(dirpath): shutil.rmtree(dirpath) except OSError as e: print(f"Warning: {dirpath} : {e.strerror}") + + +# Temporary directory for prometheus multiprocess mode +# Cleaned up automatically when this object is garbage collected +prometheus_multiproc_dir: tempfile.TemporaryDirectory + + +def set_prometheus_multiproc_dir(): + # Set prometheus multiprocess directory + # sglang uses prometheus multiprocess mode + # we need to set this before importing prometheus_client + # https://prometheus.github.io/client_python/multiprocess/ + global prometheus_multiproc_dir + + if "PROMETHEUS_MULTIPROC_DIR" in os.environ: + logger.debug("User set PROMETHEUS_MULTIPROC_DIR detected.") + prometheus_multiproc_dir = tempfile.TemporaryDirectory( + dir=os.environ["PROMETHEUS_MULTIPROC_DIR"] + ) + else: + prometheus_multiproc_dir = tempfile.TemporaryDirectory() + os.environ["PROMETHEUS_MULTIPROC_DIR"] = prometheus_multiproc_dir.name + logger.debug(f"PROMETHEUS_MULTIPROC_DIR: {os.environ['PROMETHEUS_MULTIPROC_DIR']}") + + +def add_prometheus_middleware(app): + from prometheus_client import CollectorRegistry, make_asgi_app, multiprocess + + registry = CollectorRegistry() + multiprocess.MultiProcessCollector(registry) + metrics_route = Mount("/metrics", make_asgi_app(registry=registry)) + + # Workaround for 307 Redirect for /metrics + metrics_route.path_regex = re.compile("^/metrics(?P.*)$") + app.routes.append(metrics_route) diff --git a/python/sglang/test/test_utils.py b/python/sglang/test/test_utils.py index a7893494b..2bd713898 100644 --- a/python/sglang/test/test_utils.py +++ b/python/sglang/test/test_utils.py @@ -27,6 +27,7 @@ from sglang.utils import get_exception_traceback DEFAULT_FP8_MODEL_NAME_FOR_TEST = "neuralmagic/Meta-Llama-3.1-8B-FP8" DEFAULT_MODEL_NAME_FOR_TEST = "meta-llama/Llama-3.1-8B-Instruct" +DEFAULT_SMALL_MODEL_NAME_FOR_TEST = "meta-llama/Llama-3.2-1B-Instruct" DEFAULT_MOE_MODEL_NAME_FOR_TEST = "mistralai/Mixtral-8x7B-Instruct-v0.1" DEFAULT_MLA_MODEL_NAME_FOR_TEST = "deepseek-ai/DeepSeek-Coder-V2-Lite-Instruct" DEFAULT_MLA_FP8_MODEL_NAME_FOR_TEST = "neuralmagic/DeepSeek-Coder-V2-Lite-Instruct-FP8" @@ -404,7 +405,6 @@ def popen_launch_server( other_args: tuple = (), env: Optional[dict] = None, return_stdout_stderr: Optional[tuple] = None, - enable_metrics: bool = False, ): _, host, port = base_url.split(":") host = host[2:] @@ -423,8 +423,6 @@ def popen_launch_server( ] if api_key: command += ["--api-key", api_key] - if enable_metrics: - command += ["--enable-metrics"] if return_stdout_stderr: process = subprocess.Popen( diff --git a/scripts/ci_install_dependency.sh b/scripts/ci_install_dependency.sh index a219e02e2..fd0299db0 100644 --- a/scripts/ci_install_dependency.sh +++ b/scripts/ci_install_dependency.sh @@ -4,5 +4,5 @@ Install the dependency in CI. pip install --upgrade pip pip install -e "python[all]" -pip install transformers==4.45.2 sentence_transformers +pip install transformers==4.45.2 sentence_transformers accelerate peft pip install flashinfer -i https://flashinfer.ai/whl/cu121/torch2.4/ --force-reinstall diff --git a/test/srt/run_suite.py b/test/srt/run_suite.py index f7277f03d..697fb8d21 100644 --- a/test/srt/run_suite.py +++ b/test/srt/run_suite.py @@ -16,6 +16,7 @@ suites = { "test_eval_accuracy_mini.py", "test_json_constrained.py", "test_large_max_new_tokens.py", + "test_metrics.py", "test_openai_server.py", "test_overlap_schedule.py", "test_pytorch_sampling_backend.py", diff --git a/test/srt/test_bench_latency.py b/test/srt/test_bench_latency.py index 4d2042ccf..fa6b8e2fa 100644 --- a/test/srt/test_bench_latency.py +++ b/test/srt/test_bench_latency.py @@ -1,7 +1,5 @@ -import subprocess import unittest -from sglang.srt.utils import kill_child_process from sglang.test.test_utils import ( DEFAULT_MODEL_NAME_FOR_TEST, DEFAULT_MOE_MODEL_NAME_FOR_TEST, diff --git a/test/srt/test_cache_report.py b/test/srt/test_cache_report.py index b790c3ae6..5d498ac3f 100644 --- a/test/srt/test_cache_report.py +++ b/test/srt/test_cache_report.py @@ -6,7 +6,7 @@ import requests from sglang.srt.utils import kill_child_process from sglang.test.test_utils import ( - DEFAULT_MODEL_NAME_FOR_TEST, + DEFAULT_SMALL_MODEL_NAME_FOR_TEST, DEFAULT_URL_FOR_TEST, popen_launch_server, ) @@ -15,7 +15,7 @@ from sglang.test.test_utils import ( class TestCacheReport(unittest.TestCase): @classmethod def setUpClass(cls): - cls.model = DEFAULT_MODEL_NAME_FOR_TEST + cls.model = DEFAULT_SMALL_MODEL_NAME_FOR_TEST cls.base_url = DEFAULT_URL_FOR_TEST cls.min_cached = 5 cls.process = popen_launch_server( diff --git a/test/srt/test_large_max_new_tokens.py b/test/srt/test_large_max_new_tokens.py index ea9c20e5c..0605444ba 100644 --- a/test/srt/test_large_max_new_tokens.py +++ b/test/srt/test_large_max_new_tokens.py @@ -3,6 +3,7 @@ python3 -m unittest test_large_max_new_tokens.TestLargeMaxNewTokens.test_chat_co """ import os +import time import unittest from concurrent.futures import ThreadPoolExecutor @@ -11,7 +12,7 @@ import openai from sglang.srt.hf_transformers_utils import get_tokenizer from sglang.srt.utils import kill_child_process from sglang.test.test_utils import ( - DEFAULT_MODEL_NAME_FOR_TEST, + DEFAULT_SMALL_MODEL_NAME_FOR_TEST, DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, DEFAULT_URL_FOR_TEST, popen_launch_server, @@ -21,7 +22,7 @@ from sglang.test.test_utils import ( class TestLargeMaxNewTokens(unittest.TestCase): @classmethod def setUpClass(cls): - cls.model = DEFAULT_MODEL_NAME_FOR_TEST + cls.model = DEFAULT_SMALL_MODEL_NAME_FOR_TEST cls.base_url = DEFAULT_URL_FOR_TEST cls.api_key = "sk-123456" @@ -33,12 +34,19 @@ class TestLargeMaxNewTokens(unittest.TestCase): cls.base_url, timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, api_key=cls.api_key, - other_args=("--max-total-token", "1024", "--context-len", "8192"), + other_args=( + "--max-total-token", + "1024", + "--context-len", + "8192", + "--decode-log-interval", + "2", + ), env={"SGLANG_CLIP_MAX_NEW_TOKENS_ESTIMATION": "256", **os.environ}, return_stdout_stderr=(cls.stdout, cls.stderr), ) cls.base_url += "/v1" - cls.tokenizer = get_tokenizer(DEFAULT_MODEL_NAME_FOR_TEST) + cls.tokenizer = get_tokenizer(DEFAULT_SMALL_MODEL_NAME_FOR_TEST) @classmethod def tearDownClass(cls): @@ -75,6 +83,7 @@ class TestLargeMaxNewTokens(unittest.TestCase): # Ensure that they are running concurrently pt = 0 while pt >= 0: + time.sleep(5) lines = open("stderr.txt").readlines() for line in lines[pt:]: print(line, end="", flush=True) diff --git a/test/srt/test_enable_metrics.py b/test/srt/test_metrics.py similarity index 61% rename from test/srt/test_enable_metrics.py rename to test/srt/test_metrics.py index 794e2a325..37ba6c7ef 100644 --- a/test/srt/test_enable_metrics.py +++ b/test/srt/test_metrics.py @@ -1,31 +1,24 @@ import unittest -from types import SimpleNamespace import requests from sglang.srt.utils import kill_child_process -from sglang.test.run_eval import run_eval from sglang.test.test_utils import ( - DEFAULT_MODEL_NAME_FOR_TEST, + DEFAULT_SMALL_MODEL_NAME_FOR_TEST, DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, DEFAULT_URL_FOR_TEST, popen_launch_server, ) -TEST_MODEL = ( - DEFAULT_MODEL_NAME_FOR_TEST # I used "google/gemma-2-2b-it" for testing locally -) - class TestEnableMetrics(unittest.TestCase): def test_metrics_enabled(self): """Test that metrics endpoint returns data when enabled""" - # Launch server with metrics enabled process = popen_launch_server( - model=TEST_MODEL, - base_url=DEFAULT_URL_FOR_TEST, + DEFAULT_SMALL_MODEL_NAME_FOR_TEST, + DEFAULT_URL_FOR_TEST, timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, - enable_metrics=True, + other_args=["--enable-metrics"], ) try: @@ -38,6 +31,8 @@ class TestEnableMetrics(unittest.TestCase): self.assertEqual(metrics_response.status_code, 200) metrics_content = metrics_response.text + print(f"{metrics_content=}") + # Verify essential metrics are present essential_metrics = [ "sglang:prompt_tokens_total", @@ -53,7 +48,7 @@ class TestEnableMetrics(unittest.TestCase): self.assertIn(metric, metrics_content, f"Missing metric: {metric}") # Verify model name label is present and correct - expected_model_name = TEST_MODEL + expected_model_name = DEFAULT_SMALL_MODEL_NAME_FOR_TEST self.assertIn(f'model_name="{expected_model_name}"', metrics_content) # Verify metrics have values (not empty) self.assertIn("_sum{", metrics_content) @@ -63,22 +58,6 @@ class TestEnableMetrics(unittest.TestCase): finally: kill_child_process(process.pid, include_self=True) - def test_metrics_disabled(self): - """Test that metrics endpoint returns 404 when disabled""" - # Launch server with metrics disabled - process = popen_launch_server( - model=TEST_MODEL, - base_url=DEFAULT_URL_FOR_TEST, - timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, - enable_metrics=False, - ) - try: - response = requests.get(f"{DEFAULT_URL_FOR_TEST}/health_generate") - self.assertEqual(response.status_code, 200) - # Verify metrics endpoint is not available - metrics_response = requests.get(f"{DEFAULT_URL_FOR_TEST}/metrics") - self.assertEqual(metrics_response.status_code, 404) - - finally: - kill_child_process(process.pid, include_self=True) +if __name__ == "__main__": + unittest.main() diff --git a/test/srt/test_openai_server.py b/test/srt/test_openai_server.py index 070a0633c..048026b8b 100644 --- a/test/srt/test_openai_server.py +++ b/test/srt/test_openai_server.py @@ -13,7 +13,7 @@ import openai from sglang.srt.hf_transformers_utils import get_tokenizer from sglang.srt.utils import kill_child_process from sglang.test.test_utils import ( - DEFAULT_MODEL_NAME_FOR_TEST, + DEFAULT_SMALL_MODEL_NAME_FOR_TEST, DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, DEFAULT_URL_FOR_TEST, popen_launch_server, @@ -23,7 +23,7 @@ from sglang.test.test_utils import ( class TestOpenAIServer(unittest.TestCase): @classmethod def setUpClass(cls): - cls.model = DEFAULT_MODEL_NAME_FOR_TEST + cls.model = DEFAULT_SMALL_MODEL_NAME_FOR_TEST cls.base_url = DEFAULT_URL_FOR_TEST cls.api_key = "sk-123456" cls.process = popen_launch_server( @@ -33,7 +33,7 @@ class TestOpenAIServer(unittest.TestCase): api_key=cls.api_key, ) cls.base_url += "/v1" - cls.tokenizer = get_tokenizer(DEFAULT_MODEL_NAME_FOR_TEST) + cls.tokenizer = get_tokenizer(DEFAULT_SMALL_MODEL_NAME_FOR_TEST) @classmethod def tearDownClass(cls): diff --git a/test/srt/test_radix_attention.py b/test/srt/test_radix_attention.py index e858ba9ee..f9da49a1d 100644 --- a/test/srt/test_radix_attention.py +++ b/test/srt/test_radix_attention.py @@ -5,7 +5,7 @@ import unittest import requests from sglang.test.test_utils import ( - DEFAULT_MODEL_NAME_FOR_TEST, + DEFAULT_SMALL_MODEL_NAME_FOR_TEST, DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, DEFAULT_URL_FOR_TEST, kill_child_process, @@ -62,7 +62,7 @@ def run_test(base_url, nodes): class TestRadixCacheFCFS(unittest.TestCase): @classmethod def setUpClass(cls): - cls.model = DEFAULT_MODEL_NAME_FOR_TEST + cls.model = DEFAULT_SMALL_MODEL_NAME_FOR_TEST cls.base_url = DEFAULT_URL_FOR_TEST cls.process = popen_launch_server( cls.model, @@ -90,7 +90,7 @@ class TestRadixCacheFCFS(unittest.TestCase): class TestRadixCacheLPM(TestRadixCacheFCFS): @classmethod def setUpClass(cls): - cls.model = DEFAULT_MODEL_NAME_FOR_TEST + cls.model = DEFAULT_SMALL_MODEL_NAME_FOR_TEST cls.base_url = DEFAULT_URL_FOR_TEST cls.process = popen_launch_server( cls.model, @@ -110,7 +110,7 @@ class TestRadixCacheLPM(TestRadixCacheFCFS): class TestRadixCacheOverlapLPM(TestRadixCacheFCFS): @classmethod def setUpClass(cls): - cls.model = DEFAULT_MODEL_NAME_FOR_TEST + cls.model = DEFAULT_SMALL_MODEL_NAME_FOR_TEST cls.base_url = DEFAULT_URL_FOR_TEST cls.process = popen_launch_server( cls.model, diff --git a/test/srt/test_skip_tokenizer_init.py b/test/srt/test_skip_tokenizer_init.py index a95026e20..7ec73b15d 100644 --- a/test/srt/test_skip_tokenizer_init.py +++ b/test/srt/test_skip_tokenizer_init.py @@ -9,7 +9,7 @@ import requests from sglang.srt.utils import kill_child_process from sglang.test.test_utils import ( - DEFAULT_MODEL_NAME_FOR_TEST, + DEFAULT_SMALL_MODEL_NAME_FOR_TEST, DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, DEFAULT_URL_FOR_TEST, popen_launch_server, @@ -19,7 +19,7 @@ from sglang.test.test_utils import ( class TestSkipTokenizerInit(unittest.TestCase): @classmethod def setUpClass(cls): - cls.model = DEFAULT_MODEL_NAME_FOR_TEST + cls.model = DEFAULT_SMALL_MODEL_NAME_FOR_TEST cls.base_url = DEFAULT_URL_FOR_TEST cls.process = popen_launch_server( cls.model, diff --git a/test/srt/test_srt_endpoint.py b/test/srt/test_srt_endpoint.py index 045f3100c..b13ed9ac8 100644 --- a/test/srt/test_srt_endpoint.py +++ b/test/srt/test_srt_endpoint.py @@ -10,7 +10,7 @@ import requests from sglang.srt.utils import kill_child_process from sglang.test.test_utils import ( - DEFAULT_MODEL_NAME_FOR_TEST, + DEFAULT_SMALL_MODEL_NAME_FOR_TEST, DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, DEFAULT_URL_FOR_TEST, popen_launch_server, @@ -20,7 +20,7 @@ from sglang.test.test_utils import ( class TestSRTEndpoint(unittest.TestCase): @classmethod def setUpClass(cls): - cls.model = DEFAULT_MODEL_NAME_FOR_TEST + cls.model = DEFAULT_SMALL_MODEL_NAME_FOR_TEST cls.base_url = DEFAULT_URL_FOR_TEST cls.process = popen_launch_server( cls.model, cls.base_url, timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH diff --git a/test/srt/test_srt_engine.py b/test/srt/test_srt_engine.py index 0bf46c771..b44b4e2b6 100644 --- a/test/srt/test_srt_engine.py +++ b/test/srt/test_srt_engine.py @@ -11,14 +11,17 @@ from types import SimpleNamespace import sglang as sgl from sglang.srt.hf_transformers_utils import get_tokenizer from sglang.test.few_shot_gsm8k_engine import run_eval -from sglang.test.test_utils import DEFAULT_MODEL_NAME_FOR_TEST +from sglang.test.test_utils import ( + DEFAULT_MODEL_NAME_FOR_TEST, + DEFAULT_SMALL_MODEL_NAME_FOR_TEST, +) class TestSRTEngine(unittest.TestCase): def test_1_engine_runtime_consistency(self): prompt = "Today is a sunny day and I like" - model_path = DEFAULT_MODEL_NAME_FOR_TEST + model_path = DEFAULT_SMALL_MODEL_NAME_FOR_TEST sampling_params = {"temperature": 0, "max_new_tokens": 8} @@ -40,7 +43,7 @@ class TestSRTEngine(unittest.TestCase): def test_2_engine_multiple_generate(self): # just to ensure there is no issue running multiple generate calls prompt = "Today is a sunny day and I like" - model_path = DEFAULT_MODEL_NAME_FOR_TEST + model_path = DEFAULT_SMALL_MODEL_NAME_FOR_TEST sampling_params = {"temperature": 0, "max_new_tokens": 8} @@ -66,7 +69,7 @@ class TestSRTEngine(unittest.TestCase): # Create an LLM. llm = sgl.Engine( - model_path=DEFAULT_MODEL_NAME_FOR_TEST, + model_path=DEFAULT_SMALL_MODEL_NAME_FOR_TEST, log_level="error", ) @@ -110,7 +113,7 @@ class TestSRTEngine(unittest.TestCase): def test_5_prompt_input_ids_consistency(self): prompt = "The capital of UK is" - model_path = DEFAULT_MODEL_NAME_FOR_TEST + model_path = DEFAULT_SMALL_MODEL_NAME_FOR_TEST engine = sgl.Engine(model_path=model_path, random_seed=42, log_level="error") sampling_params = {"temperature": 0, "max_new_tokens": 8} out1 = engine.generate(prompt, sampling_params)["text"] diff --git a/test/srt/test_update_weights.py b/test/srt/test_update_weights.py index c3cde0f14..327da729a 100644 --- a/test/srt/test_update_weights.py +++ b/test/srt/test_update_weights.py @@ -5,7 +5,7 @@ import requests from sglang.srt.utils import kill_child_process from sglang.test.test_utils import ( - DEFAULT_MODEL_NAME_FOR_TEST, + DEFAULT_SMALL_MODEL_NAME_FOR_TEST, DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, DEFAULT_URL_FOR_TEST, popen_launch_server, @@ -15,7 +15,7 @@ from sglang.test.test_utils import ( class TestUpdateWeights(unittest.TestCase): @classmethod def setUpClass(cls): - cls.model = DEFAULT_MODEL_NAME_FOR_TEST + cls.model = DEFAULT_SMALL_MODEL_NAME_FOR_TEST cls.base_url = DEFAULT_URL_FOR_TEST cls.process = popen_launch_server( cls.model, cls.base_url, timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH @@ -64,7 +64,7 @@ class TestUpdateWeights(unittest.TestCase): origin_response = self.run_decode() # update weights - new_model_path = "meta-llama/Meta-Llama-3.1-8B" + new_model_path = DEFAULT_SMALL_MODEL_NAME_FOR_TEST.replace("-Instruct", "") ret = self.run_update_weights(new_model_path) assert ret["success"] @@ -92,7 +92,7 @@ class TestUpdateWeights(unittest.TestCase): origin_response = self.run_decode() # update weights - new_model_path = "meta-llama/Meta-Llama-3.1-8B-1" + new_model_path = DEFAULT_SMALL_MODEL_NAME_FOR_TEST.replace("-Instruct", "wrong") ret = self.run_update_weights(new_model_path) assert not ret["success"]