Files
sglang/python/sglang/srt/openai_api/adapter.py
2024-07-28 23:07:12 +10:00

453 lines
16 KiB
Python

"""
Copyright 2023-2024 SGLang Team
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""
"""Conversion between OpenAI APIs and native SRT APIs"""
import asyncio
import json
import os
from http import HTTPStatus
from fastapi import Request
from fastapi.responses import JSONResponse, StreamingResponse
from sglang.srt.conversation import (
Conversation,
SeparatorStyle,
chat_template_exists,
generate_chat_conv,
register_conv_template,
)
from sglang.srt.managers.io_struct import GenerateReqInput
from sglang.srt.openai_api.protocol import (
ChatCompletionRequest,
ChatCompletionResponse,
ChatCompletionResponseChoice,
ChatCompletionResponseStreamChoice,
ChatCompletionStreamResponse,
ChatMessage,
CompletionRequest,
CompletionResponse,
CompletionResponseChoice,
CompletionResponseStreamChoice,
CompletionStreamResponse,
DeltaMessage,
ErrorResponse,
LogProbs,
UsageInfo,
)
chat_template_name = None
def create_error_response(
message: str,
err_type: str = "BadRequestError",
status_code: HTTPStatus = HTTPStatus.BAD_REQUEST,
):
error = ErrorResponse(message=message, type=err_type, code=status_code.value)
return JSONResponse(content=error.model_dump(), status_code=error.code)
def create_streaming_error_response(
message: str,
err_type: str = "BadRequestError",
status_code: HTTPStatus = HTTPStatus.BAD_REQUEST,
) -> str:
error = ErrorResponse(message=message, type=err_type, code=status_code.value)
json_str = json.dumps({"error": error.model_dump()})
return json_str
def load_chat_template_for_openai_api(chat_template_arg):
global chat_template_name
print(f"Use chat template: {chat_template_arg}")
if not chat_template_exists(chat_template_arg):
if not os.path.exists(chat_template_arg):
raise RuntimeError(
f"Chat template {chat_template_arg} is not a built-in template name "
"or a valid chat template file path."
)
with open(chat_template_arg, "r") as filep:
template = json.load(filep)
try:
sep_style = SeparatorStyle[template["sep_style"]]
except KeyError:
raise ValueError(
f"Unknown separator style: {template['sep_style']}"
) from None
register_conv_template(
Conversation(
name=template["name"],
system_template=template["system"] + "\n{system_message}",
system_message=template.get("system_message", ""),
roles=(template["user"], template["assistant"]),
sep_style=sep_style,
sep=template.get("sep", "\n"),
stop_str=template["stop_str"],
),
override=True,
)
chat_template_name = template["name"]
else:
chat_template_name = chat_template_arg
async def v1_completions(tokenizer_manager, raw_request: Request):
request_json = await raw_request.json()
request = CompletionRequest(**request_json)
prompt = request.prompt
if isinstance(prompt, str) or isinstance(prompt[0], str):
prompt_kwargs = {"text": prompt}
else:
prompt_kwargs = {"input_ids": prompt}
adapted_request = GenerateReqInput(
**prompt_kwargs,
sampling_params={
"temperature": request.temperature,
"max_new_tokens": request.max_tokens,
"stop": request.stop,
"top_p": request.top_p,
"presence_penalty": request.presence_penalty,
"frequency_penalty": request.frequency_penalty,
"regex": request.regex,
"n": request.n,
"ignore_eos": request.ignore_eos,
},
return_logprob=request.logprobs is not None and request.logprobs > 0,
top_logprobs_num=request.logprobs if request.logprobs is not None else 0,
return_text_in_logprobs=True,
stream=request.stream,
)
if adapted_request.stream:
async def generate_stream_resp():
stream_buffer = ""
n_prev_token = 0
try:
async for content in tokenizer_manager.generate_request(
adapted_request, raw_request
):
text = content["text"]
prompt_tokens = content["meta_info"]["prompt_tokens"]
completion_tokens = content["meta_info"]["completion_tokens"]
if not stream_buffer: # The first chunk
if request.echo:
# Prepend prompt in response text.
text = request.prompt + text
if request.logprobs:
# The first chunk and echo is enabled.
if not stream_buffer and request.echo:
input_token_logprobs = content["meta_info"][
"input_token_logprobs"
]
input_top_logprobs = content["meta_info"][
"input_top_logprobs"
]
else:
input_token_logprobs = None
input_top_logprobs = None
logprobs = to_openai_style_logprobs(
input_token_logprobs=input_token_logprobs,
input_top_logprobs=input_top_logprobs,
output_token_logprobs=content["meta_info"][
"output_token_logprobs"
][n_prev_token:],
output_top_logprobs=content["meta_info"][
"output_top_logprobs"
][n_prev_token:],
)
n_prev_token = len(
content["meta_info"]["output_token_logprobs"]
)
else:
logprobs = None
delta = text[len(stream_buffer) :]
stream_buffer = stream_buffer + delta
choice_data = CompletionResponseStreamChoice(
index=0,
text=delta,
logprobs=logprobs,
finish_reason=content["meta_info"]["finish_reason"],
)
chunk = CompletionStreamResponse(
id=content["meta_info"]["id"],
object="text_completion",
choices=[choice_data],
model=request.model,
usage=UsageInfo(
prompt_tokens=prompt_tokens,
completion_tokens=completion_tokens,
total_tokens=prompt_tokens + completion_tokens,
),
)
yield f"data: {chunk.model_dump_json()}\n\n"
except ValueError as e:
error = create_streaming_error_response(str(e))
yield f"data: {error}\n\n"
yield "data: [DONE]\n\n"
return StreamingResponse(
generate_stream_resp(),
media_type="text/event-stream",
background=tokenizer_manager.create_abort_task(adapted_request),
)
# Non-streaming response.
try:
ret = await tokenizer_manager.generate_request(
adapted_request, raw_request
).__anext__()
except ValueError as e:
return create_error_response(str(e))
if not isinstance(ret, list):
ret = [ret]
choices = []
for idx, ret_item in enumerate(ret):
text = ret_item["text"]
if request.echo:
text = request.prompt + text
if request.logprobs:
if request.echo:
input_token_logprobs = ret_item["meta_info"]["input_token_logprobs"]
input_top_logprobs = ret_item["meta_info"]["input_top_logprobs"]
else:
input_token_logprobs = None
input_top_logprobs = None
logprobs = to_openai_style_logprobs(
input_token_logprobs=input_token_logprobs,
input_top_logprobs=input_top_logprobs,
output_token_logprobs=ret_item["meta_info"]["output_token_logprobs"],
output_top_logprobs=ret_item["meta_info"]["output_top_logprobs"],
)
else:
logprobs = None
choice_data = CompletionResponseChoice(
index=idx,
text=text,
logprobs=logprobs,
finish_reason=ret_item["meta_info"]["finish_reason"],
)
choices.append(choice_data)
response = CompletionResponse(
id=ret[0]["meta_info"]["id"],
model=request.model,
choices=choices,
usage=UsageInfo(
prompt_tokens=ret[0]["meta_info"]["prompt_tokens"],
completion_tokens=sum(
item["meta_info"]["completion_tokens"] for item in ret
),
total_tokens=ret[0]["meta_info"]["prompt_tokens"]
+ sum(item["meta_info"]["completion_tokens"] for item in ret),
),
)
return response
async def v1_chat_completions(tokenizer_manager, raw_request: Request):
request_json = await raw_request.json()
request = ChatCompletionRequest(**request_json)
# Prep the data needed for the underlying GenerateReqInput:
# - prompt: The full prompt string.
# - stop: Custom stop tokens.
# - image_data: None or a list of image strings (URLs or base64 strings).
# None skips any image processing in GenerateReqInput.
if not isinstance(request.messages, str):
# Apply chat template and its stop strings.
if chat_template_name is None:
prompt = tokenizer_manager.tokenizer.apply_chat_template(
request.messages, tokenize=False, add_generation_prompt=True
)
stop = request.stop
image_data = None
else:
conv = generate_chat_conv(request, chat_template_name)
prompt = conv.get_prompt()
image_data = conv.image_data
stop = conv.stop_str or []
if request.stop:
if isinstance(request.stop, str):
stop.append(request.stop)
else:
stop.extend(request.stop)
else:
# Use the raw prompt and stop strings if the messages is already a string.
prompt = request.messages
stop = request.stop
image_data = None
adapted_request = GenerateReqInput(
text=prompt,
image_data=image_data,
sampling_params={
"temperature": request.temperature,
"max_new_tokens": request.max_tokens,
"stop": stop,
"top_p": request.top_p,
"presence_penalty": request.presence_penalty,
"frequency_penalty": request.frequency_penalty,
"regex": request.regex,
"n": request.n,
},
stream=request.stream,
)
if adapted_request.stream:
async def generate_stream_resp():
is_first = True
stream_buffer = ""
try:
async for content in tokenizer_manager.generate_request(
adapted_request, raw_request
):
if is_first:
# First chunk with role
is_first = False
choice_data = ChatCompletionResponseStreamChoice(
index=0,
delta=DeltaMessage(role="assistant"),
finish_reason=content["meta_info"]["finish_reason"],
)
chunk = ChatCompletionStreamResponse(
id=content["meta_info"]["id"],
choices=[choice_data],
model=request.model,
)
yield f"data: {chunk.model_dump_json()}\n\n"
text = content["text"]
delta = text[len(stream_buffer) :]
stream_buffer = stream_buffer + delta
choice_data = ChatCompletionResponseStreamChoice(
index=0,
delta=DeltaMessage(content=delta),
finish_reason=content["meta_info"]["finish_reason"],
)
chunk = ChatCompletionStreamResponse(
id=content["meta_info"]["id"],
choices=[choice_data],
model=request.model,
)
yield f"data: {chunk.model_dump_json()}\n\n"
except ValueError as e:
error = create_streaming_error_response(str(e))
yield f"data: {error}\n\n"
yield "data: [DONE]\n\n"
return StreamingResponse(
generate_stream_resp(),
media_type="text/event-stream",
background=tokenizer_manager.create_abort_task(adapted_request),
)
# Non-streaming response.
try:
ret = await tokenizer_manager.generate_request(
adapted_request, raw_request
).__anext__()
except ValueError as e:
return create_error_response(str(e))
if not isinstance(ret, list):
ret = [ret]
choices = []
total_prompt_tokens = 0
total_completion_tokens = 0
for idx, ret_item in enumerate(ret):
prompt_tokens = ret_item["meta_info"]["prompt_tokens"]
completion_tokens = ret_item["meta_info"]["completion_tokens"]
choice_data = ChatCompletionResponseChoice(
index=idx,
message=ChatMessage(role="assistant", content=ret_item["text"]),
finish_reason=ret_item["meta_info"]["finish_reason"],
)
choices.append(choice_data)
total_prompt_tokens = prompt_tokens
total_completion_tokens += completion_tokens
response = ChatCompletionResponse(
id=ret[0]["meta_info"]["id"],
model=request.model,
choices=choices,
usage=UsageInfo(
prompt_tokens=total_prompt_tokens,
completion_tokens=total_completion_tokens,
total_tokens=total_prompt_tokens + total_completion_tokens,
),
)
return response
def to_openai_style_logprobs(
input_token_logprobs=None,
output_token_logprobs=None,
input_top_logprobs=None,
output_top_logprobs=None,
):
ret_logprobs = LogProbs()
def append_token_logprobs(token_logprobs):
for logprob, _, token_text in token_logprobs:
ret_logprobs.tokens.append(token_text)
ret_logprobs.token_logprobs.append(logprob)
# Not supported yet
ret_logprobs.text_offset.append(-1)
def append_top_logprobs(top_logprobs):
for tokens in top_logprobs:
if tokens is not None:
ret_logprobs.top_logprobs.append(
{token[2]: token[0] for token in tokens}
)
else:
ret_logprobs.top_logprobs.append(None)
if input_token_logprobs is not None:
append_token_logprobs(input_token_logprobs)
if output_token_logprobs is not None:
append_token_logprobs(output_token_logprobs)
if input_top_logprobs is not None:
append_top_logprobs(input_top_logprobs)
if output_top_logprobs is not None:
append_top_logprobs(output_top_logprobs)
return ret_logprobs