[Fix] Fix llava on multi images (#1247)
This commit is contained in:
@@ -26,7 +26,7 @@ import struct
|
||||
import time
|
||||
from importlib.metadata import PackageNotFoundError, version
|
||||
from io import BytesIO
|
||||
from typing import List, Optional
|
||||
from typing import List, Optional, Union
|
||||
|
||||
import numpy as np
|
||||
import psutil
|
||||
@@ -193,35 +193,16 @@ def allocate_init_ports(
|
||||
return ret_ports[0], ret_ports[1:num_ports_needed]
|
||||
|
||||
|
||||
def get_int_token_logit_bias(tokenizer, vocab_size):
|
||||
"""Get the logit bias for integer-only tokens."""
|
||||
# a bug when model's vocab size > tokenizer.vocab_size
|
||||
if tokenizer == None:
|
||||
return [-1e5] * vocab_size
|
||||
vocab_size = tokenizer.vocab_size
|
||||
logit_bias = np.zeros(vocab_size, dtype=np.float32)
|
||||
for t_id in range(vocab_size):
|
||||
ss = tokenizer.decode([t_id]).strip()
|
||||
if not (ss.isdigit() or len(ss) == 0 or t_id == tokenizer.eos_token_id):
|
||||
logit_bias[t_id] = -1e5
|
||||
|
||||
return logit_bias
|
||||
|
||||
|
||||
def is_multimodal_model(model):
|
||||
from sglang.srt.model_config import ModelConfig
|
||||
|
||||
if isinstance(model, str):
|
||||
model = model.lower()
|
||||
return "llava" in model or "yi-vl" in model or "llava-next" in model
|
||||
|
||||
if isinstance(model, ModelConfig):
|
||||
model_path = model.path.lower()
|
||||
return (
|
||||
"llava" in model_path or "yi-vl" in model_path or "llava-next" in model_path
|
||||
)
|
||||
|
||||
raise ValueError("unrecognized type")
|
||||
def is_multimodal_model(model_architectures):
|
||||
if (
|
||||
"LlavaLlamaForCausalLM" in model_architectures
|
||||
or "LlavaQwenForCausalLM" in model_architectures
|
||||
or "LlavaMistralForCausalLM" in model_architectures
|
||||
or "LlavaVidForCausalLM" in model_architectures
|
||||
):
|
||||
return True
|
||||
else:
|
||||
return False
|
||||
|
||||
|
||||
def is_generation_model(model_architectures, is_embedding: bool = False):
|
||||
@@ -317,12 +298,14 @@ def decode_video_base64(video_base64):
|
||||
) # Return an empty array and size tuple if no frames were found
|
||||
|
||||
|
||||
def load_image(image_file):
|
||||
def load_image(image_file: Union[str, bytes]):
|
||||
from PIL import Image
|
||||
|
||||
image = image_size = None
|
||||
|
||||
if image_file.startswith("http://") or image_file.startswith("https://"):
|
||||
if isinstance(image_file, bytes):
|
||||
image = Image.open(BytesIO(image_file))
|
||||
elif image_file.startswith("http://") or image_file.startswith("https://"):
|
||||
timeout = int(os.getenv("REQUEST_TIMEOUT", "3"))
|
||||
response = requests.get(image_file, timeout=timeout)
|
||||
image = Image.open(BytesIO(response.content))
|
||||
@@ -334,8 +317,10 @@ def load_image(image_file):
|
||||
elif image_file.startswith("video:"):
|
||||
image_file = image_file.replace("video:", "")
|
||||
image, image_size = decode_video_base64(image_file)
|
||||
else:
|
||||
elif isinstance(image_file, str):
|
||||
image = Image.open(BytesIO(base64.b64decode(image_file)))
|
||||
else:
|
||||
raise ValueError(f"Invalid image: {image}")
|
||||
|
||||
return image, image_size
|
||||
|
||||
|
||||
Reference in New Issue
Block a user