model: support mllama4 (#5144)

This commit is contained in:
Mick
2025-04-10 00:28:44 +08:00
committed by GitHub
parent 87eddedfa2
commit fbebcb7aa4
7 changed files with 145 additions and 65 deletions

View File

@@ -148,7 +148,8 @@ def get_embedding_and_mask(
placeholder_tensor,
).unsqueeze(-1)
num_mm_tokens_in_input_ids = special_multimodal_mask.sum()
num_mm_tokens_in_input_ids = special_multimodal_mask.sum().item()
if num_mm_tokens_in_input_ids != num_mm_tokens_in_embedding:
logger.warning(
f"Number of tokens in multimodal embedding does not match those in the input text."
@@ -172,7 +173,7 @@ def get_embedding_and_mask(
embedding = embedding[-num_multimodal:, :]
else:
raise RuntimeError(
"Insufficient multimodal embedding length. This is an internal error"
f"Insufficient multimodal embedding length: {num_mm_tokens_in_input_ids=} vs {num_mm_tokens_in_embedding=}. This is an internal error"
)
return embedding, special_multimodal_mask

View File

@@ -1,10 +1,8 @@
from typing import List, Mapping, Optional, Tuple, Union
from typing import List, Union
import torch
from PIL import Image
from transformers import Llama4Processor
from transformers.image_utils import SizeDict
from transformers.models.llama4.image_processing_llama4 import (
from transformers.models.llama4.image_processing_llama4_fast import (
find_supported_resolutions,
get_best_fit,
)
@@ -15,7 +13,6 @@ from sglang.srt.managers.multimodal_processors.base_processor import (
)
from sglang.srt.managers.schedule_batch import Modality, MultimodalDataItem
from sglang.srt.models.mllama4 import Llama4ForConditionalGeneration
from sglang.srt.utils import load_image
class Mllama4ImageProcessor(BaseMultimodalProcessor):
@@ -25,6 +22,9 @@ class Mllama4ImageProcessor(BaseMultimodalProcessor):
super().__init__(hf_config, server_args, _processor)
self.vision_config = hf_config.vision_config
self.text_config = hf_config.text_config
self.boi_token_index = hf_config.boi_token_index
self.eoi_token_index = hf_config.eoi_token_index
self.image_token_index = hf_config.image_token_index
self.multimodal_tokens = MultimodalSpecialTokens(
image_token=_processor.image_token
)
@@ -54,19 +54,16 @@ class Mllama4ImageProcessor(BaseMultimodalProcessor):
)
# Process the images using the processor
processor = Llama4Processor.from_pretrained(
self.server_args.model_path, **kwargs
)
processor = self._processor
# Process the prompt and images
image_inputs = processor(
text=processed_data.input_text,
processor_output = self.process_mm_data(
input_text=processed_data.input_text,
images=processed_data.images,
return_tensors="pt",
)
# Handle image resolutions and aspect ratios
if "pixel_values" in image_inputs:
if "pixel_values" in processor_output:
image_processor = processor.image_processor
tokenizer = self._processor.tokenizer
@@ -100,8 +97,8 @@ class Mllama4ImageProcessor(BaseMultimodalProcessor):
]
# Add to image_inputs
image_inputs["aspect_ratios"] = aspect_ratios
image_inputs["patches_per_image"] = torch.tensor(patches_per_image)
processor_output["aspect_ratios"] = aspect_ratios
processor_output["patches_per_image"] = torch.tensor(patches_per_image)
# Process embed_is_patch
vocab = tokenizer.get_vocab()
@@ -109,7 +106,7 @@ class Mllama4ImageProcessor(BaseMultimodalProcessor):
image_end_id = vocab.get(processor.end_of_img_token, -1)
if patch_id != -1 and image_end_id != -1:
input_ids = image_inputs["input_ids"].view(-1)
input_ids = processor_output["input_ids"].view(-1)
# Remove BOS token if present
if input_ids.size(0) > 0 and input_ids[0] == tokenizer.bos_token_id:
@@ -129,33 +126,21 @@ class Mllama4ImageProcessor(BaseMultimodalProcessor):
for per_image_input_ids in split_input_ids:
embed_is_patch.append(per_image_input_ids == patch_id)
image_inputs["embed_is_patch"] = embed_is_patch
processor_output["embed_is_patch"] = embed_is_patch
# Convert to the format expected by SGLang
image_inputs["input_ids"] = image_inputs["input_ids"].tolist()[0]
processor_output["input_ids"] = processor_output["input_ids"].tolist()[0]
processor_output["im_start_id"] = self.boi_token_index
processor_output["im_end_id"] = self.eoi_token_index
processor_output["im_token_id"] = self.image_token_index
# Add metadata for image processing
image_inputs["mm_items"] = [
processor_output["mm_items"] = [
MultimodalDataItem(
pixel_values=image_inputs["pixel_values"],
pixel_values=processor_output["pixel_values"],
modality=Modality.IMAGE,
# Add additional metadata needed for Llama4 vision processing
embed_is_patch=image_inputs.get("embed_is_patch", None),
aspect_ratios=image_inputs.get("aspect_ratios", None),
patches_per_image=image_inputs.get("patches_per_image", None),
)
]
return image_inputs
def get_patch_per_chunk(self):
"""Calculate patches per chunk based on vision config"""
image_size = self.vision_config.image_size
patch_size = self.vision_config.patch_size
assert (
image_size % patch_size == 0
), f"chunk size {image_size} should be multiple of patch_size {patch_size}"
ds_ratio = int(round(1.0 / (self.vision_config.pixel_shuffle_ratio**2)))
return (image_size // patch_size) ** 2 // ds_ratio
return processor_output

View File

@@ -1,5 +1,6 @@
from __future__ import annotations
import hashlib
from enum import Enum, auto
# Copyright 2023-2024 SGLang Team
@@ -157,7 +158,7 @@ class Modality(Enum):
@dataclasses.dataclass
class MultimodalDataItem:
"""
A single multimodal data, from a single image/video/audio or other
A single multimodal data, from a single image/video/audio or others
"""
modality: Modality
@@ -195,25 +196,54 @@ class MultimodalDataItem:
def set_pad_value(self):
"""
Set the pad value after first hashign the data
Set the pad value after first hashing the data
"""
def tensor_hash(f):
f_list = flatten_nested_list(f)
f_list = [x.flatten() if isinstance(x, torch.Tensor) else x for x in f_list]
f_cat = torch.concat(f_list).contiguous().numpy().tobytes()
return hash(f_cat)
def data_hash(data) -> int:
hash_bytes = hashlib.sha256(data).digest()[:8]
return int.from_bytes(hash_bytes, byteorder="big", signed=False)
def tensor_hash(tensor_list) -> int:
"""
hash a tensor or a tensor list
"""
tensor = tensor_list
if isinstance(tensor_list, list):
tensor_list = flatten_nested_list(tensor_list)
tensor_list = [
x.flatten() if isinstance(x, torch.Tensor) else x
for x in tensor_list
]
tensor = torch.concat(tensor_list)
tensor = tensor.detach().contiguous()
if tensor.dtype == torch.bfloat16:
# memoryview() doesn't support PyTorch's BFloat16 dtype
tensor = tensor.float()
if tensor.is_cuda:
tensor_cpu = torch.frombuffer(
tensor.storage().untyped(), dtype=tensor.dtype, count=tensor.numel()
).clone()
else:
tensor_cpu = tensor
mv = memoryview(tensor_cpu.numpy())
return data_hash(mv.tobytes())
def hash_feature(f):
if isinstance(f, list):
if isinstance(f[0], torch.Tensor):
return tensor_hash(f)
return hash(tuple(flatten_nested_list(f)))
return data_hash(tuple(flatten_nested_list(f)))
elif isinstance(f, np.ndarray):
arr = np.ascontiguousarray(f)
arr_bytes = arr.tobytes()
return hash(arr_bytes)
return hash(f)
return data_hash(arr_bytes)
elif isinstance(f, torch.Tensor):
return tensor_hash([f])
return data_hash(f)
if self.is_audio():
self.hash = hash_feature(self.audio_features)
@@ -256,7 +286,7 @@ class MultimodalInputs:
mrope_position_delta: Optional[torch.Tensor] = None
# image
im_token_id: Optional[torch.Tensor] = None
im_token_id: Optional[int] = None
im_start_id: Optional[int] = None
im_end_id: Optional[int] = None
slice_start_id: Optional[int] = None
@@ -330,10 +360,8 @@ class MultimodalInputs:
# args needed to be merged
optional_args = [
"items",
"image_offsets",
"mm_items",
"image_pad_len",
# "modalities", # modalities should be ["multi-images"] (one entry) even for multiple images
]
for arg in optional_args:
self_arg = getattr(self, arg, None)