diff --git a/examples/quick_start/gemini_example_complete.py b/examples/quick_start/gemini_example_complete.py new file mode 100644 index 000000000..e3fe028a1 --- /dev/null +++ b/examples/quick_start/gemini_example_complete.py @@ -0,0 +1,26 @@ +from sglang import function, gen, set_default_backend, Gemini + + +@function +def few_shot_qa(s, question): + s += ( +"""The following are questions with answers. +Q: What is the capital of France? +A: Paris +Q: What is the capital of Germany? +A: Berlin +Q: What is the capital of Italy? +A: Rome +""") + s += "Q: " + question + "\n" + s += "A:" + gen("answer", stop="\n", temperature=0) + + +set_default_backend(Gemini("gemini-pro")) + +state = few_shot_qa.run(question="What is the capital of the United States?") +answer = state["answer"].strip().lower() + +assert "washington" in answer, f"answer: {state['answer']}" + +print(state.text()) diff --git a/examples/quick_start/gemini_example_multimodal_chat.py b/examples/quick_start/gemini_example_multimodal_chat.py new file mode 100644 index 000000000..312679a7e --- /dev/null +++ b/examples/quick_start/gemini_example_multimodal_chat.py @@ -0,0 +1,19 @@ +from sglang import function, user, assistant, gen, image, set_default_backend, Gemini + + +@function +def image_qa(s, image_file1, image_file2, question): + s += user(image(image_file1) + image(image_file2) + question) + s += assistant(gen("answer_1", max_tokens=256)) + +set_default_backend(Gemini("gemini-pro-vision")) + +state = image_qa.run( + image_file1="./images/cat.jpeg", + image_file2="./images/dog.jpeg", + question="Describe difference of the 2 images in one sentence.", + stream=True +) + +for out in state.text_iter(): + print(out, end="", flush=True) \ No newline at end of file diff --git a/examples/quick_start/gemini_example_stream.py b/examples/quick_start/gemini_example_stream.py new file mode 100644 index 000000000..8416ea648 --- /dev/null +++ b/examples/quick_start/gemini_example_stream.py @@ -0,0 +1,20 @@ +from sglang import function, user, assistant, gen, set_default_backend, Gemini + + +@function +def multi_turn_question(s, question_1, question_2): + s += user(question_1) + s += assistant(gen("answer_1", max_tokens=256)) + s += user(question_2) + s += assistant(gen("answer_2", max_tokens=256)) + +set_default_backend(Gemini("gemini-pro")) + +state = multi_turn_question.run( + question_1="What is the capital of the United States?", + question_2="List two local attractions.", + stream=True +) + +for out in state.text_iter(): + print(out, end="", flush=True) diff --git a/examples/quick_start/images/cat.jpeg b/examples/quick_start/images/cat.jpeg new file mode 100644 index 000000000..a6a8e48c9 Binary files /dev/null and b/examples/quick_start/images/cat.jpeg differ diff --git a/examples/quick_start/images/dog.jpeg b/examples/quick_start/images/dog.jpeg new file mode 100644 index 000000000..dc4b40e21 Binary files /dev/null and b/examples/quick_start/images/dog.jpeg differ diff --git a/python/sglang/api.py b/python/sglang/api.py index b81c41983..5e519257a 100644 --- a/python/sglang/api.py +++ b/python/sglang/api.py @@ -4,6 +4,7 @@ from typing import Callable, List, Optional, Union from sglang.backend.anthropic import Anthropic from sglang.backend.base_backend import BaseBackend +from sglang.backend.gemini import Gemini from sglang.backend.openai import OpenAI from sglang.backend.runtime_endpoint import RuntimeEndpoint from sglang.global_config import global_config diff --git a/python/sglang/backend/gemini.py b/python/sglang/backend/gemini.py new file mode 100644 index 000000000..3ce10cf4e --- /dev/null +++ b/python/sglang/backend/gemini.py @@ -0,0 +1,152 @@ +import os +import warnings +from typing import List, Optional, Union + +import numpy as np +from sglang.backend.base_backend import BaseBackend +from sglang.lang.chat_template import get_chat_template +from sglang.lang.interpreter import StreamExecutor +from sglang.lang.ir import SglSamplingParams + +try: + import vertexai + from vertexai.preview.generative_models import ( + GenerationConfig, + GenerativeModel, + Image, + ) +except ImportError as e: + GenerativeModel = e + +GEMINI_MODEL_NAMES = [ + "gemini-pro", + "gemini-pro-vision", +] + + +class Gemini(BaseBackend): + def __init__(self, model_name): + super().__init__() + + if isinstance(GenerativeModel, Exception): + raise GenerativeModel + + project_id = os.environ["GCP_PROJECT_ID"] + location = os.environ["GCP_LOCATION"] + vertexai.init(project=project_id, location=location) + + self.model_name = model_name + self.chat_template = get_chat_template("default") + + def get_chat_template(self): + return self.chat_template + + def generate( + self, + s: StreamExecutor, + sampling_params: SglSamplingParams, + ): + if s.messages_: + prompt = self.messages_to_gemini_input(s.messages_) + else: + # single-turn + prompt = ( + self.text_to_gemini_input(s.text_, s.cur_images) + if s.cur_images + else s.text_ + ) + ret = GenerativeModel(self.model_name).generate_content( + prompt, + generation_config=GenerationConfig(**sampling_params.to_gemini_kwargs()), + ) + + comp = ret.text + + return comp, {} + + def generate_stream( + self, + s: StreamExecutor, + sampling_params: SglSamplingParams, + ): + if s.messages_: + prompt = self.messages_to_gemini_input(s.messages_) + else: + # single-turn + prompt = ( + self.text_to_gemini_input(s.text_, s.cur_images) + if s.cur_images + else s.text_ + ) + generator = GenerativeModel(self.model_name).generate_content( + prompt, + stream=True, + generation_config=GenerationConfig(**sampling_params.to_gemini_kwargs()), + ) + for ret in generator: + yield ret.text, {} + + def text_to_gemini_input(self, text, images): + input = [] + # split with image token + text_segs = text.split(self.chat_template.image_token) + for image_path, image_base64_data in images: + text_seg = text_segs.pop(0) + if text_seg != "": + input.append(text_seg) + input.append(Image.from_bytes(image_base64_data)) + text_seg = text_segs.pop(0) + if text_seg != "": + input.append(text_seg) + return input + + def messages_to_gemini_input(self, messages): + gemini_message = [] + # from openai message format to gemini message format + for msg in messages: + if isinstance(msg["content"], str): + text = msg["content"] + else: + text = msg["content"][0]["text"] + + if msg["role"] == "system": + warnings.warn("Warning: system prompt is not supported in Gemini.") + gemini_message.append( + { + "role": "user", + "parts": [{"text": "System prompt: " + text}], + } + ) + gemini_message.append( + { + "role": "model", + "parts": [{"text": "Understood."}], + } + ) + continue + if msg["role"] == "user": + gemini_msg = { + "role": "user", + "parts": [{"text": text}], + } + elif msg["role"] == "assistant": + gemini_msg = { + "role": "model", + "parts": [{"text": text}], + } + + # images + if isinstance(msg["content"], list) and len(msg["content"]) > 1: + for image in msg["content"][1:]: + assert image["type"] == "image_url" + gemini_msg["parts"].append( + { + "inline_data": { + "data": image["image_url"]["url"].split(",")[1], + "mime_type": "image/jpeg", + } + } + ) + + gemini_message.append(gemini_msg) + return gemini_message diff --git a/python/sglang/lang/interpreter.py b/python/sglang/lang/interpreter.py index 5a6bb72a1..302ec96c5 100644 --- a/python/sglang/lang/interpreter.py +++ b/python/sglang/lang/interpreter.py @@ -428,6 +428,7 @@ class StreamExecutor: self.messages_.append(last_msg) self.cur_images = [] else: + # OpenAI chat API format self.messages_.append({"role": expr.role, "content": new_text}) self.cur_role = None diff --git a/python/sglang/lang/ir.py b/python/sglang/lang/ir.py index 33612c6b5..bafdea43e 100644 --- a/python/sglang/lang/ir.py +++ b/python/sglang/lang/ir.py @@ -49,6 +49,16 @@ class SglSamplingParams: "presence_penalty": self.presence_penalty, } + def to_gemini_kwargs(self): + return { + "candidate_count": 1, + "max_output_tokens": self.max_new_tokens, + "stop_sequences": self.stop, + "temperature": self.temperature, + "top_p": self.top_p, + "top_k": self.top_k if self.top_k > 0 else None, + } + def to_anthropic_kwargs(self): # Anthropic does not support frequency_penalty or presence_penalty, so we drop it here return { diff --git a/python/sglang/srt/models/mixtral.py b/python/sglang/srt/models/mixtral.py index edca306f3..0a82d3dd8 100644 --- a/python/sglang/srt/models/mixtral.py +++ b/python/sglang/srt/models/mixtral.py @@ -355,7 +355,7 @@ class MixtralForCausalLM(nn.Module): ): if "rotary_emb.inv_freq" in name: continue - for (param_name, weight_name, shard_id) in stacked_params_mapping: + for param_name, weight_name, shard_id in stacked_params_mapping: if weight_name not in name: continue name = name.replace(weight_name, param_name) diff --git a/python/sglang/test/test_programs.py b/python/sglang/test/test_programs.py index 04d9c1223..b4252e4db 100644 --- a/python/sglang/test/test_programs.py +++ b/python/sglang/test/test_programs.py @@ -304,7 +304,10 @@ def test_image_qa(): temperature=0, max_new_tokens=64, ) - assert "taxi" in state.messages()[-1]["content"] + assert ( + "taxi" in state.messages()[-1]["content"] + or "car" in state.messages()[-1]["content"] + ) def test_stream(): diff --git a/test/lang/test_gemini_backend.py b/test/lang/test_gemini_backend.py new file mode 100644 index 000000000..f2e1e83a7 --- /dev/null +++ b/test/lang/test_gemini_backend.py @@ -0,0 +1,66 @@ +import unittest + +from sglang.test.test_programs import ( + test_expert_answer, + test_few_shot_qa, + test_image_qa, + test_mt_bench, + test_parallel_decoding, + test_parallel_encoding, + test_stream, +) + +from sglang import Gemini, set_default_backend + + +class TestGeminiBackend(unittest.TestCase): + backend = None + chat_backend = None + chat_vision_backend = None + + def setUp(self): + cls = type(self) + + if cls.backend is None: + cls.backend = Gemini("gemini-pro") + cls.chat_backend = Gemini("gemini-pro") + cls.chat_vision_backend = Gemini("gemini-pro-vision") + + def test_few_shot_qa(self): + set_default_backend(self.backend) + test_few_shot_qa() + + def test_mt_bench(self): + set_default_backend(self.chat_backend) + test_mt_bench() + + def test_expert_answer(self): + set_default_backend(self.backend) + test_expert_answer() + + def test_parallel_decoding(self): + set_default_backend(self.backend) + test_parallel_decoding() + + def test_parallel_encoding(self): + set_default_backend(self.backend) + test_parallel_encoding() + + def test_image_qa(self): + set_default_backend(self.chat_vision_backend) + test_image_qa() + + def test_stream(self): + set_default_backend(self.backend) + test_stream() + + +if __name__ == "__main__": + unittest.main(warnings="ignore") + + # from sglang.global_config import global_config + + # global_config.verbosity = 2 + # t = TestGeminiBackend() + # t.setUp() + # t.test_stream() diff --git a/test/lang/test_openai_backend.py b/test/lang/test_openai_backend.py index f5590e13d..236c548a8 100644 --- a/test/lang/test_openai_backend.py +++ b/test/lang/test_openai_backend.py @@ -88,4 +88,15 @@ if __name__ == "__main__": # global_config.verbosity = 2 # t = TestOpenAIBackend() # t.setUp() + # t.test_few_shot_qa() + # t.test_mt_bench() + # t.test_select() + # t.test_decode_int() # t.test_decode_json() + # t.test_expert_answer() + # t.test_tool_use() + # t.test_react() + # t.test_parallel_decoding() + # t.test_parallel_encoding() + # t.test_image_qa() + # t.test_stream()