[Feature] Adds basic support for image content in OpenAI chat routes (#113)

This commit is contained in:
Keith Stevens
2024-01-30 23:12:33 +09:00
committed by GitHub
parent 97aa9b3284
commit 1d0fbe8e43
8 changed files with 220 additions and 11 deletions

View File

@@ -16,7 +16,7 @@ import psutil
import requests
import uvicorn
import uvloop
from fastapi import FastAPI, Request
from fastapi import FastAPI, HTTPException, Request
from fastapi.responses import Response, StreamingResponse
from sglang.backend.runtime_endpoint import RuntimeEndpoint
from sglang.srt.conversation import (
@@ -190,16 +190,31 @@ async def v1_chat_completions(raw_request: Request):
# TODO: Validate the request and return HTTPStatus.BAD_REQUEST if invalid.
assert request.n == 1
# Prep the data needed for the underlying GenerateReqInput:
# - prompt: The full prompt string.
# - stop: Custom stop tokens.
# - image_data: None or a list of image strings (URLs or base64 strings).
# None skips any image processing in GenerateReqInput.
if not isinstance(request.messages, str):
# Apply chat template and its stop strings.
if chat_template_name is None:
# This flow doesn't support the full OpenAI spec. Verify messages
# has the right type before proceeding:
for m in request.messages:
if not isinstance(m.content, str):
raise HTTPException(
status_code=503,
detail="Structured content requests not supported with HuggingFace Chat Templates. Make sure the server specifies a sglang chat template.",
)
prompt = tokenizer_manager.tokenizer.apply_chat_template(
request.messages, tokenize=False, add_generation_prompt=True
)
stop = request.stop
image_data = None
else:
conv = generate_chat_conv(request, chat_template_name)
prompt = conv.get_prompt()
image_data = conv.image_data
stop = conv.stop_str or []
if request.stop:
if isinstance(request.stop, str):
@@ -210,9 +225,11 @@ async def v1_chat_completions(raw_request: Request):
# Use the raw prompt and stop strings if the messages is already a string.
prompt = request.messages
stop = request.stop
image_data = None
adapted_request = GenerateReqInput(
text=prompt,
image_data=image_data,
sampling_params={
"temperature": request.temperature,
"max_new_tokens": request.max_tokens,
@@ -303,6 +320,7 @@ def launch_server(server_args, pipe_finish_writer):
# Load chat template if needed
if server_args.chat_template is not None:
print(server_args.chat_template)
if not chat_template_exists(server_args.chat_template):
if not os.path.exists(server_args.chat_template):
raise RuntimeError(