From 084fa54d371e439cbca8b21930c6f658c1ef4671 Mon Sep 17 00:00:00 2001 From: yichuan~ <73766326+yichuan520030910320@users.noreply.github.com> Date: Tue, 30 Jul 2024 04:07:18 +0800 Subject: [PATCH] Add support for OpenAI API : offline batch(file) processing (#699) Co-authored-by: hnyls2002 --- .pre-commit-config.yaml | 2 +- examples/usage/openai_batch_chat.py | 86 +++ examples/usage/openai_batch_complete.py | 86 +++ examples/usage/openai_parallel_sample.py | 37 + python/sglang/srt/managers/io_struct.py | 22 +- .../sglang/srt/managers/tokenizer_manager.py | 24 +- python/sglang/srt/openai_api/adapter.py | 644 ++++++++++++++---- python/sglang/srt/openai_api/protocol.py | 49 ++ python/sglang/srt/server.py | 36 +- python/sglang/srt/server_args.py | 7 + 10 files changed, 839 insertions(+), 154 deletions(-) create mode 100644 examples/usage/openai_batch_chat.py create mode 100644 examples/usage/openai_batch_complete.py diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 393c999d2..2fa1254a6 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -4,6 +4,6 @@ repos: hooks: - id: isort - repo: https://github.com/psf/black - rev: stable + rev: 24.4.2 hooks: - id: black diff --git a/examples/usage/openai_batch_chat.py b/examples/usage/openai_batch_chat.py new file mode 100644 index 000000000..cffa50c67 --- /dev/null +++ b/examples/usage/openai_batch_chat.py @@ -0,0 +1,86 @@ +import json +import os +import time + +import openai +from openai import OpenAI + + +class OpenAIBatchProcessor: + def __init__(self, api_key): + # client = OpenAI(api_key=api_key) + client = openai.Client(base_url="http://127.0.0.1:30000/v1", api_key="EMPTY") + + self.client = client + + def process_batch(self, input_file_path, endpoint, completion_window): + + # Upload the input file + with open(input_file_path, "rb") as file: + uploaded_file = self.client.files.create(file=file, purpose="batch") + + # Create the batch job + batch_job = self.client.batches.create( + input_file_id=uploaded_file.id, + endpoint=endpoint, + completion_window=completion_window, + ) + + # Monitor the batch job status + while batch_job.status not in ["completed", "failed", "cancelled"]: + time.sleep(3) # Wait for 3 seconds before checking the status again + print( + f"Batch job status: {batch_job.status}...trying again in 3 seconds..." + ) + batch_job = self.client.batches.retrieve(batch_job.id) + + # Check the batch job status and errors + if batch_job.status == "failed": + print(f"Batch job failed with status: {batch_job.status}") + print(f"Batch job errors: {batch_job.errors}") + return None + + # If the batch job is completed, process the results + if batch_job.status == "completed": + + # print result of batch job + print("batch", batch_job.request_counts) + + result_file_id = batch_job.output_file_id + # Retrieve the file content from the server + file_response = self.client.files.content(result_file_id) + result_content = file_response.read() # Read the content of the file + + # Save the content to a local file + result_file_name = "batch_job_chat_results.jsonl" + with open(result_file_name, "wb") as file: + file.write(result_content) # Write the binary content to the file + # Load data from the saved JSONL file + results = [] + with open(result_file_name, "r", encoding="utf-8") as file: + for line in file: + json_object = json.loads( + line.strip() + ) # Parse each line as a JSON object + results.append(json_object) + + return results + else: + print(f"Batch job failed with status: {batch_job.status}") + return None + + +# Initialize the OpenAIBatchProcessor +api_key = os.environ.get("OPENAI_API_KEY") +processor = OpenAIBatchProcessor(api_key) + +# Process the batch job +input_file_path = "input.jsonl" +endpoint = "/v1/chat/completions" +completion_window = "24h" + +# Process the batch job +results = processor.process_batch(input_file_path, endpoint, completion_window) + +# Print the results +print(results) diff --git a/examples/usage/openai_batch_complete.py b/examples/usage/openai_batch_complete.py new file mode 100644 index 000000000..3cf2ede0b --- /dev/null +++ b/examples/usage/openai_batch_complete.py @@ -0,0 +1,86 @@ +import json +import os +import time + +import openai +from openai import OpenAI + + +class OpenAIBatchProcessor: + def __init__(self, api_key): + # client = OpenAI(api_key=api_key) + client = openai.Client(base_url="http://127.0.0.1:30000/v1", api_key="EMPTY") + + self.client = client + + def process_batch(self, input_file_path, endpoint, completion_window): + + # Upload the input file + with open(input_file_path, "rb") as file: + uploaded_file = self.client.files.create(file=file, purpose="batch") + + # Create the batch job + batch_job = self.client.batches.create( + input_file_id=uploaded_file.id, + endpoint=endpoint, + completion_window=completion_window, + ) + + # Monitor the batch job status + while batch_job.status not in ["completed", "failed", "cancelled"]: + time.sleep(3) # Wait for 3 seconds before checking the status again + print( + f"Batch job status: {batch_job.status}...trying again in 3 seconds..." + ) + batch_job = self.client.batches.retrieve(batch_job.id) + + # Check the batch job status and errors + if batch_job.status == "failed": + print(f"Batch job failed with status: {batch_job.status}") + print(f"Batch job errors: {batch_job.errors}") + return None + + # If the batch job is completed, process the results + if batch_job.status == "completed": + + # print result of batch job + print("batch", batch_job.request_counts) + + result_file_id = batch_job.output_file_id + # Retrieve the file content from the server + file_response = self.client.files.content(result_file_id) + result_content = file_response.read() # Read the content of the file + + # Save the content to a local file + result_file_name = "batch_job_complete_results.jsonl" + with open(result_file_name, "wb") as file: + file.write(result_content) # Write the binary content to the file + # Load data from the saved JSONL file + results = [] + with open(result_file_name, "r", encoding="utf-8") as file: + for line in file: + json_object = json.loads( + line.strip() + ) # Parse each line as a JSON object + results.append(json_object) + + return results + else: + print(f"Batch job failed with status: {batch_job.status}") + return None + + +# Initialize the OpenAIBatchProcessor +api_key = os.environ.get("OPENAI_API_KEY") +processor = OpenAIBatchProcessor(api_key) + +# Process the batch job +input_file_path = "input_complete.jsonl" +endpoint = "/v1/completions" +completion_window = "24h" + +# Process the batch job +results = processor.process_batch(input_file_path, endpoint, completion_window) + +# Print the results +print(results) diff --git a/examples/usage/openai_parallel_sample.py b/examples/usage/openai_parallel_sample.py index d2d1e406f..0d3a372b4 100644 --- a/examples/usage/openai_parallel_sample.py +++ b/examples/usage/openai_parallel_sample.py @@ -13,6 +13,17 @@ response = client.completions.create( print(response) +# Text completion +response = client.completions.create( + model="default", + prompt="I am a robot and I want to study like humans. Now let's tell a story. Once upon a time, there was a little", + n=1, + temperature=0.8, + max_tokens=32, +) +print(response) + + # Text completion response = client.completions.create( model="default", @@ -24,6 +35,17 @@ response = client.completions.create( print(response) +# Text completion +response = client.completions.create( + model="default", + prompt=["The name of the famous soccer player is"], + n=1, + temperature=0.8, + max_tokens=128, +) +print(response) + + # Text completion response = client.completions.create( model="default", @@ -60,6 +82,21 @@ response = client.completions.create( ) print(response) +# Chat completion +response = client.chat.completions.create( + model="default", + messages=[ + {"role": "system", "content": "You are a helpful AI assistant"}, + {"role": "user", "content": "List 3 countries and their capitals."}, + ], + temperature=0.8, + max_tokens=64, + logprobs=True, + n=1, +) +print(response) + + # Chat completion response = client.chat.completions.create( model="default", diff --git a/python/sglang/srt/managers/io_struct.py b/python/sglang/srt/managers/io_struct.py index 5698b0264..f0b927a69 100644 --- a/python/sglang/srt/managers/io_struct.py +++ b/python/sglang/srt/managers/io_struct.py @@ -79,8 +79,26 @@ class GenerateReqInput: if self.top_logprobs_num is None: self.top_logprobs_num = 0 else: - - parallel_sample_num = self.sampling_params.get("n", 1) + parallel_sample_num_list = [] + if isinstance(self.sampling_params, dict): + parallel_sample_num = self.sampling_params.get("n", 1) + elif isinstance(self.sampling_params, list): + for sp in self.sampling_params: + parallel_sample_num = sp.get("n", 1) + parallel_sample_num_list.append(parallel_sample_num) + parallel_sample_num = max(parallel_sample_num_list) + all_equal = all( + element == parallel_sample_num + for element in parallel_sample_num_list + ) + if parallel_sample_num > 1 and (not all_equal): + ## TODO cope with the case that the parallel_sample_num is different for different samples + raise ValueError( + "The parallel_sample_num should be the same for all samples in sample params." + ) + else: + parallel_sample_num = 1 + self.parallel_sample_num = parallel_sample_num if parallel_sample_num != 1: # parallel sampling +1 represents the original prefill stage diff --git a/python/sglang/srt/managers/tokenizer_manager.py b/python/sglang/srt/managers/tokenizer_manager.py index 6e29aee0b..efdc933cd 100644 --- a/python/sglang/srt/managers/tokenizer_manager.py +++ b/python/sglang/srt/managers/tokenizer_manager.py @@ -84,6 +84,7 @@ class TokenizerManager: trust_remote_code=server_args.trust_remote_code, model_overide_args=model_overide_args, ) + if server_args.context_length is not None: self.context_len = server_args.context_length else: @@ -152,31 +153,33 @@ class TokenizerManager: self, obj, request, index=None, is_cache_for_prefill=False ): if not is_cache_for_prefill: - rid = obj.rid if index is None else obj.rid[index] - input_text = obj.text if index is None else obj.text[index] + not_use_index = not (index is not None) + rid = obj.rid if not_use_index else obj.rid[index] + input_text = obj.text if not_use_index else obj.text[index] input_ids = ( self.tokenizer.encode(input_text) if obj.input_ids is None else obj.input_ids ) - if index is not None and obj.input_ids: + if not not_use_index and obj.input_ids: input_ids = obj.input_ids[index] self._validate_input_length(input_ids) + sampling_params = self._get_sampling_params( - obj.sampling_params if index is None else obj.sampling_params[index] + obj.sampling_params if not_use_index else obj.sampling_params[index] ) pixel_values, image_hash, image_size = await self._get_pixel_values( - obj.image_data if index is None else obj.image_data[index] + obj.image_data if not_use_index else obj.image_data[index] ) return_logprob = ( - obj.return_logprob if index is None else obj.return_logprob[index] + obj.return_logprob if not_use_index else obj.return_logprob[index] ) logprob_start_len = ( - obj.logprob_start_len if index is None else obj.logprob_start_len[index] + obj.logprob_start_len if not_use_index else obj.logprob_start_len[index] ) top_logprobs_num = ( - obj.top_logprobs_num if index is None else obj.top_logprobs_num[index] + obj.top_logprobs_num if not_use_index else obj.top_logprobs_num[index] ) else: if isinstance(obj.text, list): @@ -224,7 +227,7 @@ class TokenizerManager: async def _handle_batch_request(self, obj: GenerateReqInput, request): batch_size = obj.batch_size - parallel_sample_num = obj.sampling_params[0].get("n", 1) + parallel_sample_num = obj.parallel_sample_num if parallel_sample_num != 1: # Send prefill requests to cache the common input @@ -241,7 +244,6 @@ class TokenizerManager: obj.input_ids = input_id_result elif input_id_result is not None: obj.input_ids = input_id_result[0] - # First send out all requests for i in range(batch_size): for j in range(parallel_sample_num): @@ -249,7 +251,7 @@ class TokenizerManager: continue index = i * parallel_sample_num + j if parallel_sample_num != 1: - # Here when using parallel sampling we shoul consider prefill stage so the index is : j + i * (parallel_sample_num-1) + batch_size - 1 + # Here when using parallel sampling we should consider prefill stage so the index is : j + i * (parallel_sample_num-1) + batch_size - 1 index += batch_size - 1 - i rid = obj.rid[index] if parallel_sample_num == 1: diff --git a/python/sglang/srt/openai_api/adapter.py b/python/sglang/srt/openai_api/adapter.py index 7f1b8fd2f..5fa75f1b8 100644 --- a/python/sglang/srt/openai_api/adapter.py +++ b/python/sglang/srt/openai_api/adapter.py @@ -18,10 +18,14 @@ limitations under the License. import asyncio import json import os +import time +import uuid from http import HTTPStatus +from typing import Dict, List, Optional -from fastapi import Request +from fastapi import HTTPException, Request, UploadFile from fastapi.responses import JSONResponse, StreamingResponse +from pydantic import ValidationError from sglang.srt.conversation import ( Conversation, @@ -32,6 +36,8 @@ from sglang.srt.conversation import ( ) from sglang.srt.managers.io_struct import GenerateReqInput from sglang.srt.openai_api.protocol import ( + BatchRequest, + BatchResponse, ChatCompletionRequest, ChatCompletionResponse, ChatCompletionResponseChoice, @@ -45,6 +51,8 @@ from sglang.srt.openai_api.protocol import ( CompletionStreamResponse, DeltaMessage, ErrorResponse, + FileRequest, + FileResponse, LogProbs, UsageInfo, ) @@ -52,6 +60,24 @@ from sglang.srt.openai_api.protocol import ( chat_template_name = None +class FileMetadata: + def __init__(self, filename: str, purpose: str): + self.filename = filename + self.purpose = purpose + + +# In-memory storage for batch jobs and files +batch_storage: Dict[str, BatchResponse] = {} +file_id_request: Dict[str, FileMetadata] = {} +file_id_response: Dict[str, FileResponse] = {} +## map file id to file path in SGlang backend +file_id_storage: Dict[str, str] = {} + + +# backend storage directory +storage_dir = None + + def create_error_response( message: str, err_type: str = "BadRequestError", @@ -106,33 +132,364 @@ def load_chat_template_for_openai_api(chat_template_arg): chat_template_name = chat_template_arg -async def v1_completions(tokenizer_manager, raw_request: Request): - request_json = await raw_request.json() - request = CompletionRequest(**request_json) - prompt = request.prompt - if isinstance(prompt, str) or isinstance(prompt[0], str): - prompt_kwargs = {"text": prompt} +async def v1_files_create(file: UploadFile, purpose: str, file_storage_pth: str = None): + try: + global storage_dir + if file_storage_pth: + storage_dir = file_storage_pth + # Read the file content + file_content = await file.read() + + # Create an instance of RequestBody + request_body = FileRequest(file=file_content, purpose=purpose) + + # Save the file to the sglang_oai_storage directory + os.makedirs(storage_dir, exist_ok=True) + file_id = f"backend_input_file-{uuid.uuid4()}" + filename = f"{file_id}.jsonl" + file_path = os.path.join(storage_dir, filename) + + with open(file_path, "wb") as f: + f.write(request_body.file) + + # add info to global file map + file_id_request[file_id] = FileMetadata(filename=file.filename, purpose=purpose) + file_id_storage[file_id] = file_path + + # Return the response in the required format + response = FileResponse( + id=file_id, + bytes=len(request_body.file), + created_at=int(time.time()), + filename=file.filename, + purpose=request_body.purpose, + ) + file_id_response[file_id] = response + + return response + except ValidationError as e: + return {"error": "Invalid input", "details": e.errors()} + + +async def v1_batches(tokenizer_manager, raw_request: Request): + try: + body = await raw_request.json() + + batch_request = BatchRequest(**body) + + batch_id = f"batch_{uuid.uuid4()}" + + # Create an instance of BatchResponse + batch_response = BatchResponse( + id=batch_id, + endpoint=batch_request.endpoint, + input_file_id=batch_request.input_file_id, + completion_window=batch_request.completion_window, + created_at=int(time.time()), + metadata=batch_request.metadata, + ) + + batch_storage[batch_id] = batch_response + + # Start processing the batch asynchronously + asyncio.create_task(process_batch(tokenizer_manager, batch_id, batch_request)) + + # Return the initial batch_response + return batch_response + + except ValidationError as e: + return {"error": "Invalid input", "details": e.errors()} + except Exception as e: + return {"error": str(e)} + + +async def process_batch(tokenizer_manager, batch_id: str, batch_request: BatchRequest): + try: + # Update the batch status to "in_progress" + batch_storage[batch_id].status = "in_progress" + batch_storage[batch_id].in_progress_at = int(time.time()) + + # Retrieve the input file content + input_file_request = file_id_request.get(batch_request.input_file_id) + if not input_file_request: + raise ValueError("Input file not found") + + # Parse the JSONL file and process each request + input_file_path = file_id_storage.get(batch_request.input_file_id) + with open(input_file_path, "r", encoding="utf-8") as f: + lines = f.readlines() + + total_requests = len(lines) + completed_requests = 0 + failed_requests = 0 + + all_ret = [] + end_point = batch_storage[batch_id].endpoint + file_request_list = [] + all_requests = [] + for line in lines: + request_data = json.loads(line) + file_request_list.append(request_data) + body = request_data["body"] + if end_point == "/v1/chat/completions": + all_requests.append(ChatCompletionRequest(**body)) + elif end_point == "/v1/completions": + all_requests.append(CompletionRequest(**body)) + if end_point == "/v1/chat/completions": + adapted_request, request = v1_chat_generate_request( + all_requests, tokenizer_manager + ) + elif end_point == "/v1/completions": + adapted_request, request = v1_generate_request(all_requests) + try: + ret = await tokenizer_manager.generate_request(adapted_request).__anext__() + if not isinstance(ret, list): + ret = [ret] + if end_point == "/v1/chat/completions": + responses = v1_chat_generate_response(request, ret, to_file=True) + else: + responses = v1_generate_response(request, ret, to_file=True) + + except Exception as e: + error_json = { + "id": f"batch_req_{uuid.uuid4()}", + "custom_id": request_data.get("custom_id"), + "response": None, + "error": {"message": str(e)}, + } + all_ret.append(error_json) + failed_requests += len(file_request_list) + + for idx, response in enumerate(responses): + ## the batch_req here can be changed to be named within a batch granularity + response_json = { + "id": f"batch_req_{uuid.uuid4()}", + "custom_id": file_request_list[idx].get("custom_id"), + "response": response, + "error": None, + } + all_ret.append(response_json) + completed_requests += 1 + # Write results to a new file + output_file_id = f"backend_result_file-{uuid.uuid4()}" + global storage_dir + output_file_path = os.path.join(storage_dir, f"{output_file_id}.jsonl") + with open(output_file_path, "w", encoding="utf-8") as f: + for ret in all_ret: + f.write(json.dumps(ret) + "\n") + + # Update batch response with output file information + retrieve_batch = batch_storage[batch_id] + retrieve_batch.output_file_id = output_file_id + file_id_storage[output_file_id] = output_file_path + # Update batch status to "completed" + retrieve_batch.status = "completed" + retrieve_batch.completed_at = int(time.time()) + retrieve_batch.request_counts = { + "total": total_requests, + "completed": completed_requests, + "failed": failed_requests, + } + + except Exception as e: + print("error in SGlang:", e) + # Update batch status to "failed" + retrieve_batch = batch_storage[batch_id] + retrieve_batch.status = "failed" + retrieve_batch.failed_at = int(time.time()) + retrieve_batch.errors = {"message": str(e)} + + +async def v1_retrieve_batch(batch_id: str): + # Retrieve the batch job from the in-memory storage + batch_response = batch_storage.get(batch_id) + if batch_response is None: + raise HTTPException(status_code=404, detail="Batch not found") + + return batch_response + + +async def v1_retrieve_file(file_id: str): + # Retrieve the batch job from the in-memory storage + file_response = file_id_response.get(file_id) + if file_response is None: + raise HTTPException(status_code=404, detail="File not found") + return file_response + + +async def v1_retrieve_file_content(file_id: str): + file_pth = file_id_storage.get(file_id) + if not file_pth or not os.path.exists(file_pth): + raise HTTPException(status_code=404, detail="File not found") + + def iter_file(): + with open(file_pth, mode="rb") as file_like: + yield from file_like + + return StreamingResponse(iter_file(), media_type="application/octet-stream") + + +def v1_generate_request(all_requests): + + prompts = [] + sampling_params_list = [] + first_prompt_type = type(all_requests[0].prompt) + for request in all_requests: + prompt = request.prompt + assert ( + type(prompt) == first_prompt_type + ), "All prompts must be of the same type in file input settings" + prompts.append(prompt) + sampling_params_list.append( + { + "temperature": request.temperature, + "max_new_tokens": request.max_tokens, + "stop": request.stop, + "top_p": request.top_p, + "presence_penalty": request.presence_penalty, + "frequency_penalty": request.frequency_penalty, + "regex": request.regex, + "n": request.n, + "ignore_eos": request.ignore_eos, + } + ) + if len(all_requests) > 1 and request.n > 1: + raise ValueError( + "Batch operation is not supported for completions from files" + ) + + if len(all_requests) == 1: + prompt = prompts[0] + sampling_params_list = sampling_params_list[0] + if isinstance(prompts, str) or isinstance(prompts[0], str): + prompt_kwargs = {"text": prompt} + else: + prompt_kwargs = {"input_ids": prompt} else: - prompt_kwargs = {"input_ids": prompt} + if isinstance(prompts[0], str): + prompt_kwargs = {"text": prompts} + else: + prompt_kwargs = {"input_ids": prompts} adapted_request = GenerateReqInput( **prompt_kwargs, - sampling_params={ - "temperature": request.temperature, - "max_new_tokens": request.max_tokens, - "stop": request.stop, - "top_p": request.top_p, - "presence_penalty": request.presence_penalty, - "frequency_penalty": request.frequency_penalty, - "regex": request.regex, - "n": request.n, - "ignore_eos": request.ignore_eos, - }, - return_logprob=request.logprobs is not None and request.logprobs > 0, - top_logprobs_num=request.logprobs if request.logprobs is not None else 0, + sampling_params=sampling_params_list, + return_logprob=all_requests[0].logprobs is not None + and all_requests[0].logprobs > 0, + top_logprobs_num=( + all_requests[0].logprobs if all_requests[0].logprobs is not None else 0 + ), return_text_in_logprobs=True, - stream=request.stream, + stream=all_requests[0].stream, ) + if len(all_requests) == 1: + return adapted_request, all_requests[0] + return adapted_request, all_requests + + +def v1_generate_response(request, ret, to_file=False): + choices = [] + echo = False + + if (not isinstance(request, List)) and request.echo: + # TODO: handle the case propmt is token ids + if isinstance(request.prompt, list): + prompts = request.prompt + else: + prompts = [request.prompt] + echo = True + + for idx, ret_item in enumerate(ret): + text = ret_item["text"] + if isinstance(request, List) and request[idx].echo: + echo = True + text = request[idx].prompt + text + if (not isinstance(request, List)) and echo: + text = prompts[idx] + text + + logprobs = False + if isinstance(request, List) and request[idx].logprobs: + logprobs = True + elif (not isinstance(request, List)) and request.logprobs: + logprobs = True + if logprobs: + if echo: + input_token_logprobs = ret_item["meta_info"]["input_token_logprobs"] + input_top_logprobs = ret_item["meta_info"]["input_top_logprobs"] + else: + input_token_logprobs = None + input_top_logprobs = None + + logprobs = to_openai_style_logprobs( + input_token_logprobs=input_token_logprobs, + input_top_logprobs=input_top_logprobs, + output_token_logprobs=ret_item["meta_info"]["output_token_logprobs"], + output_top_logprobs=ret_item["meta_info"]["output_top_logprobs"], + ) + else: + logprobs = None + + if to_file: + ## to make the choise data json serializable + choice_data = { + "index": 0, + "text": text, + "logprobs": logprobs, + "finish_reason": ret_item["meta_info"]["finish_reason"], + } + else: + choice_data = CompletionResponseChoice( + index=idx, + text=text, + logprobs=logprobs, + finish_reason=ret_item["meta_info"]["finish_reason"], + ) + + choices.append(choice_data) + + if to_file: + responses = [] + for i, choice in enumerate(choices): + response = { + "status_code": 200, + "request_id": ret[i]["meta_info"]["id"], + "body": { + ## remain the same but if needed we can change that + "id": ret[i]["meta_info"]["id"], + "object": "text_completion", + "created": int(time.time()), + "model": request[i].model, + "choices": choice, + "usage": { + "prompt_tokens": ret[i]["meta_info"]["prompt_tokens"], + "completion_tokens": ret[i]["meta_info"]["completion_tokens"], + "total_tokens": ret[i]["meta_info"]["prompt_tokens"] + + ret[i]["meta_info"]["completion_tokens"], + }, + "system_fingerprint": None, + }, + } + responses.append(response) + return responses + else: + completion_tokens = sum(item["meta_info"]["completion_tokens"] for item in ret) + response = CompletionResponse( + id=ret[0]["meta_info"]["id"], + model=request.model, + choices=choices, + usage=UsageInfo( + prompt_tokens=ret[0]["meta_info"]["prompt_tokens"], + completion_tokens=completion_tokens, + total_tokens=ret[0]["meta_info"]["prompt_tokens"] + completion_tokens, + ), + ) + return response + + +async def v1_completions(tokenizer_manager, raw_request: Request): + request_json = await raw_request.json() + all_requests = [CompletionRequest(**request_json)] + adapted_request, request = v1_generate_request(all_requests) if adapted_request.stream: @@ -223,109 +580,144 @@ async def v1_completions(tokenizer_manager, raw_request: Request): if not isinstance(ret, list): ret = [ret] - if request.echo: - # TODO: handle the case propmt is token ids - if isinstance(request.prompt, list): - prompts = request.prompt + + response = v1_generate_response(request, ret) + return response + + +def v1_chat_generate_request(all_requests, tokenizer_manager): + + texts = [] + sampling_params_list = [] + image_data_list = [] + for request in all_requests: + # Prep the data needed for the underlying GenerateReqInput: + # - prompt: The full prompt string. + # - stop: Custom stop tokens. + # - image_data: None or a list of image strings (URLs or base64 strings). + # None skips any image processing in GenerateReqInput. + if not isinstance(request.messages, str): + # Apply chat template and its stop strings. + if chat_template_name is None: + prompt = tokenizer_manager.tokenizer.apply_chat_template( + request.messages, tokenize=False, add_generation_prompt=True + ) + stop = request.stop + image_data = None + else: + conv = generate_chat_conv(request, chat_template_name) + prompt = conv.get_prompt() + image_data = conv.image_data + stop = conv.stop_str or [] + if request.stop: + if isinstance(request.stop, str): + stop.append(request.stop) + else: + stop.extend(request.stop) else: - prompts = [request.prompt] + # Use the raw prompt and stop strings if the messages is already a string. + prompt = request.messages + stop = request.stop + image_data = None + texts.append(prompt) + sampling_params_list.append( + { + "temperature": request.temperature, + "max_new_tokens": request.max_tokens, + "stop": stop, + "top_p": request.top_p, + "presence_penalty": request.presence_penalty, + "frequency_penalty": request.frequency_penalty, + "regex": request.regex, + "n": request.n, + } + ) + image_data_list.append(image_data) + if len(all_requests) == 1: + texts = texts[0] + sampling_params_list = sampling_params_list[0] + image_data = image_data_list[0] + adapted_request = GenerateReqInput( + text=texts, + image_data=image_data, + sampling_params=sampling_params_list, + stream=request.stream, + ) + if len(all_requests) == 1: + return adapted_request, all_requests[0] + return adapted_request, all_requests + + +def v1_chat_generate_response(request, ret, to_file=False): choices = [] + total_prompt_tokens = 0 + total_completion_tokens = 0 for idx, ret_item in enumerate(ret): - text = ret_item["text"] + prompt_tokens = ret_item["meta_info"]["prompt_tokens"] + completion_tokens = ret_item["meta_info"]["completion_tokens"] - if request.echo: - text = prompts[idx] + text - - if request.logprobs: - if request.echo: - input_token_logprobs = ret_item["meta_info"]["input_token_logprobs"] - input_top_logprobs = ret_item["meta_info"]["input_top_logprobs"] - else: - input_token_logprobs = None - input_top_logprobs = None - - logprobs = to_openai_style_logprobs( - input_token_logprobs=input_token_logprobs, - input_top_logprobs=input_top_logprobs, - output_token_logprobs=ret_item["meta_info"]["output_token_logprobs"], - output_top_logprobs=ret_item["meta_info"]["output_top_logprobs"], - ) + if to_file: + ## to make the choice data json serializable + choice_data = { + "index": 0, + "message": {"role": "assistant", "content": ret_item["text"]}, + "logprobs": None, + "finish_reason": ret_item["meta_info"]["finish_reason"], + } else: - logprobs = None - - choice_data = CompletionResponseChoice( - index=idx, - text=text, - logprobs=logprobs, - finish_reason=ret_item["meta_info"]["finish_reason"], - ) + choice_data = ChatCompletionResponseChoice( + index=idx, + message=ChatMessage(role="assistant", content=ret_item["text"]), + finish_reason=ret_item["meta_info"]["finish_reason"], + ) choices.append(choice_data) + total_prompt_tokens = prompt_tokens + total_completion_tokens += completion_tokens + if to_file: + responses = [] - completion_tokens = sum(item["meta_info"]["completion_tokens"] for item in ret) - response = CompletionResponse( - id=ret[0]["meta_info"]["id"], - model=request.model, - choices=choices, - usage=UsageInfo( - prompt_tokens=ret[0]["meta_info"]["prompt_tokens"], - completion_tokens=completion_tokens, - total_tokens=ret[0]["meta_info"]["prompt_tokens"] + completion_tokens, - ), - ) - - return response + for i, choice in enumerate(choices): + response = { + "status_code": 200, + "request_id": ret[i]["meta_info"]["id"], + "body": { + ## remain the same but if needed we can change that + "id": ret[i]["meta_info"]["id"], + "object": "chat.completion", + "created": int(time.time()), + "model": request[i].model, + "choices": choice, + "usage": { + "prompt_tokens": ret[i]["meta_info"]["prompt_tokens"], + "completion_tokens": ret[i]["meta_info"]["completion_tokens"], + "total_tokens": ret[i]["meta_info"]["prompt_tokens"] + + ret[i]["meta_info"]["completion_tokens"], + }, + "system_fingerprint": None, + }, + } + responses.append(response) + return responses + else: + response = ChatCompletionResponse( + id=ret[0]["meta_info"]["id"], + model=request.model, + choices=choices, + usage=UsageInfo( + prompt_tokens=total_prompt_tokens, + completion_tokens=total_completion_tokens, + total_tokens=total_prompt_tokens + total_completion_tokens, + ), + ) + return response async def v1_chat_completions(tokenizer_manager, raw_request: Request): request_json = await raw_request.json() - request = ChatCompletionRequest(**request_json) - - # Prep the data needed for the underlying GenerateReqInput: - # - prompt: The full prompt string. - # - stop: Custom stop tokens. - # - image_data: None or a list of image strings (URLs or base64 strings). - # None skips any image processing in GenerateReqInput. - if not isinstance(request.messages, str): - # Apply chat template and its stop strings. - if chat_template_name is None: - prompt = tokenizer_manager.tokenizer.apply_chat_template( - request.messages, tokenize=False, add_generation_prompt=True - ) - stop = request.stop - image_data = None - else: - conv = generate_chat_conv(request, chat_template_name) - prompt = conv.get_prompt() - image_data = conv.image_data - stop = conv.stop_str or [] - if request.stop: - if isinstance(request.stop, str): - stop.append(request.stop) - else: - stop.extend(request.stop) - else: - # Use the raw prompt and stop strings if the messages is already a string. - prompt = request.messages - stop = request.stop - image_data = None - - adapted_request = GenerateReqInput( - text=prompt, - image_data=image_data, - sampling_params={ - "temperature": request.temperature, - "max_new_tokens": request.max_tokens, - "stop": stop, - "top_p": request.top_p, - "presence_penalty": request.presence_penalty, - "frequency_penalty": request.frequency_penalty, - "regex": request.regex, - "n": request.n, - }, - stream=request.stream, - ) + all_requests = [ChatCompletionRequest(**request_json)] + adapted_request, request = v1_chat_generate_request(all_requests, tokenizer_manager) if adapted_request.stream: @@ -387,34 +779,8 @@ async def v1_chat_completions(tokenizer_manager, raw_request: Request): if not isinstance(ret, list): ret = [ret] - choices = [] - total_prompt_tokens = 0 - total_completion_tokens = 0 - for idx, ret_item in enumerate(ret): - prompt_tokens = ret_item["meta_info"]["prompt_tokens"] - completion_tokens = ret_item["meta_info"]["completion_tokens"] - - choice_data = ChatCompletionResponseChoice( - index=idx, - message=ChatMessage(role="assistant", content=ret_item["text"]), - finish_reason=ret_item["meta_info"]["finish_reason"], - ) - - choices.append(choice_data) - total_prompt_tokens = prompt_tokens - total_completion_tokens += completion_tokens - - response = ChatCompletionResponse( - id=ret[0]["meta_info"]["id"], - model=request.model, - choices=choices, - usage=UsageInfo( - prompt_tokens=total_prompt_tokens, - completion_tokens=total_completion_tokens, - total_tokens=total_prompt_tokens + total_completion_tokens, - ), - ) + response = v1_chat_generate_response(request, ret) return response diff --git a/python/sglang/srt/openai_api/protocol.py b/python/sglang/srt/openai_api/protocol.py index c7c18be55..853165e34 100644 --- a/python/sglang/srt/openai_api/protocol.py +++ b/python/sglang/srt/openai_api/protocol.py @@ -60,6 +60,55 @@ class UsageInfo(BaseModel): completion_tokens: Optional[int] = 0 +class FileRequest(BaseModel): + # https://platform.openai.com/docs/api-reference/files/create + file: bytes # The File object (not file name) to be uploaded + purpose: str = ( + "batch" # The intended purpose of the uploaded file, default is "batch" + ) + + +class FileResponse(BaseModel): + id: str + object: str = "file" + bytes: int + created_at: int + filename: str + purpose: str + + +class BatchRequest(BaseModel): + input_file_id: ( + str # The ID of an uploaded file that contains requests for the new batch + ) + endpoint: str # The endpoint to be used for all requests in the batch + completion_window: str # The time frame within which the batch should be processed + metadata: Optional[dict] = None # Optional custom metadata for the batch + + +class BatchResponse(BaseModel): + id: str + object: str = "batch" + endpoint: str + errors: Optional[dict] = None + input_file_id: str + completion_window: str + status: str = "validating" + output_file_id: Optional[str] = None + error_file_id: Optional[str] = None + created_at: int + in_progress_at: Optional[int] = None + expires_at: Optional[int] = None + finalizing_at: Optional[int] = None + completed_at: Optional[int] = None + failed_at: Optional[int] = None + expired_at: Optional[int] = None + cancelling_at: Optional[int] = None + cancelled_at: Optional[int] = None + request_counts: dict = {"total": 0, "completed": 0, "failed": 0} + metadata: Optional[dict] = None + + class CompletionRequest(BaseModel): # Ordered by official OpenAI API documentation # https://platform.openai.com/docs/api-reference/completions/create diff --git a/python/sglang/srt/server.py b/python/sglang/srt/server.py index 2b73a6e65..f1b5dae9c 100644 --- a/python/sglang/srt/server.py +++ b/python/sglang/srt/server.py @@ -38,7 +38,7 @@ import psutil import requests import uvicorn import uvloop -from fastapi import FastAPI, Request +from fastapi import FastAPI, File, Form, Request, UploadFile from fastapi.responses import JSONResponse, Response, StreamingResponse from sglang.lang.backend.runtime_endpoint import RuntimeEndpoint @@ -56,8 +56,13 @@ from sglang.srt.managers.io_struct import GenerateReqInput from sglang.srt.managers.tokenizer_manager import TokenizerManager from sglang.srt.openai_api.adapter import ( load_chat_template_for_openai_api, + v1_batches, v1_chat_completions, v1_completions, + v1_files_create, + v1_retrieve_batch, + v1_retrieve_file, + v1_retrieve_file_content, ) from sglang.srt.openai_api.protocol import ModelCard, ModelList from sglang.srt.server_args import PortArgs, ServerArgs @@ -152,6 +157,35 @@ async def openai_v1_chat_completions(raw_request: Request): return await v1_chat_completions(tokenizer_manager, raw_request) +@app.post("/v1/files") +async def openai_v1_files(file: UploadFile = File(...), purpose: str = Form("batch")): + return await v1_files_create( + file, purpose, tokenizer_manager.server_args.file_storage_pth + ) + + +@app.post("/v1/batches") +async def openai_v1_batches(raw_request: Request): + return await v1_batches(tokenizer_manager, raw_request) + + +@app.get("/v1/batches/{batch_id}") +async def retrieve_batch(batch_id: str): + return await v1_retrieve_batch(batch_id) + + +@app.get("/v1/files/{file_id}") +async def retrieve_file(file_id: str): + # https://platform.openai.com/docs/api-reference/files/retrieve + return await v1_retrieve_file(file_id) + + +@app.get("/v1/files/{file_id}/content") +async def retrieve_file_content(file_id: str): + # https://platform.openai.com/docs/api-reference/files/retrieve-contents + return await v1_retrieve_file_content(file_id) + + @app.get("/v1/models") def available_models(): """Show available models.""" diff --git a/python/sglang/srt/server_args.py b/python/sglang/srt/server_args.py index c9535f402..69829a7fc 100644 --- a/python/sglang/srt/server_args.py +++ b/python/sglang/srt/server_args.py @@ -60,6 +60,7 @@ class ServerArgs: # Other api_key: str = "" + file_storage_pth: str = "SGlang_storage" # Data parallelism dp_size: int = 1 @@ -290,6 +291,12 @@ class ServerArgs: default=ServerArgs.api_key, help="Set API key of the server.", ) + parser.add_argument( + "--file-storage-pth", + type=str, + default=ServerArgs.file_storage_pth, + help="The path of the file storage in backend.", + ) # Data parallelism parser.add_argument(