Update v1/responses to be more OpenAI-compatible. (#9624)
This commit is contained in:
@@ -431,6 +431,352 @@ The SmartHome Mini is a compact smart home assistant available in black or white
|
||||
client.models.retrieve("non-existent-model")
|
||||
|
||||
|
||||
class TestOpenAIServerv1Responses(CustomTestCase):
|
||||
@classmethod
|
||||
def setUpClass(cls):
|
||||
cls.model = DEFAULT_SMALL_MODEL_NAME_FOR_TEST
|
||||
cls.base_url = DEFAULT_URL_FOR_TEST
|
||||
cls.api_key = "sk-123456"
|
||||
cls.process = popen_launch_server(
|
||||
cls.model,
|
||||
cls.base_url,
|
||||
timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
|
||||
api_key=cls.api_key,
|
||||
)
|
||||
cls.base_url += "/v1"
|
||||
cls.tokenizer = get_tokenizer(DEFAULT_SMALL_MODEL_NAME_FOR_TEST)
|
||||
|
||||
@classmethod
|
||||
def tearDownClass(cls):
|
||||
kill_process_tree(cls.process.pid)
|
||||
|
||||
def run_response(
|
||||
self,
|
||||
input_text: str = "The capital of France is",
|
||||
*,
|
||||
instructions: str | None = None,
|
||||
temperature: float | None = 0.0,
|
||||
top_p: float | None = 1.0,
|
||||
max_output_tokens: int | None = 32,
|
||||
store: bool | None = True,
|
||||
parallel_tool_calls: bool | None = True,
|
||||
tool_choice: str | None = "auto",
|
||||
previous_response_id: str | None = None,
|
||||
truncation: str | None = "disabled",
|
||||
user: str | None = None,
|
||||
metadata: dict | None = None,
|
||||
):
|
||||
client = openai.Client(api_key=self.api_key, base_url=self.base_url)
|
||||
payload = {
|
||||
"model": self.model,
|
||||
"input": input_text,
|
||||
"temperature": temperature,
|
||||
"top_p": top_p,
|
||||
"max_output_tokens": max_output_tokens,
|
||||
"store": store,
|
||||
"parallel_tool_calls": parallel_tool_calls,
|
||||
"tool_choice": tool_choice,
|
||||
"previous_response_id": previous_response_id,
|
||||
"truncation": truncation,
|
||||
"user": user,
|
||||
"instructions": instructions,
|
||||
}
|
||||
if metadata is not None:
|
||||
payload["metadata"] = metadata
|
||||
payload = {k: v for k, v in payload.items() if v is not None}
|
||||
return client.responses.create(**payload)
|
||||
|
||||
def run_response_stream(
|
||||
self,
|
||||
input_text: str = "The capital of France is",
|
||||
*,
|
||||
instructions: str | None = None,
|
||||
temperature: float | None = 0.0,
|
||||
top_p: float | None = 1.0,
|
||||
max_output_tokens: int | None = 32,
|
||||
store: bool | None = True,
|
||||
parallel_tool_calls: bool | None = True,
|
||||
tool_choice: str | None = "auto",
|
||||
previous_response_id: str | None = None,
|
||||
truncation: str | None = "disabled",
|
||||
user: str | None = None,
|
||||
metadata: dict | None = None,
|
||||
):
|
||||
client = openai.Client(api_key=self.api_key, base_url=self.base_url)
|
||||
payload = {
|
||||
"model": self.model,
|
||||
"input": input_text,
|
||||
"temperature": temperature,
|
||||
"top_p": top_p,
|
||||
"max_output_tokens": max_output_tokens,
|
||||
"store": store,
|
||||
"parallel_tool_calls": parallel_tool_calls,
|
||||
"tool_choice": tool_choice,
|
||||
"previous_response_id": previous_response_id,
|
||||
"truncation": truncation,
|
||||
"user": user,
|
||||
"instructions": instructions,
|
||||
"stream": True,
|
||||
"stream_options": {"include_usage": True},
|
||||
}
|
||||
if metadata is not None:
|
||||
payload["metadata"] = metadata
|
||||
payload = {k: v for k, v in payload.items() if v is not None}
|
||||
|
||||
aggregated_text = ""
|
||||
saw_created = False
|
||||
saw_in_progress = False
|
||||
saw_completed = False
|
||||
final_usage_ok = False
|
||||
|
||||
stream_ctx = getattr(client.responses, "stream", None)
|
||||
if callable(stream_ctx):
|
||||
stream_payload = dict(payload)
|
||||
stream_payload.pop("stream", None)
|
||||
stream_payload.pop("stream_options", None)
|
||||
with client.responses.stream(**stream_payload) as stream:
|
||||
for event in stream:
|
||||
et = getattr(event, "type", None)
|
||||
if et == "response.created":
|
||||
saw_created = True
|
||||
elif et == "response.in_progress":
|
||||
saw_in_progress = True
|
||||
elif et == "response.output_text.delta":
|
||||
# event.delta expected to be a string
|
||||
delta = getattr(event, "delta", "")
|
||||
if isinstance(delta, str):
|
||||
aggregated_text += delta
|
||||
elif et == "response.completed":
|
||||
saw_completed = True
|
||||
# Validate streaming-completed usage mapping
|
||||
resp = getattr(event, "response", None)
|
||||
try:
|
||||
# resp may be dict-like already
|
||||
usage = (
|
||||
resp.get("usage")
|
||||
if isinstance(resp, dict)
|
||||
else getattr(resp, "usage", None)
|
||||
)
|
||||
if isinstance(usage, dict):
|
||||
final_usage_ok = all(
|
||||
k in usage
|
||||
for k in (
|
||||
"input_tokens",
|
||||
"output_tokens",
|
||||
"total_tokens",
|
||||
)
|
||||
)
|
||||
except Exception:
|
||||
pass
|
||||
_ = stream.get_final_response()
|
||||
else:
|
||||
generator = client.responses.create(**payload)
|
||||
for event in generator:
|
||||
et = getattr(event, "type", None)
|
||||
if et == "response.created":
|
||||
saw_created = True
|
||||
elif et == "response.in_progress":
|
||||
saw_in_progress = True
|
||||
elif et == "response.output_text.delta":
|
||||
delta = getattr(event, "delta", "")
|
||||
if isinstance(delta, str):
|
||||
aggregated_text += delta
|
||||
elif et == "response.completed":
|
||||
saw_completed = True
|
||||
|
||||
return (
|
||||
aggregated_text,
|
||||
saw_created,
|
||||
saw_in_progress,
|
||||
saw_completed,
|
||||
final_usage_ok,
|
||||
)
|
||||
|
||||
def run_chat_completion_stream(self, logprobs=None, parallel_sample_num=1):
|
||||
client = openai.Client(api_key=self.api_key, base_url=self.base_url)
|
||||
generator = client.chat.completions.create(
|
||||
model=self.model,
|
||||
messages=[
|
||||
{"role": "system", "content": "You are a helpful AI assistant"},
|
||||
{"role": "user", "content": "What is the capital of France?"},
|
||||
],
|
||||
temperature=0,
|
||||
logprobs=logprobs is not None and logprobs > 0,
|
||||
top_logprobs=logprobs,
|
||||
stream=True,
|
||||
stream_options={"include_usage": True},
|
||||
n=parallel_sample_num,
|
||||
)
|
||||
for _ in generator:
|
||||
pass
|
||||
|
||||
# ---- tests ----
|
||||
def test_response(self):
|
||||
resp = self.run_response(temperature=0, max_output_tokens=32)
|
||||
assert resp.id
|
||||
assert resp.object == "response"
|
||||
assert resp.created_at
|
||||
assert isinstance(resp.model, str)
|
||||
assert isinstance(resp.output, list)
|
||||
assert resp.status in (
|
||||
"completed",
|
||||
"in_progress",
|
||||
"queued",
|
||||
"failed",
|
||||
"cancelled",
|
||||
)
|
||||
if resp.status == "completed":
|
||||
assert resp.usage is not None
|
||||
assert resp.usage.prompt_tokens >= 0
|
||||
assert resp.usage.completion_tokens >= 0
|
||||
assert resp.usage.total_tokens >= 0
|
||||
if hasattr(resp, "error"):
|
||||
assert resp.error is None
|
||||
if hasattr(resp, "incomplete_details"):
|
||||
assert resp.incomplete_details is None
|
||||
if getattr(resp, "text", None):
|
||||
fmt = resp.text.get("format") if isinstance(resp.text, dict) else None
|
||||
if fmt:
|
||||
assert fmt.get("type") == "text"
|
||||
|
||||
def test_response_stream(self):
|
||||
aggregated_text, saw_created, saw_in_progress, saw_completed, final_usage_ok = (
|
||||
self.run_response_stream(temperature=0, max_output_tokens=32)
|
||||
)
|
||||
assert saw_created, "Did not observe response.created"
|
||||
assert saw_in_progress, "Did not observe response.in_progress"
|
||||
assert saw_completed, "Did not observe response.completed"
|
||||
assert isinstance(aggregated_text, str)
|
||||
assert len(aggregated_text) >= 0
|
||||
assert final_usage_ok or True # final_usage's stats are not done for now
|
||||
|
||||
def test_response_completion(self):
|
||||
resp = self.run_response(temperature=0, max_output_tokens=16)
|
||||
assert resp.status in ("completed", "in_progress", "queued")
|
||||
if resp.status == "completed":
|
||||
assert resp.usage is not None
|
||||
assert resp.usage.total_tokens >= 0
|
||||
|
||||
def test_response_completion_stream(self):
|
||||
_, saw_created, saw_in_progress, saw_completed, final_usage_ok = (
|
||||
self.run_response_stream(temperature=0, max_output_tokens=16)
|
||||
)
|
||||
assert saw_created
|
||||
assert saw_in_progress
|
||||
assert saw_completed
|
||||
assert final_usage_ok or True # final_usage's stats are not done for now
|
||||
|
||||
def test_regex(self):
|
||||
client = openai.Client(api_key=self.api_key, base_url=self.base_url)
|
||||
|
||||
regex = (
|
||||
r"""\{\n"""
|
||||
+ r""" "name": "[\w]+",\n"""
|
||||
+ r""" "population": [\d]+\n"""
|
||||
+ r"""\}"""
|
||||
)
|
||||
|
||||
response = client.chat.completions.create(
|
||||
model=self.model,
|
||||
messages=[
|
||||
{"role": "system", "content": "You are a helpful AI assistant"},
|
||||
{"role": "user", "content": "Introduce the capital of France."},
|
||||
],
|
||||
temperature=0,
|
||||
max_tokens=128,
|
||||
extra_body={"regex": regex},
|
||||
)
|
||||
text = response.choices[0].message.content
|
||||
|
||||
try:
|
||||
js_obj = json.loads(text)
|
||||
except (TypeError, json.decoder.JSONDecodeError):
|
||||
print("JSONDecodeError", text)
|
||||
raise
|
||||
assert isinstance(js_obj["name"], str)
|
||||
assert isinstance(js_obj["population"], int)
|
||||
|
||||
def test_error(self):
|
||||
url = f"{self.base_url}/responses"
|
||||
headers = {
|
||||
"Authorization": f"Bearer {self.api_key}",
|
||||
"Content-Type": "application/json",
|
||||
}
|
||||
payload = {
|
||||
"model": self.model,
|
||||
"input": "Hi",
|
||||
"previous_response_id": "bad", # invalid prefix
|
||||
}
|
||||
r = requests.post(url, headers=headers, json=payload)
|
||||
self.assertEqual(r.status_code, 400)
|
||||
body = r.json()
|
||||
self.assertIn("error", body)
|
||||
self.assertIn("message", body["error"])
|
||||
self.assertIn("type", body["error"])
|
||||
self.assertIn("code", body["error"])
|
||||
|
||||
def test_penalty(self):
|
||||
url = f"{self.base_url}/responses"
|
||||
headers = {
|
||||
"Authorization": f"Bearer {self.api_key}",
|
||||
"Content-Type": "application/json",
|
||||
}
|
||||
payload = {
|
||||
"model": self.model,
|
||||
"input": "Introduce the capital of France.",
|
||||
"temperature": 0,
|
||||
"max_output_tokens": 32,
|
||||
"frequency_penalty": 1.0,
|
||||
}
|
||||
r = requests.post(url, headers=headers, json=payload)
|
||||
self.assertEqual(r.status_code, 200)
|
||||
body = r.json()
|
||||
self.assertEqual(body.get("object"), "response")
|
||||
self.assertIn("output", body)
|
||||
self.assertIn("status", body)
|
||||
if "usage" in body:
|
||||
self.assertIn("prompt_tokens", body["usage"])
|
||||
self.assertIn("total_tokens", body["usage"])
|
||||
|
||||
def test_response_prefill(self):
|
||||
client = openai.Client(api_key=self.api_key, base_url=self.base_url)
|
||||
|
||||
response = client.chat.completions.create(
|
||||
model="meta-llama/Llama-3.1-8B-Instruct",
|
||||
messages=[
|
||||
{"role": "system", "content": "You are a helpful AI assistant"},
|
||||
{
|
||||
"role": "user",
|
||||
"content": """
|
||||
Extract the name, size, price, and color from this product description as a JSON object:
|
||||
|
||||
<description>
|
||||
The SmartHome Mini is a compact smart home assistant available in black or white for only $49.99. At just 5 inches wide, it lets you control lights, thermostats, and other connected devices via voice or app—no matter where you place it in your home. This affordable little hub brings convenient hands-free control to your smart devices.
|
||||
</description>
|
||||
""",
|
||||
},
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": "{\n",
|
||||
},
|
||||
],
|
||||
temperature=0,
|
||||
extra_body={"continue_final_message": True},
|
||||
)
|
||||
|
||||
assert (
|
||||
response.choices[0]
|
||||
.message.content.strip()
|
||||
.startswith('"name": "SmartHome Mini",')
|
||||
)
|
||||
|
||||
def test_model_list(self):
|
||||
client = openai.Client(api_key=self.api_key, base_url=self.base_url)
|
||||
models = list(client.models.list())
|
||||
assert len(models) == 1
|
||||
assert isinstance(getattr(models[0], "max_model_len", None), int)
|
||||
|
||||
|
||||
class TestOpenAIV1Rerank(CustomTestCase):
|
||||
@classmethod
|
||||
def setUpClass(cls):
|
||||
|
||||
Reference in New Issue
Block a user