Add support for OpenAI API : offline batch(file) processing (#699)
Co-authored-by: hnyls2002 <hnyls2002@gmail.com>
This commit is contained in:
@@ -4,6 +4,6 @@ repos:
|
|||||||
hooks:
|
hooks:
|
||||||
- id: isort
|
- id: isort
|
||||||
- repo: https://github.com/psf/black
|
- repo: https://github.com/psf/black
|
||||||
rev: stable
|
rev: 24.4.2
|
||||||
hooks:
|
hooks:
|
||||||
- id: black
|
- id: black
|
||||||
|
|||||||
86
examples/usage/openai_batch_chat.py
Normal file
86
examples/usage/openai_batch_chat.py
Normal file
@@ -0,0 +1,86 @@
|
|||||||
|
import json
|
||||||
|
import os
|
||||||
|
import time
|
||||||
|
|
||||||
|
import openai
|
||||||
|
from openai import OpenAI
|
||||||
|
|
||||||
|
|
||||||
|
class OpenAIBatchProcessor:
|
||||||
|
def __init__(self, api_key):
|
||||||
|
# client = OpenAI(api_key=api_key)
|
||||||
|
client = openai.Client(base_url="http://127.0.0.1:30000/v1", api_key="EMPTY")
|
||||||
|
|
||||||
|
self.client = client
|
||||||
|
|
||||||
|
def process_batch(self, input_file_path, endpoint, completion_window):
|
||||||
|
|
||||||
|
# Upload the input file
|
||||||
|
with open(input_file_path, "rb") as file:
|
||||||
|
uploaded_file = self.client.files.create(file=file, purpose="batch")
|
||||||
|
|
||||||
|
# Create the batch job
|
||||||
|
batch_job = self.client.batches.create(
|
||||||
|
input_file_id=uploaded_file.id,
|
||||||
|
endpoint=endpoint,
|
||||||
|
completion_window=completion_window,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Monitor the batch job status
|
||||||
|
while batch_job.status not in ["completed", "failed", "cancelled"]:
|
||||||
|
time.sleep(3) # Wait for 3 seconds before checking the status again
|
||||||
|
print(
|
||||||
|
f"Batch job status: {batch_job.status}...trying again in 3 seconds..."
|
||||||
|
)
|
||||||
|
batch_job = self.client.batches.retrieve(batch_job.id)
|
||||||
|
|
||||||
|
# Check the batch job status and errors
|
||||||
|
if batch_job.status == "failed":
|
||||||
|
print(f"Batch job failed with status: {batch_job.status}")
|
||||||
|
print(f"Batch job errors: {batch_job.errors}")
|
||||||
|
return None
|
||||||
|
|
||||||
|
# If the batch job is completed, process the results
|
||||||
|
if batch_job.status == "completed":
|
||||||
|
|
||||||
|
# print result of batch job
|
||||||
|
print("batch", batch_job.request_counts)
|
||||||
|
|
||||||
|
result_file_id = batch_job.output_file_id
|
||||||
|
# Retrieve the file content from the server
|
||||||
|
file_response = self.client.files.content(result_file_id)
|
||||||
|
result_content = file_response.read() # Read the content of the file
|
||||||
|
|
||||||
|
# Save the content to a local file
|
||||||
|
result_file_name = "batch_job_chat_results.jsonl"
|
||||||
|
with open(result_file_name, "wb") as file:
|
||||||
|
file.write(result_content) # Write the binary content to the file
|
||||||
|
# Load data from the saved JSONL file
|
||||||
|
results = []
|
||||||
|
with open(result_file_name, "r", encoding="utf-8") as file:
|
||||||
|
for line in file:
|
||||||
|
json_object = json.loads(
|
||||||
|
line.strip()
|
||||||
|
) # Parse each line as a JSON object
|
||||||
|
results.append(json_object)
|
||||||
|
|
||||||
|
return results
|
||||||
|
else:
|
||||||
|
print(f"Batch job failed with status: {batch_job.status}")
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
# Initialize the OpenAIBatchProcessor
|
||||||
|
api_key = os.environ.get("OPENAI_API_KEY")
|
||||||
|
processor = OpenAIBatchProcessor(api_key)
|
||||||
|
|
||||||
|
# Process the batch job
|
||||||
|
input_file_path = "input.jsonl"
|
||||||
|
endpoint = "/v1/chat/completions"
|
||||||
|
completion_window = "24h"
|
||||||
|
|
||||||
|
# Process the batch job
|
||||||
|
results = processor.process_batch(input_file_path, endpoint, completion_window)
|
||||||
|
|
||||||
|
# Print the results
|
||||||
|
print(results)
|
||||||
86
examples/usage/openai_batch_complete.py
Normal file
86
examples/usage/openai_batch_complete.py
Normal file
@@ -0,0 +1,86 @@
|
|||||||
|
import json
|
||||||
|
import os
|
||||||
|
import time
|
||||||
|
|
||||||
|
import openai
|
||||||
|
from openai import OpenAI
|
||||||
|
|
||||||
|
|
||||||
|
class OpenAIBatchProcessor:
|
||||||
|
def __init__(self, api_key):
|
||||||
|
# client = OpenAI(api_key=api_key)
|
||||||
|
client = openai.Client(base_url="http://127.0.0.1:30000/v1", api_key="EMPTY")
|
||||||
|
|
||||||
|
self.client = client
|
||||||
|
|
||||||
|
def process_batch(self, input_file_path, endpoint, completion_window):
|
||||||
|
|
||||||
|
# Upload the input file
|
||||||
|
with open(input_file_path, "rb") as file:
|
||||||
|
uploaded_file = self.client.files.create(file=file, purpose="batch")
|
||||||
|
|
||||||
|
# Create the batch job
|
||||||
|
batch_job = self.client.batches.create(
|
||||||
|
input_file_id=uploaded_file.id,
|
||||||
|
endpoint=endpoint,
|
||||||
|
completion_window=completion_window,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Monitor the batch job status
|
||||||
|
while batch_job.status not in ["completed", "failed", "cancelled"]:
|
||||||
|
time.sleep(3) # Wait for 3 seconds before checking the status again
|
||||||
|
print(
|
||||||
|
f"Batch job status: {batch_job.status}...trying again in 3 seconds..."
|
||||||
|
)
|
||||||
|
batch_job = self.client.batches.retrieve(batch_job.id)
|
||||||
|
|
||||||
|
# Check the batch job status and errors
|
||||||
|
if batch_job.status == "failed":
|
||||||
|
print(f"Batch job failed with status: {batch_job.status}")
|
||||||
|
print(f"Batch job errors: {batch_job.errors}")
|
||||||
|
return None
|
||||||
|
|
||||||
|
# If the batch job is completed, process the results
|
||||||
|
if batch_job.status == "completed":
|
||||||
|
|
||||||
|
# print result of batch job
|
||||||
|
print("batch", batch_job.request_counts)
|
||||||
|
|
||||||
|
result_file_id = batch_job.output_file_id
|
||||||
|
# Retrieve the file content from the server
|
||||||
|
file_response = self.client.files.content(result_file_id)
|
||||||
|
result_content = file_response.read() # Read the content of the file
|
||||||
|
|
||||||
|
# Save the content to a local file
|
||||||
|
result_file_name = "batch_job_complete_results.jsonl"
|
||||||
|
with open(result_file_name, "wb") as file:
|
||||||
|
file.write(result_content) # Write the binary content to the file
|
||||||
|
# Load data from the saved JSONL file
|
||||||
|
results = []
|
||||||
|
with open(result_file_name, "r", encoding="utf-8") as file:
|
||||||
|
for line in file:
|
||||||
|
json_object = json.loads(
|
||||||
|
line.strip()
|
||||||
|
) # Parse each line as a JSON object
|
||||||
|
results.append(json_object)
|
||||||
|
|
||||||
|
return results
|
||||||
|
else:
|
||||||
|
print(f"Batch job failed with status: {batch_job.status}")
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
# Initialize the OpenAIBatchProcessor
|
||||||
|
api_key = os.environ.get("OPENAI_API_KEY")
|
||||||
|
processor = OpenAIBatchProcessor(api_key)
|
||||||
|
|
||||||
|
# Process the batch job
|
||||||
|
input_file_path = "input_complete.jsonl"
|
||||||
|
endpoint = "/v1/completions"
|
||||||
|
completion_window = "24h"
|
||||||
|
|
||||||
|
# Process the batch job
|
||||||
|
results = processor.process_batch(input_file_path, endpoint, completion_window)
|
||||||
|
|
||||||
|
# Print the results
|
||||||
|
print(results)
|
||||||
@@ -13,6 +13,17 @@ response = client.completions.create(
|
|||||||
print(response)
|
print(response)
|
||||||
|
|
||||||
|
|
||||||
|
# Text completion
|
||||||
|
response = client.completions.create(
|
||||||
|
model="default",
|
||||||
|
prompt="I am a robot and I want to study like humans. Now let's tell a story. Once upon a time, there was a little",
|
||||||
|
n=1,
|
||||||
|
temperature=0.8,
|
||||||
|
max_tokens=32,
|
||||||
|
)
|
||||||
|
print(response)
|
||||||
|
|
||||||
|
|
||||||
# Text completion
|
# Text completion
|
||||||
response = client.completions.create(
|
response = client.completions.create(
|
||||||
model="default",
|
model="default",
|
||||||
@@ -24,6 +35,17 @@ response = client.completions.create(
|
|||||||
print(response)
|
print(response)
|
||||||
|
|
||||||
|
|
||||||
|
# Text completion
|
||||||
|
response = client.completions.create(
|
||||||
|
model="default",
|
||||||
|
prompt=["The name of the famous soccer player is"],
|
||||||
|
n=1,
|
||||||
|
temperature=0.8,
|
||||||
|
max_tokens=128,
|
||||||
|
)
|
||||||
|
print(response)
|
||||||
|
|
||||||
|
|
||||||
# Text completion
|
# Text completion
|
||||||
response = client.completions.create(
|
response = client.completions.create(
|
||||||
model="default",
|
model="default",
|
||||||
@@ -60,6 +82,21 @@ response = client.completions.create(
|
|||||||
)
|
)
|
||||||
print(response)
|
print(response)
|
||||||
|
|
||||||
|
# Chat completion
|
||||||
|
response = client.chat.completions.create(
|
||||||
|
model="default",
|
||||||
|
messages=[
|
||||||
|
{"role": "system", "content": "You are a helpful AI assistant"},
|
||||||
|
{"role": "user", "content": "List 3 countries and their capitals."},
|
||||||
|
],
|
||||||
|
temperature=0.8,
|
||||||
|
max_tokens=64,
|
||||||
|
logprobs=True,
|
||||||
|
n=1,
|
||||||
|
)
|
||||||
|
print(response)
|
||||||
|
|
||||||
|
|
||||||
# Chat completion
|
# Chat completion
|
||||||
response = client.chat.completions.create(
|
response = client.chat.completions.create(
|
||||||
model="default",
|
model="default",
|
||||||
|
|||||||
@@ -79,8 +79,26 @@ class GenerateReqInput:
|
|||||||
if self.top_logprobs_num is None:
|
if self.top_logprobs_num is None:
|
||||||
self.top_logprobs_num = 0
|
self.top_logprobs_num = 0
|
||||||
else:
|
else:
|
||||||
|
parallel_sample_num_list = []
|
||||||
|
if isinstance(self.sampling_params, dict):
|
||||||
parallel_sample_num = self.sampling_params.get("n", 1)
|
parallel_sample_num = self.sampling_params.get("n", 1)
|
||||||
|
elif isinstance(self.sampling_params, list):
|
||||||
|
for sp in self.sampling_params:
|
||||||
|
parallel_sample_num = sp.get("n", 1)
|
||||||
|
parallel_sample_num_list.append(parallel_sample_num)
|
||||||
|
parallel_sample_num = max(parallel_sample_num_list)
|
||||||
|
all_equal = all(
|
||||||
|
element == parallel_sample_num
|
||||||
|
for element in parallel_sample_num_list
|
||||||
|
)
|
||||||
|
if parallel_sample_num > 1 and (not all_equal):
|
||||||
|
## TODO cope with the case that the parallel_sample_num is different for different samples
|
||||||
|
raise ValueError(
|
||||||
|
"The parallel_sample_num should be the same for all samples in sample params."
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
parallel_sample_num = 1
|
||||||
|
self.parallel_sample_num = parallel_sample_num
|
||||||
|
|
||||||
if parallel_sample_num != 1:
|
if parallel_sample_num != 1:
|
||||||
# parallel sampling +1 represents the original prefill stage
|
# parallel sampling +1 represents the original prefill stage
|
||||||
|
|||||||
@@ -84,6 +84,7 @@ class TokenizerManager:
|
|||||||
trust_remote_code=server_args.trust_remote_code,
|
trust_remote_code=server_args.trust_remote_code,
|
||||||
model_overide_args=model_overide_args,
|
model_overide_args=model_overide_args,
|
||||||
)
|
)
|
||||||
|
|
||||||
if server_args.context_length is not None:
|
if server_args.context_length is not None:
|
||||||
self.context_len = server_args.context_length
|
self.context_len = server_args.context_length
|
||||||
else:
|
else:
|
||||||
@@ -152,31 +153,33 @@ class TokenizerManager:
|
|||||||
self, obj, request, index=None, is_cache_for_prefill=False
|
self, obj, request, index=None, is_cache_for_prefill=False
|
||||||
):
|
):
|
||||||
if not is_cache_for_prefill:
|
if not is_cache_for_prefill:
|
||||||
rid = obj.rid if index is None else obj.rid[index]
|
not_use_index = not (index is not None)
|
||||||
input_text = obj.text if index is None else obj.text[index]
|
rid = obj.rid if not_use_index else obj.rid[index]
|
||||||
|
input_text = obj.text if not_use_index else obj.text[index]
|
||||||
input_ids = (
|
input_ids = (
|
||||||
self.tokenizer.encode(input_text)
|
self.tokenizer.encode(input_text)
|
||||||
if obj.input_ids is None
|
if obj.input_ids is None
|
||||||
else obj.input_ids
|
else obj.input_ids
|
||||||
)
|
)
|
||||||
if index is not None and obj.input_ids:
|
if not not_use_index and obj.input_ids:
|
||||||
input_ids = obj.input_ids[index]
|
input_ids = obj.input_ids[index]
|
||||||
|
|
||||||
self._validate_input_length(input_ids)
|
self._validate_input_length(input_ids)
|
||||||
|
|
||||||
sampling_params = self._get_sampling_params(
|
sampling_params = self._get_sampling_params(
|
||||||
obj.sampling_params if index is None else obj.sampling_params[index]
|
obj.sampling_params if not_use_index else obj.sampling_params[index]
|
||||||
)
|
)
|
||||||
pixel_values, image_hash, image_size = await self._get_pixel_values(
|
pixel_values, image_hash, image_size = await self._get_pixel_values(
|
||||||
obj.image_data if index is None else obj.image_data[index]
|
obj.image_data if not_use_index else obj.image_data[index]
|
||||||
)
|
)
|
||||||
return_logprob = (
|
return_logprob = (
|
||||||
obj.return_logprob if index is None else obj.return_logprob[index]
|
obj.return_logprob if not_use_index else obj.return_logprob[index]
|
||||||
)
|
)
|
||||||
logprob_start_len = (
|
logprob_start_len = (
|
||||||
obj.logprob_start_len if index is None else obj.logprob_start_len[index]
|
obj.logprob_start_len if not_use_index else obj.logprob_start_len[index]
|
||||||
)
|
)
|
||||||
top_logprobs_num = (
|
top_logprobs_num = (
|
||||||
obj.top_logprobs_num if index is None else obj.top_logprobs_num[index]
|
obj.top_logprobs_num if not_use_index else obj.top_logprobs_num[index]
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
if isinstance(obj.text, list):
|
if isinstance(obj.text, list):
|
||||||
@@ -224,7 +227,7 @@ class TokenizerManager:
|
|||||||
|
|
||||||
async def _handle_batch_request(self, obj: GenerateReqInput, request):
|
async def _handle_batch_request(self, obj: GenerateReqInput, request):
|
||||||
batch_size = obj.batch_size
|
batch_size = obj.batch_size
|
||||||
parallel_sample_num = obj.sampling_params[0].get("n", 1)
|
parallel_sample_num = obj.parallel_sample_num
|
||||||
|
|
||||||
if parallel_sample_num != 1:
|
if parallel_sample_num != 1:
|
||||||
# Send prefill requests to cache the common input
|
# Send prefill requests to cache the common input
|
||||||
@@ -241,7 +244,6 @@ class TokenizerManager:
|
|||||||
obj.input_ids = input_id_result
|
obj.input_ids = input_id_result
|
||||||
elif input_id_result is not None:
|
elif input_id_result is not None:
|
||||||
obj.input_ids = input_id_result[0]
|
obj.input_ids = input_id_result[0]
|
||||||
|
|
||||||
# First send out all requests
|
# First send out all requests
|
||||||
for i in range(batch_size):
|
for i in range(batch_size):
|
||||||
for j in range(parallel_sample_num):
|
for j in range(parallel_sample_num):
|
||||||
@@ -249,7 +251,7 @@ class TokenizerManager:
|
|||||||
continue
|
continue
|
||||||
index = i * parallel_sample_num + j
|
index = i * parallel_sample_num + j
|
||||||
if parallel_sample_num != 1:
|
if parallel_sample_num != 1:
|
||||||
# Here when using parallel sampling we shoul consider prefill stage so the index is : j + i * (parallel_sample_num-1) + batch_size - 1
|
# Here when using parallel sampling we should consider prefill stage so the index is : j + i * (parallel_sample_num-1) + batch_size - 1
|
||||||
index += batch_size - 1 - i
|
index += batch_size - 1 - i
|
||||||
rid = obj.rid[index]
|
rid = obj.rid[index]
|
||||||
if parallel_sample_num == 1:
|
if parallel_sample_num == 1:
|
||||||
|
|||||||
@@ -18,10 +18,14 @@ limitations under the License.
|
|||||||
import asyncio
|
import asyncio
|
||||||
import json
|
import json
|
||||||
import os
|
import os
|
||||||
|
import time
|
||||||
|
import uuid
|
||||||
from http import HTTPStatus
|
from http import HTTPStatus
|
||||||
|
from typing import Dict, List, Optional
|
||||||
|
|
||||||
from fastapi import Request
|
from fastapi import HTTPException, Request, UploadFile
|
||||||
from fastapi.responses import JSONResponse, StreamingResponse
|
from fastapi.responses import JSONResponse, StreamingResponse
|
||||||
|
from pydantic import ValidationError
|
||||||
|
|
||||||
from sglang.srt.conversation import (
|
from sglang.srt.conversation import (
|
||||||
Conversation,
|
Conversation,
|
||||||
@@ -32,6 +36,8 @@ from sglang.srt.conversation import (
|
|||||||
)
|
)
|
||||||
from sglang.srt.managers.io_struct import GenerateReqInput
|
from sglang.srt.managers.io_struct import GenerateReqInput
|
||||||
from sglang.srt.openai_api.protocol import (
|
from sglang.srt.openai_api.protocol import (
|
||||||
|
BatchRequest,
|
||||||
|
BatchResponse,
|
||||||
ChatCompletionRequest,
|
ChatCompletionRequest,
|
||||||
ChatCompletionResponse,
|
ChatCompletionResponse,
|
||||||
ChatCompletionResponseChoice,
|
ChatCompletionResponseChoice,
|
||||||
@@ -45,6 +51,8 @@ from sglang.srt.openai_api.protocol import (
|
|||||||
CompletionStreamResponse,
|
CompletionStreamResponse,
|
||||||
DeltaMessage,
|
DeltaMessage,
|
||||||
ErrorResponse,
|
ErrorResponse,
|
||||||
|
FileRequest,
|
||||||
|
FileResponse,
|
||||||
LogProbs,
|
LogProbs,
|
||||||
UsageInfo,
|
UsageInfo,
|
||||||
)
|
)
|
||||||
@@ -52,6 +60,24 @@ from sglang.srt.openai_api.protocol import (
|
|||||||
chat_template_name = None
|
chat_template_name = None
|
||||||
|
|
||||||
|
|
||||||
|
class FileMetadata:
|
||||||
|
def __init__(self, filename: str, purpose: str):
|
||||||
|
self.filename = filename
|
||||||
|
self.purpose = purpose
|
||||||
|
|
||||||
|
|
||||||
|
# In-memory storage for batch jobs and files
|
||||||
|
batch_storage: Dict[str, BatchResponse] = {}
|
||||||
|
file_id_request: Dict[str, FileMetadata] = {}
|
||||||
|
file_id_response: Dict[str, FileResponse] = {}
|
||||||
|
## map file id to file path in SGlang backend
|
||||||
|
file_id_storage: Dict[str, str] = {}
|
||||||
|
|
||||||
|
|
||||||
|
# backend storage directory
|
||||||
|
storage_dir = None
|
||||||
|
|
||||||
|
|
||||||
def create_error_response(
|
def create_error_response(
|
||||||
message: str,
|
message: str,
|
||||||
err_type: str = "BadRequestError",
|
err_type: str = "BadRequestError",
|
||||||
@@ -106,18 +132,216 @@ def load_chat_template_for_openai_api(chat_template_arg):
|
|||||||
chat_template_name = chat_template_arg
|
chat_template_name = chat_template_arg
|
||||||
|
|
||||||
|
|
||||||
async def v1_completions(tokenizer_manager, raw_request: Request):
|
async def v1_files_create(file: UploadFile, purpose: str, file_storage_pth: str = None):
|
||||||
request_json = await raw_request.json()
|
try:
|
||||||
request = CompletionRequest(**request_json)
|
global storage_dir
|
||||||
prompt = request.prompt
|
if file_storage_pth:
|
||||||
if isinstance(prompt, str) or isinstance(prompt[0], str):
|
storage_dir = file_storage_pth
|
||||||
prompt_kwargs = {"text": prompt}
|
# Read the file content
|
||||||
else:
|
file_content = await file.read()
|
||||||
prompt_kwargs = {"input_ids": prompt}
|
|
||||||
|
|
||||||
adapted_request = GenerateReqInput(
|
# Create an instance of RequestBody
|
||||||
**prompt_kwargs,
|
request_body = FileRequest(file=file_content, purpose=purpose)
|
||||||
sampling_params={
|
|
||||||
|
# Save the file to the sglang_oai_storage directory
|
||||||
|
os.makedirs(storage_dir, exist_ok=True)
|
||||||
|
file_id = f"backend_input_file-{uuid.uuid4()}"
|
||||||
|
filename = f"{file_id}.jsonl"
|
||||||
|
file_path = os.path.join(storage_dir, filename)
|
||||||
|
|
||||||
|
with open(file_path, "wb") as f:
|
||||||
|
f.write(request_body.file)
|
||||||
|
|
||||||
|
# add info to global file map
|
||||||
|
file_id_request[file_id] = FileMetadata(filename=file.filename, purpose=purpose)
|
||||||
|
file_id_storage[file_id] = file_path
|
||||||
|
|
||||||
|
# Return the response in the required format
|
||||||
|
response = FileResponse(
|
||||||
|
id=file_id,
|
||||||
|
bytes=len(request_body.file),
|
||||||
|
created_at=int(time.time()),
|
||||||
|
filename=file.filename,
|
||||||
|
purpose=request_body.purpose,
|
||||||
|
)
|
||||||
|
file_id_response[file_id] = response
|
||||||
|
|
||||||
|
return response
|
||||||
|
except ValidationError as e:
|
||||||
|
return {"error": "Invalid input", "details": e.errors()}
|
||||||
|
|
||||||
|
|
||||||
|
async def v1_batches(tokenizer_manager, raw_request: Request):
|
||||||
|
try:
|
||||||
|
body = await raw_request.json()
|
||||||
|
|
||||||
|
batch_request = BatchRequest(**body)
|
||||||
|
|
||||||
|
batch_id = f"batch_{uuid.uuid4()}"
|
||||||
|
|
||||||
|
# Create an instance of BatchResponse
|
||||||
|
batch_response = BatchResponse(
|
||||||
|
id=batch_id,
|
||||||
|
endpoint=batch_request.endpoint,
|
||||||
|
input_file_id=batch_request.input_file_id,
|
||||||
|
completion_window=batch_request.completion_window,
|
||||||
|
created_at=int(time.time()),
|
||||||
|
metadata=batch_request.metadata,
|
||||||
|
)
|
||||||
|
|
||||||
|
batch_storage[batch_id] = batch_response
|
||||||
|
|
||||||
|
# Start processing the batch asynchronously
|
||||||
|
asyncio.create_task(process_batch(tokenizer_manager, batch_id, batch_request))
|
||||||
|
|
||||||
|
# Return the initial batch_response
|
||||||
|
return batch_response
|
||||||
|
|
||||||
|
except ValidationError as e:
|
||||||
|
return {"error": "Invalid input", "details": e.errors()}
|
||||||
|
except Exception as e:
|
||||||
|
return {"error": str(e)}
|
||||||
|
|
||||||
|
|
||||||
|
async def process_batch(tokenizer_manager, batch_id: str, batch_request: BatchRequest):
|
||||||
|
try:
|
||||||
|
# Update the batch status to "in_progress"
|
||||||
|
batch_storage[batch_id].status = "in_progress"
|
||||||
|
batch_storage[batch_id].in_progress_at = int(time.time())
|
||||||
|
|
||||||
|
# Retrieve the input file content
|
||||||
|
input_file_request = file_id_request.get(batch_request.input_file_id)
|
||||||
|
if not input_file_request:
|
||||||
|
raise ValueError("Input file not found")
|
||||||
|
|
||||||
|
# Parse the JSONL file and process each request
|
||||||
|
input_file_path = file_id_storage.get(batch_request.input_file_id)
|
||||||
|
with open(input_file_path, "r", encoding="utf-8") as f:
|
||||||
|
lines = f.readlines()
|
||||||
|
|
||||||
|
total_requests = len(lines)
|
||||||
|
completed_requests = 0
|
||||||
|
failed_requests = 0
|
||||||
|
|
||||||
|
all_ret = []
|
||||||
|
end_point = batch_storage[batch_id].endpoint
|
||||||
|
file_request_list = []
|
||||||
|
all_requests = []
|
||||||
|
for line in lines:
|
||||||
|
request_data = json.loads(line)
|
||||||
|
file_request_list.append(request_data)
|
||||||
|
body = request_data["body"]
|
||||||
|
if end_point == "/v1/chat/completions":
|
||||||
|
all_requests.append(ChatCompletionRequest(**body))
|
||||||
|
elif end_point == "/v1/completions":
|
||||||
|
all_requests.append(CompletionRequest(**body))
|
||||||
|
if end_point == "/v1/chat/completions":
|
||||||
|
adapted_request, request = v1_chat_generate_request(
|
||||||
|
all_requests, tokenizer_manager
|
||||||
|
)
|
||||||
|
elif end_point == "/v1/completions":
|
||||||
|
adapted_request, request = v1_generate_request(all_requests)
|
||||||
|
try:
|
||||||
|
ret = await tokenizer_manager.generate_request(adapted_request).__anext__()
|
||||||
|
if not isinstance(ret, list):
|
||||||
|
ret = [ret]
|
||||||
|
if end_point == "/v1/chat/completions":
|
||||||
|
responses = v1_chat_generate_response(request, ret, to_file=True)
|
||||||
|
else:
|
||||||
|
responses = v1_generate_response(request, ret, to_file=True)
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
error_json = {
|
||||||
|
"id": f"batch_req_{uuid.uuid4()}",
|
||||||
|
"custom_id": request_data.get("custom_id"),
|
||||||
|
"response": None,
|
||||||
|
"error": {"message": str(e)},
|
||||||
|
}
|
||||||
|
all_ret.append(error_json)
|
||||||
|
failed_requests += len(file_request_list)
|
||||||
|
|
||||||
|
for idx, response in enumerate(responses):
|
||||||
|
## the batch_req here can be changed to be named within a batch granularity
|
||||||
|
response_json = {
|
||||||
|
"id": f"batch_req_{uuid.uuid4()}",
|
||||||
|
"custom_id": file_request_list[idx].get("custom_id"),
|
||||||
|
"response": response,
|
||||||
|
"error": None,
|
||||||
|
}
|
||||||
|
all_ret.append(response_json)
|
||||||
|
completed_requests += 1
|
||||||
|
# Write results to a new file
|
||||||
|
output_file_id = f"backend_result_file-{uuid.uuid4()}"
|
||||||
|
global storage_dir
|
||||||
|
output_file_path = os.path.join(storage_dir, f"{output_file_id}.jsonl")
|
||||||
|
with open(output_file_path, "w", encoding="utf-8") as f:
|
||||||
|
for ret in all_ret:
|
||||||
|
f.write(json.dumps(ret) + "\n")
|
||||||
|
|
||||||
|
# Update batch response with output file information
|
||||||
|
retrieve_batch = batch_storage[batch_id]
|
||||||
|
retrieve_batch.output_file_id = output_file_id
|
||||||
|
file_id_storage[output_file_id] = output_file_path
|
||||||
|
# Update batch status to "completed"
|
||||||
|
retrieve_batch.status = "completed"
|
||||||
|
retrieve_batch.completed_at = int(time.time())
|
||||||
|
retrieve_batch.request_counts = {
|
||||||
|
"total": total_requests,
|
||||||
|
"completed": completed_requests,
|
||||||
|
"failed": failed_requests,
|
||||||
|
}
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
print("error in SGlang:", e)
|
||||||
|
# Update batch status to "failed"
|
||||||
|
retrieve_batch = batch_storage[batch_id]
|
||||||
|
retrieve_batch.status = "failed"
|
||||||
|
retrieve_batch.failed_at = int(time.time())
|
||||||
|
retrieve_batch.errors = {"message": str(e)}
|
||||||
|
|
||||||
|
|
||||||
|
async def v1_retrieve_batch(batch_id: str):
|
||||||
|
# Retrieve the batch job from the in-memory storage
|
||||||
|
batch_response = batch_storage.get(batch_id)
|
||||||
|
if batch_response is None:
|
||||||
|
raise HTTPException(status_code=404, detail="Batch not found")
|
||||||
|
|
||||||
|
return batch_response
|
||||||
|
|
||||||
|
|
||||||
|
async def v1_retrieve_file(file_id: str):
|
||||||
|
# Retrieve the batch job from the in-memory storage
|
||||||
|
file_response = file_id_response.get(file_id)
|
||||||
|
if file_response is None:
|
||||||
|
raise HTTPException(status_code=404, detail="File not found")
|
||||||
|
return file_response
|
||||||
|
|
||||||
|
|
||||||
|
async def v1_retrieve_file_content(file_id: str):
|
||||||
|
file_pth = file_id_storage.get(file_id)
|
||||||
|
if not file_pth or not os.path.exists(file_pth):
|
||||||
|
raise HTTPException(status_code=404, detail="File not found")
|
||||||
|
|
||||||
|
def iter_file():
|
||||||
|
with open(file_pth, mode="rb") as file_like:
|
||||||
|
yield from file_like
|
||||||
|
|
||||||
|
return StreamingResponse(iter_file(), media_type="application/octet-stream")
|
||||||
|
|
||||||
|
|
||||||
|
def v1_generate_request(all_requests):
|
||||||
|
|
||||||
|
prompts = []
|
||||||
|
sampling_params_list = []
|
||||||
|
first_prompt_type = type(all_requests[0].prompt)
|
||||||
|
for request in all_requests:
|
||||||
|
prompt = request.prompt
|
||||||
|
assert (
|
||||||
|
type(prompt) == first_prompt_type
|
||||||
|
), "All prompts must be of the same type in file input settings"
|
||||||
|
prompts.append(prompt)
|
||||||
|
sampling_params_list.append(
|
||||||
|
{
|
||||||
"temperature": request.temperature,
|
"temperature": request.temperature,
|
||||||
"max_new_tokens": request.max_tokens,
|
"max_new_tokens": request.max_tokens,
|
||||||
"stop": request.stop,
|
"stop": request.stop,
|
||||||
@@ -127,12 +351,145 @@ async def v1_completions(tokenizer_manager, raw_request: Request):
|
|||||||
"regex": request.regex,
|
"regex": request.regex,
|
||||||
"n": request.n,
|
"n": request.n,
|
||||||
"ignore_eos": request.ignore_eos,
|
"ignore_eos": request.ignore_eos,
|
||||||
},
|
}
|
||||||
return_logprob=request.logprobs is not None and request.logprobs > 0,
|
|
||||||
top_logprobs_num=request.logprobs if request.logprobs is not None else 0,
|
|
||||||
return_text_in_logprobs=True,
|
|
||||||
stream=request.stream,
|
|
||||||
)
|
)
|
||||||
|
if len(all_requests) > 1 and request.n > 1:
|
||||||
|
raise ValueError(
|
||||||
|
"Batch operation is not supported for completions from files"
|
||||||
|
)
|
||||||
|
|
||||||
|
if len(all_requests) == 1:
|
||||||
|
prompt = prompts[0]
|
||||||
|
sampling_params_list = sampling_params_list[0]
|
||||||
|
if isinstance(prompts, str) or isinstance(prompts[0], str):
|
||||||
|
prompt_kwargs = {"text": prompt}
|
||||||
|
else:
|
||||||
|
prompt_kwargs = {"input_ids": prompt}
|
||||||
|
else:
|
||||||
|
if isinstance(prompts[0], str):
|
||||||
|
prompt_kwargs = {"text": prompts}
|
||||||
|
else:
|
||||||
|
prompt_kwargs = {"input_ids": prompts}
|
||||||
|
|
||||||
|
adapted_request = GenerateReqInput(
|
||||||
|
**prompt_kwargs,
|
||||||
|
sampling_params=sampling_params_list,
|
||||||
|
return_logprob=all_requests[0].logprobs is not None
|
||||||
|
and all_requests[0].logprobs > 0,
|
||||||
|
top_logprobs_num=(
|
||||||
|
all_requests[0].logprobs if all_requests[0].logprobs is not None else 0
|
||||||
|
),
|
||||||
|
return_text_in_logprobs=True,
|
||||||
|
stream=all_requests[0].stream,
|
||||||
|
)
|
||||||
|
if len(all_requests) == 1:
|
||||||
|
return adapted_request, all_requests[0]
|
||||||
|
return adapted_request, all_requests
|
||||||
|
|
||||||
|
|
||||||
|
def v1_generate_response(request, ret, to_file=False):
|
||||||
|
choices = []
|
||||||
|
echo = False
|
||||||
|
|
||||||
|
if (not isinstance(request, List)) and request.echo:
|
||||||
|
# TODO: handle the case propmt is token ids
|
||||||
|
if isinstance(request.prompt, list):
|
||||||
|
prompts = request.prompt
|
||||||
|
else:
|
||||||
|
prompts = [request.prompt]
|
||||||
|
echo = True
|
||||||
|
|
||||||
|
for idx, ret_item in enumerate(ret):
|
||||||
|
text = ret_item["text"]
|
||||||
|
if isinstance(request, List) and request[idx].echo:
|
||||||
|
echo = True
|
||||||
|
text = request[idx].prompt + text
|
||||||
|
if (not isinstance(request, List)) and echo:
|
||||||
|
text = prompts[idx] + text
|
||||||
|
|
||||||
|
logprobs = False
|
||||||
|
if isinstance(request, List) and request[idx].logprobs:
|
||||||
|
logprobs = True
|
||||||
|
elif (not isinstance(request, List)) and request.logprobs:
|
||||||
|
logprobs = True
|
||||||
|
if logprobs:
|
||||||
|
if echo:
|
||||||
|
input_token_logprobs = ret_item["meta_info"]["input_token_logprobs"]
|
||||||
|
input_top_logprobs = ret_item["meta_info"]["input_top_logprobs"]
|
||||||
|
else:
|
||||||
|
input_token_logprobs = None
|
||||||
|
input_top_logprobs = None
|
||||||
|
|
||||||
|
logprobs = to_openai_style_logprobs(
|
||||||
|
input_token_logprobs=input_token_logprobs,
|
||||||
|
input_top_logprobs=input_top_logprobs,
|
||||||
|
output_token_logprobs=ret_item["meta_info"]["output_token_logprobs"],
|
||||||
|
output_top_logprobs=ret_item["meta_info"]["output_top_logprobs"],
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
logprobs = None
|
||||||
|
|
||||||
|
if to_file:
|
||||||
|
## to make the choise data json serializable
|
||||||
|
choice_data = {
|
||||||
|
"index": 0,
|
||||||
|
"text": text,
|
||||||
|
"logprobs": logprobs,
|
||||||
|
"finish_reason": ret_item["meta_info"]["finish_reason"],
|
||||||
|
}
|
||||||
|
else:
|
||||||
|
choice_data = CompletionResponseChoice(
|
||||||
|
index=idx,
|
||||||
|
text=text,
|
||||||
|
logprobs=logprobs,
|
||||||
|
finish_reason=ret_item["meta_info"]["finish_reason"],
|
||||||
|
)
|
||||||
|
|
||||||
|
choices.append(choice_data)
|
||||||
|
|
||||||
|
if to_file:
|
||||||
|
responses = []
|
||||||
|
for i, choice in enumerate(choices):
|
||||||
|
response = {
|
||||||
|
"status_code": 200,
|
||||||
|
"request_id": ret[i]["meta_info"]["id"],
|
||||||
|
"body": {
|
||||||
|
## remain the same but if needed we can change that
|
||||||
|
"id": ret[i]["meta_info"]["id"],
|
||||||
|
"object": "text_completion",
|
||||||
|
"created": int(time.time()),
|
||||||
|
"model": request[i].model,
|
||||||
|
"choices": choice,
|
||||||
|
"usage": {
|
||||||
|
"prompt_tokens": ret[i]["meta_info"]["prompt_tokens"],
|
||||||
|
"completion_tokens": ret[i]["meta_info"]["completion_tokens"],
|
||||||
|
"total_tokens": ret[i]["meta_info"]["prompt_tokens"]
|
||||||
|
+ ret[i]["meta_info"]["completion_tokens"],
|
||||||
|
},
|
||||||
|
"system_fingerprint": None,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
responses.append(response)
|
||||||
|
return responses
|
||||||
|
else:
|
||||||
|
completion_tokens = sum(item["meta_info"]["completion_tokens"] for item in ret)
|
||||||
|
response = CompletionResponse(
|
||||||
|
id=ret[0]["meta_info"]["id"],
|
||||||
|
model=request.model,
|
||||||
|
choices=choices,
|
||||||
|
usage=UsageInfo(
|
||||||
|
prompt_tokens=ret[0]["meta_info"]["prompt_tokens"],
|
||||||
|
completion_tokens=completion_tokens,
|
||||||
|
total_tokens=ret[0]["meta_info"]["prompt_tokens"] + completion_tokens,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
return response
|
||||||
|
|
||||||
|
|
||||||
|
async def v1_completions(tokenizer_manager, raw_request: Request):
|
||||||
|
request_json = await raw_request.json()
|
||||||
|
all_requests = [CompletionRequest(**request_json)]
|
||||||
|
adapted_request, request = v1_generate_request(all_requests)
|
||||||
|
|
||||||
if adapted_request.stream:
|
if adapted_request.stream:
|
||||||
|
|
||||||
@@ -223,65 +580,17 @@ async def v1_completions(tokenizer_manager, raw_request: Request):
|
|||||||
|
|
||||||
if not isinstance(ret, list):
|
if not isinstance(ret, list):
|
||||||
ret = [ret]
|
ret = [ret]
|
||||||
if request.echo:
|
|
||||||
# TODO: handle the case propmt is token ids
|
|
||||||
if isinstance(request.prompt, list):
|
|
||||||
prompts = request.prompt
|
|
||||||
else:
|
|
||||||
prompts = [request.prompt]
|
|
||||||
choices = []
|
|
||||||
|
|
||||||
for idx, ret_item in enumerate(ret):
|
|
||||||
text = ret_item["text"]
|
|
||||||
|
|
||||||
if request.echo:
|
|
||||||
text = prompts[idx] + text
|
|
||||||
|
|
||||||
if request.logprobs:
|
|
||||||
if request.echo:
|
|
||||||
input_token_logprobs = ret_item["meta_info"]["input_token_logprobs"]
|
|
||||||
input_top_logprobs = ret_item["meta_info"]["input_top_logprobs"]
|
|
||||||
else:
|
|
||||||
input_token_logprobs = None
|
|
||||||
input_top_logprobs = None
|
|
||||||
|
|
||||||
logprobs = to_openai_style_logprobs(
|
|
||||||
input_token_logprobs=input_token_logprobs,
|
|
||||||
input_top_logprobs=input_top_logprobs,
|
|
||||||
output_token_logprobs=ret_item["meta_info"]["output_token_logprobs"],
|
|
||||||
output_top_logprobs=ret_item["meta_info"]["output_top_logprobs"],
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
logprobs = None
|
|
||||||
|
|
||||||
choice_data = CompletionResponseChoice(
|
|
||||||
index=idx,
|
|
||||||
text=text,
|
|
||||||
logprobs=logprobs,
|
|
||||||
finish_reason=ret_item["meta_info"]["finish_reason"],
|
|
||||||
)
|
|
||||||
|
|
||||||
choices.append(choice_data)
|
|
||||||
|
|
||||||
completion_tokens = sum(item["meta_info"]["completion_tokens"] for item in ret)
|
|
||||||
response = CompletionResponse(
|
|
||||||
id=ret[0]["meta_info"]["id"],
|
|
||||||
model=request.model,
|
|
||||||
choices=choices,
|
|
||||||
usage=UsageInfo(
|
|
||||||
prompt_tokens=ret[0]["meta_info"]["prompt_tokens"],
|
|
||||||
completion_tokens=completion_tokens,
|
|
||||||
total_tokens=ret[0]["meta_info"]["prompt_tokens"] + completion_tokens,
|
|
||||||
),
|
|
||||||
)
|
|
||||||
|
|
||||||
|
response = v1_generate_response(request, ret)
|
||||||
return response
|
return response
|
||||||
|
|
||||||
|
|
||||||
async def v1_chat_completions(tokenizer_manager, raw_request: Request):
|
def v1_chat_generate_request(all_requests, tokenizer_manager):
|
||||||
request_json = await raw_request.json()
|
|
||||||
request = ChatCompletionRequest(**request_json)
|
|
||||||
|
|
||||||
|
texts = []
|
||||||
|
sampling_params_list = []
|
||||||
|
image_data_list = []
|
||||||
|
for request in all_requests:
|
||||||
# Prep the data needed for the underlying GenerateReqInput:
|
# Prep the data needed for the underlying GenerateReqInput:
|
||||||
# - prompt: The full prompt string.
|
# - prompt: The full prompt string.
|
||||||
# - stop: Custom stop tokens.
|
# - stop: Custom stop tokens.
|
||||||
@@ -310,11 +619,9 @@ async def v1_chat_completions(tokenizer_manager, raw_request: Request):
|
|||||||
prompt = request.messages
|
prompt = request.messages
|
||||||
stop = request.stop
|
stop = request.stop
|
||||||
image_data = None
|
image_data = None
|
||||||
|
texts.append(prompt)
|
||||||
adapted_request = GenerateReqInput(
|
sampling_params_list.append(
|
||||||
text=prompt,
|
{
|
||||||
image_data=image_data,
|
|
||||||
sampling_params={
|
|
||||||
"temperature": request.temperature,
|
"temperature": request.temperature,
|
||||||
"max_new_tokens": request.max_tokens,
|
"max_new_tokens": request.max_tokens,
|
||||||
"stop": stop,
|
"stop": stop,
|
||||||
@@ -323,9 +630,94 @@ async def v1_chat_completions(tokenizer_manager, raw_request: Request):
|
|||||||
"frequency_penalty": request.frequency_penalty,
|
"frequency_penalty": request.frequency_penalty,
|
||||||
"regex": request.regex,
|
"regex": request.regex,
|
||||||
"n": request.n,
|
"n": request.n,
|
||||||
},
|
}
|
||||||
|
)
|
||||||
|
image_data_list.append(image_data)
|
||||||
|
if len(all_requests) == 1:
|
||||||
|
texts = texts[0]
|
||||||
|
sampling_params_list = sampling_params_list[0]
|
||||||
|
image_data = image_data_list[0]
|
||||||
|
adapted_request = GenerateReqInput(
|
||||||
|
text=texts,
|
||||||
|
image_data=image_data,
|
||||||
|
sampling_params=sampling_params_list,
|
||||||
stream=request.stream,
|
stream=request.stream,
|
||||||
)
|
)
|
||||||
|
if len(all_requests) == 1:
|
||||||
|
return adapted_request, all_requests[0]
|
||||||
|
return adapted_request, all_requests
|
||||||
|
|
||||||
|
|
||||||
|
def v1_chat_generate_response(request, ret, to_file=False):
|
||||||
|
choices = []
|
||||||
|
total_prompt_tokens = 0
|
||||||
|
total_completion_tokens = 0
|
||||||
|
|
||||||
|
for idx, ret_item in enumerate(ret):
|
||||||
|
prompt_tokens = ret_item["meta_info"]["prompt_tokens"]
|
||||||
|
completion_tokens = ret_item["meta_info"]["completion_tokens"]
|
||||||
|
|
||||||
|
if to_file:
|
||||||
|
## to make the choice data json serializable
|
||||||
|
choice_data = {
|
||||||
|
"index": 0,
|
||||||
|
"message": {"role": "assistant", "content": ret_item["text"]},
|
||||||
|
"logprobs": None,
|
||||||
|
"finish_reason": ret_item["meta_info"]["finish_reason"],
|
||||||
|
}
|
||||||
|
else:
|
||||||
|
choice_data = ChatCompletionResponseChoice(
|
||||||
|
index=idx,
|
||||||
|
message=ChatMessage(role="assistant", content=ret_item["text"]),
|
||||||
|
finish_reason=ret_item["meta_info"]["finish_reason"],
|
||||||
|
)
|
||||||
|
|
||||||
|
choices.append(choice_data)
|
||||||
|
total_prompt_tokens = prompt_tokens
|
||||||
|
total_completion_tokens += completion_tokens
|
||||||
|
if to_file:
|
||||||
|
responses = []
|
||||||
|
|
||||||
|
for i, choice in enumerate(choices):
|
||||||
|
response = {
|
||||||
|
"status_code": 200,
|
||||||
|
"request_id": ret[i]["meta_info"]["id"],
|
||||||
|
"body": {
|
||||||
|
## remain the same but if needed we can change that
|
||||||
|
"id": ret[i]["meta_info"]["id"],
|
||||||
|
"object": "chat.completion",
|
||||||
|
"created": int(time.time()),
|
||||||
|
"model": request[i].model,
|
||||||
|
"choices": choice,
|
||||||
|
"usage": {
|
||||||
|
"prompt_tokens": ret[i]["meta_info"]["prompt_tokens"],
|
||||||
|
"completion_tokens": ret[i]["meta_info"]["completion_tokens"],
|
||||||
|
"total_tokens": ret[i]["meta_info"]["prompt_tokens"]
|
||||||
|
+ ret[i]["meta_info"]["completion_tokens"],
|
||||||
|
},
|
||||||
|
"system_fingerprint": None,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
responses.append(response)
|
||||||
|
return responses
|
||||||
|
else:
|
||||||
|
response = ChatCompletionResponse(
|
||||||
|
id=ret[0]["meta_info"]["id"],
|
||||||
|
model=request.model,
|
||||||
|
choices=choices,
|
||||||
|
usage=UsageInfo(
|
||||||
|
prompt_tokens=total_prompt_tokens,
|
||||||
|
completion_tokens=total_completion_tokens,
|
||||||
|
total_tokens=total_prompt_tokens + total_completion_tokens,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
return response
|
||||||
|
|
||||||
|
|
||||||
|
async def v1_chat_completions(tokenizer_manager, raw_request: Request):
|
||||||
|
request_json = await raw_request.json()
|
||||||
|
all_requests = [ChatCompletionRequest(**request_json)]
|
||||||
|
adapted_request, request = v1_chat_generate_request(all_requests, tokenizer_manager)
|
||||||
|
|
||||||
if adapted_request.stream:
|
if adapted_request.stream:
|
||||||
|
|
||||||
@@ -387,34 +779,8 @@ async def v1_chat_completions(tokenizer_manager, raw_request: Request):
|
|||||||
|
|
||||||
if not isinstance(ret, list):
|
if not isinstance(ret, list):
|
||||||
ret = [ret]
|
ret = [ret]
|
||||||
choices = []
|
|
||||||
total_prompt_tokens = 0
|
|
||||||
total_completion_tokens = 0
|
|
||||||
|
|
||||||
for idx, ret_item in enumerate(ret):
|
response = v1_chat_generate_response(request, ret)
|
||||||
prompt_tokens = ret_item["meta_info"]["prompt_tokens"]
|
|
||||||
completion_tokens = ret_item["meta_info"]["completion_tokens"]
|
|
||||||
|
|
||||||
choice_data = ChatCompletionResponseChoice(
|
|
||||||
index=idx,
|
|
||||||
message=ChatMessage(role="assistant", content=ret_item["text"]),
|
|
||||||
finish_reason=ret_item["meta_info"]["finish_reason"],
|
|
||||||
)
|
|
||||||
|
|
||||||
choices.append(choice_data)
|
|
||||||
total_prompt_tokens = prompt_tokens
|
|
||||||
total_completion_tokens += completion_tokens
|
|
||||||
|
|
||||||
response = ChatCompletionResponse(
|
|
||||||
id=ret[0]["meta_info"]["id"],
|
|
||||||
model=request.model,
|
|
||||||
choices=choices,
|
|
||||||
usage=UsageInfo(
|
|
||||||
prompt_tokens=total_prompt_tokens,
|
|
||||||
completion_tokens=total_completion_tokens,
|
|
||||||
total_tokens=total_prompt_tokens + total_completion_tokens,
|
|
||||||
),
|
|
||||||
)
|
|
||||||
|
|
||||||
return response
|
return response
|
||||||
|
|
||||||
|
|||||||
@@ -60,6 +60,55 @@ class UsageInfo(BaseModel):
|
|||||||
completion_tokens: Optional[int] = 0
|
completion_tokens: Optional[int] = 0
|
||||||
|
|
||||||
|
|
||||||
|
class FileRequest(BaseModel):
|
||||||
|
# https://platform.openai.com/docs/api-reference/files/create
|
||||||
|
file: bytes # The File object (not file name) to be uploaded
|
||||||
|
purpose: str = (
|
||||||
|
"batch" # The intended purpose of the uploaded file, default is "batch"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class FileResponse(BaseModel):
|
||||||
|
id: str
|
||||||
|
object: str = "file"
|
||||||
|
bytes: int
|
||||||
|
created_at: int
|
||||||
|
filename: str
|
||||||
|
purpose: str
|
||||||
|
|
||||||
|
|
||||||
|
class BatchRequest(BaseModel):
|
||||||
|
input_file_id: (
|
||||||
|
str # The ID of an uploaded file that contains requests for the new batch
|
||||||
|
)
|
||||||
|
endpoint: str # The endpoint to be used for all requests in the batch
|
||||||
|
completion_window: str # The time frame within which the batch should be processed
|
||||||
|
metadata: Optional[dict] = None # Optional custom metadata for the batch
|
||||||
|
|
||||||
|
|
||||||
|
class BatchResponse(BaseModel):
|
||||||
|
id: str
|
||||||
|
object: str = "batch"
|
||||||
|
endpoint: str
|
||||||
|
errors: Optional[dict] = None
|
||||||
|
input_file_id: str
|
||||||
|
completion_window: str
|
||||||
|
status: str = "validating"
|
||||||
|
output_file_id: Optional[str] = None
|
||||||
|
error_file_id: Optional[str] = None
|
||||||
|
created_at: int
|
||||||
|
in_progress_at: Optional[int] = None
|
||||||
|
expires_at: Optional[int] = None
|
||||||
|
finalizing_at: Optional[int] = None
|
||||||
|
completed_at: Optional[int] = None
|
||||||
|
failed_at: Optional[int] = None
|
||||||
|
expired_at: Optional[int] = None
|
||||||
|
cancelling_at: Optional[int] = None
|
||||||
|
cancelled_at: Optional[int] = None
|
||||||
|
request_counts: dict = {"total": 0, "completed": 0, "failed": 0}
|
||||||
|
metadata: Optional[dict] = None
|
||||||
|
|
||||||
|
|
||||||
class CompletionRequest(BaseModel):
|
class CompletionRequest(BaseModel):
|
||||||
# Ordered by official OpenAI API documentation
|
# Ordered by official OpenAI API documentation
|
||||||
# https://platform.openai.com/docs/api-reference/completions/create
|
# https://platform.openai.com/docs/api-reference/completions/create
|
||||||
|
|||||||
@@ -38,7 +38,7 @@ import psutil
|
|||||||
import requests
|
import requests
|
||||||
import uvicorn
|
import uvicorn
|
||||||
import uvloop
|
import uvloop
|
||||||
from fastapi import FastAPI, Request
|
from fastapi import FastAPI, File, Form, Request, UploadFile
|
||||||
from fastapi.responses import JSONResponse, Response, StreamingResponse
|
from fastapi.responses import JSONResponse, Response, StreamingResponse
|
||||||
|
|
||||||
from sglang.lang.backend.runtime_endpoint import RuntimeEndpoint
|
from sglang.lang.backend.runtime_endpoint import RuntimeEndpoint
|
||||||
@@ -56,8 +56,13 @@ from sglang.srt.managers.io_struct import GenerateReqInput
|
|||||||
from sglang.srt.managers.tokenizer_manager import TokenizerManager
|
from sglang.srt.managers.tokenizer_manager import TokenizerManager
|
||||||
from sglang.srt.openai_api.adapter import (
|
from sglang.srt.openai_api.adapter import (
|
||||||
load_chat_template_for_openai_api,
|
load_chat_template_for_openai_api,
|
||||||
|
v1_batches,
|
||||||
v1_chat_completions,
|
v1_chat_completions,
|
||||||
v1_completions,
|
v1_completions,
|
||||||
|
v1_files_create,
|
||||||
|
v1_retrieve_batch,
|
||||||
|
v1_retrieve_file,
|
||||||
|
v1_retrieve_file_content,
|
||||||
)
|
)
|
||||||
from sglang.srt.openai_api.protocol import ModelCard, ModelList
|
from sglang.srt.openai_api.protocol import ModelCard, ModelList
|
||||||
from sglang.srt.server_args import PortArgs, ServerArgs
|
from sglang.srt.server_args import PortArgs, ServerArgs
|
||||||
@@ -152,6 +157,35 @@ async def openai_v1_chat_completions(raw_request: Request):
|
|||||||
return await v1_chat_completions(tokenizer_manager, raw_request)
|
return await v1_chat_completions(tokenizer_manager, raw_request)
|
||||||
|
|
||||||
|
|
||||||
|
@app.post("/v1/files")
|
||||||
|
async def openai_v1_files(file: UploadFile = File(...), purpose: str = Form("batch")):
|
||||||
|
return await v1_files_create(
|
||||||
|
file, purpose, tokenizer_manager.server_args.file_storage_pth
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@app.post("/v1/batches")
|
||||||
|
async def openai_v1_batches(raw_request: Request):
|
||||||
|
return await v1_batches(tokenizer_manager, raw_request)
|
||||||
|
|
||||||
|
|
||||||
|
@app.get("/v1/batches/{batch_id}")
|
||||||
|
async def retrieve_batch(batch_id: str):
|
||||||
|
return await v1_retrieve_batch(batch_id)
|
||||||
|
|
||||||
|
|
||||||
|
@app.get("/v1/files/{file_id}")
|
||||||
|
async def retrieve_file(file_id: str):
|
||||||
|
# https://platform.openai.com/docs/api-reference/files/retrieve
|
||||||
|
return await v1_retrieve_file(file_id)
|
||||||
|
|
||||||
|
|
||||||
|
@app.get("/v1/files/{file_id}/content")
|
||||||
|
async def retrieve_file_content(file_id: str):
|
||||||
|
# https://platform.openai.com/docs/api-reference/files/retrieve-contents
|
||||||
|
return await v1_retrieve_file_content(file_id)
|
||||||
|
|
||||||
|
|
||||||
@app.get("/v1/models")
|
@app.get("/v1/models")
|
||||||
def available_models():
|
def available_models():
|
||||||
"""Show available models."""
|
"""Show available models."""
|
||||||
|
|||||||
@@ -60,6 +60,7 @@ class ServerArgs:
|
|||||||
|
|
||||||
# Other
|
# Other
|
||||||
api_key: str = ""
|
api_key: str = ""
|
||||||
|
file_storage_pth: str = "SGlang_storage"
|
||||||
|
|
||||||
# Data parallelism
|
# Data parallelism
|
||||||
dp_size: int = 1
|
dp_size: int = 1
|
||||||
@@ -290,6 +291,12 @@ class ServerArgs:
|
|||||||
default=ServerArgs.api_key,
|
default=ServerArgs.api_key,
|
||||||
help="Set API key of the server.",
|
help="Set API key of the server.",
|
||||||
)
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--file-storage-pth",
|
||||||
|
type=str,
|
||||||
|
default=ServerArgs.file_storage_pth,
|
||||||
|
help="The path of the file storage in backend.",
|
||||||
|
)
|
||||||
|
|
||||||
# Data parallelism
|
# Data parallelism
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
|
|||||||
Reference in New Issue
Block a user