diff --git a/.github/workflows/pr-test-rust.yml b/.github/workflows/pr-test-rust.yml index b527a7c70..594c51055 100644 --- a/.github/workflows/pr-test-rust.yml +++ b/.github/workflows/pr-test-rust.yml @@ -144,12 +144,6 @@ jobs: python3 -m pip --no-cache-dir install --upgrade --break-system-packages genai-bench==0.0.2 pytest -m e2e -s -vv -o log_cli=true --log-cli-level=INFO - - name: Run Python E2E gRPC tests - run: | - bash scripts/killall_sglang.sh "nuk_gpus" - cd sgl-router - SHOW_ROUTER_LOGS=1 ROUTER_LOCAL_MODEL_PATH="/home/ubuntu/models" pytest py_test/e2e_grpc -s -vv -o log_cli=true --log-cli-level=INFO - - name: Upload benchmark results if: success() uses: actions/upload-artifact@v4 @@ -157,8 +151,58 @@ jobs: name: genai-bench-results-all-policies path: sgl-router/benchmark_**/ + pytest-rust-2: + if: github.event_name != 'pull_request' || contains(github.event.pull_request.labels.*.name, 'run-ci') + runs-on: 4-gpu-a10 + timeout-minutes: 16 + steps: + - name: Checkout code + uses: actions/checkout@v4 + + - name: Install rust dependencies + run: | + bash scripts/ci/ci_install_rust.sh + + - name: Configure sccache + uses: mozilla-actions/sccache-action@v0.0.9 + with: + version: "v0.10.0" + + - name: Rust cache + uses: Swatinem/rust-cache@v2 + with: + workspaces: sgl-router + cache-all-crates: true + cache-on-failure: true + + - name: Install SGLang dependencies + run: | + sudo --preserve-env=PATH bash scripts/ci/ci_install_dependency.sh + + - name: Build python binding + run: | + source "$HOME/.cargo/env" + export RUSTC_WRAPPER=sccache + cd sgl-router + pip install setuptools-rust wheel build + python3 -m build + pip install --force-reinstall dist/*.whl + + - name: Run Python E2E response API tests + run: | + bash scripts/killall_sglang.sh "nuk_gpus" + cd sgl-router + SHOW_ROUTER_LOGS=1 pytest py_test/e2e_response_api -s -vv -o log_cli=true --log-cli-level=INFO + + - name: Run Python E2E gRPC tests + run: | + bash scripts/killall_sglang.sh "nuk_gpus" + cd sgl-router + SHOW_ROUTER_LOGS=1 ROUTER_LOCAL_MODEL_PATH="/home/ubuntu/models" pytest py_test/e2e_grpc -s -vv -o log_cli=true --log-cli-level=INFO + + finish: - needs: [unit-test-rust, pytest-rust] + needs: [unit-test-rust, pytest-rust, pytest-rust-2] runs-on: ubuntu-latest steps: - name: Finish diff --git a/sgl-router/py_test/e2e_grpc/fixtures.py b/sgl-router/py_test/e2e_grpc/fixtures.py index 869c70167..f7047394c 100644 --- a/sgl-router/py_test/e2e_grpc/fixtures.py +++ b/sgl-router/py_test/e2e_grpc/fixtures.py @@ -267,8 +267,6 @@ def popen_launch_workers_and_router( policy, "--model-path", model, - "--log-level", - "warn", ] # Add worker URLs diff --git a/sgl-router/py_test/e2e_response_api/base.py b/sgl-router/py_test/e2e_response_api/base.py new file mode 100644 index 000000000..07cf50bb2 --- /dev/null +++ b/sgl-router/py_test/e2e_response_api/base.py @@ -0,0 +1,480 @@ +""" +Base test class for Response API e2e tests. + +This module provides base test classes that can be reused across different backends +(OpenAI, XAI, gRPC) with common test logic. +""" + +import json +import sys +import time +import unittest +from pathlib import Path +from typing import Optional + +import requests + +# Add current directory for local imports +_TEST_DIR = Path(__file__).parent +sys.path.insert(0, str(_TEST_DIR)) + +from util import CustomTestCase + + +class ResponseAPIBaseTest(CustomTestCase): + """Base class for Response API tests with common utilities.""" + + # To be set by subclasses + base_url: str = None + api_key: str = None + model: str = None + + def make_request( + self, + endpoint: str, + method: str = "POST", + json_data: Optional[dict] = None, + params: Optional[dict] = None, + ) -> requests.Response: + """ + Make HTTP request to router. + + Args: + endpoint: Endpoint path (e.g., "/v1/responses") + method: HTTP method (GET, POST, DELETE) + json_data: JSON body for POST requests + params: Query parameters + + Returns: + requests.Response object + """ + url = f"{self.base_url}{endpoint}" + headers = {"Content-Type": "application/json"} + if self.api_key: + headers["Authorization"] = f"Bearer {self.api_key}" + + if method == "POST": + resp = requests.post(url, json=json_data, headers=headers, params=params) + elif method == "GET": + resp = requests.get(url, headers=headers, params=params) + elif method == "DELETE": + resp = requests.delete(url, headers=headers, params=params) + else: + raise ValueError(f"Unsupported method: {method}") + return resp + + def create_response( + self, + input_text: str, + instructions: Optional[str] = None, + stream: bool = False, + max_output_tokens: Optional[int] = None, + temperature: Optional[float] = None, + previous_response_id: Optional[str] = None, + conversation: Optional[str] = None, + tools: Optional[list] = None, + background: bool = False, + **kwargs, + ) -> requests.Response: + """ + Create a response via POST /v1/responses. + + Args: + input_text: User input + instructions: Optional system instructions + stream: Whether to stream response + max_output_tokens: Optional max tokens to generate + temperature: Sampling temperature + previous_response_id: Optional previous response ID for state management + conversation: Optional conversation ID for state management + tools: Optional list of MCP tools + background: Whether to run in background mode + **kwargs: Additional request parameters + + Returns: + requests.Response object + """ + data = { + "model": self.model, + "input": input_text, + "stream": stream, + **kwargs, + } + + if instructions: + data["instructions"] = instructions + + if max_output_tokens is not None: + data["max_output_tokens"] = max_output_tokens + + if temperature is not None: + data["temperature"] = temperature + + if previous_response_id: + data["previous_response_id"] = previous_response_id + + if conversation: + data["conversation"] = conversation + + if tools: + data["tools"] = tools + + if background: + data["background"] = background + + if stream: + # For streaming, we need to handle SSE + return self._create_streaming_response(data) + else: + return self.make_request("/v1/responses", "POST", data) + + def _create_streaming_response(self, data: dict) -> requests.Response: + """Handle streaming response creation.""" + url = f"{self.base_url}/v1/responses" + headers = {"Content-Type": "application/json"} + if self.api_key: + headers["Authorization"] = f"Bearer {self.api_key}" + + # Return response object with stream=True + return requests.post(url, json=data, headers=headers, stream=True) + + def get_response(self, response_id: str) -> requests.Response: + """Get response by ID via GET /v1/responses/{response_id}.""" + return self.make_request(f"/v1/responses/{response_id}", "GET") + + def delete_response(self, response_id: str) -> requests.Response: + """Delete response by ID via DELETE /v1/responses/{response_id}.""" + return self.make_request(f"/v1/responses/{response_id}", "DELETE") + + def cancel_response(self, response_id: str) -> requests.Response: + """Cancel response by ID via POST /v1/responses/{response_id}/cancel.""" + return self.make_request(f"/v1/responses/{response_id}/cancel", "POST", {}) + + def get_response_input(self, response_id: str) -> requests.Response: + """Get response input items via GET /v1/responses/{response_id}/input.""" + return self.make_request(f"/v1/responses/{response_id}/input", "GET") + + def create_conversation(self, metadata: Optional[dict] = None) -> requests.Response: + """Create conversation via POST /v1/conversations.""" + data = {} + if metadata: + data["metadata"] = metadata + return self.make_request("/v1/conversations", "POST", data) + + def get_conversation(self, conversation_id: str) -> requests.Response: + """Get conversation by ID via GET /v1/conversations/{conversation_id}.""" + return self.make_request(f"/v1/conversations/{conversation_id}", "GET") + + def update_conversation( + self, conversation_id: str, metadata: dict + ) -> requests.Response: + """Update conversation via POST /v1/conversations/{conversation_id}.""" + return self.make_request( + f"/v1/conversations/{conversation_id}", "POST", {"metadata": metadata} + ) + + def delete_conversation(self, conversation_id: str) -> requests.Response: + """Delete conversation via DELETE /v1/conversations/{conversation_id}.""" + return self.make_request(f"/v1/conversations/{conversation_id}", "DELETE") + + def list_conversation_items( + self, + conversation_id: str, + limit: Optional[int] = None, + after: Optional[str] = None, + before: Optional[str] = None, + order: str = "asc", + ) -> requests.Response: + """List conversation items via GET /v1/conversations/{conversation_id}/items.""" + params = {"order": order} + if limit: + params["limit"] = limit + if after: + params["after"] = after + if before: + params["before"] = before + return self.make_request( + f"/v1/conversations/{conversation_id}/items", "GET", params=params + ) + + def create_conversation_items( + self, conversation_id: str, items: list + ) -> requests.Response: + """Create conversation items via POST /v1/conversations/{conversation_id}/items.""" + return self.make_request( + f"/v1/conversations/{conversation_id}/items", "POST", {"items": items} + ) + + def get_conversation_item( + self, conversation_id: str, item_id: str + ) -> requests.Response: + """Get conversation item via GET /v1/conversations/{conversation_id}/items/{item_id}.""" + return self.make_request( + f"/v1/conversations/{conversation_id}/items/{item_id}", "GET" + ) + + def delete_conversation_item( + self, conversation_id: str, item_id: str + ) -> requests.Response: + """Delete conversation item via DELETE /v1/conversations/{conversation_id}/items/{item_id}.""" + return self.make_request( + f"/v1/conversations/{conversation_id}/items/{item_id}", "DELETE" + ) + + def parse_sse_events(self, response: requests.Response) -> list: + """ + Parse Server-Sent Events from streaming response. + + Args: + response: requests.Response with stream=True + + Returns: + List of event dictionaries with 'event' and 'data' keys + """ + events = [] + current_event = None + + for line in response.iter_lines(): + if not line: + # Empty line signals end of event + if current_event and current_event.get("data"): + events.append(current_event) + current_event = None + continue + + line = line.decode("utf-8") + + if line.startswith("event:"): + current_event = {"event": line[6:].strip()} + elif line.startswith("data:"): + if current_event is None: + current_event = {} + data_str = line[5:].strip() + try: + current_event["data"] = json.loads(data_str) + except json.JSONDecodeError: + current_event["data"] = data_str + + # Don't forget the last event if stream ends without empty line + if current_event and current_event.get("data"): + events.append(current_event) + + return events + + def wait_for_background_task( + self, response_id: str, timeout: int = 30, poll_interval: float = 0.5 + ) -> dict: + """ + Wait for background task to complete. + + Args: + response_id: Response ID to poll + timeout: Max seconds to wait + poll_interval: Seconds between polls + + Returns: + Final response data + + Raises: + TimeoutError: If task doesn't complete in time + AssertionError: If task fails + """ + start_time = time.time() + + while time.time() - start_time < timeout: + resp = self.get_response(response_id) + self.assertEqual(resp.status_code, 200) + + data = resp.json() + status = data.get("status") + + if status == "completed": + return data + elif status == "failed": + raise AssertionError( + f"Background task failed: {data.get('error', 'Unknown error')}" + ) + elif status == "cancelled": + raise AssertionError("Background task was cancelled") + + time.sleep(poll_interval) + + raise TimeoutError( + f"Background task {response_id} did not complete within {timeout}s" + ) + + +class StateManagementBaseTest(ResponseAPIBaseTest): + """Base class for state management tests (previous_response_id and conversation).""" + + def test_basic_response_creation(self): + """Test basic response creation without state.""" + resp = self.create_response("What is 2+2?", max_output_tokens=50) + self.assertEqual(resp.status_code, 200) + + data = resp.json() + self.assertIn("id", data) + self.assertIn("output", data) + self.assertEqual(data["status"], "completed") + self.assertIn("usage", data) + + def test_streaming_response(self): + """Test streaming response.""" + resp = self.create_response("Count to 5", stream=True, max_output_tokens=50) + self.assertEqual(resp.status_code, 200) + + events = self.parse_sse_events(resp) + self.assertGreater(len(events), 0) + + # Check for response.created event + created_events = [e for e in events if e.get("event") == "response.created"] + self.assertGreater(len(created_events), 0) + + # Check for final completed event or in_progress events + self.assertTrue( + any( + e.get("event") in ["response.completed", "response.in_progress"] + for e in events + ) + ) + + +class ResponseCRUDBaseTest(ResponseAPIBaseTest): + """Base class for Response API CRUD tests.""" + + def test_create_and_get_response(self): + """Test creating response and retrieving it.""" + # Create response + create_resp = self.create_response("Hello, world!") + self.assertEqual(create_resp.status_code, 200) + + create_data = create_resp.json() + response_id = create_data["id"] + + # Get response + get_resp = self.get_response(response_id) + self.assertEqual(get_resp.status_code, 200) + + get_data = get_resp.json() + self.assertEqual(get_data["id"], response_id) + self.assertEqual(get_data["status"], "completed") + + input_resp = self.get_response_input(get_data["id"]) + # change not merge yet + self.assertEqual(input_resp.status_code, 501) + # self.assertEqual(input_resp.status_code, 200) + # input_data = input_resp.json() + # self.assertIn("data", input_data) + # self.assertGreater(len(input_data["data"]), 0) + + @unittest.skip("TODO: Add delete response feature") + def test_delete_response(self): + """Test deleting response.""" + # Create response + create_resp = self.create_response("Test deletion", max_output_tokens=50) + self.assertEqual(create_resp.status_code, 200) + + response_id = create_resp.json()["id"] + + # Delete response + delete_resp = self.delete_response(response_id) + self.assertEqual(delete_resp.status_code, 200) + + # Verify it's deleted (should return 404) + get_resp = self.get_response(response_id) + self.assertEqual(get_resp.status_code, 404) + + @unittest.skip("TODO: Add background response feature") + def test_background_response(self): + """Test background response execution.""" + # Create background response + create_resp = self.create_response( + "Write a short story", background=True, max_output_tokens=100 + ) + self.assertEqual(create_resp.status_code, 200) + + create_data = create_resp.json() + response_id = create_data["id"] + self.assertEqual(create_data["status"], "in_progress") + + # Wait for completion + final_data = self.wait_for_background_task(response_id, timeout=60) + self.assertEqual(final_data["status"], "completed") + + +class ConversationCRUDBaseTest(ResponseAPIBaseTest): + """Base class for Conversation API CRUD tests.""" + + def test_create_and_get_conversation(self): + """Test creating and retrieving conversation.""" + # Create conversation + create_resp = self.create_conversation(metadata={"user": "test_user"}) + self.assertEqual(create_resp.status_code, 200) + + create_data = create_resp.json() + conversation_id = create_data["id"] + self.assertEqual(create_data["metadata"]["user"], "test_user") + + # Get conversation + get_resp = self.get_conversation(conversation_id) + self.assertEqual(get_resp.status_code, 200) + + get_data = get_resp.json() + self.assertEqual(get_data["id"], conversation_id) + self.assertEqual(get_data["metadata"]["user"], "test_user") + + def test_update_conversation(self): + """Test updating conversation metadata.""" + # Create conversation + create_resp = self.create_conversation(metadata={"key1": "value1"}) + self.assertEqual(create_resp.status_code, 200) + conversation_id = create_resp.json()["id"] + + # Update conversation + update_resp = self.update_conversation( + conversation_id, metadata={"key1": "value1", "key2": "value2"} + ) + self.assertEqual(update_resp.status_code, 200) + + # Verify update + get_resp = self.get_conversation(conversation_id) + get_data = get_resp.json() + self.assertEqual(get_data["metadata"]["key2"], "value2") + + def test_delete_conversation(self): + """Test deleting conversation.""" + # Create conversation + create_resp = self.create_conversation() + self.assertEqual(create_resp.status_code, 200) + conversation_id = create_resp.json()["id"] + + # Delete conversation + delete_resp = self.delete_conversation(conversation_id) + self.assertEqual(delete_resp.status_code, 200) + + # Verify deletion + get_resp = self.get_conversation(conversation_id) + self.assertEqual(get_resp.status_code, 404) + + def test_list_conversation_items(self): + """Test listing conversation items.""" + # Create conversation + conv_resp = self.create_conversation() + conversation_id = conv_resp.json()["id"] + + # Create response with conversation + self.create_response( + "First message", conversation=conversation_id, max_output_tokens=50 + ) + self.create_response( + "Second message", conversation=conversation_id, max_output_tokens=50 + ) + + # List items + list_resp = self.list_conversation_items(conversation_id) + self.assertEqual(list_resp.status_code, 200) + + list_data = list_resp.json() + self.assertIn("data", list_data) + # Should have at least 4 items (2 inputs + 2 outputs) + self.assertGreaterEqual(len(list_data["data"]), 4) diff --git a/sgl-router/py_test/e2e_response_api/conftest.py b/sgl-router/py_test/e2e_response_api/conftest.py new file mode 100644 index 000000000..04bd6c453 --- /dev/null +++ b/sgl-router/py_test/e2e_response_api/conftest.py @@ -0,0 +1,39 @@ +""" +pytest configuration for e2e_response_api tests. + +This configures pytest to not collect base test classes that are meant to be inherited. +""" + +import pytest + + +def pytest_collection_modifyitems(config, items): + """ + Modify test collection to exclude base test classes. + + Base test classes are meant to be inherited, not run directly. + We exclude any test that comes from these base classes: + - StateManagementBaseTest + - ResponseCRUDBaseTest + - ConversationCRUDBaseTest + - MCPTests + - StateManagementTests + """ + base_class_names = { + "StateManagementBaseTest", + "ResponseCRUDBaseTest", + "ConversationCRUDBaseTest", + "MCPTests", + "StateManagementTests", + } + + # Filter out tests from base classes + filtered_items = [] + for item in items: + # Check if the test's parent class is a base class + parent_name = item.parent.name if hasattr(item, "parent") else None + if parent_name not in base_class_names: + filtered_items.append(item) + + # Update items list + items[:] = filtered_items diff --git a/sgl-router/py_test/e2e_response_api/mcp.py b/sgl-router/py_test/e2e_response_api/mcp.py new file mode 100644 index 000000000..57dda72c8 --- /dev/null +++ b/sgl-router/py_test/e2e_response_api/mcp.py @@ -0,0 +1,229 @@ +""" +MCP (Model Context Protocol) tests for Response API. + +Tests MCP tool calling in both streaming and non-streaming modes. +These tests should work across all backends that support MCP (OpenAI, XAI). +""" + +from base import ResponseAPIBaseTest + + +class MCPTests(ResponseAPIBaseTest): + """Tests for MCP tool calling in both streaming and non-streaming modes.""" + + def test_mcp_basic_tool_call(self): + """Test basic MCP tool call (non-streaming).""" + tools = [ + { + "type": "mcp", + "server_label": "deepwiki", + "server_url": "https://mcp.deepwiki.com/mcp", + "require_approval": "never", + } + ] + + resp = self.create_response( + "What transport protocols does the 2025-03-26 version of the MCP spec (modelcontextprotocol/modelcontextprotocol) support?", + tools=tools, + stream=False, + ) + + # Should successfully make the request + self.assertEqual(resp.status_code, 200) + + data = resp.json() + print(f"MCP response: {data}") + + # Basic response structure + self.assertIn("id", data) + self.assertIn("status", data) + self.assertEqual(data["status"], "completed") + self.assertIn("output", data) + self.assertIn("model", data) + + # Verify output array is not empty + output = data["output"] + self.assertIsInstance(output, list) + self.assertGreater(len(output), 0) + + # Check for MCP-specific output types + output_types = [item.get("type") for item in output] + + # Should have mcp_list_tools - tools are listed before calling + self.assertIn( + "mcp_list_tools", output_types, "Response should contain mcp_list_tools" + ) + + # Should have at least one mcp_call + mcp_calls = [item for item in output if item.get("type") == "mcp_call"] + self.assertGreater( + len(mcp_calls), 0, "Response should contain at least one mcp_call" + ) + + # Verify mcp_call structure + for mcp_call in mcp_calls: + self.assertIn("id", mcp_call) + self.assertIn("status", mcp_call) + self.assertEqual(mcp_call["status"], "completed") + self.assertIn("server_label", mcp_call) + self.assertEqual(mcp_call["server_label"], "deepwiki") + self.assertIn("name", mcp_call) + self.assertIn("arguments", mcp_call) + self.assertIn("output", mcp_call) + + # Should have final message output + messages = [item for item in output if item.get("type") == "message"] + self.assertGreater( + len(messages), 0, "Response should contain at least one message" + ) + + # Verify message structure + for msg in messages: + self.assertIn("content", msg) + self.assertIsInstance(msg["content"], list) + + # Check content has text + for content_item in msg["content"]: + if content_item.get("type") == "output_text": + self.assertIn("text", content_item) + self.assertIsInstance(content_item["text"], str) + self.assertGreater(len(content_item["text"]), 0) + + def test_mcp_basic_tool_call_streaming(self): + """Test basic MCP tool call (streaming).""" + tools = [ + { + "type": "mcp", + "server_label": "deepwiki", + "server_url": "https://mcp.deepwiki.com/mcp", + "require_approval": "never", + } + ] + + resp = self.create_response( + "What transport protocols does the 2025-03-26 version of the MCP spec (modelcontextprotocol/modelcontextprotocol) support?", + tools=tools, + stream=True, + ) + + # Should successfully make the request + self.assertEqual(resp.status_code, 200) + + events = self.parse_sse_events(resp) + self.assertGreater(len(events), 0) + + event_types = [e.get("event") for e in events] + + # Check for lifecycle events + self.assertIn( + "response.created", event_types, "Should have response.created event" + ) + self.assertIn( + "response.completed", event_types, "Should have response.completed event" + ) + + # Check for MCP list tools events + self.assertIn( + "response.output_item.added", + event_types, + "Should have output_item.added events", + ) + self.assertIn( + "response.mcp_list_tools.in_progress", + event_types, + "Should have mcp_list_tools.in_progress event", + ) + self.assertIn( + "response.mcp_list_tools.completed", + event_types, + "Should have mcp_list_tools.completed event", + ) + + # Check for MCP call events + self.assertIn( + "response.mcp_call.in_progress", + event_types, + "Should have mcp_call.in_progress event", + ) + self.assertIn( + "response.mcp_call_arguments.delta", + event_types, + "Should have mcp_call_arguments.delta event", + ) + self.assertIn( + "response.mcp_call_arguments.done", + event_types, + "Should have mcp_call_arguments.done event", + ) + self.assertIn( + "response.mcp_call.completed", + event_types, + "Should have mcp_call.completed event", + ) + + # Check for text output events + self.assertIn( + "response.content_part.added", + event_types, + "Should have content_part.added event", + ) + self.assertIn( + "response.output_text.delta", + event_types, + "Should have output_text.delta events", + ) + self.assertIn( + "response.output_text.done", + event_types, + "Should have output_text.done event", + ) + self.assertIn( + "response.content_part.done", + event_types, + "Should have content_part.done event", + ) + + # Verify final completed event has full response + completed_events = [e for e in events if e.get("event") == "response.completed"] + self.assertEqual(len(completed_events), 1) + + final_response = completed_events[0].get("data", {}).get("response", {}) + self.assertIn("id", final_response) + self.assertEqual(final_response.get("status"), "completed") + self.assertIn("output", final_response) + + # Verify final output contains expected items + final_output = final_response.get("output", []) + final_output_types = [item.get("type") for item in final_output] + + self.assertIn("mcp_list_tools", final_output_types) + self.assertIn("mcp_call", final_output_types) + self.assertIn("message", final_output_types) + + # Verify mcp_call items in final output + mcp_calls = [item for item in final_output if item.get("type") == "mcp_call"] + self.assertGreater(len(mcp_calls), 0) + + for mcp_call in mcp_calls: + self.assertEqual(mcp_call.get("status"), "completed") + self.assertEqual(mcp_call.get("server_label"), "deepwiki") + self.assertIn("name", mcp_call) + self.assertIn("arguments", mcp_call) + self.assertIn("output", mcp_call) + + # Verify text deltas combine to final message + text_deltas = [ + e.get("data", {}).get("delta", "") + for e in events + if e.get("event") == "response.output_text.delta" + ] + self.assertGreater(len(text_deltas), 0, "Should have text deltas") + + # Get final text from output_text.done event + text_done_events = [ + e for e in events if e.get("event") == "response.output_text.done" + ] + self.assertGreater(len(text_done_events), 0) + + final_text = text_done_events[0].get("data", {}).get("text", "") + self.assertGreater(len(final_text), 0, "Final text should not be empty") diff --git a/sgl-router/py_test/e2e_response_api/router_fixtures.py b/sgl-router/py_test/e2e_response_api/router_fixtures.py new file mode 100644 index 000000000..90c192cca --- /dev/null +++ b/sgl-router/py_test/e2e_response_api/router_fixtures.py @@ -0,0 +1,554 @@ +""" +Fixtures for launching OpenAI/XAI router for response API e2e testing. + +This module provides fixtures for launching SGLang router with OpenAI or XAI backends: + 1. Launch router with --backend openai pointing to OpenAI or XAI API + 2. Configure history backend (memory or oracle) + +This supports testing the Response API against real cloud providers. +""" + +import os +import socket +import subprocess +import time +from typing import Optional + +import requests + + +def wait_for_workers_ready( + router_url: str, + expected_workers: int, + timeout: int = 300, + api_key: Optional[str] = None, +) -> None: + """ + Wait for router to have all workers connected. + + Polls the /workers endpoint until the 'total' field matches expected_workers. + + Example response from /workers endpoint: + {"workers":[],"total":0,"stats":{"prefill_count":0,"decode_count":0,"regular_count":0}} + + Args: + router_url: Base URL of router (e.g., "http://127.0.0.1:30000") + expected_workers: Number of workers expected to be connected + timeout: Max seconds to wait + api_key: Optional API key for authentication + """ + start_time = time.time() + last_error = None + attempt = 0 + + headers = {} + if api_key: + headers["Authorization"] = f"Bearer {api_key}" + + with requests.Session() as session: + while time.time() - start_time < timeout: + attempt += 1 + elapsed = int(time.time() - start_time) + + # Print progress every 10 seconds + if elapsed > 0 and elapsed % 10 == 0 and attempt % 10 == 0: + print(f" Still waiting for workers... ({elapsed}/{timeout}s elapsed)") + + try: + response = session.get( + f"{router_url}/workers", headers=headers, timeout=5 + ) + if response.status_code == 200: + data = response.json() + total_workers = data.get("total", 0) + + if total_workers == expected_workers: + print( + f" All {expected_workers} workers connected after {elapsed}s" + ) + return + else: + last_error = f"Workers: {total_workers}/{expected_workers}" + else: + last_error = f"HTTP {response.status_code}" + except requests.ConnectionError: + last_error = "Connection refused (router not ready yet)" + except requests.Timeout: + last_error = "Timeout" + except requests.RequestException as e: + last_error = str(e) + except (ValueError, KeyError) as e: + last_error = f"Invalid response: {e}" + + time.sleep(1) + + raise TimeoutError( + f"Router at {router_url} did not get {expected_workers} workers within {timeout}s.\n" + f"Last status: {last_error}\n" + f"Hint: Run with SHOW_ROUTER_LOGS=1 to see startup logs" + ) + + +def find_free_port() -> int: + """Find an available port on localhost.""" + with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s: + s.bind(("127.0.0.1", 0)) + return s.getsockname()[1] + + +def wait_for_router_ready( + router_url: str, + timeout: int = 60, + api_key: Optional[str] = None, +) -> None: + """ + Wait for router to be ready. + + Polls the /health endpoint until it returns 200. + + Args: + router_url: Base URL of router (e.g., "http://127.0.0.1:30000") + timeout: Max seconds to wait + api_key: Optional API key for authentication + """ + start_time = time.time() + last_error = None + attempt = 0 + + headers = {} + if api_key: + headers["Authorization"] = f"Bearer {api_key}" + + with requests.Session() as session: + while time.time() - start_time < timeout: + attempt += 1 + elapsed = int(time.time() - start_time) + + # Print progress every 10 seconds + if elapsed > 0 and elapsed % 10 == 0 and attempt % 10 == 0: + print(f" Still waiting for router... ({elapsed}/{timeout}s elapsed)") + + try: + response = session.get( + f"{router_url}/health", headers=headers, timeout=5 + ) + if response.status_code == 200: + print(f" Router ready after {elapsed}s") + return + else: + last_error = f"HTTP {response.status_code}" + except requests.ConnectionError: + last_error = "Connection refused (router not ready yet)" + except requests.Timeout: + last_error = "Timeout" + except requests.RequestException as e: + last_error = str(e) + + time.sleep(1) + + raise TimeoutError( + f"Router at {router_url} did not become ready within {timeout}s.\n" + f"Last status: {last_error}\n" + f"Hint: Run with SHOW_ROUTER_LOGS=1 to see startup logs" + ) + + +def popen_launch_openai_xai_router( + backend: str, # "openai" or "xai" + base_url: str, + timeout: int = 60, + history_backend: str = "memory", + api_key: Optional[str] = None, + router_args: Optional[list] = None, + stdout=None, + stderr=None, + prometheus_port: Optional[int] = None, +) -> dict: + """ + Launch SGLang router with OpenAI or XAI backend. + + This approach: + 1. Starts router with --backend openai + 2. Points to OpenAI or XAI API via --worker-urls + 3. Configures history backend (memory or oracle) + 4. Waits for router health check to pass + + Args: + backend: "openai" or "xai" + base_url: Base URL for router (e.g., "http://127.0.0.1:30000") + timeout: Timeout for router startup (default: 60s) + history_backend: "memory" or "oracle" (default: memory) + api_key: Optional API key for router authentication + router_args: Additional arguments for router + stdout: Optional file handle for router stdout + stderr: Optional file handle for router stderr + + Returns: + dict with: + - router: router process object + - base_url: router URL (HTTP endpoint) + + Example: + >>> cluster = popen_launch_openai_xai_router( + ... "openai", "http://127.0.0.1:30000" + ... ) + >>> # Use cluster['base_url'] for HTTP requests + >>> # Cleanup: + >>> kill_process_tree(cluster['router'].pid) + """ + show_output = os.environ.get("SHOW_ROUTER_LOGS", "0") == "1" + + # Parse router port from base_url + if ":" in base_url.split("//")[-1]: + router_port = int(base_url.split(":")[-1]) + else: + router_port = find_free_port() + + print(f"\n{'='*70}") + print(f"Launching {backend.upper()} router") + print(f"{'='*70}") + print(f" Backend: {backend}") + print(f" Router port: {router_port}") + print(f" History backend: {history_backend}") + + # Determine worker URL based on backend + if backend == "openai": + worker_url = "https://api.openai.com" + # Get API key from environment + backend_api_key = os.environ.get("OPENAI_API_KEY") + if not backend_api_key: + raise ValueError( + "OPENAI_API_KEY environment variable must be set for OpenAI backend" + ) + elif backend == "xai": + worker_url = "https://api.x.ai" + # Get API key from environment + backend_api_key = os.environ.get("XAI_API_KEY") + if not backend_api_key: + raise ValueError( + "XAI_API_KEY environment variable must be set for XAI backend" + ) + else: + raise ValueError(f"Unsupported backend: {backend}") + + print(f" Worker URL: {worker_url}") + + # Build router command + router_cmd = [ + "python3", + "-m", + "sglang_router.launch_router", + "--host", + "127.0.0.1", + "--port", + str(router_port), + "--backend", + "openai", + "--worker-urls", + worker_url, + "--history-backend", + history_backend, + "--log-level", + "warn", + ] + + # Note: Not adding --api-key to router command for local testing + # The router will not require authentication + + # Add Prometheus port to avoid conflicts (use unique port or disable) + if prometheus_port is None: + # Auto-assign a unique prometheus port based on router port + prometheus_port = router_port + 1000 + router_cmd.extend(["--prometheus-port", str(prometheus_port)]) + + # Add router-specific args + if router_args: + router_cmd.extend(router_args) + + if show_output: + print(f" Command: {' '.join(router_cmd)}") + + # Set up environment with backend API key + env = os.environ.copy() + if backend == "openai": + env["OPENAI_API_KEY"] = backend_api_key + else: + env["XAI_API_KEY"] = backend_api_key + + # Launch router + if show_output: + router_proc = subprocess.Popen( + router_cmd, + env=env, + stdout=stdout, + stderr=stderr, + ) + else: + router_proc = subprocess.Popen( + router_cmd, + stdout=stdout if stdout is not None else subprocess.PIPE, + stderr=stderr if stderr is not None else subprocess.PIPE, + env=env, + ) + + print(f" PID: {router_proc.pid}") + + # Wait for router to be ready + router_url = f"http://127.0.0.1:{router_port}" + print(f"\nWaiting for router to start at {router_url}...") + + try: + wait_for_router_ready(router_url, timeout=timeout, api_key=None) + print(f"✓ Router ready at {router_url}") + except TimeoutError: + print(f"✗ Router failed to start") + # Cleanup: kill router + try: + router_proc.kill() + except: + pass + raise + + print(f"\n{'='*70}") + print(f"✓ {backend.upper()} router ready!") + print(f" Router: {router_url}") + print(f"{'='*70}\n") + + return { + "router": router_proc, + "base_url": router_url, + } + + +def popen_launch_workers_and_router( + model: str, + base_url: str, + timeout: int = 300, + num_workers: int = 2, + policy: str = "round_robin", + api_key: Optional[str] = None, + worker_args: Optional[list] = None, + router_args: Optional[list] = None, + tp_size: int = 1, + env: Optional[dict] = None, + stdout=None, + stderr=None, +) -> dict: + """ + Launch SGLang workers and gRPC router separately. + + This approach: + 1. Starts N SGLang workers with --grpc-mode flag + 2. Waits for workers to initialize (process startup) + 3. Starts a gRPC router pointing to those workers + 4. Waits for router health check to pass (router validates worker connectivity) + + This matches production deployment patterns better than the integrated approach. + + Args: + model: Model path (e.g., /home/ubuntu/models/llama-3.1-8b-instruct) + base_url: Base URL for router (e.g., "http://127.0.0.1:8080") + timeout: Timeout for server startup (default: 300s) + num_workers: Number of workers to launch + policy: Routing policy (round_robin, random, power_of_two, cache_aware) + api_key: Optional API key for router + worker_args: Additional arguments for workers (e.g., ["--context-len", "8192"]) + router_args: Additional arguments for router (e.g., ["--max-total-token", "1536"]) + tp_size: Tensor parallelism size for workers (default: 1) + env: Optional environment variables for workers (e.g., {"SGLANG_CLIP_MAX_NEW_TOKENS_ESTIMATION": "256"}) + stdout: Optional file handle for worker stdout (default: subprocess.PIPE) + stderr: Optional file handle for worker stderr (default: subprocess.PIPE) + + Returns: + dict with: + - workers: list of worker process objects + - worker_urls: list of gRPC worker URLs + - router: router process object + - base_url: router URL (HTTP endpoint) + + Example: + >>> cluster = popen_launch_workers_and_router(model, base_url, num_workers=2) + >>> # Use cluster['base_url'] for HTTP requests + >>> # Cleanup: + >>> for worker in cluster['workers']: + >>> kill_process_tree(worker.pid) + >>> kill_process_tree(cluster['router'].pid) + """ + show_output = os.environ.get("SHOW_ROUTER_LOGS", "0") == "1" + + # Parse router port from base_url + if ":" in base_url.split("//")[-1]: + router_port = int(base_url.split(":")[-1]) + else: + router_port = find_free_port() + + print(f"\n{'='*70}") + print(f"Launching gRPC cluster (separate workers + router)") + print(f"{'='*70}") + print(f" Model: {model}") + print(f" Router port: {router_port}") + print(f" Workers: {num_workers}") + print(f" TP size: {tp_size}") + print(f" Policy: {policy}") + + # Step 1: Launch workers with gRPC enabled + workers = [] + worker_urls = [] + + for i in range(num_workers): + worker_port = find_free_port() + worker_url = f"grpc://127.0.0.1:{worker_port}" + worker_urls.append(worker_url) + + print(f"\n[Worker {i+1}/{num_workers}]") + print(f" Port: {worker_port}") + print(f" URL: {worker_url}") + + # Build worker command + worker_cmd = [ + "python3", + "-m", + "sglang.launch_server", + "--model-path", + model, + "--host", + "127.0.0.1", + "--port", + str(worker_port), + "--grpc-mode", # Enable gRPC for this worker + "--mem-fraction-static", + "0.8", + "--attention-backend", + "fa3", + ] + + # Add TP size + if tp_size > 1: + worker_cmd.extend(["--tp-size", str(tp_size)]) + + # Add worker-specific args + if worker_args: + worker_cmd.extend(worker_args) + + # Launch worker with optional environment variables + if show_output: + worker_proc = subprocess.Popen( + worker_cmd, + env=env, + stdout=stdout, + stderr=stderr, + ) + else: + worker_proc = subprocess.Popen( + worker_cmd, + stdout=stdout if stdout is not None else subprocess.PIPE, + stderr=stderr if stderr is not None else subprocess.PIPE, + env=env, + ) + + workers.append(worker_proc) + print(f" PID: {worker_proc.pid}") + + # Give workers a moment to start binding to ports + # The router will check worker health when it starts + print(f"\nWaiting for {num_workers} workers to initialize (20s)...") + time.sleep(20) + + # Quick check: make sure worker processes are still alive + for i, worker in enumerate(workers): + if worker.poll() is not None: + print(f" ✗ Worker {i+1} died during startup (exit code: {worker.poll()})") + # Cleanup: kill all workers + for w in workers: + try: + w.kill() + except: + pass + raise RuntimeError(f"Worker {i+1} failed to start") + + print(f"✓ All {num_workers} workers started (router will verify connectivity)") + + # Step 2: Launch router pointing to workers + print(f"\n[Router]") + print(f" Port: {router_port}") + print(f" Worker URLs: {', '.join(worker_urls)}") + + # Build router command + router_cmd = [ + "python3", + "-m", + "sglang_router.launch_router", + "--host", + "127.0.0.1", + "--port", + str(router_port), + "--prometheus-port", + "9321", + "--policy", + policy, + "--model-path", + model, + ] + + # Add worker URLs + router_cmd.append("--worker-urls") + router_cmd.extend(worker_urls) + + # Add API key + if api_key: + router_cmd.extend(["--api-key", api_key]) + + # Add router-specific args + if router_args: + router_cmd.extend(router_args) + + if show_output: + print(f" Command: {' '.join(router_cmd)}") + + # Launch router + if show_output: + router_proc = subprocess.Popen(router_cmd) + else: + router_proc = subprocess.Popen( + router_cmd, + stdout=subprocess.PIPE, + stderr=subprocess.PIPE, + ) + + print(f" PID: {router_proc.pid}") + + # Wait for router to be ready + router_url = f"http://127.0.0.1:{router_port}" + print(f"\nWaiting for router to start at {router_url}...") + + try: + wait_for_workers_ready( + router_url, expected_workers=num_workers, timeout=180, api_key=api_key + ) + print(f"✓ Router ready at {router_url}") + except TimeoutError: + print(f"✗ Router failed to start") + # Cleanup: kill router and all workers + try: + router_proc.kill() + except: + pass + for worker in workers: + try: + worker.kill() + except: + pass + raise + + print(f"\n{'='*70}") + print(f"✓ gRPC cluster ready!") + print(f" Router: {router_url}") + print(f" Workers: {len(workers)}") + print(f"{'='*70}\n") + + return { + "workers": workers, + "worker_urls": worker_urls, + "router": router_proc, + "base_url": router_url, + } diff --git a/sgl-router/py_test/e2e_response_api/state_management.py b/sgl-router/py_test/e2e_response_api/state_management.py new file mode 100644 index 000000000..b35049093 --- /dev/null +++ b/sgl-router/py_test/e2e_response_api/state_management.py @@ -0,0 +1,135 @@ +""" +State management tests for Response API. + +Tests both previous_response_id and conversation-based state management. +These tests should work across all backends (OpenAI, XAI, gRPC). +""" + +import unittest + +from base import ResponseAPIBaseTest + + +class StateManagementTests(ResponseAPIBaseTest): + """Tests for state management using previous_response_id and conversation.""" + + def test_previous_response_id_chaining(self): + """Test chaining responses using previous_response_id.""" + # First response + resp1 = self.create_response( + "My name is Alice and my friend is Bob. Remember it." + ) + self.assertEqual(resp1.status_code, 200) + response1_id = resp1.json()["id"] + + # Second response referencing first + resp2 = self.create_response( + "What is my name", previous_response_id=response1_id + ) + self.assertEqual(resp2.status_code, 200) + response2_data = resp2.json() + + # The model should remember the name from previous response + output_text = self._extract_output_text(response2_data) + self.assertIn("Alice", output_text) + + # Third response referencing second + resp3 = self.create_response( + "What is my friend name?", + previous_response_id=response2_data["id"], + ) + response3_data = resp3.json() + output_text = self._extract_output_text(response3_data) + self.assertEqual(resp3.status_code, 200) + self.assertIn("Bob", output_text) + + @unittest.skip("TODO: Add the invalid previous_response_id check") + def test_previous_response_id_invalid(self): + """Test using invalid previous_response_id.""" + resp = self.create_response( + "Test", previous_response_id="resp_invalid123", max_output_tokens=50 + ) + # Should return 404 or 400 for invalid response ID + if resp.status_code != 200: + print(f"\n❌ Response creation failed!") + print(f"Status: {resp.status_code}") + print(f"Response: {resp.text}") + self.assertIn(resp.status_code, [400, 404]) + + def test_conversation_with_multiple_turns(self): + """Test state management using conversation ID.""" + # Create conversation + conv_resp = self.create_conversation(metadata={"topic": "math"}) + self.assertEqual(conv_resp.status_code, 200) + + conversation_id = conv_resp.json()["id"] + + # First response in conversation + resp1 = self.create_response("I have 5 apples.", conversation=conversation_id) + self.assertEqual(resp1.status_code, 200) + + # Second response in same conversation + resp2 = self.create_response( + "How many apples do I have?", + conversation=conversation_id, + ) + self.assertEqual(resp2.status_code, 200) + output_text = self._extract_output_text(resp2.json()) + + # Should remember "5 apples" + self.assertTrue("5" in output_text or "five" in output_text.lower()) + + # Third response in same conversation + resp3 = self.create_response( + "If I get 3 more, how many total?", + conversation=conversation_id, + ) + self.assertEqual(resp3.status_code, 200) + output_text = self._extract_output_text(resp3.json()) + + # Should calculate 5 + 3 = 8 + self.assertTrue("8" in output_text or "eight" in output_text.lower()) + list_resp = self.list_conversation_items(conversation_id) + self.assertEqual(list_resp.status_code, 200) + items = list_resp.json()["data"] + # Should have at least 6 items (3 inputs + 3 outputs) + self.assertGreaterEqual(len(items), 6) + + def test_mutually_exclusive_parameters(self): + """Test that previous_response_id and conversation are mutually exclusive.""" + # Create conversation and response + conv_resp = self.create_conversation() + conversation_id = conv_resp.json()["id"] + + resp1 = self.create_response("Test") + response1_id = resp1.json()["id"] + + # Try to use both parameters + resp = self.create_response( + "This should fail", + previous_response_id=response1_id, + conversation=conversation_id, + ) + + # Should return 400 Bad Request + self.assertEqual(resp.status_code, 400) + error_data = resp.json() + self.assertIn("error", error_data) + self.assertIn("mutually exclusive", error_data["error"]["message"].lower()) + + # Helper methods + + def _extract_output_text(self, response_data: dict) -> str: + """Extract text content from response output.""" + output = response_data.get("output", []) + if not output: + return "" + + text_parts = [] + for item in output: + content = item.get("content", []) + for part in content: + if part.get("type") == "output_text": + text_parts.append(part.get("text", "")) + + return " ".join(text_parts) diff --git a/sgl-router/py_test/e2e_response_api/test_response_api.py b/sgl-router/py_test/e2e_response_api/test_response_api.py new file mode 100644 index 000000000..5a053be1d --- /dev/null +++ b/sgl-router/py_test/e2e_response_api/test_response_api.py @@ -0,0 +1,130 @@ +""" +OpenAI backend tests for Response API. + +Run with: + export OPENAI_API_KEY=your_key + python3 -m pytest py_test/e2e_response_api/test_openai_backend.py -v + python3 -m unittest e2e_response_api.test_openai_backend.TestOpenAIStateManagement +""" + +import os +import sys +import unittest +from pathlib import Path + +# Add current directory for imports +_TEST_DIR = Path(__file__).parent +sys.path.insert(0, str(_TEST_DIR)) + +# Import local modules +from base import ConversationCRUDBaseTest, ResponseCRUDBaseTest +from mcp import MCPTests +from router_fixtures import ( + popen_launch_openai_xai_router, + popen_launch_workers_and_router, +) +from state_management import StateManagementTests +from util import kill_process_tree + + +class TestOpenaiBackend( + ResponseCRUDBaseTest, ConversationCRUDBaseTest, StateManagementTests, MCPTests +): + """End to end tests for OpenAI backend.""" + + api_key = os.environ.get("OPENAI_API_KEY") + + @classmethod + def setUpClass(cls): + cls.model = "gpt-5-nano" + cls.base_url_port = "http://127.0.0.1:30010" + + cls.cluster = popen_launch_openai_xai_router( + backend="openai", + base_url=cls.base_url_port, + history_backend="memory", + ) + + cls.base_url = cls.cluster["base_url"] + + @classmethod + def tearDownClass(cls): + kill_process_tree(cls.cluster["router"].pid) + + +class TestXaiBackend(StateManagementTests): + """End to end tests for XAI backend.""" + + api_key = os.environ.get("XAI_API_KEY") + + @classmethod + def setUpClass(cls): + cls.model = "grok-4-fast" + cls.base_url_port = "http://127.0.0.1:30023" + + cls.cluster = popen_launch_openai_xai_router( + backend="xai", + base_url=cls.base_url_port, + history_backend="memory", + ) + + cls.base_url = cls.cluster["base_url"] + + @classmethod + def tearDownClass(cls): + kill_process_tree(cls.cluster["router"].pid) + + +class TestGrpcBackend(StateManagementTests, MCPTests): + """End to end tests for gRPC backend.""" + + @classmethod + def setUpClass(cls): + cls.model = "/home/ubuntu/models/meta-llama/Llama-3.1-8B-Instruct" + cls.base_url_port = "http://127.0.0.1:30030" + + cls.cluster = popen_launch_workers_and_router( + cls.model, + cls.base_url_port, + timeout=90, + num_workers=1, + tp_size=2, + policy="round_robin", + router_args=["--history-backend", "memory"], + ) + + cls.base_url = cls.cluster["base_url"] + + @classmethod + def tearDownClass(cls): + kill_process_tree(cls.cluster["router"].pid) + for worker in cls.cluster.get("workers", []): + kill_process_tree(worker.pid) + + @unittest.skip( + "TODO: transport error, details: [], metadata: MetadataMap { headers: {} }" + ) + def test_previous_response_id_chaining(self): + super().test_previous_response_id_chaining() + + @unittest.skip("TODO: return 501 Not Implemented") + def test_conversation_with_multiple_turns(self): + super().test_conversation_with_multiple_turns() + + @unittest.skip("TODO: decode error message") + def test_mutually_exclusive_parameters(self): + super().test_mutually_exclusive_parameters() + + @unittest.skip( + "TODO: Pipeline execution failed: Pipeline stage WorkerSelection failed" + ) + def test_mcp_basic_tool_call(self): + super().test_mcp_basic_tool_call() + + @unittest.skip("TODO: no event fields") + def test_mcp_basic_tool_call_streaming(self): + return super().test_mcp_basic_tool_call_streaming() + + +if __name__ == "__main__": + unittest.main() diff --git a/sgl-router/py_test/e2e_response_api/util.py b/sgl-router/py_test/e2e_response_api/util.py new file mode 100644 index 000000000..74dddf015 --- /dev/null +++ b/sgl-router/py_test/e2e_response_api/util.py @@ -0,0 +1,82 @@ +""" +Utility functions for Response API e2e tests. +""" + +import os +import signal +import threading +import unittest + +import psutil + + +def kill_process_tree(parent_pid, include_parent: bool = True, skip_pid: int = None): + """ + Kill the process and all its child processes. + + Args: + parent_pid: PID of the parent process + include_parent: Whether to kill the parent process itself + skip_pid: Optional PID to skip during cleanup + """ + # Remove sigchld handler to avoid spammy logs + if threading.current_thread() is threading.main_thread(): + signal.signal(signal.SIGCHLD, signal.SIG_DFL) + + if parent_pid is None: + parent_pid = os.getpid() + include_parent = False + + try: + itself = psutil.Process(parent_pid) + except psutil.NoSuchProcess: + return + + children = itself.children(recursive=True) + for child in children: + if child.pid == skip_pid: + continue + try: + child.kill() + except psutil.NoSuchProcess: + pass + + if include_parent: + try: + itself.kill() + except psutil.NoSuchProcess: + pass + + +class CustomTestCase(unittest.TestCase): + """ + Custom test case base class with retry support. + + This provides automatic test retry functionality based on environment variables. + """ + + def _callTestMethod(self, method): + """Override to add retry logic.""" + max_retry = int(os.environ.get("SGLANG_TEST_MAX_RETRY", "0")) + + if max_retry == 0: + # No retry, just run once + return super(CustomTestCase, self)._callTestMethod(method) + + # Retry logic + for attempt in range(max_retry + 1): + try: + return super(CustomTestCase, self)._callTestMethod(method) + except Exception as e: + if attempt < max_retry: + print( + f"Test failed on attempt {attempt + 1}/{max_retry + 1}, retrying..." + ) + continue + else: + # Last attempt, re-raise the exception + raise + + def setUp(self): + """Print test method name at the start of each test.""" + print(f"[Test Method] {self._testMethodName}", flush=True)