refactor: bug fixes and refactor for vlm (#4661)
This commit is contained in:
@@ -9,8 +9,6 @@ import PIL
|
||||
import torch
|
||||
from PIL.Image import Image
|
||||
from transformers import (
|
||||
AutoImageProcessor,
|
||||
AutoProcessor,
|
||||
BaseImageProcessor,
|
||||
BatchFeature,
|
||||
LlamaConfig,
|
||||
@@ -20,6 +18,7 @@ from transformers import (
|
||||
)
|
||||
from transformers.image_utils import to_numpy_array
|
||||
|
||||
from sglang.srt.configs.utils import register_image_processor, register_processor
|
||||
from sglang.srt.mm_utils import expand2square
|
||||
|
||||
|
||||
@@ -625,5 +624,5 @@ class VLMImageProcessorConfig(PretrainedConfig):
|
||||
super().__init__(**kwargs)
|
||||
|
||||
|
||||
AutoProcessor.register(MultiModalityConfig, VLChatProcessor, exist_ok=True)
|
||||
AutoImageProcessor.register(VLMImageProcessorConfig, None, VLMImageProcessor, None)
|
||||
register_processor(MultiModalityConfig, VLChatProcessor)
|
||||
register_image_processor(MultiModalityConfig, VLMImageProcessor)
|
||||
|
||||
@@ -460,6 +460,7 @@ def is_generation_model(model_architectures: List[str], is_embedding: bool = Fal
|
||||
|
||||
|
||||
multimodal_model_archs = [
|
||||
"DeepseekVL2ForCausalLM",
|
||||
"LlavaLlamaForCausalLM",
|
||||
"LlavaQwenForCausalLM",
|
||||
"LlavaMistralForCausalLM",
|
||||
@@ -472,7 +473,6 @@ multimodal_model_archs = [
|
||||
"Qwen2_5_VLForConditionalGeneration",
|
||||
"MiniCPMV",
|
||||
"MultiModalityCausalLM",
|
||||
"DeepseekVL2ForCausalLM",
|
||||
]
|
||||
|
||||
|
||||
|
||||
25
python/sglang/srt/configs/utils.py
Normal file
25
python/sglang/srt/configs/utils.py
Normal file
@@ -0,0 +1,25 @@
|
||||
from typing import Type
|
||||
|
||||
from transformers import (
|
||||
AutoImageProcessor,
|
||||
AutoProcessor,
|
||||
BaseImageProcessor,
|
||||
PretrainedConfig,
|
||||
ProcessorMixin,
|
||||
)
|
||||
|
||||
|
||||
def register_image_processor(
|
||||
config: Type[PretrainedConfig], image_processor: Type[BaseImageProcessor]
|
||||
):
|
||||
"""
|
||||
register customized hf image processor while removing hf impl
|
||||
"""
|
||||
AutoImageProcessor.register(config, None, image_processor, None, exist_ok=True)
|
||||
|
||||
|
||||
def register_processor(config: Type[PretrainedConfig], processor: Type[ProcessorMixin]):
|
||||
"""
|
||||
register customized hf processor while removing hf impl
|
||||
"""
|
||||
AutoProcessor.register(config, processor, exist_ok=True)
|
||||
@@ -653,7 +653,7 @@ register_conv_template(
|
||||
Conversation(
|
||||
name="gemma-it",
|
||||
system_message="You are a helpful assistant.",
|
||||
system_template="<bos><start_of_turn>user{system_message}\n\n",
|
||||
system_template="<start_of_turn>user{system_message}\n\n",
|
||||
roles=("<start_of_turn>user\n", "<start_of_turn>model\n"),
|
||||
sep="<end_of_turn>\n",
|
||||
sep_style=SeparatorStyle.GEMMA3,
|
||||
|
||||
@@ -143,9 +143,14 @@ class VisionAttention(nn.Module):
|
||||
if position_embeddings is not None:
|
||||
cos, sin = position_embeddings
|
||||
original_shape = q.shape
|
||||
q, k = q.view(s, head, -1), k.view(s, head, -1)
|
||||
# [total_tokens, head, head_size]
|
||||
q = q.view(-1, head, self.head_size)
|
||||
k = k.view(-1, head, self.head_size)
|
||||
|
||||
q, k = apply_rotary_pos_emb(q, k, cos, sin)
|
||||
q, k = q.reshape(original_shape), k.reshape(original_shape)
|
||||
|
||||
q = q.view(original_shape)
|
||||
k = k.view(original_shape)
|
||||
|
||||
if self.use_qkv_parallel:
|
||||
pass
|
||||
|
||||
@@ -1,9 +1,12 @@
|
||||
# TODO: also move pad_input_ids into this module
|
||||
import importlib
|
||||
import inspect
|
||||
import logging
|
||||
import pkgutil
|
||||
from functools import lru_cache
|
||||
from typing import Union
|
||||
|
||||
from torch import Tensor
|
||||
from transformers import IMAGE_PROCESSOR_MAPPING
|
||||
|
||||
from sglang.srt.managers.image_processors.base_image_processor import (
|
||||
@@ -18,9 +21,7 @@ logger = logging.getLogger(__name__)
|
||||
IMAGE_PROCESSOR_MAPPING = {}
|
||||
|
||||
|
||||
def get_image_processor(
|
||||
hf_config, server_args: ServerArgs, processor
|
||||
) -> BaseImageProcessor:
|
||||
def get_image_processor(hf_config, server_args, processor) -> BaseImageProcessor:
|
||||
for model_cls, processor_cls in IMAGE_PROCESSOR_MAPPING.items():
|
||||
if model_cls.__name__ in hf_config.architectures:
|
||||
return processor_cls(hf_config, server_args, processor)
|
||||
@@ -42,13 +43,18 @@ def import_image_processors():
|
||||
try:
|
||||
module = importlib.import_module(name)
|
||||
except Exception as e:
|
||||
logger.warning(f"Ignore import error when loading {name}: " f"{e}")
|
||||
logger.warning(f" Ignore import error when loading {name}: " f"{e}")
|
||||
continue
|
||||
if hasattr(module, "ImageProcessorMapping"):
|
||||
entry = module.ImageProcessorMapping
|
||||
if isinstance(entry, dict):
|
||||
for processor_name, cls in entry.items():
|
||||
IMAGE_PROCESSOR_MAPPING[processor_name] = cls
|
||||
all_members = inspect.getmembers(module, inspect.isclass)
|
||||
classes = [
|
||||
member
|
||||
for name, member in all_members
|
||||
if member.__module__ == module.__name__
|
||||
]
|
||||
for cls in classes:
|
||||
if issubclass(cls, BaseImageProcessor):
|
||||
for arch in getattr(cls, "models"):
|
||||
IMAGE_PROCESSOR_MAPPING[arch] = cls
|
||||
|
||||
|
||||
# also register processors
|
||||
|
||||
@@ -4,14 +4,14 @@ import dataclasses
|
||||
import multiprocessing as mp
|
||||
import os
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Optional
|
||||
from typing import Optional, Union
|
||||
|
||||
import PIL
|
||||
import transformers
|
||||
from decord import VideoReader, cpu
|
||||
from openai import BadRequestError
|
||||
from PIL import Image
|
||||
|
||||
from sglang.srt.server_args import ServerArgs
|
||||
from sglang.srt.utils import load_image
|
||||
from sglang.utils import logger
|
||||
|
||||
@@ -31,8 +31,16 @@ class BaseImageProcessorOutput:
|
||||
# input_text, with each frame of video/image represented as an image_token
|
||||
input_text: str
|
||||
|
||||
def normalize(self):
|
||||
for field_name in ["data_hashes", "image_sizes", "all_frames"]:
|
||||
field = getattr(self, field_name, None)
|
||||
if field is not None and isinstance(field, list) and len(field) == 0:
|
||||
setattr(self, field_name, None)
|
||||
|
||||
|
||||
class BaseImageProcessor(ABC):
|
||||
models = []
|
||||
|
||||
def __init__(self, hf_config, server_args, _processor):
|
||||
self.hf_config = hf_config
|
||||
self._processor = _processor
|
||||
@@ -40,6 +48,9 @@ class BaseImageProcessor(ABC):
|
||||
# FIXME: not accurate, model and image specific
|
||||
self.NUM_TOKEN_PER_FRAME = 330
|
||||
|
||||
# Initialize global processor first
|
||||
init_global_processor(self, server_args)
|
||||
|
||||
self.executor = concurrent.futures.ProcessPoolExecutor(
|
||||
initializer=init_global_processor,
|
||||
mp_context=mp.get_context("fork"),
|
||||
@@ -113,7 +124,7 @@ class BaseImageProcessor(ABC):
|
||||
self,
|
||||
input_ids: list[int],
|
||||
image_data,
|
||||
image_token: str,
|
||||
image_token: Union[int, str],
|
||||
max_req_input_len: int,
|
||||
return_text: Optional[bool] = True,
|
||||
discard_alpha_channel: bool = True,
|
||||
@@ -122,9 +133,16 @@ class BaseImageProcessor(ABC):
|
||||
Each frame of video/image will be replaced by a single image token
|
||||
|
||||
Args:
|
||||
image_token: The token ID representing the image placeholder.
|
||||
discard_alpha_channel: if True, discards the alpha channel in the returned images
|
||||
|
||||
"""
|
||||
if isinstance(image_token, int):
|
||||
image_token_str = self._processor.tokenizer.convert_ids_to_tokens(
|
||||
image_token
|
||||
)
|
||||
else:
|
||||
image_token_str = image_token
|
||||
|
||||
if isinstance(input_ids, list) and return_text:
|
||||
assert len(input_ids) and isinstance(input_ids[0], int)
|
||||
@@ -190,13 +208,11 @@ class BaseImageProcessor(ABC):
|
||||
new_text += text_part
|
||||
|
||||
except Exception as e:
|
||||
import openai
|
||||
|
||||
logger.error(f"An exception occurred while loading images: {e}")
|
||||
raise BadRequestError(
|
||||
f"An exception occurred while loading images: {e}"
|
||||
)
|
||||
continue
|
||||
|
||||
return BaseImageProcessorOutput(
|
||||
image_hashes=hashes,
|
||||
@@ -204,6 +220,8 @@ class BaseImageProcessor(ABC):
|
||||
all_frames=images,
|
||||
input_text=new_text,
|
||||
)
|
||||
out.normalize()
|
||||
return out
|
||||
|
||||
|
||||
class DummyImageProcessor(BaseImageProcessor):
|
||||
@@ -214,9 +232,7 @@ class DummyImageProcessor(BaseImageProcessor):
|
||||
return None
|
||||
|
||||
|
||||
def init_global_processor(
|
||||
sglang_image_processor: BaseImageProcessor, server_args: ServerArgs
|
||||
):
|
||||
def init_global_processor(sglang_image_processor: BaseImageProcessor, server_args):
|
||||
"""Init the global processor for multi-modal models."""
|
||||
global global_processor
|
||||
transformers.logging.set_verbosity_error()
|
||||
|
||||
@@ -16,13 +16,9 @@
|
||||
# COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER
|
||||
# IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN
|
||||
# CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
|
||||
|
||||
import asyncio
|
||||
import math
|
||||
from typing import List, Union
|
||||
|
||||
import torch
|
||||
from PIL import Image, ImageOps
|
||||
|
||||
from sglang.srt.managers.image_processor import BaseImageProcessor
|
||||
from sglang.srt.managers.image_processors.base_image_processor import (
|
||||
@@ -32,18 +28,24 @@ from sglang.srt.models.deepseek_vl2 import DeepseekVL2ForCausalLM
|
||||
|
||||
|
||||
class DeepseekVL2ImageProcessor(BaseImageProcessor):
|
||||
models = [DeepseekVL2ForCausalLM]
|
||||
|
||||
def __init__(self, hf_config, server_args, _processor):
|
||||
# with contextlib.suppress(ValueError):
|
||||
# AutoProcessor.register("DeepseekVLV2Processor", DeepseekVLV2Processor)
|
||||
super().__init__(hf_config, server_args, _processor)
|
||||
self.IMAGE_TOKEN = "<image>"
|
||||
|
||||
@staticmethod
|
||||
def _process_images_task(image, input_text, max_req_input_len):
|
||||
return get_global_processor().__call__(
|
||||
processor = get_global_processor()
|
||||
res = processor.__call__(
|
||||
conversations=input_text, images=image, max_req_input_len=max_req_input_len
|
||||
)
|
||||
|
||||
image_token_id = processor.image_token_id
|
||||
|
||||
res["im_token_id"] = image_token_id
|
||||
return res
|
||||
|
||||
async def _process_images(self, image_data, input_text, max_req_input_len):
|
||||
if self.executor is not None:
|
||||
loop = asyncio.get_event_loop()
|
||||
@@ -70,18 +72,15 @@ class DeepseekVL2ImageProcessor(BaseImageProcessor):
|
||||
if not isinstance(image_data, list):
|
||||
image_data = [image_data]
|
||||
|
||||
images, image_hashes, image_sizes = [], [], []
|
||||
images, image_sizes = [], []
|
||||
|
||||
image_token = self.IMAGE_TOKEN
|
||||
base_output = self.load_images(
|
||||
input_ids, image_data, image_token, max_req_input_len
|
||||
)
|
||||
base_output.all_frames = [img.convert("RGB") for img in base_output.all_frames]
|
||||
res = await self._process_images(
|
||||
base_output.all_frames, base_output.input_text, max_req_input_len
|
||||
)
|
||||
pixel_values = res["images"]
|
||||
input_ids = res["input_ids"]
|
||||
images_seq_mask = res["images_seq_mask"]
|
||||
images_spatial_crop = res["images_spatial_crop"]
|
||||
batched_images_spatial_crop = []
|
||||
@@ -89,16 +88,12 @@ class DeepseekVL2ImageProcessor(BaseImageProcessor):
|
||||
batched_images_spatial_crop = torch.stack(batched_images_spatial_crop, dim=0)
|
||||
|
||||
return {
|
||||
"input_ids": input_ids.tolist(),
|
||||
"pixel_values": pixel_values,
|
||||
"image_hashes": image_hashes,
|
||||
"input_ids": res["input_ids"].tolist(),
|
||||
"pixel_values": res["images"],
|
||||
"im_token_id": res["im_token_id"],
|
||||
"image_hashes": base_output.image_hashes,
|
||||
"image_sizes": image_sizes,
|
||||
"image_seq_mask": images_seq_mask,
|
||||
"images_emb_mask": images_seq_mask,
|
||||
"image_spatial_crop": batched_images_spatial_crop,
|
||||
"modalities": request_obj.modalities or ["image"],
|
||||
}
|
||||
|
||||
|
||||
ImageProcessorMapping = {
|
||||
DeepseekVL2ForCausalLM: DeepseekVL2ImageProcessor,
|
||||
}
|
||||
|
||||
@@ -17,14 +17,15 @@ logger = logging.get_logger(__name__)
|
||||
|
||||
|
||||
class Gemma3SGLangImageProcessor(SGLangBaseImageProcessor):
|
||||
models = [Gemma3ForConditionalGeneration]
|
||||
|
||||
def __init__(self, hf_config, server_args, _processor):
|
||||
super().__init__(hf_config, server_args, _processor)
|
||||
self.IMAGE_TOKEN = "<start_of_image>"
|
||||
self.IM_START_TOKEN_ID = hf_config.boi_token_index
|
||||
self.IM_END_TOKEN_ID = hf_config.eoi_token_index
|
||||
|
||||
@staticmethod
|
||||
def _process_images_task(images, input_text, _hf_config):
|
||||
async def _process_single_image(self, images, input_text) -> dict:
|
||||
if isinstance(images, list) and len(images) == 0:
|
||||
images = None
|
||||
processor = get_global_processor()
|
||||
@@ -46,19 +47,6 @@ class Gemma3SGLangImageProcessor(SGLangBaseImageProcessor):
|
||||
"pixel_values": pixel_values,
|
||||
}
|
||||
|
||||
async def _process_images(self, images, input_text) -> dict:
|
||||
if self.executor is not None:
|
||||
loop = asyncio.get_event_loop()
|
||||
return await loop.run_in_executor(
|
||||
self.executor,
|
||||
Gemma3SGLangImageProcessor._process_images_task,
|
||||
images,
|
||||
input_text,
|
||||
self.hf_config,
|
||||
)
|
||||
else:
|
||||
return self._process_images_task(images, input_text, self.hf_config)
|
||||
|
||||
async def process_images_async(
|
||||
self,
|
||||
image_data: List[Union[str, bytes]],
|
||||
@@ -82,7 +70,7 @@ class Gemma3SGLangImageProcessor(SGLangBaseImageProcessor):
|
||||
discard_alpha_channel=True,
|
||||
)
|
||||
|
||||
ret = await self._process_images(
|
||||
ret = await self._process_single_image(
|
||||
input_text=base_output.input_text, images=base_output.all_frames
|
||||
)
|
||||
|
||||
@@ -93,8 +81,3 @@ class Gemma3SGLangImageProcessor(SGLangBaseImageProcessor):
|
||||
"im_start_id": self.IM_START_TOKEN_ID,
|
||||
"im_end_id": self.IM_END_TOKEN_ID,
|
||||
}
|
||||
|
||||
|
||||
ImageProcessorMapping = {
|
||||
Gemma3ForConditionalGeneration: Gemma3SGLangImageProcessor,
|
||||
}
|
||||
|
||||
@@ -11,6 +11,8 @@ from sglang.srt.models.deepseek_janus_pro import MultiModalityCausalLM
|
||||
|
||||
|
||||
class JanusProProcessor(SGLangBaseImageProcessor):
|
||||
models = [MultiModalityCausalLM]
|
||||
|
||||
def __init__(self, hf_config, server_args, _processor):
|
||||
super().__init__(hf_config, server_args, _processor)
|
||||
|
||||
@@ -77,6 +79,3 @@ class JanusProProcessor(SGLangBaseImageProcessor):
|
||||
"im_end_id": res["im_end_id"],
|
||||
"im_token_id": res["im_token_id"],
|
||||
}
|
||||
|
||||
|
||||
ImageProcessorMapping = {MultiModalityCausalLM: JanusProProcessor}
|
||||
|
||||
@@ -15,6 +15,8 @@ from sglang.utils import get_exception_traceback
|
||||
|
||||
|
||||
class LlavaImageProcessor(BaseImageProcessor):
|
||||
models = [LlavaVidForCausalLM, LlavaQwenForCausalLM, LlavaMistralForCausalLM]
|
||||
|
||||
def __init__(self, hf_config, server_args, _processor):
|
||||
super().__init__(hf_config, server_args, _processor)
|
||||
|
||||
@@ -143,10 +145,3 @@ class LlavaImageProcessor(BaseImageProcessor):
|
||||
"image_sizes": image_sizes,
|
||||
"modalities": request_obj.modalities or ["image"],
|
||||
}
|
||||
|
||||
|
||||
ImageProcessorMapping = {
|
||||
LlavaVidForCausalLM: LlavaImageProcessor,
|
||||
LlavaQwenForCausalLM: LlavaImageProcessor,
|
||||
LlavaMistralForCausalLM: LlavaImageProcessor,
|
||||
}
|
||||
|
||||
@@ -1,6 +1,8 @@
|
||||
import asyncio
|
||||
from typing import List, Union
|
||||
|
||||
import torch
|
||||
|
||||
from sglang.srt.managers.image_processor import BaseImageProcessor
|
||||
from sglang.srt.managers.image_processors.base_image_processor import (
|
||||
get_global_processor,
|
||||
@@ -9,6 +11,8 @@ from sglang.srt.models.minicpmv import MiniCPMV
|
||||
|
||||
|
||||
class MiniCPMVImageProcessor(BaseImageProcessor):
|
||||
models = [MiniCPMV]
|
||||
|
||||
def __init__(self, hf_config, server_args, _processor):
|
||||
super().__init__(hf_config, server_args, _processor)
|
||||
self.IMAGE_TOKEN = "(<image>./</image>)"
|
||||
@@ -69,21 +73,57 @@ class MiniCPMVImageProcessor(BaseImageProcessor):
|
||||
# Collect special token ids
|
||||
tokenizer = self._processor.tokenizer
|
||||
im_start_id = tokenizer.im_start_id
|
||||
im_token_id = tokenizer.unk_token_id
|
||||
im_end_id = tokenizer.im_end_id
|
||||
if tokenizer.slice_start_id:
|
||||
slice_start_id = tokenizer.slice_start_id
|
||||
slice_end_id = tokenizer.slice_end_id
|
||||
|
||||
pixel_values = res["pixel_values"]
|
||||
tgt_sizes = res["tgt_sizes"]
|
||||
|
||||
if not isinstance(pixel_values, (torch.Tensor, list)):
|
||||
raise ValueError(
|
||||
"Incorrect type of pixel values. " f"Got type: {type(pixel_values)}"
|
||||
)
|
||||
|
||||
if not isinstance(tgt_sizes, (torch.Tensor, list)):
|
||||
raise ValueError(
|
||||
"Incorrect type of target sizes. " f"Got type: {type(tgt_sizes)}"
|
||||
)
|
||||
|
||||
if len(pixel_values) != len(tgt_sizes):
|
||||
raise ValueError(
|
||||
"Inconsistent batch lengths, found: "
|
||||
f"{len(pixel_values)} vs. {len(tgt_sizes)}"
|
||||
)
|
||||
|
||||
# tgt_sizes = [tgt_size for tgt_size in tgt_sizes if isinstance(tgt_size, torch.Tensor)]
|
||||
# tgt_sizes = torch.vstack(tgt_sizes).type(torch.int32)
|
||||
pixel_values_flat: List[torch.Tensor] = []
|
||||
tgt_sizes_flat: List[torch.Tensor] = []
|
||||
for pixel_b, tgt_b in zip(pixel_values, tgt_sizes):
|
||||
# per image
|
||||
if len(pixel_b) != len(tgt_b):
|
||||
raise ValueError(
|
||||
"Inconsistent N lengths, found: " f"{len(pixel_b)} vs {len(tgt_b)}"
|
||||
)
|
||||
for pixel_n, tgt_n in zip(pixel_b, tgt_b):
|
||||
# per patch
|
||||
pixel_values_flat += [pixel_n]
|
||||
tgt_sizes_flat += [tgt_n]
|
||||
|
||||
pixel_values = pixel_values_flat
|
||||
tgt_sizes = torch.stack(tgt_sizes_flat)
|
||||
return {
|
||||
"input_ids": res["input_ids"].flatten().tolist(),
|
||||
"pixel_values": res["pixel_values"],
|
||||
"tgt_sizes": res["tgt_sizes"],
|
||||
"pixel_values": pixel_values,
|
||||
"tgt_sizes": tgt_sizes,
|
||||
"image_hashes": base_output.image_hashes,
|
||||
"modalities": request_obj.modalities or ["image"],
|
||||
"im_start_id": im_start_id,
|
||||
"im_token_id": im_token_id,
|
||||
"im_end_id": im_end_id,
|
||||
"slice_start_id": slice_start_id,
|
||||
"slice_end_id": slice_end_id,
|
||||
}
|
||||
|
||||
|
||||
ImageProcessorMapping = {MiniCPMV: MiniCPMVImageProcessor}
|
||||
|
||||
@@ -10,6 +10,8 @@ from sglang.srt.utils import load_image
|
||||
|
||||
|
||||
class MllamaImageProcessor(BaseImageProcessor):
|
||||
models = [MllamaForConditionalGeneration]
|
||||
|
||||
def __init__(self, hf_config, server_args, _processor):
|
||||
super().__init__(hf_config, server_args, _processor)
|
||||
|
||||
@@ -55,6 +57,3 @@ class MllamaImageProcessor(BaseImageProcessor):
|
||||
image_inputs["input_ids"] = image_inputs["input_ids"].tolist()[0]
|
||||
|
||||
return image_inputs
|
||||
|
||||
|
||||
ImageProcessorMapping = {MllamaForConditionalGeneration: MllamaImageProcessor}
|
||||
|
||||
@@ -2,6 +2,7 @@ import asyncio
|
||||
import math
|
||||
from typing import List, Union
|
||||
|
||||
import torch
|
||||
from PIL import Image
|
||||
|
||||
from sglang.srt.managers.image_processor import BaseImageProcessor
|
||||
@@ -14,6 +15,8 @@ from sglang.srt.models.qwen2_vl import Qwen2VLForConditionalGeneration
|
||||
|
||||
# Compatible with Qwen2VL and Qwen2_5VL
|
||||
class Qwen2_5VLImageProcessor(BaseImageProcessor):
|
||||
models = [Qwen2VLForConditionalGeneration, Qwen2_5_VLForConditionalGeneration]
|
||||
|
||||
def __init__(self, hf_config, server_args, _processor):
|
||||
super().__init__(hf_config, server_args, _processor)
|
||||
self.IMAGE_TOKEN = "<|vision_start|><|image_pad|><|vision_end|>"
|
||||
@@ -43,7 +46,7 @@ class Qwen2_5VLImageProcessor(BaseImageProcessor):
|
||||
"video_grid_thws": getattr(result, "video_grid_thws", None),
|
||||
}
|
||||
|
||||
async def _process_images(self, images, input_text) -> dict:
|
||||
async def _process_single_image(self, images, input_text) -> dict:
|
||||
if self.executor is not None:
|
||||
loop = asyncio.get_event_loop()
|
||||
return await loop.run_in_executor(
|
||||
@@ -138,23 +141,23 @@ class Qwen2_5VLImageProcessor(BaseImageProcessor):
|
||||
|
||||
images = [resize_image(image) for image in base_output.all_frames]
|
||||
|
||||
ret = await self._process_images(images, base_output.input_text)
|
||||
ret = await self._process_single_image(
|
||||
images=images, input_text=base_output.input_text
|
||||
)
|
||||
|
||||
image_grid_thws = torch.concat([ret["image_grid_thw"]])
|
||||
video_grid_thws = None
|
||||
|
||||
return {
|
||||
"input_ids": ret["input_ids"].flatten().tolist(),
|
||||
"pixel_values": ret["pixel_values"],
|
||||
"image_hashes": base_output.image_hashes,
|
||||
"modalities": request_obj.modalities or ["image"],
|
||||
"image_grid_thws": ret["image_grid_thw"],
|
||||
"video_grid_thws": ret["video_grid_thws"],
|
||||
"image_grid_thws": image_grid_thws,
|
||||
"video_grid_thws": video_grid_thws,
|
||||
"im_start_id": self.IM_START_TOKEN_ID,
|
||||
"im_end_id": self.IM_END_TOKEN_ID,
|
||||
"im_token_id": self.image_token_id,
|
||||
"video_token_id": self.video_token_id,
|
||||
"second_per_grid_ts": ret["second_per_grid_ts"],
|
||||
}
|
||||
|
||||
|
||||
ImageProcessorMapping = {
|
||||
Qwen2VLForConditionalGeneration: Qwen2_5VLImageProcessor,
|
||||
Qwen2_5_VLForConditionalGeneration: Qwen2_5VLImageProcessor,
|
||||
}
|
||||
|
||||
303
python/sglang/srt/managers/mm_utils.py
Normal file
303
python/sglang/srt/managers/mm_utils.py
Normal file
@@ -0,0 +1,303 @@
|
||||
"""
|
||||
Multimodality utils
|
||||
"""
|
||||
|
||||
from abc import abstractmethod
|
||||
from typing import Callable, List, Optional, Tuple
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
|
||||
from sglang.srt.managers.schedule_batch import (
|
||||
ImageInputs,
|
||||
global_server_args_dict,
|
||||
logger,
|
||||
)
|
||||
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
|
||||
from sglang.utils import logger
|
||||
|
||||
|
||||
class MultiModalityDataPaddingPattern:
|
||||
"""
|
||||
Data tokens (like image tokens) often need special handling during padding
|
||||
to maintain model compatibility. This class provides the interface for
|
||||
implementing different padding strategies for data tokens
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def pad_input_tokens(
|
||||
self, input_ids: List[int], image_inputs: ImageInputs
|
||||
) -> List[int]:
|
||||
"""
|
||||
Pad the input ids sequence containing data tokens, and replace them with pad_values
|
||||
"""
|
||||
pass
|
||||
|
||||
|
||||
class MultiModalityDataPaddingPatternTokenPairs(MultiModalityDataPaddingPattern):
|
||||
"""In this pattern, data tokens should be enclosed by special token pairs (e.g. <image>...</image>, data_token_pairs)
|
||||
|
||||
This strategy should be applied when data content is marked by start/end token pairs in the input sequence.
|
||||
"""
|
||||
|
||||
def __init__(self, data_token_pairs: Optional[List[Tuple[int, int]]]) -> None:
|
||||
self.data_token_id_pairs = data_token_pairs
|
||||
|
||||
def pad_input_tokens(
|
||||
self, input_ids: List[int], image_inputs: ImageInputs
|
||||
) -> List[int]:
|
||||
"""
|
||||
This function will replace the data-tokens inbetween with pad_values accordingly
|
||||
"""
|
||||
pad_values = image_inputs.pad_values
|
||||
data_token_pairs = self.data_token_id_pairs
|
||||
image_inputs.image_offsets = []
|
||||
if data_token_pairs is None:
|
||||
data_token_pairs = [image_inputs.im_start_id, image_inputs.im_end_id]
|
||||
if data_token_pairs is None:
|
||||
logger.warning(
|
||||
"No data_token_pairs provided, RadixAttention might be influenced."
|
||||
)
|
||||
return input_ids
|
||||
start_token_ids = [s for s, _e in data_token_pairs]
|
||||
end_tokens_ids = [e for _s, e in data_token_pairs]
|
||||
# First start token marks new data
|
||||
data_start_token = start_token_ids[0]
|
||||
|
||||
padded_ids = []
|
||||
last_idx = 0
|
||||
data_idx = -1
|
||||
|
||||
start_indices = [i for i, x in enumerate(input_ids) if x in start_token_ids]
|
||||
end_indices = [i for i, x in enumerate(input_ids) if x in end_tokens_ids]
|
||||
|
||||
if len(start_indices) != len(end_indices):
|
||||
return input_ids
|
||||
|
||||
for start_idx, end_idx in zip(start_indices, end_indices):
|
||||
padded_ids.extend(input_ids[last_idx : start_idx + 1])
|
||||
|
||||
if input_ids[start_idx] == data_start_token:
|
||||
data_idx += 1
|
||||
image_inputs.image_offsets += [start_idx]
|
||||
|
||||
num_tokens = end_idx - start_idx - 1
|
||||
pad_value = pad_values[data_idx]
|
||||
padded_ids.extend([pad_value] * num_tokens)
|
||||
|
||||
last_idx = end_idx
|
||||
|
||||
padded_ids.extend(input_ids[last_idx:])
|
||||
|
||||
assert len(input_ids) == len(padded_ids)
|
||||
return padded_ids
|
||||
|
||||
|
||||
class MultModalityDataPaddingPatternSingleToken(MultiModalityDataPaddingPattern):
|
||||
"""In this pattern, data is represented with a special token_id ( image_inputs.im_token_id ),
|
||||
which needs first to be expanded to multiple tokens, then replaced with their padding values
|
||||
|
||||
This strategy should be used when a single data token represents content that should
|
||||
be expanded to multiple tokens during processing.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self, num_data_token_calc_func: Callable[[Tuple[int, int, int]], int]
|
||||
) -> None:
|
||||
self.num_data_token_calc_func = num_data_token_calc_func
|
||||
|
||||
def pad_input_tokens(
|
||||
self, input_ids: List[int], image_inputs: ImageInputs
|
||||
) -> List[int]:
|
||||
"""
|
||||
This function will follow the procedure of:
|
||||
1. the data token will be expanded, of which the final number will be calculated by `num_data_token_calc_func`
|
||||
2. the padded data tokens will be replaced with their pad_values
|
||||
"""
|
||||
image_grid_thws = image_inputs.image_grid_thws
|
||||
pad_values = image_inputs.pad_values
|
||||
|
||||
image_indices = [
|
||||
idx
|
||||
for idx, token in enumerate(input_ids)
|
||||
if token == image_inputs.im_token_id
|
||||
]
|
||||
|
||||
image_inputs.image_offsets = []
|
||||
|
||||
input_ids_with_image = []
|
||||
for image_cnt, _ in enumerate(image_grid_thws):
|
||||
num_image_tokens = self.num_data_token_calc_func(image_grid_thws[image_cnt])
|
||||
if image_cnt == 0:
|
||||
non_image_tokens = input_ids[: image_indices[image_cnt]]
|
||||
else:
|
||||
non_image_tokens = input_ids[
|
||||
image_indices[image_cnt - 1] + 1 : image_indices[image_cnt]
|
||||
]
|
||||
input_ids_with_image.extend(non_image_tokens)
|
||||
image_inputs.image_offsets.append(len(input_ids_with_image))
|
||||
pad_ids = pad_values * (
|
||||
(num_image_tokens + len(pad_values)) // len(pad_values)
|
||||
)
|
||||
input_ids_with_image.extend(pad_ids[:num_image_tokens])
|
||||
input_ids_with_image.extend(input_ids[image_indices[-1] + 1 :])
|
||||
|
||||
return input_ids_with_image
|
||||
|
||||
|
||||
class MultiModalityDataPaddingPatternImageTokens(MultiModalityDataPaddingPattern):
|
||||
"""In this pattern, data tokens should be represented as image tokens (e.g. <image><image>....<image>)"""
|
||||
|
||||
def __init__(self, image_token_id: torch.Tensor) -> None:
|
||||
self.image_token_id = image_token_id
|
||||
|
||||
def pad_input_tokens(self, input_ids: List[int], image_inputs) -> List[int]:
|
||||
"""
|
||||
This function will replace the data-tokens in between with pad_values accordingly
|
||||
"""
|
||||
pad_values = image_inputs.pad_values
|
||||
assert len(pad_values) != 0
|
||||
|
||||
input_ids_tensor = torch.tensor(input_ids)
|
||||
mask = torch.isin(input_ids_tensor, self.image_token_id)
|
||||
|
||||
num_image_tokens = mask.sum().item()
|
||||
repeated_pad_values = torch.tensor(pad_values).repeat(
|
||||
num_image_tokens // len(pad_values) + 1
|
||||
)[:num_image_tokens]
|
||||
|
||||
input_ids_tensor[mask] = repeated_pad_values
|
||||
return input_ids_tensor.tolist()
|
||||
|
||||
|
||||
def embed_image_inputs(
|
||||
image_input: ImageInputs,
|
||||
input_ids: torch.Tensor,
|
||||
input_embedding: nn.Embedding,
|
||||
image_embedding_func,
|
||||
placeholder_token_ids: List[int] = None,
|
||||
) -> Optional[torch.Tensor]:
|
||||
"""
|
||||
Calculate the image embeddings if necessary, then scatter the result with
|
||||
the help of a boolean mask denoting the embed locations
|
||||
|
||||
Returns:
|
||||
final embedding: Optional[torch.Tensor]
|
||||
"""
|
||||
if image_input is None:
|
||||
return None
|
||||
|
||||
placeholder_token_ids = placeholder_token_ids or image_input.pad_values
|
||||
|
||||
# boolean masking the special tokens
|
||||
special_image_mask = torch.isin(
|
||||
input_ids,
|
||||
torch.tensor(placeholder_token_ids, device=input_ids.device),
|
||||
).unsqueeze(-1)
|
||||
|
||||
num_image_tokens_in_input_ids = special_image_mask.sum()
|
||||
|
||||
if num_image_tokens_in_input_ids == 0:
|
||||
# unexpected
|
||||
inputs_embeds = input_embedding(input_ids)
|
||||
else:
|
||||
image_embedding = image_embedding_func(image_input)
|
||||
|
||||
if image_embedding.dim() == 2:
|
||||
num_image_tokens_in_embedding = image_embedding.shape[0]
|
||||
else:
|
||||
num_image_tokens_in_embedding = (
|
||||
image_embedding.shape[0] * image_embedding.shape[1]
|
||||
)
|
||||
if num_image_tokens_in_input_ids != num_image_tokens_in_embedding:
|
||||
num_image = num_image_tokens_in_input_ids // image_embedding.shape[1]
|
||||
image_embedding = image_embedding[:num_image, :]
|
||||
logger.warning(
|
||||
f"Number of images does not match number of special image tokens in the input text. "
|
||||
f"Got {num_image_tokens_in_input_ids} image tokens in the text but {num_image_tokens_in_embedding} "
|
||||
"tokens from image embeddings."
|
||||
)
|
||||
|
||||
# TODO: chunked prefill will split special tokens from input_ids into several passes, failing the embedding
|
||||
# a fix may be cache the unfinished image embedding for future reuse, determine the tokens to embed with
|
||||
# extend_start_loc and extend_seq_lens
|
||||
if num_image_tokens_in_input_ids > num_image_tokens_in_embedding:
|
||||
chunked_prefill_size = global_server_args_dict["chunked_prefill_size"]
|
||||
if chunked_prefill_size != -1:
|
||||
logger.warning(
|
||||
"You may want to avoid this issue by raising `chunked_prefill_size`, or disabling chunked_prefill"
|
||||
)
|
||||
|
||||
vocab_size = input_embedding.num_embeddings
|
||||
# Important: clamp after getting original image regions
|
||||
# Clamp input ids. This is because the input_ids for the image tokens are
|
||||
# filled with the hash values of the image for the prefix matching in the radix attention.
|
||||
# There values are useless because their embeddings will be replaced by vision embeddings anyway.
|
||||
input_ids.clamp_(min=0, max=vocab_size - 1)
|
||||
inputs_embeds = input_embedding(input_ids)
|
||||
|
||||
special_image_mask = special_image_mask.expand_as(inputs_embeds).to(
|
||||
inputs_embeds.device
|
||||
)
|
||||
|
||||
inputs_embeds = inputs_embeds.masked_scatter(
|
||||
special_image_mask,
|
||||
image_embedding.to(inputs_embeds.device, inputs_embeds.dtype),
|
||||
)
|
||||
return inputs_embeds
|
||||
|
||||
|
||||
def embed_image_embedding(
|
||||
inputs_embeds: torch.Tensor,
|
||||
image_embedding: torch.Tensor,
|
||||
image_bounds: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
scatter image_embedding into inputs_embeds according to image_bounds
|
||||
"""
|
||||
if len(image_bounds) > 0:
|
||||
image_indices = torch.stack(
|
||||
[
|
||||
torch.arange(start, end, dtype=torch.long)
|
||||
for start, end in image_bounds.tolist()
|
||||
]
|
||||
).to(inputs_embeds.device)
|
||||
|
||||
inputs_embeds.scatter_(
|
||||
0,
|
||||
image_indices.view(-1, 1).repeat(1, inputs_embeds.shape[-1]),
|
||||
image_embedding.view(-1, image_embedding.shape[-1]),
|
||||
)
|
||||
return inputs_embeds
|
||||
|
||||
|
||||
def general_mm_embed_routine(
|
||||
input_ids: torch.Tensor,
|
||||
positions: torch.Tensor,
|
||||
forward_batch: ForwardBatch,
|
||||
embed_tokens: nn.Embedding,
|
||||
image_embedding_func: Callable[[ImageInputs], torch.Tensor],
|
||||
placeholder_token_ids: List[int] = None,
|
||||
):
|
||||
"""
|
||||
a general wrapper function to get final input embeds from multimodal models
|
||||
with a language model as causal model
|
||||
"""
|
||||
if (
|
||||
forward_batch.forward_mode.is_decode()
|
||||
or not forward_batch.contains_image_inputs()
|
||||
):
|
||||
inputs_embeds = embed_tokens(input_ids)
|
||||
else:
|
||||
image = forward_batch.merge_image_inputs()
|
||||
inputs_embeds = embed_image_inputs(
|
||||
image_input=image,
|
||||
input_ids=input_ids,
|
||||
input_embedding=embed_tokens,
|
||||
image_embedding_func=image_embedding_func,
|
||||
placeholder_token_ids=placeholder_token_ids,
|
||||
)
|
||||
# once used, image_inputs is useless
|
||||
# just being defensive here
|
||||
forward_batch.image_inputs = None
|
||||
return inputs_embeds
|
||||
@@ -1,134 +0,0 @@
|
||||
from abc import abstractmethod
|
||||
from typing import Callable, List, Optional, Tuple
|
||||
|
||||
from sglang.srt.managers.schedule_batch import ImageInputs
|
||||
from sglang.utils import logger
|
||||
|
||||
|
||||
class MultiModalityDataPaddingPattern:
|
||||
"""
|
||||
Data tokens (like image tokens) often need special handling during padding
|
||||
to maintain model compatibility. This class provides the interface for
|
||||
implementing different padding strategies for data tokens
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def pad_input_tokens(
|
||||
self, input_ids: List[int], image_inputs: ImageInputs
|
||||
) -> List[int]:
|
||||
"""
|
||||
Pad the input ids sequence containing data tokens, and replace them with pad_values
|
||||
"""
|
||||
pass
|
||||
|
||||
|
||||
class MultiModalityDataPaddingPatternTokenPairs(MultiModalityDataPaddingPattern):
|
||||
"""In this pattern, data tokens should be enclosed by special token pairs (e.g. <image>...</image>, data_token_pairs)
|
||||
|
||||
This strategy should be applied when data content is marked by start/end token pairs in the input sequence.
|
||||
"""
|
||||
|
||||
def __init__(self, data_token_pairs: Optional[List[Tuple[int, int]]]) -> None:
|
||||
self.data_token_id_pairs = data_token_pairs
|
||||
|
||||
def pad_input_tokens(
|
||||
self, input_ids: List[int], image_inputs: ImageInputs
|
||||
) -> List[int]:
|
||||
"""
|
||||
This function will replace the data-tokens inbetween with pad_values accordingly
|
||||
"""
|
||||
pad_values = image_inputs.pad_values
|
||||
data_token_pairs = self.data_token_id_pairs
|
||||
image_inputs.image_offsets = []
|
||||
if data_token_pairs is None:
|
||||
data_token_pairs = [image_inputs.im_start_id, image_inputs.im_end_id]
|
||||
if data_token_pairs is None:
|
||||
logger.warning(
|
||||
"No data_token_pairs provided, RadixAttention might be influenced."
|
||||
)
|
||||
return input_ids
|
||||
start_token_ids = [s for s, _e in data_token_pairs]
|
||||
end_tokens_ids = [e for _s, e in data_token_pairs]
|
||||
# First start token marks new data
|
||||
data_start_token = start_token_ids[0]
|
||||
|
||||
padded_ids = []
|
||||
last_idx = 0
|
||||
data_idx = -1
|
||||
|
||||
start_indices = [i for i, x in enumerate(input_ids) if x in start_token_ids]
|
||||
end_indices = [i for i, x in enumerate(input_ids) if x in end_tokens_ids]
|
||||
|
||||
if len(start_indices) != len(end_indices):
|
||||
return input_ids
|
||||
|
||||
for start_idx, end_idx in zip(start_indices, end_indices):
|
||||
padded_ids.extend(input_ids[last_idx : start_idx + 1])
|
||||
|
||||
if input_ids[start_idx] == data_start_token:
|
||||
data_idx += 1
|
||||
image_inputs.image_offsets += [start_idx]
|
||||
|
||||
num_tokens = end_idx - start_idx - 1
|
||||
pad_value = pad_values[data_idx]
|
||||
padded_ids.extend([pad_value] * num_tokens)
|
||||
|
||||
last_idx = end_idx
|
||||
|
||||
padded_ids.extend(input_ids[last_idx:])
|
||||
|
||||
assert len(input_ids) == len(padded_ids)
|
||||
return padded_ids
|
||||
|
||||
|
||||
class MultModalityDataPaddingPatternSingleToken(MultiModalityDataPaddingPattern):
|
||||
"""In this pattern, data is represented with a special token_id ( image_inputs.im_token_id ),
|
||||
which needs first to be expanded to multiple tokens, then replaced with their padding values
|
||||
|
||||
This strategy should be used when a single data token represents content that should
|
||||
be expanded to multiple tokens during processing.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self, num_data_token_calc_func: Callable[[Tuple[int, int, int]], int]
|
||||
) -> None:
|
||||
self.num_data_token_calc_func = num_data_token_calc_func
|
||||
|
||||
def pad_input_tokens(
|
||||
self, input_ids: List[int], image_inputs: ImageInputs
|
||||
) -> List[int]:
|
||||
"""
|
||||
This function will follow the procedure of:
|
||||
1. the data token will be expanded, of which the final number will be calculated by `num_data_token_calc_func`
|
||||
2. the padded data tokens will be replaced with their pad_values
|
||||
"""
|
||||
image_grid_thws = image_inputs.image_grid_thws
|
||||
pad_values = image_inputs.pad_values
|
||||
|
||||
image_indices = [
|
||||
idx
|
||||
for idx, token in enumerate(input_ids)
|
||||
if token == image_inputs.im_token_id
|
||||
]
|
||||
|
||||
image_inputs.image_offsets = []
|
||||
|
||||
input_ids_with_image = []
|
||||
for image_cnt, _ in enumerate(image_grid_thws):
|
||||
print(f"image_cnt {image_cnt}")
|
||||
num_image_tokens = self.num_data_token_calc_func(image_grid_thws[image_cnt])
|
||||
if image_cnt == 0:
|
||||
non_image_tokens = input_ids[: image_indices[image_cnt]]
|
||||
else:
|
||||
non_image_tokens = input_ids[
|
||||
image_indices[image_cnt - 1] + 1 : image_indices[image_cnt]
|
||||
]
|
||||
input_ids_with_image.extend(non_image_tokens)
|
||||
image_inputs.image_offsets.append(len(input_ids_with_image))
|
||||
pad_ids = pad_values * (
|
||||
(num_image_tokens + len(pad_values)) // len(pad_values)
|
||||
)
|
||||
input_ids_with_image.extend(pad_ids[:num_image_tokens])
|
||||
input_ids_with_image.extend(input_ids[image_indices[-1] + 1 :])
|
||||
|
||||
return input_ids_with_image
|
||||
@@ -77,6 +77,7 @@ global_server_args_dict = {
|
||||
"enable_flashmla": ServerArgs.enable_flashmla,
|
||||
"disable_radix_cache": ServerArgs.disable_radix_cache,
|
||||
"flashinfer_mla_disable_ragged": ServerArgs.flashinfer_mla_disable_ragged,
|
||||
"chunked_prefill_size": ServerArgs.chunked_prefill_size,
|
||||
}
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@@ -160,7 +161,8 @@ class ImageInputs:
|
||||
aspect_ratio_mask: Optional[List[torch.Tensor]] = None
|
||||
|
||||
# QWen2-VL related
|
||||
image_grid_thws: List[Tuple[int, int, int]] = None
|
||||
# [num_of_images, t, h, w]
|
||||
image_grid_thws: torch.Tensor = None
|
||||
mrope_position_delta: Optional[torch.Tensor] = None
|
||||
# Qwen2-VL video related
|
||||
video_token_id: Optional[int] = None
|
||||
@@ -168,7 +170,7 @@ class ImageInputs:
|
||||
second_per_grid_ts: Optional[List[torch.Tensor]] = None
|
||||
|
||||
# deepseek vl2 related
|
||||
image_seq_mask: Optional[List[torch.Tensor]] = None
|
||||
images_emb_mask: Optional[List[torch.Tensor]] = None
|
||||
image_spatial_crop: Optional[List[torch.Tensor]] = None
|
||||
|
||||
# The id of the single-image placeholder token
|
||||
@@ -182,9 +184,6 @@ class ImageInputs:
|
||||
slice_end_id: Optional[int] = None
|
||||
tgt_sizes: Optional[list] = None
|
||||
|
||||
# denotes the number of valid image tokens in each image
|
||||
images_emb_mask: Optional[torch.BoolTensor] = None
|
||||
|
||||
@staticmethod
|
||||
def from_dict(obj: dict):
|
||||
ret = ImageInputs(
|
||||
@@ -204,7 +203,7 @@ class ImageInputs:
|
||||
"aspect_ratio_ids",
|
||||
"aspect_ratio_mask",
|
||||
"image_grid_thws",
|
||||
"image_seq_mask",
|
||||
"images_emb_mask",
|
||||
"image_spatial_crop",
|
||||
"im_token_id",
|
||||
"im_start_id",
|
||||
@@ -212,20 +211,58 @@ class ImageInputs:
|
||||
"slice_start_id",
|
||||
"slice_end_id",
|
||||
"tgt_sizes",
|
||||
"images_emb_mask",
|
||||
]
|
||||
for arg in optional_args:
|
||||
if arg in obj:
|
||||
setattr(ret, arg, obj[arg])
|
||||
|
||||
# validate
|
||||
assert (
|
||||
isinstance(ret.pixel_values, torch.Tensor)
|
||||
or isinstance(ret.pixel_values, np.ndarray)
|
||||
or isinstance(ret.pixel_values, list)
|
||||
)
|
||||
|
||||
return ret
|
||||
|
||||
def merge(self, other):
|
||||
def merge(self, other: ImageInputs):
|
||||
"""
|
||||
merge image inputs when requests are being merged
|
||||
"""
|
||||
assert self.pixel_values.shape[1:] == other.pixel_values.shape[1:]
|
||||
self.pixel_values = np.concatenate([self.pixel_values, other.pixel_values])
|
||||
if isinstance(self.pixel_values, list):
|
||||
# in some rare cases, pixel values are list of patches with different shapes
|
||||
# e.g. minicpm
|
||||
self.pixel_values += other.pixel_values
|
||||
else:
|
||||
assert (
|
||||
self.pixel_values.shape[1:] == other.pixel_values.shape[1:]
|
||||
), f"{self.pixel_values.shape[1:]} vs {other.pixel_values.shape[1:]}"
|
||||
self.pixel_values = np.concatenate([self.pixel_values, other.pixel_values])
|
||||
|
||||
# args would be stacked along first dim
|
||||
# usually these are already tensors
|
||||
stack_args = [
|
||||
# TODO: merge with image_grid_thws, basically the same thing
|
||||
"tgt_sizes",
|
||||
"image_spatial_crop",
|
||||
]
|
||||
for arg in stack_args:
|
||||
if getattr(self, arg, None) is None:
|
||||
setattr(self, arg, getattr(other, arg, None))
|
||||
elif getattr(other, arg, None) is not None:
|
||||
# self and other both not None
|
||||
setattr(
|
||||
self,
|
||||
arg,
|
||||
torch.cat([getattr(self, arg), getattr(other, arg)], dim=0),
|
||||
)
|
||||
|
||||
if self.image_grid_thws is None:
|
||||
self.image_grid_thws = other.image_grid_thws
|
||||
elif other.image_grid_thws is not None:
|
||||
self.image_grid_thws = torch.concat(
|
||||
[self.image_grid_thws, other.image_grid_thws]
|
||||
)
|
||||
|
||||
# Use image hash as fake token_ids. We use this as the key for prefix matching in the radix cache.
|
||||
# Please note that if the `input_ids` is later used in the model forward,
|
||||
@@ -233,7 +270,7 @@ class ImageInputs:
|
||||
# errors in cuda kernels. See also llava.py for example.
|
||||
self.image_hashes += other.image_hashes
|
||||
self.pad_values = [x % (1 << 30) for x in self.image_hashes]
|
||||
|
||||
# args needed to be merged
|
||||
optional_args = [
|
||||
"image_sizes",
|
||||
"image_offsets",
|
||||
@@ -241,13 +278,13 @@ class ImageInputs:
|
||||
# "modalities", # modalities should be ["multi-images"] (one entry) even for multiple images
|
||||
"aspect_ratio_ids",
|
||||
"aspect_ratio_mask",
|
||||
"image_grid_thws",
|
||||
"image_seq_mask",
|
||||
"image_spatial_crop",
|
||||
"images_emb_mask",
|
||||
]
|
||||
for arg in optional_args:
|
||||
if getattr(self, arg, None) is not None:
|
||||
setattr(self, arg, getattr(self, arg) + getattr(other, arg))
|
||||
self_arg = getattr(self, arg, None)
|
||||
if self_arg is not None:
|
||||
setattr(self, arg, self_arg + getattr(other, arg))
|
||||
# other args would be kept intact
|
||||
|
||||
|
||||
class Req:
|
||||
|
||||
@@ -179,7 +179,7 @@ class TokenizerManager:
|
||||
)
|
||||
|
||||
# We want to parallelize the image pre-processing so we create an executor for it
|
||||
# We creat image_processor for any skip_tokenizer_init to make sure we still encode
|
||||
# We create image_processor for any skip_tokenizer_init to make sure we still encode
|
||||
# images even with skip_tokenizer_init=False.
|
||||
self.image_processor = get_image_processor(
|
||||
self.model_config.hf_config, server_args, _processor
|
||||
|
||||
@@ -332,7 +332,7 @@ class ForwardBatch:
|
||||
|
||||
return ret
|
||||
|
||||
def get_merged_image_inputs(self) -> Optional[ImageInputs]:
|
||||
def merge_image_inputs(self) -> Optional[ImageInputs]:
|
||||
"""
|
||||
Merge all image inputs in the batch into a single ImageInputs object.
|
||||
|
||||
@@ -358,6 +358,16 @@ class ForwardBatch:
|
||||
|
||||
return merged
|
||||
|
||||
def contains_image_inputs(self) -> bool:
|
||||
""" """
|
||||
if self.image_inputs is None:
|
||||
return True
|
||||
return any(
|
||||
image_input.pixel_values is not None and image_input.pixel_values is not []
|
||||
for image_input in self.image_inputs
|
||||
if image_input is not None
|
||||
)
|
||||
|
||||
def _compute_mrope_positions(
|
||||
self, model_runner: ModelRunner, batch: ModelWorkerBatch
|
||||
):
|
||||
|
||||
@@ -273,7 +273,7 @@ class ModelRunner:
|
||||
if self.model_config.hf_config.architectures == ["DeepseekVL2ForCausalLM"]:
|
||||
# TODO: deepseek-vl2 does not support radix cache now, set disable_radix_cache=True automatically
|
||||
logger.info(
|
||||
"Automatically turn off --chunked-prefill-size and disable radix cache for deekseek-vl2."
|
||||
"Automatically turn off --chunked-prefill-size and disable radix cache for deepseek-vl2."
|
||||
)
|
||||
server_args.chunked_prefill_size = -1
|
||||
server_args.disable_radix_cache = True
|
||||
|
||||
@@ -47,8 +47,9 @@ from sglang.srt.configs.janus_pro import *
|
||||
from sglang.srt.layers.attention.vision import VisionAttention
|
||||
from sglang.srt.layers.logits_processor import LogitsProcessor
|
||||
from sglang.srt.layers.quantization import QuantizationConfig
|
||||
from sglang.srt.managers.multi_modality_padding import (
|
||||
from sglang.srt.managers.mm_utils import (
|
||||
MultiModalityDataPaddingPatternTokenPairs,
|
||||
general_mm_embed_routine,
|
||||
)
|
||||
from sglang.srt.managers.schedule_batch import ImageInputs
|
||||
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
|
||||
@@ -1958,82 +1959,8 @@ class MultiModalityCausalLM(MultiModalityPreTrainedModel):
|
||||
)
|
||||
self.logits_processor = LogitsProcessor(config)
|
||||
|
||||
def prepare_images_seq_mask(
|
||||
self, input_ids: torch.Tensor, image_inputs: ImageInputs
|
||||
) -> Optional[torch.LongTensor]:
|
||||
images_seq_mask = torch.isin(
|
||||
input_ids, torch.tensor(image_inputs.pad_values, device=input_ids.device)
|
||||
)
|
||||
if images_seq_mask.sum() == 0:
|
||||
# sometimes image_inputs is not empty, but input_ids contain no image token because of prefix-cache
|
||||
return None
|
||||
else:
|
||||
return images_seq_mask
|
||||
|
||||
@torch.no_grad()
|
||||
def forward(
|
||||
self,
|
||||
input_ids: torch.LongTensor,
|
||||
positions: torch.Tensor,
|
||||
forward_batch: ForwardBatch,
|
||||
) -> torch.Tensor:
|
||||
|
||||
inputs_embeds = None
|
||||
if (
|
||||
forward_batch.image_inputs is not None
|
||||
and len(forward_batch.image_inputs) != 0
|
||||
and forward_batch.image_inputs[0] is not None
|
||||
):
|
||||
|
||||
image_inputs = forward_batch.image_inputs[0]
|
||||
|
||||
images_seq_mask = self.prepare_images_seq_mask(
|
||||
input_ids=input_ids, image_inputs=image_inputs
|
||||
)
|
||||
|
||||
if images_seq_mask is not None:
|
||||
input_ids.clamp_(min=0, max=self.config.vocab_size - 1)
|
||||
inputs_embeds = self.prepare_inputs_embeds(
|
||||
input_ids=input_ids,
|
||||
pixel_values=image_inputs.pixel_values,
|
||||
images_seq_mask=images_seq_mask,
|
||||
images_emb_mask=image_inputs.images_emb_mask,
|
||||
)
|
||||
input_ids = None
|
||||
|
||||
if input_ids is not None:
|
||||
input_ids.clamp_(min=0, max=self.config.vocab_size - 1)
|
||||
|
||||
return self.language_model(
|
||||
input_ids=input_ids,
|
||||
positions=positions,
|
||||
forward_batch=forward_batch,
|
||||
input_embeds=inputs_embeds,
|
||||
get_embedding=False,
|
||||
)
|
||||
|
||||
def prepare_inputs_embeds(
|
||||
self,
|
||||
input_ids: torch.LongTensor,
|
||||
pixel_values: torch.FloatTensor,
|
||||
images_seq_mask: torch.LongTensor,
|
||||
images_emb_mask: torch.BoolTensor,
|
||||
**_kwargs,
|
||||
):
|
||||
"""
|
||||
|
||||
Args:
|
||||
input_ids (torch.LongTensor): [b, T]
|
||||
pixel_values (torch.FloatTensor): [b, n_images, 3, h, w]
|
||||
images_seq_mask (torch.BoolTensor): [b, T]
|
||||
images_emb_mask (torch.BoolTensor): [b, n_images, n_image_tokens]
|
||||
|
||||
assert torch.sum(images_seq_mask) == torch.sum(images_emb_mask)
|
||||
|
||||
Returns:
|
||||
input_embeds (torch.Tensor): [b, T, D]
|
||||
"""
|
||||
|
||||
def get_image_feature(self, image_input: ImageInputs) -> torch.Tensor:
|
||||
pixel_values = image_input.pixel_values
|
||||
bs, n = pixel_values.shape[0:2]
|
||||
pixel_values = pixel_values.to(
|
||||
device=self.vision_model.device, dtype=self.vision_model.dtype
|
||||
@@ -2045,18 +1972,35 @@ class MultiModalityCausalLM(MultiModalityPreTrainedModel):
|
||||
|
||||
# [b x n, T2, D] -> [b, n x T2, D]
|
||||
images_embeds = rearrange(images_embeds, "(b n) t d -> b (n t) d", b=bs, n=n)
|
||||
# [b, n, T2] -> [b, n x T2]
|
||||
images_emb_mask = rearrange(images_emb_mask, "b n t -> b (n t)")
|
||||
|
||||
# [b, T, D]
|
||||
# ignore the image embeddings
|
||||
input_ids[input_ids < 0] = 0
|
||||
inputs_embeds = self.language_model.model.embed_tokens(input_ids)
|
||||
return images_embeds
|
||||
|
||||
# replace with the image embeddings
|
||||
inputs_embeds[images_seq_mask] = images_embeds[images_emb_mask]
|
||||
def get_input_embeddings(self) -> nn.Embedding:
|
||||
return self.language_model.model.embed_tokens
|
||||
|
||||
return inputs_embeds
|
||||
@torch.no_grad()
|
||||
def forward(
|
||||
self,
|
||||
input_ids: torch.LongTensor,
|
||||
positions: torch.Tensor,
|
||||
forward_batch: ForwardBatch,
|
||||
) -> torch.Tensor:
|
||||
|
||||
inputs_embeds = general_mm_embed_routine(
|
||||
input_ids=input_ids,
|
||||
positions=positions,
|
||||
forward_batch=forward_batch,
|
||||
embed_tokens=self.get_input_embeddings(),
|
||||
image_embedding_func=self.get_image_feature,
|
||||
)
|
||||
|
||||
return self.language_model(
|
||||
input_ids=None,
|
||||
positions=positions,
|
||||
forward_batch=forward_batch,
|
||||
input_embeds=inputs_embeds,
|
||||
get_embedding=False,
|
||||
)
|
||||
|
||||
def prepare_gen_img_embeds(self, image_ids: torch.LongTensor):
|
||||
return self.gen_aligner(self.gen_embed(image_ids))
|
||||
|
||||
@@ -1,34 +1,16 @@
|
||||
import collections
|
||||
import itertools
|
||||
import math
|
||||
import warnings
|
||||
from enum import Enum
|
||||
from functools import partial
|
||||
from typing import Callable, Iterable, List, Optional, Tuple, Type, Union
|
||||
from typing import Iterable, List, Optional, Tuple
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from einops import rearrange, repeat
|
||||
from torch import nn
|
||||
|
||||
from sglang.srt.configs import DeepseekVL2Config
|
||||
from sglang.srt.configs.deepseekvl2 import (
|
||||
DeepseekVL2Config,
|
||||
DeepseekVL2MlpProjectorConfig,
|
||||
)
|
||||
from sglang.srt.layers.attention.vision import VisionAttention
|
||||
from sglang.srt.layers.layernorm import RMSNorm
|
||||
from sglang.srt.layers.linear import (
|
||||
ColumnParallelLinear,
|
||||
LinearBase,
|
||||
ReplicatedLinear,
|
||||
RowParallelLinear,
|
||||
)
|
||||
from sglang.srt.layers.linear import ReplicatedLinear
|
||||
from sglang.srt.layers.quantization.base_config import QuantizationConfig
|
||||
from sglang.srt.layers.vocab_parallel_embedding import (
|
||||
ParallelLMHead,
|
||||
VocabParallelEmbedding,
|
||||
)
|
||||
from sglang.srt.managers.schedule_batch import ImageInputs
|
||||
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
|
||||
from sglang.srt.model_loader.weight_utils import default_weight_loader
|
||||
@@ -233,11 +215,11 @@ class DeepseekVL2ForCausalLM(nn.Module):
|
||||
forward_batch: ForwardBatch,
|
||||
**kwargs: object,
|
||||
):
|
||||
|
||||
input_embeds = self.language_model.model.embed_tokens(input_ids)
|
||||
if forward_batch.forward_mode.is_extend() and forward_batch.image_inputs != [
|
||||
None
|
||||
]:
|
||||
if (
|
||||
forward_batch.forward_mode.is_extend()
|
||||
and forward_batch.contains_image_inputs()
|
||||
):
|
||||
extend_start_loc_cpu = forward_batch.extend_start_loc.cpu().numpy()
|
||||
extend_seq_lens_cpu = forward_batch.extend_seq_lens.cpu().numpy()
|
||||
for idx, image in enumerate(forward_batch.image_inputs):
|
||||
@@ -245,17 +227,11 @@ class DeepseekVL2ForCausalLM(nn.Module):
|
||||
continue
|
||||
start_idx = extend_start_loc_cpu[idx]
|
||||
end_idx = start_idx + extend_seq_lens_cpu[idx]
|
||||
pixel_values = image.pixel_values.to(
|
||||
device="cuda", dtype=torch.bfloat16
|
||||
)
|
||||
image_seq_mask = image.image_seq_mask.to(device="cuda")
|
||||
image_spatial_crop = image.image_spatial_crop
|
||||
input_embeds[start_idx:end_idx] = self.prepare_inputs_embeds(
|
||||
pixel_values,
|
||||
image_seq_mask,
|
||||
image_spatial_crop,
|
||||
input_embeds[start_idx:end_idx],
|
||||
)
|
||||
images_emb_mask = image.images_emb_mask.to(device="cuda")
|
||||
image_features = self.get_image_feature(image)
|
||||
input_embeds[start_idx:end_idx] = input_embeds[
|
||||
start_idx:end_idx
|
||||
].masked_scatter(images_emb_mask.unsqueeze(-1), image_features)
|
||||
|
||||
outputs = self.language_model.forward(
|
||||
input_ids=input_ids,
|
||||
@@ -289,20 +265,17 @@ class DeepseekVL2ForCausalLM(nn.Module):
|
||||
def pad_input_ids(self, input_ids: List[int], image_inputs: ImageInputs):
|
||||
return input_ids
|
||||
|
||||
def prepare_inputs_embeds(
|
||||
self,
|
||||
pixel_values,
|
||||
images_seq_mask,
|
||||
images_spatial_crop,
|
||||
input_embeds,
|
||||
):
|
||||
def get_image_feature(self, image_input: ImageInputs):
|
||||
pixel_values = image_input.pixel_values.type(
|
||||
next(self.vision.parameters()).dtype
|
||||
).to(device=next(self.vision.parameters()).device)
|
||||
image_feature = self.vision.forward_features(pixel_values)
|
||||
images_embeds = self.projector(image_feature)
|
||||
_, hw, n_dim = images_embeds.shape
|
||||
h = w = int(hw**0.5)
|
||||
|
||||
tile_index = 0
|
||||
images_in_this_batch = []
|
||||
images_spatial_crop = image_input.image_spatial_crop
|
||||
for jdx in range(images_spatial_crop.shape[1]):
|
||||
num_width_tiles, num_height_tiles = images_spatial_crop[0, jdx]
|
||||
if num_width_tiles == 0 or num_height_tiles == 0:
|
||||
@@ -379,13 +352,7 @@ class DeepseekVL2ForCausalLM(nn.Module):
|
||||
|
||||
images_in_this_batch.append(global_local_features)
|
||||
|
||||
if len(images_in_this_batch) > 0:
|
||||
images_in_this_batch = torch.cat(images_in_this_batch, dim=0)
|
||||
input_embeds.masked_scatter_(
|
||||
images_seq_mask.unsqueeze(-1), images_in_this_batch
|
||||
)
|
||||
|
||||
return input_embeds
|
||||
return torch.cat(images_in_this_batch, dim=0)
|
||||
|
||||
|
||||
EntryClass = DeepseekVL2ForCausalLM
|
||||
|
||||
@@ -37,11 +37,8 @@ from sglang.srt.layers.linear import (
|
||||
from sglang.srt.layers.logits_processor import LogitsProcessor
|
||||
from sglang.srt.layers.quantization.base_config import QuantizationConfig
|
||||
from sglang.srt.layers.radix_attention import RadixAttention
|
||||
from sglang.srt.layers.rotary_embedding import apply_rotary_pos_emb, get_rope
|
||||
from sglang.srt.layers.vocab_parallel_embedding import (
|
||||
ParallelLMHead,
|
||||
VocabParallelEmbedding,
|
||||
)
|
||||
from sglang.srt.layers.rotary_embedding import apply_rotary_pos_emb
|
||||
from sglang.srt.layers.vocab_parallel_embedding import ParallelLMHead
|
||||
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
|
||||
from sglang.srt.model_loader.weight_utils import (
|
||||
default_weight_loader,
|
||||
@@ -511,7 +508,7 @@ class Gemma3TextModel(PreTrainedModel):
|
||||
else:
|
||||
hidden_states = input_embeds
|
||||
|
||||
if len(positions.shape) == 1:
|
||||
if positions.dim() == 1:
|
||||
positions = einops.rearrange(positions, "s -> 1 s")
|
||||
|
||||
position_embeddings_global = self.rotary_emb(hidden_states, positions)
|
||||
@@ -609,11 +606,11 @@ class Gemma3ForCausalLM(PreTrainedModel):
|
||||
)
|
||||
self.post_init()
|
||||
|
||||
def get_input_embeddings(self):
|
||||
def get_input_embeddings(self) -> nn.Embedding:
|
||||
return self.model.embed_tokens
|
||||
|
||||
def dtype(self) -> torch.dtype:
|
||||
return self.model.layers[0].mlp.gate_up_proj.weight.dtype
|
||||
return next(self.parameters()).dtype
|
||||
|
||||
@torch.no_grad()
|
||||
def forward(
|
||||
|
||||
@@ -34,8 +34,9 @@ from sglang.srt.hf_transformers_utils import get_processor
|
||||
from sglang.srt.layers.layernorm import Gemma3RMSNorm
|
||||
from sglang.srt.layers.logits_processor import LogitsProcessor
|
||||
from sglang.srt.layers.quantization.base_config import QuantizationConfig
|
||||
from sglang.srt.managers.multi_modality_padding import (
|
||||
from sglang.srt.managers.mm_utils import (
|
||||
MultiModalityDataPaddingPatternTokenPairs,
|
||||
general_mm_embed_routine,
|
||||
)
|
||||
from sglang.srt.managers.schedule_batch import ImageInputs
|
||||
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
|
||||
@@ -264,10 +265,10 @@ class Gemma3ForConditionalGeneration(PreTrainedModel):
|
||||
kwargs["local_attn_masks"] = local_attn_masks
|
||||
return kwargs
|
||||
|
||||
def get_input_embeddings(self):
|
||||
def get_input_embeddings(self) -> nn.Embedding:
|
||||
return self.language_model.get_input_embeddings()
|
||||
|
||||
def get_image_features(self, pixel_values: torch.Tensor):
|
||||
def get_image_feature(self, image_input: ImageInputs):
|
||||
"""
|
||||
Projects the last hidden state from the vision model into language model space.
|
||||
|
||||
@@ -277,6 +278,7 @@ class Gemma3ForConditionalGeneration(PreTrainedModel):
|
||||
Returns:
|
||||
image_features (`torch.Tensor`): Image feature tensor of shape `(num_images, image_length, embed_dim)`).
|
||||
"""
|
||||
pixel_values = image_input.pixel_values
|
||||
pixel_values = pixel_values.to("cuda")
|
||||
pixel_values = pixel_values.to(dtype=self.language_model.dtype())
|
||||
|
||||
@@ -305,7 +307,7 @@ class Gemma3ForConditionalGeneration(PreTrainedModel):
|
||||
return inputs_embeds
|
||||
else:
|
||||
# print(f"image tokens from input_ids: {inputs_embeds[special_image_mask].numel()}")
|
||||
image_features = self.get_image_features(image_input.pixel_values)
|
||||
image_features = self.get_image_feature(image_input.pixel_values)
|
||||
|
||||
# print(f"image tokens from image embeddings: {image_features.numel()}")
|
||||
num_image_tokens_in_embedding = (
|
||||
@@ -397,20 +399,13 @@ class Gemma3ForConditionalGeneration(PreTrainedModel):
|
||||
else:
|
||||
llm_input_ids = input_ids
|
||||
|
||||
merged_image_input = forward_batch.get_merged_image_inputs()
|
||||
|
||||
if (
|
||||
not forward_batch.forward_mode.is_decode()
|
||||
and merged_image_input is not None
|
||||
):
|
||||
inputs_embeds = self.embed_image_inputs(
|
||||
input_ids=llm_input_ids,
|
||||
forward_batch=forward_batch,
|
||||
image_input=merged_image_input,
|
||||
)
|
||||
else:
|
||||
llm_input_ids.clamp_(min=0, max=self.vocab_size - 1)
|
||||
inputs_embeds = self.get_input_embeddings()(llm_input_ids)
|
||||
inputs_embeds = general_mm_embed_routine(
|
||||
input_ids=llm_input_ids,
|
||||
positions=positions,
|
||||
forward_batch=forward_batch,
|
||||
embed_tokens=self.get_input_embeddings(),
|
||||
image_embedding_func=self.get_image_feature,
|
||||
)
|
||||
|
||||
outputs = self.language_model(
|
||||
input_ids=None,
|
||||
|
||||
@@ -50,8 +50,9 @@ from sglang.srt.layers.linear import (
|
||||
)
|
||||
from sglang.srt.layers.logits_processor import LogitsProcessor
|
||||
from sglang.srt.layers.quantization.base_config import QuantizationConfig
|
||||
from sglang.srt.managers.multi_modality_padding import (
|
||||
from sglang.srt.managers.mm_utils import (
|
||||
MultiModalityDataPaddingPatternTokenPairs,
|
||||
embed_image_inputs,
|
||||
)
|
||||
from sglang.srt.managers.schedule_batch import ImageInputs
|
||||
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
|
||||
@@ -399,7 +400,7 @@ class Idefics2VisionTransformer(nn.Module):
|
||||
)
|
||||
self.post_layernorm = nn.LayerNorm(embed_dim, eps=config.layer_norm_eps)
|
||||
|
||||
def get_input_embeddings(self):
|
||||
def get_input_embeddings(self) -> nn.Embedding:
|
||||
return self.embeddings
|
||||
|
||||
def compute_cu_seqlens(self, tgt_sizes: torch.Tensor) -> torch.Tensor:
|
||||
@@ -762,42 +763,6 @@ class MiniCPMVBaseModel(nn.Module):
|
||||
valid_pairs_tensor = torch.tensor(valid_pairs, device=input_ids.device)
|
||||
return valid_pairs_tensor
|
||||
|
||||
def get_embedding(
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
image_inputs: Optional[MiniCPMVImageInputs],
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
vlm_embedding: torch.Tensor = self.llm.get_input_embeddings(input_ids)
|
||||
|
||||
if image_inputs is None: # No image
|
||||
vision_hidden_states = torch.tensor([], device=input_ids.device)
|
||||
else:
|
||||
if image_inputs["type"] == "image_embeds":
|
||||
vision_hidden_states = (
|
||||
image_inputs["data"]
|
||||
.type(vlm_embedding.dtype)
|
||||
.to(vlm_embedding.device)
|
||||
)
|
||||
else:
|
||||
vision_hidden_states = self.get_vision_hidden_states(image_inputs)
|
||||
# See NOTE in _parse_and_validate_inputs
|
||||
image_bounds = image_inputs["image_bounds"]
|
||||
if len(image_bounds) > 0:
|
||||
image_indices = torch.stack(
|
||||
[
|
||||
torch.arange(start, end, dtype=torch.long)
|
||||
for start, end in image_bounds.tolist()
|
||||
]
|
||||
).to(vlm_embedding.device)
|
||||
|
||||
vlm_embedding.scatter_(
|
||||
0,
|
||||
image_indices.view(-1, 1).repeat(1, vlm_embedding.shape[-1]),
|
||||
vision_hidden_states.view(-1, vision_hidden_states.shape[-1]),
|
||||
)
|
||||
|
||||
return vlm_embedding, vision_hidden_states
|
||||
|
||||
def _parse_and_validate_inputs(
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
@@ -836,46 +801,6 @@ class MiniCPMVBaseModel(nn.Module):
|
||||
type="image_embeds",
|
||||
)
|
||||
|
||||
if not isinstance(pixel_values, (torch.Tensor, list)):
|
||||
raise ValueError(
|
||||
"Incorrect type of pixel values. " f"Got type: {type(pixel_values)}"
|
||||
)
|
||||
|
||||
if not isinstance(tgt_sizes, (torch.Tensor, list)):
|
||||
raise ValueError(
|
||||
"Incorrect type of target sizes. " f"Got type: {type(tgt_sizes)}"
|
||||
)
|
||||
|
||||
if len(pixel_values) != len(tgt_sizes):
|
||||
raise ValueError(
|
||||
"Inconsistent batch lengths, found: "
|
||||
f"{len(pixel_values)} vs. {len(tgt_sizes)}"
|
||||
)
|
||||
|
||||
pixel_values_flat: List[torch.Tensor] = []
|
||||
tgt_sizes_flat: List[torch.Tensor] = []
|
||||
for pixel_b, tgt_b in zip(pixel_values, tgt_sizes):
|
||||
if len(pixel_b) != len(tgt_b):
|
||||
raise ValueError(
|
||||
"Inconsistent N lengths, found: " f"{len(pixel_b)} vs {len(tgt_b)}"
|
||||
)
|
||||
|
||||
for pixel_n, tgt_n in zip(pixel_b, tgt_b):
|
||||
pixel_values_flat += pixel_n
|
||||
tgt_sizes_flat += tgt_n
|
||||
|
||||
# NOTE: Input IDs does not contain image tokens during memory profiling,
|
||||
# so we allow it to be empty
|
||||
if len(pixel_values_flat) != len(tgt_sizes_flat):
|
||||
raise ValueError(
|
||||
"Inconsistent flattened lengths, found: "
|
||||
f"{len(pixel_values_flat)} vs. "
|
||||
f"{len(tgt_sizes_flat)}"
|
||||
)
|
||||
|
||||
if len(pixel_values_flat) == 0:
|
||||
return None
|
||||
|
||||
image_bounds = self._get_image_bounds(
|
||||
input_ids=input_ids,
|
||||
pad_values=pad_values,
|
||||
@@ -886,11 +811,50 @@ class MiniCPMVBaseModel(nn.Module):
|
||||
)
|
||||
return MiniCPMVImagePixelInputs(
|
||||
image_bounds=image_bounds.to(device=input_ids.device),
|
||||
data=pixel_values_flat,
|
||||
tgt_sizes=torch.stack(tgt_sizes_flat),
|
||||
data=pixel_values,
|
||||
tgt_sizes=tgt_sizes,
|
||||
type="pixel_values",
|
||||
)
|
||||
|
||||
def get_embedding(
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
image_inputs: Optional[MiniCPMVImageInputs],
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
vlm_embedding: torch.Tensor = self.llm.get_input_embeddings(input_ids)
|
||||
|
||||
if image_inputs is None: # No image
|
||||
vision_hidden_states = torch.tensor([], device=input_ids.device)
|
||||
else:
|
||||
if image_inputs["type"] == "image_embeds":
|
||||
vision_hidden_states = (
|
||||
image_inputs["data"]
|
||||
.type(vlm_embedding.dtype)
|
||||
.to(vlm_embedding.device)
|
||||
)
|
||||
else:
|
||||
vision_hidden_states = self.get_vision_hidden_states(image_inputs)
|
||||
# See NOTE in _parse_and_validate_inputs
|
||||
image_bounds = image_inputs["image_bounds"]
|
||||
if len(image_bounds) > 0:
|
||||
image_indices = torch.stack(
|
||||
[
|
||||
torch.arange(start, end, dtype=torch.long)
|
||||
for start, end in image_bounds.tolist()
|
||||
]
|
||||
).to(vlm_embedding.device)
|
||||
|
||||
vlm_embedding.scatter_(
|
||||
0,
|
||||
image_indices.view(-1, 1).repeat(1, vlm_embedding.shape[-1]),
|
||||
vision_hidden_states.view(-1, vision_hidden_states.shape[-1]),
|
||||
)
|
||||
|
||||
return vlm_embedding, vision_hidden_states
|
||||
|
||||
def get_input_embeddings(self) -> nn.Embedding:
|
||||
return self.llm.get_input_embedding()
|
||||
|
||||
def forward(
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
@@ -899,58 +863,29 @@ class MiniCPMVBaseModel(nn.Module):
|
||||
**kwargs: Any,
|
||||
) -> torch.Tensor:
|
||||
if (
|
||||
forward_batch.image_inputs is not None
|
||||
and len(forward_batch.image_inputs) > 0
|
||||
and forward_batch.image_inputs[0] is not None
|
||||
forward_batch.forward_mode.is_decode()
|
||||
or not forward_batch.contains_image_inputs()
|
||||
):
|
||||
# TODO: bath
|
||||
kwargs.update(
|
||||
{
|
||||
"pixel_values": (
|
||||
None
|
||||
if forward_batch.image_inputs is None
|
||||
else [
|
||||
i.pixel_values
|
||||
for i in forward_batch.image_inputs
|
||||
if i is not None
|
||||
]
|
||||
),
|
||||
"tgt_sizes": (
|
||||
None
|
||||
if forward_batch.image_inputs is None
|
||||
else [
|
||||
i.tgt_sizes
|
||||
for i in forward_batch.image_inputs
|
||||
if i is not None
|
||||
]
|
||||
),
|
||||
"im_start_id": forward_batch.image_inputs[0].im_start_id,
|
||||
"im_end_id": forward_batch.image_inputs[0].im_end_id,
|
||||
"slice_start_id": forward_batch.image_inputs[0].slice_start_id,
|
||||
"slice_end_id": forward_batch.image_inputs[0].slice_end_id,
|
||||
"pad_values": forward_batch.image_inputs[0].pad_values,
|
||||
}
|
||||
inputs_embeds: torch.Tensor = self.llm.get_input_embeddings(input_ids)
|
||||
else:
|
||||
# Clamp input ids. This is because the input_ids for the image tokens are
|
||||
# filled with the hash values of the image for the prefix matching in the radix attention.
|
||||
# There values are useless because their embeddings will be replaced by vision embeddings anyway.
|
||||
image_inputs = forward_batch.merge_image_inputs()
|
||||
inputs_embeds = embed_image_inputs(
|
||||
image_input=image_inputs,
|
||||
input_ids=input_ids,
|
||||
input_embedding=self.get_input_embeddings(),
|
||||
image_embedding_func=self.get_image_features,
|
||||
placeholder_token_ids=[image_inputs.im_token_id]
|
||||
+ image_inputs.pad_values,
|
||||
)
|
||||
|
||||
image_inputs = self._parse_and_validate_inputs(input_ids, **kwargs)
|
||||
|
||||
# Clamp input ids. This is because the input_ids for the image tokens are
|
||||
# filled with the hash values of the image for the prefix matching in the radix attention.
|
||||
# There values are useless because their embeddings will be replaced by vision embeddings anyway.
|
||||
input_ids.clamp_(min=0, max=self.config.vocab_size - 1)
|
||||
|
||||
vlm_embeddings, _ = self.get_embedding(input_ids, image_inputs)
|
||||
|
||||
# always pass the input via `inputs_embeds`
|
||||
# to make sure the computation graph is consistent
|
||||
# for `torch.compile` integration
|
||||
input_ids = None
|
||||
|
||||
hidden_states = self.llm.model(
|
||||
input_ids=input_ids,
|
||||
input_ids=None,
|
||||
positions=positions,
|
||||
forward_batch=forward_batch,
|
||||
input_embeds=vlm_embeddings,
|
||||
input_embeds=inputs_embeds,
|
||||
)
|
||||
|
||||
return self.logits_processor(
|
||||
@@ -990,7 +925,7 @@ class MiniCPMVBaseModel(nn.Module):
|
||||
) -> torch.Tensor:
|
||||
raise NotImplementedError
|
||||
|
||||
def get_vision_hidden_states(self, data: MiniCPMVImageInputs) -> torch.Tensor:
|
||||
def get_image_features(self, image_inputs: ImageInputs) -> torch.Tensor:
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
@@ -1100,12 +1035,14 @@ class MiniCPMV2_6(MiniCPMVBaseModel):
|
||||
)
|
||||
return vision_embedding
|
||||
|
||||
def get_vision_hidden_states(
|
||||
def get_image_features(
|
||||
self,
|
||||
data: MiniCPMVImageInputs,
|
||||
image_inputs: ImageInputs,
|
||||
) -> torch.Tensor:
|
||||
pixel_values = data["data"]
|
||||
tgt_sizes = data["tgt_sizes"]
|
||||
# list of tensors
|
||||
pixel_values = image_inputs.pixel_values
|
||||
|
||||
tgt_sizes = image_inputs.tgt_sizes
|
||||
|
||||
device = self.vpm.embeddings.position_embedding.weight.device
|
||||
dtype = self.vpm.embeddings.position_embedding.weight.dtype
|
||||
|
||||
@@ -361,6 +361,9 @@ class Qwen2ForCausalLM(nn.Module):
|
||||
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
|
||||
return self.model.get_input_embeddings(input_ids)
|
||||
|
||||
def get_input_embedding(self) -> nn.Embedding:
|
||||
return self.model.embed_tokens
|
||||
|
||||
@torch.no_grad()
|
||||
def forward(
|
||||
self,
|
||||
|
||||
@@ -26,7 +26,6 @@ import logging
|
||||
from functools import lru_cache, partial
|
||||
from typing import Iterable, List, Optional, Tuple, Type
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
@@ -54,14 +53,15 @@ from sglang.srt.layers.logits_processor import LogitsProcessor
|
||||
from sglang.srt.layers.pooler import Pooler, PoolingType
|
||||
from sglang.srt.layers.quantization.base_config import QuantizationConfig
|
||||
from sglang.srt.layers.vocab_parallel_embedding import ParallelLMHead
|
||||
from sglang.srt.managers.multi_modality_padding import (
|
||||
from sglang.srt.managers.mm_utils import (
|
||||
MultiModalityDataPaddingPatternTokenPairs,
|
||||
general_mm_embed_routine,
|
||||
)
|
||||
from sglang.srt.managers.schedule_batch import ImageInputs
|
||||
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
|
||||
from sglang.srt.model_loader.weight_utils import default_weight_loader
|
||||
from sglang.srt.models.qwen2 import Qwen2Model
|
||||
from sglang.srt.models.qwen2_vl import Qwen2VLImageInputs, Qwen2VLVideoInputs
|
||||
from sglang.srt.models.qwen2_vl import Qwen2VLVideoInputs
|
||||
from sglang.srt.utils import add_prefix
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@@ -326,13 +326,12 @@ class Qwen2_5_VisionTransformer(nn.Module):
|
||||
)
|
||||
|
||||
def get_window_index(self, grid_thw):
|
||||
window_index: list = []
|
||||
cu_window_seqlens: list = [0]
|
||||
window_index_id = 0
|
||||
vit_merger_window_size = (
|
||||
self.window_size // self.spatial_merge_size // self.patch_size
|
||||
)
|
||||
|
||||
window_index: list = []
|
||||
for grid_t, grid_h, grid_w in grid_thw:
|
||||
llm_grid_h, llm_grid_w = (
|
||||
grid_h // self.spatial_merge_size,
|
||||
@@ -369,7 +368,6 @@ class Qwen2_5_VisionTransformer(nn.Module):
|
||||
cu_window_seqlens.extend(cu_seqlens_tmp.tolist())
|
||||
window_index_id += (grid_t * llm_grid_h * llm_grid_w).item()
|
||||
window_index = torch.cat(window_index, dim=0)
|
||||
|
||||
return window_index, cu_window_seqlens
|
||||
|
||||
@property
|
||||
@@ -382,8 +380,10 @@ class Qwen2_5_VisionTransformer(nn.Module):
|
||||
|
||||
def rot_pos_emb(self, grid_thw: torch.Tensor) -> torch.Tensor:
|
||||
pos_ids = []
|
||||
for t, h, w in grid_thw:
|
||||
for i in range(grid_thw.size(0)):
|
||||
t, h, w = grid_thw[i].tolist()
|
||||
hpos_ids = torch.arange(h).unsqueeze(1).expand(-1, w)
|
||||
|
||||
hpos_ids = hpos_ids.reshape(
|
||||
h // self.spatial_merge_size,
|
||||
self.spatial_merge_size,
|
||||
@@ -402,6 +402,7 @@ class Qwen2_5_VisionTransformer(nn.Module):
|
||||
)
|
||||
wpos_ids = wpos_ids.permute(0, 2, 1, 3)
|
||||
wpos_ids = wpos_ids.flatten()
|
||||
|
||||
pos_ids.append(torch.stack([hpos_ids, wpos_ids], dim=-1).repeat(t, 1))
|
||||
pos_ids = torch.cat(pos_ids, dim=0)
|
||||
max_grid_size = grid_thw[:, 1:].max()
|
||||
@@ -443,9 +444,12 @@ class Qwen2_5_VisionTransformer(nn.Module):
|
||||
position_embeddings = (emb.cos(), emb.sin())
|
||||
|
||||
# compute cu_seqlens
|
||||
cu_seqlens = torch.repeat_interleave(
|
||||
grid_thw[:, 1] * grid_thw[:, 2], grid_thw[:, 0]
|
||||
).cumsum(dim=0, dtype=torch.int32)
|
||||
cu_seqlens = torch.cat(
|
||||
[
|
||||
torch.tensor([0], device=grid_thw.device),
|
||||
(grid_thw[:, 0] * grid_thw[:, 1] * grid_thw[:, 2]).cumsum(dim=0),
|
||||
]
|
||||
)
|
||||
cu_seqlens = F.pad(cu_seqlens, (1, 0), "constant", 0)
|
||||
|
||||
# transformers
|
||||
@@ -509,18 +513,6 @@ class Qwen2_5_VLForConditionalGeneration(nn.Module):
|
||||
self.logits_processor = LogitsProcessor(config)
|
||||
self.pooler = Pooler(pooling_type=PoolingType.LAST, normalize=True)
|
||||
|
||||
def calculate_num_image_tokens(self, image_grid_thw: Tuple[int, int, int]):
|
||||
processor = cached_get_processor(self.config._name_or_path)
|
||||
grid_t, grid_h, grid_w = image_grid_thw
|
||||
num_image_tokens = (
|
||||
grid_t
|
||||
* grid_h
|
||||
* grid_w
|
||||
// processor.image_processor.merge_size
|
||||
// processor.image_processor.merge_size
|
||||
)
|
||||
return num_image_tokens
|
||||
|
||||
def pad_input_ids(self, input_ids: List[int], image_inputs: ImageInputs):
|
||||
# Get all special token IDs
|
||||
im_start_id: int = image_inputs.im_start_id
|
||||
@@ -531,9 +523,9 @@ class Qwen2_5_VLForConditionalGeneration(nn.Module):
|
||||
|
||||
return pattern.pad_input_tokens(input_ids, image_inputs)
|
||||
|
||||
def _process_image_input(self, image_input: Qwen2VLImageInputs) -> torch.Tensor:
|
||||
pixel_values = image_input["pixel_values"].type(self.visual.dtype)
|
||||
image_embeds = self.visual(pixel_values, grid_thw=image_input["image_grid_thw"])
|
||||
def get_image_feature(self, image_input: ImageInputs) -> torch.Tensor:
|
||||
pixel_values = image_input.pixel_values.type(self.visual.dtype)
|
||||
image_embeds = self.visual(pixel_values, grid_thw=image_input.image_grid_thws)
|
||||
return image_embeds
|
||||
|
||||
def _process_video_input(self, video_input: Qwen2VLVideoInputs) -> torch.Tensor:
|
||||
@@ -543,6 +535,9 @@ class Qwen2_5_VLForConditionalGeneration(nn.Module):
|
||||
)
|
||||
return video_embeds
|
||||
|
||||
def get_input_embeddings(self):
|
||||
return self.model.embed_tokens
|
||||
|
||||
def forward(
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
@@ -565,86 +560,26 @@ class Qwen2_5_VLForConditionalGeneration(nn.Module):
|
||||
if getattr(self.config, "rope_scaling", {}).get("type", None) == "mrope":
|
||||
positions = forward_batch.mrope_positions
|
||||
|
||||
image_inputs = None
|
||||
if forward_batch.image_inputs is not None:
|
||||
image_inputs = [
|
||||
img for img in forward_batch.image_inputs if img is not None
|
||||
]
|
||||
|
||||
if (
|
||||
if not (
|
||||
forward_batch.forward_mode.is_decode()
|
||||
or image_inputs is None
|
||||
or len(image_inputs) == 0
|
||||
or not forward_batch.contains_image_inputs()
|
||||
):
|
||||
inputs_embeds = self.model.embed_tokens(input_ids)
|
||||
else:
|
||||
if getattr(self.config, "rope_scaling", {}).get("type", None) == "mrope":
|
||||
assert positions.ndim == 2 and positions.size(0) == 3, (
|
||||
"multimodal section rotary embedding requires "
|
||||
f"(3, seq_len) positions, but got {positions.size()}"
|
||||
)
|
||||
|
||||
# Clamp input ids. This is because the input_ids for the image tokens are
|
||||
# filled with the hash values of the image for the prefix matching in the radix attention.
|
||||
# There values are useless because their embeddings will be replaced by vision embeddings anyway.
|
||||
input_ids.clamp_(min=0, max=self.config.vocab_size - 1)
|
||||
# [B, s, hidden_size]
|
||||
inputs_embeds = self.model.embed_tokens(input_ids)
|
||||
extend_start_loc_cpu = forward_batch.extend_start_loc.cpu().numpy()
|
||||
prefix_lens_cpu = forward_batch.extend_prefix_lens_cpu
|
||||
for i, image in enumerate(forward_batch.image_inputs):
|
||||
if image is None or image.pixel_values is None:
|
||||
continue
|
||||
start_idx = extend_start_loc_cpu[i]
|
||||
prefix_len = prefix_lens_cpu[i]
|
||||
|
||||
pixel_values = image.pixel_values.to(device="cuda")
|
||||
|
||||
image_grid_thws = torch.tensor(
|
||||
np.array(image.image_grid_thws), device="cuda"
|
||||
)
|
||||
image_offsets = image.image_offsets
|
||||
image_input = Qwen2VLImageInputs(
|
||||
pixel_values=pixel_values, image_grid_thw=image_grid_thws
|
||||
)
|
||||
image_embeds = self._process_image_input(image_input)
|
||||
|
||||
image_embeds_offset = 0
|
||||
for idx, image_offset in enumerate(image_offsets):
|
||||
if image_offset < prefix_len:
|
||||
continue
|
||||
num_image_tokens = self.calculate_num_image_tokens(
|
||||
image_grid_thws[idx]
|
||||
)
|
||||
|
||||
left_idx = start_idx + (image_offset - prefix_len)
|
||||
right_idx = left_idx + num_image_tokens
|
||||
|
||||
tp_size = get_tensor_model_parallel_world_size()
|
||||
|
||||
hidden_size = image_embeds.shape[-1]
|
||||
|
||||
if hidden_size % tp_size != 0:
|
||||
padding_size = tp_size - (hidden_size % tp_size)
|
||||
image_embeds = F.pad(image_embeds, (0, padding_size))
|
||||
inputs_embeds = F.pad(inputs_embeds, (0, padding_size))
|
||||
|
||||
hidden_chunk_size = image_embeds.shape[-1] // tp_size
|
||||
rank = get_tensor_model_parallel_rank()
|
||||
start_dim = rank * hidden_chunk_size
|
||||
end_dim = (rank + 1) * hidden_chunk_size
|
||||
inputs_embeds[left_idx:right_idx, ..., start_dim:end_dim] = (
|
||||
image_embeds[
|
||||
image_embeds_offset : image_embeds_offset
|
||||
+ num_image_tokens,
|
||||
...,
|
||||
start_dim:end_dim,
|
||||
]
|
||||
)
|
||||
image_embeds_offset += num_image_tokens
|
||||
inputs_embeds = general_mm_embed_routine(
|
||||
input_ids=input_ids,
|
||||
positions=positions,
|
||||
forward_batch=forward_batch,
|
||||
embed_tokens=self.get_input_embeddings(),
|
||||
image_embedding_func=self.get_image_feature,
|
||||
)
|
||||
|
||||
hidden_states = self.model(
|
||||
input_ids=input_ids,
|
||||
input_ids=None,
|
||||
positions=positions,
|
||||
forward_batch=forward_batch,
|
||||
input_embeds=inputs_embeds,
|
||||
|
||||
@@ -26,7 +26,6 @@ import logging
|
||||
from functools import lru_cache, partial
|
||||
from typing import Iterable, List, Optional, Tuple, Type, TypedDict
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
@@ -42,8 +41,9 @@ from sglang.srt.layers.logits_processor import LogitsProcessor
|
||||
from sglang.srt.layers.pooler import Pooler, PoolingType
|
||||
from sglang.srt.layers.quantization.base_config import QuantizationConfig
|
||||
from sglang.srt.layers.vocab_parallel_embedding import ParallelLMHead
|
||||
from sglang.srt.managers.multi_modality_padding import (
|
||||
from sglang.srt.managers.mm_utils import (
|
||||
MultiModalityDataPaddingPatternTokenPairs,
|
||||
general_mm_embed_routine,
|
||||
)
|
||||
from sglang.srt.managers.schedule_batch import ImageInputs
|
||||
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
|
||||
@@ -351,7 +351,7 @@ class Qwen2VisionTransformer(nn.Module):
|
||||
|
||||
@property
|
||||
def dtype(self) -> torch.dtype:
|
||||
return self.blocks[0].mlp.fc2.weight.dtype
|
||||
return next(self.parameters()).dtype
|
||||
|
||||
@property
|
||||
def device(self) -> torch.device:
|
||||
@@ -359,7 +359,8 @@ class Qwen2VisionTransformer(nn.Module):
|
||||
|
||||
def rot_pos_emb(self, grid_thw: torch.Tensor) -> torch.Tensor:
|
||||
pos_ids = []
|
||||
for t, h, w in grid_thw:
|
||||
for i in range(grid_thw.size(0)):
|
||||
t, h, w = grid_thw[i].tolist()
|
||||
hpos_ids = torch.arange(h).unsqueeze(1).expand(-1, w)
|
||||
wpos_ids = torch.arange(w).unsqueeze(0).expand(h, -1)
|
||||
hpos_ids = (
|
||||
@@ -480,9 +481,9 @@ class Qwen2VLForConditionalGeneration(nn.Module):
|
||||
pattern = MultiModalityDataPaddingPatternTokenPairs(media_token_pairs)
|
||||
return pattern.pad_input_tokens(input_ids, image_inputs)
|
||||
|
||||
def _process_image_input(self, image_input: Qwen2VLImageInputs) -> torch.Tensor:
|
||||
pixel_values = image_input["pixel_values"].type(self.visual.dtype)
|
||||
image_embeds = self.visual(pixel_values, grid_thw=image_input["image_grid_thw"])
|
||||
def get_image_feature(self, image_input: ImageInputs) -> torch.Tensor:
|
||||
pixel_values = image_input.pixel_values.type(self.visual.dtype)
|
||||
image_embeds = self.visual(pixel_values, grid_thw=image_input.image_grid_thws)
|
||||
return image_embeds
|
||||
|
||||
def _process_video_input(self, video_input: Qwen2VLVideoInputs) -> torch.Tensor:
|
||||
@@ -492,6 +493,9 @@ class Qwen2VLForConditionalGeneration(nn.Module):
|
||||
)
|
||||
return video_embeds
|
||||
|
||||
def get_input_embeddings(self):
|
||||
return self.model.embed_tokens
|
||||
|
||||
def forward(
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
@@ -514,67 +518,26 @@ class Qwen2VLForConditionalGeneration(nn.Module):
|
||||
if getattr(self.config, "rope_scaling", {}).get("type", None) == "mrope":
|
||||
positions = forward_batch.mrope_positions
|
||||
|
||||
image_inputs = None
|
||||
if forward_batch.image_inputs is not None:
|
||||
image_inputs = [
|
||||
img for img in forward_batch.image_inputs if img is not None
|
||||
]
|
||||
|
||||
if (
|
||||
if not (
|
||||
forward_batch.forward_mode.is_decode()
|
||||
or image_inputs is None
|
||||
or len(image_inputs) == 0
|
||||
or not forward_batch.contains_image_inputs()
|
||||
):
|
||||
inputs_embeds = self.model.embed_tokens(input_ids)
|
||||
else:
|
||||
if getattr(self.config, "rope_scaling", {}).get("type", None) == "mrope":
|
||||
assert positions.ndim == 2 and positions.size(0) == 3, (
|
||||
"multimodal section rotary embedding requires "
|
||||
f"(3, seq_len) positions, but got {positions.size()}"
|
||||
)
|
||||
|
||||
# Clamp input ids. This is because the input_ids for the image tokens are
|
||||
# filled with the hash values of the image for the prefix matching in the radix attention.
|
||||
# There values are useless because their embeddings will be replaced by vision embeddings anyway.
|
||||
input_ids.clamp_(min=0, max=self.config.vocab_size - 1)
|
||||
|
||||
inputs_embeds = self.model.embed_tokens(input_ids)
|
||||
extend_start_loc_cpu = forward_batch.extend_start_loc.cpu().numpy()
|
||||
prefix_lens_cpu = forward_batch.extend_prefix_lens_cpu
|
||||
for i, image in enumerate(forward_batch.image_inputs):
|
||||
if image is None or image.pixel_values is None:
|
||||
continue
|
||||
start_idx = extend_start_loc_cpu[i]
|
||||
prefix_len = prefix_lens_cpu[i]
|
||||
pixel_values = image.pixel_values.clone()
|
||||
|
||||
image_grid_thws = torch.tensor(
|
||||
np.array(image.image_grid_thws), device="cuda"
|
||||
)
|
||||
image_offsets = image.image_offsets
|
||||
image_input = Qwen2VLImageInputs(
|
||||
pixel_values=pixel_values, image_grid_thw=image_grid_thws
|
||||
)
|
||||
image_embeds = self._process_image_input(image_input)
|
||||
|
||||
image_embeds_offset = 0
|
||||
for idx, image_offset in enumerate(image_offsets):
|
||||
if image_offset < prefix_len:
|
||||
continue
|
||||
num_image_tokens = self.calculate_num_image_tokens(
|
||||
image_grid_thws[idx]
|
||||
)
|
||||
|
||||
left_idx = start_idx + (image_offset - prefix_len + 1)
|
||||
right_idx = left_idx + num_image_tokens
|
||||
inputs_embeds[left_idx:right_idx] = image_embeds[
|
||||
image_embeds_offset : image_embeds_offset + num_image_tokens
|
||||
]
|
||||
image_embeds_offset += num_image_tokens
|
||||
input_ids = None
|
||||
inputs_embeds = general_mm_embed_routine(
|
||||
input_ids=input_ids,
|
||||
positions=positions,
|
||||
forward_batch=forward_batch,
|
||||
embed_tokens=self.get_input_embeddings(),
|
||||
image_embedding_func=self.get_image_feature,
|
||||
)
|
||||
|
||||
hidden_states = self.model(
|
||||
input_ids=input_ids,
|
||||
input_ids=None,
|
||||
positions=positions,
|
||||
forward_batch=forward_batch,
|
||||
input_embeds=inputs_embeds,
|
||||
|
||||
Reference in New Issue
Block a user