From 795eab6dda7b7c3df552b9c44dce65c695a0f97c Mon Sep 17 00:00:00 2001 From: yichuan~ <73766326+yichuan520030910320@users.noreply.github.com> Date: Wed, 7 Aug 2024 14:52:10 +0800 Subject: [PATCH] Add support for Batch API test (#936) --- .../sglang/srt/managers/tokenizer_manager.py | 2 - python/sglang/srt/openai_api/adapter.py | 25 +++- python/sglang/srt/openai_api/protocol.py | 6 + python/sglang/srt/server.py | 7 + test/srt/test_openai_server.py | 128 ++++++++++++++++++ 5 files changed, 164 insertions(+), 4 deletions(-) diff --git a/python/sglang/srt/managers/tokenizer_manager.py b/python/sglang/srt/managers/tokenizer_manager.py index 1f45eed1f..2b3d38f78 100644 --- a/python/sglang/srt/managers/tokenizer_manager.py +++ b/python/sglang/srt/managers/tokenizer_manager.py @@ -308,7 +308,6 @@ class TokenizerManager: event = asyncio.Event() state = ReqState([], False, event) self.rid_to_state[rid] = state - # Then wait for all responses output_list = [] for i in range(batch_size): @@ -341,7 +340,6 @@ class TokenizerManager: ) assert state.finished del self.rid_to_state[rid] - yield output_list def _validate_input_length(self, input_ids: List[int]): diff --git a/python/sglang/srt/openai_api/adapter.py b/python/sglang/srt/openai_api/adapter.py index affa720f5..2b6fd961a 100644 --- a/python/sglang/srt/openai_api/adapter.py +++ b/python/sglang/srt/openai_api/adapter.py @@ -53,6 +53,7 @@ from sglang.srt.openai_api.protocol import ( CompletionStreamResponse, DeltaMessage, ErrorResponse, + FileDeleteResponse, FileRequest, FileResponse, LogProbs, @@ -174,6 +175,20 @@ async def v1_files_create(file: UploadFile, purpose: str, file_storage_pth: str return {"error": "Invalid input", "details": e.errors()} +async def v1_delete_file(file_id: str): + # Retrieve the file job from the in-memory storage + file_response = file_id_response.get(file_id) + if file_response is None: + raise HTTPException(status_code=404, detail="File not found") + file_path = file_id_storage.get(file_id) + if file_path is None: + raise HTTPException(status_code=404, detail="File not found") + os.remove(file_path) + del file_id_response[file_id] + del file_id_storage[file_id] + return FileDeleteResponse(id=file_id, deleted=True) + + async def v1_batches(tokenizer_manager, raw_request: Request): try: body = await raw_request.json() @@ -287,6 +302,13 @@ async def process_batch(tokenizer_manager, batch_id: str, batch_request: BatchRe retrieve_batch = batch_storage[batch_id] retrieve_batch.output_file_id = output_file_id file_id_storage[output_file_id] = output_file_path + file_id_response[output_file_id] = FileResponse( + id=output_file_id, + bytes=os.path.getsize(output_file_path), + created_at=int(time.time()), + filename=f"{output_file_id}.jsonl", + purpose="batch_result", + ) # Update batch status to "completed" retrieve_batch.status = "completed" retrieve_batch.completed_at = int(time.time()) @@ -380,7 +402,7 @@ def v1_generate_request(all_requests): else: prompt_kwargs = {"input_ids": prompt} else: - if isinstance(prompts[0], str) or isinstance(propmt[0][0], str): + if isinstance(prompts[0], str): prompt_kwargs = {"text": prompts} else: prompt_kwargs = {"input_ids": prompts} @@ -931,7 +953,6 @@ async def v1_chat_completions(tokenizer_manager, raw_request: Request): ).__anext__() except ValueError as e: return create_error_response(str(e)) - if not isinstance(ret, list): ret = [ret] diff --git a/python/sglang/srt/openai_api/protocol.py b/python/sglang/srt/openai_api/protocol.py index 8c079dd2a..0e9b90223 100644 --- a/python/sglang/srt/openai_api/protocol.py +++ b/python/sglang/srt/openai_api/protocol.py @@ -95,6 +95,12 @@ class FileResponse(BaseModel): purpose: str +class FileDeleteResponse(BaseModel): + id: str + object: str = "file" + deleted: bool + + class BatchRequest(BaseModel): input_file_id: ( str # The ID of an uploaded file that contains requests for the new batch diff --git a/python/sglang/srt/server.py b/python/sglang/srt/server.py index 4df1431cb..9c9540a89 100644 --- a/python/sglang/srt/server.py +++ b/python/sglang/srt/server.py @@ -59,6 +59,7 @@ from sglang.srt.openai_api.adapter import ( v1_batches, v1_chat_completions, v1_completions, + v1_delete_file, v1_files_create, v1_retrieve_batch, v1_retrieve_file, @@ -175,6 +176,12 @@ async def openai_v1_files(file: UploadFile = File(...), purpose: str = Form("bat ) +@app.delete("/v1/files/{file_id}") +async def delete_file(file_id: str): + # https://platform.openai.com/docs/api-reference/files/delete + return await v1_delete_file(file_id) + + @app.post("/v1/batches") async def openai_v1_batches(raw_request: Request): return await v1_batches(tokenizer_manager, raw_request) diff --git a/test/srt/test_openai_server.py b/test/srt/test_openai_server.py index c98728ca8..db1a3c027 100644 --- a/test/srt/test_openai_server.py +++ b/test/srt/test_openai_server.py @@ -1,4 +1,5 @@ import json +import time import unittest import openai @@ -207,6 +208,129 @@ class TestOpenAIServer(unittest.TestCase): assert response.id assert response.created + def run_batch(self, mode): + client = openai.Client(api_key=self.api_key, base_url=self.base_url) + if mode == "completion": + input_file_path = "complete_input.jsonl" + # write content to input file + content = [ + { + "custom_id": "request-1", + "method": "POST", + "url": "/v1/completions", + "body": { + "model": "gpt-3.5-turbo-instruct", + "prompt": "List 3 names of famous soccer player: ", + "max_tokens": 20, + }, + }, + { + "custom_id": "request-2", + "method": "POST", + "url": "/v1/completions", + "body": { + "model": "gpt-3.5-turbo-instruct", + "prompt": "List 6 names of famous basketball player: ", + "max_tokens": 40, + }, + }, + { + "custom_id": "request-3", + "method": "POST", + "url": "/v1/completions", + "body": { + "model": "gpt-3.5-turbo-instruct", + "prompt": "List 6 names of famous tenniss player: ", + "max_tokens": 40, + }, + }, + ] + + else: + input_file_path = "chat_input.jsonl" + content = [ + { + "custom_id": "request-1", + "method": "POST", + "url": "/v1/chat/completions", + "body": { + "model": "gpt-3.5-turbo-0125", + "messages": [ + { + "role": "system", + "content": "You are a helpful assistant.", + }, + { + "role": "user", + "content": "Hello! List 3 NBA players and tell a story", + }, + ], + "max_tokens": 30, + }, + }, + { + "custom_id": "request-2", + "method": "POST", + "url": "/v1/chat/completions", + "body": { + "model": "gpt-3.5-turbo-0125", + "messages": [ + {"role": "system", "content": "You are an assistant. "}, + { + "role": "user", + "content": "Hello! List three capital and tell a story", + }, + ], + "max_tokens": 50, + }, + }, + ] + with open(input_file_path, "w") as file: + for line in content: + file.write(json.dumps(line) + "\n") + with open(input_file_path, "rb") as file: + uploaded_file = client.files.create(file=file, purpose="batch") + if mode == "completion": + endpoint = "/v1/completions" + elif mode == "chat": + endpoint = "/v1/chat/completions" + completion_window = "24h" + batch_job = client.batches.create( + input_file_id=uploaded_file.id, + endpoint=endpoint, + completion_window=completion_window, + ) + while batch_job.status not in ["completed", "failed", "cancelled"]: + time.sleep(3) + print( + f"Batch job status: {batch_job.status}...trying again in 3 seconds..." + ) + batch_job = client.batches.retrieve(batch_job.id) + assert batch_job.status == "completed" + assert batch_job.request_counts.completed == len(content) + assert batch_job.request_counts.failed == 0 + assert batch_job.request_counts.total == len(content) + + result_file_id = batch_job.output_file_id + file_response = client.files.content(result_file_id) + result_content = file_response.read() + + if mode == "completion": + result_file_name = "batch_job_complete_results.jsonl" + else: + result_file_name = "batch_job_chat_results.jsonl" + with open(result_file_name, "wb") as file: + file.write(result_content) + results = [] + with open(result_file_name, "r", encoding="utf-8") as file: + for line in file: + json_object = json.loads(line.strip()) + results.append(json_object) + for delete_fid in [uploaded_file.id, result_file_id]: + del_pesponse = client.files.delete(delete_fid) + assert del_pesponse.deleted + assert len(results) == len(content) + def test_completion(self): for echo in [False, True]: for logprobs in [None, 5]: @@ -237,6 +361,10 @@ class TestOpenAIServer(unittest.TestCase): for logprobs in [None, 5]: self.run_chat_completion_stream(logprobs) + def test_batch(self): + for mode in ["completion", "chat"]: + self.run_batch(mode) + def test_regex(self): client = openai.Client(api_key=self.api_key, base_url=self.base_url)