[Feature] Adds basic support for image content in OpenAI chat routes (#113)
This commit is contained in:
@@ -357,7 +357,7 @@ python -m sglang.launch_server --model-path meta-llama/Llama-2-7b-chat-hf --port
|
|||||||
- Mistral
|
- Mistral
|
||||||
- Mixtral
|
- Mixtral
|
||||||
- LLaVA
|
- LLaVA
|
||||||
- `python3 -m sglang.launch_server --model-path liuhaotian/llava-v1.5-7b --tokenizer-path llava-hf/llava-1.5-7b-hf --port 30000`
|
- `python3 -m sglang.launch_server --model-path liuhaotian/llava-v1.5-7b --tokenizer-path llava-hf/llava-1.5-7b-hf --chat-template vicuna_v1.1 --port 30000`
|
||||||
- Qwen / Qwen 2
|
- Qwen / Qwen 2
|
||||||
- AWQ quantization
|
- AWQ quantization
|
||||||
|
|
||||||
|
|||||||
@@ -2,7 +2,7 @@
|
|||||||
# https://github.com/lm-sys/FastChat/blob/main/fastchat/conversation.py
|
# https://github.com/lm-sys/FastChat/blob/main/fastchat/conversation.py
|
||||||
import dataclasses
|
import dataclasses
|
||||||
from enum import IntEnum, auto
|
from enum import IntEnum, auto
|
||||||
from typing import Dict, List, Tuple, Union
|
from typing import Dict, List, Optional, Tuple, Union
|
||||||
|
|
||||||
from sglang.srt.managers.openai_protocol import ChatCompletionRequest
|
from sglang.srt.managers.openai_protocol import ChatCompletionRequest
|
||||||
|
|
||||||
@@ -52,6 +52,7 @@ class Conversation:
|
|||||||
sep2: str = None
|
sep2: str = None
|
||||||
# Stop criteria (the default one is EOS token)
|
# Stop criteria (the default one is EOS token)
|
||||||
stop_str: Union[str, List[str]] = None
|
stop_str: Union[str, List[str]] = None
|
||||||
|
image_data: Optional[List[str]] = None
|
||||||
|
|
||||||
def get_prompt(self) -> str:
|
def get_prompt(self) -> str:
|
||||||
"""Get the prompt for generation."""
|
"""Get the prompt for generation."""
|
||||||
@@ -251,6 +252,10 @@ class Conversation:
|
|||||||
"""Append a new message."""
|
"""Append a new message."""
|
||||||
self.messages.append([role, message])
|
self.messages.append([role, message])
|
||||||
|
|
||||||
|
def append_image(self, image: str):
|
||||||
|
"""Append a new message."""
|
||||||
|
self.image_data.append(image)
|
||||||
|
|
||||||
def update_last_message(self, message: str):
|
def update_last_message(self, message: str):
|
||||||
"""Update the last output.
|
"""Update the last output.
|
||||||
|
|
||||||
@@ -341,18 +346,31 @@ def generate_chat_conv(
|
|||||||
sep=conv.sep,
|
sep=conv.sep,
|
||||||
sep2=conv.sep2,
|
sep2=conv.sep2,
|
||||||
stop_str=conv.stop_str,
|
stop_str=conv.stop_str,
|
||||||
|
image_data=[],
|
||||||
)
|
)
|
||||||
|
|
||||||
if isinstance(request.messages, str):
|
if isinstance(request.messages, str):
|
||||||
raise ValueError("The messages should be a list of dict.")
|
raise ValueError("The messages should be a list of dict.")
|
||||||
for message in request.messages:
|
for message in request.messages:
|
||||||
msg_role = message["role"]
|
msg_role = message.role
|
||||||
if msg_role == "system":
|
if msg_role == "system":
|
||||||
conv.system_message = message["content"]
|
conv.system_message = message.content
|
||||||
elif msg_role == "user":
|
elif msg_role == "user":
|
||||||
conv.append_message(conv.roles[0], message["content"])
|
# Handle the various types of Chat Request content types here.
|
||||||
|
role = conv.roles[0]
|
||||||
|
if isinstance(message.content, str):
|
||||||
|
conv.append_message(conv.roles[0], message.content)
|
||||||
|
else:
|
||||||
|
real_content = ""
|
||||||
|
for content in message.content:
|
||||||
|
if content.type == "text":
|
||||||
|
real_content += content.text
|
||||||
|
elif content.type == "image_url":
|
||||||
|
real_content += "<image>"
|
||||||
|
conv.append_image(content.image_url.url)
|
||||||
|
conv.append_message(conv.roles[0], real_content)
|
||||||
elif msg_role == "assistant":
|
elif msg_role == "assistant":
|
||||||
conv.append_message(conv.roles[1], message["content"])
|
conv.append_message(conv.roles[1], message.content)
|
||||||
else:
|
else:
|
||||||
raise ValueError(f"Unknown role: {msg_role}")
|
raise ValueError(f"Unknown role: {msg_role}")
|
||||||
|
|
||||||
|
|||||||
@@ -1,5 +1,6 @@
|
|||||||
import time
|
import time
|
||||||
from typing import Dict, List, Optional, Union
|
from typing import Dict, List, Optional, Union
|
||||||
|
from typing_extensions import Literal
|
||||||
|
|
||||||
from pydantic import BaseModel, Field
|
from pydantic import BaseModel, Field
|
||||||
|
|
||||||
@@ -68,9 +69,44 @@ class CompletionStreamResponse(BaseModel):
|
|||||||
usage: UsageInfo
|
usage: UsageInfo
|
||||||
|
|
||||||
|
|
||||||
|
class ChatCompletionMessageGenericParam(BaseModel):
|
||||||
|
role: Literal["system", "assistant"]
|
||||||
|
content: str
|
||||||
|
|
||||||
|
|
||||||
|
class ChatCompletionMessageContentTextPart(BaseModel):
|
||||||
|
type: Literal["text"]
|
||||||
|
text: str
|
||||||
|
|
||||||
|
|
||||||
|
class ChatCompletionMessageContentImageURL(BaseModel):
|
||||||
|
url: str
|
||||||
|
detail: Optional[Literal["auto", "low", "high"]] = "auto"
|
||||||
|
|
||||||
|
|
||||||
|
class ChatCompletionMessageContentImagePart(BaseModel):
|
||||||
|
type: Literal["image_url"]
|
||||||
|
image_url: ChatCompletionMessageContentImageURL
|
||||||
|
|
||||||
|
|
||||||
|
ChatCompletionMessageContentPart = Union[
|
||||||
|
ChatCompletionMessageContentTextPart, ChatCompletionMessageContentImagePart
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
class ChatCompletionMessageUserParam(BaseModel):
|
||||||
|
role: Literal["user"]
|
||||||
|
content: Union[str, List[ChatCompletionMessageContentPart]]
|
||||||
|
|
||||||
|
|
||||||
|
ChatCompletionMessageParam = Union[
|
||||||
|
ChatCompletionMessageGenericParam, ChatCompletionMessageUserParam
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
class ChatCompletionRequest(BaseModel):
|
class ChatCompletionRequest(BaseModel):
|
||||||
model: str
|
model: str
|
||||||
messages: Union[str, List[Dict[str, str]]]
|
messages: Union[str, List[ChatCompletionMessageParam]]
|
||||||
temperature: Optional[float] = 0.7
|
temperature: Optional[float] = 0.7
|
||||||
top_p: Optional[float] = 1.0
|
top_p: Optional[float] = 1.0
|
||||||
n: Optional[int] = 1
|
n: Optional[int] = 1
|
||||||
|
|||||||
@@ -150,12 +150,17 @@ class TokenizerManager:
|
|||||||
if sampling_params.max_new_tokens != 0:
|
if sampling_params.max_new_tokens != 0:
|
||||||
sampling_params.normalize(self.tokenizer)
|
sampling_params.normalize(self.tokenizer)
|
||||||
sampling_params.verify()
|
sampling_params.verify()
|
||||||
if obj.image_data is None:
|
|
||||||
pixel_values, image_hash, image_size = None, None, None
|
if isinstance(obj.image_data, list) and len(obj.image_data) > 0:
|
||||||
else:
|
pixel_values, image_hash, image_size = await self.get_pixel_values(
|
||||||
|
obj.image_data[0]
|
||||||
|
)
|
||||||
|
elif isinstance(obj.image_data, str):
|
||||||
pixel_values, image_hash, image_size = await self.get_pixel_values(
|
pixel_values, image_hash, image_size = await self.get_pixel_values(
|
||||||
obj.image_data
|
obj.image_data
|
||||||
)
|
)
|
||||||
|
else:
|
||||||
|
pixel_values, image_hash, image_size = None, None, None
|
||||||
tokenized_obj = TokenizedGenerateReqInput(
|
tokenized_obj = TokenizedGenerateReqInput(
|
||||||
rid=rid,
|
rid=rid,
|
||||||
input_text=obj.text,
|
input_text=obj.text,
|
||||||
|
|||||||
@@ -16,7 +16,7 @@ import psutil
|
|||||||
import requests
|
import requests
|
||||||
import uvicorn
|
import uvicorn
|
||||||
import uvloop
|
import uvloop
|
||||||
from fastapi import FastAPI, Request
|
from fastapi import FastAPI, HTTPException, Request
|
||||||
from fastapi.responses import Response, StreamingResponse
|
from fastapi.responses import Response, StreamingResponse
|
||||||
from sglang.backend.runtime_endpoint import RuntimeEndpoint
|
from sglang.backend.runtime_endpoint import RuntimeEndpoint
|
||||||
from sglang.srt.conversation import (
|
from sglang.srt.conversation import (
|
||||||
@@ -190,16 +190,31 @@ async def v1_chat_completions(raw_request: Request):
|
|||||||
# TODO: Validate the request and return HTTPStatus.BAD_REQUEST if invalid.
|
# TODO: Validate the request and return HTTPStatus.BAD_REQUEST if invalid.
|
||||||
assert request.n == 1
|
assert request.n == 1
|
||||||
|
|
||||||
|
# Prep the data needed for the underlying GenerateReqInput:
|
||||||
|
# - prompt: The full prompt string.
|
||||||
|
# - stop: Custom stop tokens.
|
||||||
|
# - image_data: None or a list of image strings (URLs or base64 strings).
|
||||||
|
# None skips any image processing in GenerateReqInput.
|
||||||
if not isinstance(request.messages, str):
|
if not isinstance(request.messages, str):
|
||||||
# Apply chat template and its stop strings.
|
# Apply chat template and its stop strings.
|
||||||
if chat_template_name is None:
|
if chat_template_name is None:
|
||||||
|
# This flow doesn't support the full OpenAI spec. Verify messages
|
||||||
|
# has the right type before proceeding:
|
||||||
|
for m in request.messages:
|
||||||
|
if not isinstance(m.content, str):
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=503,
|
||||||
|
detail="Structured content requests not supported with HuggingFace Chat Templates. Make sure the server specifies a sglang chat template.",
|
||||||
|
)
|
||||||
prompt = tokenizer_manager.tokenizer.apply_chat_template(
|
prompt = tokenizer_manager.tokenizer.apply_chat_template(
|
||||||
request.messages, tokenize=False, add_generation_prompt=True
|
request.messages, tokenize=False, add_generation_prompt=True
|
||||||
)
|
)
|
||||||
stop = request.stop
|
stop = request.stop
|
||||||
|
image_data = None
|
||||||
else:
|
else:
|
||||||
conv = generate_chat_conv(request, chat_template_name)
|
conv = generate_chat_conv(request, chat_template_name)
|
||||||
prompt = conv.get_prompt()
|
prompt = conv.get_prompt()
|
||||||
|
image_data = conv.image_data
|
||||||
stop = conv.stop_str or []
|
stop = conv.stop_str or []
|
||||||
if request.stop:
|
if request.stop:
|
||||||
if isinstance(request.stop, str):
|
if isinstance(request.stop, str):
|
||||||
@@ -210,9 +225,11 @@ async def v1_chat_completions(raw_request: Request):
|
|||||||
# Use the raw prompt and stop strings if the messages is already a string.
|
# Use the raw prompt and stop strings if the messages is already a string.
|
||||||
prompt = request.messages
|
prompt = request.messages
|
||||||
stop = request.stop
|
stop = request.stop
|
||||||
|
image_data = None
|
||||||
|
|
||||||
adapted_request = GenerateReqInput(
|
adapted_request = GenerateReqInput(
|
||||||
text=prompt,
|
text=prompt,
|
||||||
|
image_data=image_data,
|
||||||
sampling_params={
|
sampling_params={
|
||||||
"temperature": request.temperature,
|
"temperature": request.temperature,
|
||||||
"max_new_tokens": request.max_tokens,
|
"max_new_tokens": request.max_tokens,
|
||||||
@@ -303,6 +320,7 @@ def launch_server(server_args, pipe_finish_writer):
|
|||||||
|
|
||||||
# Load chat template if needed
|
# Load chat template if needed
|
||||||
if server_args.chat_template is not None:
|
if server_args.chat_template is not None:
|
||||||
|
print(server_args.chat_template)
|
||||||
if not chat_template_exists(server_args.chat_template):
|
if not chat_template_exists(server_args.chat_template):
|
||||||
if not os.path.exists(server_args.chat_template):
|
if not os.path.exists(server_args.chat_template):
|
||||||
raise RuntimeError(
|
raise RuntimeError(
|
||||||
|
|||||||
46
python/sglang/test/test_conversation.py
Normal file
46
python/sglang/test/test_conversation.py
Normal file
@@ -0,0 +1,46 @@
|
|||||||
|
from sglang.srt.conversation import generate_chat_conv
|
||||||
|
from sglang.srt.managers.openai_protocol import (
|
||||||
|
ChatCompletionMessageGenericParam,
|
||||||
|
ChatCompletionMessageContentImagePart,
|
||||||
|
ChatCompletionMessageContentImageURL,
|
||||||
|
ChatCompletionMessageContentTextPart,
|
||||||
|
ChatCompletionMessageUserParam,
|
||||||
|
ChatCompletionRequest,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def test_chat_completion_to_conv_image():
|
||||||
|
"""Test that we can convert a chat image request to a convo"""
|
||||||
|
request = ChatCompletionRequest(
|
||||||
|
model="default",
|
||||||
|
messages=[
|
||||||
|
ChatCompletionMessageGenericParam(
|
||||||
|
role="system", content="You are a helpful AI assistant"
|
||||||
|
),
|
||||||
|
ChatCompletionMessageUserParam(
|
||||||
|
role="user",
|
||||||
|
content=[
|
||||||
|
ChatCompletionMessageContentTextPart(
|
||||||
|
type="text", text="Describe this image"
|
||||||
|
),
|
||||||
|
ChatCompletionMessageContentImagePart(
|
||||||
|
type="image_url",
|
||||||
|
image_url=ChatCompletionMessageContentImageURL(
|
||||||
|
url="https://someurl.com"
|
||||||
|
),
|
||||||
|
),
|
||||||
|
],
|
||||||
|
),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
conv = generate_chat_conv(request, "vicuna_v1.1")
|
||||||
|
assert conv.messages == [
|
||||||
|
["USER", "Describe this image<image>"],
|
||||||
|
["ASSISTANT", None],
|
||||||
|
]
|
||||||
|
assert conv.system_message == "You are a helpful AI assistant"
|
||||||
|
assert conv.image_data == ["https://someurl.com"]
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
test_chat_completion_to_conv_image()
|
||||||
51
python/sglang/test/test_openai_protocol.py
Normal file
51
python/sglang/test/test_openai_protocol.py
Normal file
@@ -0,0 +1,51 @@
|
|||||||
|
from sglang.srt.managers.openai_protocol import (
|
||||||
|
ChatCompletionMessageGenericParam,
|
||||||
|
ChatCompletionMessageContentImagePart,
|
||||||
|
ChatCompletionMessageContentImageURL,
|
||||||
|
ChatCompletionMessageContentTextPart,
|
||||||
|
ChatCompletionMessageUserParam,
|
||||||
|
ChatCompletionRequest,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def test_chat_completion_request_image():
|
||||||
|
"""Test that Chat Completion Requests with images can be converted."""
|
||||||
|
|
||||||
|
image_request = {
|
||||||
|
"model": "default",
|
||||||
|
"messages": [
|
||||||
|
{"role": "system", "content": "You are a helpful AI assistant"},
|
||||||
|
{
|
||||||
|
"role": "user",
|
||||||
|
"content": [
|
||||||
|
{"type": "text", "text": "Describe this image"},
|
||||||
|
{"type": "image_url", "image_url": {"url": "https://someurl.com"}},
|
||||||
|
],
|
||||||
|
},
|
||||||
|
],
|
||||||
|
"temperature": 0,
|
||||||
|
"max_tokens": 64,
|
||||||
|
}
|
||||||
|
request = ChatCompletionRequest(**image_request)
|
||||||
|
assert len(request.messages) == 2
|
||||||
|
assert request.messages[0] == ChatCompletionMessageGenericParam(
|
||||||
|
role="system", content="You are a helpful AI assistant"
|
||||||
|
)
|
||||||
|
assert request.messages[1] == ChatCompletionMessageUserParam(
|
||||||
|
role="user",
|
||||||
|
content=[
|
||||||
|
ChatCompletionMessageContentTextPart(
|
||||||
|
type="text", text="Describe this image"
|
||||||
|
),
|
||||||
|
ChatCompletionMessageContentImagePart(
|
||||||
|
type="image_url",
|
||||||
|
image_url=ChatCompletionMessageContentImageURL(
|
||||||
|
url="https://someurl.com"
|
||||||
|
),
|
||||||
|
),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
test_chat_completion_request_image()
|
||||||
@@ -71,6 +71,36 @@ def test_chat_completion(args):
|
|||||||
assert response.usage.total_tokens > 0
|
assert response.usage.total_tokens > 0
|
||||||
|
|
||||||
|
|
||||||
|
def test_chat_completion_image(args):
|
||||||
|
client = openai.Client(api_key="EMPTY", base_url=args.base_url)
|
||||||
|
response = client.chat.completions.create(
|
||||||
|
model="default",
|
||||||
|
messages=[
|
||||||
|
{"role": "system", "content": "You are a helpful AI assistant"},
|
||||||
|
{
|
||||||
|
"role": "user",
|
||||||
|
"content": [
|
||||||
|
{"type": "text", "text": "Describe this image"},
|
||||||
|
{
|
||||||
|
"type": "image_url",
|
||||||
|
"image_url": {
|
||||||
|
"url": "https://raw.githubusercontent.com/sgl-project/sglang/main/assets/mixtral_8x7b.jpg"
|
||||||
|
},
|
||||||
|
},
|
||||||
|
],
|
||||||
|
},
|
||||||
|
],
|
||||||
|
temperature=0,
|
||||||
|
max_tokens=32,
|
||||||
|
)
|
||||||
|
print(response.choices[0].message.content)
|
||||||
|
assert response.id
|
||||||
|
assert response.created
|
||||||
|
assert response.usage.prompt_tokens > 0
|
||||||
|
assert response.usage.completion_tokens > 0
|
||||||
|
assert response.usage.total_tokens > 0
|
||||||
|
|
||||||
|
|
||||||
def test_chat_completion_stream(args):
|
def test_chat_completion_stream(args):
|
||||||
client = openai.Client(api_key="EMPTY", base_url=args.base_url)
|
client = openai.Client(api_key="EMPTY", base_url=args.base_url)
|
||||||
response = client.chat.completions.create(
|
response = client.chat.completions.create(
|
||||||
@@ -100,9 +130,14 @@ def test_chat_completion_stream(args):
|
|||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
parser = argparse.ArgumentParser()
|
parser = argparse.ArgumentParser()
|
||||||
parser.add_argument("--base-url", type=str, default="http://127.0.0.1:30000/v1")
|
parser.add_argument("--base-url", type=str, default="http://127.0.0.1:30000/v1")
|
||||||
|
parser.add_argument(
|
||||||
|
"--test-image", action="store_true", help="Enables testing image inputs"
|
||||||
|
)
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
test_completion(args)
|
test_completion(args)
|
||||||
test_completion_stream(args)
|
test_completion_stream(args)
|
||||||
test_chat_completion(args)
|
test_chat_completion(args)
|
||||||
test_chat_completion_stream(args)
|
test_chat_completion_stream(args)
|
||||||
|
if args.test_image:
|
||||||
|
test_chat_completion_image(args)
|
||||||
|
|||||||
Reference in New Issue
Block a user