481 lines
17 KiB
Python
481 lines
17 KiB
Python
"""
|
|
Base test class for Response API e2e tests.
|
|
|
|
This module provides base test classes that can be reused across different backends
|
|
(OpenAI, XAI, gRPC) with common test logic.
|
|
"""
|
|
|
|
import json
|
|
import sys
|
|
import time
|
|
import unittest
|
|
from pathlib import Path
|
|
from typing import Optional
|
|
|
|
import requests
|
|
|
|
# Add current directory for local imports
|
|
_TEST_DIR = Path(__file__).parent
|
|
sys.path.insert(0, str(_TEST_DIR))
|
|
|
|
from util import CustomTestCase
|
|
|
|
|
|
class ResponseAPIBaseTest(CustomTestCase):
|
|
"""Base class for Response API tests with common utilities."""
|
|
|
|
# To be set by subclasses
|
|
base_url: str = None
|
|
api_key: str = None
|
|
model: str = None
|
|
|
|
def make_request(
|
|
self,
|
|
endpoint: str,
|
|
method: str = "POST",
|
|
json_data: Optional[dict] = None,
|
|
params: Optional[dict] = None,
|
|
) -> requests.Response:
|
|
"""
|
|
Make HTTP request to router.
|
|
|
|
Args:
|
|
endpoint: Endpoint path (e.g., "/v1/responses")
|
|
method: HTTP method (GET, POST, DELETE)
|
|
json_data: JSON body for POST requests
|
|
params: Query parameters
|
|
|
|
Returns:
|
|
requests.Response object
|
|
"""
|
|
url = f"{self.base_url}{endpoint}"
|
|
headers = {"Content-Type": "application/json"}
|
|
if self.api_key:
|
|
headers["Authorization"] = f"Bearer {self.api_key}"
|
|
|
|
if method == "POST":
|
|
resp = requests.post(url, json=json_data, headers=headers, params=params)
|
|
elif method == "GET":
|
|
resp = requests.get(url, headers=headers, params=params)
|
|
elif method == "DELETE":
|
|
resp = requests.delete(url, headers=headers, params=params)
|
|
else:
|
|
raise ValueError(f"Unsupported method: {method}")
|
|
return resp
|
|
|
|
def create_response(
|
|
self,
|
|
input_text: str,
|
|
instructions: Optional[str] = None,
|
|
stream: bool = False,
|
|
max_output_tokens: Optional[int] = None,
|
|
temperature: Optional[float] = None,
|
|
previous_response_id: Optional[str] = None,
|
|
conversation: Optional[str] = None,
|
|
tools: Optional[list] = None,
|
|
background: bool = False,
|
|
**kwargs,
|
|
) -> requests.Response:
|
|
"""
|
|
Create a response via POST /v1/responses.
|
|
|
|
Args:
|
|
input_text: User input
|
|
instructions: Optional system instructions
|
|
stream: Whether to stream response
|
|
max_output_tokens: Optional max tokens to generate
|
|
temperature: Sampling temperature
|
|
previous_response_id: Optional previous response ID for state management
|
|
conversation: Optional conversation ID for state management
|
|
tools: Optional list of MCP tools
|
|
background: Whether to run in background mode
|
|
**kwargs: Additional request parameters
|
|
|
|
Returns:
|
|
requests.Response object
|
|
"""
|
|
data = {
|
|
"model": self.model,
|
|
"input": input_text,
|
|
"stream": stream,
|
|
**kwargs,
|
|
}
|
|
|
|
if instructions:
|
|
data["instructions"] = instructions
|
|
|
|
if max_output_tokens is not None:
|
|
data["max_output_tokens"] = max_output_tokens
|
|
|
|
if temperature is not None:
|
|
data["temperature"] = temperature
|
|
|
|
if previous_response_id:
|
|
data["previous_response_id"] = previous_response_id
|
|
|
|
if conversation:
|
|
data["conversation"] = conversation
|
|
|
|
if tools:
|
|
data["tools"] = tools
|
|
|
|
if background:
|
|
data["background"] = background
|
|
|
|
if stream:
|
|
# For streaming, we need to handle SSE
|
|
return self._create_streaming_response(data)
|
|
else:
|
|
return self.make_request("/v1/responses", "POST", data)
|
|
|
|
def _create_streaming_response(self, data: dict) -> requests.Response:
|
|
"""Handle streaming response creation."""
|
|
url = f"{self.base_url}/v1/responses"
|
|
headers = {"Content-Type": "application/json"}
|
|
if self.api_key:
|
|
headers["Authorization"] = f"Bearer {self.api_key}"
|
|
|
|
# Return response object with stream=True
|
|
return requests.post(url, json=data, headers=headers, stream=True)
|
|
|
|
def get_response(self, response_id: str) -> requests.Response:
|
|
"""Get response by ID via GET /v1/responses/{response_id}."""
|
|
return self.make_request(f"/v1/responses/{response_id}", "GET")
|
|
|
|
def delete_response(self, response_id: str) -> requests.Response:
|
|
"""Delete response by ID via DELETE /v1/responses/{response_id}."""
|
|
return self.make_request(f"/v1/responses/{response_id}", "DELETE")
|
|
|
|
def cancel_response(self, response_id: str) -> requests.Response:
|
|
"""Cancel response by ID via POST /v1/responses/{response_id}/cancel."""
|
|
return self.make_request(f"/v1/responses/{response_id}/cancel", "POST", {})
|
|
|
|
def get_response_input(self, response_id: str) -> requests.Response:
|
|
"""Get response input items via GET /v1/responses/{response_id}/input."""
|
|
return self.make_request(f"/v1/responses/{response_id}/input", "GET")
|
|
|
|
def create_conversation(self, metadata: Optional[dict] = None) -> requests.Response:
|
|
"""Create conversation via POST /v1/conversations."""
|
|
data = {}
|
|
if metadata:
|
|
data["metadata"] = metadata
|
|
return self.make_request("/v1/conversations", "POST", data)
|
|
|
|
def get_conversation(self, conversation_id: str) -> requests.Response:
|
|
"""Get conversation by ID via GET /v1/conversations/{conversation_id}."""
|
|
return self.make_request(f"/v1/conversations/{conversation_id}", "GET")
|
|
|
|
def update_conversation(
|
|
self, conversation_id: str, metadata: dict
|
|
) -> requests.Response:
|
|
"""Update conversation via POST /v1/conversations/{conversation_id}."""
|
|
return self.make_request(
|
|
f"/v1/conversations/{conversation_id}", "POST", {"metadata": metadata}
|
|
)
|
|
|
|
def delete_conversation(self, conversation_id: str) -> requests.Response:
|
|
"""Delete conversation via DELETE /v1/conversations/{conversation_id}."""
|
|
return self.make_request(f"/v1/conversations/{conversation_id}", "DELETE")
|
|
|
|
def list_conversation_items(
|
|
self,
|
|
conversation_id: str,
|
|
limit: Optional[int] = None,
|
|
after: Optional[str] = None,
|
|
before: Optional[str] = None,
|
|
order: str = "asc",
|
|
) -> requests.Response:
|
|
"""List conversation items via GET /v1/conversations/{conversation_id}/items."""
|
|
params = {"order": order}
|
|
if limit:
|
|
params["limit"] = limit
|
|
if after:
|
|
params["after"] = after
|
|
if before:
|
|
params["before"] = before
|
|
return self.make_request(
|
|
f"/v1/conversations/{conversation_id}/items", "GET", params=params
|
|
)
|
|
|
|
def create_conversation_items(
|
|
self, conversation_id: str, items: list
|
|
) -> requests.Response:
|
|
"""Create conversation items via POST /v1/conversations/{conversation_id}/items."""
|
|
return self.make_request(
|
|
f"/v1/conversations/{conversation_id}/items", "POST", {"items": items}
|
|
)
|
|
|
|
def get_conversation_item(
|
|
self, conversation_id: str, item_id: str
|
|
) -> requests.Response:
|
|
"""Get conversation item via GET /v1/conversations/{conversation_id}/items/{item_id}."""
|
|
return self.make_request(
|
|
f"/v1/conversations/{conversation_id}/items/{item_id}", "GET"
|
|
)
|
|
|
|
def delete_conversation_item(
|
|
self, conversation_id: str, item_id: str
|
|
) -> requests.Response:
|
|
"""Delete conversation item via DELETE /v1/conversations/{conversation_id}/items/{item_id}."""
|
|
return self.make_request(
|
|
f"/v1/conversations/{conversation_id}/items/{item_id}", "DELETE"
|
|
)
|
|
|
|
def parse_sse_events(self, response: requests.Response) -> list:
|
|
"""
|
|
Parse Server-Sent Events from streaming response.
|
|
|
|
Args:
|
|
response: requests.Response with stream=True
|
|
|
|
Returns:
|
|
List of event dictionaries with 'event' and 'data' keys
|
|
"""
|
|
events = []
|
|
current_event = None
|
|
|
|
for line in response.iter_lines():
|
|
if not line:
|
|
# Empty line signals end of event
|
|
if current_event and current_event.get("data"):
|
|
events.append(current_event)
|
|
current_event = None
|
|
continue
|
|
|
|
line = line.decode("utf-8")
|
|
|
|
if line.startswith("event:"):
|
|
current_event = {"event": line[6:].strip()}
|
|
elif line.startswith("data:"):
|
|
if current_event is None:
|
|
current_event = {}
|
|
data_str = line[5:].strip()
|
|
try:
|
|
current_event["data"] = json.loads(data_str)
|
|
except json.JSONDecodeError:
|
|
current_event["data"] = data_str
|
|
|
|
# Don't forget the last event if stream ends without empty line
|
|
if current_event and current_event.get("data"):
|
|
events.append(current_event)
|
|
|
|
return events
|
|
|
|
def wait_for_background_task(
|
|
self, response_id: str, timeout: int = 30, poll_interval: float = 0.5
|
|
) -> dict:
|
|
"""
|
|
Wait for background task to complete.
|
|
|
|
Args:
|
|
response_id: Response ID to poll
|
|
timeout: Max seconds to wait
|
|
poll_interval: Seconds between polls
|
|
|
|
Returns:
|
|
Final response data
|
|
|
|
Raises:
|
|
TimeoutError: If task doesn't complete in time
|
|
AssertionError: If task fails
|
|
"""
|
|
start_time = time.time()
|
|
|
|
while time.time() - start_time < timeout:
|
|
resp = self.get_response(response_id)
|
|
self.assertEqual(resp.status_code, 200)
|
|
|
|
data = resp.json()
|
|
status = data.get("status")
|
|
|
|
if status == "completed":
|
|
return data
|
|
elif status == "failed":
|
|
raise AssertionError(
|
|
f"Background task failed: {data.get('error', 'Unknown error')}"
|
|
)
|
|
elif status == "cancelled":
|
|
raise AssertionError("Background task was cancelled")
|
|
|
|
time.sleep(poll_interval)
|
|
|
|
raise TimeoutError(
|
|
f"Background task {response_id} did not complete within {timeout}s"
|
|
)
|
|
|
|
|
|
class StateManagementBaseTest(ResponseAPIBaseTest):
|
|
"""Base class for state management tests (previous_response_id and conversation)."""
|
|
|
|
def test_basic_response_creation(self):
|
|
"""Test basic response creation without state."""
|
|
resp = self.create_response("What is 2+2?", max_output_tokens=50)
|
|
self.assertEqual(resp.status_code, 200)
|
|
|
|
data = resp.json()
|
|
self.assertIn("id", data)
|
|
self.assertIn("output", data)
|
|
self.assertEqual(data["status"], "completed")
|
|
self.assertIn("usage", data)
|
|
|
|
def test_streaming_response(self):
|
|
"""Test streaming response."""
|
|
resp = self.create_response("Count to 5", stream=True, max_output_tokens=50)
|
|
self.assertEqual(resp.status_code, 200)
|
|
|
|
events = self.parse_sse_events(resp)
|
|
self.assertGreater(len(events), 0)
|
|
|
|
# Check for response.created event
|
|
created_events = [e for e in events if e.get("event") == "response.created"]
|
|
self.assertGreater(len(created_events), 0)
|
|
|
|
# Check for final completed event or in_progress events
|
|
self.assertTrue(
|
|
any(
|
|
e.get("event") in ["response.completed", "response.in_progress"]
|
|
for e in events
|
|
)
|
|
)
|
|
|
|
|
|
class ResponseCRUDBaseTest(ResponseAPIBaseTest):
|
|
"""Base class for Response API CRUD tests."""
|
|
|
|
def test_create_and_get_response(self):
|
|
"""Test creating response and retrieving it."""
|
|
# Create response
|
|
create_resp = self.create_response("Hello, world!")
|
|
self.assertEqual(create_resp.status_code, 200)
|
|
|
|
create_data = create_resp.json()
|
|
response_id = create_data["id"]
|
|
|
|
# Get response
|
|
get_resp = self.get_response(response_id)
|
|
self.assertEqual(get_resp.status_code, 200)
|
|
|
|
get_data = get_resp.json()
|
|
self.assertEqual(get_data["id"], response_id)
|
|
self.assertEqual(get_data["status"], "completed")
|
|
|
|
input_resp = self.get_response_input(get_data["id"])
|
|
# change not merge yet
|
|
self.assertEqual(input_resp.status_code, 501)
|
|
# self.assertEqual(input_resp.status_code, 200)
|
|
# input_data = input_resp.json()
|
|
# self.assertIn("data", input_data)
|
|
# self.assertGreater(len(input_data["data"]), 0)
|
|
|
|
@unittest.skip("TODO: Add delete response feature")
|
|
def test_delete_response(self):
|
|
"""Test deleting response."""
|
|
# Create response
|
|
create_resp = self.create_response("Test deletion", max_output_tokens=50)
|
|
self.assertEqual(create_resp.status_code, 200)
|
|
|
|
response_id = create_resp.json()["id"]
|
|
|
|
# Delete response
|
|
delete_resp = self.delete_response(response_id)
|
|
self.assertEqual(delete_resp.status_code, 200)
|
|
|
|
# Verify it's deleted (should return 404)
|
|
get_resp = self.get_response(response_id)
|
|
self.assertEqual(get_resp.status_code, 404)
|
|
|
|
@unittest.skip("TODO: Add background response feature")
|
|
def test_background_response(self):
|
|
"""Test background response execution."""
|
|
# Create background response
|
|
create_resp = self.create_response(
|
|
"Write a short story", background=True, max_output_tokens=100
|
|
)
|
|
self.assertEqual(create_resp.status_code, 200)
|
|
|
|
create_data = create_resp.json()
|
|
response_id = create_data["id"]
|
|
self.assertEqual(create_data["status"], "in_progress")
|
|
|
|
# Wait for completion
|
|
final_data = self.wait_for_background_task(response_id, timeout=60)
|
|
self.assertEqual(final_data["status"], "completed")
|
|
|
|
|
|
class ConversationCRUDBaseTest(ResponseAPIBaseTest):
|
|
"""Base class for Conversation API CRUD tests."""
|
|
|
|
def test_create_and_get_conversation(self):
|
|
"""Test creating and retrieving conversation."""
|
|
# Create conversation
|
|
create_resp = self.create_conversation(metadata={"user": "test_user"})
|
|
self.assertEqual(create_resp.status_code, 200)
|
|
|
|
create_data = create_resp.json()
|
|
conversation_id = create_data["id"]
|
|
self.assertEqual(create_data["metadata"]["user"], "test_user")
|
|
|
|
# Get conversation
|
|
get_resp = self.get_conversation(conversation_id)
|
|
self.assertEqual(get_resp.status_code, 200)
|
|
|
|
get_data = get_resp.json()
|
|
self.assertEqual(get_data["id"], conversation_id)
|
|
self.assertEqual(get_data["metadata"]["user"], "test_user")
|
|
|
|
def test_update_conversation(self):
|
|
"""Test updating conversation metadata."""
|
|
# Create conversation
|
|
create_resp = self.create_conversation(metadata={"key1": "value1"})
|
|
self.assertEqual(create_resp.status_code, 200)
|
|
conversation_id = create_resp.json()["id"]
|
|
|
|
# Update conversation
|
|
update_resp = self.update_conversation(
|
|
conversation_id, metadata={"key1": "value1", "key2": "value2"}
|
|
)
|
|
self.assertEqual(update_resp.status_code, 200)
|
|
|
|
# Verify update
|
|
get_resp = self.get_conversation(conversation_id)
|
|
get_data = get_resp.json()
|
|
self.assertEqual(get_data["metadata"]["key2"], "value2")
|
|
|
|
def test_delete_conversation(self):
|
|
"""Test deleting conversation."""
|
|
# Create conversation
|
|
create_resp = self.create_conversation()
|
|
self.assertEqual(create_resp.status_code, 200)
|
|
conversation_id = create_resp.json()["id"]
|
|
|
|
# Delete conversation
|
|
delete_resp = self.delete_conversation(conversation_id)
|
|
self.assertEqual(delete_resp.status_code, 200)
|
|
|
|
# Verify deletion
|
|
get_resp = self.get_conversation(conversation_id)
|
|
self.assertEqual(get_resp.status_code, 404)
|
|
|
|
def test_list_conversation_items(self):
|
|
"""Test listing conversation items."""
|
|
# Create conversation
|
|
conv_resp = self.create_conversation()
|
|
conversation_id = conv_resp.json()["id"]
|
|
|
|
# Create response with conversation
|
|
self.create_response(
|
|
"First message", conversation=conversation_id, max_output_tokens=50
|
|
)
|
|
self.create_response(
|
|
"Second message", conversation=conversation_id, max_output_tokens=50
|
|
)
|
|
|
|
# List items
|
|
list_resp = self.list_conversation_items(conversation_id)
|
|
self.assertEqual(list_resp.status_code, 200)
|
|
|
|
list_data = list_resp.json()
|
|
self.assertIn("data", list_data)
|
|
# Should have at least 4 items (2 inputs + 2 outputs)
|
|
self.assertGreaterEqual(len(list_data["data"]), 4)
|