diff --git a/.github/workflows/pr-test.yml b/.github/workflows/pr-test.yml index f228901dd..59a6abece 100644 --- a/.github/workflows/pr-test.yml +++ b/.github/workflows/pr-test.yml @@ -47,10 +47,10 @@ jobs: bash scripts/ci_install_dependency.sh - name: Run test - timeout-minutes: 20 + timeout-minutes: 25 run: | cd test/srt - python3 run_suite.py --suite minimal --range-begin 0 --range-end 6 + python3 run_suite.py --suite minimal --range-begin 0 --range-end 5 unit-test-backend-part-2: if: github.repository == 'sgl-project/sglang' || github.event_name == 'pull_request' @@ -64,10 +64,10 @@ jobs: bash scripts/ci_install_dependency.sh - name: Run test - timeout-minutes: 20 + timeout-minutes: 25 run: | cd test/srt - python3 run_suite.py --suite minimal --range-begin 6 --range-end 14 + python3 run_suite.py --suite minimal --range-begin 5 --range-end 14 unit-test-backend-part-3: if: github.repository == 'sgl-project/sglang' || github.event_name == 'pull_request' @@ -81,7 +81,7 @@ jobs: bash scripts/ci_install_dependency.sh - name: Run test - timeout-minutes: 20 + timeout-minutes: 25 run: | cd test/srt python3 run_suite.py --suite minimal --range-begin 14 --range-end 20 @@ -98,7 +98,7 @@ jobs: bash scripts/ci_install_dependency.sh - name: Run test - timeout-minutes: 20 + timeout-minutes: 25 run: | cd test/srt python3 run_suite.py --suite minimal --range-begin 20 diff --git a/python/sglang/srt/server.py b/python/sglang/srt/server.py index ecde19f5b..ce01154a6 100644 --- a/python/sglang/srt/server.py +++ b/python/sglang/srt/server.py @@ -114,9 +114,16 @@ async def health() -> Response: @app.get("/health_generate") async def health_generate(request: Request) -> Response: """Check the health of the inference server by generating one token.""" - gri = GenerateReqInput( - text="s", sampling_params={"max_new_tokens": 1, "temperature": 0.7} - ) + + if tokenizer_manager.is_generation: + gri = GenerateReqInput( + input_ids=[0], sampling_params={"max_new_tokens": 1, "temperature": 0.7} + ) + else: + gri = EmbeddingReqInput( + input_ids=[0], sampling_params={"max_new_tokens": 1, "temperature": 0.7} + ) + try: async for _ in tokenizer_manager.generate_request(gri, request): break diff --git a/python/sglang/test/test_utils.py b/python/sglang/test/test_utils.py index 2bd713898..070ca508a 100644 --- a/python/sglang/test/test_utils.py +++ b/python/sglang/test/test_utils.py @@ -442,7 +442,7 @@ def popen_launch_server( "Content-Type": "application/json; charset=utf-8", "Authorization": f"Bearer {api_key}", } - response = requests.get(f"{base_url}/v1/models", headers=headers) + response = requests.get(f"{base_url}/health_generate", headers=headers) if response.status_code == 200: return process except requests.RequestException: @@ -637,8 +637,8 @@ def calculate_rouge_l(output_strs_list1, output_strs_list2): return rouge_l_scores -STDOUT_FILENAME = "stdout.txt" STDERR_FILENAME = "stderr.txt" +STDOUT_FILENAME = "stdout.txt" def read_output(output_lines): diff --git a/test/srt/test_large_max_new_tokens.py b/test/srt/test_large_max_new_tokens.py index 0605444ba..6ebe5e0d9 100644 --- a/test/srt/test_large_max_new_tokens.py +++ b/test/srt/test_large_max_new_tokens.py @@ -15,6 +15,8 @@ from sglang.test.test_utils import ( DEFAULT_SMALL_MODEL_NAME_FOR_TEST, DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, DEFAULT_URL_FOR_TEST, + STDERR_FILENAME, + STDOUT_FILENAME, popen_launch_server, ) @@ -26,8 +28,8 @@ class TestLargeMaxNewTokens(unittest.TestCase): cls.base_url = DEFAULT_URL_FOR_TEST cls.api_key = "sk-123456" - cls.stdout = open("stdout.txt", "w") - cls.stderr = open("stderr.txt", "w") + cls.stdout = open(STDOUT_FILENAME, "w") + cls.stderr = open(STDERR_FILENAME, "w") cls.process = popen_launch_server( cls.model, @@ -53,8 +55,8 @@ class TestLargeMaxNewTokens(unittest.TestCase): kill_child_process(cls.process.pid, include_self=True) cls.stdout.close() cls.stderr.close() - os.remove("stdout.txt") - os.remove("stderr.txt") + os.remove(STDOUT_FILENAME) + os.remove(STDERR_FILENAME) def run_chat_completion(self): client = openai.Client(api_key=self.api_key, base_url=self.base_url) @@ -84,7 +86,7 @@ class TestLargeMaxNewTokens(unittest.TestCase): pt = 0 while pt >= 0: time.sleep(5) - lines = open("stderr.txt").readlines() + lines = open(STDERR_FILENAME).readlines() for line in lines[pt:]: print(line, end="", flush=True) if f"#running-req: {num_requests}" in line: