Add support for Batch API test (#936)
This commit is contained in:
@@ -308,7 +308,6 @@ class TokenizerManager:
|
|||||||
event = asyncio.Event()
|
event = asyncio.Event()
|
||||||
state = ReqState([], False, event)
|
state = ReqState([], False, event)
|
||||||
self.rid_to_state[rid] = state
|
self.rid_to_state[rid] = state
|
||||||
|
|
||||||
# Then wait for all responses
|
# Then wait for all responses
|
||||||
output_list = []
|
output_list = []
|
||||||
for i in range(batch_size):
|
for i in range(batch_size):
|
||||||
@@ -341,7 +340,6 @@ class TokenizerManager:
|
|||||||
)
|
)
|
||||||
assert state.finished
|
assert state.finished
|
||||||
del self.rid_to_state[rid]
|
del self.rid_to_state[rid]
|
||||||
|
|
||||||
yield output_list
|
yield output_list
|
||||||
|
|
||||||
def _validate_input_length(self, input_ids: List[int]):
|
def _validate_input_length(self, input_ids: List[int]):
|
||||||
|
|||||||
@@ -53,6 +53,7 @@ from sglang.srt.openai_api.protocol import (
|
|||||||
CompletionStreamResponse,
|
CompletionStreamResponse,
|
||||||
DeltaMessage,
|
DeltaMessage,
|
||||||
ErrorResponse,
|
ErrorResponse,
|
||||||
|
FileDeleteResponse,
|
||||||
FileRequest,
|
FileRequest,
|
||||||
FileResponse,
|
FileResponse,
|
||||||
LogProbs,
|
LogProbs,
|
||||||
@@ -174,6 +175,20 @@ async def v1_files_create(file: UploadFile, purpose: str, file_storage_pth: str
|
|||||||
return {"error": "Invalid input", "details": e.errors()}
|
return {"error": "Invalid input", "details": e.errors()}
|
||||||
|
|
||||||
|
|
||||||
|
async def v1_delete_file(file_id: str):
|
||||||
|
# Retrieve the file job from the in-memory storage
|
||||||
|
file_response = file_id_response.get(file_id)
|
||||||
|
if file_response is None:
|
||||||
|
raise HTTPException(status_code=404, detail="File not found")
|
||||||
|
file_path = file_id_storage.get(file_id)
|
||||||
|
if file_path is None:
|
||||||
|
raise HTTPException(status_code=404, detail="File not found")
|
||||||
|
os.remove(file_path)
|
||||||
|
del file_id_response[file_id]
|
||||||
|
del file_id_storage[file_id]
|
||||||
|
return FileDeleteResponse(id=file_id, deleted=True)
|
||||||
|
|
||||||
|
|
||||||
async def v1_batches(tokenizer_manager, raw_request: Request):
|
async def v1_batches(tokenizer_manager, raw_request: Request):
|
||||||
try:
|
try:
|
||||||
body = await raw_request.json()
|
body = await raw_request.json()
|
||||||
@@ -287,6 +302,13 @@ async def process_batch(tokenizer_manager, batch_id: str, batch_request: BatchRe
|
|||||||
retrieve_batch = batch_storage[batch_id]
|
retrieve_batch = batch_storage[batch_id]
|
||||||
retrieve_batch.output_file_id = output_file_id
|
retrieve_batch.output_file_id = output_file_id
|
||||||
file_id_storage[output_file_id] = output_file_path
|
file_id_storage[output_file_id] = output_file_path
|
||||||
|
file_id_response[output_file_id] = FileResponse(
|
||||||
|
id=output_file_id,
|
||||||
|
bytes=os.path.getsize(output_file_path),
|
||||||
|
created_at=int(time.time()),
|
||||||
|
filename=f"{output_file_id}.jsonl",
|
||||||
|
purpose="batch_result",
|
||||||
|
)
|
||||||
# Update batch status to "completed"
|
# Update batch status to "completed"
|
||||||
retrieve_batch.status = "completed"
|
retrieve_batch.status = "completed"
|
||||||
retrieve_batch.completed_at = int(time.time())
|
retrieve_batch.completed_at = int(time.time())
|
||||||
@@ -380,7 +402,7 @@ def v1_generate_request(all_requests):
|
|||||||
else:
|
else:
|
||||||
prompt_kwargs = {"input_ids": prompt}
|
prompt_kwargs = {"input_ids": prompt}
|
||||||
else:
|
else:
|
||||||
if isinstance(prompts[0], str) or isinstance(propmt[0][0], str):
|
if isinstance(prompts[0], str):
|
||||||
prompt_kwargs = {"text": prompts}
|
prompt_kwargs = {"text": prompts}
|
||||||
else:
|
else:
|
||||||
prompt_kwargs = {"input_ids": prompts}
|
prompt_kwargs = {"input_ids": prompts}
|
||||||
@@ -931,7 +953,6 @@ async def v1_chat_completions(tokenizer_manager, raw_request: Request):
|
|||||||
).__anext__()
|
).__anext__()
|
||||||
except ValueError as e:
|
except ValueError as e:
|
||||||
return create_error_response(str(e))
|
return create_error_response(str(e))
|
||||||
|
|
||||||
if not isinstance(ret, list):
|
if not isinstance(ret, list):
|
||||||
ret = [ret]
|
ret = [ret]
|
||||||
|
|
||||||
|
|||||||
@@ -95,6 +95,12 @@ class FileResponse(BaseModel):
|
|||||||
purpose: str
|
purpose: str
|
||||||
|
|
||||||
|
|
||||||
|
class FileDeleteResponse(BaseModel):
|
||||||
|
id: str
|
||||||
|
object: str = "file"
|
||||||
|
deleted: bool
|
||||||
|
|
||||||
|
|
||||||
class BatchRequest(BaseModel):
|
class BatchRequest(BaseModel):
|
||||||
input_file_id: (
|
input_file_id: (
|
||||||
str # The ID of an uploaded file that contains requests for the new batch
|
str # The ID of an uploaded file that contains requests for the new batch
|
||||||
|
|||||||
@@ -59,6 +59,7 @@ from sglang.srt.openai_api.adapter import (
|
|||||||
v1_batches,
|
v1_batches,
|
||||||
v1_chat_completions,
|
v1_chat_completions,
|
||||||
v1_completions,
|
v1_completions,
|
||||||
|
v1_delete_file,
|
||||||
v1_files_create,
|
v1_files_create,
|
||||||
v1_retrieve_batch,
|
v1_retrieve_batch,
|
||||||
v1_retrieve_file,
|
v1_retrieve_file,
|
||||||
@@ -175,6 +176,12 @@ async def openai_v1_files(file: UploadFile = File(...), purpose: str = Form("bat
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@app.delete("/v1/files/{file_id}")
|
||||||
|
async def delete_file(file_id: str):
|
||||||
|
# https://platform.openai.com/docs/api-reference/files/delete
|
||||||
|
return await v1_delete_file(file_id)
|
||||||
|
|
||||||
|
|
||||||
@app.post("/v1/batches")
|
@app.post("/v1/batches")
|
||||||
async def openai_v1_batches(raw_request: Request):
|
async def openai_v1_batches(raw_request: Request):
|
||||||
return await v1_batches(tokenizer_manager, raw_request)
|
return await v1_batches(tokenizer_manager, raw_request)
|
||||||
|
|||||||
@@ -1,4 +1,5 @@
|
|||||||
import json
|
import json
|
||||||
|
import time
|
||||||
import unittest
|
import unittest
|
||||||
|
|
||||||
import openai
|
import openai
|
||||||
@@ -207,6 +208,129 @@ class TestOpenAIServer(unittest.TestCase):
|
|||||||
assert response.id
|
assert response.id
|
||||||
assert response.created
|
assert response.created
|
||||||
|
|
||||||
|
def run_batch(self, mode):
|
||||||
|
client = openai.Client(api_key=self.api_key, base_url=self.base_url)
|
||||||
|
if mode == "completion":
|
||||||
|
input_file_path = "complete_input.jsonl"
|
||||||
|
# write content to input file
|
||||||
|
content = [
|
||||||
|
{
|
||||||
|
"custom_id": "request-1",
|
||||||
|
"method": "POST",
|
||||||
|
"url": "/v1/completions",
|
||||||
|
"body": {
|
||||||
|
"model": "gpt-3.5-turbo-instruct",
|
||||||
|
"prompt": "List 3 names of famous soccer player: ",
|
||||||
|
"max_tokens": 20,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"custom_id": "request-2",
|
||||||
|
"method": "POST",
|
||||||
|
"url": "/v1/completions",
|
||||||
|
"body": {
|
||||||
|
"model": "gpt-3.5-turbo-instruct",
|
||||||
|
"prompt": "List 6 names of famous basketball player: ",
|
||||||
|
"max_tokens": 40,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"custom_id": "request-3",
|
||||||
|
"method": "POST",
|
||||||
|
"url": "/v1/completions",
|
||||||
|
"body": {
|
||||||
|
"model": "gpt-3.5-turbo-instruct",
|
||||||
|
"prompt": "List 6 names of famous tenniss player: ",
|
||||||
|
"max_tokens": 40,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
]
|
||||||
|
|
||||||
|
else:
|
||||||
|
input_file_path = "chat_input.jsonl"
|
||||||
|
content = [
|
||||||
|
{
|
||||||
|
"custom_id": "request-1",
|
||||||
|
"method": "POST",
|
||||||
|
"url": "/v1/chat/completions",
|
||||||
|
"body": {
|
||||||
|
"model": "gpt-3.5-turbo-0125",
|
||||||
|
"messages": [
|
||||||
|
{
|
||||||
|
"role": "system",
|
||||||
|
"content": "You are a helpful assistant.",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"role": "user",
|
||||||
|
"content": "Hello! List 3 NBA players and tell a story",
|
||||||
|
},
|
||||||
|
],
|
||||||
|
"max_tokens": 30,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"custom_id": "request-2",
|
||||||
|
"method": "POST",
|
||||||
|
"url": "/v1/chat/completions",
|
||||||
|
"body": {
|
||||||
|
"model": "gpt-3.5-turbo-0125",
|
||||||
|
"messages": [
|
||||||
|
{"role": "system", "content": "You are an assistant. "},
|
||||||
|
{
|
||||||
|
"role": "user",
|
||||||
|
"content": "Hello! List three capital and tell a story",
|
||||||
|
},
|
||||||
|
],
|
||||||
|
"max_tokens": 50,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
]
|
||||||
|
with open(input_file_path, "w") as file:
|
||||||
|
for line in content:
|
||||||
|
file.write(json.dumps(line) + "\n")
|
||||||
|
with open(input_file_path, "rb") as file:
|
||||||
|
uploaded_file = client.files.create(file=file, purpose="batch")
|
||||||
|
if mode == "completion":
|
||||||
|
endpoint = "/v1/completions"
|
||||||
|
elif mode == "chat":
|
||||||
|
endpoint = "/v1/chat/completions"
|
||||||
|
completion_window = "24h"
|
||||||
|
batch_job = client.batches.create(
|
||||||
|
input_file_id=uploaded_file.id,
|
||||||
|
endpoint=endpoint,
|
||||||
|
completion_window=completion_window,
|
||||||
|
)
|
||||||
|
while batch_job.status not in ["completed", "failed", "cancelled"]:
|
||||||
|
time.sleep(3)
|
||||||
|
print(
|
||||||
|
f"Batch job status: {batch_job.status}...trying again in 3 seconds..."
|
||||||
|
)
|
||||||
|
batch_job = client.batches.retrieve(batch_job.id)
|
||||||
|
assert batch_job.status == "completed"
|
||||||
|
assert batch_job.request_counts.completed == len(content)
|
||||||
|
assert batch_job.request_counts.failed == 0
|
||||||
|
assert batch_job.request_counts.total == len(content)
|
||||||
|
|
||||||
|
result_file_id = batch_job.output_file_id
|
||||||
|
file_response = client.files.content(result_file_id)
|
||||||
|
result_content = file_response.read()
|
||||||
|
|
||||||
|
if mode == "completion":
|
||||||
|
result_file_name = "batch_job_complete_results.jsonl"
|
||||||
|
else:
|
||||||
|
result_file_name = "batch_job_chat_results.jsonl"
|
||||||
|
with open(result_file_name, "wb") as file:
|
||||||
|
file.write(result_content)
|
||||||
|
results = []
|
||||||
|
with open(result_file_name, "r", encoding="utf-8") as file:
|
||||||
|
for line in file:
|
||||||
|
json_object = json.loads(line.strip())
|
||||||
|
results.append(json_object)
|
||||||
|
for delete_fid in [uploaded_file.id, result_file_id]:
|
||||||
|
del_pesponse = client.files.delete(delete_fid)
|
||||||
|
assert del_pesponse.deleted
|
||||||
|
assert len(results) == len(content)
|
||||||
|
|
||||||
def test_completion(self):
|
def test_completion(self):
|
||||||
for echo in [False, True]:
|
for echo in [False, True]:
|
||||||
for logprobs in [None, 5]:
|
for logprobs in [None, 5]:
|
||||||
@@ -237,6 +361,10 @@ class TestOpenAIServer(unittest.TestCase):
|
|||||||
for logprobs in [None, 5]:
|
for logprobs in [None, 5]:
|
||||||
self.run_chat_completion_stream(logprobs)
|
self.run_chat_completion_stream(logprobs)
|
||||||
|
|
||||||
|
def test_batch(self):
|
||||||
|
for mode in ["completion", "chat"]:
|
||||||
|
self.run_batch(mode)
|
||||||
|
|
||||||
def test_regex(self):
|
def test_regex(self):
|
||||||
client = openai.Client(api_key=self.api_key, base_url=self.base_url)
|
client = openai.Client(api_key=self.api_key, base_url=self.base_url)
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user