[Fix] Fix llava on multi images (#1247)

This commit is contained in:
Lianmin Zheng
2024-08-28 06:33:05 -07:00
committed by GitHub
parent b1a540ec42
commit bf53bf5142
22 changed files with 272 additions and 488 deletions

View File

@@ -26,7 +26,7 @@ import struct
import time
from importlib.metadata import PackageNotFoundError, version
from io import BytesIO
from typing import List, Optional
from typing import List, Optional, Union
import numpy as np
import psutil
@@ -193,35 +193,16 @@ def allocate_init_ports(
return ret_ports[0], ret_ports[1:num_ports_needed]
def get_int_token_logit_bias(tokenizer, vocab_size):
"""Get the logit bias for integer-only tokens."""
# a bug when model's vocab size > tokenizer.vocab_size
if tokenizer == None:
return [-1e5] * vocab_size
vocab_size = tokenizer.vocab_size
logit_bias = np.zeros(vocab_size, dtype=np.float32)
for t_id in range(vocab_size):
ss = tokenizer.decode([t_id]).strip()
if not (ss.isdigit() or len(ss) == 0 or t_id == tokenizer.eos_token_id):
logit_bias[t_id] = -1e5
return logit_bias
def is_multimodal_model(model):
from sglang.srt.model_config import ModelConfig
if isinstance(model, str):
model = model.lower()
return "llava" in model or "yi-vl" in model or "llava-next" in model
if isinstance(model, ModelConfig):
model_path = model.path.lower()
return (
"llava" in model_path or "yi-vl" in model_path or "llava-next" in model_path
)
raise ValueError("unrecognized type")
def is_multimodal_model(model_architectures):
if (
"LlavaLlamaForCausalLM" in model_architectures
or "LlavaQwenForCausalLM" in model_architectures
or "LlavaMistralForCausalLM" in model_architectures
or "LlavaVidForCausalLM" in model_architectures
):
return True
else:
return False
def is_generation_model(model_architectures, is_embedding: bool = False):
@@ -317,12 +298,14 @@ def decode_video_base64(video_base64):
) # Return an empty array and size tuple if no frames were found
def load_image(image_file):
def load_image(image_file: Union[str, bytes]):
from PIL import Image
image = image_size = None
if image_file.startswith("http://") or image_file.startswith("https://"):
if isinstance(image_file, bytes):
image = Image.open(BytesIO(image_file))
elif image_file.startswith("http://") or image_file.startswith("https://"):
timeout = int(os.getenv("REQUEST_TIMEOUT", "3"))
response = requests.get(image_file, timeout=timeout)
image = Image.open(BytesIO(response.content))
@@ -334,8 +317,10 @@ def load_image(image_file):
elif image_file.startswith("video:"):
image_file = image_file.replace("video:", "")
image, image_size = decode_video_base64(image_file)
else:
elif isinstance(image_file, str):
image = Image.open(BytesIO(base64.b64decode(image_file)))
else:
raise ValueError(f"Invalid image: {image}")
return image, image_size