refactor: bug fixes and refactor for vlm (#4661)

This commit is contained in:
Mick
2025-03-23 13:48:49 +08:00
committed by GitHub
parent ca75741e86
commit 11577cedb7
31 changed files with 770 additions and 735 deletions

View File

@@ -9,8 +9,6 @@ import PIL
import torch
from PIL.Image import Image
from transformers import (
AutoImageProcessor,
AutoProcessor,
BaseImageProcessor,
BatchFeature,
LlamaConfig,
@@ -20,6 +18,7 @@ from transformers import (
)
from transformers.image_utils import to_numpy_array
from sglang.srt.configs.utils import register_image_processor, register_processor
from sglang.srt.mm_utils import expand2square
@@ -625,5 +624,5 @@ class VLMImageProcessorConfig(PretrainedConfig):
super().__init__(**kwargs)
AutoProcessor.register(MultiModalityConfig, VLChatProcessor, exist_ok=True)
AutoImageProcessor.register(VLMImageProcessorConfig, None, VLMImageProcessor, None)
register_processor(MultiModalityConfig, VLChatProcessor)
register_image_processor(MultiModalityConfig, VLMImageProcessor)

View File

@@ -460,6 +460,7 @@ def is_generation_model(model_architectures: List[str], is_embedding: bool = Fal
multimodal_model_archs = [
"DeepseekVL2ForCausalLM",
"LlavaLlamaForCausalLM",
"LlavaQwenForCausalLM",
"LlavaMistralForCausalLM",
@@ -472,7 +473,6 @@ multimodal_model_archs = [
"Qwen2_5_VLForConditionalGeneration",
"MiniCPMV",
"MultiModalityCausalLM",
"DeepseekVL2ForCausalLM",
]

View File

@@ -0,0 +1,25 @@
from typing import Type
from transformers import (
AutoImageProcessor,
AutoProcessor,
BaseImageProcessor,
PretrainedConfig,
ProcessorMixin,
)
def register_image_processor(
config: Type[PretrainedConfig], image_processor: Type[BaseImageProcessor]
):
"""
register customized hf image processor while removing hf impl
"""
AutoImageProcessor.register(config, None, image_processor, None, exist_ok=True)
def register_processor(config: Type[PretrainedConfig], processor: Type[ProcessorMixin]):
"""
register customized hf processor while removing hf impl
"""
AutoProcessor.register(config, processor, exist_ok=True)

View File

@@ -653,7 +653,7 @@ register_conv_template(
Conversation(
name="gemma-it",
system_message="You are a helpful assistant.",
system_template="<bos><start_of_turn>user{system_message}\n\n",
system_template="<start_of_turn>user{system_message}\n\n",
roles=("<start_of_turn>user\n", "<start_of_turn>model\n"),
sep="<end_of_turn>\n",
sep_style=SeparatorStyle.GEMMA3,

View File

@@ -143,9 +143,14 @@ class VisionAttention(nn.Module):
if position_embeddings is not None:
cos, sin = position_embeddings
original_shape = q.shape
q, k = q.view(s, head, -1), k.view(s, head, -1)
# [total_tokens, head, head_size]
q = q.view(-1, head, self.head_size)
k = k.view(-1, head, self.head_size)
q, k = apply_rotary_pos_emb(q, k, cos, sin)
q, k = q.reshape(original_shape), k.reshape(original_shape)
q = q.view(original_shape)
k = k.view(original_shape)
if self.use_qkv_parallel:
pass

View File

@@ -1,9 +1,12 @@
# TODO: also move pad_input_ids into this module
import importlib
import inspect
import logging
import pkgutil
from functools import lru_cache
from typing import Union
from torch import Tensor
from transformers import IMAGE_PROCESSOR_MAPPING
from sglang.srt.managers.image_processors.base_image_processor import (
@@ -18,9 +21,7 @@ logger = logging.getLogger(__name__)
IMAGE_PROCESSOR_MAPPING = {}
def get_image_processor(
hf_config, server_args: ServerArgs, processor
) -> BaseImageProcessor:
def get_image_processor(hf_config, server_args, processor) -> BaseImageProcessor:
for model_cls, processor_cls in IMAGE_PROCESSOR_MAPPING.items():
if model_cls.__name__ in hf_config.architectures:
return processor_cls(hf_config, server_args, processor)
@@ -42,13 +43,18 @@ def import_image_processors():
try:
module = importlib.import_module(name)
except Exception as e:
logger.warning(f"Ignore import error when loading {name}: " f"{e}")
logger.warning(f" Ignore import error when loading {name}: " f"{e}")
continue
if hasattr(module, "ImageProcessorMapping"):
entry = module.ImageProcessorMapping
if isinstance(entry, dict):
for processor_name, cls in entry.items():
IMAGE_PROCESSOR_MAPPING[processor_name] = cls
all_members = inspect.getmembers(module, inspect.isclass)
classes = [
member
for name, member in all_members
if member.__module__ == module.__name__
]
for cls in classes:
if issubclass(cls, BaseImageProcessor):
for arch in getattr(cls, "models"):
IMAGE_PROCESSOR_MAPPING[arch] = cls
# also register processors

View File

@@ -4,14 +4,14 @@ import dataclasses
import multiprocessing as mp
import os
from abc import ABC, abstractmethod
from typing import Optional
from typing import Optional, Union
import PIL
import transformers
from decord import VideoReader, cpu
from openai import BadRequestError
from PIL import Image
from sglang.srt.server_args import ServerArgs
from sglang.srt.utils import load_image
from sglang.utils import logger
@@ -31,8 +31,16 @@ class BaseImageProcessorOutput:
# input_text, with each frame of video/image represented as an image_token
input_text: str
def normalize(self):
for field_name in ["data_hashes", "image_sizes", "all_frames"]:
field = getattr(self, field_name, None)
if field is not None and isinstance(field, list) and len(field) == 0:
setattr(self, field_name, None)
class BaseImageProcessor(ABC):
models = []
def __init__(self, hf_config, server_args, _processor):
self.hf_config = hf_config
self._processor = _processor
@@ -40,6 +48,9 @@ class BaseImageProcessor(ABC):
# FIXME: not accurate, model and image specific
self.NUM_TOKEN_PER_FRAME = 330
# Initialize global processor first
init_global_processor(self, server_args)
self.executor = concurrent.futures.ProcessPoolExecutor(
initializer=init_global_processor,
mp_context=mp.get_context("fork"),
@@ -113,7 +124,7 @@ class BaseImageProcessor(ABC):
self,
input_ids: list[int],
image_data,
image_token: str,
image_token: Union[int, str],
max_req_input_len: int,
return_text: Optional[bool] = True,
discard_alpha_channel: bool = True,
@@ -122,9 +133,16 @@ class BaseImageProcessor(ABC):
Each frame of video/image will be replaced by a single image token
Args:
image_token: The token ID representing the image placeholder.
discard_alpha_channel: if True, discards the alpha channel in the returned images
"""
if isinstance(image_token, int):
image_token_str = self._processor.tokenizer.convert_ids_to_tokens(
image_token
)
else:
image_token_str = image_token
if isinstance(input_ids, list) and return_text:
assert len(input_ids) and isinstance(input_ids[0], int)
@@ -190,13 +208,11 @@ class BaseImageProcessor(ABC):
new_text += text_part
except Exception as e:
import openai
logger.error(f"An exception occurred while loading images: {e}")
raise BadRequestError(
f"An exception occurred while loading images: {e}"
)
continue
return BaseImageProcessorOutput(
image_hashes=hashes,
@@ -204,6 +220,8 @@ class BaseImageProcessor(ABC):
all_frames=images,
input_text=new_text,
)
out.normalize()
return out
class DummyImageProcessor(BaseImageProcessor):
@@ -214,9 +232,7 @@ class DummyImageProcessor(BaseImageProcessor):
return None
def init_global_processor(
sglang_image_processor: BaseImageProcessor, server_args: ServerArgs
):
def init_global_processor(sglang_image_processor: BaseImageProcessor, server_args):
"""Init the global processor for multi-modal models."""
global global_processor
transformers.logging.set_verbosity_error()

View File

@@ -16,13 +16,9 @@
# COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER
# IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN
# CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
import asyncio
import math
from typing import List, Union
import torch
from PIL import Image, ImageOps
from sglang.srt.managers.image_processor import BaseImageProcessor
from sglang.srt.managers.image_processors.base_image_processor import (
@@ -32,18 +28,24 @@ from sglang.srt.models.deepseek_vl2 import DeepseekVL2ForCausalLM
class DeepseekVL2ImageProcessor(BaseImageProcessor):
models = [DeepseekVL2ForCausalLM]
def __init__(self, hf_config, server_args, _processor):
# with contextlib.suppress(ValueError):
# AutoProcessor.register("DeepseekVLV2Processor", DeepseekVLV2Processor)
super().__init__(hf_config, server_args, _processor)
self.IMAGE_TOKEN = "<image>"
@staticmethod
def _process_images_task(image, input_text, max_req_input_len):
return get_global_processor().__call__(
processor = get_global_processor()
res = processor.__call__(
conversations=input_text, images=image, max_req_input_len=max_req_input_len
)
image_token_id = processor.image_token_id
res["im_token_id"] = image_token_id
return res
async def _process_images(self, image_data, input_text, max_req_input_len):
if self.executor is not None:
loop = asyncio.get_event_loop()
@@ -70,18 +72,15 @@ class DeepseekVL2ImageProcessor(BaseImageProcessor):
if not isinstance(image_data, list):
image_data = [image_data]
images, image_hashes, image_sizes = [], [], []
images, image_sizes = [], []
image_token = self.IMAGE_TOKEN
base_output = self.load_images(
input_ids, image_data, image_token, max_req_input_len
)
base_output.all_frames = [img.convert("RGB") for img in base_output.all_frames]
res = await self._process_images(
base_output.all_frames, base_output.input_text, max_req_input_len
)
pixel_values = res["images"]
input_ids = res["input_ids"]
images_seq_mask = res["images_seq_mask"]
images_spatial_crop = res["images_spatial_crop"]
batched_images_spatial_crop = []
@@ -89,16 +88,12 @@ class DeepseekVL2ImageProcessor(BaseImageProcessor):
batched_images_spatial_crop = torch.stack(batched_images_spatial_crop, dim=0)
return {
"input_ids": input_ids.tolist(),
"pixel_values": pixel_values,
"image_hashes": image_hashes,
"input_ids": res["input_ids"].tolist(),
"pixel_values": res["images"],
"im_token_id": res["im_token_id"],
"image_hashes": base_output.image_hashes,
"image_sizes": image_sizes,
"image_seq_mask": images_seq_mask,
"images_emb_mask": images_seq_mask,
"image_spatial_crop": batched_images_spatial_crop,
"modalities": request_obj.modalities or ["image"],
}
ImageProcessorMapping = {
DeepseekVL2ForCausalLM: DeepseekVL2ImageProcessor,
}

View File

@@ -17,14 +17,15 @@ logger = logging.get_logger(__name__)
class Gemma3SGLangImageProcessor(SGLangBaseImageProcessor):
models = [Gemma3ForConditionalGeneration]
def __init__(self, hf_config, server_args, _processor):
super().__init__(hf_config, server_args, _processor)
self.IMAGE_TOKEN = "<start_of_image>"
self.IM_START_TOKEN_ID = hf_config.boi_token_index
self.IM_END_TOKEN_ID = hf_config.eoi_token_index
@staticmethod
def _process_images_task(images, input_text, _hf_config):
async def _process_single_image(self, images, input_text) -> dict:
if isinstance(images, list) and len(images) == 0:
images = None
processor = get_global_processor()
@@ -46,19 +47,6 @@ class Gemma3SGLangImageProcessor(SGLangBaseImageProcessor):
"pixel_values": pixel_values,
}
async def _process_images(self, images, input_text) -> dict:
if self.executor is not None:
loop = asyncio.get_event_loop()
return await loop.run_in_executor(
self.executor,
Gemma3SGLangImageProcessor._process_images_task,
images,
input_text,
self.hf_config,
)
else:
return self._process_images_task(images, input_text, self.hf_config)
async def process_images_async(
self,
image_data: List[Union[str, bytes]],
@@ -82,7 +70,7 @@ class Gemma3SGLangImageProcessor(SGLangBaseImageProcessor):
discard_alpha_channel=True,
)
ret = await self._process_images(
ret = await self._process_single_image(
input_text=base_output.input_text, images=base_output.all_frames
)
@@ -93,8 +81,3 @@ class Gemma3SGLangImageProcessor(SGLangBaseImageProcessor):
"im_start_id": self.IM_START_TOKEN_ID,
"im_end_id": self.IM_END_TOKEN_ID,
}
ImageProcessorMapping = {
Gemma3ForConditionalGeneration: Gemma3SGLangImageProcessor,
}

View File

@@ -11,6 +11,8 @@ from sglang.srt.models.deepseek_janus_pro import MultiModalityCausalLM
class JanusProProcessor(SGLangBaseImageProcessor):
models = [MultiModalityCausalLM]
def __init__(self, hf_config, server_args, _processor):
super().__init__(hf_config, server_args, _processor)
@@ -77,6 +79,3 @@ class JanusProProcessor(SGLangBaseImageProcessor):
"im_end_id": res["im_end_id"],
"im_token_id": res["im_token_id"],
}
ImageProcessorMapping = {MultiModalityCausalLM: JanusProProcessor}

View File

@@ -15,6 +15,8 @@ from sglang.utils import get_exception_traceback
class LlavaImageProcessor(BaseImageProcessor):
models = [LlavaVidForCausalLM, LlavaQwenForCausalLM, LlavaMistralForCausalLM]
def __init__(self, hf_config, server_args, _processor):
super().__init__(hf_config, server_args, _processor)
@@ -143,10 +145,3 @@ class LlavaImageProcessor(BaseImageProcessor):
"image_sizes": image_sizes,
"modalities": request_obj.modalities or ["image"],
}
ImageProcessorMapping = {
LlavaVidForCausalLM: LlavaImageProcessor,
LlavaQwenForCausalLM: LlavaImageProcessor,
LlavaMistralForCausalLM: LlavaImageProcessor,
}

View File

@@ -1,6 +1,8 @@
import asyncio
from typing import List, Union
import torch
from sglang.srt.managers.image_processor import BaseImageProcessor
from sglang.srt.managers.image_processors.base_image_processor import (
get_global_processor,
@@ -9,6 +11,8 @@ from sglang.srt.models.minicpmv import MiniCPMV
class MiniCPMVImageProcessor(BaseImageProcessor):
models = [MiniCPMV]
def __init__(self, hf_config, server_args, _processor):
super().__init__(hf_config, server_args, _processor)
self.IMAGE_TOKEN = "(<image>./</image>)"
@@ -69,21 +73,57 @@ class MiniCPMVImageProcessor(BaseImageProcessor):
# Collect special token ids
tokenizer = self._processor.tokenizer
im_start_id = tokenizer.im_start_id
im_token_id = tokenizer.unk_token_id
im_end_id = tokenizer.im_end_id
if tokenizer.slice_start_id:
slice_start_id = tokenizer.slice_start_id
slice_end_id = tokenizer.slice_end_id
pixel_values = res["pixel_values"]
tgt_sizes = res["tgt_sizes"]
if not isinstance(pixel_values, (torch.Tensor, list)):
raise ValueError(
"Incorrect type of pixel values. " f"Got type: {type(pixel_values)}"
)
if not isinstance(tgt_sizes, (torch.Tensor, list)):
raise ValueError(
"Incorrect type of target sizes. " f"Got type: {type(tgt_sizes)}"
)
if len(pixel_values) != len(tgt_sizes):
raise ValueError(
"Inconsistent batch lengths, found: "
f"{len(pixel_values)} vs. {len(tgt_sizes)}"
)
# tgt_sizes = [tgt_size for tgt_size in tgt_sizes if isinstance(tgt_size, torch.Tensor)]
# tgt_sizes = torch.vstack(tgt_sizes).type(torch.int32)
pixel_values_flat: List[torch.Tensor] = []
tgt_sizes_flat: List[torch.Tensor] = []
for pixel_b, tgt_b in zip(pixel_values, tgt_sizes):
# per image
if len(pixel_b) != len(tgt_b):
raise ValueError(
"Inconsistent N lengths, found: " f"{len(pixel_b)} vs {len(tgt_b)}"
)
for pixel_n, tgt_n in zip(pixel_b, tgt_b):
# per patch
pixel_values_flat += [pixel_n]
tgt_sizes_flat += [tgt_n]
pixel_values = pixel_values_flat
tgt_sizes = torch.stack(tgt_sizes_flat)
return {
"input_ids": res["input_ids"].flatten().tolist(),
"pixel_values": res["pixel_values"],
"tgt_sizes": res["tgt_sizes"],
"pixel_values": pixel_values,
"tgt_sizes": tgt_sizes,
"image_hashes": base_output.image_hashes,
"modalities": request_obj.modalities or ["image"],
"im_start_id": im_start_id,
"im_token_id": im_token_id,
"im_end_id": im_end_id,
"slice_start_id": slice_start_id,
"slice_end_id": slice_end_id,
}
ImageProcessorMapping = {MiniCPMV: MiniCPMVImageProcessor}

View File

@@ -10,6 +10,8 @@ from sglang.srt.utils import load_image
class MllamaImageProcessor(BaseImageProcessor):
models = [MllamaForConditionalGeneration]
def __init__(self, hf_config, server_args, _processor):
super().__init__(hf_config, server_args, _processor)
@@ -55,6 +57,3 @@ class MllamaImageProcessor(BaseImageProcessor):
image_inputs["input_ids"] = image_inputs["input_ids"].tolist()[0]
return image_inputs
ImageProcessorMapping = {MllamaForConditionalGeneration: MllamaImageProcessor}

View File

@@ -2,6 +2,7 @@ import asyncio
import math
from typing import List, Union
import torch
from PIL import Image
from sglang.srt.managers.image_processor import BaseImageProcessor
@@ -14,6 +15,8 @@ from sglang.srt.models.qwen2_vl import Qwen2VLForConditionalGeneration
# Compatible with Qwen2VL and Qwen2_5VL
class Qwen2_5VLImageProcessor(BaseImageProcessor):
models = [Qwen2VLForConditionalGeneration, Qwen2_5_VLForConditionalGeneration]
def __init__(self, hf_config, server_args, _processor):
super().__init__(hf_config, server_args, _processor)
self.IMAGE_TOKEN = "<|vision_start|><|image_pad|><|vision_end|>"
@@ -43,7 +46,7 @@ class Qwen2_5VLImageProcessor(BaseImageProcessor):
"video_grid_thws": getattr(result, "video_grid_thws", None),
}
async def _process_images(self, images, input_text) -> dict:
async def _process_single_image(self, images, input_text) -> dict:
if self.executor is not None:
loop = asyncio.get_event_loop()
return await loop.run_in_executor(
@@ -138,23 +141,23 @@ class Qwen2_5VLImageProcessor(BaseImageProcessor):
images = [resize_image(image) for image in base_output.all_frames]
ret = await self._process_images(images, base_output.input_text)
ret = await self._process_single_image(
images=images, input_text=base_output.input_text
)
image_grid_thws = torch.concat([ret["image_grid_thw"]])
video_grid_thws = None
return {
"input_ids": ret["input_ids"].flatten().tolist(),
"pixel_values": ret["pixel_values"],
"image_hashes": base_output.image_hashes,
"modalities": request_obj.modalities or ["image"],
"image_grid_thws": ret["image_grid_thw"],
"video_grid_thws": ret["video_grid_thws"],
"image_grid_thws": image_grid_thws,
"video_grid_thws": video_grid_thws,
"im_start_id": self.IM_START_TOKEN_ID,
"im_end_id": self.IM_END_TOKEN_ID,
"im_token_id": self.image_token_id,
"video_token_id": self.video_token_id,
"second_per_grid_ts": ret["second_per_grid_ts"],
}
ImageProcessorMapping = {
Qwen2VLForConditionalGeneration: Qwen2_5VLImageProcessor,
Qwen2_5_VLForConditionalGeneration: Qwen2_5VLImageProcessor,
}

View File

@@ -0,0 +1,303 @@
"""
Multimodality utils
"""
from abc import abstractmethod
from typing import Callable, List, Optional, Tuple
import torch
from torch import nn
from sglang.srt.managers.schedule_batch import (
ImageInputs,
global_server_args_dict,
logger,
)
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
from sglang.utils import logger
class MultiModalityDataPaddingPattern:
"""
Data tokens (like image tokens) often need special handling during padding
to maintain model compatibility. This class provides the interface for
implementing different padding strategies for data tokens
"""
@abstractmethod
def pad_input_tokens(
self, input_ids: List[int], image_inputs: ImageInputs
) -> List[int]:
"""
Pad the input ids sequence containing data tokens, and replace them with pad_values
"""
pass
class MultiModalityDataPaddingPatternTokenPairs(MultiModalityDataPaddingPattern):
"""In this pattern, data tokens should be enclosed by special token pairs (e.g. <image>...</image>, data_token_pairs)
This strategy should be applied when data content is marked by start/end token pairs in the input sequence.
"""
def __init__(self, data_token_pairs: Optional[List[Tuple[int, int]]]) -> None:
self.data_token_id_pairs = data_token_pairs
def pad_input_tokens(
self, input_ids: List[int], image_inputs: ImageInputs
) -> List[int]:
"""
This function will replace the data-tokens inbetween with pad_values accordingly
"""
pad_values = image_inputs.pad_values
data_token_pairs = self.data_token_id_pairs
image_inputs.image_offsets = []
if data_token_pairs is None:
data_token_pairs = [image_inputs.im_start_id, image_inputs.im_end_id]
if data_token_pairs is None:
logger.warning(
"No data_token_pairs provided, RadixAttention might be influenced."
)
return input_ids
start_token_ids = [s for s, _e in data_token_pairs]
end_tokens_ids = [e for _s, e in data_token_pairs]
# First start token marks new data
data_start_token = start_token_ids[0]
padded_ids = []
last_idx = 0
data_idx = -1
start_indices = [i for i, x in enumerate(input_ids) if x in start_token_ids]
end_indices = [i for i, x in enumerate(input_ids) if x in end_tokens_ids]
if len(start_indices) != len(end_indices):
return input_ids
for start_idx, end_idx in zip(start_indices, end_indices):
padded_ids.extend(input_ids[last_idx : start_idx + 1])
if input_ids[start_idx] == data_start_token:
data_idx += 1
image_inputs.image_offsets += [start_idx]
num_tokens = end_idx - start_idx - 1
pad_value = pad_values[data_idx]
padded_ids.extend([pad_value] * num_tokens)
last_idx = end_idx
padded_ids.extend(input_ids[last_idx:])
assert len(input_ids) == len(padded_ids)
return padded_ids
class MultModalityDataPaddingPatternSingleToken(MultiModalityDataPaddingPattern):
"""In this pattern, data is represented with a special token_id ( image_inputs.im_token_id ),
which needs first to be expanded to multiple tokens, then replaced with their padding values
This strategy should be used when a single data token represents content that should
be expanded to multiple tokens during processing.
"""
def __init__(
self, num_data_token_calc_func: Callable[[Tuple[int, int, int]], int]
) -> None:
self.num_data_token_calc_func = num_data_token_calc_func
def pad_input_tokens(
self, input_ids: List[int], image_inputs: ImageInputs
) -> List[int]:
"""
This function will follow the procedure of:
1. the data token will be expanded, of which the final number will be calculated by `num_data_token_calc_func`
2. the padded data tokens will be replaced with their pad_values
"""
image_grid_thws = image_inputs.image_grid_thws
pad_values = image_inputs.pad_values
image_indices = [
idx
for idx, token in enumerate(input_ids)
if token == image_inputs.im_token_id
]
image_inputs.image_offsets = []
input_ids_with_image = []
for image_cnt, _ in enumerate(image_grid_thws):
num_image_tokens = self.num_data_token_calc_func(image_grid_thws[image_cnt])
if image_cnt == 0:
non_image_tokens = input_ids[: image_indices[image_cnt]]
else:
non_image_tokens = input_ids[
image_indices[image_cnt - 1] + 1 : image_indices[image_cnt]
]
input_ids_with_image.extend(non_image_tokens)
image_inputs.image_offsets.append(len(input_ids_with_image))
pad_ids = pad_values * (
(num_image_tokens + len(pad_values)) // len(pad_values)
)
input_ids_with_image.extend(pad_ids[:num_image_tokens])
input_ids_with_image.extend(input_ids[image_indices[-1] + 1 :])
return input_ids_with_image
class MultiModalityDataPaddingPatternImageTokens(MultiModalityDataPaddingPattern):
"""In this pattern, data tokens should be represented as image tokens (e.g. <image><image>....<image>)"""
def __init__(self, image_token_id: torch.Tensor) -> None:
self.image_token_id = image_token_id
def pad_input_tokens(self, input_ids: List[int], image_inputs) -> List[int]:
"""
This function will replace the data-tokens in between with pad_values accordingly
"""
pad_values = image_inputs.pad_values
assert len(pad_values) != 0
input_ids_tensor = torch.tensor(input_ids)
mask = torch.isin(input_ids_tensor, self.image_token_id)
num_image_tokens = mask.sum().item()
repeated_pad_values = torch.tensor(pad_values).repeat(
num_image_tokens // len(pad_values) + 1
)[:num_image_tokens]
input_ids_tensor[mask] = repeated_pad_values
return input_ids_tensor.tolist()
def embed_image_inputs(
image_input: ImageInputs,
input_ids: torch.Tensor,
input_embedding: nn.Embedding,
image_embedding_func,
placeholder_token_ids: List[int] = None,
) -> Optional[torch.Tensor]:
"""
Calculate the image embeddings if necessary, then scatter the result with
the help of a boolean mask denoting the embed locations
Returns:
final embedding: Optional[torch.Tensor]
"""
if image_input is None:
return None
placeholder_token_ids = placeholder_token_ids or image_input.pad_values
# boolean masking the special tokens
special_image_mask = torch.isin(
input_ids,
torch.tensor(placeholder_token_ids, device=input_ids.device),
).unsqueeze(-1)
num_image_tokens_in_input_ids = special_image_mask.sum()
if num_image_tokens_in_input_ids == 0:
# unexpected
inputs_embeds = input_embedding(input_ids)
else:
image_embedding = image_embedding_func(image_input)
if image_embedding.dim() == 2:
num_image_tokens_in_embedding = image_embedding.shape[0]
else:
num_image_tokens_in_embedding = (
image_embedding.shape[0] * image_embedding.shape[1]
)
if num_image_tokens_in_input_ids != num_image_tokens_in_embedding:
num_image = num_image_tokens_in_input_ids // image_embedding.shape[1]
image_embedding = image_embedding[:num_image, :]
logger.warning(
f"Number of images does not match number of special image tokens in the input text. "
f"Got {num_image_tokens_in_input_ids} image tokens in the text but {num_image_tokens_in_embedding} "
"tokens from image embeddings."
)
# TODO: chunked prefill will split special tokens from input_ids into several passes, failing the embedding
# a fix may be cache the unfinished image embedding for future reuse, determine the tokens to embed with
# extend_start_loc and extend_seq_lens
if num_image_tokens_in_input_ids > num_image_tokens_in_embedding:
chunked_prefill_size = global_server_args_dict["chunked_prefill_size"]
if chunked_prefill_size != -1:
logger.warning(
"You may want to avoid this issue by raising `chunked_prefill_size`, or disabling chunked_prefill"
)
vocab_size = input_embedding.num_embeddings
# Important: clamp after getting original image regions
# Clamp input ids. This is because the input_ids for the image tokens are
# filled with the hash values of the image for the prefix matching in the radix attention.
# There values are useless because their embeddings will be replaced by vision embeddings anyway.
input_ids.clamp_(min=0, max=vocab_size - 1)
inputs_embeds = input_embedding(input_ids)
special_image_mask = special_image_mask.expand_as(inputs_embeds).to(
inputs_embeds.device
)
inputs_embeds = inputs_embeds.masked_scatter(
special_image_mask,
image_embedding.to(inputs_embeds.device, inputs_embeds.dtype),
)
return inputs_embeds
def embed_image_embedding(
inputs_embeds: torch.Tensor,
image_embedding: torch.Tensor,
image_bounds: torch.Tensor,
) -> torch.Tensor:
"""
scatter image_embedding into inputs_embeds according to image_bounds
"""
if len(image_bounds) > 0:
image_indices = torch.stack(
[
torch.arange(start, end, dtype=torch.long)
for start, end in image_bounds.tolist()
]
).to(inputs_embeds.device)
inputs_embeds.scatter_(
0,
image_indices.view(-1, 1).repeat(1, inputs_embeds.shape[-1]),
image_embedding.view(-1, image_embedding.shape[-1]),
)
return inputs_embeds
def general_mm_embed_routine(
input_ids: torch.Tensor,
positions: torch.Tensor,
forward_batch: ForwardBatch,
embed_tokens: nn.Embedding,
image_embedding_func: Callable[[ImageInputs], torch.Tensor],
placeholder_token_ids: List[int] = None,
):
"""
a general wrapper function to get final input embeds from multimodal models
with a language model as causal model
"""
if (
forward_batch.forward_mode.is_decode()
or not forward_batch.contains_image_inputs()
):
inputs_embeds = embed_tokens(input_ids)
else:
image = forward_batch.merge_image_inputs()
inputs_embeds = embed_image_inputs(
image_input=image,
input_ids=input_ids,
input_embedding=embed_tokens,
image_embedding_func=image_embedding_func,
placeholder_token_ids=placeholder_token_ids,
)
# once used, image_inputs is useless
# just being defensive here
forward_batch.image_inputs = None
return inputs_embeds

View File

@@ -1,134 +0,0 @@
from abc import abstractmethod
from typing import Callable, List, Optional, Tuple
from sglang.srt.managers.schedule_batch import ImageInputs
from sglang.utils import logger
class MultiModalityDataPaddingPattern:
"""
Data tokens (like image tokens) often need special handling during padding
to maintain model compatibility. This class provides the interface for
implementing different padding strategies for data tokens
"""
@abstractmethod
def pad_input_tokens(
self, input_ids: List[int], image_inputs: ImageInputs
) -> List[int]:
"""
Pad the input ids sequence containing data tokens, and replace them with pad_values
"""
pass
class MultiModalityDataPaddingPatternTokenPairs(MultiModalityDataPaddingPattern):
"""In this pattern, data tokens should be enclosed by special token pairs (e.g. <image>...</image>, data_token_pairs)
This strategy should be applied when data content is marked by start/end token pairs in the input sequence.
"""
def __init__(self, data_token_pairs: Optional[List[Tuple[int, int]]]) -> None:
self.data_token_id_pairs = data_token_pairs
def pad_input_tokens(
self, input_ids: List[int], image_inputs: ImageInputs
) -> List[int]:
"""
This function will replace the data-tokens inbetween with pad_values accordingly
"""
pad_values = image_inputs.pad_values
data_token_pairs = self.data_token_id_pairs
image_inputs.image_offsets = []
if data_token_pairs is None:
data_token_pairs = [image_inputs.im_start_id, image_inputs.im_end_id]
if data_token_pairs is None:
logger.warning(
"No data_token_pairs provided, RadixAttention might be influenced."
)
return input_ids
start_token_ids = [s for s, _e in data_token_pairs]
end_tokens_ids = [e for _s, e in data_token_pairs]
# First start token marks new data
data_start_token = start_token_ids[0]
padded_ids = []
last_idx = 0
data_idx = -1
start_indices = [i for i, x in enumerate(input_ids) if x in start_token_ids]
end_indices = [i for i, x in enumerate(input_ids) if x in end_tokens_ids]
if len(start_indices) != len(end_indices):
return input_ids
for start_idx, end_idx in zip(start_indices, end_indices):
padded_ids.extend(input_ids[last_idx : start_idx + 1])
if input_ids[start_idx] == data_start_token:
data_idx += 1
image_inputs.image_offsets += [start_idx]
num_tokens = end_idx - start_idx - 1
pad_value = pad_values[data_idx]
padded_ids.extend([pad_value] * num_tokens)
last_idx = end_idx
padded_ids.extend(input_ids[last_idx:])
assert len(input_ids) == len(padded_ids)
return padded_ids
class MultModalityDataPaddingPatternSingleToken(MultiModalityDataPaddingPattern):
"""In this pattern, data is represented with a special token_id ( image_inputs.im_token_id ),
which needs first to be expanded to multiple tokens, then replaced with their padding values
This strategy should be used when a single data token represents content that should
be expanded to multiple tokens during processing.
"""
def __init__(
self, num_data_token_calc_func: Callable[[Tuple[int, int, int]], int]
) -> None:
self.num_data_token_calc_func = num_data_token_calc_func
def pad_input_tokens(
self, input_ids: List[int], image_inputs: ImageInputs
) -> List[int]:
"""
This function will follow the procedure of:
1. the data token will be expanded, of which the final number will be calculated by `num_data_token_calc_func`
2. the padded data tokens will be replaced with their pad_values
"""
image_grid_thws = image_inputs.image_grid_thws
pad_values = image_inputs.pad_values
image_indices = [
idx
for idx, token in enumerate(input_ids)
if token == image_inputs.im_token_id
]
image_inputs.image_offsets = []
input_ids_with_image = []
for image_cnt, _ in enumerate(image_grid_thws):
print(f"image_cnt {image_cnt}")
num_image_tokens = self.num_data_token_calc_func(image_grid_thws[image_cnt])
if image_cnt == 0:
non_image_tokens = input_ids[: image_indices[image_cnt]]
else:
non_image_tokens = input_ids[
image_indices[image_cnt - 1] + 1 : image_indices[image_cnt]
]
input_ids_with_image.extend(non_image_tokens)
image_inputs.image_offsets.append(len(input_ids_with_image))
pad_ids = pad_values * (
(num_image_tokens + len(pad_values)) // len(pad_values)
)
input_ids_with_image.extend(pad_ids[:num_image_tokens])
input_ids_with_image.extend(input_ids[image_indices[-1] + 1 :])
return input_ids_with_image

View File

@@ -77,6 +77,7 @@ global_server_args_dict = {
"enable_flashmla": ServerArgs.enable_flashmla,
"disable_radix_cache": ServerArgs.disable_radix_cache,
"flashinfer_mla_disable_ragged": ServerArgs.flashinfer_mla_disable_ragged,
"chunked_prefill_size": ServerArgs.chunked_prefill_size,
}
logger = logging.getLogger(__name__)
@@ -160,7 +161,8 @@ class ImageInputs:
aspect_ratio_mask: Optional[List[torch.Tensor]] = None
# QWen2-VL related
image_grid_thws: List[Tuple[int, int, int]] = None
# [num_of_images, t, h, w]
image_grid_thws: torch.Tensor = None
mrope_position_delta: Optional[torch.Tensor] = None
# Qwen2-VL video related
video_token_id: Optional[int] = None
@@ -168,7 +170,7 @@ class ImageInputs:
second_per_grid_ts: Optional[List[torch.Tensor]] = None
# deepseek vl2 related
image_seq_mask: Optional[List[torch.Tensor]] = None
images_emb_mask: Optional[List[torch.Tensor]] = None
image_spatial_crop: Optional[List[torch.Tensor]] = None
# The id of the single-image placeholder token
@@ -182,9 +184,6 @@ class ImageInputs:
slice_end_id: Optional[int] = None
tgt_sizes: Optional[list] = None
# denotes the number of valid image tokens in each image
images_emb_mask: Optional[torch.BoolTensor] = None
@staticmethod
def from_dict(obj: dict):
ret = ImageInputs(
@@ -204,7 +203,7 @@ class ImageInputs:
"aspect_ratio_ids",
"aspect_ratio_mask",
"image_grid_thws",
"image_seq_mask",
"images_emb_mask",
"image_spatial_crop",
"im_token_id",
"im_start_id",
@@ -212,20 +211,58 @@ class ImageInputs:
"slice_start_id",
"slice_end_id",
"tgt_sizes",
"images_emb_mask",
]
for arg in optional_args:
if arg in obj:
setattr(ret, arg, obj[arg])
# validate
assert (
isinstance(ret.pixel_values, torch.Tensor)
or isinstance(ret.pixel_values, np.ndarray)
or isinstance(ret.pixel_values, list)
)
return ret
def merge(self, other):
def merge(self, other: ImageInputs):
"""
merge image inputs when requests are being merged
"""
assert self.pixel_values.shape[1:] == other.pixel_values.shape[1:]
self.pixel_values = np.concatenate([self.pixel_values, other.pixel_values])
if isinstance(self.pixel_values, list):
# in some rare cases, pixel values are list of patches with different shapes
# e.g. minicpm
self.pixel_values += other.pixel_values
else:
assert (
self.pixel_values.shape[1:] == other.pixel_values.shape[1:]
), f"{self.pixel_values.shape[1:]} vs {other.pixel_values.shape[1:]}"
self.pixel_values = np.concatenate([self.pixel_values, other.pixel_values])
# args would be stacked along first dim
# usually these are already tensors
stack_args = [
# TODO: merge with image_grid_thws, basically the same thing
"tgt_sizes",
"image_spatial_crop",
]
for arg in stack_args:
if getattr(self, arg, None) is None:
setattr(self, arg, getattr(other, arg, None))
elif getattr(other, arg, None) is not None:
# self and other both not None
setattr(
self,
arg,
torch.cat([getattr(self, arg), getattr(other, arg)], dim=0),
)
if self.image_grid_thws is None:
self.image_grid_thws = other.image_grid_thws
elif other.image_grid_thws is not None:
self.image_grid_thws = torch.concat(
[self.image_grid_thws, other.image_grid_thws]
)
# Use image hash as fake token_ids. We use this as the key for prefix matching in the radix cache.
# Please note that if the `input_ids` is later used in the model forward,
@@ -233,7 +270,7 @@ class ImageInputs:
# errors in cuda kernels. See also llava.py for example.
self.image_hashes += other.image_hashes
self.pad_values = [x % (1 << 30) for x in self.image_hashes]
# args needed to be merged
optional_args = [
"image_sizes",
"image_offsets",
@@ -241,13 +278,13 @@ class ImageInputs:
# "modalities", # modalities should be ["multi-images"] (one entry) even for multiple images
"aspect_ratio_ids",
"aspect_ratio_mask",
"image_grid_thws",
"image_seq_mask",
"image_spatial_crop",
"images_emb_mask",
]
for arg in optional_args:
if getattr(self, arg, None) is not None:
setattr(self, arg, getattr(self, arg) + getattr(other, arg))
self_arg = getattr(self, arg, None)
if self_arg is not None:
setattr(self, arg, self_arg + getattr(other, arg))
# other args would be kept intact
class Req:

View File

@@ -179,7 +179,7 @@ class TokenizerManager:
)
# We want to parallelize the image pre-processing so we create an executor for it
# We creat image_processor for any skip_tokenizer_init to make sure we still encode
# We create image_processor for any skip_tokenizer_init to make sure we still encode
# images even with skip_tokenizer_init=False.
self.image_processor = get_image_processor(
self.model_config.hf_config, server_args, _processor

View File

@@ -332,7 +332,7 @@ class ForwardBatch:
return ret
def get_merged_image_inputs(self) -> Optional[ImageInputs]:
def merge_image_inputs(self) -> Optional[ImageInputs]:
"""
Merge all image inputs in the batch into a single ImageInputs object.
@@ -358,6 +358,16 @@ class ForwardBatch:
return merged
def contains_image_inputs(self) -> bool:
""" """
if self.image_inputs is None:
return True
return any(
image_input.pixel_values is not None and image_input.pixel_values is not []
for image_input in self.image_inputs
if image_input is not None
)
def _compute_mrope_positions(
self, model_runner: ModelRunner, batch: ModelWorkerBatch
):

View File

@@ -273,7 +273,7 @@ class ModelRunner:
if self.model_config.hf_config.architectures == ["DeepseekVL2ForCausalLM"]:
# TODO: deepseek-vl2 does not support radix cache now, set disable_radix_cache=True automatically
logger.info(
"Automatically turn off --chunked-prefill-size and disable radix cache for deekseek-vl2."
"Automatically turn off --chunked-prefill-size and disable radix cache for deepseek-vl2."
)
server_args.chunked_prefill_size = -1
server_args.disable_radix_cache = True

View File

@@ -47,8 +47,9 @@ from sglang.srt.configs.janus_pro import *
from sglang.srt.layers.attention.vision import VisionAttention
from sglang.srt.layers.logits_processor import LogitsProcessor
from sglang.srt.layers.quantization import QuantizationConfig
from sglang.srt.managers.multi_modality_padding import (
from sglang.srt.managers.mm_utils import (
MultiModalityDataPaddingPatternTokenPairs,
general_mm_embed_routine,
)
from sglang.srt.managers.schedule_batch import ImageInputs
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
@@ -1958,82 +1959,8 @@ class MultiModalityCausalLM(MultiModalityPreTrainedModel):
)
self.logits_processor = LogitsProcessor(config)
def prepare_images_seq_mask(
self, input_ids: torch.Tensor, image_inputs: ImageInputs
) -> Optional[torch.LongTensor]:
images_seq_mask = torch.isin(
input_ids, torch.tensor(image_inputs.pad_values, device=input_ids.device)
)
if images_seq_mask.sum() == 0:
# sometimes image_inputs is not empty, but input_ids contain no image token because of prefix-cache
return None
else:
return images_seq_mask
@torch.no_grad()
def forward(
self,
input_ids: torch.LongTensor,
positions: torch.Tensor,
forward_batch: ForwardBatch,
) -> torch.Tensor:
inputs_embeds = None
if (
forward_batch.image_inputs is not None
and len(forward_batch.image_inputs) != 0
and forward_batch.image_inputs[0] is not None
):
image_inputs = forward_batch.image_inputs[0]
images_seq_mask = self.prepare_images_seq_mask(
input_ids=input_ids, image_inputs=image_inputs
)
if images_seq_mask is not None:
input_ids.clamp_(min=0, max=self.config.vocab_size - 1)
inputs_embeds = self.prepare_inputs_embeds(
input_ids=input_ids,
pixel_values=image_inputs.pixel_values,
images_seq_mask=images_seq_mask,
images_emb_mask=image_inputs.images_emb_mask,
)
input_ids = None
if input_ids is not None:
input_ids.clamp_(min=0, max=self.config.vocab_size - 1)
return self.language_model(
input_ids=input_ids,
positions=positions,
forward_batch=forward_batch,
input_embeds=inputs_embeds,
get_embedding=False,
)
def prepare_inputs_embeds(
self,
input_ids: torch.LongTensor,
pixel_values: torch.FloatTensor,
images_seq_mask: torch.LongTensor,
images_emb_mask: torch.BoolTensor,
**_kwargs,
):
"""
Args:
input_ids (torch.LongTensor): [b, T]
pixel_values (torch.FloatTensor): [b, n_images, 3, h, w]
images_seq_mask (torch.BoolTensor): [b, T]
images_emb_mask (torch.BoolTensor): [b, n_images, n_image_tokens]
assert torch.sum(images_seq_mask) == torch.sum(images_emb_mask)
Returns:
input_embeds (torch.Tensor): [b, T, D]
"""
def get_image_feature(self, image_input: ImageInputs) -> torch.Tensor:
pixel_values = image_input.pixel_values
bs, n = pixel_values.shape[0:2]
pixel_values = pixel_values.to(
device=self.vision_model.device, dtype=self.vision_model.dtype
@@ -2045,18 +1972,35 @@ class MultiModalityCausalLM(MultiModalityPreTrainedModel):
# [b x n, T2, D] -> [b, n x T2, D]
images_embeds = rearrange(images_embeds, "(b n) t d -> b (n t) d", b=bs, n=n)
# [b, n, T2] -> [b, n x T2]
images_emb_mask = rearrange(images_emb_mask, "b n t -> b (n t)")
# [b, T, D]
# ignore the image embeddings
input_ids[input_ids < 0] = 0
inputs_embeds = self.language_model.model.embed_tokens(input_ids)
return images_embeds
# replace with the image embeddings
inputs_embeds[images_seq_mask] = images_embeds[images_emb_mask]
def get_input_embeddings(self) -> nn.Embedding:
return self.language_model.model.embed_tokens
return inputs_embeds
@torch.no_grad()
def forward(
self,
input_ids: torch.LongTensor,
positions: torch.Tensor,
forward_batch: ForwardBatch,
) -> torch.Tensor:
inputs_embeds = general_mm_embed_routine(
input_ids=input_ids,
positions=positions,
forward_batch=forward_batch,
embed_tokens=self.get_input_embeddings(),
image_embedding_func=self.get_image_feature,
)
return self.language_model(
input_ids=None,
positions=positions,
forward_batch=forward_batch,
input_embeds=inputs_embeds,
get_embedding=False,
)
def prepare_gen_img_embeds(self, image_ids: torch.LongTensor):
return self.gen_aligner(self.gen_embed(image_ids))

View File

@@ -1,34 +1,16 @@
import collections
import itertools
import math
import warnings
from enum import Enum
from functools import partial
from typing import Callable, Iterable, List, Optional, Tuple, Type, Union
from typing import Iterable, List, Optional, Tuple
import torch
import torch.nn.functional as F
from einops import rearrange, repeat
from torch import nn
from sglang.srt.configs import DeepseekVL2Config
from sglang.srt.configs.deepseekvl2 import (
DeepseekVL2Config,
DeepseekVL2MlpProjectorConfig,
)
from sglang.srt.layers.attention.vision import VisionAttention
from sglang.srt.layers.layernorm import RMSNorm
from sglang.srt.layers.linear import (
ColumnParallelLinear,
LinearBase,
ReplicatedLinear,
RowParallelLinear,
)
from sglang.srt.layers.linear import ReplicatedLinear
from sglang.srt.layers.quantization.base_config import QuantizationConfig
from sglang.srt.layers.vocab_parallel_embedding import (
ParallelLMHead,
VocabParallelEmbedding,
)
from sglang.srt.managers.schedule_batch import ImageInputs
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
from sglang.srt.model_loader.weight_utils import default_weight_loader
@@ -233,11 +215,11 @@ class DeepseekVL2ForCausalLM(nn.Module):
forward_batch: ForwardBatch,
**kwargs: object,
):
input_embeds = self.language_model.model.embed_tokens(input_ids)
if forward_batch.forward_mode.is_extend() and forward_batch.image_inputs != [
None
]:
if (
forward_batch.forward_mode.is_extend()
and forward_batch.contains_image_inputs()
):
extend_start_loc_cpu = forward_batch.extend_start_loc.cpu().numpy()
extend_seq_lens_cpu = forward_batch.extend_seq_lens.cpu().numpy()
for idx, image in enumerate(forward_batch.image_inputs):
@@ -245,17 +227,11 @@ class DeepseekVL2ForCausalLM(nn.Module):
continue
start_idx = extend_start_loc_cpu[idx]
end_idx = start_idx + extend_seq_lens_cpu[idx]
pixel_values = image.pixel_values.to(
device="cuda", dtype=torch.bfloat16
)
image_seq_mask = image.image_seq_mask.to(device="cuda")
image_spatial_crop = image.image_spatial_crop
input_embeds[start_idx:end_idx] = self.prepare_inputs_embeds(
pixel_values,
image_seq_mask,
image_spatial_crop,
input_embeds[start_idx:end_idx],
)
images_emb_mask = image.images_emb_mask.to(device="cuda")
image_features = self.get_image_feature(image)
input_embeds[start_idx:end_idx] = input_embeds[
start_idx:end_idx
].masked_scatter(images_emb_mask.unsqueeze(-1), image_features)
outputs = self.language_model.forward(
input_ids=input_ids,
@@ -289,20 +265,17 @@ class DeepseekVL2ForCausalLM(nn.Module):
def pad_input_ids(self, input_ids: List[int], image_inputs: ImageInputs):
return input_ids
def prepare_inputs_embeds(
self,
pixel_values,
images_seq_mask,
images_spatial_crop,
input_embeds,
):
def get_image_feature(self, image_input: ImageInputs):
pixel_values = image_input.pixel_values.type(
next(self.vision.parameters()).dtype
).to(device=next(self.vision.parameters()).device)
image_feature = self.vision.forward_features(pixel_values)
images_embeds = self.projector(image_feature)
_, hw, n_dim = images_embeds.shape
h = w = int(hw**0.5)
tile_index = 0
images_in_this_batch = []
images_spatial_crop = image_input.image_spatial_crop
for jdx in range(images_spatial_crop.shape[1]):
num_width_tiles, num_height_tiles = images_spatial_crop[0, jdx]
if num_width_tiles == 0 or num_height_tiles == 0:
@@ -379,13 +352,7 @@ class DeepseekVL2ForCausalLM(nn.Module):
images_in_this_batch.append(global_local_features)
if len(images_in_this_batch) > 0:
images_in_this_batch = torch.cat(images_in_this_batch, dim=0)
input_embeds.masked_scatter_(
images_seq_mask.unsqueeze(-1), images_in_this_batch
)
return input_embeds
return torch.cat(images_in_this_batch, dim=0)
EntryClass = DeepseekVL2ForCausalLM

View File

@@ -37,11 +37,8 @@ from sglang.srt.layers.linear import (
from sglang.srt.layers.logits_processor import LogitsProcessor
from sglang.srt.layers.quantization.base_config import QuantizationConfig
from sglang.srt.layers.radix_attention import RadixAttention
from sglang.srt.layers.rotary_embedding import apply_rotary_pos_emb, get_rope
from sglang.srt.layers.vocab_parallel_embedding import (
ParallelLMHead,
VocabParallelEmbedding,
)
from sglang.srt.layers.rotary_embedding import apply_rotary_pos_emb
from sglang.srt.layers.vocab_parallel_embedding import ParallelLMHead
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
from sglang.srt.model_loader.weight_utils import (
default_weight_loader,
@@ -511,7 +508,7 @@ class Gemma3TextModel(PreTrainedModel):
else:
hidden_states = input_embeds
if len(positions.shape) == 1:
if positions.dim() == 1:
positions = einops.rearrange(positions, "s -> 1 s")
position_embeddings_global = self.rotary_emb(hidden_states, positions)
@@ -609,11 +606,11 @@ class Gemma3ForCausalLM(PreTrainedModel):
)
self.post_init()
def get_input_embeddings(self):
def get_input_embeddings(self) -> nn.Embedding:
return self.model.embed_tokens
def dtype(self) -> torch.dtype:
return self.model.layers[0].mlp.gate_up_proj.weight.dtype
return next(self.parameters()).dtype
@torch.no_grad()
def forward(

View File

@@ -34,8 +34,9 @@ from sglang.srt.hf_transformers_utils import get_processor
from sglang.srt.layers.layernorm import Gemma3RMSNorm
from sglang.srt.layers.logits_processor import LogitsProcessor
from sglang.srt.layers.quantization.base_config import QuantizationConfig
from sglang.srt.managers.multi_modality_padding import (
from sglang.srt.managers.mm_utils import (
MultiModalityDataPaddingPatternTokenPairs,
general_mm_embed_routine,
)
from sglang.srt.managers.schedule_batch import ImageInputs
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
@@ -264,10 +265,10 @@ class Gemma3ForConditionalGeneration(PreTrainedModel):
kwargs["local_attn_masks"] = local_attn_masks
return kwargs
def get_input_embeddings(self):
def get_input_embeddings(self) -> nn.Embedding:
return self.language_model.get_input_embeddings()
def get_image_features(self, pixel_values: torch.Tensor):
def get_image_feature(self, image_input: ImageInputs):
"""
Projects the last hidden state from the vision model into language model space.
@@ -277,6 +278,7 @@ class Gemma3ForConditionalGeneration(PreTrainedModel):
Returns:
image_features (`torch.Tensor`): Image feature tensor of shape `(num_images, image_length, embed_dim)`).
"""
pixel_values = image_input.pixel_values
pixel_values = pixel_values.to("cuda")
pixel_values = pixel_values.to(dtype=self.language_model.dtype())
@@ -305,7 +307,7 @@ class Gemma3ForConditionalGeneration(PreTrainedModel):
return inputs_embeds
else:
# print(f"image tokens from input_ids: {inputs_embeds[special_image_mask].numel()}")
image_features = self.get_image_features(image_input.pixel_values)
image_features = self.get_image_feature(image_input.pixel_values)
# print(f"image tokens from image embeddings: {image_features.numel()}")
num_image_tokens_in_embedding = (
@@ -397,20 +399,13 @@ class Gemma3ForConditionalGeneration(PreTrainedModel):
else:
llm_input_ids = input_ids
merged_image_input = forward_batch.get_merged_image_inputs()
if (
not forward_batch.forward_mode.is_decode()
and merged_image_input is not None
):
inputs_embeds = self.embed_image_inputs(
input_ids=llm_input_ids,
forward_batch=forward_batch,
image_input=merged_image_input,
)
else:
llm_input_ids.clamp_(min=0, max=self.vocab_size - 1)
inputs_embeds = self.get_input_embeddings()(llm_input_ids)
inputs_embeds = general_mm_embed_routine(
input_ids=llm_input_ids,
positions=positions,
forward_batch=forward_batch,
embed_tokens=self.get_input_embeddings(),
image_embedding_func=self.get_image_feature,
)
outputs = self.language_model(
input_ids=None,

View File

@@ -50,8 +50,9 @@ from sglang.srt.layers.linear import (
)
from sglang.srt.layers.logits_processor import LogitsProcessor
from sglang.srt.layers.quantization.base_config import QuantizationConfig
from sglang.srt.managers.multi_modality_padding import (
from sglang.srt.managers.mm_utils import (
MultiModalityDataPaddingPatternTokenPairs,
embed_image_inputs,
)
from sglang.srt.managers.schedule_batch import ImageInputs
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
@@ -399,7 +400,7 @@ class Idefics2VisionTransformer(nn.Module):
)
self.post_layernorm = nn.LayerNorm(embed_dim, eps=config.layer_norm_eps)
def get_input_embeddings(self):
def get_input_embeddings(self) -> nn.Embedding:
return self.embeddings
def compute_cu_seqlens(self, tgt_sizes: torch.Tensor) -> torch.Tensor:
@@ -762,42 +763,6 @@ class MiniCPMVBaseModel(nn.Module):
valid_pairs_tensor = torch.tensor(valid_pairs, device=input_ids.device)
return valid_pairs_tensor
def get_embedding(
self,
input_ids: torch.Tensor,
image_inputs: Optional[MiniCPMVImageInputs],
) -> Tuple[torch.Tensor, torch.Tensor]:
vlm_embedding: torch.Tensor = self.llm.get_input_embeddings(input_ids)
if image_inputs is None: # No image
vision_hidden_states = torch.tensor([], device=input_ids.device)
else:
if image_inputs["type"] == "image_embeds":
vision_hidden_states = (
image_inputs["data"]
.type(vlm_embedding.dtype)
.to(vlm_embedding.device)
)
else:
vision_hidden_states = self.get_vision_hidden_states(image_inputs)
# See NOTE in _parse_and_validate_inputs
image_bounds = image_inputs["image_bounds"]
if len(image_bounds) > 0:
image_indices = torch.stack(
[
torch.arange(start, end, dtype=torch.long)
for start, end in image_bounds.tolist()
]
).to(vlm_embedding.device)
vlm_embedding.scatter_(
0,
image_indices.view(-1, 1).repeat(1, vlm_embedding.shape[-1]),
vision_hidden_states.view(-1, vision_hidden_states.shape[-1]),
)
return vlm_embedding, vision_hidden_states
def _parse_and_validate_inputs(
self,
input_ids: torch.Tensor,
@@ -836,46 +801,6 @@ class MiniCPMVBaseModel(nn.Module):
type="image_embeds",
)
if not isinstance(pixel_values, (torch.Tensor, list)):
raise ValueError(
"Incorrect type of pixel values. " f"Got type: {type(pixel_values)}"
)
if not isinstance(tgt_sizes, (torch.Tensor, list)):
raise ValueError(
"Incorrect type of target sizes. " f"Got type: {type(tgt_sizes)}"
)
if len(pixel_values) != len(tgt_sizes):
raise ValueError(
"Inconsistent batch lengths, found: "
f"{len(pixel_values)} vs. {len(tgt_sizes)}"
)
pixel_values_flat: List[torch.Tensor] = []
tgt_sizes_flat: List[torch.Tensor] = []
for pixel_b, tgt_b in zip(pixel_values, tgt_sizes):
if len(pixel_b) != len(tgt_b):
raise ValueError(
"Inconsistent N lengths, found: " f"{len(pixel_b)} vs {len(tgt_b)}"
)
for pixel_n, tgt_n in zip(pixel_b, tgt_b):
pixel_values_flat += pixel_n
tgt_sizes_flat += tgt_n
# NOTE: Input IDs does not contain image tokens during memory profiling,
# so we allow it to be empty
if len(pixel_values_flat) != len(tgt_sizes_flat):
raise ValueError(
"Inconsistent flattened lengths, found: "
f"{len(pixel_values_flat)} vs. "
f"{len(tgt_sizes_flat)}"
)
if len(pixel_values_flat) == 0:
return None
image_bounds = self._get_image_bounds(
input_ids=input_ids,
pad_values=pad_values,
@@ -886,11 +811,50 @@ class MiniCPMVBaseModel(nn.Module):
)
return MiniCPMVImagePixelInputs(
image_bounds=image_bounds.to(device=input_ids.device),
data=pixel_values_flat,
tgt_sizes=torch.stack(tgt_sizes_flat),
data=pixel_values,
tgt_sizes=tgt_sizes,
type="pixel_values",
)
def get_embedding(
self,
input_ids: torch.Tensor,
image_inputs: Optional[MiniCPMVImageInputs],
) -> Tuple[torch.Tensor, torch.Tensor]:
vlm_embedding: torch.Tensor = self.llm.get_input_embeddings(input_ids)
if image_inputs is None: # No image
vision_hidden_states = torch.tensor([], device=input_ids.device)
else:
if image_inputs["type"] == "image_embeds":
vision_hidden_states = (
image_inputs["data"]
.type(vlm_embedding.dtype)
.to(vlm_embedding.device)
)
else:
vision_hidden_states = self.get_vision_hidden_states(image_inputs)
# See NOTE in _parse_and_validate_inputs
image_bounds = image_inputs["image_bounds"]
if len(image_bounds) > 0:
image_indices = torch.stack(
[
torch.arange(start, end, dtype=torch.long)
for start, end in image_bounds.tolist()
]
).to(vlm_embedding.device)
vlm_embedding.scatter_(
0,
image_indices.view(-1, 1).repeat(1, vlm_embedding.shape[-1]),
vision_hidden_states.view(-1, vision_hidden_states.shape[-1]),
)
return vlm_embedding, vision_hidden_states
def get_input_embeddings(self) -> nn.Embedding:
return self.llm.get_input_embedding()
def forward(
self,
input_ids: torch.Tensor,
@@ -899,58 +863,29 @@ class MiniCPMVBaseModel(nn.Module):
**kwargs: Any,
) -> torch.Tensor:
if (
forward_batch.image_inputs is not None
and len(forward_batch.image_inputs) > 0
and forward_batch.image_inputs[0] is not None
forward_batch.forward_mode.is_decode()
or not forward_batch.contains_image_inputs()
):
# TODO: bath
kwargs.update(
{
"pixel_values": (
None
if forward_batch.image_inputs is None
else [
i.pixel_values
for i in forward_batch.image_inputs
if i is not None
]
),
"tgt_sizes": (
None
if forward_batch.image_inputs is None
else [
i.tgt_sizes
for i in forward_batch.image_inputs
if i is not None
]
),
"im_start_id": forward_batch.image_inputs[0].im_start_id,
"im_end_id": forward_batch.image_inputs[0].im_end_id,
"slice_start_id": forward_batch.image_inputs[0].slice_start_id,
"slice_end_id": forward_batch.image_inputs[0].slice_end_id,
"pad_values": forward_batch.image_inputs[0].pad_values,
}
inputs_embeds: torch.Tensor = self.llm.get_input_embeddings(input_ids)
else:
# Clamp input ids. This is because the input_ids for the image tokens are
# filled with the hash values of the image for the prefix matching in the radix attention.
# There values are useless because their embeddings will be replaced by vision embeddings anyway.
image_inputs = forward_batch.merge_image_inputs()
inputs_embeds = embed_image_inputs(
image_input=image_inputs,
input_ids=input_ids,
input_embedding=self.get_input_embeddings(),
image_embedding_func=self.get_image_features,
placeholder_token_ids=[image_inputs.im_token_id]
+ image_inputs.pad_values,
)
image_inputs = self._parse_and_validate_inputs(input_ids, **kwargs)
# Clamp input ids. This is because the input_ids for the image tokens are
# filled with the hash values of the image for the prefix matching in the radix attention.
# There values are useless because their embeddings will be replaced by vision embeddings anyway.
input_ids.clamp_(min=0, max=self.config.vocab_size - 1)
vlm_embeddings, _ = self.get_embedding(input_ids, image_inputs)
# always pass the input via `inputs_embeds`
# to make sure the computation graph is consistent
# for `torch.compile` integration
input_ids = None
hidden_states = self.llm.model(
input_ids=input_ids,
input_ids=None,
positions=positions,
forward_batch=forward_batch,
input_embeds=vlm_embeddings,
input_embeds=inputs_embeds,
)
return self.logits_processor(
@@ -990,7 +925,7 @@ class MiniCPMVBaseModel(nn.Module):
) -> torch.Tensor:
raise NotImplementedError
def get_vision_hidden_states(self, data: MiniCPMVImageInputs) -> torch.Tensor:
def get_image_features(self, image_inputs: ImageInputs) -> torch.Tensor:
raise NotImplementedError
@@ -1100,12 +1035,14 @@ class MiniCPMV2_6(MiniCPMVBaseModel):
)
return vision_embedding
def get_vision_hidden_states(
def get_image_features(
self,
data: MiniCPMVImageInputs,
image_inputs: ImageInputs,
) -> torch.Tensor:
pixel_values = data["data"]
tgt_sizes = data["tgt_sizes"]
# list of tensors
pixel_values = image_inputs.pixel_values
tgt_sizes = image_inputs.tgt_sizes
device = self.vpm.embeddings.position_embedding.weight.device
dtype = self.vpm.embeddings.position_embedding.weight.dtype

View File

@@ -361,6 +361,9 @@ class Qwen2ForCausalLM(nn.Module):
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
return self.model.get_input_embeddings(input_ids)
def get_input_embedding(self) -> nn.Embedding:
return self.model.embed_tokens
@torch.no_grad()
def forward(
self,

View File

@@ -26,7 +26,6 @@ import logging
from functools import lru_cache, partial
from typing import Iterable, List, Optional, Tuple, Type
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
@@ -54,14 +53,15 @@ from sglang.srt.layers.logits_processor import LogitsProcessor
from sglang.srt.layers.pooler import Pooler, PoolingType
from sglang.srt.layers.quantization.base_config import QuantizationConfig
from sglang.srt.layers.vocab_parallel_embedding import ParallelLMHead
from sglang.srt.managers.multi_modality_padding import (
from sglang.srt.managers.mm_utils import (
MultiModalityDataPaddingPatternTokenPairs,
general_mm_embed_routine,
)
from sglang.srt.managers.schedule_batch import ImageInputs
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
from sglang.srt.model_loader.weight_utils import default_weight_loader
from sglang.srt.models.qwen2 import Qwen2Model
from sglang.srt.models.qwen2_vl import Qwen2VLImageInputs, Qwen2VLVideoInputs
from sglang.srt.models.qwen2_vl import Qwen2VLVideoInputs
from sglang.srt.utils import add_prefix
logger = logging.getLogger(__name__)
@@ -326,13 +326,12 @@ class Qwen2_5_VisionTransformer(nn.Module):
)
def get_window_index(self, grid_thw):
window_index: list = []
cu_window_seqlens: list = [0]
window_index_id = 0
vit_merger_window_size = (
self.window_size // self.spatial_merge_size // self.patch_size
)
window_index: list = []
for grid_t, grid_h, grid_w in grid_thw:
llm_grid_h, llm_grid_w = (
grid_h // self.spatial_merge_size,
@@ -369,7 +368,6 @@ class Qwen2_5_VisionTransformer(nn.Module):
cu_window_seqlens.extend(cu_seqlens_tmp.tolist())
window_index_id += (grid_t * llm_grid_h * llm_grid_w).item()
window_index = torch.cat(window_index, dim=0)
return window_index, cu_window_seqlens
@property
@@ -382,8 +380,10 @@ class Qwen2_5_VisionTransformer(nn.Module):
def rot_pos_emb(self, grid_thw: torch.Tensor) -> torch.Tensor:
pos_ids = []
for t, h, w in grid_thw:
for i in range(grid_thw.size(0)):
t, h, w = grid_thw[i].tolist()
hpos_ids = torch.arange(h).unsqueeze(1).expand(-1, w)
hpos_ids = hpos_ids.reshape(
h // self.spatial_merge_size,
self.spatial_merge_size,
@@ -402,6 +402,7 @@ class Qwen2_5_VisionTransformer(nn.Module):
)
wpos_ids = wpos_ids.permute(0, 2, 1, 3)
wpos_ids = wpos_ids.flatten()
pos_ids.append(torch.stack([hpos_ids, wpos_ids], dim=-1).repeat(t, 1))
pos_ids = torch.cat(pos_ids, dim=0)
max_grid_size = grid_thw[:, 1:].max()
@@ -443,9 +444,12 @@ class Qwen2_5_VisionTransformer(nn.Module):
position_embeddings = (emb.cos(), emb.sin())
# compute cu_seqlens
cu_seqlens = torch.repeat_interleave(
grid_thw[:, 1] * grid_thw[:, 2], grid_thw[:, 0]
).cumsum(dim=0, dtype=torch.int32)
cu_seqlens = torch.cat(
[
torch.tensor([0], device=grid_thw.device),
(grid_thw[:, 0] * grid_thw[:, 1] * grid_thw[:, 2]).cumsum(dim=0),
]
)
cu_seqlens = F.pad(cu_seqlens, (1, 0), "constant", 0)
# transformers
@@ -509,18 +513,6 @@ class Qwen2_5_VLForConditionalGeneration(nn.Module):
self.logits_processor = LogitsProcessor(config)
self.pooler = Pooler(pooling_type=PoolingType.LAST, normalize=True)
def calculate_num_image_tokens(self, image_grid_thw: Tuple[int, int, int]):
processor = cached_get_processor(self.config._name_or_path)
grid_t, grid_h, grid_w = image_grid_thw
num_image_tokens = (
grid_t
* grid_h
* grid_w
// processor.image_processor.merge_size
// processor.image_processor.merge_size
)
return num_image_tokens
def pad_input_ids(self, input_ids: List[int], image_inputs: ImageInputs):
# Get all special token IDs
im_start_id: int = image_inputs.im_start_id
@@ -531,9 +523,9 @@ class Qwen2_5_VLForConditionalGeneration(nn.Module):
return pattern.pad_input_tokens(input_ids, image_inputs)
def _process_image_input(self, image_input: Qwen2VLImageInputs) -> torch.Tensor:
pixel_values = image_input["pixel_values"].type(self.visual.dtype)
image_embeds = self.visual(pixel_values, grid_thw=image_input["image_grid_thw"])
def get_image_feature(self, image_input: ImageInputs) -> torch.Tensor:
pixel_values = image_input.pixel_values.type(self.visual.dtype)
image_embeds = self.visual(pixel_values, grid_thw=image_input.image_grid_thws)
return image_embeds
def _process_video_input(self, video_input: Qwen2VLVideoInputs) -> torch.Tensor:
@@ -543,6 +535,9 @@ class Qwen2_5_VLForConditionalGeneration(nn.Module):
)
return video_embeds
def get_input_embeddings(self):
return self.model.embed_tokens
def forward(
self,
input_ids: torch.Tensor,
@@ -565,86 +560,26 @@ class Qwen2_5_VLForConditionalGeneration(nn.Module):
if getattr(self.config, "rope_scaling", {}).get("type", None) == "mrope":
positions = forward_batch.mrope_positions
image_inputs = None
if forward_batch.image_inputs is not None:
image_inputs = [
img for img in forward_batch.image_inputs if img is not None
]
if (
if not (
forward_batch.forward_mode.is_decode()
or image_inputs is None
or len(image_inputs) == 0
or not forward_batch.contains_image_inputs()
):
inputs_embeds = self.model.embed_tokens(input_ids)
else:
if getattr(self.config, "rope_scaling", {}).get("type", None) == "mrope":
assert positions.ndim == 2 and positions.size(0) == 3, (
"multimodal section rotary embedding requires "
f"(3, seq_len) positions, but got {positions.size()}"
)
# Clamp input ids. This is because the input_ids for the image tokens are
# filled with the hash values of the image for the prefix matching in the radix attention.
# There values are useless because their embeddings will be replaced by vision embeddings anyway.
input_ids.clamp_(min=0, max=self.config.vocab_size - 1)
# [B, s, hidden_size]
inputs_embeds = self.model.embed_tokens(input_ids)
extend_start_loc_cpu = forward_batch.extend_start_loc.cpu().numpy()
prefix_lens_cpu = forward_batch.extend_prefix_lens_cpu
for i, image in enumerate(forward_batch.image_inputs):
if image is None or image.pixel_values is None:
continue
start_idx = extend_start_loc_cpu[i]
prefix_len = prefix_lens_cpu[i]
pixel_values = image.pixel_values.to(device="cuda")
image_grid_thws = torch.tensor(
np.array(image.image_grid_thws), device="cuda"
)
image_offsets = image.image_offsets
image_input = Qwen2VLImageInputs(
pixel_values=pixel_values, image_grid_thw=image_grid_thws
)
image_embeds = self._process_image_input(image_input)
image_embeds_offset = 0
for idx, image_offset in enumerate(image_offsets):
if image_offset < prefix_len:
continue
num_image_tokens = self.calculate_num_image_tokens(
image_grid_thws[idx]
)
left_idx = start_idx + (image_offset - prefix_len)
right_idx = left_idx + num_image_tokens
tp_size = get_tensor_model_parallel_world_size()
hidden_size = image_embeds.shape[-1]
if hidden_size % tp_size != 0:
padding_size = tp_size - (hidden_size % tp_size)
image_embeds = F.pad(image_embeds, (0, padding_size))
inputs_embeds = F.pad(inputs_embeds, (0, padding_size))
hidden_chunk_size = image_embeds.shape[-1] // tp_size
rank = get_tensor_model_parallel_rank()
start_dim = rank * hidden_chunk_size
end_dim = (rank + 1) * hidden_chunk_size
inputs_embeds[left_idx:right_idx, ..., start_dim:end_dim] = (
image_embeds[
image_embeds_offset : image_embeds_offset
+ num_image_tokens,
...,
start_dim:end_dim,
]
)
image_embeds_offset += num_image_tokens
inputs_embeds = general_mm_embed_routine(
input_ids=input_ids,
positions=positions,
forward_batch=forward_batch,
embed_tokens=self.get_input_embeddings(),
image_embedding_func=self.get_image_feature,
)
hidden_states = self.model(
input_ids=input_ids,
input_ids=None,
positions=positions,
forward_batch=forward_batch,
input_embeds=inputs_embeds,

View File

@@ -26,7 +26,6 @@ import logging
from functools import lru_cache, partial
from typing import Iterable, List, Optional, Tuple, Type, TypedDict
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
@@ -42,8 +41,9 @@ from sglang.srt.layers.logits_processor import LogitsProcessor
from sglang.srt.layers.pooler import Pooler, PoolingType
from sglang.srt.layers.quantization.base_config import QuantizationConfig
from sglang.srt.layers.vocab_parallel_embedding import ParallelLMHead
from sglang.srt.managers.multi_modality_padding import (
from sglang.srt.managers.mm_utils import (
MultiModalityDataPaddingPatternTokenPairs,
general_mm_embed_routine,
)
from sglang.srt.managers.schedule_batch import ImageInputs
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
@@ -351,7 +351,7 @@ class Qwen2VisionTransformer(nn.Module):
@property
def dtype(self) -> torch.dtype:
return self.blocks[0].mlp.fc2.weight.dtype
return next(self.parameters()).dtype
@property
def device(self) -> torch.device:
@@ -359,7 +359,8 @@ class Qwen2VisionTransformer(nn.Module):
def rot_pos_emb(self, grid_thw: torch.Tensor) -> torch.Tensor:
pos_ids = []
for t, h, w in grid_thw:
for i in range(grid_thw.size(0)):
t, h, w = grid_thw[i].tolist()
hpos_ids = torch.arange(h).unsqueeze(1).expand(-1, w)
wpos_ids = torch.arange(w).unsqueeze(0).expand(h, -1)
hpos_ids = (
@@ -480,9 +481,9 @@ class Qwen2VLForConditionalGeneration(nn.Module):
pattern = MultiModalityDataPaddingPatternTokenPairs(media_token_pairs)
return pattern.pad_input_tokens(input_ids, image_inputs)
def _process_image_input(self, image_input: Qwen2VLImageInputs) -> torch.Tensor:
pixel_values = image_input["pixel_values"].type(self.visual.dtype)
image_embeds = self.visual(pixel_values, grid_thw=image_input["image_grid_thw"])
def get_image_feature(self, image_input: ImageInputs) -> torch.Tensor:
pixel_values = image_input.pixel_values.type(self.visual.dtype)
image_embeds = self.visual(pixel_values, grid_thw=image_input.image_grid_thws)
return image_embeds
def _process_video_input(self, video_input: Qwen2VLVideoInputs) -> torch.Tensor:
@@ -492,6 +493,9 @@ class Qwen2VLForConditionalGeneration(nn.Module):
)
return video_embeds
def get_input_embeddings(self):
return self.model.embed_tokens
def forward(
self,
input_ids: torch.Tensor,
@@ -514,67 +518,26 @@ class Qwen2VLForConditionalGeneration(nn.Module):
if getattr(self.config, "rope_scaling", {}).get("type", None) == "mrope":
positions = forward_batch.mrope_positions
image_inputs = None
if forward_batch.image_inputs is not None:
image_inputs = [
img for img in forward_batch.image_inputs if img is not None
]
if (
if not (
forward_batch.forward_mode.is_decode()
or image_inputs is None
or len(image_inputs) == 0
or not forward_batch.contains_image_inputs()
):
inputs_embeds = self.model.embed_tokens(input_ids)
else:
if getattr(self.config, "rope_scaling", {}).get("type", None) == "mrope":
assert positions.ndim == 2 and positions.size(0) == 3, (
"multimodal section rotary embedding requires "
f"(3, seq_len) positions, but got {positions.size()}"
)
# Clamp input ids. This is because the input_ids for the image tokens are
# filled with the hash values of the image for the prefix matching in the radix attention.
# There values are useless because their embeddings will be replaced by vision embeddings anyway.
input_ids.clamp_(min=0, max=self.config.vocab_size - 1)
inputs_embeds = self.model.embed_tokens(input_ids)
extend_start_loc_cpu = forward_batch.extend_start_loc.cpu().numpy()
prefix_lens_cpu = forward_batch.extend_prefix_lens_cpu
for i, image in enumerate(forward_batch.image_inputs):
if image is None or image.pixel_values is None:
continue
start_idx = extend_start_loc_cpu[i]
prefix_len = prefix_lens_cpu[i]
pixel_values = image.pixel_values.clone()
image_grid_thws = torch.tensor(
np.array(image.image_grid_thws), device="cuda"
)
image_offsets = image.image_offsets
image_input = Qwen2VLImageInputs(
pixel_values=pixel_values, image_grid_thw=image_grid_thws
)
image_embeds = self._process_image_input(image_input)
image_embeds_offset = 0
for idx, image_offset in enumerate(image_offsets):
if image_offset < prefix_len:
continue
num_image_tokens = self.calculate_num_image_tokens(
image_grid_thws[idx]
)
left_idx = start_idx + (image_offset - prefix_len + 1)
right_idx = left_idx + num_image_tokens
inputs_embeds[left_idx:right_idx] = image_embeds[
image_embeds_offset : image_embeds_offset + num_image_tokens
]
image_embeds_offset += num_image_tokens
input_ids = None
inputs_embeds = general_mm_embed_routine(
input_ids=input_ids,
positions=positions,
forward_batch=forward_batch,
embed_tokens=self.get_input_embeddings(),
image_embedding_func=self.get_image_feature,
)
hidden_states = self.model(
input_ids=input_ids,
input_ids=None,
positions=positions,
forward_batch=forward_batch,
input_embeds=inputs_embeds,