[Feature] Adds basic support for image content in OpenAI chat routes (#113)
This commit is contained in:
@@ -2,7 +2,7 @@
|
||||
# https://github.com/lm-sys/FastChat/blob/main/fastchat/conversation.py
|
||||
import dataclasses
|
||||
from enum import IntEnum, auto
|
||||
from typing import Dict, List, Tuple, Union
|
||||
from typing import Dict, List, Optional, Tuple, Union
|
||||
|
||||
from sglang.srt.managers.openai_protocol import ChatCompletionRequest
|
||||
|
||||
@@ -52,6 +52,7 @@ class Conversation:
|
||||
sep2: str = None
|
||||
# Stop criteria (the default one is EOS token)
|
||||
stop_str: Union[str, List[str]] = None
|
||||
image_data: Optional[List[str]] = None
|
||||
|
||||
def get_prompt(self) -> str:
|
||||
"""Get the prompt for generation."""
|
||||
@@ -251,6 +252,10 @@ class Conversation:
|
||||
"""Append a new message."""
|
||||
self.messages.append([role, message])
|
||||
|
||||
def append_image(self, image: str):
|
||||
"""Append a new message."""
|
||||
self.image_data.append(image)
|
||||
|
||||
def update_last_message(self, message: str):
|
||||
"""Update the last output.
|
||||
|
||||
@@ -341,18 +346,31 @@ def generate_chat_conv(
|
||||
sep=conv.sep,
|
||||
sep2=conv.sep2,
|
||||
stop_str=conv.stop_str,
|
||||
image_data=[],
|
||||
)
|
||||
|
||||
if isinstance(request.messages, str):
|
||||
raise ValueError("The messages should be a list of dict.")
|
||||
for message in request.messages:
|
||||
msg_role = message["role"]
|
||||
msg_role = message.role
|
||||
if msg_role == "system":
|
||||
conv.system_message = message["content"]
|
||||
conv.system_message = message.content
|
||||
elif msg_role == "user":
|
||||
conv.append_message(conv.roles[0], message["content"])
|
||||
# Handle the various types of Chat Request content types here.
|
||||
role = conv.roles[0]
|
||||
if isinstance(message.content, str):
|
||||
conv.append_message(conv.roles[0], message.content)
|
||||
else:
|
||||
real_content = ""
|
||||
for content in message.content:
|
||||
if content.type == "text":
|
||||
real_content += content.text
|
||||
elif content.type == "image_url":
|
||||
real_content += "<image>"
|
||||
conv.append_image(content.image_url.url)
|
||||
conv.append_message(conv.roles[0], real_content)
|
||||
elif msg_role == "assistant":
|
||||
conv.append_message(conv.roles[1], message["content"])
|
||||
conv.append_message(conv.roles[1], message.content)
|
||||
else:
|
||||
raise ValueError(f"Unknown role: {msg_role}")
|
||||
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
import time
|
||||
from typing import Dict, List, Optional, Union
|
||||
from typing_extensions import Literal
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
@@ -68,9 +69,44 @@ class CompletionStreamResponse(BaseModel):
|
||||
usage: UsageInfo
|
||||
|
||||
|
||||
class ChatCompletionMessageGenericParam(BaseModel):
|
||||
role: Literal["system", "assistant"]
|
||||
content: str
|
||||
|
||||
|
||||
class ChatCompletionMessageContentTextPart(BaseModel):
|
||||
type: Literal["text"]
|
||||
text: str
|
||||
|
||||
|
||||
class ChatCompletionMessageContentImageURL(BaseModel):
|
||||
url: str
|
||||
detail: Optional[Literal["auto", "low", "high"]] = "auto"
|
||||
|
||||
|
||||
class ChatCompletionMessageContentImagePart(BaseModel):
|
||||
type: Literal["image_url"]
|
||||
image_url: ChatCompletionMessageContentImageURL
|
||||
|
||||
|
||||
ChatCompletionMessageContentPart = Union[
|
||||
ChatCompletionMessageContentTextPart, ChatCompletionMessageContentImagePart
|
||||
]
|
||||
|
||||
|
||||
class ChatCompletionMessageUserParam(BaseModel):
|
||||
role: Literal["user"]
|
||||
content: Union[str, List[ChatCompletionMessageContentPart]]
|
||||
|
||||
|
||||
ChatCompletionMessageParam = Union[
|
||||
ChatCompletionMessageGenericParam, ChatCompletionMessageUserParam
|
||||
]
|
||||
|
||||
|
||||
class ChatCompletionRequest(BaseModel):
|
||||
model: str
|
||||
messages: Union[str, List[Dict[str, str]]]
|
||||
messages: Union[str, List[ChatCompletionMessageParam]]
|
||||
temperature: Optional[float] = 0.7
|
||||
top_p: Optional[float] = 1.0
|
||||
n: Optional[int] = 1
|
||||
|
||||
@@ -150,12 +150,17 @@ class TokenizerManager:
|
||||
if sampling_params.max_new_tokens != 0:
|
||||
sampling_params.normalize(self.tokenizer)
|
||||
sampling_params.verify()
|
||||
if obj.image_data is None:
|
||||
pixel_values, image_hash, image_size = None, None, None
|
||||
else:
|
||||
|
||||
if isinstance(obj.image_data, list) and len(obj.image_data) > 0:
|
||||
pixel_values, image_hash, image_size = await self.get_pixel_values(
|
||||
obj.image_data[0]
|
||||
)
|
||||
elif isinstance(obj.image_data, str):
|
||||
pixel_values, image_hash, image_size = await self.get_pixel_values(
|
||||
obj.image_data
|
||||
)
|
||||
else:
|
||||
pixel_values, image_hash, image_size = None, None, None
|
||||
tokenized_obj = TokenizedGenerateReqInput(
|
||||
rid=rid,
|
||||
input_text=obj.text,
|
||||
|
||||
@@ -16,7 +16,7 @@ import psutil
|
||||
import requests
|
||||
import uvicorn
|
||||
import uvloop
|
||||
from fastapi import FastAPI, Request
|
||||
from fastapi import FastAPI, HTTPException, Request
|
||||
from fastapi.responses import Response, StreamingResponse
|
||||
from sglang.backend.runtime_endpoint import RuntimeEndpoint
|
||||
from sglang.srt.conversation import (
|
||||
@@ -190,16 +190,31 @@ async def v1_chat_completions(raw_request: Request):
|
||||
# TODO: Validate the request and return HTTPStatus.BAD_REQUEST if invalid.
|
||||
assert request.n == 1
|
||||
|
||||
# Prep the data needed for the underlying GenerateReqInput:
|
||||
# - prompt: The full prompt string.
|
||||
# - stop: Custom stop tokens.
|
||||
# - image_data: None or a list of image strings (URLs or base64 strings).
|
||||
# None skips any image processing in GenerateReqInput.
|
||||
if not isinstance(request.messages, str):
|
||||
# Apply chat template and its stop strings.
|
||||
if chat_template_name is None:
|
||||
# This flow doesn't support the full OpenAI spec. Verify messages
|
||||
# has the right type before proceeding:
|
||||
for m in request.messages:
|
||||
if not isinstance(m.content, str):
|
||||
raise HTTPException(
|
||||
status_code=503,
|
||||
detail="Structured content requests not supported with HuggingFace Chat Templates. Make sure the server specifies a sglang chat template.",
|
||||
)
|
||||
prompt = tokenizer_manager.tokenizer.apply_chat_template(
|
||||
request.messages, tokenize=False, add_generation_prompt=True
|
||||
)
|
||||
stop = request.stop
|
||||
image_data = None
|
||||
else:
|
||||
conv = generate_chat_conv(request, chat_template_name)
|
||||
prompt = conv.get_prompt()
|
||||
image_data = conv.image_data
|
||||
stop = conv.stop_str or []
|
||||
if request.stop:
|
||||
if isinstance(request.stop, str):
|
||||
@@ -210,9 +225,11 @@ async def v1_chat_completions(raw_request: Request):
|
||||
# Use the raw prompt and stop strings if the messages is already a string.
|
||||
prompt = request.messages
|
||||
stop = request.stop
|
||||
image_data = None
|
||||
|
||||
adapted_request = GenerateReqInput(
|
||||
text=prompt,
|
||||
image_data=image_data,
|
||||
sampling_params={
|
||||
"temperature": request.temperature,
|
||||
"max_new_tokens": request.max_tokens,
|
||||
@@ -303,6 +320,7 @@ def launch_server(server_args, pipe_finish_writer):
|
||||
|
||||
# Load chat template if needed
|
||||
if server_args.chat_template is not None:
|
||||
print(server_args.chat_template)
|
||||
if not chat_template_exists(server_args.chat_template):
|
||||
if not os.path.exists(server_args.chat_template):
|
||||
raise RuntimeError(
|
||||
|
||||
46
python/sglang/test/test_conversation.py
Normal file
46
python/sglang/test/test_conversation.py
Normal file
@@ -0,0 +1,46 @@
|
||||
from sglang.srt.conversation import generate_chat_conv
|
||||
from sglang.srt.managers.openai_protocol import (
|
||||
ChatCompletionMessageGenericParam,
|
||||
ChatCompletionMessageContentImagePart,
|
||||
ChatCompletionMessageContentImageURL,
|
||||
ChatCompletionMessageContentTextPart,
|
||||
ChatCompletionMessageUserParam,
|
||||
ChatCompletionRequest,
|
||||
)
|
||||
|
||||
|
||||
def test_chat_completion_to_conv_image():
|
||||
"""Test that we can convert a chat image request to a convo"""
|
||||
request = ChatCompletionRequest(
|
||||
model="default",
|
||||
messages=[
|
||||
ChatCompletionMessageGenericParam(
|
||||
role="system", content="You are a helpful AI assistant"
|
||||
),
|
||||
ChatCompletionMessageUserParam(
|
||||
role="user",
|
||||
content=[
|
||||
ChatCompletionMessageContentTextPart(
|
||||
type="text", text="Describe this image"
|
||||
),
|
||||
ChatCompletionMessageContentImagePart(
|
||||
type="image_url",
|
||||
image_url=ChatCompletionMessageContentImageURL(
|
||||
url="https://someurl.com"
|
||||
),
|
||||
),
|
||||
],
|
||||
),
|
||||
],
|
||||
)
|
||||
conv = generate_chat_conv(request, "vicuna_v1.1")
|
||||
assert conv.messages == [
|
||||
["USER", "Describe this image<image>"],
|
||||
["ASSISTANT", None],
|
||||
]
|
||||
assert conv.system_message == "You are a helpful AI assistant"
|
||||
assert conv.image_data == ["https://someurl.com"]
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test_chat_completion_to_conv_image()
|
||||
51
python/sglang/test/test_openai_protocol.py
Normal file
51
python/sglang/test/test_openai_protocol.py
Normal file
@@ -0,0 +1,51 @@
|
||||
from sglang.srt.managers.openai_protocol import (
|
||||
ChatCompletionMessageGenericParam,
|
||||
ChatCompletionMessageContentImagePart,
|
||||
ChatCompletionMessageContentImageURL,
|
||||
ChatCompletionMessageContentTextPart,
|
||||
ChatCompletionMessageUserParam,
|
||||
ChatCompletionRequest,
|
||||
)
|
||||
|
||||
|
||||
def test_chat_completion_request_image():
|
||||
"""Test that Chat Completion Requests with images can be converted."""
|
||||
|
||||
image_request = {
|
||||
"model": "default",
|
||||
"messages": [
|
||||
{"role": "system", "content": "You are a helpful AI assistant"},
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{"type": "text", "text": "Describe this image"},
|
||||
{"type": "image_url", "image_url": {"url": "https://someurl.com"}},
|
||||
],
|
||||
},
|
||||
],
|
||||
"temperature": 0,
|
||||
"max_tokens": 64,
|
||||
}
|
||||
request = ChatCompletionRequest(**image_request)
|
||||
assert len(request.messages) == 2
|
||||
assert request.messages[0] == ChatCompletionMessageGenericParam(
|
||||
role="system", content="You are a helpful AI assistant"
|
||||
)
|
||||
assert request.messages[1] == ChatCompletionMessageUserParam(
|
||||
role="user",
|
||||
content=[
|
||||
ChatCompletionMessageContentTextPart(
|
||||
type="text", text="Describe this image"
|
||||
),
|
||||
ChatCompletionMessageContentImagePart(
|
||||
type="image_url",
|
||||
image_url=ChatCompletionMessageContentImageURL(
|
||||
url="https://someurl.com"
|
||||
),
|
||||
),
|
||||
],
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test_chat_completion_request_image()
|
||||
Reference in New Issue
Block a user