91 lines
2.5 KiB
Python
91 lines
2.5 KiB
Python
"""
|
|
Usage: python3 srt_example_llava.py
|
|
"""
|
|
|
|
from PIL import ImageFile
|
|
|
|
import sglang as sgl
|
|
from sglang.lang.chat_template import get_chat_template
|
|
from sglang.srt.utils import load_image
|
|
|
|
ImageFile.LOAD_TRUNCATED_IMAGES = True # Allow loading of truncated images
|
|
|
|
|
|
@sgl.function
|
|
def image_qa(s, image, question):
|
|
s += sgl.user(sgl.image(image) + question)
|
|
s += sgl.assistant(sgl.gen("answer"))
|
|
|
|
|
|
def single():
|
|
image_url = "https://farm4.staticflickr.com/3175/2653711032_804ff86d81_z.jpg"
|
|
pil_image, _ = load_image(image_url)
|
|
state = image_qa.run(image=pil_image, question="What is this?", max_new_tokens=512)
|
|
print(state["answer"], "\n")
|
|
|
|
|
|
def stream():
|
|
image_url = "https://farm4.staticflickr.com/3175/2653711032_804ff86d81_z.jpg"
|
|
pil_image, _ = load_image(image_url)
|
|
state = image_qa.run(
|
|
image=pil_image,
|
|
question="Please generate short caption for this image.",
|
|
max_new_tokens=512,
|
|
temperature=0,
|
|
stream=True,
|
|
)
|
|
|
|
for out in state.text_iter("answer"):
|
|
print(out, end="", flush=True)
|
|
print()
|
|
|
|
|
|
def batch():
|
|
image_url = "https://farm4.staticflickr.com/3175/2653711032_804ff86d81_z.jpg"
|
|
pil_image, _ = load_image(image_url)
|
|
states = image_qa.run_batch(
|
|
[
|
|
{"image": pil_image, "question": "What is this?"},
|
|
{"image": pil_image, "question": "What is this?"},
|
|
],
|
|
max_new_tokens=512,
|
|
)
|
|
for s in states:
|
|
print(s["answer"], "\n")
|
|
|
|
|
|
if __name__ == "__main__":
|
|
import multiprocessing as mp
|
|
|
|
mp.set_start_method("spawn", force=True)
|
|
runtime = sgl.Runtime(
|
|
model_path="lmms-lab/llama3-llava-next-8b",
|
|
tokenizer_path="lmms-lab/llama3-llava-next-8b-tokenizer",
|
|
)
|
|
runtime.endpoint.chat_template = get_chat_template("llama-3-instruct")
|
|
# runtime = sgl.Runtime(
|
|
# model_path="lmms-lab/llava-next-72b",
|
|
# tokenizer_path="lmms-lab/llavanext-qwen-tokenizer",
|
|
# )
|
|
# runtime.endpoint.chat_template = get_chat_template("chatml-llava")
|
|
sgl.set_default_backend(runtime)
|
|
print(f"chat template: {runtime.endpoint.chat_template.name}")
|
|
|
|
# Or you can use API models
|
|
# sgl.set_default_backend(sgl.OpenAI("gpt-4-vision-preview"))
|
|
# sgl.set_default_backend(sgl.VertexAI("gemini-pro-vision"))
|
|
|
|
# Run a single request
|
|
print("\n========== single ==========\n")
|
|
single()
|
|
|
|
# Stream output
|
|
print("\n========== stream ==========\n")
|
|
stream()
|
|
|
|
# Run a batch of requests
|
|
print("\n========== batch ==========\n")
|
|
batch()
|
|
|
|
runtime.shutdown()
|