Add support for OpenAI API : offline batch(file) processing (#699)
Co-authored-by: hnyls2002 <hnyls2002@gmail.com>
This commit is contained in:
@@ -79,8 +79,26 @@ class GenerateReqInput:
|
||||
if self.top_logprobs_num is None:
|
||||
self.top_logprobs_num = 0
|
||||
else:
|
||||
|
||||
parallel_sample_num = self.sampling_params.get("n", 1)
|
||||
parallel_sample_num_list = []
|
||||
if isinstance(self.sampling_params, dict):
|
||||
parallel_sample_num = self.sampling_params.get("n", 1)
|
||||
elif isinstance(self.sampling_params, list):
|
||||
for sp in self.sampling_params:
|
||||
parallel_sample_num = sp.get("n", 1)
|
||||
parallel_sample_num_list.append(parallel_sample_num)
|
||||
parallel_sample_num = max(parallel_sample_num_list)
|
||||
all_equal = all(
|
||||
element == parallel_sample_num
|
||||
for element in parallel_sample_num_list
|
||||
)
|
||||
if parallel_sample_num > 1 and (not all_equal):
|
||||
## TODO cope with the case that the parallel_sample_num is different for different samples
|
||||
raise ValueError(
|
||||
"The parallel_sample_num should be the same for all samples in sample params."
|
||||
)
|
||||
else:
|
||||
parallel_sample_num = 1
|
||||
self.parallel_sample_num = parallel_sample_num
|
||||
|
||||
if parallel_sample_num != 1:
|
||||
# parallel sampling +1 represents the original prefill stage
|
||||
|
||||
@@ -84,6 +84,7 @@ class TokenizerManager:
|
||||
trust_remote_code=server_args.trust_remote_code,
|
||||
model_overide_args=model_overide_args,
|
||||
)
|
||||
|
||||
if server_args.context_length is not None:
|
||||
self.context_len = server_args.context_length
|
||||
else:
|
||||
@@ -152,31 +153,33 @@ class TokenizerManager:
|
||||
self, obj, request, index=None, is_cache_for_prefill=False
|
||||
):
|
||||
if not is_cache_for_prefill:
|
||||
rid = obj.rid if index is None else obj.rid[index]
|
||||
input_text = obj.text if index is None else obj.text[index]
|
||||
not_use_index = not (index is not None)
|
||||
rid = obj.rid if not_use_index else obj.rid[index]
|
||||
input_text = obj.text if not_use_index else obj.text[index]
|
||||
input_ids = (
|
||||
self.tokenizer.encode(input_text)
|
||||
if obj.input_ids is None
|
||||
else obj.input_ids
|
||||
)
|
||||
if index is not None and obj.input_ids:
|
||||
if not not_use_index and obj.input_ids:
|
||||
input_ids = obj.input_ids[index]
|
||||
|
||||
self._validate_input_length(input_ids)
|
||||
|
||||
sampling_params = self._get_sampling_params(
|
||||
obj.sampling_params if index is None else obj.sampling_params[index]
|
||||
obj.sampling_params if not_use_index else obj.sampling_params[index]
|
||||
)
|
||||
pixel_values, image_hash, image_size = await self._get_pixel_values(
|
||||
obj.image_data if index is None else obj.image_data[index]
|
||||
obj.image_data if not_use_index else obj.image_data[index]
|
||||
)
|
||||
return_logprob = (
|
||||
obj.return_logprob if index is None else obj.return_logprob[index]
|
||||
obj.return_logprob if not_use_index else obj.return_logprob[index]
|
||||
)
|
||||
logprob_start_len = (
|
||||
obj.logprob_start_len if index is None else obj.logprob_start_len[index]
|
||||
obj.logprob_start_len if not_use_index else obj.logprob_start_len[index]
|
||||
)
|
||||
top_logprobs_num = (
|
||||
obj.top_logprobs_num if index is None else obj.top_logprobs_num[index]
|
||||
obj.top_logprobs_num if not_use_index else obj.top_logprobs_num[index]
|
||||
)
|
||||
else:
|
||||
if isinstance(obj.text, list):
|
||||
@@ -224,7 +227,7 @@ class TokenizerManager:
|
||||
|
||||
async def _handle_batch_request(self, obj: GenerateReqInput, request):
|
||||
batch_size = obj.batch_size
|
||||
parallel_sample_num = obj.sampling_params[0].get("n", 1)
|
||||
parallel_sample_num = obj.parallel_sample_num
|
||||
|
||||
if parallel_sample_num != 1:
|
||||
# Send prefill requests to cache the common input
|
||||
@@ -241,7 +244,6 @@ class TokenizerManager:
|
||||
obj.input_ids = input_id_result
|
||||
elif input_id_result is not None:
|
||||
obj.input_ids = input_id_result[0]
|
||||
|
||||
# First send out all requests
|
||||
for i in range(batch_size):
|
||||
for j in range(parallel_sample_num):
|
||||
@@ -249,7 +251,7 @@ class TokenizerManager:
|
||||
continue
|
||||
index = i * parallel_sample_num + j
|
||||
if parallel_sample_num != 1:
|
||||
# Here when using parallel sampling we shoul consider prefill stage so the index is : j + i * (parallel_sample_num-1) + batch_size - 1
|
||||
# Here when using parallel sampling we should consider prefill stage so the index is : j + i * (parallel_sample_num-1) + batch_size - 1
|
||||
index += batch_size - 1 - i
|
||||
rid = obj.rid[index]
|
||||
if parallel_sample_num == 1:
|
||||
|
||||
@@ -18,10 +18,14 @@ limitations under the License.
|
||||
import asyncio
|
||||
import json
|
||||
import os
|
||||
import time
|
||||
import uuid
|
||||
from http import HTTPStatus
|
||||
from typing import Dict, List, Optional
|
||||
|
||||
from fastapi import Request
|
||||
from fastapi import HTTPException, Request, UploadFile
|
||||
from fastapi.responses import JSONResponse, StreamingResponse
|
||||
from pydantic import ValidationError
|
||||
|
||||
from sglang.srt.conversation import (
|
||||
Conversation,
|
||||
@@ -32,6 +36,8 @@ from sglang.srt.conversation import (
|
||||
)
|
||||
from sglang.srt.managers.io_struct import GenerateReqInput
|
||||
from sglang.srt.openai_api.protocol import (
|
||||
BatchRequest,
|
||||
BatchResponse,
|
||||
ChatCompletionRequest,
|
||||
ChatCompletionResponse,
|
||||
ChatCompletionResponseChoice,
|
||||
@@ -45,6 +51,8 @@ from sglang.srt.openai_api.protocol import (
|
||||
CompletionStreamResponse,
|
||||
DeltaMessage,
|
||||
ErrorResponse,
|
||||
FileRequest,
|
||||
FileResponse,
|
||||
LogProbs,
|
||||
UsageInfo,
|
||||
)
|
||||
@@ -52,6 +60,24 @@ from sglang.srt.openai_api.protocol import (
|
||||
chat_template_name = None
|
||||
|
||||
|
||||
class FileMetadata:
|
||||
def __init__(self, filename: str, purpose: str):
|
||||
self.filename = filename
|
||||
self.purpose = purpose
|
||||
|
||||
|
||||
# In-memory storage for batch jobs and files
|
||||
batch_storage: Dict[str, BatchResponse] = {}
|
||||
file_id_request: Dict[str, FileMetadata] = {}
|
||||
file_id_response: Dict[str, FileResponse] = {}
|
||||
## map file id to file path in SGlang backend
|
||||
file_id_storage: Dict[str, str] = {}
|
||||
|
||||
|
||||
# backend storage directory
|
||||
storage_dir = None
|
||||
|
||||
|
||||
def create_error_response(
|
||||
message: str,
|
||||
err_type: str = "BadRequestError",
|
||||
@@ -106,33 +132,364 @@ def load_chat_template_for_openai_api(chat_template_arg):
|
||||
chat_template_name = chat_template_arg
|
||||
|
||||
|
||||
async def v1_completions(tokenizer_manager, raw_request: Request):
|
||||
request_json = await raw_request.json()
|
||||
request = CompletionRequest(**request_json)
|
||||
prompt = request.prompt
|
||||
if isinstance(prompt, str) or isinstance(prompt[0], str):
|
||||
prompt_kwargs = {"text": prompt}
|
||||
async def v1_files_create(file: UploadFile, purpose: str, file_storage_pth: str = None):
|
||||
try:
|
||||
global storage_dir
|
||||
if file_storage_pth:
|
||||
storage_dir = file_storage_pth
|
||||
# Read the file content
|
||||
file_content = await file.read()
|
||||
|
||||
# Create an instance of RequestBody
|
||||
request_body = FileRequest(file=file_content, purpose=purpose)
|
||||
|
||||
# Save the file to the sglang_oai_storage directory
|
||||
os.makedirs(storage_dir, exist_ok=True)
|
||||
file_id = f"backend_input_file-{uuid.uuid4()}"
|
||||
filename = f"{file_id}.jsonl"
|
||||
file_path = os.path.join(storage_dir, filename)
|
||||
|
||||
with open(file_path, "wb") as f:
|
||||
f.write(request_body.file)
|
||||
|
||||
# add info to global file map
|
||||
file_id_request[file_id] = FileMetadata(filename=file.filename, purpose=purpose)
|
||||
file_id_storage[file_id] = file_path
|
||||
|
||||
# Return the response in the required format
|
||||
response = FileResponse(
|
||||
id=file_id,
|
||||
bytes=len(request_body.file),
|
||||
created_at=int(time.time()),
|
||||
filename=file.filename,
|
||||
purpose=request_body.purpose,
|
||||
)
|
||||
file_id_response[file_id] = response
|
||||
|
||||
return response
|
||||
except ValidationError as e:
|
||||
return {"error": "Invalid input", "details": e.errors()}
|
||||
|
||||
|
||||
async def v1_batches(tokenizer_manager, raw_request: Request):
|
||||
try:
|
||||
body = await raw_request.json()
|
||||
|
||||
batch_request = BatchRequest(**body)
|
||||
|
||||
batch_id = f"batch_{uuid.uuid4()}"
|
||||
|
||||
# Create an instance of BatchResponse
|
||||
batch_response = BatchResponse(
|
||||
id=batch_id,
|
||||
endpoint=batch_request.endpoint,
|
||||
input_file_id=batch_request.input_file_id,
|
||||
completion_window=batch_request.completion_window,
|
||||
created_at=int(time.time()),
|
||||
metadata=batch_request.metadata,
|
||||
)
|
||||
|
||||
batch_storage[batch_id] = batch_response
|
||||
|
||||
# Start processing the batch asynchronously
|
||||
asyncio.create_task(process_batch(tokenizer_manager, batch_id, batch_request))
|
||||
|
||||
# Return the initial batch_response
|
||||
return batch_response
|
||||
|
||||
except ValidationError as e:
|
||||
return {"error": "Invalid input", "details": e.errors()}
|
||||
except Exception as e:
|
||||
return {"error": str(e)}
|
||||
|
||||
|
||||
async def process_batch(tokenizer_manager, batch_id: str, batch_request: BatchRequest):
|
||||
try:
|
||||
# Update the batch status to "in_progress"
|
||||
batch_storage[batch_id].status = "in_progress"
|
||||
batch_storage[batch_id].in_progress_at = int(time.time())
|
||||
|
||||
# Retrieve the input file content
|
||||
input_file_request = file_id_request.get(batch_request.input_file_id)
|
||||
if not input_file_request:
|
||||
raise ValueError("Input file not found")
|
||||
|
||||
# Parse the JSONL file and process each request
|
||||
input_file_path = file_id_storage.get(batch_request.input_file_id)
|
||||
with open(input_file_path, "r", encoding="utf-8") as f:
|
||||
lines = f.readlines()
|
||||
|
||||
total_requests = len(lines)
|
||||
completed_requests = 0
|
||||
failed_requests = 0
|
||||
|
||||
all_ret = []
|
||||
end_point = batch_storage[batch_id].endpoint
|
||||
file_request_list = []
|
||||
all_requests = []
|
||||
for line in lines:
|
||||
request_data = json.loads(line)
|
||||
file_request_list.append(request_data)
|
||||
body = request_data["body"]
|
||||
if end_point == "/v1/chat/completions":
|
||||
all_requests.append(ChatCompletionRequest(**body))
|
||||
elif end_point == "/v1/completions":
|
||||
all_requests.append(CompletionRequest(**body))
|
||||
if end_point == "/v1/chat/completions":
|
||||
adapted_request, request = v1_chat_generate_request(
|
||||
all_requests, tokenizer_manager
|
||||
)
|
||||
elif end_point == "/v1/completions":
|
||||
adapted_request, request = v1_generate_request(all_requests)
|
||||
try:
|
||||
ret = await tokenizer_manager.generate_request(adapted_request).__anext__()
|
||||
if not isinstance(ret, list):
|
||||
ret = [ret]
|
||||
if end_point == "/v1/chat/completions":
|
||||
responses = v1_chat_generate_response(request, ret, to_file=True)
|
||||
else:
|
||||
responses = v1_generate_response(request, ret, to_file=True)
|
||||
|
||||
except Exception as e:
|
||||
error_json = {
|
||||
"id": f"batch_req_{uuid.uuid4()}",
|
||||
"custom_id": request_data.get("custom_id"),
|
||||
"response": None,
|
||||
"error": {"message": str(e)},
|
||||
}
|
||||
all_ret.append(error_json)
|
||||
failed_requests += len(file_request_list)
|
||||
|
||||
for idx, response in enumerate(responses):
|
||||
## the batch_req here can be changed to be named within a batch granularity
|
||||
response_json = {
|
||||
"id": f"batch_req_{uuid.uuid4()}",
|
||||
"custom_id": file_request_list[idx].get("custom_id"),
|
||||
"response": response,
|
||||
"error": None,
|
||||
}
|
||||
all_ret.append(response_json)
|
||||
completed_requests += 1
|
||||
# Write results to a new file
|
||||
output_file_id = f"backend_result_file-{uuid.uuid4()}"
|
||||
global storage_dir
|
||||
output_file_path = os.path.join(storage_dir, f"{output_file_id}.jsonl")
|
||||
with open(output_file_path, "w", encoding="utf-8") as f:
|
||||
for ret in all_ret:
|
||||
f.write(json.dumps(ret) + "\n")
|
||||
|
||||
# Update batch response with output file information
|
||||
retrieve_batch = batch_storage[batch_id]
|
||||
retrieve_batch.output_file_id = output_file_id
|
||||
file_id_storage[output_file_id] = output_file_path
|
||||
# Update batch status to "completed"
|
||||
retrieve_batch.status = "completed"
|
||||
retrieve_batch.completed_at = int(time.time())
|
||||
retrieve_batch.request_counts = {
|
||||
"total": total_requests,
|
||||
"completed": completed_requests,
|
||||
"failed": failed_requests,
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
print("error in SGlang:", e)
|
||||
# Update batch status to "failed"
|
||||
retrieve_batch = batch_storage[batch_id]
|
||||
retrieve_batch.status = "failed"
|
||||
retrieve_batch.failed_at = int(time.time())
|
||||
retrieve_batch.errors = {"message": str(e)}
|
||||
|
||||
|
||||
async def v1_retrieve_batch(batch_id: str):
|
||||
# Retrieve the batch job from the in-memory storage
|
||||
batch_response = batch_storage.get(batch_id)
|
||||
if batch_response is None:
|
||||
raise HTTPException(status_code=404, detail="Batch not found")
|
||||
|
||||
return batch_response
|
||||
|
||||
|
||||
async def v1_retrieve_file(file_id: str):
|
||||
# Retrieve the batch job from the in-memory storage
|
||||
file_response = file_id_response.get(file_id)
|
||||
if file_response is None:
|
||||
raise HTTPException(status_code=404, detail="File not found")
|
||||
return file_response
|
||||
|
||||
|
||||
async def v1_retrieve_file_content(file_id: str):
|
||||
file_pth = file_id_storage.get(file_id)
|
||||
if not file_pth or not os.path.exists(file_pth):
|
||||
raise HTTPException(status_code=404, detail="File not found")
|
||||
|
||||
def iter_file():
|
||||
with open(file_pth, mode="rb") as file_like:
|
||||
yield from file_like
|
||||
|
||||
return StreamingResponse(iter_file(), media_type="application/octet-stream")
|
||||
|
||||
|
||||
def v1_generate_request(all_requests):
|
||||
|
||||
prompts = []
|
||||
sampling_params_list = []
|
||||
first_prompt_type = type(all_requests[0].prompt)
|
||||
for request in all_requests:
|
||||
prompt = request.prompt
|
||||
assert (
|
||||
type(prompt) == first_prompt_type
|
||||
), "All prompts must be of the same type in file input settings"
|
||||
prompts.append(prompt)
|
||||
sampling_params_list.append(
|
||||
{
|
||||
"temperature": request.temperature,
|
||||
"max_new_tokens": request.max_tokens,
|
||||
"stop": request.stop,
|
||||
"top_p": request.top_p,
|
||||
"presence_penalty": request.presence_penalty,
|
||||
"frequency_penalty": request.frequency_penalty,
|
||||
"regex": request.regex,
|
||||
"n": request.n,
|
||||
"ignore_eos": request.ignore_eos,
|
||||
}
|
||||
)
|
||||
if len(all_requests) > 1 and request.n > 1:
|
||||
raise ValueError(
|
||||
"Batch operation is not supported for completions from files"
|
||||
)
|
||||
|
||||
if len(all_requests) == 1:
|
||||
prompt = prompts[0]
|
||||
sampling_params_list = sampling_params_list[0]
|
||||
if isinstance(prompts, str) or isinstance(prompts[0], str):
|
||||
prompt_kwargs = {"text": prompt}
|
||||
else:
|
||||
prompt_kwargs = {"input_ids": prompt}
|
||||
else:
|
||||
prompt_kwargs = {"input_ids": prompt}
|
||||
if isinstance(prompts[0], str):
|
||||
prompt_kwargs = {"text": prompts}
|
||||
else:
|
||||
prompt_kwargs = {"input_ids": prompts}
|
||||
|
||||
adapted_request = GenerateReqInput(
|
||||
**prompt_kwargs,
|
||||
sampling_params={
|
||||
"temperature": request.temperature,
|
||||
"max_new_tokens": request.max_tokens,
|
||||
"stop": request.stop,
|
||||
"top_p": request.top_p,
|
||||
"presence_penalty": request.presence_penalty,
|
||||
"frequency_penalty": request.frequency_penalty,
|
||||
"regex": request.regex,
|
||||
"n": request.n,
|
||||
"ignore_eos": request.ignore_eos,
|
||||
},
|
||||
return_logprob=request.logprobs is not None and request.logprobs > 0,
|
||||
top_logprobs_num=request.logprobs if request.logprobs is not None else 0,
|
||||
sampling_params=sampling_params_list,
|
||||
return_logprob=all_requests[0].logprobs is not None
|
||||
and all_requests[0].logprobs > 0,
|
||||
top_logprobs_num=(
|
||||
all_requests[0].logprobs if all_requests[0].logprobs is not None else 0
|
||||
),
|
||||
return_text_in_logprobs=True,
|
||||
stream=request.stream,
|
||||
stream=all_requests[0].stream,
|
||||
)
|
||||
if len(all_requests) == 1:
|
||||
return adapted_request, all_requests[0]
|
||||
return adapted_request, all_requests
|
||||
|
||||
|
||||
def v1_generate_response(request, ret, to_file=False):
|
||||
choices = []
|
||||
echo = False
|
||||
|
||||
if (not isinstance(request, List)) and request.echo:
|
||||
# TODO: handle the case propmt is token ids
|
||||
if isinstance(request.prompt, list):
|
||||
prompts = request.prompt
|
||||
else:
|
||||
prompts = [request.prompt]
|
||||
echo = True
|
||||
|
||||
for idx, ret_item in enumerate(ret):
|
||||
text = ret_item["text"]
|
||||
if isinstance(request, List) and request[idx].echo:
|
||||
echo = True
|
||||
text = request[idx].prompt + text
|
||||
if (not isinstance(request, List)) and echo:
|
||||
text = prompts[idx] + text
|
||||
|
||||
logprobs = False
|
||||
if isinstance(request, List) and request[idx].logprobs:
|
||||
logprobs = True
|
||||
elif (not isinstance(request, List)) and request.logprobs:
|
||||
logprobs = True
|
||||
if logprobs:
|
||||
if echo:
|
||||
input_token_logprobs = ret_item["meta_info"]["input_token_logprobs"]
|
||||
input_top_logprobs = ret_item["meta_info"]["input_top_logprobs"]
|
||||
else:
|
||||
input_token_logprobs = None
|
||||
input_top_logprobs = None
|
||||
|
||||
logprobs = to_openai_style_logprobs(
|
||||
input_token_logprobs=input_token_logprobs,
|
||||
input_top_logprobs=input_top_logprobs,
|
||||
output_token_logprobs=ret_item["meta_info"]["output_token_logprobs"],
|
||||
output_top_logprobs=ret_item["meta_info"]["output_top_logprobs"],
|
||||
)
|
||||
else:
|
||||
logprobs = None
|
||||
|
||||
if to_file:
|
||||
## to make the choise data json serializable
|
||||
choice_data = {
|
||||
"index": 0,
|
||||
"text": text,
|
||||
"logprobs": logprobs,
|
||||
"finish_reason": ret_item["meta_info"]["finish_reason"],
|
||||
}
|
||||
else:
|
||||
choice_data = CompletionResponseChoice(
|
||||
index=idx,
|
||||
text=text,
|
||||
logprobs=logprobs,
|
||||
finish_reason=ret_item["meta_info"]["finish_reason"],
|
||||
)
|
||||
|
||||
choices.append(choice_data)
|
||||
|
||||
if to_file:
|
||||
responses = []
|
||||
for i, choice in enumerate(choices):
|
||||
response = {
|
||||
"status_code": 200,
|
||||
"request_id": ret[i]["meta_info"]["id"],
|
||||
"body": {
|
||||
## remain the same but if needed we can change that
|
||||
"id": ret[i]["meta_info"]["id"],
|
||||
"object": "text_completion",
|
||||
"created": int(time.time()),
|
||||
"model": request[i].model,
|
||||
"choices": choice,
|
||||
"usage": {
|
||||
"prompt_tokens": ret[i]["meta_info"]["prompt_tokens"],
|
||||
"completion_tokens": ret[i]["meta_info"]["completion_tokens"],
|
||||
"total_tokens": ret[i]["meta_info"]["prompt_tokens"]
|
||||
+ ret[i]["meta_info"]["completion_tokens"],
|
||||
},
|
||||
"system_fingerprint": None,
|
||||
},
|
||||
}
|
||||
responses.append(response)
|
||||
return responses
|
||||
else:
|
||||
completion_tokens = sum(item["meta_info"]["completion_tokens"] for item in ret)
|
||||
response = CompletionResponse(
|
||||
id=ret[0]["meta_info"]["id"],
|
||||
model=request.model,
|
||||
choices=choices,
|
||||
usage=UsageInfo(
|
||||
prompt_tokens=ret[0]["meta_info"]["prompt_tokens"],
|
||||
completion_tokens=completion_tokens,
|
||||
total_tokens=ret[0]["meta_info"]["prompt_tokens"] + completion_tokens,
|
||||
),
|
||||
)
|
||||
return response
|
||||
|
||||
|
||||
async def v1_completions(tokenizer_manager, raw_request: Request):
|
||||
request_json = await raw_request.json()
|
||||
all_requests = [CompletionRequest(**request_json)]
|
||||
adapted_request, request = v1_generate_request(all_requests)
|
||||
|
||||
if adapted_request.stream:
|
||||
|
||||
@@ -223,109 +580,144 @@ async def v1_completions(tokenizer_manager, raw_request: Request):
|
||||
|
||||
if not isinstance(ret, list):
|
||||
ret = [ret]
|
||||
if request.echo:
|
||||
# TODO: handle the case propmt is token ids
|
||||
if isinstance(request.prompt, list):
|
||||
prompts = request.prompt
|
||||
|
||||
response = v1_generate_response(request, ret)
|
||||
return response
|
||||
|
||||
|
||||
def v1_chat_generate_request(all_requests, tokenizer_manager):
|
||||
|
||||
texts = []
|
||||
sampling_params_list = []
|
||||
image_data_list = []
|
||||
for request in all_requests:
|
||||
# Prep the data needed for the underlying GenerateReqInput:
|
||||
# - prompt: The full prompt string.
|
||||
# - stop: Custom stop tokens.
|
||||
# - image_data: None or a list of image strings (URLs or base64 strings).
|
||||
# None skips any image processing in GenerateReqInput.
|
||||
if not isinstance(request.messages, str):
|
||||
# Apply chat template and its stop strings.
|
||||
if chat_template_name is None:
|
||||
prompt = tokenizer_manager.tokenizer.apply_chat_template(
|
||||
request.messages, tokenize=False, add_generation_prompt=True
|
||||
)
|
||||
stop = request.stop
|
||||
image_data = None
|
||||
else:
|
||||
conv = generate_chat_conv(request, chat_template_name)
|
||||
prompt = conv.get_prompt()
|
||||
image_data = conv.image_data
|
||||
stop = conv.stop_str or []
|
||||
if request.stop:
|
||||
if isinstance(request.stop, str):
|
||||
stop.append(request.stop)
|
||||
else:
|
||||
stop.extend(request.stop)
|
||||
else:
|
||||
prompts = [request.prompt]
|
||||
# Use the raw prompt and stop strings if the messages is already a string.
|
||||
prompt = request.messages
|
||||
stop = request.stop
|
||||
image_data = None
|
||||
texts.append(prompt)
|
||||
sampling_params_list.append(
|
||||
{
|
||||
"temperature": request.temperature,
|
||||
"max_new_tokens": request.max_tokens,
|
||||
"stop": stop,
|
||||
"top_p": request.top_p,
|
||||
"presence_penalty": request.presence_penalty,
|
||||
"frequency_penalty": request.frequency_penalty,
|
||||
"regex": request.regex,
|
||||
"n": request.n,
|
||||
}
|
||||
)
|
||||
image_data_list.append(image_data)
|
||||
if len(all_requests) == 1:
|
||||
texts = texts[0]
|
||||
sampling_params_list = sampling_params_list[0]
|
||||
image_data = image_data_list[0]
|
||||
adapted_request = GenerateReqInput(
|
||||
text=texts,
|
||||
image_data=image_data,
|
||||
sampling_params=sampling_params_list,
|
||||
stream=request.stream,
|
||||
)
|
||||
if len(all_requests) == 1:
|
||||
return adapted_request, all_requests[0]
|
||||
return adapted_request, all_requests
|
||||
|
||||
|
||||
def v1_chat_generate_response(request, ret, to_file=False):
|
||||
choices = []
|
||||
total_prompt_tokens = 0
|
||||
total_completion_tokens = 0
|
||||
|
||||
for idx, ret_item in enumerate(ret):
|
||||
text = ret_item["text"]
|
||||
prompt_tokens = ret_item["meta_info"]["prompt_tokens"]
|
||||
completion_tokens = ret_item["meta_info"]["completion_tokens"]
|
||||
|
||||
if request.echo:
|
||||
text = prompts[idx] + text
|
||||
|
||||
if request.logprobs:
|
||||
if request.echo:
|
||||
input_token_logprobs = ret_item["meta_info"]["input_token_logprobs"]
|
||||
input_top_logprobs = ret_item["meta_info"]["input_top_logprobs"]
|
||||
else:
|
||||
input_token_logprobs = None
|
||||
input_top_logprobs = None
|
||||
|
||||
logprobs = to_openai_style_logprobs(
|
||||
input_token_logprobs=input_token_logprobs,
|
||||
input_top_logprobs=input_top_logprobs,
|
||||
output_token_logprobs=ret_item["meta_info"]["output_token_logprobs"],
|
||||
output_top_logprobs=ret_item["meta_info"]["output_top_logprobs"],
|
||||
)
|
||||
if to_file:
|
||||
## to make the choice data json serializable
|
||||
choice_data = {
|
||||
"index": 0,
|
||||
"message": {"role": "assistant", "content": ret_item["text"]},
|
||||
"logprobs": None,
|
||||
"finish_reason": ret_item["meta_info"]["finish_reason"],
|
||||
}
|
||||
else:
|
||||
logprobs = None
|
||||
|
||||
choice_data = CompletionResponseChoice(
|
||||
index=idx,
|
||||
text=text,
|
||||
logprobs=logprobs,
|
||||
finish_reason=ret_item["meta_info"]["finish_reason"],
|
||||
)
|
||||
choice_data = ChatCompletionResponseChoice(
|
||||
index=idx,
|
||||
message=ChatMessage(role="assistant", content=ret_item["text"]),
|
||||
finish_reason=ret_item["meta_info"]["finish_reason"],
|
||||
)
|
||||
|
||||
choices.append(choice_data)
|
||||
total_prompt_tokens = prompt_tokens
|
||||
total_completion_tokens += completion_tokens
|
||||
if to_file:
|
||||
responses = []
|
||||
|
||||
completion_tokens = sum(item["meta_info"]["completion_tokens"] for item in ret)
|
||||
response = CompletionResponse(
|
||||
id=ret[0]["meta_info"]["id"],
|
||||
model=request.model,
|
||||
choices=choices,
|
||||
usage=UsageInfo(
|
||||
prompt_tokens=ret[0]["meta_info"]["prompt_tokens"],
|
||||
completion_tokens=completion_tokens,
|
||||
total_tokens=ret[0]["meta_info"]["prompt_tokens"] + completion_tokens,
|
||||
),
|
||||
)
|
||||
|
||||
return response
|
||||
for i, choice in enumerate(choices):
|
||||
response = {
|
||||
"status_code": 200,
|
||||
"request_id": ret[i]["meta_info"]["id"],
|
||||
"body": {
|
||||
## remain the same but if needed we can change that
|
||||
"id": ret[i]["meta_info"]["id"],
|
||||
"object": "chat.completion",
|
||||
"created": int(time.time()),
|
||||
"model": request[i].model,
|
||||
"choices": choice,
|
||||
"usage": {
|
||||
"prompt_tokens": ret[i]["meta_info"]["prompt_tokens"],
|
||||
"completion_tokens": ret[i]["meta_info"]["completion_tokens"],
|
||||
"total_tokens": ret[i]["meta_info"]["prompt_tokens"]
|
||||
+ ret[i]["meta_info"]["completion_tokens"],
|
||||
},
|
||||
"system_fingerprint": None,
|
||||
},
|
||||
}
|
||||
responses.append(response)
|
||||
return responses
|
||||
else:
|
||||
response = ChatCompletionResponse(
|
||||
id=ret[0]["meta_info"]["id"],
|
||||
model=request.model,
|
||||
choices=choices,
|
||||
usage=UsageInfo(
|
||||
prompt_tokens=total_prompt_tokens,
|
||||
completion_tokens=total_completion_tokens,
|
||||
total_tokens=total_prompt_tokens + total_completion_tokens,
|
||||
),
|
||||
)
|
||||
return response
|
||||
|
||||
|
||||
async def v1_chat_completions(tokenizer_manager, raw_request: Request):
|
||||
request_json = await raw_request.json()
|
||||
request = ChatCompletionRequest(**request_json)
|
||||
|
||||
# Prep the data needed for the underlying GenerateReqInput:
|
||||
# - prompt: The full prompt string.
|
||||
# - stop: Custom stop tokens.
|
||||
# - image_data: None or a list of image strings (URLs or base64 strings).
|
||||
# None skips any image processing in GenerateReqInput.
|
||||
if not isinstance(request.messages, str):
|
||||
# Apply chat template and its stop strings.
|
||||
if chat_template_name is None:
|
||||
prompt = tokenizer_manager.tokenizer.apply_chat_template(
|
||||
request.messages, tokenize=False, add_generation_prompt=True
|
||||
)
|
||||
stop = request.stop
|
||||
image_data = None
|
||||
else:
|
||||
conv = generate_chat_conv(request, chat_template_name)
|
||||
prompt = conv.get_prompt()
|
||||
image_data = conv.image_data
|
||||
stop = conv.stop_str or []
|
||||
if request.stop:
|
||||
if isinstance(request.stop, str):
|
||||
stop.append(request.stop)
|
||||
else:
|
||||
stop.extend(request.stop)
|
||||
else:
|
||||
# Use the raw prompt and stop strings if the messages is already a string.
|
||||
prompt = request.messages
|
||||
stop = request.stop
|
||||
image_data = None
|
||||
|
||||
adapted_request = GenerateReqInput(
|
||||
text=prompt,
|
||||
image_data=image_data,
|
||||
sampling_params={
|
||||
"temperature": request.temperature,
|
||||
"max_new_tokens": request.max_tokens,
|
||||
"stop": stop,
|
||||
"top_p": request.top_p,
|
||||
"presence_penalty": request.presence_penalty,
|
||||
"frequency_penalty": request.frequency_penalty,
|
||||
"regex": request.regex,
|
||||
"n": request.n,
|
||||
},
|
||||
stream=request.stream,
|
||||
)
|
||||
all_requests = [ChatCompletionRequest(**request_json)]
|
||||
adapted_request, request = v1_chat_generate_request(all_requests, tokenizer_manager)
|
||||
|
||||
if adapted_request.stream:
|
||||
|
||||
@@ -387,34 +779,8 @@ async def v1_chat_completions(tokenizer_manager, raw_request: Request):
|
||||
|
||||
if not isinstance(ret, list):
|
||||
ret = [ret]
|
||||
choices = []
|
||||
total_prompt_tokens = 0
|
||||
total_completion_tokens = 0
|
||||
|
||||
for idx, ret_item in enumerate(ret):
|
||||
prompt_tokens = ret_item["meta_info"]["prompt_tokens"]
|
||||
completion_tokens = ret_item["meta_info"]["completion_tokens"]
|
||||
|
||||
choice_data = ChatCompletionResponseChoice(
|
||||
index=idx,
|
||||
message=ChatMessage(role="assistant", content=ret_item["text"]),
|
||||
finish_reason=ret_item["meta_info"]["finish_reason"],
|
||||
)
|
||||
|
||||
choices.append(choice_data)
|
||||
total_prompt_tokens = prompt_tokens
|
||||
total_completion_tokens += completion_tokens
|
||||
|
||||
response = ChatCompletionResponse(
|
||||
id=ret[0]["meta_info"]["id"],
|
||||
model=request.model,
|
||||
choices=choices,
|
||||
usage=UsageInfo(
|
||||
prompt_tokens=total_prompt_tokens,
|
||||
completion_tokens=total_completion_tokens,
|
||||
total_tokens=total_prompt_tokens + total_completion_tokens,
|
||||
),
|
||||
)
|
||||
response = v1_chat_generate_response(request, ret)
|
||||
|
||||
return response
|
||||
|
||||
|
||||
@@ -60,6 +60,55 @@ class UsageInfo(BaseModel):
|
||||
completion_tokens: Optional[int] = 0
|
||||
|
||||
|
||||
class FileRequest(BaseModel):
|
||||
# https://platform.openai.com/docs/api-reference/files/create
|
||||
file: bytes # The File object (not file name) to be uploaded
|
||||
purpose: str = (
|
||||
"batch" # The intended purpose of the uploaded file, default is "batch"
|
||||
)
|
||||
|
||||
|
||||
class FileResponse(BaseModel):
|
||||
id: str
|
||||
object: str = "file"
|
||||
bytes: int
|
||||
created_at: int
|
||||
filename: str
|
||||
purpose: str
|
||||
|
||||
|
||||
class BatchRequest(BaseModel):
|
||||
input_file_id: (
|
||||
str # The ID of an uploaded file that contains requests for the new batch
|
||||
)
|
||||
endpoint: str # The endpoint to be used for all requests in the batch
|
||||
completion_window: str # The time frame within which the batch should be processed
|
||||
metadata: Optional[dict] = None # Optional custom metadata for the batch
|
||||
|
||||
|
||||
class BatchResponse(BaseModel):
|
||||
id: str
|
||||
object: str = "batch"
|
||||
endpoint: str
|
||||
errors: Optional[dict] = None
|
||||
input_file_id: str
|
||||
completion_window: str
|
||||
status: str = "validating"
|
||||
output_file_id: Optional[str] = None
|
||||
error_file_id: Optional[str] = None
|
||||
created_at: int
|
||||
in_progress_at: Optional[int] = None
|
||||
expires_at: Optional[int] = None
|
||||
finalizing_at: Optional[int] = None
|
||||
completed_at: Optional[int] = None
|
||||
failed_at: Optional[int] = None
|
||||
expired_at: Optional[int] = None
|
||||
cancelling_at: Optional[int] = None
|
||||
cancelled_at: Optional[int] = None
|
||||
request_counts: dict = {"total": 0, "completed": 0, "failed": 0}
|
||||
metadata: Optional[dict] = None
|
||||
|
||||
|
||||
class CompletionRequest(BaseModel):
|
||||
# Ordered by official OpenAI API documentation
|
||||
# https://platform.openai.com/docs/api-reference/completions/create
|
||||
|
||||
@@ -38,7 +38,7 @@ import psutil
|
||||
import requests
|
||||
import uvicorn
|
||||
import uvloop
|
||||
from fastapi import FastAPI, Request
|
||||
from fastapi import FastAPI, File, Form, Request, UploadFile
|
||||
from fastapi.responses import JSONResponse, Response, StreamingResponse
|
||||
|
||||
from sglang.lang.backend.runtime_endpoint import RuntimeEndpoint
|
||||
@@ -56,8 +56,13 @@ from sglang.srt.managers.io_struct import GenerateReqInput
|
||||
from sglang.srt.managers.tokenizer_manager import TokenizerManager
|
||||
from sglang.srt.openai_api.adapter import (
|
||||
load_chat_template_for_openai_api,
|
||||
v1_batches,
|
||||
v1_chat_completions,
|
||||
v1_completions,
|
||||
v1_files_create,
|
||||
v1_retrieve_batch,
|
||||
v1_retrieve_file,
|
||||
v1_retrieve_file_content,
|
||||
)
|
||||
from sglang.srt.openai_api.protocol import ModelCard, ModelList
|
||||
from sglang.srt.server_args import PortArgs, ServerArgs
|
||||
@@ -152,6 +157,35 @@ async def openai_v1_chat_completions(raw_request: Request):
|
||||
return await v1_chat_completions(tokenizer_manager, raw_request)
|
||||
|
||||
|
||||
@app.post("/v1/files")
|
||||
async def openai_v1_files(file: UploadFile = File(...), purpose: str = Form("batch")):
|
||||
return await v1_files_create(
|
||||
file, purpose, tokenizer_manager.server_args.file_storage_pth
|
||||
)
|
||||
|
||||
|
||||
@app.post("/v1/batches")
|
||||
async def openai_v1_batches(raw_request: Request):
|
||||
return await v1_batches(tokenizer_manager, raw_request)
|
||||
|
||||
|
||||
@app.get("/v1/batches/{batch_id}")
|
||||
async def retrieve_batch(batch_id: str):
|
||||
return await v1_retrieve_batch(batch_id)
|
||||
|
||||
|
||||
@app.get("/v1/files/{file_id}")
|
||||
async def retrieve_file(file_id: str):
|
||||
# https://platform.openai.com/docs/api-reference/files/retrieve
|
||||
return await v1_retrieve_file(file_id)
|
||||
|
||||
|
||||
@app.get("/v1/files/{file_id}/content")
|
||||
async def retrieve_file_content(file_id: str):
|
||||
# https://platform.openai.com/docs/api-reference/files/retrieve-contents
|
||||
return await v1_retrieve_file_content(file_id)
|
||||
|
||||
|
||||
@app.get("/v1/models")
|
||||
def available_models():
|
||||
"""Show available models."""
|
||||
|
||||
@@ -60,6 +60,7 @@ class ServerArgs:
|
||||
|
||||
# Other
|
||||
api_key: str = ""
|
||||
file_storage_pth: str = "SGlang_storage"
|
||||
|
||||
# Data parallelism
|
||||
dp_size: int = 1
|
||||
@@ -290,6 +291,12 @@ class ServerArgs:
|
||||
default=ServerArgs.api_key,
|
||||
help="Set API key of the server.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--file-storage-pth",
|
||||
type=str,
|
||||
default=ServerArgs.file_storage_pth,
|
||||
help="The path of the file storage in backend.",
|
||||
)
|
||||
|
||||
# Data parallelism
|
||||
parser.add_argument(
|
||||
|
||||
Reference in New Issue
Block a user