diff --git a/README.md b/README.md index c33c6512f..7218f10b8 100644 --- a/README.md +++ b/README.md @@ -357,7 +357,7 @@ python -m sglang.launch_server --model-path meta-llama/Llama-2-7b-chat-hf --port - Mistral - Mixtral - LLaVA - - `python3 -m sglang.launch_server --model-path liuhaotian/llava-v1.5-7b --tokenizer-path llava-hf/llava-1.5-7b-hf --port 30000` + - `python3 -m sglang.launch_server --model-path liuhaotian/llava-v1.5-7b --tokenizer-path llava-hf/llava-1.5-7b-hf --chat-template vicuna_v1.1 --port 30000` - Qwen / Qwen 2 - AWQ quantization diff --git a/python/sglang/srt/conversation.py b/python/sglang/srt/conversation.py index 24c84a5c9..df872f77a 100644 --- a/python/sglang/srt/conversation.py +++ b/python/sglang/srt/conversation.py @@ -2,7 +2,7 @@ # https://github.com/lm-sys/FastChat/blob/main/fastchat/conversation.py import dataclasses from enum import IntEnum, auto -from typing import Dict, List, Tuple, Union +from typing import Dict, List, Optional, Tuple, Union from sglang.srt.managers.openai_protocol import ChatCompletionRequest @@ -52,6 +52,7 @@ class Conversation: sep2: str = None # Stop criteria (the default one is EOS token) stop_str: Union[str, List[str]] = None + image_data: Optional[List[str]] = None def get_prompt(self) -> str: """Get the prompt for generation.""" @@ -251,6 +252,10 @@ class Conversation: """Append a new message.""" self.messages.append([role, message]) + def append_image(self, image: str): + """Append a new message.""" + self.image_data.append(image) + def update_last_message(self, message: str): """Update the last output. @@ -341,18 +346,31 @@ def generate_chat_conv( sep=conv.sep, sep2=conv.sep2, stop_str=conv.stop_str, + image_data=[], ) if isinstance(request.messages, str): raise ValueError("The messages should be a list of dict.") for message in request.messages: - msg_role = message["role"] + msg_role = message.role if msg_role == "system": - conv.system_message = message["content"] + conv.system_message = message.content elif msg_role == "user": - conv.append_message(conv.roles[0], message["content"]) + # Handle the various types of Chat Request content types here. + role = conv.roles[0] + if isinstance(message.content, str): + conv.append_message(conv.roles[0], message.content) + else: + real_content = "" + for content in message.content: + if content.type == "text": + real_content += content.text + elif content.type == "image_url": + real_content += "" + conv.append_image(content.image_url.url) + conv.append_message(conv.roles[0], real_content) elif msg_role == "assistant": - conv.append_message(conv.roles[1], message["content"]) + conv.append_message(conv.roles[1], message.content) else: raise ValueError(f"Unknown role: {msg_role}") diff --git a/python/sglang/srt/managers/openai_protocol.py b/python/sglang/srt/managers/openai_protocol.py index 974e38a91..f4ef99dd9 100644 --- a/python/sglang/srt/managers/openai_protocol.py +++ b/python/sglang/srt/managers/openai_protocol.py @@ -1,5 +1,6 @@ import time from typing import Dict, List, Optional, Union +from typing_extensions import Literal from pydantic import BaseModel, Field @@ -68,9 +69,44 @@ class CompletionStreamResponse(BaseModel): usage: UsageInfo +class ChatCompletionMessageGenericParam(BaseModel): + role: Literal["system", "assistant"] + content: str + + +class ChatCompletionMessageContentTextPart(BaseModel): + type: Literal["text"] + text: str + + +class ChatCompletionMessageContentImageURL(BaseModel): + url: str + detail: Optional[Literal["auto", "low", "high"]] = "auto" + + +class ChatCompletionMessageContentImagePart(BaseModel): + type: Literal["image_url"] + image_url: ChatCompletionMessageContentImageURL + + +ChatCompletionMessageContentPart = Union[ + ChatCompletionMessageContentTextPart, ChatCompletionMessageContentImagePart +] + + +class ChatCompletionMessageUserParam(BaseModel): + role: Literal["user"] + content: Union[str, List[ChatCompletionMessageContentPart]] + + +ChatCompletionMessageParam = Union[ + ChatCompletionMessageGenericParam, ChatCompletionMessageUserParam +] + + class ChatCompletionRequest(BaseModel): model: str - messages: Union[str, List[Dict[str, str]]] + messages: Union[str, List[ChatCompletionMessageParam]] temperature: Optional[float] = 0.7 top_p: Optional[float] = 1.0 n: Optional[int] = 1 diff --git a/python/sglang/srt/managers/tokenizer_manager.py b/python/sglang/srt/managers/tokenizer_manager.py index 2213858bf..d67cb49ea 100644 --- a/python/sglang/srt/managers/tokenizer_manager.py +++ b/python/sglang/srt/managers/tokenizer_manager.py @@ -150,12 +150,17 @@ class TokenizerManager: if sampling_params.max_new_tokens != 0: sampling_params.normalize(self.tokenizer) sampling_params.verify() - if obj.image_data is None: - pixel_values, image_hash, image_size = None, None, None - else: + + if isinstance(obj.image_data, list) and len(obj.image_data) > 0: + pixel_values, image_hash, image_size = await self.get_pixel_values( + obj.image_data[0] + ) + elif isinstance(obj.image_data, str): pixel_values, image_hash, image_size = await self.get_pixel_values( obj.image_data ) + else: + pixel_values, image_hash, image_size = None, None, None tokenized_obj = TokenizedGenerateReqInput( rid=rid, input_text=obj.text, diff --git a/python/sglang/srt/server.py b/python/sglang/srt/server.py index 06d1b3def..56b62ee95 100644 --- a/python/sglang/srt/server.py +++ b/python/sglang/srt/server.py @@ -16,7 +16,7 @@ import psutil import requests import uvicorn import uvloop -from fastapi import FastAPI, Request +from fastapi import FastAPI, HTTPException, Request from fastapi.responses import Response, StreamingResponse from sglang.backend.runtime_endpoint import RuntimeEndpoint from sglang.srt.conversation import ( @@ -190,16 +190,31 @@ async def v1_chat_completions(raw_request: Request): # TODO: Validate the request and return HTTPStatus.BAD_REQUEST if invalid. assert request.n == 1 + # Prep the data needed for the underlying GenerateReqInput: + # - prompt: The full prompt string. + # - stop: Custom stop tokens. + # - image_data: None or a list of image strings (URLs or base64 strings). + # None skips any image processing in GenerateReqInput. if not isinstance(request.messages, str): # Apply chat template and its stop strings. if chat_template_name is None: + # This flow doesn't support the full OpenAI spec. Verify messages + # has the right type before proceeding: + for m in request.messages: + if not isinstance(m.content, str): + raise HTTPException( + status_code=503, + detail="Structured content requests not supported with HuggingFace Chat Templates. Make sure the server specifies a sglang chat template.", + ) prompt = tokenizer_manager.tokenizer.apply_chat_template( request.messages, tokenize=False, add_generation_prompt=True ) stop = request.stop + image_data = None else: conv = generate_chat_conv(request, chat_template_name) prompt = conv.get_prompt() + image_data = conv.image_data stop = conv.stop_str or [] if request.stop: if isinstance(request.stop, str): @@ -210,9 +225,11 @@ async def v1_chat_completions(raw_request: Request): # Use the raw prompt and stop strings if the messages is already a string. prompt = request.messages stop = request.stop + image_data = None adapted_request = GenerateReqInput( text=prompt, + image_data=image_data, sampling_params={ "temperature": request.temperature, "max_new_tokens": request.max_tokens, @@ -303,6 +320,7 @@ def launch_server(server_args, pipe_finish_writer): # Load chat template if needed if server_args.chat_template is not None: + print(server_args.chat_template) if not chat_template_exists(server_args.chat_template): if not os.path.exists(server_args.chat_template): raise RuntimeError( diff --git a/python/sglang/test/test_conversation.py b/python/sglang/test/test_conversation.py new file mode 100644 index 000000000..4f4f956fe --- /dev/null +++ b/python/sglang/test/test_conversation.py @@ -0,0 +1,46 @@ +from sglang.srt.conversation import generate_chat_conv +from sglang.srt.managers.openai_protocol import ( + ChatCompletionMessageGenericParam, + ChatCompletionMessageContentImagePart, + ChatCompletionMessageContentImageURL, + ChatCompletionMessageContentTextPart, + ChatCompletionMessageUserParam, + ChatCompletionRequest, +) + + +def test_chat_completion_to_conv_image(): + """Test that we can convert a chat image request to a convo""" + request = ChatCompletionRequest( + model="default", + messages=[ + ChatCompletionMessageGenericParam( + role="system", content="You are a helpful AI assistant" + ), + ChatCompletionMessageUserParam( + role="user", + content=[ + ChatCompletionMessageContentTextPart( + type="text", text="Describe this image" + ), + ChatCompletionMessageContentImagePart( + type="image_url", + image_url=ChatCompletionMessageContentImageURL( + url="https://someurl.com" + ), + ), + ], + ), + ], + ) + conv = generate_chat_conv(request, "vicuna_v1.1") + assert conv.messages == [ + ["USER", "Describe this image"], + ["ASSISTANT", None], + ] + assert conv.system_message == "You are a helpful AI assistant" + assert conv.image_data == ["https://someurl.com"] + + +if __name__ == "__main__": + test_chat_completion_to_conv_image() diff --git a/python/sglang/test/test_openai_protocol.py b/python/sglang/test/test_openai_protocol.py new file mode 100644 index 000000000..ed18e428a --- /dev/null +++ b/python/sglang/test/test_openai_protocol.py @@ -0,0 +1,51 @@ +from sglang.srt.managers.openai_protocol import ( + ChatCompletionMessageGenericParam, + ChatCompletionMessageContentImagePart, + ChatCompletionMessageContentImageURL, + ChatCompletionMessageContentTextPart, + ChatCompletionMessageUserParam, + ChatCompletionRequest, +) + + +def test_chat_completion_request_image(): + """Test that Chat Completion Requests with images can be converted.""" + + image_request = { + "model": "default", + "messages": [ + {"role": "system", "content": "You are a helpful AI assistant"}, + { + "role": "user", + "content": [ + {"type": "text", "text": "Describe this image"}, + {"type": "image_url", "image_url": {"url": "https://someurl.com"}}, + ], + }, + ], + "temperature": 0, + "max_tokens": 64, + } + request = ChatCompletionRequest(**image_request) + assert len(request.messages) == 2 + assert request.messages[0] == ChatCompletionMessageGenericParam( + role="system", content="You are a helpful AI assistant" + ) + assert request.messages[1] == ChatCompletionMessageUserParam( + role="user", + content=[ + ChatCompletionMessageContentTextPart( + type="text", text="Describe this image" + ), + ChatCompletionMessageContentImagePart( + type="image_url", + image_url=ChatCompletionMessageContentImageURL( + url="https://someurl.com" + ), + ), + ], + ) + + +if __name__ == "__main__": + test_chat_completion_request_image() diff --git a/test/srt/test_openai_server.py b/test/srt/test_openai_server.py index 33d5b0672..f0dc078e2 100644 --- a/test/srt/test_openai_server.py +++ b/test/srt/test_openai_server.py @@ -71,6 +71,36 @@ def test_chat_completion(args): assert response.usage.total_tokens > 0 +def test_chat_completion_image(args): + client = openai.Client(api_key="EMPTY", base_url=args.base_url) + response = client.chat.completions.create( + model="default", + messages=[ + {"role": "system", "content": "You are a helpful AI assistant"}, + { + "role": "user", + "content": [ + {"type": "text", "text": "Describe this image"}, + { + "type": "image_url", + "image_url": { + "url": "https://raw.githubusercontent.com/sgl-project/sglang/main/assets/mixtral_8x7b.jpg" + }, + }, + ], + }, + ], + temperature=0, + max_tokens=32, + ) + print(response.choices[0].message.content) + assert response.id + assert response.created + assert response.usage.prompt_tokens > 0 + assert response.usage.completion_tokens > 0 + assert response.usage.total_tokens > 0 + + def test_chat_completion_stream(args): client = openai.Client(api_key="EMPTY", base_url=args.base_url) response = client.chat.completions.create( @@ -100,9 +130,14 @@ def test_chat_completion_stream(args): if __name__ == "__main__": parser = argparse.ArgumentParser() parser.add_argument("--base-url", type=str, default="http://127.0.0.1:30000/v1") + parser.add_argument( + "--test-image", action="store_true", help="Enables testing image inputs" + ) args = parser.parse_args() test_completion(args) test_completion_stream(args) test_chat_completion(args) test_chat_completion_stream(args) + if args.test_image: + test_chat_completion_image(args)