Add support for Batch API test (#936)
This commit is contained in:
@@ -1,4 +1,5 @@
|
||||
import json
|
||||
import time
|
||||
import unittest
|
||||
|
||||
import openai
|
||||
@@ -207,6 +208,129 @@ class TestOpenAIServer(unittest.TestCase):
|
||||
assert response.id
|
||||
assert response.created
|
||||
|
||||
def run_batch(self, mode):
|
||||
client = openai.Client(api_key=self.api_key, base_url=self.base_url)
|
||||
if mode == "completion":
|
||||
input_file_path = "complete_input.jsonl"
|
||||
# write content to input file
|
||||
content = [
|
||||
{
|
||||
"custom_id": "request-1",
|
||||
"method": "POST",
|
||||
"url": "/v1/completions",
|
||||
"body": {
|
||||
"model": "gpt-3.5-turbo-instruct",
|
||||
"prompt": "List 3 names of famous soccer player: ",
|
||||
"max_tokens": 20,
|
||||
},
|
||||
},
|
||||
{
|
||||
"custom_id": "request-2",
|
||||
"method": "POST",
|
||||
"url": "/v1/completions",
|
||||
"body": {
|
||||
"model": "gpt-3.5-turbo-instruct",
|
||||
"prompt": "List 6 names of famous basketball player: ",
|
||||
"max_tokens": 40,
|
||||
},
|
||||
},
|
||||
{
|
||||
"custom_id": "request-3",
|
||||
"method": "POST",
|
||||
"url": "/v1/completions",
|
||||
"body": {
|
||||
"model": "gpt-3.5-turbo-instruct",
|
||||
"prompt": "List 6 names of famous tenniss player: ",
|
||||
"max_tokens": 40,
|
||||
},
|
||||
},
|
||||
]
|
||||
|
||||
else:
|
||||
input_file_path = "chat_input.jsonl"
|
||||
content = [
|
||||
{
|
||||
"custom_id": "request-1",
|
||||
"method": "POST",
|
||||
"url": "/v1/chat/completions",
|
||||
"body": {
|
||||
"model": "gpt-3.5-turbo-0125",
|
||||
"messages": [
|
||||
{
|
||||
"role": "system",
|
||||
"content": "You are a helpful assistant.",
|
||||
},
|
||||
{
|
||||
"role": "user",
|
||||
"content": "Hello! List 3 NBA players and tell a story",
|
||||
},
|
||||
],
|
||||
"max_tokens": 30,
|
||||
},
|
||||
},
|
||||
{
|
||||
"custom_id": "request-2",
|
||||
"method": "POST",
|
||||
"url": "/v1/chat/completions",
|
||||
"body": {
|
||||
"model": "gpt-3.5-turbo-0125",
|
||||
"messages": [
|
||||
{"role": "system", "content": "You are an assistant. "},
|
||||
{
|
||||
"role": "user",
|
||||
"content": "Hello! List three capital and tell a story",
|
||||
},
|
||||
],
|
||||
"max_tokens": 50,
|
||||
},
|
||||
},
|
||||
]
|
||||
with open(input_file_path, "w") as file:
|
||||
for line in content:
|
||||
file.write(json.dumps(line) + "\n")
|
||||
with open(input_file_path, "rb") as file:
|
||||
uploaded_file = client.files.create(file=file, purpose="batch")
|
||||
if mode == "completion":
|
||||
endpoint = "/v1/completions"
|
||||
elif mode == "chat":
|
||||
endpoint = "/v1/chat/completions"
|
||||
completion_window = "24h"
|
||||
batch_job = client.batches.create(
|
||||
input_file_id=uploaded_file.id,
|
||||
endpoint=endpoint,
|
||||
completion_window=completion_window,
|
||||
)
|
||||
while batch_job.status not in ["completed", "failed", "cancelled"]:
|
||||
time.sleep(3)
|
||||
print(
|
||||
f"Batch job status: {batch_job.status}...trying again in 3 seconds..."
|
||||
)
|
||||
batch_job = client.batches.retrieve(batch_job.id)
|
||||
assert batch_job.status == "completed"
|
||||
assert batch_job.request_counts.completed == len(content)
|
||||
assert batch_job.request_counts.failed == 0
|
||||
assert batch_job.request_counts.total == len(content)
|
||||
|
||||
result_file_id = batch_job.output_file_id
|
||||
file_response = client.files.content(result_file_id)
|
||||
result_content = file_response.read()
|
||||
|
||||
if mode == "completion":
|
||||
result_file_name = "batch_job_complete_results.jsonl"
|
||||
else:
|
||||
result_file_name = "batch_job_chat_results.jsonl"
|
||||
with open(result_file_name, "wb") as file:
|
||||
file.write(result_content)
|
||||
results = []
|
||||
with open(result_file_name, "r", encoding="utf-8") as file:
|
||||
for line in file:
|
||||
json_object = json.loads(line.strip())
|
||||
results.append(json_object)
|
||||
for delete_fid in [uploaded_file.id, result_file_id]:
|
||||
del_pesponse = client.files.delete(delete_fid)
|
||||
assert del_pesponse.deleted
|
||||
assert len(results) == len(content)
|
||||
|
||||
def test_completion(self):
|
||||
for echo in [False, True]:
|
||||
for logprobs in [None, 5]:
|
||||
@@ -237,6 +361,10 @@ class TestOpenAIServer(unittest.TestCase):
|
||||
for logprobs in [None, 5]:
|
||||
self.run_chat_completion_stream(logprobs)
|
||||
|
||||
def test_batch(self):
|
||||
for mode in ["completion", "chat"]:
|
||||
self.run_batch(mode)
|
||||
|
||||
def test_regex(self):
|
||||
client = openai.Client(api_key=self.api_key, base_url=self.base_url)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user