refactor: move image processors to separate files (#4229)
This commit is contained in:
@@ -191,7 +191,7 @@ class Conversation:
|
||||
|
||||
for i, (role, message) in enumerate(self.messages):
|
||||
if i % 2 == 0:
|
||||
ret += f"[Round {i//2 + round_add_n}]{self.sep}"
|
||||
ret += f"[Round {i // 2 + round_add_n}]{self.sep}"
|
||||
|
||||
if message:
|
||||
ret += f"{role}:{message}{self.sep}"
|
||||
@@ -453,7 +453,6 @@ def generate_chat_conv(
|
||||
conv.system_message = getattr(message.content[0], "text", "")
|
||||
elif msg_role == "user":
|
||||
# Handle the various types of Chat Request content types here.
|
||||
role = conv.roles[0]
|
||||
if isinstance(message.content, str):
|
||||
conv.append_message(conv.roles[0], message.content)
|
||||
else:
|
||||
|
||||
@@ -66,6 +66,7 @@ def get_config(
|
||||
config = AutoConfig.from_pretrained(
|
||||
model, trust_remote_code=trust_remote_code, revision=revision, **kwargs
|
||||
)
|
||||
|
||||
if config.model_type in _CONFIG_REGISTRY:
|
||||
config_class = _CONFIG_REGISTRY[config.model_type]
|
||||
config = config_class.from_pretrained(model, revision=revision)
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from functools import lru_cache
|
||||
from typing import Optional
|
||||
from typing import Optional, Tuple
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
@@ -22,47 +22,29 @@ from sglang.srt.layers.quantization import QuantizationConfig
|
||||
from sglang.srt.utils import add_prefix
|
||||
|
||||
|
||||
def rotate_half(x: torch.Tensor, interleaved: bool = False) -> torch.Tensor:
|
||||
if not interleaved:
|
||||
x1, x2 = x.chunk(2, dim=-1)
|
||||
return torch.cat((-x2, x1), dim=-1)
|
||||
else:
|
||||
x1, x2 = x[..., ::2], x[..., 1::2]
|
||||
return rearrange(
|
||||
torch.stack((-x2, x1), dim=-1), "... d two -> ... (d two)", two=2
|
||||
)
|
||||
# Copied from transformers, modeling_qwen2_vl.py
|
||||
def rotate_half(x):
|
||||
"""Rotates half the hidden dims of the input."""
|
||||
x1 = x[..., : x.shape[-1] // 2]
|
||||
x2 = x[..., x.shape[-1] // 2 :]
|
||||
return torch.cat((-x2, x1), dim=-1)
|
||||
|
||||
|
||||
def apply_rotary_emb_torch(
|
||||
x: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor, interleaved: bool = False
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
x: (batch_size, seqlen, nheads, headdim)
|
||||
cos, sin: (seqlen, rotary_dim / 2) or (batch_size, seqlen, rotary_dim / 2)
|
||||
"""
|
||||
ro_dim = cos.shape[-1] * 2
|
||||
assert ro_dim <= x.shape[-1]
|
||||
cos = repeat(
|
||||
cos, "... d -> ... 1 (2 d)" if not interleaved else "... d -> ... 1 (d 2)"
|
||||
)
|
||||
sin = repeat(
|
||||
sin, "... d -> ... 1 (2 d)" if not interleaved else "... d -> ... 1 (d 2)"
|
||||
)
|
||||
return torch.cat(
|
||||
[
|
||||
x[..., :ro_dim] * cos + rotate_half(x[..., :ro_dim], interleaved) * sin,
|
||||
x[..., ro_dim:],
|
||||
],
|
||||
dim=-1,
|
||||
)
|
||||
def apply_rotary_pos_emb_vision(
|
||||
q: torch.Tensor, k: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
orig_q_dtype = q.dtype
|
||||
orig_k_dtype = k.dtype
|
||||
q, k = q.float(), k.float()
|
||||
|
||||
cos, sin = cos.unsqueeze(-2).float(), sin.unsqueeze(-2).float()
|
||||
q_embed = (q * cos) + (rotate_half(q) * sin)
|
||||
k_embed = (k * cos) + (rotate_half(k) * sin)
|
||||
|
||||
def apply_rotary_pos_emb_vision(t: torch.Tensor, freqs: torch.Tensor) -> torch.Tensor:
|
||||
t_ = t.float()
|
||||
cos = freqs.cos()
|
||||
sin = freqs.sin()
|
||||
output = apply_rotary_emb_torch(t_, cos, sin).type_as(t)
|
||||
return output
|
||||
q_embed = q_embed.to(orig_q_dtype)
|
||||
k_embed = k_embed.to(orig_k_dtype)
|
||||
|
||||
return q_embed, k_embed
|
||||
|
||||
|
||||
class VisionAttention(nn.Module):
|
||||
@@ -75,8 +57,8 @@ class VisionAttention(nn.Module):
|
||||
use_context_forward (bool, default to True):
|
||||
if ``True``, a flash_attn style attention will be applied
|
||||
Otherwise, a full-sequence attention will be applied.
|
||||
use_full_precision_softmax (bool, default to False):
|
||||
if ``True``, the softmax will be performed in full-precision
|
||||
softmax_in_single_precision (bool, default to False):
|
||||
if ``True``, the softmax will be performed in single-precision
|
||||
Otherwise, it will be performed in half-precision
|
||||
|
||||
"""
|
||||
@@ -90,7 +72,7 @@ class VisionAttention(nn.Module):
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
dropout: float = 0.0,
|
||||
use_context_forward: bool = True,
|
||||
use_full_precision_softmax: bool = False,
|
||||
softmax_in_single_precision: bool = False,
|
||||
flatten_batch: bool = False,
|
||||
prefix: str = "",
|
||||
):
|
||||
@@ -113,7 +95,7 @@ class VisionAttention(nn.Module):
|
||||
head_size=self.head_size,
|
||||
dropout=dropout,
|
||||
flatten_batch=flatten_batch,
|
||||
use_full_precision_softmax=use_full_precision_softmax,
|
||||
softmax_in_single_precision=softmax_in_single_precision,
|
||||
)
|
||||
|
||||
self.use_qkv_parallel = use_qkv_parallel
|
||||
@@ -143,7 +125,7 @@ class VisionAttention(nn.Module):
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
cu_seqlens: Optional[torch.Tensor] = None,
|
||||
rotary_pos_emb: torch.Tensor = None,
|
||||
position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
) -> torch.Tensor:
|
||||
r"""
|
||||
@@ -151,21 +133,17 @@ class VisionAttention(nn.Module):
|
||||
x: [b, s, embed_dim]
|
||||
cu_seqlens: [b]
|
||||
Returns:
|
||||
[s, b, num_heads * head]
|
||||
[s, b, head * head_size]
|
||||
"""
|
||||
bsz, s, _ = x.shape
|
||||
head = self.num_attention_heads_per_partition
|
||||
if self.use_qkv_parallel:
|
||||
# [b, s, embed_dim] --> [b, s, embed_dim]
|
||||
qkv, _ = self.qkv_proj(x)
|
||||
q, k, v = qkv.chunk(3, dim=-1)
|
||||
|
||||
# [b, s, embed_dim] --> [b * s, num_heads, head_size]
|
||||
q, k, v = [
|
||||
x.reshape(
|
||||
bsz * s, self.num_attention_heads_per_partition, -1
|
||||
).contiguous()
|
||||
for x in (q, k, v)
|
||||
]
|
||||
# [b, s, embed_dim] --> [b * s, head, head_size]
|
||||
q, k, v = [x.reshape(bsz * s, head, -1).contiguous() for x in (q, k, v)]
|
||||
else:
|
||||
# [b, s, embed_dim] --> [s, b, embed_dim]
|
||||
x = rearrange(x, "b s ... -> s b ...")
|
||||
@@ -173,7 +151,7 @@ class VisionAttention(nn.Module):
|
||||
qkv, _ = self.qkv_proj(x)
|
||||
# [s, b, head * 3 * head_size] --> [s, b, head, 3 * head_size]
|
||||
new_x_shape = qkv.size()[:-1] + (
|
||||
self.num_attention_heads_per_partition,
|
||||
head,
|
||||
3 * self.hidden_size_per_attention_head,
|
||||
)
|
||||
qkv = qkv.view(*new_x_shape)
|
||||
@@ -186,9 +164,12 @@ class VisionAttention(nn.Module):
|
||||
rearrange(x, "s b ... -> b s ...").contiguous() for x in (q, k, v)
|
||||
]
|
||||
|
||||
if rotary_pos_emb is not None:
|
||||
q = apply_rotary_pos_emb_vision(q, rotary_pos_emb)
|
||||
k = apply_rotary_pos_emb_vision(k, rotary_pos_emb)
|
||||
if position_embeddings is not None:
|
||||
cos, sin = position_embeddings
|
||||
original_shape = q.shape
|
||||
q, k = q.view(s, head, -1), k.view(s, head, -1)
|
||||
q, k = apply_rotary_pos_emb_vision(q, k, cos, sin)
|
||||
q, k = q.reshape(original_shape), k.reshape(original_shape)
|
||||
|
||||
if self.use_qkv_parallel:
|
||||
pass
|
||||
@@ -230,12 +211,12 @@ class VisionSdpaAttention(nn.Module):
|
||||
head_size: int,
|
||||
dropout: float = 0.0,
|
||||
flatten_batch: bool = False,
|
||||
use_full_precision_softmax: bool = False,
|
||||
softmax_in_single_precision: bool = False,
|
||||
):
|
||||
super().__init__()
|
||||
self.head_size = head_size
|
||||
self.flatten_batch = flatten_batch
|
||||
self.use_full_precision_softmax = use_full_precision_softmax
|
||||
self.softmax_in_single_precision = softmax_in_single_precision
|
||||
self.dropout = dropout
|
||||
|
||||
@staticmethod
|
||||
@@ -319,14 +300,14 @@ class VisionSdpaAttention(nn.Module):
|
||||
)
|
||||
|
||||
if attention_mask is None:
|
||||
if self.use_full_precision_softmax:
|
||||
if self.softmax_in_single_precision:
|
||||
raise RuntimeError("Empty attention mask")
|
||||
else:
|
||||
attention_mask = attention_mask.to(device=q.device)
|
||||
|
||||
q, k, v = [rearrange(x, "(b s) h d -> b h s d", b=bsz) for x in [q, k, v]]
|
||||
|
||||
if self.use_full_precision_softmax:
|
||||
if self.softmax_in_single_precision:
|
||||
scale = self.head_size**-0.5
|
||||
k_transposed = rearrange(k, "b h s d -> b h d s")
|
||||
attn_weights = torch.matmul(q, k_transposed) * scale
|
||||
|
||||
@@ -1,649 +1,55 @@
|
||||
# TODO: also move pad_input_ids into this module
|
||||
import asyncio
|
||||
import concurrent.futures
|
||||
import dataclasses
|
||||
import importlib
|
||||
import logging
|
||||
import multiprocessing as mp
|
||||
import os
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import List, Optional, Union
|
||||
import pkgutil
|
||||
from functools import lru_cache
|
||||
|
||||
import numpy as np
|
||||
import PIL
|
||||
import transformers
|
||||
from decord import VideoReader, cpu
|
||||
from PIL import Image
|
||||
from transformers import IMAGE_PROCESSOR_MAPPING
|
||||
|
||||
from sglang.srt.hf_transformers_utils import get_processor
|
||||
from sglang.srt.mm_utils import expand2square, process_anyres_image
|
||||
from sglang.srt.managers.image_processors.base_image_processor import (
|
||||
BaseImageProcessor,
|
||||
DummyImageProcessor,
|
||||
)
|
||||
from sglang.srt.server_args import ServerArgs
|
||||
from sglang.srt.utils import load_image
|
||||
from sglang.utils import get_exception_traceback
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
global global_processor
|
||||
|
||||
|
||||
def init_global_processor(server_args: ServerArgs):
|
||||
"""Init the global processor for multi modal models."""
|
||||
global global_processor
|
||||
transformers.logging.set_verbosity_error()
|
||||
global_processor = get_processor(
|
||||
server_args.tokenizer_path,
|
||||
tokenizer_mode=server_args.tokenizer_mode,
|
||||
trust_remote_code=server_args.trust_remote_code,
|
||||
)
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class BaseImageProcessorOutput:
|
||||
image_hashes: list[int]
|
||||
image_sizes: list[int]
|
||||
all_frames: [PIL.Image]
|
||||
# input_text, with each frame of video/image represented with a image_token
|
||||
input_text: str
|
||||
|
||||
|
||||
class BaseImageProcessor(ABC):
|
||||
def __init__(self, hf_config, server_args, _processor):
|
||||
self.hf_config = hf_config
|
||||
self._processor = _processor
|
||||
self.server_args = server_args
|
||||
# FIXME: not accurate, model and image specific
|
||||
self.NUM_TOKEN_PER_FRAME = 330
|
||||
|
||||
self.executor = concurrent.futures.ProcessPoolExecutor(
|
||||
initializer=init_global_processor,
|
||||
mp_context=mp.get_context("fork"),
|
||||
initargs=(server_args,),
|
||||
max_workers=int(os.environ.get("SGLANG_CPU_COUNT", os.cpu_count())),
|
||||
)
|
||||
|
||||
@abstractmethod
|
||||
async def process_images_async(
|
||||
self, image_data, input_text, max_req_input_len, **kwargs
|
||||
):
|
||||
pass
|
||||
|
||||
def get_estimated_frames_list(self, image_data):
|
||||
"""
|
||||
estimate the total frame count from all visual input
|
||||
"""
|
||||
# Before processing inputs
|
||||
estimated_frames_list = []
|
||||
for image in image_data:
|
||||
if isinstance(image, str) and image.startswith("video:"):
|
||||
path = image[len("video:") :]
|
||||
# Estimate frames for the video
|
||||
vr = VideoReader(path, ctx=cpu(0))
|
||||
num_frames = len(vr)
|
||||
else:
|
||||
# For images, each contributes one frame
|
||||
num_frames = 1
|
||||
estimated_frames_list.append(num_frames)
|
||||
|
||||
return estimated_frames_list
|
||||
|
||||
def encode_video(self, video_path, frame_count_limit=None):
|
||||
if not os.path.exists(video_path):
|
||||
logger.error(f"Video {video_path} does not exist")
|
||||
return []
|
||||
|
||||
if frame_count_limit == 0:
|
||||
return []
|
||||
|
||||
def uniform_sample(l, n):
|
||||
gap = len(l) / n
|
||||
idxs = [int(i * gap + gap / 2) for i in range(n)]
|
||||
return [l[i] for i in idxs]
|
||||
|
||||
vr = VideoReader(video_path, ctx=cpu(0))
|
||||
sample_fps = round(vr.get_avg_fps() / 1) # FPS
|
||||
frame_idx = [i for i in range(0, len(vr), sample_fps)]
|
||||
if frame_count_limit is not None and len(frame_idx) > frame_count_limit:
|
||||
frame_idx = uniform_sample(frame_idx, frame_count_limit)
|
||||
frames = vr.get_batch(frame_idx).asnumpy()
|
||||
frames = [Image.fromarray(v.astype("uint8")) for v in frames]
|
||||
return frames
|
||||
|
||||
def load_images(
|
||||
self,
|
||||
max_req_input_len: int,
|
||||
input_ids: list,
|
||||
image_data,
|
||||
image_token: str,
|
||||
) -> BaseImageProcessorOutput:
|
||||
"""
|
||||
Each frame of video/image will be replaced by a single image token
|
||||
"""
|
||||
image_hashes, image_sizes = [], []
|
||||
all_frames = []
|
||||
new_text_parts = []
|
||||
|
||||
if isinstance(input_ids, list):
|
||||
assert len(input_ids) and isinstance(input_ids[0], int)
|
||||
input_text = self._processor.tokenizer.decode(input_ids)
|
||||
else:
|
||||
input_text = input_ids
|
||||
|
||||
text_parts = input_text.split(image_token)
|
||||
|
||||
# roughly calculate the max number of frames under the max_req_input_len limit
|
||||
def calculate_max_num_frames() -> int:
|
||||
ret = (max_req_input_len - len(input_ids)) // self.NUM_TOKEN_PER_FRAME
|
||||
return min(ret, 100)
|
||||
|
||||
MAX_NUM_FRAMES = calculate_max_num_frames()
|
||||
estimated_frames_list = self.get_estimated_frames_list(image_data=image_data)
|
||||
total_frame_count = sum(estimated_frames_list)
|
||||
# a heuristic value, suggesting the maximum fraction of frames to embed from all visual inputs.
|
||||
# e.g., 0.1 suggests that 1 frame out of 10 input frames should be used
|
||||
scaling_factor = min(1.0, MAX_NUM_FRAMES / total_frame_count)
|
||||
|
||||
# Process each input with allocated frames
|
||||
for image_index, (image, estimated_frames) in enumerate(
|
||||
zip(image_data, estimated_frames_list)
|
||||
):
|
||||
if len(all_frames) >= MAX_NUM_FRAMES:
|
||||
frames_to_process = 0
|
||||
else:
|
||||
frames_to_process = max(1, int(estimated_frames * scaling_factor))
|
||||
|
||||
if frames_to_process == 0:
|
||||
frames = []
|
||||
else:
|
||||
try:
|
||||
if isinstance(image, str) and image.startswith("video:"):
|
||||
path = image[len("video:") :]
|
||||
frames = self.encode_video(
|
||||
path, frame_count_limit=frames_to_process
|
||||
)
|
||||
else:
|
||||
raw_image, _size = load_image(image)
|
||||
frames = [raw_image]
|
||||
if len(frames) == 0:
|
||||
continue
|
||||
except FileNotFoundError as e:
|
||||
print(e)
|
||||
return None
|
||||
image_sizes += frames[0].size * len(frames)
|
||||
image_hashes += [hash(image)] * len(frames)
|
||||
all_frames += frames
|
||||
|
||||
new_text_parts.append(text_parts[image_index])
|
||||
if frames_to_process != 0:
|
||||
new_text_parts.append(image_token * len(frames))
|
||||
assert frames_to_process == len(frames)
|
||||
|
||||
new_text_parts.append(text_parts[-1])
|
||||
|
||||
input_text = "".join(new_text_parts)
|
||||
return BaseImageProcessorOutput(
|
||||
image_hashes, image_sizes, all_frames, input_text
|
||||
)
|
||||
|
||||
|
||||
class DummyImageProcessor(BaseImageProcessor):
|
||||
def __init__(self):
|
||||
pass
|
||||
|
||||
async def process_images_async(self, *args, **kwargs):
|
||||
return None
|
||||
|
||||
|
||||
class LlavaImageProcessor(BaseImageProcessor):
|
||||
def __init__(self, hf_config, server_args, _processor):
|
||||
super().__init__(hf_config, server_args, _processor)
|
||||
|
||||
@staticmethod
|
||||
def _process_single_image_task(
|
||||
image_data: Union[str, bytes],
|
||||
image_aspect_ratio: Optional[str] = None,
|
||||
image_grid_pinpoints: Optional[str] = None,
|
||||
image_processor=None,
|
||||
):
|
||||
image_processor = image_processor or global_processor.image_processor
|
||||
|
||||
try:
|
||||
image, image_size = load_image(image_data)
|
||||
if image_size is not None:
|
||||
# It is a video with multiple images
|
||||
image_hash = hash(image_data)
|
||||
pixel_values = image_processor(image)["pixel_values"]
|
||||
for _ in range(len(pixel_values)):
|
||||
pixel_values[_] = pixel_values[_].astype(np.float16)
|
||||
pixel_values = np.stack(pixel_values, axis=0)
|
||||
return pixel_values, image_hash, image_size
|
||||
else:
|
||||
# It is an image
|
||||
image_hash = hash(image_data)
|
||||
if image_aspect_ratio == "pad":
|
||||
image = expand2square(
|
||||
image,
|
||||
tuple(int(x * 255) for x in image_processor.image_mean),
|
||||
)
|
||||
pixel_values = image_processor(image.convert("RGB"))[
|
||||
"pixel_values"
|
||||
][0]
|
||||
elif image_aspect_ratio == "anyres" or (
|
||||
image_aspect_ratio is not None
|
||||
and "anyres_max" in image_aspect_ratio
|
||||
):
|
||||
pixel_values = process_anyres_image(
|
||||
image, image_processor, image_grid_pinpoints
|
||||
)
|
||||
else:
|
||||
pixel_values = image_processor(image)["pixel_values"][0]
|
||||
|
||||
if isinstance(pixel_values, np.ndarray):
|
||||
pixel_values = pixel_values.astype(np.float16)
|
||||
|
||||
return pixel_values, image_hash, image.size
|
||||
except Exception:
|
||||
logger.error("Exception in TokenizerManager:\n" + get_exception_traceback())
|
||||
|
||||
async def _process_single_image(
|
||||
self, image_data: Union[bytes, str], aspect_ratio: str, grid_pinpoints: str
|
||||
):
|
||||
if self.executor is not None:
|
||||
loop = asyncio.get_event_loop()
|
||||
return await loop.run_in_executor(
|
||||
self.executor,
|
||||
LlavaImageProcessor._process_single_image_task,
|
||||
image_data,
|
||||
aspect_ratio,
|
||||
grid_pinpoints,
|
||||
)
|
||||
else:
|
||||
return self._process_single_image_task(
|
||||
image_data, aspect_ratio, grid_pinpoints
|
||||
)
|
||||
|
||||
async def process_images_async(
|
||||
self,
|
||||
image_data: List[Union[str, bytes]],
|
||||
input_text,
|
||||
request_obj,
|
||||
*args,
|
||||
**kwargs,
|
||||
):
|
||||
if not image_data:
|
||||
return None
|
||||
|
||||
modalities = request_obj.modalities or ["image"]
|
||||
aspect_ratio = getattr(self.hf_config, "image_aspect_ratio", None)
|
||||
grid_pinpoints = (
|
||||
self.hf_config.image_grid_pinpoints
|
||||
if hasattr(self.hf_config, "image_grid_pinpoints")
|
||||
and "anyres" in aspect_ratio
|
||||
else None
|
||||
)
|
||||
|
||||
if isinstance(image_data, str):
|
||||
image_data = [image_data]
|
||||
|
||||
if isinstance(image_data, list) and len(image_data) > 0:
|
||||
if "multi-images" in modalities or "video" in modalities:
|
||||
# Multiple images
|
||||
aspect_ratio = "pad" # LLaVA OneVision Handling: more than one image --> interleaved image mode or video mode. We do not use anyres
|
||||
pixel_values, image_hashes, image_sizes = [], [], []
|
||||
res = []
|
||||
for img_data in image_data:
|
||||
res.append(
|
||||
self._process_single_image(
|
||||
img_data, aspect_ratio, grid_pinpoints
|
||||
)
|
||||
)
|
||||
res = await asyncio.gather(*res)
|
||||
for pixel_v, image_h, image_s in res:
|
||||
pixel_values.append(pixel_v)
|
||||
image_hashes.append(image_h)
|
||||
image_sizes.append(image_s)
|
||||
|
||||
if isinstance(pixel_values[0], np.ndarray):
|
||||
pixel_values = np.stack(pixel_values, axis=0)
|
||||
else:
|
||||
# A single image
|
||||
pixel_values, image_hash, image_size = await self._process_single_image(
|
||||
image_data[0], aspect_ratio, grid_pinpoints
|
||||
)
|
||||
image_hashes = [image_hash]
|
||||
image_sizes = [image_size]
|
||||
else:
|
||||
raise ValueError(f"Invalid image data: {image_data}")
|
||||
|
||||
return {
|
||||
"pixel_values": pixel_values,
|
||||
"image_hashes": image_hashes,
|
||||
"image_sizes": image_sizes,
|
||||
"modalities": request_obj.modalities or ["image"],
|
||||
}
|
||||
|
||||
|
||||
class MllamaImageProcessor(BaseImageProcessor):
|
||||
def __init__(self, hf_config, server_args, _processor):
|
||||
super().__init__(hf_config, server_args, _processor)
|
||||
|
||||
@staticmethod
|
||||
def _process_single_image_task(images, input_text):
|
||||
# input_ids', 'attention_mask', 'pixel_values', 'aspect_ratio_ids', 'aspect_ratio_mask', 'cross_attention_mask'
|
||||
return global_processor(images, input_text, return_tensors="pt")
|
||||
|
||||
async def _process_single_image(self, images, input_text):
|
||||
if self.executor is not None:
|
||||
loop = asyncio.get_event_loop()
|
||||
image_inputs = await loop.run_in_executor(
|
||||
self.executor,
|
||||
MllamaImageProcessor._process_single_image_task,
|
||||
images,
|
||||
input_text,
|
||||
)
|
||||
else:
|
||||
image_inputs = self._processor(images, input_text, return_tensors="pt")
|
||||
|
||||
return image_inputs
|
||||
|
||||
async def process_images_async(
|
||||
self, image_data: List[Union[str, bytes]], input_text, *args, **kwargs
|
||||
):
|
||||
if not image_data:
|
||||
return None
|
||||
|
||||
if isinstance(input_text, list):
|
||||
assert len(input_text) and isinstance(input_text[0], int)
|
||||
input_text = self._processor.tokenizer.decode(input_text)
|
||||
|
||||
if not isinstance(image_data, list):
|
||||
image_data = [image_data]
|
||||
|
||||
if len(image_data) > 0:
|
||||
images = [load_image(image)[0] for image in image_data]
|
||||
else:
|
||||
images = load_image(image_data[0])[0]
|
||||
|
||||
image_inputs = await self._process_single_image(images, input_text)
|
||||
image_inputs["image_hashes"] = [hash(str(image_data))]
|
||||
image_inputs["input_ids"] = image_inputs["input_ids"].tolist()[0]
|
||||
|
||||
return image_inputs
|
||||
|
||||
|
||||
class MiniCPMVImageProcessor(BaseImageProcessor):
|
||||
def __init__(self, hf_config, server_args, _processor):
|
||||
super().__init__(hf_config, server_args, _processor)
|
||||
self.IMAGE_TOKEN = "(<image>./</image>)"
|
||||
|
||||
@staticmethod
|
||||
def _process_images_task(images, input_text):
|
||||
result = global_processor.__call__(
|
||||
text=input_text, images=images, return_tensors="pt"
|
||||
)
|
||||
return {
|
||||
"input_ids": result.input_ids,
|
||||
"pixel_values": result.pixel_values,
|
||||
"tgt_sizes": result.tgt_sizes,
|
||||
}
|
||||
|
||||
async def _process_images(self, images, input_text):
|
||||
if self.executor is not None:
|
||||
loop = asyncio.get_event_loop()
|
||||
image_inputs = await loop.run_in_executor(
|
||||
self.executor,
|
||||
MiniCPMVImageProcessor._process_images_task,
|
||||
images,
|
||||
input_text,
|
||||
)
|
||||
else:
|
||||
image_inputs = self._processor(
|
||||
images=images, text=input_text, return_tensors="pt"
|
||||
)
|
||||
|
||||
return image_inputs
|
||||
|
||||
async def process_images_async(
|
||||
self,
|
||||
image_data: List[Union[str, bytes]],
|
||||
input_ids,
|
||||
request_obj,
|
||||
max_req_input_len,
|
||||
):
|
||||
if not image_data:
|
||||
return None
|
||||
if not isinstance(image_data, list):
|
||||
image_data = [image_data]
|
||||
|
||||
base_output = self.load_images(
|
||||
max_req_input_len, input_ids, image_data, self.IMAGE_TOKEN
|
||||
)
|
||||
if base_output is None:
|
||||
return None
|
||||
|
||||
if len(base_output.all_frames) == 0:
|
||||
return None
|
||||
res = await self._process_images(
|
||||
images=base_output.all_frames, input_text=base_output.input_text
|
||||
)
|
||||
|
||||
# Collect special token ids
|
||||
tokenizer = self._processor.tokenizer
|
||||
im_start_id = [tokenizer.im_start_id]
|
||||
im_end_id = [tokenizer.im_end_id]
|
||||
if tokenizer.slice_start_id:
|
||||
slice_start_id = [tokenizer.slice_start_id]
|
||||
slice_end_id = [tokenizer.slice_end_id]
|
||||
return {
|
||||
"input_ids": res["input_ids"].flatten().tolist(),
|
||||
"pixel_values": res["pixel_values"],
|
||||
"tgt_sizes": res["tgt_sizes"],
|
||||
"image_hashes": base_output.image_hashes,
|
||||
"modalities": request_obj.modalities or ["image"],
|
||||
"im_start_id": im_start_id,
|
||||
"im_end_id": im_end_id,
|
||||
"slice_start_id": slice_start_id,
|
||||
"slice_end_id": slice_end_id,
|
||||
}
|
||||
|
||||
|
||||
class Qwen2VLImageProcessor(BaseImageProcessor):
|
||||
def __init__(self, hf_config, server_args, _image_processor):
|
||||
self.hf_config = hf_config
|
||||
self._image_processor = _image_processor
|
||||
self.executor = concurrent.futures.ProcessPoolExecutor(
|
||||
initializer=init_global_processor,
|
||||
mp_context=mp.get_context("fork"),
|
||||
initargs=(server_args,),
|
||||
max_workers=int(os.environ.get("SGLANG_CPU_COUNT", os.cpu_count())),
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _process_single_image_task(
|
||||
image_data: Union[str, bytes],
|
||||
image_processor=None,
|
||||
):
|
||||
image_processor = image_processor or global_processor.image_processor
|
||||
|
||||
try:
|
||||
image, image_size = load_image(image_data)
|
||||
if image_size is not None:
|
||||
# It is a video with multiple images
|
||||
image_hash = hash(image_data)
|
||||
process_result = image_processor(image)
|
||||
pixel_values, image_grid_thws = (
|
||||
process_result["pixel_values"],
|
||||
process_result["image_grid_thw"][0],
|
||||
)
|
||||
for _ in range(len(pixel_values)):
|
||||
pixel_values[_] = pixel_values[_].astype(np.float16)
|
||||
pixel_values = np.stack(pixel_values, axis=0)
|
||||
image_grid_thws = np.stack(image_grid_thws, axis=0)
|
||||
return pixel_values, image_hash, image_size, image_grid_thws
|
||||
else:
|
||||
# It is an image
|
||||
image_hash = hash(image_data)
|
||||
process_result = image_processor(image)
|
||||
pixel_values, image_grid_thws = (
|
||||
process_result["pixel_values"],
|
||||
process_result["image_grid_thw"][0],
|
||||
)
|
||||
if isinstance(pixel_values, np.ndarray):
|
||||
pixel_values = pixel_values.astype(np.float16)
|
||||
|
||||
return pixel_values, image_hash, image.size, image_grid_thws
|
||||
except Exception:
|
||||
logger.error("Exception in TokenizerManager:\n" + get_exception_traceback())
|
||||
|
||||
async def _process_single_image(self, image_data: Union[bytes, str]):
|
||||
if self.executor is not None:
|
||||
loop = asyncio.get_event_loop()
|
||||
return await loop.run_in_executor(
|
||||
self.executor,
|
||||
Qwen2VLImageProcessor._process_single_image_task,
|
||||
image_data,
|
||||
)
|
||||
else:
|
||||
return self._process_single_image_task(image_data)
|
||||
|
||||
async def process_images_async(
|
||||
self,
|
||||
image_data: List[Union[str, bytes]],
|
||||
input_text,
|
||||
request_obj,
|
||||
*args,
|
||||
**kwargs,
|
||||
):
|
||||
if not image_data:
|
||||
return None
|
||||
|
||||
if isinstance(image_data, list) and len(image_data) > 0:
|
||||
# Multiple images
|
||||
if len(image_data) > 1:
|
||||
pixel_values, image_hashes, image_sizes, image_grid_thws = (
|
||||
[],
|
||||
[],
|
||||
[],
|
||||
[],
|
||||
)
|
||||
res = []
|
||||
for img_data in image_data:
|
||||
res.append(self._process_single_image(img_data))
|
||||
res = await asyncio.gather(*res)
|
||||
for pixel_v, image_h, image_s, image_thw in res:
|
||||
pixel_values.append(pixel_v)
|
||||
image_hashes.append(image_h)
|
||||
image_sizes.append(image_s)
|
||||
image_grid_thws.append(image_thw)
|
||||
|
||||
if isinstance(pixel_values[0], np.ndarray):
|
||||
pixel_values = np.concatenate(pixel_values, axis=0)
|
||||
else:
|
||||
# A single image
|
||||
pixel_values, image_hash, image_size, image_grid_thw = (
|
||||
await self._process_single_image(image_data[0])
|
||||
)
|
||||
image_hashes = [image_hash]
|
||||
image_sizes = [image_size]
|
||||
image_grid_thws = [image_grid_thw]
|
||||
elif isinstance(image_data, str) or isinstance(image_data, bytes):
|
||||
# A single image
|
||||
pixel_values, image_hash, image_size, image_grid_thw = (
|
||||
await self._process_single_image(image_data)
|
||||
)
|
||||
image_hashes = [image_hash]
|
||||
image_sizes = [image_size]
|
||||
image_grid_thws = [image_grid_thw]
|
||||
else:
|
||||
|
||||
raise ValueError(f"Invalid image data: {image_data}")
|
||||
|
||||
return {
|
||||
"pixel_values": pixel_values,
|
||||
"image_hashes": image_hashes,
|
||||
"image_sizes": image_sizes,
|
||||
"modalities": request_obj.modalities or ["image"],
|
||||
"image_grid_thws": image_grid_thws,
|
||||
}
|
||||
|
||||
|
||||
class Qwen2_5VLImageProcessor(BaseImageProcessor):
|
||||
def __init__(self, hf_config, server_args, _processor):
|
||||
super().__init__(hf_config, server_args, _processor)
|
||||
self.IMAGE_TOKEN = "<|vision_start|><|image_pad|><|vision_end|>"
|
||||
self.IM_START_TOKEN_ID = hf_config.vision_start_token_id
|
||||
self.IM_END_TOKEN_ID = hf_config.vision_end_token_id
|
||||
self.NUM_TOKEN_PER_FRAME = 770
|
||||
|
||||
@staticmethod
|
||||
def _process_images_task(images, input_text):
|
||||
result = global_processor.__call__(
|
||||
text=input_text, images=images, return_tensors="pt"
|
||||
)
|
||||
return {
|
||||
"input_ids": result.input_ids,
|
||||
"pixel_values": result.pixel_values,
|
||||
"image_grid_thws": result.image_grid_thw,
|
||||
}
|
||||
|
||||
async def _process_images(self, images, input_text) -> dict:
|
||||
if self.executor is not None:
|
||||
loop = asyncio.get_event_loop()
|
||||
return await loop.run_in_executor(
|
||||
self.executor,
|
||||
Qwen2_5VLImageProcessor._process_images_task,
|
||||
images,
|
||||
input_text,
|
||||
)
|
||||
else:
|
||||
return self._process_images_task(images, input_text)
|
||||
|
||||
async def process_images_async(
|
||||
self,
|
||||
image_data: List[Union[str, bytes]],
|
||||
input_ids,
|
||||
request_obj,
|
||||
max_req_input_len,
|
||||
*args,
|
||||
**kwargs,
|
||||
):
|
||||
if not image_data:
|
||||
return None
|
||||
if isinstance(image_data, str):
|
||||
image_data = [image_data]
|
||||
|
||||
image_token = self.IMAGE_TOKEN
|
||||
base_output = self.load_images(
|
||||
max_req_input_len, input_ids, image_data, image_token
|
||||
)
|
||||
|
||||
ret = await self._process_images(base_output.all_frames, base_output.input_text)
|
||||
|
||||
return {
|
||||
"input_ids": ret["input_ids"].flatten().tolist(),
|
||||
"pixel_values": ret["pixel_values"],
|
||||
"image_hashes": base_output.image_hashes,
|
||||
"modalities": request_obj.modalities or ["image"],
|
||||
"image_grid_thws": ret["image_grid_thws"],
|
||||
"im_start_id": self.IM_START_TOKEN_ID,
|
||||
"im_end_id": self.IM_END_TOKEN_ID,
|
||||
}
|
||||
IMAGE_PROCESSOR_MAPPING = {}
|
||||
|
||||
|
||||
def get_image_processor(
|
||||
hf_config, server_args: ServerArgs, processor
|
||||
) -> BaseImageProcessor:
|
||||
if "MllamaForConditionalGeneration" in hf_config.architectures:
|
||||
return MllamaImageProcessor(hf_config, server_args, processor)
|
||||
elif "Qwen2VLForConditionalGeneration" in hf_config.architectures:
|
||||
|
||||
return Qwen2VLImageProcessor(hf_config, server_args, processor)
|
||||
elif "Qwen2_5_VLForConditionalGeneration" in hf_config.architectures:
|
||||
return Qwen2_5VLImageProcessor(hf_config, server_args, processor)
|
||||
|
||||
elif "MiniCPMV" in hf_config.architectures:
|
||||
return MiniCPMVImageProcessor(hf_config, server_args, processor)
|
||||
else:
|
||||
return LlavaImageProcessor(hf_config, server_args, processor.image_processor)
|
||||
for model_cls, processor_cls in IMAGE_PROCESSOR_MAPPING.items():
|
||||
if model_cls.__name__ in hf_config.architectures:
|
||||
return processor_cls(hf_config, server_args, processor)
|
||||
raise ValueError(
|
||||
f"No image processor found for architecture: {hf_config.architectures}"
|
||||
)
|
||||
|
||||
|
||||
def get_dummy_image_processor():
|
||||
return DummyImageProcessor()
|
||||
|
||||
|
||||
@lru_cache()
|
||||
def import_image_processors():
|
||||
package_name = "sglang.srt.managers.image_processors"
|
||||
package = importlib.import_module(package_name)
|
||||
for _, name, ispkg in pkgutil.iter_modules(package.__path__, package_name + "."):
|
||||
if not ispkg:
|
||||
try:
|
||||
module = importlib.import_module(name)
|
||||
except Exception as e:
|
||||
logger.warning(f"Ignore import error when loading {name}: " f"{e}")
|
||||
continue
|
||||
if hasattr(module, "ImageProcessorMapping"):
|
||||
entry = module.ImageProcessorMapping
|
||||
if isinstance(entry, dict):
|
||||
for processor_name, cls in entry.items():
|
||||
IMAGE_PROCESSOR_MAPPING[processor_name] = cls
|
||||
|
||||
|
||||
# also register processors
|
||||
import_image_processors()
|
||||
|
||||
@@ -0,0 +1,206 @@
|
||||
import concurrent
|
||||
import concurrent.futures
|
||||
import dataclasses
|
||||
import multiprocessing as mp
|
||||
import os
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Optional
|
||||
|
||||
import PIL
|
||||
import transformers
|
||||
from decord import VideoReader, cpu
|
||||
from PIL import Image
|
||||
|
||||
from sglang.srt.server_args import ServerArgs
|
||||
from sglang.srt.utils import load_image
|
||||
|
||||
global global_processor
|
||||
|
||||
|
||||
def get_global_processor():
|
||||
global global_processor
|
||||
return global_processor
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class BaseImageProcessorOutput:
|
||||
image_hashes: list[int]
|
||||
image_sizes: list[tuple[int, int]]
|
||||
all_frames: [PIL.Image]
|
||||
# input_text, with each frame of video/image represented as an image_token
|
||||
input_text: str
|
||||
|
||||
|
||||
class BaseImageProcessor(ABC):
|
||||
def __init__(self, hf_config, server_args, _processor):
|
||||
self.hf_config = hf_config
|
||||
self._processor = _processor
|
||||
self.server_args = server_args
|
||||
# FIXME: not accurate, model and image specific
|
||||
self.NUM_TOKEN_PER_FRAME = 330
|
||||
|
||||
self.executor = concurrent.futures.ProcessPoolExecutor(
|
||||
initializer=init_global_processor,
|
||||
mp_context=mp.get_context("fork"),
|
||||
initargs=(
|
||||
self,
|
||||
server_args,
|
||||
),
|
||||
max_workers=int(os.environ.get("SGLANG_CPU_COUNT", os.cpu_count())),
|
||||
)
|
||||
|
||||
def _build_processor(self, server_args):
|
||||
"""Init the global processor for multi modal models."""
|
||||
from sglang.srt.hf_transformers_utils import get_processor
|
||||
|
||||
return get_processor(
|
||||
server_args.tokenizer_path,
|
||||
tokenizer_mode=server_args.tokenizer_mode,
|
||||
trust_remote_code=server_args.trust_remote_code,
|
||||
)
|
||||
|
||||
@abstractmethod
|
||||
async def process_images_async(
|
||||
self, image_data, input_text, max_req_input_len, **kwargs
|
||||
):
|
||||
pass
|
||||
|
||||
def get_estimated_frames_list(self, image_data):
|
||||
"""
|
||||
estimate the total frame count from all visual input
|
||||
"""
|
||||
# Before processing inputs
|
||||
estimated_frames_list = []
|
||||
for image in image_data:
|
||||
if isinstance(image, str) and image.startswith("video:"):
|
||||
path = image[len("video:") :]
|
||||
# Estimate frames for the video
|
||||
vr = VideoReader(path, ctx=cpu(0))
|
||||
num_frames = len(vr)
|
||||
else:
|
||||
# For images, each contributes one frame
|
||||
num_frames = 1
|
||||
estimated_frames_list.append(num_frames)
|
||||
|
||||
return estimated_frames_list
|
||||
|
||||
@staticmethod
|
||||
def encode_video(video_path, frame_count_limit=None):
|
||||
if not os.path.exists(video_path):
|
||||
logger.error(f"Video {video_path} does not exist")
|
||||
return []
|
||||
|
||||
if frame_count_limit == 0:
|
||||
return []
|
||||
|
||||
def uniform_sample(l, n):
|
||||
gap = len(l) / n
|
||||
idxs = [int(i * gap + gap / 2) for i in range(n)]
|
||||
return [l[i] for i in idxs]
|
||||
|
||||
vr = VideoReader(video_path, ctx=cpu(0))
|
||||
sample_fps = round(vr.get_avg_fps() / 1) # FPS
|
||||
frame_indices = [i for i in range(0, len(vr), sample_fps)]
|
||||
if frame_count_limit is not None and len(frame_indices) > frame_count_limit:
|
||||
frame_indices = uniform_sample(frame_indices, frame_count_limit)
|
||||
|
||||
frames = vr.get_batch(frame_indices).asnumpy()
|
||||
frames = [Image.fromarray(v.astype("uint8")) for v in frames]
|
||||
return frames
|
||||
|
||||
def load_images(
|
||||
self,
|
||||
input_ids: list,
|
||||
image_data,
|
||||
image_token: str,
|
||||
max_req_input_len: int,
|
||||
return_text: Optional[bool] = True,
|
||||
discard_alpha_channel: bool = True,
|
||||
) -> BaseImageProcessorOutput:
|
||||
"""
|
||||
Each frame of video/image will be replaced by a single image token
|
||||
"""
|
||||
image_hashes, image_sizes = [], []
|
||||
all_frames = []
|
||||
new_text_parts = []
|
||||
|
||||
if isinstance(input_ids, list) and return_text:
|
||||
assert len(input_ids) and isinstance(input_ids[0], int)
|
||||
input_text = self._processor.tokenizer.decode(input_ids)
|
||||
else:
|
||||
input_text = input_ids
|
||||
|
||||
if return_text:
|
||||
text_parts = input_text.split(image_token)
|
||||
|
||||
# roughly calculate the max number of frames under the max_req_input_len limit
|
||||
MAX_NUM_FRAMES = 30
|
||||
estimated_frames_list = self.get_estimated_frames_list(image_data=image_data)
|
||||
total_frame_count = sum(estimated_frames_list)
|
||||
# a heuristic value, suggesting the maximum fraction of frames to embed from all visual inputs.
|
||||
# e.g., 0.1 suggests that 1 frame out of 10 input frames should be used
|
||||
scaling_factor = min(1.0, MAX_NUM_FRAMES / total_frame_count)
|
||||
|
||||
assert len(image_data) == len(estimated_frames_list)
|
||||
|
||||
# Process each input with allocated frames
|
||||
for image_index, (image, estimated_frames) in enumerate(
|
||||
zip(image_data, estimated_frames_list)
|
||||
):
|
||||
if len(all_frames) >= MAX_NUM_FRAMES:
|
||||
max_frames_to_process = 0
|
||||
else:
|
||||
max_frames_to_process = max(1, int(estimated_frames * scaling_factor))
|
||||
|
||||
if max_frames_to_process == 0:
|
||||
frames = []
|
||||
else:
|
||||
try:
|
||||
if isinstance(image, str) and image.startswith("video:"):
|
||||
path = image[len("video:") :]
|
||||
frames = BaseImageProcessor.encode_video(
|
||||
path, frame_count_limit=max_frames_to_process
|
||||
)
|
||||
else:
|
||||
raw_image, _size = load_image(image)
|
||||
if discard_alpha_channel:
|
||||
raw_image = raw_image.convert("RGB")
|
||||
frames = [raw_image]
|
||||
assert len(frames) != 0
|
||||
except FileNotFoundError as e:
|
||||
print(e)
|
||||
return None
|
||||
|
||||
image_sizes += [frames[0].size] * len(frames)
|
||||
image_hashes += [hash(image)] * len(frames)
|
||||
all_frames += frames
|
||||
|
||||
if return_text:
|
||||
new_text_parts.append(text_parts[image_index])
|
||||
if max_frames_to_process != 0:
|
||||
new_text_parts.append(image_token * len(frames))
|
||||
assert max_frames_to_process >= len(frames)
|
||||
if return_text:
|
||||
new_text_parts.append(text_parts[-1])
|
||||
|
||||
input_text = "".join(new_text_parts)
|
||||
return BaseImageProcessorOutput(
|
||||
image_hashes, image_sizes, all_frames, input_text
|
||||
)
|
||||
|
||||
|
||||
class DummyImageProcessor(BaseImageProcessor):
|
||||
def __init__(self):
|
||||
pass
|
||||
|
||||
async def process_images_async(self, *args, **kwargs):
|
||||
return None
|
||||
|
||||
|
||||
def init_global_processor(
|
||||
sglang_image_processor: BaseImageProcessor, server_args: ServerArgs
|
||||
):
|
||||
"""Init the global processor for multi-modal models."""
|
||||
global global_processor
|
||||
transformers.logging.set_verbosity_error()
|
||||
global_processor = sglang_image_processor._build_processor(server_args=server_args)
|
||||
152
python/sglang/srt/managers/image_processors/llava.py
Normal file
152
python/sglang/srt/managers/image_processors/llava.py
Normal file
@@ -0,0 +1,152 @@
|
||||
import asyncio
|
||||
from typing import List, Optional, Union
|
||||
|
||||
import numpy as np
|
||||
|
||||
from sglang.srt.managers.image_processor import BaseImageProcessor
|
||||
from sglang.srt.managers.image_processors.base_image_processor import (
|
||||
get_global_processor,
|
||||
)
|
||||
from sglang.srt.mm_utils import expand2square, process_anyres_image
|
||||
from sglang.srt.models.llava import LlavaMistralForCausalLM, LlavaQwenForCausalLM
|
||||
from sglang.srt.models.llavavid import LlavaVidForCausalLM
|
||||
from sglang.srt.utils import load_image, logger
|
||||
from sglang.utils import get_exception_traceback
|
||||
|
||||
|
||||
class LlavaImageProcessor(BaseImageProcessor):
|
||||
def __init__(self, hf_config, server_args, _processor):
|
||||
super().__init__(hf_config, server_args, _processor)
|
||||
|
||||
@staticmethod
|
||||
def _process_single_image_task(
|
||||
image_data: Union[str, bytes],
|
||||
image_aspect_ratio: Optional[str] = None,
|
||||
image_grid_pinpoints: Optional[str] = None,
|
||||
image_processor=None,
|
||||
):
|
||||
processor = get_global_processor()
|
||||
|
||||
image_processor = image_processor or processor.image_processor
|
||||
|
||||
try:
|
||||
image, image_size = load_image(image_data)
|
||||
if image_size is not None:
|
||||
# It is a video with multiple images
|
||||
image_hash = hash(image_data)
|
||||
pixel_values = image_processor(image)["pixel_values"]
|
||||
for _ in range(len(pixel_values)):
|
||||
pixel_values[_] = pixel_values[_].astype(np.float16)
|
||||
pixel_values = np.stack(pixel_values, axis=0)
|
||||
return pixel_values, image_hash, image_size
|
||||
else:
|
||||
# It is an image
|
||||
image_hash = hash(image_data)
|
||||
if image_aspect_ratio == "pad":
|
||||
image = expand2square(
|
||||
image,
|
||||
tuple(int(x * 255) for x in image_processor.image_mean),
|
||||
)
|
||||
pixel_values = image_processor(image.convert("RGB"))[
|
||||
"pixel_values"
|
||||
][0]
|
||||
elif image_aspect_ratio == "anyres" or (
|
||||
image_aspect_ratio is not None
|
||||
and "anyres_max" in image_aspect_ratio
|
||||
):
|
||||
pixel_values = process_anyres_image(
|
||||
image, image_processor, image_grid_pinpoints
|
||||
)
|
||||
else:
|
||||
pixel_values = image_processor(image)["pixel_values"][0]
|
||||
|
||||
if isinstance(pixel_values, np.ndarray):
|
||||
pixel_values = pixel_values.astype(np.float16)
|
||||
|
||||
return pixel_values, image_hash, image.size
|
||||
except Exception:
|
||||
logger.error("Exception in TokenizerManager:\n" + get_exception_traceback())
|
||||
|
||||
async def _process_single_image(
|
||||
self, image_data: Union[bytes, str], aspect_ratio: str, grid_pinpoints: str
|
||||
):
|
||||
if self.executor is not None:
|
||||
loop = asyncio.get_event_loop()
|
||||
return await loop.run_in_executor(
|
||||
self.executor,
|
||||
LlavaImageProcessor._process_single_image_task,
|
||||
image_data,
|
||||
aspect_ratio,
|
||||
grid_pinpoints,
|
||||
)
|
||||
else:
|
||||
return self._process_single_image_task(
|
||||
image_data, aspect_ratio, grid_pinpoints
|
||||
)
|
||||
|
||||
async def process_images_async(
|
||||
self,
|
||||
image_data: List[Union[str, bytes]],
|
||||
input_text,
|
||||
request_obj,
|
||||
*args,
|
||||
**kwargs,
|
||||
):
|
||||
if not image_data:
|
||||
return None
|
||||
|
||||
modalities = request_obj.modalities or ["image"]
|
||||
aspect_ratio = getattr(self.hf_config, "image_aspect_ratio", None)
|
||||
grid_pinpoints = (
|
||||
self.hf_config.image_grid_pinpoints
|
||||
if hasattr(self.hf_config, "image_grid_pinpoints")
|
||||
and "anyres" in aspect_ratio
|
||||
else None
|
||||
)
|
||||
|
||||
if isinstance(image_data, str):
|
||||
image_data = [image_data]
|
||||
|
||||
if isinstance(image_data, list) and len(image_data) > 0:
|
||||
if "multi-images" in modalities or "video" in modalities:
|
||||
# Multiple images
|
||||
aspect_ratio = "pad" # LLaVA OneVision Handling: more than one image --> interleaved image mode or video mode. We do not use anyres
|
||||
pixel_values, image_hashes, image_sizes = [], [], []
|
||||
res = []
|
||||
for img_data in image_data:
|
||||
res.append(
|
||||
self._process_single_image(
|
||||
img_data, aspect_ratio, grid_pinpoints
|
||||
)
|
||||
)
|
||||
res = await asyncio.gather(*res)
|
||||
for pixel_v, image_h, image_s in res:
|
||||
pixel_values.append(pixel_v)
|
||||
image_hashes.append(image_h)
|
||||
image_sizes.append(image_s)
|
||||
|
||||
if isinstance(pixel_values[0], np.ndarray):
|
||||
pixel_values = np.stack(pixel_values, axis=0)
|
||||
else:
|
||||
# A single image
|
||||
pixel_values, image_hash, image_size = await self._process_single_image(
|
||||
image_data[0], aspect_ratio, grid_pinpoints
|
||||
)
|
||||
image_hashes = [image_hash]
|
||||
image_sizes = [image_size]
|
||||
else:
|
||||
raise ValueError(f"Invalid image data: {image_data}")
|
||||
|
||||
return {
|
||||
"pixel_values": pixel_values,
|
||||
"image_hashes": image_hashes,
|
||||
"image_sizes": image_sizes,
|
||||
"modalities": request_obj.modalities or ["image"],
|
||||
}
|
||||
|
||||
|
||||
ImageProcessorMapping = {
|
||||
LlavaVidForCausalLM: LlavaImageProcessor,
|
||||
LlavaQwenForCausalLM: LlavaImageProcessor,
|
||||
LlavaMistralForCausalLM: LlavaImageProcessor,
|
||||
}
|
||||
86
python/sglang/srt/managers/image_processors/minicpmv.py
Normal file
86
python/sglang/srt/managers/image_processors/minicpmv.py
Normal file
@@ -0,0 +1,86 @@
|
||||
import asyncio
|
||||
from typing import List, Union
|
||||
|
||||
from sglang.srt.managers.image_processor import BaseImageProcessor
|
||||
from sglang.srt.managers.image_processors.base_image_processor import (
|
||||
get_global_processor,
|
||||
)
|
||||
from sglang.srt.models.minicpmv import MiniCPMV
|
||||
|
||||
|
||||
class MiniCPMVImageProcessor(BaseImageProcessor):
|
||||
def __init__(self, hf_config, server_args, _processor):
|
||||
super().__init__(hf_config, server_args, _processor)
|
||||
self.IMAGE_TOKEN = "(<image>./</image>)"
|
||||
|
||||
@staticmethod
|
||||
def _process_images_task(images, input_text):
|
||||
processor = get_global_processor()
|
||||
result = processor.__call__(text=input_text, images=images, return_tensors="pt")
|
||||
return {
|
||||
"input_ids": result.input_ids,
|
||||
"pixel_values": result.pixel_values,
|
||||
"tgt_sizes": result.tgt_sizes,
|
||||
}
|
||||
|
||||
async def _process_images(self, images, input_text):
|
||||
if self.executor is not None:
|
||||
loop = asyncio.get_event_loop()
|
||||
image_inputs = await loop.run_in_executor(
|
||||
self.executor,
|
||||
MiniCPMVImageProcessor._process_images_task,
|
||||
images,
|
||||
input_text,
|
||||
)
|
||||
else:
|
||||
image_inputs = self._processor(
|
||||
images=images, text=input_text, return_tensors="pt"
|
||||
)
|
||||
|
||||
return image_inputs
|
||||
|
||||
async def process_images_async(
|
||||
self,
|
||||
image_data: List[Union[str, bytes]],
|
||||
input_ids,
|
||||
request_obj,
|
||||
max_req_input_len,
|
||||
):
|
||||
if not image_data:
|
||||
return None
|
||||
if not isinstance(image_data, list):
|
||||
image_data = [image_data]
|
||||
|
||||
base_output = self.load_images(
|
||||
input_ids, image_data, self.IMAGE_TOKEN, max_req_input_len
|
||||
)
|
||||
if base_output is None:
|
||||
return None
|
||||
|
||||
if len(base_output.all_frames) == 0:
|
||||
return None
|
||||
res = await self._process_images(
|
||||
images=base_output.all_frames, input_text=base_output.input_text
|
||||
)
|
||||
|
||||
# Collect special token ids
|
||||
tokenizer = self._processor.tokenizer
|
||||
im_start_id = tokenizer.im_start_id
|
||||
im_end_id = tokenizer.im_end_id
|
||||
if tokenizer.slice_start_id:
|
||||
slice_start_id = tokenizer.slice_start_id
|
||||
slice_end_id = tokenizer.slice_end_id
|
||||
return {
|
||||
"input_ids": res["input_ids"].flatten().tolist(),
|
||||
"pixel_values": res["pixel_values"],
|
||||
"tgt_sizes": res["tgt_sizes"],
|
||||
"image_hashes": base_output.image_hashes,
|
||||
"modalities": request_obj.modalities or ["image"],
|
||||
"im_start_id": im_start_id,
|
||||
"im_end_id": im_end_id,
|
||||
"slice_start_id": slice_start_id,
|
||||
"slice_end_id": slice_end_id,
|
||||
}
|
||||
|
||||
|
||||
ImageProcessorMapping = {MiniCPMV: MiniCPMVImageProcessor}
|
||||
60
python/sglang/srt/managers/image_processors/mlama.py
Normal file
60
python/sglang/srt/managers/image_processors/mlama.py
Normal file
@@ -0,0 +1,60 @@
|
||||
import asyncio
|
||||
from typing import List, Union
|
||||
|
||||
from sglang.srt.managers.image_processor import BaseImageProcessor
|
||||
from sglang.srt.managers.image_processors.base_image_processor import (
|
||||
get_global_processor,
|
||||
)
|
||||
from sglang.srt.models.mllama import MllamaForConditionalGeneration
|
||||
from sglang.srt.utils import load_image
|
||||
|
||||
|
||||
class MllamaImageProcessor(BaseImageProcessor):
|
||||
def __init__(self, hf_config, server_args, _processor):
|
||||
super().__init__(hf_config, server_args, _processor)
|
||||
|
||||
@staticmethod
|
||||
def _process_single_image_task(images, input_text):
|
||||
# input_ids', 'attention_mask', 'pixel_values', 'aspect_ratio_ids', 'aspect_ratio_mask', 'cross_attention_mask'
|
||||
return get_global_processor()(images, input_text, return_tensors="pt")
|
||||
|
||||
async def _process_single_image(self, images, input_text):
|
||||
if self.executor is not None:
|
||||
loop = asyncio.get_event_loop()
|
||||
image_inputs = await loop.run_in_executor(
|
||||
self.executor,
|
||||
MllamaImageProcessor._process_single_image_task,
|
||||
images,
|
||||
input_text,
|
||||
)
|
||||
else:
|
||||
image_inputs = self._processor(images, input_text, return_tensors="pt")
|
||||
|
||||
return image_inputs
|
||||
|
||||
async def process_images_async(
|
||||
self, image_data: List[Union[str, bytes]], input_text, *args, **kwargs
|
||||
):
|
||||
if not image_data:
|
||||
return None
|
||||
|
||||
if isinstance(input_text, list):
|
||||
assert len(input_text) and isinstance(input_text[0], int)
|
||||
input_text = self._processor.tokenizer.decode(input_text)
|
||||
|
||||
if not isinstance(image_data, list):
|
||||
image_data = [image_data]
|
||||
|
||||
if len(image_data) > 0:
|
||||
images = [load_image(image)[0] for image in image_data]
|
||||
else:
|
||||
images = load_image(image_data[0])[0]
|
||||
|
||||
image_inputs = await self._process_single_image(images, input_text)
|
||||
image_inputs["image_hashes"] = [hash(str(image_data))]
|
||||
image_inputs["input_ids"] = image_inputs["input_ids"].tolist()[0]
|
||||
|
||||
return image_inputs
|
||||
|
||||
|
||||
ImageProcessorMapping = {MllamaForConditionalGeneration: MllamaImageProcessor}
|
||||
161
python/sglang/srt/managers/image_processors/qwen_vl.py
Normal file
161
python/sglang/srt/managers/image_processors/qwen_vl.py
Normal file
@@ -0,0 +1,161 @@
|
||||
import asyncio
|
||||
import math
|
||||
from typing import List, Union
|
||||
|
||||
from PIL import Image
|
||||
|
||||
from sglang.srt.managers.image_processor import BaseImageProcessor
|
||||
from sglang.srt.managers.image_processors.base_image_processor import (
|
||||
get_global_processor,
|
||||
)
|
||||
from sglang.srt.models.qwen2_5_vl import Qwen2_5_VLForConditionalGeneration
|
||||
from sglang.srt.models.qwen2_vl import Qwen2VLForConditionalGeneration
|
||||
|
||||
|
||||
# Compatible with Qwen2VL and Qwen2_5VL
|
||||
class Qwen2_5VLImageProcessor(BaseImageProcessor):
|
||||
def __init__(self, hf_config, server_args, _processor):
|
||||
super().__init__(hf_config, server_args, _processor)
|
||||
self.IMAGE_TOKEN = "<|vision_start|><|image_pad|><|vision_end|>"
|
||||
self.IM_START_TOKEN_ID = hf_config.vision_start_token_id
|
||||
self.IM_END_TOKEN_ID = hf_config.vision_end_token_id
|
||||
self.image_token_id = hf_config.image_token_id
|
||||
self.video_token_id = hf_config.video_token_id
|
||||
self.NUM_TOKEN_PER_FRAME = 770
|
||||
self.IMAGE_FACTOR = 28
|
||||
self.MIN_PIXELS = 4 * 28 * 28
|
||||
self.MAX_PIXELS = 16384 * 28 * 28
|
||||
self.MAX_PIXELS = 16384 * 28 * 28
|
||||
self.MAX_RATIO = 200
|
||||
|
||||
@staticmethod
|
||||
def _process_images_task(images, input_text, _hf_config):
|
||||
if isinstance(images, list) and len(images) == 0:
|
||||
images = None
|
||||
result = get_global_processor().__call__(
|
||||
text=[input_text], images=images, padding=True, return_tensors="pt"
|
||||
)
|
||||
|
||||
return {
|
||||
"input_ids": result.input_ids,
|
||||
"pixel_values": getattr(result, "pixel_values", None),
|
||||
"image_grid_thw": getattr(result, "image_grid_thw", None),
|
||||
"second_per_grid_ts": getattr(result, "second_per_grid_ts", None),
|
||||
"video_grid_thws": getattr(result, "video_grid_thws", None),
|
||||
}
|
||||
|
||||
async def _process_images(self, images, input_text) -> dict:
|
||||
if self.executor is not None:
|
||||
loop = asyncio.get_event_loop()
|
||||
return await loop.run_in_executor(
|
||||
self.executor,
|
||||
Qwen2_5VLImageProcessor._process_images_task,
|
||||
images,
|
||||
input_text,
|
||||
self.hf_config,
|
||||
)
|
||||
else:
|
||||
return self._process_images_task(images, input_text, self.hf_config)
|
||||
|
||||
async def process_images_async(
|
||||
self,
|
||||
image_data: List[Union[str, bytes]],
|
||||
input_ids,
|
||||
request_obj,
|
||||
max_req_input_len,
|
||||
*args,
|
||||
**kwargs,
|
||||
):
|
||||
if not image_data:
|
||||
return None
|
||||
if isinstance(image_data, str):
|
||||
image_data = [image_data]
|
||||
|
||||
image_token = self.IMAGE_TOKEN
|
||||
base_output = self.load_images(
|
||||
input_ids,
|
||||
image_data,
|
||||
image_token,
|
||||
max_req_input_len,
|
||||
)
|
||||
|
||||
def smart_resize(
|
||||
height: int,
|
||||
width: int,
|
||||
factor: int = self.IMAGE_FACTOR,
|
||||
min_pixels: int = self.MIN_PIXELS,
|
||||
max_pixels: int = self.MAX_PIXELS,
|
||||
) -> tuple[int, int]:
|
||||
"""
|
||||
Rescales the image so that the following conditions are met:
|
||||
|
||||
1. Both dimensions (height and width) are divisible by 'factor'.
|
||||
|
||||
2. The total number of pixels is within the range ['min_pixels', 'max_pixels'].
|
||||
|
||||
3. The aspect ratio of the image is maintained as closely as possible.
|
||||
"""
|
||||
if max(height, width) / min(height, width) > self.MAX_RATIO:
|
||||
raise ValueError(
|
||||
f"absolute aspect ratio must be smaller than {self.MAX_RATIO}, got {max(height, width) / min(height, width)}"
|
||||
)
|
||||
h_bar = max(factor, round_by_factor(height, factor))
|
||||
w_bar = max(factor, round_by_factor(width, factor))
|
||||
if h_bar * w_bar > max_pixels:
|
||||
beta = math.sqrt((height * width) / max_pixels)
|
||||
h_bar = floor_by_factor(height / beta, factor)
|
||||
w_bar = floor_by_factor(width / beta, factor)
|
||||
elif h_bar * w_bar < min_pixels:
|
||||
beta = math.sqrt(min_pixels / (height * width))
|
||||
h_bar = ceil_by_factor(height * beta, factor)
|
||||
w_bar = ceil_by_factor(width * beta, factor)
|
||||
return h_bar, w_bar
|
||||
|
||||
def resize_image(image, size_factor: int = self.IMAGE_FACTOR) -> Image.Image:
|
||||
width, height = image.size
|
||||
min_pixels = self.MIN_PIXELS
|
||||
max_pixels = self.MAX_PIXELS
|
||||
resized_height, resized_width = smart_resize(
|
||||
height,
|
||||
width,
|
||||
factor=size_factor,
|
||||
min_pixels=min_pixels,
|
||||
max_pixels=max_pixels,
|
||||
)
|
||||
image = image.resize((resized_width, resized_height))
|
||||
return image
|
||||
|
||||
def round_by_factor(number: int, factor: int) -> int:
|
||||
"""Returns the closest integer to 'number' that is divisible by 'factor'."""
|
||||
return round(number / factor) * factor
|
||||
|
||||
def ceil_by_factor(number: int, factor: int) -> int:
|
||||
"""Returns the smallest integer greater than or equal to 'number' that is divisible by 'factor'."""
|
||||
return math.ceil(number / factor) * factor
|
||||
|
||||
def floor_by_factor(number: int, factor: int) -> int:
|
||||
"""Returns the largest integer less than or equal to 'number' that is divisible by 'factor'."""
|
||||
return math.floor(number / factor) * factor
|
||||
|
||||
images = [resize_image(image) for image in base_output.all_frames]
|
||||
|
||||
ret = await self._process_images(images, base_output.input_text)
|
||||
return {
|
||||
"input_ids": ret["input_ids"].flatten().tolist(),
|
||||
"pixel_values": ret["pixel_values"],
|
||||
"image_hashes": base_output.image_hashes,
|
||||
"modalities": request_obj.modalities or ["image"],
|
||||
"image_grid_thws": ret["image_grid_thw"],
|
||||
"video_grid_thws": ret["video_grid_thws"],
|
||||
"im_start_id": self.IM_START_TOKEN_ID,
|
||||
"im_end_id": self.IM_END_TOKEN_ID,
|
||||
"im_token_id": self.image_token_id,
|
||||
"video_token_id": self.video_token_id,
|
||||
"second_per_grid_ts": ret["second_per_grid_ts"],
|
||||
}
|
||||
|
||||
|
||||
ImageProcessorMapping = {
|
||||
Qwen2VLForConditionalGeneration: Qwen2_5VLImageProcessor,
|
||||
Qwen2_5_VLForConditionalGeneration: Qwen2_5VLImageProcessor,
|
||||
}
|
||||
134
python/sglang/srt/managers/multi_modality_padding.py
Normal file
134
python/sglang/srt/managers/multi_modality_padding.py
Normal file
@@ -0,0 +1,134 @@
|
||||
from abc import abstractmethod
|
||||
from typing import Callable, List, Optional, Tuple
|
||||
|
||||
from sglang.srt.managers.schedule_batch import ImageInputs
|
||||
from sglang.utils import logger
|
||||
|
||||
|
||||
class MultiModalityDataPaddingPattern:
|
||||
"""
|
||||
Data tokens (like image tokens) often need special handling during padding
|
||||
to maintain model compatibility. This class provides the interface for
|
||||
implementing different padding strategies for data tokens
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def pad_input_tokens(
|
||||
self, input_ids: List[int], image_inputs: ImageInputs
|
||||
) -> List[int]:
|
||||
"""
|
||||
Pad the input ids sequence containing data tokens, and replace them with pad_values
|
||||
"""
|
||||
pass
|
||||
|
||||
|
||||
class MultiModalityDataPaddingPatternTokenPairs(MultiModalityDataPaddingPattern):
|
||||
"""In this pattern, data tokens should be enclosed by special token pairs (e.g. <image>...</image>, data_token_pairs)
|
||||
|
||||
This strategy should be applied when data content is marked by start/end token pairs in the input sequence.
|
||||
"""
|
||||
|
||||
def __init__(self, data_token_pairs: Optional[List[Tuple[int, int]]]) -> None:
|
||||
self.data_token_id_pairs = data_token_pairs
|
||||
|
||||
def pad_input_tokens(
|
||||
self, input_ids: List[int], image_inputs: ImageInputs
|
||||
) -> List[int]:
|
||||
"""
|
||||
This function will replace the data-tokens inbetween with pad_values accordingly
|
||||
"""
|
||||
pad_values = image_inputs.pad_values
|
||||
data_token_pairs = self.data_token_id_pairs
|
||||
image_inputs.image_offsets = []
|
||||
if data_token_pairs is None:
|
||||
data_token_pairs = [image_inputs.im_start_id, image_inputs.im_end_id]
|
||||
if data_token_pairs is None:
|
||||
logger.warning(
|
||||
"No data_token_pairs provided, RadixAttention might be influenced."
|
||||
)
|
||||
return input_ids
|
||||
start_token_ids = [s for s, _e in data_token_pairs]
|
||||
end_tokens_ids = [e for _s, e in data_token_pairs]
|
||||
# First start token marks new data
|
||||
data_start_token = start_token_ids[0]
|
||||
|
||||
padded_ids = []
|
||||
last_idx = 0
|
||||
data_idx = -1
|
||||
|
||||
start_indices = [i for i, x in enumerate(input_ids) if x in start_token_ids]
|
||||
end_indices = [i for i, x in enumerate(input_ids) if x in end_tokens_ids]
|
||||
|
||||
if len(start_indices) != len(end_indices):
|
||||
return input_ids
|
||||
|
||||
for start_idx, end_idx in zip(start_indices, end_indices):
|
||||
padded_ids.extend(input_ids[last_idx : start_idx + 1])
|
||||
|
||||
if input_ids[start_idx] == data_start_token:
|
||||
data_idx += 1
|
||||
image_inputs.image_offsets += [start_idx]
|
||||
|
||||
num_tokens = end_idx - start_idx - 1
|
||||
pad_value = pad_values[data_idx]
|
||||
padded_ids.extend([pad_value] * num_tokens)
|
||||
|
||||
last_idx = end_idx
|
||||
|
||||
padded_ids.extend(input_ids[last_idx:])
|
||||
|
||||
assert len(input_ids) == len(padded_ids)
|
||||
return padded_ids
|
||||
|
||||
|
||||
class MultModalityDataPaddingPatternSingleToken(MultiModalityDataPaddingPattern):
|
||||
"""In this pattern, data is represented with a special token_id ( image_inputs.im_token_id ),
|
||||
which needs first to be expanded to multiple tokens, then replaced with their padding values
|
||||
|
||||
This strategy should be used when a single data token represents content that should
|
||||
be expanded to multiple tokens during processing.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self, num_data_token_calc_func: Callable[[Tuple[int, int, int]], int]
|
||||
) -> None:
|
||||
self.num_data_token_calc_func = num_data_token_calc_func
|
||||
|
||||
def pad_input_tokens(
|
||||
self, input_ids: List[int], image_inputs: ImageInputs
|
||||
) -> List[int]:
|
||||
"""
|
||||
This function will follow the procedure of:
|
||||
1. the data token will be expanded, of which the final number will be calculated by `num_data_token_calc_func`
|
||||
2. the padded data tokens will be replaced with their pad_values
|
||||
"""
|
||||
image_grid_thws = image_inputs.image_grid_thws
|
||||
pad_values = image_inputs.pad_values
|
||||
|
||||
image_indices = [
|
||||
idx
|
||||
for idx, token in enumerate(input_ids)
|
||||
if token == image_inputs.im_token_id
|
||||
]
|
||||
|
||||
image_inputs.image_offsets = []
|
||||
|
||||
input_ids_with_image = []
|
||||
for image_cnt, _ in enumerate(image_grid_thws):
|
||||
print(f"image_cnt {image_cnt}")
|
||||
num_image_tokens = self.num_data_token_calc_func(image_grid_thws[image_cnt])
|
||||
if image_cnt == 0:
|
||||
non_image_tokens = input_ids[: image_indices[image_cnt]]
|
||||
else:
|
||||
non_image_tokens = input_ids[
|
||||
image_indices[image_cnt - 1] + 1 : image_indices[image_cnt]
|
||||
]
|
||||
input_ids_with_image.extend(non_image_tokens)
|
||||
image_inputs.image_offsets.append(len(input_ids_with_image))
|
||||
pad_ids = pad_values * (
|
||||
(num_image_tokens + len(pad_values)) // len(pad_values)
|
||||
)
|
||||
input_ids_with_image.extend(pad_ids[:num_image_tokens])
|
||||
input_ids_with_image.extend(input_ids[image_indices[-1] + 1 :])
|
||||
|
||||
return input_ids_with_image
|
||||
@@ -158,15 +158,19 @@ class ImageInputs:
|
||||
image_grid_thws: List[Tuple[int, int, int]] = None
|
||||
mrope_position_delta: Optional[torch.Tensor] = None
|
||||
|
||||
# MiniCPMV related
|
||||
# The id of the single-image placeholder token
|
||||
im_token_id: Optional[torch.Tensor] = None
|
||||
# All the images in the batch should share the same special image
|
||||
# bound token ids.
|
||||
im_start_id: Optional[torch.Tensor] = None
|
||||
im_end_id: Optional[torch.Tensor] = None
|
||||
slice_start_id: Optional[torch.Tensor] = None
|
||||
slice_end_id: Optional[torch.Tensor] = None
|
||||
im_start_id: Optional[int] = None
|
||||
im_end_id: Optional[int] = None
|
||||
slice_start_id: Optional[int] = None
|
||||
slice_end_id: Optional[int] = None
|
||||
tgt_sizes: Optional[list] = None
|
||||
|
||||
# denotes the number of valid image tokens in each image
|
||||
images_emb_mask: Optional[torch.BoolTensor] = None
|
||||
|
||||
@staticmethod
|
||||
def from_dict(obj: dict):
|
||||
ret = ImageInputs(
|
||||
@@ -186,11 +190,13 @@ class ImageInputs:
|
||||
"aspect_ratio_ids",
|
||||
"aspect_ratio_mask",
|
||||
"image_grid_thws",
|
||||
"im_token_id",
|
||||
"im_start_id",
|
||||
"im_end_id",
|
||||
"slice_start_id",
|
||||
"slice_end_id",
|
||||
"tgt_sizes",
|
||||
"images_emb_mask",
|
||||
]
|
||||
for arg in optional_args:
|
||||
if arg in obj:
|
||||
|
||||
@@ -455,7 +455,7 @@ def pt_weights_iterator(
|
||||
disable=not enable_tqdm,
|
||||
bar_format=_BAR_FORMAT,
|
||||
):
|
||||
state = torch.load(bin_file, map_location="cpu")
|
||||
state = torch.load(bin_file, map_location="cpu", weights_only=True)
|
||||
yield from state.items()
|
||||
del state
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
@@ -41,7 +41,6 @@ from torch import nn
|
||||
from torch.nn.init import trunc_normal_
|
||||
from transformers import PretrainedConfig
|
||||
|
||||
from sglang.srt.distributed import divide, get_tensor_model_parallel_world_size
|
||||
from sglang.srt.layers.activation import get_act_fn
|
||||
from sglang.srt.layers.attention.vision import VisionAttention
|
||||
from sglang.srt.layers.linear import (
|
||||
@@ -51,6 +50,9 @@ from sglang.srt.layers.linear import (
|
||||
)
|
||||
from sglang.srt.layers.logits_processor import LogitsProcessor
|
||||
from sglang.srt.layers.quantization.base_config import QuantizationConfig
|
||||
from sglang.srt.managers.multi_modality_padding import (
|
||||
MultiModalityDataPaddingPatternTokenPairs,
|
||||
)
|
||||
from sglang.srt.managers.schedule_batch import ImageInputs
|
||||
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
|
||||
from sglang.srt.model_loader.utils import set_default_torch_dtype
|
||||
@@ -186,19 +188,16 @@ class Idefics2EncoderLayer(nn.Module):
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.embed_dim = config.hidden_size
|
||||
|
||||
self.num_heads = config.num_attention_heads
|
||||
tp_size = get_tensor_model_parallel_world_size()
|
||||
num_heads_per_partition = divide(self.num_heads, tp_size)
|
||||
self.self_attn = VisionAttention(
|
||||
embed_dim=config.hidden_size,
|
||||
num_heads=num_heads_per_partition,
|
||||
num_heads=self.num_heads,
|
||||
projection_size=config.intermediate_size,
|
||||
use_qkv_parallel=True,
|
||||
quant_config=quant_config,
|
||||
dropout=config.attention_dropout,
|
||||
use_context_forward=False,
|
||||
use_full_precision_softmax=True,
|
||||
softmax_in_single_precision=True,
|
||||
flatten_batch=False,
|
||||
prefix=add_prefix("self_attn", prefix),
|
||||
)
|
||||
@@ -708,21 +707,21 @@ class MiniCPMVBaseModel(nn.Module):
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
pad_values: List[int],
|
||||
im_start_id: torch.Tensor,
|
||||
im_end_id: torch.Tensor,
|
||||
slice_start_id: Optional[torch.Tensor] = None,
|
||||
slice_end_id: Optional[torch.Tensor] = None,
|
||||
im_start_id: int,
|
||||
im_end_id: int,
|
||||
slice_start_id: Optional[int] = None,
|
||||
slice_end_id: Optional[int] = None,
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Returns a tensor indicating the bounds (start and end token ids) of the images
|
||||
"""
|
||||
# All the images in the batch should share the same special image
|
||||
# bound token ids.
|
||||
start_cond = input_ids == im_start_id[0]
|
||||
end_cond = input_ids == im_end_id[0]
|
||||
start_cond = input_ids == im_start_id
|
||||
end_cond = input_ids == im_end_id
|
||||
if slice_start_id is not None:
|
||||
start_cond |= input_ids == slice_start_id[0]
|
||||
end_cond |= input_ids == slice_end_id[0]
|
||||
start_cond |= input_ids == slice_start_id
|
||||
end_cond |= input_ids == slice_end_id
|
||||
|
||||
(image_start_tokens,) = torch.where(start_cond)
|
||||
image_start_tokens += 1
|
||||
@@ -733,6 +732,8 @@ class MiniCPMVBaseModel(nn.Module):
|
||||
if (
|
||||
len(image_start_tokens) + 1 == len(image_end_tokens)
|
||||
and input_ids[0] in pad_values
|
||||
and len(image_start_tokens) != 0
|
||||
and len(image_end_tokens) != 0
|
||||
and image_end_tokens[0] < image_start_tokens[0]
|
||||
):
|
||||
image_start_tokens = torch.cat(
|
||||
@@ -897,9 +898,12 @@ class MiniCPMVBaseModel(nn.Module):
|
||||
forward_batch: ForwardBatch,
|
||||
**kwargs: Any,
|
||||
) -> torch.Tensor:
|
||||
if forward_batch.image_inputs is not None and forward_batch.image_inputs != [
|
||||
None
|
||||
]:
|
||||
if (
|
||||
forward_batch.image_inputs is not None
|
||||
and len(forward_batch.image_inputs) > 0
|
||||
and forward_batch.image_inputs[0] is not None
|
||||
):
|
||||
# TODO: bath
|
||||
kwargs.update(
|
||||
{
|
||||
"pixel_values": (
|
||||
@@ -1135,81 +1139,16 @@ class MiniCPMV2_6(MiniCPMVBaseModel):
|
||||
return self.resampler(vision_embedding, tgt_sizes)
|
||||
|
||||
def pad_input_ids(self, input_ids: List[int], image_inputs: ImageInputs):
|
||||
if not isinstance(image_inputs.im_start_id, list) or not isinstance(
|
||||
image_inputs.im_end_id, list
|
||||
):
|
||||
return input_ids
|
||||
|
||||
new_input_ids = []
|
||||
last_idx = 0
|
||||
image_idx = -1
|
||||
image_inputs.image_offsets = []
|
||||
|
||||
# Get all special token IDs
|
||||
im_start_id = (
|
||||
image_inputs.im_start_id[0].item()
|
||||
if isinstance(image_inputs.im_start_id[0], torch.Tensor)
|
||||
else image_inputs.im_start_id[0]
|
||||
)
|
||||
im_end_id = (
|
||||
image_inputs.im_end_id[0].item()
|
||||
if isinstance(image_inputs.im_end_id[0], torch.Tensor)
|
||||
else image_inputs.im_end_id[0]
|
||||
)
|
||||
slice_start_id = (
|
||||
image_inputs.slice_start_id[0].item()
|
||||
if isinstance(image_inputs.slice_start_id[0], torch.Tensor)
|
||||
else image_inputs.slice_start_id[0]
|
||||
)
|
||||
slice_end_id = (
|
||||
image_inputs.slice_end_id[0].item()
|
||||
if isinstance(image_inputs.slice_end_id[0], torch.Tensor)
|
||||
else image_inputs.slice_end_id[0]
|
||||
)
|
||||
im_start_id: int = image_inputs.im_start_id
|
||||
im_end_id: int = image_inputs.im_end_id
|
||||
slice_start_id: int = image_inputs.slice_start_id
|
||||
slice_end_id: int = image_inputs.slice_end_id
|
||||
|
||||
# Find all start and end positions for both types
|
||||
start_indices = [
|
||||
i
|
||||
for i, x in enumerate(input_ids)
|
||||
if x == im_start_id or x == slice_start_id
|
||||
]
|
||||
end_indices = [
|
||||
i for i, x in enumerate(input_ids) if x == im_end_id or x == slice_end_id
|
||||
]
|
||||
media_token_pairs = [(im_start_id, im_end_id), (slice_start_id, slice_end_id)]
|
||||
pattern = MultiModalityDataPaddingPatternTokenPairs(media_token_pairs)
|
||||
|
||||
if len(start_indices) != len(end_indices):
|
||||
return input_ids
|
||||
# Process each region (both image and slice)
|
||||
for start_idx, end_idx in zip(start_indices, end_indices):
|
||||
# Add non-image tokens before this region
|
||||
new_input_ids.extend(
|
||||
input_ids[last_idx : start_idx + 1]
|
||||
) # include start token
|
||||
|
||||
is_image_start = input_ids[start_idx] == im_start_id
|
||||
|
||||
if is_image_start:
|
||||
image_inputs.image_offsets += [start_idx]
|
||||
image_idx += 1
|
||||
|
||||
num_tokens = end_idx - start_idx - 1 # exclude start and end tokens
|
||||
|
||||
# Generate pad_ids
|
||||
pad_values = [image_inputs.pad_values[image_idx]]
|
||||
|
||||
pad_ids = pad_values * ((num_tokens + len(pad_values)) // len(pad_values))
|
||||
pad_ids = pad_ids[:num_tokens]
|
||||
|
||||
# Add pad_ids
|
||||
new_input_ids.extend(pad_ids)
|
||||
|
||||
# Update last_idx to after end token
|
||||
last_idx = end_idx
|
||||
|
||||
# Add remaining tokens after last region
|
||||
new_input_ids.extend(input_ids[last_idx:])
|
||||
assert len(input_ids) == len(new_input_ids)
|
||||
return new_input_ids
|
||||
return pattern.pad_input_tokens(input_ids, image_inputs)
|
||||
|
||||
|
||||
_SUPPORT_VERSION = {(2, 6): MiniCPMV2_6}
|
||||
|
||||
@@ -202,7 +202,7 @@ class MllamaVisionEncoderLayer(nn.Module):
|
||||
quant_config=None,
|
||||
dropout=0.0,
|
||||
use_context_forward=False,
|
||||
use_full_precision_softmax=False,
|
||||
softmax_in_single_precision=False,
|
||||
flatten_batch=False,
|
||||
prefix=add_prefix("self_attn", prefix),
|
||||
)
|
||||
|
||||
@@ -47,6 +47,9 @@ from sglang.srt.layers.logits_processor import LogitsProcessor
|
||||
from sglang.srt.layers.pooler import Pooler, PoolingType
|
||||
from sglang.srt.layers.quantization.base_config import QuantizationConfig
|
||||
from sglang.srt.layers.vocab_parallel_embedding import ParallelLMHead
|
||||
from sglang.srt.managers.multi_modality_padding import (
|
||||
MultiModalityDataPaddingPatternTokenPairs,
|
||||
)
|
||||
from sglang.srt.managers.schedule_batch import ImageInputs
|
||||
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
|
||||
from sglang.srt.model_loader.weight_utils import default_weight_loader
|
||||
@@ -121,12 +124,12 @@ class Qwen2_5_VisionBlock(nn.Module):
|
||||
self.norm2 = Qwen2RMSNorm(dim, eps=1e-6)
|
||||
if attn_implementation == "sdpa":
|
||||
use_context_forward = False
|
||||
use_full_precision_softmax = False
|
||||
softmax_in_single_precision = False
|
||||
elif attn_implementation == "flash_attention_2":
|
||||
use_full_precision_softmax = False
|
||||
softmax_in_single_precision = False
|
||||
use_context_forward = True
|
||||
elif attn_implementation == "eager":
|
||||
use_full_precision_softmax = True
|
||||
softmax_in_single_precision = True
|
||||
use_context_forward = False
|
||||
|
||||
self.attn = VisionAttention(
|
||||
@@ -135,7 +138,7 @@ class Qwen2_5_VisionBlock(nn.Module):
|
||||
projection_size=dim,
|
||||
use_qkv_parallel=False,
|
||||
use_context_forward=use_context_forward,
|
||||
use_full_precision_softmax=use_full_precision_softmax,
|
||||
softmax_in_single_precision=softmax_in_single_precision,
|
||||
flatten_batch=True,
|
||||
quant_config=quant_config,
|
||||
prefix=add_prefix("attn", prefix),
|
||||
@@ -149,12 +152,17 @@ class Qwen2_5_VisionBlock(nn.Module):
|
||||
)
|
||||
|
||||
def forward(
|
||||
self, x: torch.Tensor, cu_seqlens: torch.Tensor, rotary_pos_emb: torch.Tensor
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
cu_seqlens: torch.Tensor,
|
||||
position_embeddings: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
hidden_states = self.norm1(x)
|
||||
hidden_states = rearrange(hidden_states, "s b ... -> b s ...")
|
||||
attn = self.attn(
|
||||
hidden_states, cu_seqlens=cu_seqlens, rotary_pos_emb=rotary_pos_emb
|
||||
hidden_states,
|
||||
cu_seqlens=cu_seqlens,
|
||||
position_embeddings=position_embeddings,
|
||||
)
|
||||
attn = rearrange(attn, "b s ... -> s b ...")
|
||||
x = x + attn
|
||||
@@ -443,6 +451,8 @@ class Qwen2_5_VisionTransformer(nn.Module):
|
||||
)
|
||||
rotary_pos_emb = rotary_pos_emb[window_index, :, :]
|
||||
rotary_pos_emb = rotary_pos_emb.reshape(seq_len, -1)
|
||||
emb = torch.cat((rotary_pos_emb, rotary_pos_emb), dim=-1)
|
||||
position_embeddings = (emb.cos(), emb.sin())
|
||||
|
||||
# compute cu_seqlens
|
||||
cu_seqlens = torch.repeat_interleave(
|
||||
@@ -457,7 +467,9 @@ class Qwen2_5_VisionTransformer(nn.Module):
|
||||
cu_seqlens_now = cu_seqlens
|
||||
else:
|
||||
cu_seqlens_now = cu_window_seqlens
|
||||
x = blk(x, cu_seqlens=cu_seqlens_now, rotary_pos_emb=rotary_pos_emb)
|
||||
x = blk(
|
||||
x, cu_seqlens=cu_seqlens_now, position_embeddings=position_embeddings
|
||||
)
|
||||
|
||||
# adapter
|
||||
x = self.merger(x)
|
||||
@@ -522,50 +534,14 @@ class Qwen2_5_VLForConditionalGeneration(nn.Module):
|
||||
return num_image_tokens
|
||||
|
||||
def pad_input_ids(self, input_ids: List[int], image_inputs: ImageInputs):
|
||||
new_input_ids = []
|
||||
last_idx = 0
|
||||
image_idx = -1
|
||||
image_inputs.image_offsets = []
|
||||
|
||||
# Get all special token IDs
|
||||
im_start_id = image_inputs.im_start_id
|
||||
im_end_id = image_inputs.im_end_id
|
||||
im_start_id: int = image_inputs.im_start_id
|
||||
im_end_id: int = image_inputs.im_end_id
|
||||
|
||||
# Find all start and end positions for both types
|
||||
start_indices = [i for i, x in enumerate(input_ids) if x == im_start_id]
|
||||
end_indices = [i for i, x in enumerate(input_ids) if x == im_end_id]
|
||||
media_token_pairs = [(im_start_id, im_end_id)]
|
||||
pattern = MultiModalityDataPaddingPatternTokenPairs(media_token_pairs)
|
||||
|
||||
if len(start_indices) != len(end_indices):
|
||||
return input_ids
|
||||
# Process each region (both image and slice)
|
||||
for start_idx, end_idx in zip(start_indices, end_indices):
|
||||
# Add non-image tokens before this region
|
||||
new_input_ids.extend(input_ids[last_idx : start_idx + 1])
|
||||
|
||||
is_image_start = input_ids[start_idx] == im_start_id
|
||||
|
||||
if is_image_start:
|
||||
image_inputs.image_offsets += [start_idx]
|
||||
image_idx += 1
|
||||
|
||||
num_tokens = end_idx - start_idx - 1 # exclude start and end tokens
|
||||
|
||||
# Generate pad_ids
|
||||
pad_values = [image_inputs.pad_values[image_idx]]
|
||||
|
||||
pad_ids = pad_values * ((num_tokens + len(pad_values)) // len(pad_values))
|
||||
pad_ids = pad_ids[:num_tokens]
|
||||
|
||||
# Add pad_ids
|
||||
new_input_ids.extend(pad_ids)
|
||||
|
||||
# Update last_idx to after end token
|
||||
last_idx = end_idx
|
||||
|
||||
# Add remaining tokens after last region
|
||||
new_input_ids.extend(input_ids[last_idx:])
|
||||
assert len(input_ids) == len(new_input_ids)
|
||||
return new_input_ids
|
||||
return pattern.pad_input_tokens(input_ids, image_inputs)
|
||||
|
||||
def _process_image_input(self, image_input: Qwen2VLImageInputs) -> torch.Tensor:
|
||||
pixel_values = image_input["pixel_values"].type(self.visual.dtype)
|
||||
@@ -629,7 +605,7 @@ class Qwen2_5_VLForConditionalGeneration(nn.Module):
|
||||
extend_start_loc_cpu = forward_batch.extend_start_loc.cpu().numpy()
|
||||
prefix_lens_cpu = forward_batch.extend_prefix_lens_cpu
|
||||
for i, image in enumerate(forward_batch.image_inputs):
|
||||
if image is None:
|
||||
if image is None or image.pixel_values is None:
|
||||
continue
|
||||
start_idx = extend_start_loc_cpu[i]
|
||||
prefix_len = prefix_lens_cpu[i]
|
||||
|
||||
@@ -42,6 +42,9 @@ from sglang.srt.layers.logits_processor import LogitsProcessor
|
||||
from sglang.srt.layers.pooler import Pooler, PoolingType
|
||||
from sglang.srt.layers.quantization.base_config import QuantizationConfig
|
||||
from sglang.srt.layers.vocab_parallel_embedding import ParallelLMHead
|
||||
from sglang.srt.managers.multi_modality_padding import (
|
||||
MultiModalityDataPaddingPatternTokenPairs,
|
||||
)
|
||||
from sglang.srt.managers.schedule_batch import ImageInputs
|
||||
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
|
||||
from sglang.srt.model_loader.weight_utils import default_weight_loader
|
||||
@@ -137,12 +140,12 @@ class Qwen2VisionBlock(nn.Module):
|
||||
mlp_hidden_dim = int(dim * mlp_ratio)
|
||||
if attn_implementation == "sdpa":
|
||||
use_context_forward = False
|
||||
use_full_precision_softmax = False
|
||||
softmax_in_single_precision = False
|
||||
elif attn_implementation == "flash_attention_2":
|
||||
use_full_precision_softmax = False
|
||||
softmax_in_single_precision = False
|
||||
use_context_forward = True
|
||||
elif attn_implementation == "eager":
|
||||
use_full_precision_softmax = True
|
||||
softmax_in_single_precision = True
|
||||
use_context_forward = False
|
||||
|
||||
self.attn = VisionAttention(
|
||||
@@ -151,7 +154,7 @@ class Qwen2VisionBlock(nn.Module):
|
||||
projection_size=dim,
|
||||
use_qkv_parallel=False,
|
||||
use_context_forward=use_context_forward,
|
||||
use_full_precision_softmax=use_full_precision_softmax,
|
||||
softmax_in_single_precision=softmax_in_single_precision,
|
||||
flatten_batch=True,
|
||||
quant_config=quant_config,
|
||||
prefix=add_prefix("attn", prefix),
|
||||
@@ -165,12 +168,17 @@ class Qwen2VisionBlock(nn.Module):
|
||||
)
|
||||
|
||||
def forward(
|
||||
self, x: torch.Tensor, cu_seqlens: torch.Tensor, rotary_pos_emb: torch.Tensor
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
cu_seqlens: torch.Tensor,
|
||||
position_embeddings: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
hidden_states = self.norm1(x)
|
||||
hidden_states = rearrange(hidden_states, "s b ... -> b s ...")
|
||||
attn = self.attn(
|
||||
hidden_states, cu_seqlens=cu_seqlens, rotary_pos_emb=rotary_pos_emb
|
||||
hidden_states,
|
||||
cu_seqlens=cu_seqlens,
|
||||
position_embeddings=position_embeddings,
|
||||
)
|
||||
attn = rearrange(attn, "b s ... -> s b ...")
|
||||
x = x + attn
|
||||
@@ -392,7 +400,8 @@ class Qwen2VisionTransformer(nn.Module):
|
||||
|
||||
# compute position embedding
|
||||
rotary_pos_emb = self.rot_pos_emb(grid_thw)
|
||||
|
||||
emb = torch.cat((rotary_pos_emb, rotary_pos_emb), dim=-1)
|
||||
position_embeddings = (emb.cos(), emb.sin())
|
||||
# compute cu_seqlens
|
||||
cu_seqlens = torch.repeat_interleave(
|
||||
grid_thw[:, 1] * grid_thw[:, 2], grid_thw[:, 0]
|
||||
@@ -402,7 +411,7 @@ class Qwen2VisionTransformer(nn.Module):
|
||||
# transformers
|
||||
x = x.unsqueeze(1)
|
||||
for blk in self.blocks:
|
||||
x = blk(x, cu_seqlens=cu_seqlens, rotary_pos_emb=rotary_pos_emb)
|
||||
x = blk(x, cu_seqlens=cu_seqlens, position_embeddings=position_embeddings)
|
||||
|
||||
# adapter
|
||||
x = self.merger(x)
|
||||
@@ -425,40 +434,6 @@ class Qwen2VLForConditionalGeneration(nn.Module):
|
||||
)
|
||||
return num_image_tokens
|
||||
|
||||
# Use grid_t * grid_w * grid_h to pad tokens for each image
|
||||
# add replaced padding by unique image hash
|
||||
def pad_input_ids(self, input_ids: List[int], image_inputs: ImageInputs):
|
||||
image_grid_thws = image_inputs.image_grid_thws
|
||||
pad_values = image_inputs.pad_values
|
||||
|
||||
image_indices = [
|
||||
idx
|
||||
for idx, token in enumerate(input_ids)
|
||||
if token == self.config.image_token_id
|
||||
]
|
||||
image_inputs.image_offsets = []
|
||||
|
||||
input_ids_with_image = []
|
||||
for image_cnt, _ in enumerate(image_grid_thws):
|
||||
num_image_tokens = self.calculate_num_image_tokens(
|
||||
image_grid_thws[image_cnt]
|
||||
)
|
||||
if image_cnt == 0:
|
||||
non_image_tokens = input_ids[: image_indices[image_cnt]]
|
||||
else:
|
||||
non_image_tokens = input_ids[
|
||||
image_indices[image_cnt - 1] + 1 : image_indices[image_cnt]
|
||||
]
|
||||
input_ids_with_image.extend(non_image_tokens)
|
||||
image_inputs.image_offsets.append(len(input_ids_with_image))
|
||||
pad_ids = pad_values * (
|
||||
(num_image_tokens + len(pad_values)) // len(pad_values)
|
||||
)
|
||||
input_ids_with_image.extend(pad_ids[:num_image_tokens])
|
||||
input_ids_with_image.extend(input_ids[image_indices[-1] + 1 :])
|
||||
|
||||
return input_ids_with_image
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
config: Qwen2VLConfig,
|
||||
@@ -494,6 +469,17 @@ class Qwen2VLForConditionalGeneration(nn.Module):
|
||||
self.logits_processor = LogitsProcessor(config)
|
||||
self.pooler = Pooler(pooling_type=PoolingType.LAST, normalize=True)
|
||||
|
||||
# Use grid_t * grid_w * grid_h to pad tokens for each image
|
||||
# add replaced padding by unique image hash
|
||||
def pad_input_ids(self, input_ids: List[int], image_inputs: ImageInputs):
|
||||
# Get all special token IDs
|
||||
im_start_id: int = image_inputs.im_start_id
|
||||
im_end_id: int = image_inputs.im_end_id
|
||||
|
||||
media_token_pairs = [(im_start_id, im_end_id)]
|
||||
pattern = MultiModalityDataPaddingPatternTokenPairs(media_token_pairs)
|
||||
return pattern.pad_input_tokens(input_ids, image_inputs)
|
||||
|
||||
def _process_image_input(self, image_input: Qwen2VLImageInputs) -> torch.Tensor:
|
||||
pixel_values = image_input["pixel_values"].type(self.visual.dtype)
|
||||
image_embeds = self.visual(pixel_values, grid_thw=image_input["image_grid_thw"])
|
||||
@@ -556,12 +542,12 @@ class Qwen2VLForConditionalGeneration(nn.Module):
|
||||
extend_start_loc_cpu = forward_batch.extend_start_loc.cpu().numpy()
|
||||
prefix_lens_cpu = forward_batch.extend_prefix_lens_cpu
|
||||
for i, image in enumerate(forward_batch.image_inputs):
|
||||
if image is None:
|
||||
if image is None or image.pixel_values is None:
|
||||
continue
|
||||
start_idx = extend_start_loc_cpu[i]
|
||||
prefix_len = prefix_lens_cpu[i]
|
||||
pixel_values = image.pixel_values.clone()
|
||||
|
||||
pixel_values = torch.tensor(image.pixel_values, device="cuda")
|
||||
image_grid_thws = torch.tensor(
|
||||
np.array(image.image_grid_thws), device="cuda"
|
||||
)
|
||||
@@ -579,15 +565,13 @@ class Qwen2VLForConditionalGeneration(nn.Module):
|
||||
image_grid_thws[idx]
|
||||
)
|
||||
|
||||
left_idx = start_idx + (image_offset - prefix_len)
|
||||
right_idx = (
|
||||
start_idx + (image_offset - prefix_len) + num_image_tokens
|
||||
)
|
||||
|
||||
left_idx = start_idx + (image_offset - prefix_len + 1)
|
||||
right_idx = left_idx + num_image_tokens
|
||||
inputs_embeds[left_idx:right_idx] = image_embeds[
|
||||
image_embeds_offset : image_embeds_offset + num_image_tokens
|
||||
]
|
||||
image_embeds_offset += num_image_tokens
|
||||
input_ids = None
|
||||
|
||||
hidden_states = self.model(
|
||||
input_ids=input_ids,
|
||||
|
||||
Reference in New Issue
Block a user