[router] Add comprehensive E2E tests for Response API (#11988)
This commit is contained in:
58
.github/workflows/pr-test-rust.yml
vendored
58
.github/workflows/pr-test-rust.yml
vendored
@@ -144,12 +144,6 @@ jobs:
|
|||||||
python3 -m pip --no-cache-dir install --upgrade --break-system-packages genai-bench==0.0.2
|
python3 -m pip --no-cache-dir install --upgrade --break-system-packages genai-bench==0.0.2
|
||||||
pytest -m e2e -s -vv -o log_cli=true --log-cli-level=INFO
|
pytest -m e2e -s -vv -o log_cli=true --log-cli-level=INFO
|
||||||
|
|
||||||
- name: Run Python E2E gRPC tests
|
|
||||||
run: |
|
|
||||||
bash scripts/killall_sglang.sh "nuk_gpus"
|
|
||||||
cd sgl-router
|
|
||||||
SHOW_ROUTER_LOGS=1 ROUTER_LOCAL_MODEL_PATH="/home/ubuntu/models" pytest py_test/e2e_grpc -s -vv -o log_cli=true --log-cli-level=INFO
|
|
||||||
|
|
||||||
- name: Upload benchmark results
|
- name: Upload benchmark results
|
||||||
if: success()
|
if: success()
|
||||||
uses: actions/upload-artifact@v4
|
uses: actions/upload-artifact@v4
|
||||||
@@ -157,8 +151,58 @@ jobs:
|
|||||||
name: genai-bench-results-all-policies
|
name: genai-bench-results-all-policies
|
||||||
path: sgl-router/benchmark_**/
|
path: sgl-router/benchmark_**/
|
||||||
|
|
||||||
|
pytest-rust-2:
|
||||||
|
if: github.event_name != 'pull_request' || contains(github.event.pull_request.labels.*.name, 'run-ci')
|
||||||
|
runs-on: 4-gpu-a10
|
||||||
|
timeout-minutes: 16
|
||||||
|
steps:
|
||||||
|
- name: Checkout code
|
||||||
|
uses: actions/checkout@v4
|
||||||
|
|
||||||
|
- name: Install rust dependencies
|
||||||
|
run: |
|
||||||
|
bash scripts/ci/ci_install_rust.sh
|
||||||
|
|
||||||
|
- name: Configure sccache
|
||||||
|
uses: mozilla-actions/sccache-action@v0.0.9
|
||||||
|
with:
|
||||||
|
version: "v0.10.0"
|
||||||
|
|
||||||
|
- name: Rust cache
|
||||||
|
uses: Swatinem/rust-cache@v2
|
||||||
|
with:
|
||||||
|
workspaces: sgl-router
|
||||||
|
cache-all-crates: true
|
||||||
|
cache-on-failure: true
|
||||||
|
|
||||||
|
- name: Install SGLang dependencies
|
||||||
|
run: |
|
||||||
|
sudo --preserve-env=PATH bash scripts/ci/ci_install_dependency.sh
|
||||||
|
|
||||||
|
- name: Build python binding
|
||||||
|
run: |
|
||||||
|
source "$HOME/.cargo/env"
|
||||||
|
export RUSTC_WRAPPER=sccache
|
||||||
|
cd sgl-router
|
||||||
|
pip install setuptools-rust wheel build
|
||||||
|
python3 -m build
|
||||||
|
pip install --force-reinstall dist/*.whl
|
||||||
|
|
||||||
|
- name: Run Python E2E response API tests
|
||||||
|
run: |
|
||||||
|
bash scripts/killall_sglang.sh "nuk_gpus"
|
||||||
|
cd sgl-router
|
||||||
|
SHOW_ROUTER_LOGS=1 pytest py_test/e2e_response_api -s -vv -o log_cli=true --log-cli-level=INFO
|
||||||
|
|
||||||
|
- name: Run Python E2E gRPC tests
|
||||||
|
run: |
|
||||||
|
bash scripts/killall_sglang.sh "nuk_gpus"
|
||||||
|
cd sgl-router
|
||||||
|
SHOW_ROUTER_LOGS=1 ROUTER_LOCAL_MODEL_PATH="/home/ubuntu/models" pytest py_test/e2e_grpc -s -vv -o log_cli=true --log-cli-level=INFO
|
||||||
|
|
||||||
|
|
||||||
finish:
|
finish:
|
||||||
needs: [unit-test-rust, pytest-rust]
|
needs: [unit-test-rust, pytest-rust, pytest-rust-2]
|
||||||
runs-on: ubuntu-latest
|
runs-on: ubuntu-latest
|
||||||
steps:
|
steps:
|
||||||
- name: Finish
|
- name: Finish
|
||||||
|
|||||||
@@ -267,8 +267,6 @@ def popen_launch_workers_and_router(
|
|||||||
policy,
|
policy,
|
||||||
"--model-path",
|
"--model-path",
|
||||||
model,
|
model,
|
||||||
"--log-level",
|
|
||||||
"warn",
|
|
||||||
]
|
]
|
||||||
|
|
||||||
# Add worker URLs
|
# Add worker URLs
|
||||||
|
|||||||
480
sgl-router/py_test/e2e_response_api/base.py
Normal file
480
sgl-router/py_test/e2e_response_api/base.py
Normal file
@@ -0,0 +1,480 @@
|
|||||||
|
"""
|
||||||
|
Base test class for Response API e2e tests.
|
||||||
|
|
||||||
|
This module provides base test classes that can be reused across different backends
|
||||||
|
(OpenAI, XAI, gRPC) with common test logic.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import json
|
||||||
|
import sys
|
||||||
|
import time
|
||||||
|
import unittest
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
import requests
|
||||||
|
|
||||||
|
# Add current directory for local imports
|
||||||
|
_TEST_DIR = Path(__file__).parent
|
||||||
|
sys.path.insert(0, str(_TEST_DIR))
|
||||||
|
|
||||||
|
from util import CustomTestCase
|
||||||
|
|
||||||
|
|
||||||
|
class ResponseAPIBaseTest(CustomTestCase):
|
||||||
|
"""Base class for Response API tests with common utilities."""
|
||||||
|
|
||||||
|
# To be set by subclasses
|
||||||
|
base_url: str = None
|
||||||
|
api_key: str = None
|
||||||
|
model: str = None
|
||||||
|
|
||||||
|
def make_request(
|
||||||
|
self,
|
||||||
|
endpoint: str,
|
||||||
|
method: str = "POST",
|
||||||
|
json_data: Optional[dict] = None,
|
||||||
|
params: Optional[dict] = None,
|
||||||
|
) -> requests.Response:
|
||||||
|
"""
|
||||||
|
Make HTTP request to router.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
endpoint: Endpoint path (e.g., "/v1/responses")
|
||||||
|
method: HTTP method (GET, POST, DELETE)
|
||||||
|
json_data: JSON body for POST requests
|
||||||
|
params: Query parameters
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
requests.Response object
|
||||||
|
"""
|
||||||
|
url = f"{self.base_url}{endpoint}"
|
||||||
|
headers = {"Content-Type": "application/json"}
|
||||||
|
if self.api_key:
|
||||||
|
headers["Authorization"] = f"Bearer {self.api_key}"
|
||||||
|
|
||||||
|
if method == "POST":
|
||||||
|
resp = requests.post(url, json=json_data, headers=headers, params=params)
|
||||||
|
elif method == "GET":
|
||||||
|
resp = requests.get(url, headers=headers, params=params)
|
||||||
|
elif method == "DELETE":
|
||||||
|
resp = requests.delete(url, headers=headers, params=params)
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Unsupported method: {method}")
|
||||||
|
return resp
|
||||||
|
|
||||||
|
def create_response(
|
||||||
|
self,
|
||||||
|
input_text: str,
|
||||||
|
instructions: Optional[str] = None,
|
||||||
|
stream: bool = False,
|
||||||
|
max_output_tokens: Optional[int] = None,
|
||||||
|
temperature: Optional[float] = None,
|
||||||
|
previous_response_id: Optional[str] = None,
|
||||||
|
conversation: Optional[str] = None,
|
||||||
|
tools: Optional[list] = None,
|
||||||
|
background: bool = False,
|
||||||
|
**kwargs,
|
||||||
|
) -> requests.Response:
|
||||||
|
"""
|
||||||
|
Create a response via POST /v1/responses.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
input_text: User input
|
||||||
|
instructions: Optional system instructions
|
||||||
|
stream: Whether to stream response
|
||||||
|
max_output_tokens: Optional max tokens to generate
|
||||||
|
temperature: Sampling temperature
|
||||||
|
previous_response_id: Optional previous response ID for state management
|
||||||
|
conversation: Optional conversation ID for state management
|
||||||
|
tools: Optional list of MCP tools
|
||||||
|
background: Whether to run in background mode
|
||||||
|
**kwargs: Additional request parameters
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
requests.Response object
|
||||||
|
"""
|
||||||
|
data = {
|
||||||
|
"model": self.model,
|
||||||
|
"input": input_text,
|
||||||
|
"stream": stream,
|
||||||
|
**kwargs,
|
||||||
|
}
|
||||||
|
|
||||||
|
if instructions:
|
||||||
|
data["instructions"] = instructions
|
||||||
|
|
||||||
|
if max_output_tokens is not None:
|
||||||
|
data["max_output_tokens"] = max_output_tokens
|
||||||
|
|
||||||
|
if temperature is not None:
|
||||||
|
data["temperature"] = temperature
|
||||||
|
|
||||||
|
if previous_response_id:
|
||||||
|
data["previous_response_id"] = previous_response_id
|
||||||
|
|
||||||
|
if conversation:
|
||||||
|
data["conversation"] = conversation
|
||||||
|
|
||||||
|
if tools:
|
||||||
|
data["tools"] = tools
|
||||||
|
|
||||||
|
if background:
|
||||||
|
data["background"] = background
|
||||||
|
|
||||||
|
if stream:
|
||||||
|
# For streaming, we need to handle SSE
|
||||||
|
return self._create_streaming_response(data)
|
||||||
|
else:
|
||||||
|
return self.make_request("/v1/responses", "POST", data)
|
||||||
|
|
||||||
|
def _create_streaming_response(self, data: dict) -> requests.Response:
|
||||||
|
"""Handle streaming response creation."""
|
||||||
|
url = f"{self.base_url}/v1/responses"
|
||||||
|
headers = {"Content-Type": "application/json"}
|
||||||
|
if self.api_key:
|
||||||
|
headers["Authorization"] = f"Bearer {self.api_key}"
|
||||||
|
|
||||||
|
# Return response object with stream=True
|
||||||
|
return requests.post(url, json=data, headers=headers, stream=True)
|
||||||
|
|
||||||
|
def get_response(self, response_id: str) -> requests.Response:
|
||||||
|
"""Get response by ID via GET /v1/responses/{response_id}."""
|
||||||
|
return self.make_request(f"/v1/responses/{response_id}", "GET")
|
||||||
|
|
||||||
|
def delete_response(self, response_id: str) -> requests.Response:
|
||||||
|
"""Delete response by ID via DELETE /v1/responses/{response_id}."""
|
||||||
|
return self.make_request(f"/v1/responses/{response_id}", "DELETE")
|
||||||
|
|
||||||
|
def cancel_response(self, response_id: str) -> requests.Response:
|
||||||
|
"""Cancel response by ID via POST /v1/responses/{response_id}/cancel."""
|
||||||
|
return self.make_request(f"/v1/responses/{response_id}/cancel", "POST", {})
|
||||||
|
|
||||||
|
def get_response_input(self, response_id: str) -> requests.Response:
|
||||||
|
"""Get response input items via GET /v1/responses/{response_id}/input."""
|
||||||
|
return self.make_request(f"/v1/responses/{response_id}/input", "GET")
|
||||||
|
|
||||||
|
def create_conversation(self, metadata: Optional[dict] = None) -> requests.Response:
|
||||||
|
"""Create conversation via POST /v1/conversations."""
|
||||||
|
data = {}
|
||||||
|
if metadata:
|
||||||
|
data["metadata"] = metadata
|
||||||
|
return self.make_request("/v1/conversations", "POST", data)
|
||||||
|
|
||||||
|
def get_conversation(self, conversation_id: str) -> requests.Response:
|
||||||
|
"""Get conversation by ID via GET /v1/conversations/{conversation_id}."""
|
||||||
|
return self.make_request(f"/v1/conversations/{conversation_id}", "GET")
|
||||||
|
|
||||||
|
def update_conversation(
|
||||||
|
self, conversation_id: str, metadata: dict
|
||||||
|
) -> requests.Response:
|
||||||
|
"""Update conversation via POST /v1/conversations/{conversation_id}."""
|
||||||
|
return self.make_request(
|
||||||
|
f"/v1/conversations/{conversation_id}", "POST", {"metadata": metadata}
|
||||||
|
)
|
||||||
|
|
||||||
|
def delete_conversation(self, conversation_id: str) -> requests.Response:
|
||||||
|
"""Delete conversation via DELETE /v1/conversations/{conversation_id}."""
|
||||||
|
return self.make_request(f"/v1/conversations/{conversation_id}", "DELETE")
|
||||||
|
|
||||||
|
def list_conversation_items(
|
||||||
|
self,
|
||||||
|
conversation_id: str,
|
||||||
|
limit: Optional[int] = None,
|
||||||
|
after: Optional[str] = None,
|
||||||
|
before: Optional[str] = None,
|
||||||
|
order: str = "asc",
|
||||||
|
) -> requests.Response:
|
||||||
|
"""List conversation items via GET /v1/conversations/{conversation_id}/items."""
|
||||||
|
params = {"order": order}
|
||||||
|
if limit:
|
||||||
|
params["limit"] = limit
|
||||||
|
if after:
|
||||||
|
params["after"] = after
|
||||||
|
if before:
|
||||||
|
params["before"] = before
|
||||||
|
return self.make_request(
|
||||||
|
f"/v1/conversations/{conversation_id}/items", "GET", params=params
|
||||||
|
)
|
||||||
|
|
||||||
|
def create_conversation_items(
|
||||||
|
self, conversation_id: str, items: list
|
||||||
|
) -> requests.Response:
|
||||||
|
"""Create conversation items via POST /v1/conversations/{conversation_id}/items."""
|
||||||
|
return self.make_request(
|
||||||
|
f"/v1/conversations/{conversation_id}/items", "POST", {"items": items}
|
||||||
|
)
|
||||||
|
|
||||||
|
def get_conversation_item(
|
||||||
|
self, conversation_id: str, item_id: str
|
||||||
|
) -> requests.Response:
|
||||||
|
"""Get conversation item via GET /v1/conversations/{conversation_id}/items/{item_id}."""
|
||||||
|
return self.make_request(
|
||||||
|
f"/v1/conversations/{conversation_id}/items/{item_id}", "GET"
|
||||||
|
)
|
||||||
|
|
||||||
|
def delete_conversation_item(
|
||||||
|
self, conversation_id: str, item_id: str
|
||||||
|
) -> requests.Response:
|
||||||
|
"""Delete conversation item via DELETE /v1/conversations/{conversation_id}/items/{item_id}."""
|
||||||
|
return self.make_request(
|
||||||
|
f"/v1/conversations/{conversation_id}/items/{item_id}", "DELETE"
|
||||||
|
)
|
||||||
|
|
||||||
|
def parse_sse_events(self, response: requests.Response) -> list:
|
||||||
|
"""
|
||||||
|
Parse Server-Sent Events from streaming response.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
response: requests.Response with stream=True
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of event dictionaries with 'event' and 'data' keys
|
||||||
|
"""
|
||||||
|
events = []
|
||||||
|
current_event = None
|
||||||
|
|
||||||
|
for line in response.iter_lines():
|
||||||
|
if not line:
|
||||||
|
# Empty line signals end of event
|
||||||
|
if current_event and current_event.get("data"):
|
||||||
|
events.append(current_event)
|
||||||
|
current_event = None
|
||||||
|
continue
|
||||||
|
|
||||||
|
line = line.decode("utf-8")
|
||||||
|
|
||||||
|
if line.startswith("event:"):
|
||||||
|
current_event = {"event": line[6:].strip()}
|
||||||
|
elif line.startswith("data:"):
|
||||||
|
if current_event is None:
|
||||||
|
current_event = {}
|
||||||
|
data_str = line[5:].strip()
|
||||||
|
try:
|
||||||
|
current_event["data"] = json.loads(data_str)
|
||||||
|
except json.JSONDecodeError:
|
||||||
|
current_event["data"] = data_str
|
||||||
|
|
||||||
|
# Don't forget the last event if stream ends without empty line
|
||||||
|
if current_event and current_event.get("data"):
|
||||||
|
events.append(current_event)
|
||||||
|
|
||||||
|
return events
|
||||||
|
|
||||||
|
def wait_for_background_task(
|
||||||
|
self, response_id: str, timeout: int = 30, poll_interval: float = 0.5
|
||||||
|
) -> dict:
|
||||||
|
"""
|
||||||
|
Wait for background task to complete.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
response_id: Response ID to poll
|
||||||
|
timeout: Max seconds to wait
|
||||||
|
poll_interval: Seconds between polls
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Final response data
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
TimeoutError: If task doesn't complete in time
|
||||||
|
AssertionError: If task fails
|
||||||
|
"""
|
||||||
|
start_time = time.time()
|
||||||
|
|
||||||
|
while time.time() - start_time < timeout:
|
||||||
|
resp = self.get_response(response_id)
|
||||||
|
self.assertEqual(resp.status_code, 200)
|
||||||
|
|
||||||
|
data = resp.json()
|
||||||
|
status = data.get("status")
|
||||||
|
|
||||||
|
if status == "completed":
|
||||||
|
return data
|
||||||
|
elif status == "failed":
|
||||||
|
raise AssertionError(
|
||||||
|
f"Background task failed: {data.get('error', 'Unknown error')}"
|
||||||
|
)
|
||||||
|
elif status == "cancelled":
|
||||||
|
raise AssertionError("Background task was cancelled")
|
||||||
|
|
||||||
|
time.sleep(poll_interval)
|
||||||
|
|
||||||
|
raise TimeoutError(
|
||||||
|
f"Background task {response_id} did not complete within {timeout}s"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class StateManagementBaseTest(ResponseAPIBaseTest):
|
||||||
|
"""Base class for state management tests (previous_response_id and conversation)."""
|
||||||
|
|
||||||
|
def test_basic_response_creation(self):
|
||||||
|
"""Test basic response creation without state."""
|
||||||
|
resp = self.create_response("What is 2+2?", max_output_tokens=50)
|
||||||
|
self.assertEqual(resp.status_code, 200)
|
||||||
|
|
||||||
|
data = resp.json()
|
||||||
|
self.assertIn("id", data)
|
||||||
|
self.assertIn("output", data)
|
||||||
|
self.assertEqual(data["status"], "completed")
|
||||||
|
self.assertIn("usage", data)
|
||||||
|
|
||||||
|
def test_streaming_response(self):
|
||||||
|
"""Test streaming response."""
|
||||||
|
resp = self.create_response("Count to 5", stream=True, max_output_tokens=50)
|
||||||
|
self.assertEqual(resp.status_code, 200)
|
||||||
|
|
||||||
|
events = self.parse_sse_events(resp)
|
||||||
|
self.assertGreater(len(events), 0)
|
||||||
|
|
||||||
|
# Check for response.created event
|
||||||
|
created_events = [e for e in events if e.get("event") == "response.created"]
|
||||||
|
self.assertGreater(len(created_events), 0)
|
||||||
|
|
||||||
|
# Check for final completed event or in_progress events
|
||||||
|
self.assertTrue(
|
||||||
|
any(
|
||||||
|
e.get("event") in ["response.completed", "response.in_progress"]
|
||||||
|
for e in events
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class ResponseCRUDBaseTest(ResponseAPIBaseTest):
|
||||||
|
"""Base class for Response API CRUD tests."""
|
||||||
|
|
||||||
|
def test_create_and_get_response(self):
|
||||||
|
"""Test creating response and retrieving it."""
|
||||||
|
# Create response
|
||||||
|
create_resp = self.create_response("Hello, world!")
|
||||||
|
self.assertEqual(create_resp.status_code, 200)
|
||||||
|
|
||||||
|
create_data = create_resp.json()
|
||||||
|
response_id = create_data["id"]
|
||||||
|
|
||||||
|
# Get response
|
||||||
|
get_resp = self.get_response(response_id)
|
||||||
|
self.assertEqual(get_resp.status_code, 200)
|
||||||
|
|
||||||
|
get_data = get_resp.json()
|
||||||
|
self.assertEqual(get_data["id"], response_id)
|
||||||
|
self.assertEqual(get_data["status"], "completed")
|
||||||
|
|
||||||
|
input_resp = self.get_response_input(get_data["id"])
|
||||||
|
# change not merge yet
|
||||||
|
self.assertEqual(input_resp.status_code, 501)
|
||||||
|
# self.assertEqual(input_resp.status_code, 200)
|
||||||
|
# input_data = input_resp.json()
|
||||||
|
# self.assertIn("data", input_data)
|
||||||
|
# self.assertGreater(len(input_data["data"]), 0)
|
||||||
|
|
||||||
|
@unittest.skip("TODO: Add delete response feature")
|
||||||
|
def test_delete_response(self):
|
||||||
|
"""Test deleting response."""
|
||||||
|
# Create response
|
||||||
|
create_resp = self.create_response("Test deletion", max_output_tokens=50)
|
||||||
|
self.assertEqual(create_resp.status_code, 200)
|
||||||
|
|
||||||
|
response_id = create_resp.json()["id"]
|
||||||
|
|
||||||
|
# Delete response
|
||||||
|
delete_resp = self.delete_response(response_id)
|
||||||
|
self.assertEqual(delete_resp.status_code, 200)
|
||||||
|
|
||||||
|
# Verify it's deleted (should return 404)
|
||||||
|
get_resp = self.get_response(response_id)
|
||||||
|
self.assertEqual(get_resp.status_code, 404)
|
||||||
|
|
||||||
|
@unittest.skip("TODO: Add background response feature")
|
||||||
|
def test_background_response(self):
|
||||||
|
"""Test background response execution."""
|
||||||
|
# Create background response
|
||||||
|
create_resp = self.create_response(
|
||||||
|
"Write a short story", background=True, max_output_tokens=100
|
||||||
|
)
|
||||||
|
self.assertEqual(create_resp.status_code, 200)
|
||||||
|
|
||||||
|
create_data = create_resp.json()
|
||||||
|
response_id = create_data["id"]
|
||||||
|
self.assertEqual(create_data["status"], "in_progress")
|
||||||
|
|
||||||
|
# Wait for completion
|
||||||
|
final_data = self.wait_for_background_task(response_id, timeout=60)
|
||||||
|
self.assertEqual(final_data["status"], "completed")
|
||||||
|
|
||||||
|
|
||||||
|
class ConversationCRUDBaseTest(ResponseAPIBaseTest):
|
||||||
|
"""Base class for Conversation API CRUD tests."""
|
||||||
|
|
||||||
|
def test_create_and_get_conversation(self):
|
||||||
|
"""Test creating and retrieving conversation."""
|
||||||
|
# Create conversation
|
||||||
|
create_resp = self.create_conversation(metadata={"user": "test_user"})
|
||||||
|
self.assertEqual(create_resp.status_code, 200)
|
||||||
|
|
||||||
|
create_data = create_resp.json()
|
||||||
|
conversation_id = create_data["id"]
|
||||||
|
self.assertEqual(create_data["metadata"]["user"], "test_user")
|
||||||
|
|
||||||
|
# Get conversation
|
||||||
|
get_resp = self.get_conversation(conversation_id)
|
||||||
|
self.assertEqual(get_resp.status_code, 200)
|
||||||
|
|
||||||
|
get_data = get_resp.json()
|
||||||
|
self.assertEqual(get_data["id"], conversation_id)
|
||||||
|
self.assertEqual(get_data["metadata"]["user"], "test_user")
|
||||||
|
|
||||||
|
def test_update_conversation(self):
|
||||||
|
"""Test updating conversation metadata."""
|
||||||
|
# Create conversation
|
||||||
|
create_resp = self.create_conversation(metadata={"key1": "value1"})
|
||||||
|
self.assertEqual(create_resp.status_code, 200)
|
||||||
|
conversation_id = create_resp.json()["id"]
|
||||||
|
|
||||||
|
# Update conversation
|
||||||
|
update_resp = self.update_conversation(
|
||||||
|
conversation_id, metadata={"key1": "value1", "key2": "value2"}
|
||||||
|
)
|
||||||
|
self.assertEqual(update_resp.status_code, 200)
|
||||||
|
|
||||||
|
# Verify update
|
||||||
|
get_resp = self.get_conversation(conversation_id)
|
||||||
|
get_data = get_resp.json()
|
||||||
|
self.assertEqual(get_data["metadata"]["key2"], "value2")
|
||||||
|
|
||||||
|
def test_delete_conversation(self):
|
||||||
|
"""Test deleting conversation."""
|
||||||
|
# Create conversation
|
||||||
|
create_resp = self.create_conversation()
|
||||||
|
self.assertEqual(create_resp.status_code, 200)
|
||||||
|
conversation_id = create_resp.json()["id"]
|
||||||
|
|
||||||
|
# Delete conversation
|
||||||
|
delete_resp = self.delete_conversation(conversation_id)
|
||||||
|
self.assertEqual(delete_resp.status_code, 200)
|
||||||
|
|
||||||
|
# Verify deletion
|
||||||
|
get_resp = self.get_conversation(conversation_id)
|
||||||
|
self.assertEqual(get_resp.status_code, 404)
|
||||||
|
|
||||||
|
def test_list_conversation_items(self):
|
||||||
|
"""Test listing conversation items."""
|
||||||
|
# Create conversation
|
||||||
|
conv_resp = self.create_conversation()
|
||||||
|
conversation_id = conv_resp.json()["id"]
|
||||||
|
|
||||||
|
# Create response with conversation
|
||||||
|
self.create_response(
|
||||||
|
"First message", conversation=conversation_id, max_output_tokens=50
|
||||||
|
)
|
||||||
|
self.create_response(
|
||||||
|
"Second message", conversation=conversation_id, max_output_tokens=50
|
||||||
|
)
|
||||||
|
|
||||||
|
# List items
|
||||||
|
list_resp = self.list_conversation_items(conversation_id)
|
||||||
|
self.assertEqual(list_resp.status_code, 200)
|
||||||
|
|
||||||
|
list_data = list_resp.json()
|
||||||
|
self.assertIn("data", list_data)
|
||||||
|
# Should have at least 4 items (2 inputs + 2 outputs)
|
||||||
|
self.assertGreaterEqual(len(list_data["data"]), 4)
|
||||||
39
sgl-router/py_test/e2e_response_api/conftest.py
Normal file
39
sgl-router/py_test/e2e_response_api/conftest.py
Normal file
@@ -0,0 +1,39 @@
|
|||||||
|
"""
|
||||||
|
pytest configuration for e2e_response_api tests.
|
||||||
|
|
||||||
|
This configures pytest to not collect base test classes that are meant to be inherited.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
|
||||||
|
def pytest_collection_modifyitems(config, items):
|
||||||
|
"""
|
||||||
|
Modify test collection to exclude base test classes.
|
||||||
|
|
||||||
|
Base test classes are meant to be inherited, not run directly.
|
||||||
|
We exclude any test that comes from these base classes:
|
||||||
|
- StateManagementBaseTest
|
||||||
|
- ResponseCRUDBaseTest
|
||||||
|
- ConversationCRUDBaseTest
|
||||||
|
- MCPTests
|
||||||
|
- StateManagementTests
|
||||||
|
"""
|
||||||
|
base_class_names = {
|
||||||
|
"StateManagementBaseTest",
|
||||||
|
"ResponseCRUDBaseTest",
|
||||||
|
"ConversationCRUDBaseTest",
|
||||||
|
"MCPTests",
|
||||||
|
"StateManagementTests",
|
||||||
|
}
|
||||||
|
|
||||||
|
# Filter out tests from base classes
|
||||||
|
filtered_items = []
|
||||||
|
for item in items:
|
||||||
|
# Check if the test's parent class is a base class
|
||||||
|
parent_name = item.parent.name if hasattr(item, "parent") else None
|
||||||
|
if parent_name not in base_class_names:
|
||||||
|
filtered_items.append(item)
|
||||||
|
|
||||||
|
# Update items list
|
||||||
|
items[:] = filtered_items
|
||||||
229
sgl-router/py_test/e2e_response_api/mcp.py
Normal file
229
sgl-router/py_test/e2e_response_api/mcp.py
Normal file
@@ -0,0 +1,229 @@
|
|||||||
|
"""
|
||||||
|
MCP (Model Context Protocol) tests for Response API.
|
||||||
|
|
||||||
|
Tests MCP tool calling in both streaming and non-streaming modes.
|
||||||
|
These tests should work across all backends that support MCP (OpenAI, XAI).
|
||||||
|
"""
|
||||||
|
|
||||||
|
from base import ResponseAPIBaseTest
|
||||||
|
|
||||||
|
|
||||||
|
class MCPTests(ResponseAPIBaseTest):
|
||||||
|
"""Tests for MCP tool calling in both streaming and non-streaming modes."""
|
||||||
|
|
||||||
|
def test_mcp_basic_tool_call(self):
|
||||||
|
"""Test basic MCP tool call (non-streaming)."""
|
||||||
|
tools = [
|
||||||
|
{
|
||||||
|
"type": "mcp",
|
||||||
|
"server_label": "deepwiki",
|
||||||
|
"server_url": "https://mcp.deepwiki.com/mcp",
|
||||||
|
"require_approval": "never",
|
||||||
|
}
|
||||||
|
]
|
||||||
|
|
||||||
|
resp = self.create_response(
|
||||||
|
"What transport protocols does the 2025-03-26 version of the MCP spec (modelcontextprotocol/modelcontextprotocol) support?",
|
||||||
|
tools=tools,
|
||||||
|
stream=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Should successfully make the request
|
||||||
|
self.assertEqual(resp.status_code, 200)
|
||||||
|
|
||||||
|
data = resp.json()
|
||||||
|
print(f"MCP response: {data}")
|
||||||
|
|
||||||
|
# Basic response structure
|
||||||
|
self.assertIn("id", data)
|
||||||
|
self.assertIn("status", data)
|
||||||
|
self.assertEqual(data["status"], "completed")
|
||||||
|
self.assertIn("output", data)
|
||||||
|
self.assertIn("model", data)
|
||||||
|
|
||||||
|
# Verify output array is not empty
|
||||||
|
output = data["output"]
|
||||||
|
self.assertIsInstance(output, list)
|
||||||
|
self.assertGreater(len(output), 0)
|
||||||
|
|
||||||
|
# Check for MCP-specific output types
|
||||||
|
output_types = [item.get("type") for item in output]
|
||||||
|
|
||||||
|
# Should have mcp_list_tools - tools are listed before calling
|
||||||
|
self.assertIn(
|
||||||
|
"mcp_list_tools", output_types, "Response should contain mcp_list_tools"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Should have at least one mcp_call
|
||||||
|
mcp_calls = [item for item in output if item.get("type") == "mcp_call"]
|
||||||
|
self.assertGreater(
|
||||||
|
len(mcp_calls), 0, "Response should contain at least one mcp_call"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Verify mcp_call structure
|
||||||
|
for mcp_call in mcp_calls:
|
||||||
|
self.assertIn("id", mcp_call)
|
||||||
|
self.assertIn("status", mcp_call)
|
||||||
|
self.assertEqual(mcp_call["status"], "completed")
|
||||||
|
self.assertIn("server_label", mcp_call)
|
||||||
|
self.assertEqual(mcp_call["server_label"], "deepwiki")
|
||||||
|
self.assertIn("name", mcp_call)
|
||||||
|
self.assertIn("arguments", mcp_call)
|
||||||
|
self.assertIn("output", mcp_call)
|
||||||
|
|
||||||
|
# Should have final message output
|
||||||
|
messages = [item for item in output if item.get("type") == "message"]
|
||||||
|
self.assertGreater(
|
||||||
|
len(messages), 0, "Response should contain at least one message"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Verify message structure
|
||||||
|
for msg in messages:
|
||||||
|
self.assertIn("content", msg)
|
||||||
|
self.assertIsInstance(msg["content"], list)
|
||||||
|
|
||||||
|
# Check content has text
|
||||||
|
for content_item in msg["content"]:
|
||||||
|
if content_item.get("type") == "output_text":
|
||||||
|
self.assertIn("text", content_item)
|
||||||
|
self.assertIsInstance(content_item["text"], str)
|
||||||
|
self.assertGreater(len(content_item["text"]), 0)
|
||||||
|
|
||||||
|
def test_mcp_basic_tool_call_streaming(self):
|
||||||
|
"""Test basic MCP tool call (streaming)."""
|
||||||
|
tools = [
|
||||||
|
{
|
||||||
|
"type": "mcp",
|
||||||
|
"server_label": "deepwiki",
|
||||||
|
"server_url": "https://mcp.deepwiki.com/mcp",
|
||||||
|
"require_approval": "never",
|
||||||
|
}
|
||||||
|
]
|
||||||
|
|
||||||
|
resp = self.create_response(
|
||||||
|
"What transport protocols does the 2025-03-26 version of the MCP spec (modelcontextprotocol/modelcontextprotocol) support?",
|
||||||
|
tools=tools,
|
||||||
|
stream=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Should successfully make the request
|
||||||
|
self.assertEqual(resp.status_code, 200)
|
||||||
|
|
||||||
|
events = self.parse_sse_events(resp)
|
||||||
|
self.assertGreater(len(events), 0)
|
||||||
|
|
||||||
|
event_types = [e.get("event") for e in events]
|
||||||
|
|
||||||
|
# Check for lifecycle events
|
||||||
|
self.assertIn(
|
||||||
|
"response.created", event_types, "Should have response.created event"
|
||||||
|
)
|
||||||
|
self.assertIn(
|
||||||
|
"response.completed", event_types, "Should have response.completed event"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Check for MCP list tools events
|
||||||
|
self.assertIn(
|
||||||
|
"response.output_item.added",
|
||||||
|
event_types,
|
||||||
|
"Should have output_item.added events",
|
||||||
|
)
|
||||||
|
self.assertIn(
|
||||||
|
"response.mcp_list_tools.in_progress",
|
||||||
|
event_types,
|
||||||
|
"Should have mcp_list_tools.in_progress event",
|
||||||
|
)
|
||||||
|
self.assertIn(
|
||||||
|
"response.mcp_list_tools.completed",
|
||||||
|
event_types,
|
||||||
|
"Should have mcp_list_tools.completed event",
|
||||||
|
)
|
||||||
|
|
||||||
|
# Check for MCP call events
|
||||||
|
self.assertIn(
|
||||||
|
"response.mcp_call.in_progress",
|
||||||
|
event_types,
|
||||||
|
"Should have mcp_call.in_progress event",
|
||||||
|
)
|
||||||
|
self.assertIn(
|
||||||
|
"response.mcp_call_arguments.delta",
|
||||||
|
event_types,
|
||||||
|
"Should have mcp_call_arguments.delta event",
|
||||||
|
)
|
||||||
|
self.assertIn(
|
||||||
|
"response.mcp_call_arguments.done",
|
||||||
|
event_types,
|
||||||
|
"Should have mcp_call_arguments.done event",
|
||||||
|
)
|
||||||
|
self.assertIn(
|
||||||
|
"response.mcp_call.completed",
|
||||||
|
event_types,
|
||||||
|
"Should have mcp_call.completed event",
|
||||||
|
)
|
||||||
|
|
||||||
|
# Check for text output events
|
||||||
|
self.assertIn(
|
||||||
|
"response.content_part.added",
|
||||||
|
event_types,
|
||||||
|
"Should have content_part.added event",
|
||||||
|
)
|
||||||
|
self.assertIn(
|
||||||
|
"response.output_text.delta",
|
||||||
|
event_types,
|
||||||
|
"Should have output_text.delta events",
|
||||||
|
)
|
||||||
|
self.assertIn(
|
||||||
|
"response.output_text.done",
|
||||||
|
event_types,
|
||||||
|
"Should have output_text.done event",
|
||||||
|
)
|
||||||
|
self.assertIn(
|
||||||
|
"response.content_part.done",
|
||||||
|
event_types,
|
||||||
|
"Should have content_part.done event",
|
||||||
|
)
|
||||||
|
|
||||||
|
# Verify final completed event has full response
|
||||||
|
completed_events = [e for e in events if e.get("event") == "response.completed"]
|
||||||
|
self.assertEqual(len(completed_events), 1)
|
||||||
|
|
||||||
|
final_response = completed_events[0].get("data", {}).get("response", {})
|
||||||
|
self.assertIn("id", final_response)
|
||||||
|
self.assertEqual(final_response.get("status"), "completed")
|
||||||
|
self.assertIn("output", final_response)
|
||||||
|
|
||||||
|
# Verify final output contains expected items
|
||||||
|
final_output = final_response.get("output", [])
|
||||||
|
final_output_types = [item.get("type") for item in final_output]
|
||||||
|
|
||||||
|
self.assertIn("mcp_list_tools", final_output_types)
|
||||||
|
self.assertIn("mcp_call", final_output_types)
|
||||||
|
self.assertIn("message", final_output_types)
|
||||||
|
|
||||||
|
# Verify mcp_call items in final output
|
||||||
|
mcp_calls = [item for item in final_output if item.get("type") == "mcp_call"]
|
||||||
|
self.assertGreater(len(mcp_calls), 0)
|
||||||
|
|
||||||
|
for mcp_call in mcp_calls:
|
||||||
|
self.assertEqual(mcp_call.get("status"), "completed")
|
||||||
|
self.assertEqual(mcp_call.get("server_label"), "deepwiki")
|
||||||
|
self.assertIn("name", mcp_call)
|
||||||
|
self.assertIn("arguments", mcp_call)
|
||||||
|
self.assertIn("output", mcp_call)
|
||||||
|
|
||||||
|
# Verify text deltas combine to final message
|
||||||
|
text_deltas = [
|
||||||
|
e.get("data", {}).get("delta", "")
|
||||||
|
for e in events
|
||||||
|
if e.get("event") == "response.output_text.delta"
|
||||||
|
]
|
||||||
|
self.assertGreater(len(text_deltas), 0, "Should have text deltas")
|
||||||
|
|
||||||
|
# Get final text from output_text.done event
|
||||||
|
text_done_events = [
|
||||||
|
e for e in events if e.get("event") == "response.output_text.done"
|
||||||
|
]
|
||||||
|
self.assertGreater(len(text_done_events), 0)
|
||||||
|
|
||||||
|
final_text = text_done_events[0].get("data", {}).get("text", "")
|
||||||
|
self.assertGreater(len(final_text), 0, "Final text should not be empty")
|
||||||
554
sgl-router/py_test/e2e_response_api/router_fixtures.py
Normal file
554
sgl-router/py_test/e2e_response_api/router_fixtures.py
Normal file
@@ -0,0 +1,554 @@
|
|||||||
|
"""
|
||||||
|
Fixtures for launching OpenAI/XAI router for response API e2e testing.
|
||||||
|
|
||||||
|
This module provides fixtures for launching SGLang router with OpenAI or XAI backends:
|
||||||
|
1. Launch router with --backend openai pointing to OpenAI or XAI API
|
||||||
|
2. Configure history backend (memory or oracle)
|
||||||
|
|
||||||
|
This supports testing the Response API against real cloud providers.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import os
|
||||||
|
import socket
|
||||||
|
import subprocess
|
||||||
|
import time
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
import requests
|
||||||
|
|
||||||
|
|
||||||
|
def wait_for_workers_ready(
|
||||||
|
router_url: str,
|
||||||
|
expected_workers: int,
|
||||||
|
timeout: int = 300,
|
||||||
|
api_key: Optional[str] = None,
|
||||||
|
) -> None:
|
||||||
|
"""
|
||||||
|
Wait for router to have all workers connected.
|
||||||
|
|
||||||
|
Polls the /workers endpoint until the 'total' field matches expected_workers.
|
||||||
|
|
||||||
|
Example response from /workers endpoint:
|
||||||
|
{"workers":[],"total":0,"stats":{"prefill_count":0,"decode_count":0,"regular_count":0}}
|
||||||
|
|
||||||
|
Args:
|
||||||
|
router_url: Base URL of router (e.g., "http://127.0.0.1:30000")
|
||||||
|
expected_workers: Number of workers expected to be connected
|
||||||
|
timeout: Max seconds to wait
|
||||||
|
api_key: Optional API key for authentication
|
||||||
|
"""
|
||||||
|
start_time = time.time()
|
||||||
|
last_error = None
|
||||||
|
attempt = 0
|
||||||
|
|
||||||
|
headers = {}
|
||||||
|
if api_key:
|
||||||
|
headers["Authorization"] = f"Bearer {api_key}"
|
||||||
|
|
||||||
|
with requests.Session() as session:
|
||||||
|
while time.time() - start_time < timeout:
|
||||||
|
attempt += 1
|
||||||
|
elapsed = int(time.time() - start_time)
|
||||||
|
|
||||||
|
# Print progress every 10 seconds
|
||||||
|
if elapsed > 0 and elapsed % 10 == 0 and attempt % 10 == 0:
|
||||||
|
print(f" Still waiting for workers... ({elapsed}/{timeout}s elapsed)")
|
||||||
|
|
||||||
|
try:
|
||||||
|
response = session.get(
|
||||||
|
f"{router_url}/workers", headers=headers, timeout=5
|
||||||
|
)
|
||||||
|
if response.status_code == 200:
|
||||||
|
data = response.json()
|
||||||
|
total_workers = data.get("total", 0)
|
||||||
|
|
||||||
|
if total_workers == expected_workers:
|
||||||
|
print(
|
||||||
|
f" All {expected_workers} workers connected after {elapsed}s"
|
||||||
|
)
|
||||||
|
return
|
||||||
|
else:
|
||||||
|
last_error = f"Workers: {total_workers}/{expected_workers}"
|
||||||
|
else:
|
||||||
|
last_error = f"HTTP {response.status_code}"
|
||||||
|
except requests.ConnectionError:
|
||||||
|
last_error = "Connection refused (router not ready yet)"
|
||||||
|
except requests.Timeout:
|
||||||
|
last_error = "Timeout"
|
||||||
|
except requests.RequestException as e:
|
||||||
|
last_error = str(e)
|
||||||
|
except (ValueError, KeyError) as e:
|
||||||
|
last_error = f"Invalid response: {e}"
|
||||||
|
|
||||||
|
time.sleep(1)
|
||||||
|
|
||||||
|
raise TimeoutError(
|
||||||
|
f"Router at {router_url} did not get {expected_workers} workers within {timeout}s.\n"
|
||||||
|
f"Last status: {last_error}\n"
|
||||||
|
f"Hint: Run with SHOW_ROUTER_LOGS=1 to see startup logs"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def find_free_port() -> int:
|
||||||
|
"""Find an available port on localhost."""
|
||||||
|
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
|
||||||
|
s.bind(("127.0.0.1", 0))
|
||||||
|
return s.getsockname()[1]
|
||||||
|
|
||||||
|
|
||||||
|
def wait_for_router_ready(
|
||||||
|
router_url: str,
|
||||||
|
timeout: int = 60,
|
||||||
|
api_key: Optional[str] = None,
|
||||||
|
) -> None:
|
||||||
|
"""
|
||||||
|
Wait for router to be ready.
|
||||||
|
|
||||||
|
Polls the /health endpoint until it returns 200.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
router_url: Base URL of router (e.g., "http://127.0.0.1:30000")
|
||||||
|
timeout: Max seconds to wait
|
||||||
|
api_key: Optional API key for authentication
|
||||||
|
"""
|
||||||
|
start_time = time.time()
|
||||||
|
last_error = None
|
||||||
|
attempt = 0
|
||||||
|
|
||||||
|
headers = {}
|
||||||
|
if api_key:
|
||||||
|
headers["Authorization"] = f"Bearer {api_key}"
|
||||||
|
|
||||||
|
with requests.Session() as session:
|
||||||
|
while time.time() - start_time < timeout:
|
||||||
|
attempt += 1
|
||||||
|
elapsed = int(time.time() - start_time)
|
||||||
|
|
||||||
|
# Print progress every 10 seconds
|
||||||
|
if elapsed > 0 and elapsed % 10 == 0 and attempt % 10 == 0:
|
||||||
|
print(f" Still waiting for router... ({elapsed}/{timeout}s elapsed)")
|
||||||
|
|
||||||
|
try:
|
||||||
|
response = session.get(
|
||||||
|
f"{router_url}/health", headers=headers, timeout=5
|
||||||
|
)
|
||||||
|
if response.status_code == 200:
|
||||||
|
print(f" Router ready after {elapsed}s")
|
||||||
|
return
|
||||||
|
else:
|
||||||
|
last_error = f"HTTP {response.status_code}"
|
||||||
|
except requests.ConnectionError:
|
||||||
|
last_error = "Connection refused (router not ready yet)"
|
||||||
|
except requests.Timeout:
|
||||||
|
last_error = "Timeout"
|
||||||
|
except requests.RequestException as e:
|
||||||
|
last_error = str(e)
|
||||||
|
|
||||||
|
time.sleep(1)
|
||||||
|
|
||||||
|
raise TimeoutError(
|
||||||
|
f"Router at {router_url} did not become ready within {timeout}s.\n"
|
||||||
|
f"Last status: {last_error}\n"
|
||||||
|
f"Hint: Run with SHOW_ROUTER_LOGS=1 to see startup logs"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def popen_launch_openai_xai_router(
|
||||||
|
backend: str, # "openai" or "xai"
|
||||||
|
base_url: str,
|
||||||
|
timeout: int = 60,
|
||||||
|
history_backend: str = "memory",
|
||||||
|
api_key: Optional[str] = None,
|
||||||
|
router_args: Optional[list] = None,
|
||||||
|
stdout=None,
|
||||||
|
stderr=None,
|
||||||
|
prometheus_port: Optional[int] = None,
|
||||||
|
) -> dict:
|
||||||
|
"""
|
||||||
|
Launch SGLang router with OpenAI or XAI backend.
|
||||||
|
|
||||||
|
This approach:
|
||||||
|
1. Starts router with --backend openai
|
||||||
|
2. Points to OpenAI or XAI API via --worker-urls
|
||||||
|
3. Configures history backend (memory or oracle)
|
||||||
|
4. Waits for router health check to pass
|
||||||
|
|
||||||
|
Args:
|
||||||
|
backend: "openai" or "xai"
|
||||||
|
base_url: Base URL for router (e.g., "http://127.0.0.1:30000")
|
||||||
|
timeout: Timeout for router startup (default: 60s)
|
||||||
|
history_backend: "memory" or "oracle" (default: memory)
|
||||||
|
api_key: Optional API key for router authentication
|
||||||
|
router_args: Additional arguments for router
|
||||||
|
stdout: Optional file handle for router stdout
|
||||||
|
stderr: Optional file handle for router stderr
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
dict with:
|
||||||
|
- router: router process object
|
||||||
|
- base_url: router URL (HTTP endpoint)
|
||||||
|
|
||||||
|
Example:
|
||||||
|
>>> cluster = popen_launch_openai_xai_router(
|
||||||
|
... "openai", "http://127.0.0.1:30000"
|
||||||
|
... )
|
||||||
|
>>> # Use cluster['base_url'] for HTTP requests
|
||||||
|
>>> # Cleanup:
|
||||||
|
>>> kill_process_tree(cluster['router'].pid)
|
||||||
|
"""
|
||||||
|
show_output = os.environ.get("SHOW_ROUTER_LOGS", "0") == "1"
|
||||||
|
|
||||||
|
# Parse router port from base_url
|
||||||
|
if ":" in base_url.split("//")[-1]:
|
||||||
|
router_port = int(base_url.split(":")[-1])
|
||||||
|
else:
|
||||||
|
router_port = find_free_port()
|
||||||
|
|
||||||
|
print(f"\n{'='*70}")
|
||||||
|
print(f"Launching {backend.upper()} router")
|
||||||
|
print(f"{'='*70}")
|
||||||
|
print(f" Backend: {backend}")
|
||||||
|
print(f" Router port: {router_port}")
|
||||||
|
print(f" History backend: {history_backend}")
|
||||||
|
|
||||||
|
# Determine worker URL based on backend
|
||||||
|
if backend == "openai":
|
||||||
|
worker_url = "https://api.openai.com"
|
||||||
|
# Get API key from environment
|
||||||
|
backend_api_key = os.environ.get("OPENAI_API_KEY")
|
||||||
|
if not backend_api_key:
|
||||||
|
raise ValueError(
|
||||||
|
"OPENAI_API_KEY environment variable must be set for OpenAI backend"
|
||||||
|
)
|
||||||
|
elif backend == "xai":
|
||||||
|
worker_url = "https://api.x.ai"
|
||||||
|
# Get API key from environment
|
||||||
|
backend_api_key = os.environ.get("XAI_API_KEY")
|
||||||
|
if not backend_api_key:
|
||||||
|
raise ValueError(
|
||||||
|
"XAI_API_KEY environment variable must be set for XAI backend"
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Unsupported backend: {backend}")
|
||||||
|
|
||||||
|
print(f" Worker URL: {worker_url}")
|
||||||
|
|
||||||
|
# Build router command
|
||||||
|
router_cmd = [
|
||||||
|
"python3",
|
||||||
|
"-m",
|
||||||
|
"sglang_router.launch_router",
|
||||||
|
"--host",
|
||||||
|
"127.0.0.1",
|
||||||
|
"--port",
|
||||||
|
str(router_port),
|
||||||
|
"--backend",
|
||||||
|
"openai",
|
||||||
|
"--worker-urls",
|
||||||
|
worker_url,
|
||||||
|
"--history-backend",
|
||||||
|
history_backend,
|
||||||
|
"--log-level",
|
||||||
|
"warn",
|
||||||
|
]
|
||||||
|
|
||||||
|
# Note: Not adding --api-key to router command for local testing
|
||||||
|
# The router will not require authentication
|
||||||
|
|
||||||
|
# Add Prometheus port to avoid conflicts (use unique port or disable)
|
||||||
|
if prometheus_port is None:
|
||||||
|
# Auto-assign a unique prometheus port based on router port
|
||||||
|
prometheus_port = router_port + 1000
|
||||||
|
router_cmd.extend(["--prometheus-port", str(prometheus_port)])
|
||||||
|
|
||||||
|
# Add router-specific args
|
||||||
|
if router_args:
|
||||||
|
router_cmd.extend(router_args)
|
||||||
|
|
||||||
|
if show_output:
|
||||||
|
print(f" Command: {' '.join(router_cmd)}")
|
||||||
|
|
||||||
|
# Set up environment with backend API key
|
||||||
|
env = os.environ.copy()
|
||||||
|
if backend == "openai":
|
||||||
|
env["OPENAI_API_KEY"] = backend_api_key
|
||||||
|
else:
|
||||||
|
env["XAI_API_KEY"] = backend_api_key
|
||||||
|
|
||||||
|
# Launch router
|
||||||
|
if show_output:
|
||||||
|
router_proc = subprocess.Popen(
|
||||||
|
router_cmd,
|
||||||
|
env=env,
|
||||||
|
stdout=stdout,
|
||||||
|
stderr=stderr,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
router_proc = subprocess.Popen(
|
||||||
|
router_cmd,
|
||||||
|
stdout=stdout if stdout is not None else subprocess.PIPE,
|
||||||
|
stderr=stderr if stderr is not None else subprocess.PIPE,
|
||||||
|
env=env,
|
||||||
|
)
|
||||||
|
|
||||||
|
print(f" PID: {router_proc.pid}")
|
||||||
|
|
||||||
|
# Wait for router to be ready
|
||||||
|
router_url = f"http://127.0.0.1:{router_port}"
|
||||||
|
print(f"\nWaiting for router to start at {router_url}...")
|
||||||
|
|
||||||
|
try:
|
||||||
|
wait_for_router_ready(router_url, timeout=timeout, api_key=None)
|
||||||
|
print(f"✓ Router ready at {router_url}")
|
||||||
|
except TimeoutError:
|
||||||
|
print(f"✗ Router failed to start")
|
||||||
|
# Cleanup: kill router
|
||||||
|
try:
|
||||||
|
router_proc.kill()
|
||||||
|
except:
|
||||||
|
pass
|
||||||
|
raise
|
||||||
|
|
||||||
|
print(f"\n{'='*70}")
|
||||||
|
print(f"✓ {backend.upper()} router ready!")
|
||||||
|
print(f" Router: {router_url}")
|
||||||
|
print(f"{'='*70}\n")
|
||||||
|
|
||||||
|
return {
|
||||||
|
"router": router_proc,
|
||||||
|
"base_url": router_url,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def popen_launch_workers_and_router(
|
||||||
|
model: str,
|
||||||
|
base_url: str,
|
||||||
|
timeout: int = 300,
|
||||||
|
num_workers: int = 2,
|
||||||
|
policy: str = "round_robin",
|
||||||
|
api_key: Optional[str] = None,
|
||||||
|
worker_args: Optional[list] = None,
|
||||||
|
router_args: Optional[list] = None,
|
||||||
|
tp_size: int = 1,
|
||||||
|
env: Optional[dict] = None,
|
||||||
|
stdout=None,
|
||||||
|
stderr=None,
|
||||||
|
) -> dict:
|
||||||
|
"""
|
||||||
|
Launch SGLang workers and gRPC router separately.
|
||||||
|
|
||||||
|
This approach:
|
||||||
|
1. Starts N SGLang workers with --grpc-mode flag
|
||||||
|
2. Waits for workers to initialize (process startup)
|
||||||
|
3. Starts a gRPC router pointing to those workers
|
||||||
|
4. Waits for router health check to pass (router validates worker connectivity)
|
||||||
|
|
||||||
|
This matches production deployment patterns better than the integrated approach.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
model: Model path (e.g., /home/ubuntu/models/llama-3.1-8b-instruct)
|
||||||
|
base_url: Base URL for router (e.g., "http://127.0.0.1:8080")
|
||||||
|
timeout: Timeout for server startup (default: 300s)
|
||||||
|
num_workers: Number of workers to launch
|
||||||
|
policy: Routing policy (round_robin, random, power_of_two, cache_aware)
|
||||||
|
api_key: Optional API key for router
|
||||||
|
worker_args: Additional arguments for workers (e.g., ["--context-len", "8192"])
|
||||||
|
router_args: Additional arguments for router (e.g., ["--max-total-token", "1536"])
|
||||||
|
tp_size: Tensor parallelism size for workers (default: 1)
|
||||||
|
env: Optional environment variables for workers (e.g., {"SGLANG_CLIP_MAX_NEW_TOKENS_ESTIMATION": "256"})
|
||||||
|
stdout: Optional file handle for worker stdout (default: subprocess.PIPE)
|
||||||
|
stderr: Optional file handle for worker stderr (default: subprocess.PIPE)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
dict with:
|
||||||
|
- workers: list of worker process objects
|
||||||
|
- worker_urls: list of gRPC worker URLs
|
||||||
|
- router: router process object
|
||||||
|
- base_url: router URL (HTTP endpoint)
|
||||||
|
|
||||||
|
Example:
|
||||||
|
>>> cluster = popen_launch_workers_and_router(model, base_url, num_workers=2)
|
||||||
|
>>> # Use cluster['base_url'] for HTTP requests
|
||||||
|
>>> # Cleanup:
|
||||||
|
>>> for worker in cluster['workers']:
|
||||||
|
>>> kill_process_tree(worker.pid)
|
||||||
|
>>> kill_process_tree(cluster['router'].pid)
|
||||||
|
"""
|
||||||
|
show_output = os.environ.get("SHOW_ROUTER_LOGS", "0") == "1"
|
||||||
|
|
||||||
|
# Parse router port from base_url
|
||||||
|
if ":" in base_url.split("//")[-1]:
|
||||||
|
router_port = int(base_url.split(":")[-1])
|
||||||
|
else:
|
||||||
|
router_port = find_free_port()
|
||||||
|
|
||||||
|
print(f"\n{'='*70}")
|
||||||
|
print(f"Launching gRPC cluster (separate workers + router)")
|
||||||
|
print(f"{'='*70}")
|
||||||
|
print(f" Model: {model}")
|
||||||
|
print(f" Router port: {router_port}")
|
||||||
|
print(f" Workers: {num_workers}")
|
||||||
|
print(f" TP size: {tp_size}")
|
||||||
|
print(f" Policy: {policy}")
|
||||||
|
|
||||||
|
# Step 1: Launch workers with gRPC enabled
|
||||||
|
workers = []
|
||||||
|
worker_urls = []
|
||||||
|
|
||||||
|
for i in range(num_workers):
|
||||||
|
worker_port = find_free_port()
|
||||||
|
worker_url = f"grpc://127.0.0.1:{worker_port}"
|
||||||
|
worker_urls.append(worker_url)
|
||||||
|
|
||||||
|
print(f"\n[Worker {i+1}/{num_workers}]")
|
||||||
|
print(f" Port: {worker_port}")
|
||||||
|
print(f" URL: {worker_url}")
|
||||||
|
|
||||||
|
# Build worker command
|
||||||
|
worker_cmd = [
|
||||||
|
"python3",
|
||||||
|
"-m",
|
||||||
|
"sglang.launch_server",
|
||||||
|
"--model-path",
|
||||||
|
model,
|
||||||
|
"--host",
|
||||||
|
"127.0.0.1",
|
||||||
|
"--port",
|
||||||
|
str(worker_port),
|
||||||
|
"--grpc-mode", # Enable gRPC for this worker
|
||||||
|
"--mem-fraction-static",
|
||||||
|
"0.8",
|
||||||
|
"--attention-backend",
|
||||||
|
"fa3",
|
||||||
|
]
|
||||||
|
|
||||||
|
# Add TP size
|
||||||
|
if tp_size > 1:
|
||||||
|
worker_cmd.extend(["--tp-size", str(tp_size)])
|
||||||
|
|
||||||
|
# Add worker-specific args
|
||||||
|
if worker_args:
|
||||||
|
worker_cmd.extend(worker_args)
|
||||||
|
|
||||||
|
# Launch worker with optional environment variables
|
||||||
|
if show_output:
|
||||||
|
worker_proc = subprocess.Popen(
|
||||||
|
worker_cmd,
|
||||||
|
env=env,
|
||||||
|
stdout=stdout,
|
||||||
|
stderr=stderr,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
worker_proc = subprocess.Popen(
|
||||||
|
worker_cmd,
|
||||||
|
stdout=stdout if stdout is not None else subprocess.PIPE,
|
||||||
|
stderr=stderr if stderr is not None else subprocess.PIPE,
|
||||||
|
env=env,
|
||||||
|
)
|
||||||
|
|
||||||
|
workers.append(worker_proc)
|
||||||
|
print(f" PID: {worker_proc.pid}")
|
||||||
|
|
||||||
|
# Give workers a moment to start binding to ports
|
||||||
|
# The router will check worker health when it starts
|
||||||
|
print(f"\nWaiting for {num_workers} workers to initialize (20s)...")
|
||||||
|
time.sleep(20)
|
||||||
|
|
||||||
|
# Quick check: make sure worker processes are still alive
|
||||||
|
for i, worker in enumerate(workers):
|
||||||
|
if worker.poll() is not None:
|
||||||
|
print(f" ✗ Worker {i+1} died during startup (exit code: {worker.poll()})")
|
||||||
|
# Cleanup: kill all workers
|
||||||
|
for w in workers:
|
||||||
|
try:
|
||||||
|
w.kill()
|
||||||
|
except:
|
||||||
|
pass
|
||||||
|
raise RuntimeError(f"Worker {i+1} failed to start")
|
||||||
|
|
||||||
|
print(f"✓ All {num_workers} workers started (router will verify connectivity)")
|
||||||
|
|
||||||
|
# Step 2: Launch router pointing to workers
|
||||||
|
print(f"\n[Router]")
|
||||||
|
print(f" Port: {router_port}")
|
||||||
|
print(f" Worker URLs: {', '.join(worker_urls)}")
|
||||||
|
|
||||||
|
# Build router command
|
||||||
|
router_cmd = [
|
||||||
|
"python3",
|
||||||
|
"-m",
|
||||||
|
"sglang_router.launch_router",
|
||||||
|
"--host",
|
||||||
|
"127.0.0.1",
|
||||||
|
"--port",
|
||||||
|
str(router_port),
|
||||||
|
"--prometheus-port",
|
||||||
|
"9321",
|
||||||
|
"--policy",
|
||||||
|
policy,
|
||||||
|
"--model-path",
|
||||||
|
model,
|
||||||
|
]
|
||||||
|
|
||||||
|
# Add worker URLs
|
||||||
|
router_cmd.append("--worker-urls")
|
||||||
|
router_cmd.extend(worker_urls)
|
||||||
|
|
||||||
|
# Add API key
|
||||||
|
if api_key:
|
||||||
|
router_cmd.extend(["--api-key", api_key])
|
||||||
|
|
||||||
|
# Add router-specific args
|
||||||
|
if router_args:
|
||||||
|
router_cmd.extend(router_args)
|
||||||
|
|
||||||
|
if show_output:
|
||||||
|
print(f" Command: {' '.join(router_cmd)}")
|
||||||
|
|
||||||
|
# Launch router
|
||||||
|
if show_output:
|
||||||
|
router_proc = subprocess.Popen(router_cmd)
|
||||||
|
else:
|
||||||
|
router_proc = subprocess.Popen(
|
||||||
|
router_cmd,
|
||||||
|
stdout=subprocess.PIPE,
|
||||||
|
stderr=subprocess.PIPE,
|
||||||
|
)
|
||||||
|
|
||||||
|
print(f" PID: {router_proc.pid}")
|
||||||
|
|
||||||
|
# Wait for router to be ready
|
||||||
|
router_url = f"http://127.0.0.1:{router_port}"
|
||||||
|
print(f"\nWaiting for router to start at {router_url}...")
|
||||||
|
|
||||||
|
try:
|
||||||
|
wait_for_workers_ready(
|
||||||
|
router_url, expected_workers=num_workers, timeout=180, api_key=api_key
|
||||||
|
)
|
||||||
|
print(f"✓ Router ready at {router_url}")
|
||||||
|
except TimeoutError:
|
||||||
|
print(f"✗ Router failed to start")
|
||||||
|
# Cleanup: kill router and all workers
|
||||||
|
try:
|
||||||
|
router_proc.kill()
|
||||||
|
except:
|
||||||
|
pass
|
||||||
|
for worker in workers:
|
||||||
|
try:
|
||||||
|
worker.kill()
|
||||||
|
except:
|
||||||
|
pass
|
||||||
|
raise
|
||||||
|
|
||||||
|
print(f"\n{'='*70}")
|
||||||
|
print(f"✓ gRPC cluster ready!")
|
||||||
|
print(f" Router: {router_url}")
|
||||||
|
print(f" Workers: {len(workers)}")
|
||||||
|
print(f"{'='*70}\n")
|
||||||
|
|
||||||
|
return {
|
||||||
|
"workers": workers,
|
||||||
|
"worker_urls": worker_urls,
|
||||||
|
"router": router_proc,
|
||||||
|
"base_url": router_url,
|
||||||
|
}
|
||||||
135
sgl-router/py_test/e2e_response_api/state_management.py
Normal file
135
sgl-router/py_test/e2e_response_api/state_management.py
Normal file
@@ -0,0 +1,135 @@
|
|||||||
|
"""
|
||||||
|
State management tests for Response API.
|
||||||
|
|
||||||
|
Tests both previous_response_id and conversation-based state management.
|
||||||
|
These tests should work across all backends (OpenAI, XAI, gRPC).
|
||||||
|
"""
|
||||||
|
|
||||||
|
import unittest
|
||||||
|
|
||||||
|
from base import ResponseAPIBaseTest
|
||||||
|
|
||||||
|
|
||||||
|
class StateManagementTests(ResponseAPIBaseTest):
|
||||||
|
"""Tests for state management using previous_response_id and conversation."""
|
||||||
|
|
||||||
|
def test_previous_response_id_chaining(self):
|
||||||
|
"""Test chaining responses using previous_response_id."""
|
||||||
|
# First response
|
||||||
|
resp1 = self.create_response(
|
||||||
|
"My name is Alice and my friend is Bob. Remember it."
|
||||||
|
)
|
||||||
|
self.assertEqual(resp1.status_code, 200)
|
||||||
|
response1_id = resp1.json()["id"]
|
||||||
|
|
||||||
|
# Second response referencing first
|
||||||
|
resp2 = self.create_response(
|
||||||
|
"What is my name", previous_response_id=response1_id
|
||||||
|
)
|
||||||
|
self.assertEqual(resp2.status_code, 200)
|
||||||
|
response2_data = resp2.json()
|
||||||
|
|
||||||
|
# The model should remember the name from previous response
|
||||||
|
output_text = self._extract_output_text(response2_data)
|
||||||
|
self.assertIn("Alice", output_text)
|
||||||
|
|
||||||
|
# Third response referencing second
|
||||||
|
resp3 = self.create_response(
|
||||||
|
"What is my friend name?",
|
||||||
|
previous_response_id=response2_data["id"],
|
||||||
|
)
|
||||||
|
response3_data = resp3.json()
|
||||||
|
output_text = self._extract_output_text(response3_data)
|
||||||
|
self.assertEqual(resp3.status_code, 200)
|
||||||
|
self.assertIn("Bob", output_text)
|
||||||
|
|
||||||
|
@unittest.skip("TODO: Add the invalid previous_response_id check")
|
||||||
|
def test_previous_response_id_invalid(self):
|
||||||
|
"""Test using invalid previous_response_id."""
|
||||||
|
resp = self.create_response(
|
||||||
|
"Test", previous_response_id="resp_invalid123", max_output_tokens=50
|
||||||
|
)
|
||||||
|
# Should return 404 or 400 for invalid response ID
|
||||||
|
if resp.status_code != 200:
|
||||||
|
print(f"\n❌ Response creation failed!")
|
||||||
|
print(f"Status: {resp.status_code}")
|
||||||
|
print(f"Response: {resp.text}")
|
||||||
|
self.assertIn(resp.status_code, [400, 404])
|
||||||
|
|
||||||
|
def test_conversation_with_multiple_turns(self):
|
||||||
|
"""Test state management using conversation ID."""
|
||||||
|
# Create conversation
|
||||||
|
conv_resp = self.create_conversation(metadata={"topic": "math"})
|
||||||
|
self.assertEqual(conv_resp.status_code, 200)
|
||||||
|
|
||||||
|
conversation_id = conv_resp.json()["id"]
|
||||||
|
|
||||||
|
# First response in conversation
|
||||||
|
resp1 = self.create_response("I have 5 apples.", conversation=conversation_id)
|
||||||
|
self.assertEqual(resp1.status_code, 200)
|
||||||
|
|
||||||
|
# Second response in same conversation
|
||||||
|
resp2 = self.create_response(
|
||||||
|
"How many apples do I have?",
|
||||||
|
conversation=conversation_id,
|
||||||
|
)
|
||||||
|
self.assertEqual(resp2.status_code, 200)
|
||||||
|
output_text = self._extract_output_text(resp2.json())
|
||||||
|
|
||||||
|
# Should remember "5 apples"
|
||||||
|
self.assertTrue("5" in output_text or "five" in output_text.lower())
|
||||||
|
|
||||||
|
# Third response in same conversation
|
||||||
|
resp3 = self.create_response(
|
||||||
|
"If I get 3 more, how many total?",
|
||||||
|
conversation=conversation_id,
|
||||||
|
)
|
||||||
|
self.assertEqual(resp3.status_code, 200)
|
||||||
|
output_text = self._extract_output_text(resp3.json())
|
||||||
|
|
||||||
|
# Should calculate 5 + 3 = 8
|
||||||
|
self.assertTrue("8" in output_text or "eight" in output_text.lower())
|
||||||
|
list_resp = self.list_conversation_items(conversation_id)
|
||||||
|
self.assertEqual(list_resp.status_code, 200)
|
||||||
|
items = list_resp.json()["data"]
|
||||||
|
# Should have at least 6 items (3 inputs + 3 outputs)
|
||||||
|
self.assertGreaterEqual(len(items), 6)
|
||||||
|
|
||||||
|
def test_mutually_exclusive_parameters(self):
|
||||||
|
"""Test that previous_response_id and conversation are mutually exclusive."""
|
||||||
|
# Create conversation and response
|
||||||
|
conv_resp = self.create_conversation()
|
||||||
|
conversation_id = conv_resp.json()["id"]
|
||||||
|
|
||||||
|
resp1 = self.create_response("Test")
|
||||||
|
response1_id = resp1.json()["id"]
|
||||||
|
|
||||||
|
# Try to use both parameters
|
||||||
|
resp = self.create_response(
|
||||||
|
"This should fail",
|
||||||
|
previous_response_id=response1_id,
|
||||||
|
conversation=conversation_id,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Should return 400 Bad Request
|
||||||
|
self.assertEqual(resp.status_code, 400)
|
||||||
|
error_data = resp.json()
|
||||||
|
self.assertIn("error", error_data)
|
||||||
|
self.assertIn("mutually exclusive", error_data["error"]["message"].lower())
|
||||||
|
|
||||||
|
# Helper methods
|
||||||
|
|
||||||
|
def _extract_output_text(self, response_data: dict) -> str:
|
||||||
|
"""Extract text content from response output."""
|
||||||
|
output = response_data.get("output", [])
|
||||||
|
if not output:
|
||||||
|
return ""
|
||||||
|
|
||||||
|
text_parts = []
|
||||||
|
for item in output:
|
||||||
|
content = item.get("content", [])
|
||||||
|
for part in content:
|
||||||
|
if part.get("type") == "output_text":
|
||||||
|
text_parts.append(part.get("text", ""))
|
||||||
|
|
||||||
|
return " ".join(text_parts)
|
||||||
130
sgl-router/py_test/e2e_response_api/test_response_api.py
Normal file
130
sgl-router/py_test/e2e_response_api/test_response_api.py
Normal file
@@ -0,0 +1,130 @@
|
|||||||
|
"""
|
||||||
|
OpenAI backend tests for Response API.
|
||||||
|
|
||||||
|
Run with:
|
||||||
|
export OPENAI_API_KEY=your_key
|
||||||
|
python3 -m pytest py_test/e2e_response_api/test_openai_backend.py -v
|
||||||
|
python3 -m unittest e2e_response_api.test_openai_backend.TestOpenAIStateManagement
|
||||||
|
"""
|
||||||
|
|
||||||
|
import os
|
||||||
|
import sys
|
||||||
|
import unittest
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
# Add current directory for imports
|
||||||
|
_TEST_DIR = Path(__file__).parent
|
||||||
|
sys.path.insert(0, str(_TEST_DIR))
|
||||||
|
|
||||||
|
# Import local modules
|
||||||
|
from base import ConversationCRUDBaseTest, ResponseCRUDBaseTest
|
||||||
|
from mcp import MCPTests
|
||||||
|
from router_fixtures import (
|
||||||
|
popen_launch_openai_xai_router,
|
||||||
|
popen_launch_workers_and_router,
|
||||||
|
)
|
||||||
|
from state_management import StateManagementTests
|
||||||
|
from util import kill_process_tree
|
||||||
|
|
||||||
|
|
||||||
|
class TestOpenaiBackend(
|
||||||
|
ResponseCRUDBaseTest, ConversationCRUDBaseTest, StateManagementTests, MCPTests
|
||||||
|
):
|
||||||
|
"""End to end tests for OpenAI backend."""
|
||||||
|
|
||||||
|
api_key = os.environ.get("OPENAI_API_KEY")
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def setUpClass(cls):
|
||||||
|
cls.model = "gpt-5-nano"
|
||||||
|
cls.base_url_port = "http://127.0.0.1:30010"
|
||||||
|
|
||||||
|
cls.cluster = popen_launch_openai_xai_router(
|
||||||
|
backend="openai",
|
||||||
|
base_url=cls.base_url_port,
|
||||||
|
history_backend="memory",
|
||||||
|
)
|
||||||
|
|
||||||
|
cls.base_url = cls.cluster["base_url"]
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def tearDownClass(cls):
|
||||||
|
kill_process_tree(cls.cluster["router"].pid)
|
||||||
|
|
||||||
|
|
||||||
|
class TestXaiBackend(StateManagementTests):
|
||||||
|
"""End to end tests for XAI backend."""
|
||||||
|
|
||||||
|
api_key = os.environ.get("XAI_API_KEY")
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def setUpClass(cls):
|
||||||
|
cls.model = "grok-4-fast"
|
||||||
|
cls.base_url_port = "http://127.0.0.1:30023"
|
||||||
|
|
||||||
|
cls.cluster = popen_launch_openai_xai_router(
|
||||||
|
backend="xai",
|
||||||
|
base_url=cls.base_url_port,
|
||||||
|
history_backend="memory",
|
||||||
|
)
|
||||||
|
|
||||||
|
cls.base_url = cls.cluster["base_url"]
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def tearDownClass(cls):
|
||||||
|
kill_process_tree(cls.cluster["router"].pid)
|
||||||
|
|
||||||
|
|
||||||
|
class TestGrpcBackend(StateManagementTests, MCPTests):
|
||||||
|
"""End to end tests for gRPC backend."""
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def setUpClass(cls):
|
||||||
|
cls.model = "/home/ubuntu/models/meta-llama/Llama-3.1-8B-Instruct"
|
||||||
|
cls.base_url_port = "http://127.0.0.1:30030"
|
||||||
|
|
||||||
|
cls.cluster = popen_launch_workers_and_router(
|
||||||
|
cls.model,
|
||||||
|
cls.base_url_port,
|
||||||
|
timeout=90,
|
||||||
|
num_workers=1,
|
||||||
|
tp_size=2,
|
||||||
|
policy="round_robin",
|
||||||
|
router_args=["--history-backend", "memory"],
|
||||||
|
)
|
||||||
|
|
||||||
|
cls.base_url = cls.cluster["base_url"]
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def tearDownClass(cls):
|
||||||
|
kill_process_tree(cls.cluster["router"].pid)
|
||||||
|
for worker in cls.cluster.get("workers", []):
|
||||||
|
kill_process_tree(worker.pid)
|
||||||
|
|
||||||
|
@unittest.skip(
|
||||||
|
"TODO: transport error, details: [], metadata: MetadataMap { headers: {} }"
|
||||||
|
)
|
||||||
|
def test_previous_response_id_chaining(self):
|
||||||
|
super().test_previous_response_id_chaining()
|
||||||
|
|
||||||
|
@unittest.skip("TODO: return 501 Not Implemented")
|
||||||
|
def test_conversation_with_multiple_turns(self):
|
||||||
|
super().test_conversation_with_multiple_turns()
|
||||||
|
|
||||||
|
@unittest.skip("TODO: decode error message")
|
||||||
|
def test_mutually_exclusive_parameters(self):
|
||||||
|
super().test_mutually_exclusive_parameters()
|
||||||
|
|
||||||
|
@unittest.skip(
|
||||||
|
"TODO: Pipeline execution failed: Pipeline stage WorkerSelection failed"
|
||||||
|
)
|
||||||
|
def test_mcp_basic_tool_call(self):
|
||||||
|
super().test_mcp_basic_tool_call()
|
||||||
|
|
||||||
|
@unittest.skip("TODO: no event fields")
|
||||||
|
def test_mcp_basic_tool_call_streaming(self):
|
||||||
|
return super().test_mcp_basic_tool_call_streaming()
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
unittest.main()
|
||||||
82
sgl-router/py_test/e2e_response_api/util.py
Normal file
82
sgl-router/py_test/e2e_response_api/util.py
Normal file
@@ -0,0 +1,82 @@
|
|||||||
|
"""
|
||||||
|
Utility functions for Response API e2e tests.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import os
|
||||||
|
import signal
|
||||||
|
import threading
|
||||||
|
import unittest
|
||||||
|
|
||||||
|
import psutil
|
||||||
|
|
||||||
|
|
||||||
|
def kill_process_tree(parent_pid, include_parent: bool = True, skip_pid: int = None):
|
||||||
|
"""
|
||||||
|
Kill the process and all its child processes.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
parent_pid: PID of the parent process
|
||||||
|
include_parent: Whether to kill the parent process itself
|
||||||
|
skip_pid: Optional PID to skip during cleanup
|
||||||
|
"""
|
||||||
|
# Remove sigchld handler to avoid spammy logs
|
||||||
|
if threading.current_thread() is threading.main_thread():
|
||||||
|
signal.signal(signal.SIGCHLD, signal.SIG_DFL)
|
||||||
|
|
||||||
|
if parent_pid is None:
|
||||||
|
parent_pid = os.getpid()
|
||||||
|
include_parent = False
|
||||||
|
|
||||||
|
try:
|
||||||
|
itself = psutil.Process(parent_pid)
|
||||||
|
except psutil.NoSuchProcess:
|
||||||
|
return
|
||||||
|
|
||||||
|
children = itself.children(recursive=True)
|
||||||
|
for child in children:
|
||||||
|
if child.pid == skip_pid:
|
||||||
|
continue
|
||||||
|
try:
|
||||||
|
child.kill()
|
||||||
|
except psutil.NoSuchProcess:
|
||||||
|
pass
|
||||||
|
|
||||||
|
if include_parent:
|
||||||
|
try:
|
||||||
|
itself.kill()
|
||||||
|
except psutil.NoSuchProcess:
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class CustomTestCase(unittest.TestCase):
|
||||||
|
"""
|
||||||
|
Custom test case base class with retry support.
|
||||||
|
|
||||||
|
This provides automatic test retry functionality based on environment variables.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def _callTestMethod(self, method):
|
||||||
|
"""Override to add retry logic."""
|
||||||
|
max_retry = int(os.environ.get("SGLANG_TEST_MAX_RETRY", "0"))
|
||||||
|
|
||||||
|
if max_retry == 0:
|
||||||
|
# No retry, just run once
|
||||||
|
return super(CustomTestCase, self)._callTestMethod(method)
|
||||||
|
|
||||||
|
# Retry logic
|
||||||
|
for attempt in range(max_retry + 1):
|
||||||
|
try:
|
||||||
|
return super(CustomTestCase, self)._callTestMethod(method)
|
||||||
|
except Exception as e:
|
||||||
|
if attempt < max_retry:
|
||||||
|
print(
|
||||||
|
f"Test failed on attempt {attempt + 1}/{max_retry + 1}, retrying..."
|
||||||
|
)
|
||||||
|
continue
|
||||||
|
else:
|
||||||
|
# Last attempt, re-raise the exception
|
||||||
|
raise
|
||||||
|
|
||||||
|
def setUp(self):
|
||||||
|
"""Print test method name at the start of each test."""
|
||||||
|
print(f"[Test Method] {self._testMethodName}", flush=True)
|
||||||
Reference in New Issue
Block a user