model: support mllama4 (#5144)
This commit is contained in:
@@ -1,10 +1,8 @@
|
||||
from typing import List, Mapping, Optional, Tuple, Union
|
||||
from typing import List, Union
|
||||
|
||||
import torch
|
||||
from PIL import Image
|
||||
from transformers import Llama4Processor
|
||||
from transformers.image_utils import SizeDict
|
||||
from transformers.models.llama4.image_processing_llama4 import (
|
||||
from transformers.models.llama4.image_processing_llama4_fast import (
|
||||
find_supported_resolutions,
|
||||
get_best_fit,
|
||||
)
|
||||
@@ -15,7 +13,6 @@ from sglang.srt.managers.multimodal_processors.base_processor import (
|
||||
)
|
||||
from sglang.srt.managers.schedule_batch import Modality, MultimodalDataItem
|
||||
from sglang.srt.models.mllama4 import Llama4ForConditionalGeneration
|
||||
from sglang.srt.utils import load_image
|
||||
|
||||
|
||||
class Mllama4ImageProcessor(BaseMultimodalProcessor):
|
||||
@@ -25,6 +22,9 @@ class Mllama4ImageProcessor(BaseMultimodalProcessor):
|
||||
super().__init__(hf_config, server_args, _processor)
|
||||
self.vision_config = hf_config.vision_config
|
||||
self.text_config = hf_config.text_config
|
||||
self.boi_token_index = hf_config.boi_token_index
|
||||
self.eoi_token_index = hf_config.eoi_token_index
|
||||
self.image_token_index = hf_config.image_token_index
|
||||
self.multimodal_tokens = MultimodalSpecialTokens(
|
||||
image_token=_processor.image_token
|
||||
)
|
||||
@@ -54,19 +54,16 @@ class Mllama4ImageProcessor(BaseMultimodalProcessor):
|
||||
)
|
||||
|
||||
# Process the images using the processor
|
||||
processor = Llama4Processor.from_pretrained(
|
||||
self.server_args.model_path, **kwargs
|
||||
)
|
||||
processor = self._processor
|
||||
|
||||
# Process the prompt and images
|
||||
image_inputs = processor(
|
||||
text=processed_data.input_text,
|
||||
processor_output = self.process_mm_data(
|
||||
input_text=processed_data.input_text,
|
||||
images=processed_data.images,
|
||||
return_tensors="pt",
|
||||
)
|
||||
|
||||
# Handle image resolutions and aspect ratios
|
||||
if "pixel_values" in image_inputs:
|
||||
if "pixel_values" in processor_output:
|
||||
image_processor = processor.image_processor
|
||||
tokenizer = self._processor.tokenizer
|
||||
|
||||
@@ -100,8 +97,8 @@ class Mllama4ImageProcessor(BaseMultimodalProcessor):
|
||||
]
|
||||
|
||||
# Add to image_inputs
|
||||
image_inputs["aspect_ratios"] = aspect_ratios
|
||||
image_inputs["patches_per_image"] = torch.tensor(patches_per_image)
|
||||
processor_output["aspect_ratios"] = aspect_ratios
|
||||
processor_output["patches_per_image"] = torch.tensor(patches_per_image)
|
||||
|
||||
# Process embed_is_patch
|
||||
vocab = tokenizer.get_vocab()
|
||||
@@ -109,7 +106,7 @@ class Mllama4ImageProcessor(BaseMultimodalProcessor):
|
||||
image_end_id = vocab.get(processor.end_of_img_token, -1)
|
||||
|
||||
if patch_id != -1 and image_end_id != -1:
|
||||
input_ids = image_inputs["input_ids"].view(-1)
|
||||
input_ids = processor_output["input_ids"].view(-1)
|
||||
|
||||
# Remove BOS token if present
|
||||
if input_ids.size(0) > 0 and input_ids[0] == tokenizer.bos_token_id:
|
||||
@@ -129,33 +126,21 @@ class Mllama4ImageProcessor(BaseMultimodalProcessor):
|
||||
for per_image_input_ids in split_input_ids:
|
||||
embed_is_patch.append(per_image_input_ids == patch_id)
|
||||
|
||||
image_inputs["embed_is_patch"] = embed_is_patch
|
||||
processor_output["embed_is_patch"] = embed_is_patch
|
||||
|
||||
# Convert to the format expected by SGLang
|
||||
image_inputs["input_ids"] = image_inputs["input_ids"].tolist()[0]
|
||||
processor_output["input_ids"] = processor_output["input_ids"].tolist()[0]
|
||||
|
||||
processor_output["im_start_id"] = self.boi_token_index
|
||||
processor_output["im_end_id"] = self.eoi_token_index
|
||||
processor_output["im_token_id"] = self.image_token_index
|
||||
|
||||
# Add metadata for image processing
|
||||
image_inputs["mm_items"] = [
|
||||
processor_output["mm_items"] = [
|
||||
MultimodalDataItem(
|
||||
pixel_values=image_inputs["pixel_values"],
|
||||
pixel_values=processor_output["pixel_values"],
|
||||
modality=Modality.IMAGE,
|
||||
# Add additional metadata needed for Llama4 vision processing
|
||||
embed_is_patch=image_inputs.get("embed_is_patch", None),
|
||||
aspect_ratios=image_inputs.get("aspect_ratios", None),
|
||||
patches_per_image=image_inputs.get("patches_per_image", None),
|
||||
)
|
||||
]
|
||||
|
||||
return image_inputs
|
||||
|
||||
def get_patch_per_chunk(self):
|
||||
"""Calculate patches per chunk based on vision config"""
|
||||
image_size = self.vision_config.image_size
|
||||
patch_size = self.vision_config.patch_size
|
||||
|
||||
assert (
|
||||
image_size % patch_size == 0
|
||||
), f"chunk size {image_size} should be multiple of patch_size {patch_size}"
|
||||
|
||||
ds_ratio = int(round(1.0 / (self.vision_config.pixel_shuffle_ratio**2)))
|
||||
return (image_size // patch_size) ** 2 // ds_ratio
|
||||
return processor_output
|
||||
|
||||
Reference in New Issue
Block a user