26
examples/quick_start/gemini_example_complete.py
Normal file
26
examples/quick_start/gemini_example_complete.py
Normal file
@@ -0,0 +1,26 @@
|
|||||||
|
from sglang import function, gen, set_default_backend, Gemini
|
||||||
|
|
||||||
|
|
||||||
|
@function
|
||||||
|
def few_shot_qa(s, question):
|
||||||
|
s += (
|
||||||
|
"""The following are questions with answers.
|
||||||
|
Q: What is the capital of France?
|
||||||
|
A: Paris
|
||||||
|
Q: What is the capital of Germany?
|
||||||
|
A: Berlin
|
||||||
|
Q: What is the capital of Italy?
|
||||||
|
A: Rome
|
||||||
|
""")
|
||||||
|
s += "Q: " + question + "\n"
|
||||||
|
s += "A:" + gen("answer", stop="\n", temperature=0)
|
||||||
|
|
||||||
|
|
||||||
|
set_default_backend(Gemini("gemini-pro"))
|
||||||
|
|
||||||
|
state = few_shot_qa.run(question="What is the capital of the United States?")
|
||||||
|
answer = state["answer"].strip().lower()
|
||||||
|
|
||||||
|
assert "washington" in answer, f"answer: {state['answer']}"
|
||||||
|
|
||||||
|
print(state.text())
|
||||||
19
examples/quick_start/gemini_example_multimodal_chat.py
Normal file
19
examples/quick_start/gemini_example_multimodal_chat.py
Normal file
@@ -0,0 +1,19 @@
|
|||||||
|
from sglang import function, user, assistant, gen, image, set_default_backend, Gemini
|
||||||
|
|
||||||
|
|
||||||
|
@function
|
||||||
|
def image_qa(s, image_file1, image_file2, question):
|
||||||
|
s += user(image(image_file1) + image(image_file2) + question)
|
||||||
|
s += assistant(gen("answer_1", max_tokens=256))
|
||||||
|
|
||||||
|
set_default_backend(Gemini("gemini-pro-vision"))
|
||||||
|
|
||||||
|
state = image_qa.run(
|
||||||
|
image_file1="./images/cat.jpeg",
|
||||||
|
image_file2="./images/dog.jpeg",
|
||||||
|
question="Describe difference of the 2 images in one sentence.",
|
||||||
|
stream=True
|
||||||
|
)
|
||||||
|
|
||||||
|
for out in state.text_iter():
|
||||||
|
print(out, end="", flush=True)
|
||||||
20
examples/quick_start/gemini_example_stream.py
Normal file
20
examples/quick_start/gemini_example_stream.py
Normal file
@@ -0,0 +1,20 @@
|
|||||||
|
from sglang import function, user, assistant, gen, set_default_backend, Gemini
|
||||||
|
|
||||||
|
|
||||||
|
@function
|
||||||
|
def multi_turn_question(s, question_1, question_2):
|
||||||
|
s += user(question_1)
|
||||||
|
s += assistant(gen("answer_1", max_tokens=256))
|
||||||
|
s += user(question_2)
|
||||||
|
s += assistant(gen("answer_2", max_tokens=256))
|
||||||
|
|
||||||
|
set_default_backend(Gemini("gemini-pro"))
|
||||||
|
|
||||||
|
state = multi_turn_question.run(
|
||||||
|
question_1="What is the capital of the United States?",
|
||||||
|
question_2="List two local attractions.",
|
||||||
|
stream=True
|
||||||
|
)
|
||||||
|
|
||||||
|
for out in state.text_iter():
|
||||||
|
print(out, end="", flush=True)
|
||||||
BIN
examples/quick_start/images/cat.jpeg
Normal file
BIN
examples/quick_start/images/cat.jpeg
Normal file
Binary file not shown.
|
After Width: | Height: | Size: 337 KiB |
BIN
examples/quick_start/images/dog.jpeg
Normal file
BIN
examples/quick_start/images/dog.jpeg
Normal file
Binary file not shown.
|
After Width: | Height: | Size: 407 KiB |
@@ -4,6 +4,7 @@ from typing import Callable, List, Optional, Union
|
|||||||
|
|
||||||
from sglang.backend.anthropic import Anthropic
|
from sglang.backend.anthropic import Anthropic
|
||||||
from sglang.backend.base_backend import BaseBackend
|
from sglang.backend.base_backend import BaseBackend
|
||||||
|
from sglang.backend.gemini import Gemini
|
||||||
from sglang.backend.openai import OpenAI
|
from sglang.backend.openai import OpenAI
|
||||||
from sglang.backend.runtime_endpoint import RuntimeEndpoint
|
from sglang.backend.runtime_endpoint import RuntimeEndpoint
|
||||||
from sglang.global_config import global_config
|
from sglang.global_config import global_config
|
||||||
|
|||||||
152
python/sglang/backend/gemini.py
Normal file
152
python/sglang/backend/gemini.py
Normal file
@@ -0,0 +1,152 @@
|
|||||||
|
import os
|
||||||
|
import warnings
|
||||||
|
from typing import List, Optional, Union
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
from sglang.backend.base_backend import BaseBackend
|
||||||
|
from sglang.lang.chat_template import get_chat_template
|
||||||
|
from sglang.lang.interpreter import StreamExecutor
|
||||||
|
from sglang.lang.ir import SglSamplingParams
|
||||||
|
|
||||||
|
try:
|
||||||
|
import vertexai
|
||||||
|
from vertexai.preview.generative_models import (
|
||||||
|
GenerationConfig,
|
||||||
|
GenerativeModel,
|
||||||
|
Image,
|
||||||
|
)
|
||||||
|
except ImportError as e:
|
||||||
|
GenerativeModel = e
|
||||||
|
|
||||||
|
GEMINI_MODEL_NAMES = [
|
||||||
|
"gemini-pro",
|
||||||
|
"gemini-pro-vision",
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
class Gemini(BaseBackend):
|
||||||
|
def __init__(self, model_name):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
if isinstance(GenerativeModel, Exception):
|
||||||
|
raise GenerativeModel
|
||||||
|
|
||||||
|
project_id = os.environ["GCP_PROJECT_ID"]
|
||||||
|
location = os.environ["GCP_LOCATION"]
|
||||||
|
vertexai.init(project=project_id, location=location)
|
||||||
|
|
||||||
|
self.model_name = model_name
|
||||||
|
self.chat_template = get_chat_template("default")
|
||||||
|
|
||||||
|
def get_chat_template(self):
|
||||||
|
return self.chat_template
|
||||||
|
|
||||||
|
def generate(
|
||||||
|
self,
|
||||||
|
s: StreamExecutor,
|
||||||
|
sampling_params: SglSamplingParams,
|
||||||
|
):
|
||||||
|
if s.messages_:
|
||||||
|
prompt = self.messages_to_gemini_input(s.messages_)
|
||||||
|
else:
|
||||||
|
# single-turn
|
||||||
|
prompt = (
|
||||||
|
self.text_to_gemini_input(s.text_, s.cur_images)
|
||||||
|
if s.cur_images
|
||||||
|
else s.text_
|
||||||
|
)
|
||||||
|
ret = GenerativeModel(self.model_name).generate_content(
|
||||||
|
prompt,
|
||||||
|
generation_config=GenerationConfig(**sampling_params.to_gemini_kwargs()),
|
||||||
|
)
|
||||||
|
|
||||||
|
comp = ret.text
|
||||||
|
|
||||||
|
return comp, {}
|
||||||
|
|
||||||
|
def generate_stream(
|
||||||
|
self,
|
||||||
|
s: StreamExecutor,
|
||||||
|
sampling_params: SglSamplingParams,
|
||||||
|
):
|
||||||
|
if s.messages_:
|
||||||
|
prompt = self.messages_to_gemini_input(s.messages_)
|
||||||
|
else:
|
||||||
|
# single-turn
|
||||||
|
prompt = (
|
||||||
|
self.text_to_gemini_input(s.text_, s.cur_images)
|
||||||
|
if s.cur_images
|
||||||
|
else s.text_
|
||||||
|
)
|
||||||
|
generator = GenerativeModel(self.model_name).generate_content(
|
||||||
|
prompt,
|
||||||
|
stream=True,
|
||||||
|
generation_config=GenerationConfig(**sampling_params.to_gemini_kwargs()),
|
||||||
|
)
|
||||||
|
for ret in generator:
|
||||||
|
yield ret.text, {}
|
||||||
|
|
||||||
|
def text_to_gemini_input(self, text, images):
|
||||||
|
input = []
|
||||||
|
# split with image token
|
||||||
|
text_segs = text.split(self.chat_template.image_token)
|
||||||
|
for image_path, image_base64_data in images:
|
||||||
|
text_seg = text_segs.pop(0)
|
||||||
|
if text_seg != "":
|
||||||
|
input.append(text_seg)
|
||||||
|
input.append(Image.from_bytes(image_base64_data))
|
||||||
|
text_seg = text_segs.pop(0)
|
||||||
|
if text_seg != "":
|
||||||
|
input.append(text_seg)
|
||||||
|
return input
|
||||||
|
|
||||||
|
def messages_to_gemini_input(self, messages):
|
||||||
|
gemini_message = []
|
||||||
|
# from openai message format to gemini message format
|
||||||
|
for msg in messages:
|
||||||
|
if isinstance(msg["content"], str):
|
||||||
|
text = msg["content"]
|
||||||
|
else:
|
||||||
|
text = msg["content"][0]["text"]
|
||||||
|
|
||||||
|
if msg["role"] == "system":
|
||||||
|
warnings.warn("Warning: system prompt is not supported in Gemini.")
|
||||||
|
gemini_message.append(
|
||||||
|
{
|
||||||
|
"role": "user",
|
||||||
|
"parts": [{"text": "System prompt: " + text}],
|
||||||
|
}
|
||||||
|
)
|
||||||
|
gemini_message.append(
|
||||||
|
{
|
||||||
|
"role": "model",
|
||||||
|
"parts": [{"text": "Understood."}],
|
||||||
|
}
|
||||||
|
)
|
||||||
|
continue
|
||||||
|
if msg["role"] == "user":
|
||||||
|
gemini_msg = {
|
||||||
|
"role": "user",
|
||||||
|
"parts": [{"text": text}],
|
||||||
|
}
|
||||||
|
elif msg["role"] == "assistant":
|
||||||
|
gemini_msg = {
|
||||||
|
"role": "model",
|
||||||
|
"parts": [{"text": text}],
|
||||||
|
}
|
||||||
|
|
||||||
|
# images
|
||||||
|
if isinstance(msg["content"], list) and len(msg["content"]) > 1:
|
||||||
|
for image in msg["content"][1:]:
|
||||||
|
assert image["type"] == "image_url"
|
||||||
|
gemini_msg["parts"].append(
|
||||||
|
{
|
||||||
|
"inline_data": {
|
||||||
|
"data": image["image_url"]["url"].split(",")[1],
|
||||||
|
"mime_type": "image/jpeg",
|
||||||
|
}
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
gemini_message.append(gemini_msg)
|
||||||
|
return gemini_message
|
||||||
@@ -428,6 +428,7 @@ class StreamExecutor:
|
|||||||
self.messages_.append(last_msg)
|
self.messages_.append(last_msg)
|
||||||
self.cur_images = []
|
self.cur_images = []
|
||||||
else:
|
else:
|
||||||
|
# OpenAI chat API format
|
||||||
self.messages_.append({"role": expr.role, "content": new_text})
|
self.messages_.append({"role": expr.role, "content": new_text})
|
||||||
|
|
||||||
self.cur_role = None
|
self.cur_role = None
|
||||||
|
|||||||
@@ -49,6 +49,16 @@ class SglSamplingParams:
|
|||||||
"presence_penalty": self.presence_penalty,
|
"presence_penalty": self.presence_penalty,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
def to_gemini_kwargs(self):
|
||||||
|
return {
|
||||||
|
"candidate_count": 1,
|
||||||
|
"max_output_tokens": self.max_new_tokens,
|
||||||
|
"stop_sequences": self.stop,
|
||||||
|
"temperature": self.temperature,
|
||||||
|
"top_p": self.top_p,
|
||||||
|
"top_k": self.top_k if self.top_k > 0 else None,
|
||||||
|
}
|
||||||
|
|
||||||
def to_anthropic_kwargs(self):
|
def to_anthropic_kwargs(self):
|
||||||
# Anthropic does not support frequency_penalty or presence_penalty, so we drop it here
|
# Anthropic does not support frequency_penalty or presence_penalty, so we drop it here
|
||||||
return {
|
return {
|
||||||
|
|||||||
@@ -355,7 +355,7 @@ class MixtralForCausalLM(nn.Module):
|
|||||||
):
|
):
|
||||||
if "rotary_emb.inv_freq" in name:
|
if "rotary_emb.inv_freq" in name:
|
||||||
continue
|
continue
|
||||||
for (param_name, weight_name, shard_id) in stacked_params_mapping:
|
for param_name, weight_name, shard_id in stacked_params_mapping:
|
||||||
if weight_name not in name:
|
if weight_name not in name:
|
||||||
continue
|
continue
|
||||||
name = name.replace(weight_name, param_name)
|
name = name.replace(weight_name, param_name)
|
||||||
|
|||||||
@@ -304,7 +304,10 @@ def test_image_qa():
|
|||||||
temperature=0,
|
temperature=0,
|
||||||
max_new_tokens=64,
|
max_new_tokens=64,
|
||||||
)
|
)
|
||||||
assert "taxi" in state.messages()[-1]["content"]
|
assert (
|
||||||
|
"taxi" in state.messages()[-1]["content"]
|
||||||
|
or "car" in state.messages()[-1]["content"]
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def test_stream():
|
def test_stream():
|
||||||
|
|||||||
66
test/lang/test_gemini_backend.py
Normal file
66
test/lang/test_gemini_backend.py
Normal file
@@ -0,0 +1,66 @@
|
|||||||
|
import unittest
|
||||||
|
|
||||||
|
from sglang.test.test_programs import (
|
||||||
|
test_expert_answer,
|
||||||
|
test_few_shot_qa,
|
||||||
|
test_image_qa,
|
||||||
|
test_mt_bench,
|
||||||
|
test_parallel_decoding,
|
||||||
|
test_parallel_encoding,
|
||||||
|
test_stream,
|
||||||
|
)
|
||||||
|
|
||||||
|
from sglang import Gemini, set_default_backend
|
||||||
|
|
||||||
|
|
||||||
|
class TestGeminiBackend(unittest.TestCase):
|
||||||
|
backend = None
|
||||||
|
chat_backend = None
|
||||||
|
chat_vision_backend = None
|
||||||
|
|
||||||
|
def setUp(self):
|
||||||
|
cls = type(self)
|
||||||
|
|
||||||
|
if cls.backend is None:
|
||||||
|
cls.backend = Gemini("gemini-pro")
|
||||||
|
cls.chat_backend = Gemini("gemini-pro")
|
||||||
|
cls.chat_vision_backend = Gemini("gemini-pro-vision")
|
||||||
|
|
||||||
|
def test_few_shot_qa(self):
|
||||||
|
set_default_backend(self.backend)
|
||||||
|
test_few_shot_qa()
|
||||||
|
|
||||||
|
def test_mt_bench(self):
|
||||||
|
set_default_backend(self.chat_backend)
|
||||||
|
test_mt_bench()
|
||||||
|
|
||||||
|
def test_expert_answer(self):
|
||||||
|
set_default_backend(self.backend)
|
||||||
|
test_expert_answer()
|
||||||
|
|
||||||
|
def test_parallel_decoding(self):
|
||||||
|
set_default_backend(self.backend)
|
||||||
|
test_parallel_decoding()
|
||||||
|
|
||||||
|
def test_parallel_encoding(self):
|
||||||
|
set_default_backend(self.backend)
|
||||||
|
test_parallel_encoding()
|
||||||
|
|
||||||
|
def test_image_qa(self):
|
||||||
|
set_default_backend(self.chat_vision_backend)
|
||||||
|
test_image_qa()
|
||||||
|
|
||||||
|
def test_stream(self):
|
||||||
|
set_default_backend(self.backend)
|
||||||
|
test_stream()
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
unittest.main(warnings="ignore")
|
||||||
|
|
||||||
|
# from sglang.global_config import global_config
|
||||||
|
|
||||||
|
# global_config.verbosity = 2
|
||||||
|
# t = TestGeminiBackend()
|
||||||
|
# t.setUp()
|
||||||
|
# t.test_stream()
|
||||||
@@ -88,4 +88,15 @@ if __name__ == "__main__":
|
|||||||
# global_config.verbosity = 2
|
# global_config.verbosity = 2
|
||||||
# t = TestOpenAIBackend()
|
# t = TestOpenAIBackend()
|
||||||
# t.setUp()
|
# t.setUp()
|
||||||
|
# t.test_few_shot_qa()
|
||||||
|
# t.test_mt_bench()
|
||||||
|
# t.test_select()
|
||||||
|
# t.test_decode_int()
|
||||||
# t.test_decode_json()
|
# t.test_decode_json()
|
||||||
|
# t.test_expert_answer()
|
||||||
|
# t.test_tool_use()
|
||||||
|
# t.test_react()
|
||||||
|
# t.test_parallel_decoding()
|
||||||
|
# t.test_parallel_encoding()
|
||||||
|
# t.test_image_qa()
|
||||||
|
# t.test_stream()
|
||||||
|
|||||||
Reference in New Issue
Block a user