From d84c5e70f7a0d309978eb64fa3e7aa5ac47fbb7a Mon Sep 17 00:00:00 2001 From: Lianmin Zheng Date: Sun, 11 Aug 2024 16:41:03 -0700 Subject: [PATCH] Test the case when max_new_tokens is very large (#1038) --- .../srt/managers/detokenizer_manager.py | 4 +- .../sglang/srt/managers/policy_scheduler.py | 7 +- python/sglang/srt/openai_api/adapter.py | 4 +- python/sglang/srt/server_args.py | 2 +- python/sglang/test/test_utils.py | 13 +++- test/srt/run_suite.py | 12 ++-- test/srt/test_large_max_new_tokens.py | 72 +++++++++++++++++++ 7 files changed, 100 insertions(+), 14 deletions(-) create mode 100644 test/srt/test_large_max_new_tokens.py diff --git a/python/sglang/srt/managers/detokenizer_manager.py b/python/sglang/srt/managers/detokenizer_manager.py index d765a365f..08ccfd5ce 100644 --- a/python/sglang/srt/managers/detokenizer_manager.py +++ b/python/sglang/srt/managers/detokenizer_manager.py @@ -32,7 +32,7 @@ from sglang.srt.managers.io_struct import ( ) from sglang.srt.managers.schedule_batch import FINISH_MATCHED_STR from sglang.srt.server_args import PortArgs, ServerArgs -from sglang.utils import find_printable_text, get_exception_traceback, graceful_registry +from sglang.utils import find_printable_text, get_exception_traceback asyncio.set_event_loop_policy(uvloop.EventLoopPolicy()) @@ -164,8 +164,6 @@ def start_detokenizer_process( port_args: PortArgs, pipe_writer, ): - graceful_registry(inspect.currentframe().f_code.co_name) - try: manager = DetokenizerManager(server_args, port_args) except Exception: diff --git a/python/sglang/srt/managers/policy_scheduler.py b/python/sglang/srt/managers/policy_scheduler.py index 4fd0ea290..4bf700f51 100644 --- a/python/sglang/srt/managers/policy_scheduler.py +++ b/python/sglang/srt/managers/policy_scheduler.py @@ -15,6 +15,7 @@ limitations under the License. """Request policy scheduler""" +import os import random from collections import defaultdict from contextlib import contextmanager @@ -24,9 +25,11 @@ from sglang.srt.managers.schedule_batch import Req, ScheduleBatch from sglang.srt.mem_cache.base_prefix_cache import BasePrefixCache from sglang.srt.mem_cache.radix_cache import TreeNode -# Clip the max new tokens for the request whose max_new_tokens is very large. +# Clip the estimation of max_new_tokens for the request whose max_new_tokens is very large. # This can prevent the server from being too conservative. -CLIP_MAX_NEW_TOKENS = 4096 +# Note that this only clips the estimation in the scheduler but does not change the stop +# condition. The request can still generate tokens until it hits the unclipped max_new_tokens. +CLIP_MAX_NEW_TOKENS = int(os.environ.get("SGLANG_CLIP_MAX_NEW_TOKENS", "4096")) class PolicyScheduler: diff --git a/python/sglang/srt/openai_api/adapter.py b/python/sglang/srt/openai_api/adapter.py index 5c4b9719d..8998cf39d 100644 --- a/python/sglang/srt/openai_api/adapter.py +++ b/python/sglang/srt/openai_api/adapter.py @@ -77,7 +77,7 @@ class FileMetadata: batch_storage: Dict[str, BatchResponse] = {} file_id_request: Dict[str, FileMetadata] = {} file_id_response: Dict[str, FileResponse] = {} -# map file id to file path in SGlang backend +# map file id to file path in SGLang backend file_id_storage: Dict[str, str] = {} @@ -335,7 +335,7 @@ async def process_batch(tokenizer_manager, batch_id: str, batch_request: BatchRe } except Exception as e: - print("error in SGlang:", e) + print("error in SGLang:", e) # Update batch status to "failed" retrieve_batch = batch_storage[batch_id] retrieve_batch.status = "failed" diff --git a/python/sglang/srt/server_args.py b/python/sglang/srt/server_args.py index 539551c70..474c80b25 100644 --- a/python/sglang/srt/server_args.py +++ b/python/sglang/srt/server_args.py @@ -64,7 +64,7 @@ class ServerArgs: # Other api_key: Optional[str] = None - file_storage_pth: str = "SGlang_storage" + file_storage_pth: str = "SGLang_storage" # Data parallelism dp_size: int = 1 diff --git a/python/sglang/test/test_utils.py b/python/sglang/test/test_utils.py index 613645b57..22aa597f5 100644 --- a/python/sglang/test/test_utils.py +++ b/python/sglang/test/test_utils.py @@ -398,6 +398,8 @@ def popen_launch_server( timeout: float, api_key: Optional[str] = None, other_args: tuple = (), + env: Optional[dict] = None, + return_stdout_stderr: bool = False, ): _, host, port = base_url.split(":") host = host[2:] @@ -417,7 +419,16 @@ def popen_launch_server( if api_key: command += ["--api-key", api_key] - process = subprocess.Popen(command, stdout=None, stderr=None) + if return_stdout_stderr: + process = subprocess.Popen( + command, + stdout=subprocess.PIPE, + stderr=subprocess.PIPE, + env=env, + text=True, + ) + else: + process = subprocess.Popen(command, stdout=None, stderr=None, env=env) start_time = time.time() while time.time() - start_time < timeout: diff --git a/test/srt/run_suite.py b/test/srt/run_suite.py index 288645c21..08122389f 100644 --- a/test/srt/run_suite.py +++ b/test/srt/run_suite.py @@ -5,13 +5,15 @@ from sglang.test.test_utils import run_unittest_files suites = { "minimal": [ - "test_eval_accuracy.py", - "test_openai_server.py", - "test_vision_openai_server.py", - "test_embedding_openai_server.py", "test_chunked_prefill.py", + "test_embedding_openai_server.py", + "test_eval_accuracy.py", + "test_large_max_new_tokens.py", + "test_openai_server.py", + "test_skip_tokenizer_init.py", "test_torch_compile.py", - "test_models_from_modelscope.py", + "test_vision_openai_server.py", + "test_large_max_new_tokens.py", "models/test_generation_models.py", "models/test_embedding_models.py", "sampling/penaltylib", diff --git a/test/srt/test_large_max_new_tokens.py b/test/srt/test_large_max_new_tokens.py new file mode 100644 index 000000000..3b3212209 --- /dev/null +++ b/test/srt/test_large_max_new_tokens.py @@ -0,0 +1,72 @@ +import json +import os +import time +import unittest +from concurrent.futures import ThreadPoolExecutor + +import openai + +from sglang.srt.hf_transformers_utils import get_tokenizer +from sglang.srt.utils import kill_child_process +from sglang.test.test_utils import DEFAULT_MODEL_NAME_FOR_TEST, popen_launch_server + + +class TestOpenAIServer(unittest.TestCase): + + @classmethod + def setUpClass(cls): + cls.model = DEFAULT_MODEL_NAME_FOR_TEST + cls.base_url = "http://127.0.0.1:8157" + cls.api_key = "sk-123456" + cls.process = popen_launch_server( + cls.model, + cls.base_url, + timeout=300, + api_key=cls.api_key, + other_args=("--max-total-token", "1024"), + env={"SGLANG_CLIP_MAX_NEW_TOKENS": "256", **os.environ}, + return_stdout_stderr=True, + ) + cls.base_url += "/v1" + cls.tokenizer = get_tokenizer(DEFAULT_MODEL_NAME_FOR_TEST) + + @classmethod + def tearDownClass(cls): + kill_child_process(cls.process.pid) + + def run_chat_completion(self): + client = openai.Client(api_key=self.api_key, base_url=self.base_url) + response = client.chat.completions.create( + model=self.model, + messages=[ + {"role": "system", "content": "You are a helpful AI assistant"}, + { + "role": "user", + "content": "Please repeat the world 'hello' for 10000 times.", + }, + ], + temperature=0, + ) + return response + + def test_chat_completion(self): + num_requests = 4 + + futures = [] + with ThreadPoolExecutor(16) as executor: + for i in range(num_requests): + futures.append(executor.submit(self.run_chat_completion)) + + all_requests_running = False + for line in iter(self.process.stderr.readline, ""): + line = str(line) + print(line, end="") + if f"#running-req: {num_requests}" in line: + all_requests_running = True + break + + assert all_requests_running + + +if __name__ == "__main__": + unittest.main()