model: support mllama4 (#5144)
This commit is contained in:
@@ -486,8 +486,8 @@ multimodal_model_archs = [
|
||||
"Gemma3ForConditionalGeneration",
|
||||
"Grok1VForCausalLM",
|
||||
"Grok1AForCausalLM",
|
||||
# TODO: add multimodal support for "Llama4ForConditionalGeneration",
|
||||
"LlavaLlamaForCausalLM",
|
||||
"Llama4ForConditionalGeneration",
|
||||
"LlavaMistralForCausalLM",
|
||||
"LlavaQwenForCausalLM",
|
||||
"LlavaVidForCausalLM",
|
||||
|
||||
@@ -148,7 +148,8 @@ def get_embedding_and_mask(
|
||||
placeholder_tensor,
|
||||
).unsqueeze(-1)
|
||||
|
||||
num_mm_tokens_in_input_ids = special_multimodal_mask.sum()
|
||||
num_mm_tokens_in_input_ids = special_multimodal_mask.sum().item()
|
||||
|
||||
if num_mm_tokens_in_input_ids != num_mm_tokens_in_embedding:
|
||||
logger.warning(
|
||||
f"Number of tokens in multimodal embedding does not match those in the input text."
|
||||
@@ -172,7 +173,7 @@ def get_embedding_and_mask(
|
||||
embedding = embedding[-num_multimodal:, :]
|
||||
else:
|
||||
raise RuntimeError(
|
||||
"Insufficient multimodal embedding length. This is an internal error"
|
||||
f"Insufficient multimodal embedding length: {num_mm_tokens_in_input_ids=} vs {num_mm_tokens_in_embedding=}. This is an internal error"
|
||||
)
|
||||
|
||||
return embedding, special_multimodal_mask
|
||||
|
||||
@@ -1,10 +1,8 @@
|
||||
from typing import List, Mapping, Optional, Tuple, Union
|
||||
from typing import List, Union
|
||||
|
||||
import torch
|
||||
from PIL import Image
|
||||
from transformers import Llama4Processor
|
||||
from transformers.image_utils import SizeDict
|
||||
from transformers.models.llama4.image_processing_llama4 import (
|
||||
from transformers.models.llama4.image_processing_llama4_fast import (
|
||||
find_supported_resolutions,
|
||||
get_best_fit,
|
||||
)
|
||||
@@ -15,7 +13,6 @@ from sglang.srt.managers.multimodal_processors.base_processor import (
|
||||
)
|
||||
from sglang.srt.managers.schedule_batch import Modality, MultimodalDataItem
|
||||
from sglang.srt.models.mllama4 import Llama4ForConditionalGeneration
|
||||
from sglang.srt.utils import load_image
|
||||
|
||||
|
||||
class Mllama4ImageProcessor(BaseMultimodalProcessor):
|
||||
@@ -25,6 +22,9 @@ class Mllama4ImageProcessor(BaseMultimodalProcessor):
|
||||
super().__init__(hf_config, server_args, _processor)
|
||||
self.vision_config = hf_config.vision_config
|
||||
self.text_config = hf_config.text_config
|
||||
self.boi_token_index = hf_config.boi_token_index
|
||||
self.eoi_token_index = hf_config.eoi_token_index
|
||||
self.image_token_index = hf_config.image_token_index
|
||||
self.multimodal_tokens = MultimodalSpecialTokens(
|
||||
image_token=_processor.image_token
|
||||
)
|
||||
@@ -54,19 +54,16 @@ class Mllama4ImageProcessor(BaseMultimodalProcessor):
|
||||
)
|
||||
|
||||
# Process the images using the processor
|
||||
processor = Llama4Processor.from_pretrained(
|
||||
self.server_args.model_path, **kwargs
|
||||
)
|
||||
processor = self._processor
|
||||
|
||||
# Process the prompt and images
|
||||
image_inputs = processor(
|
||||
text=processed_data.input_text,
|
||||
processor_output = self.process_mm_data(
|
||||
input_text=processed_data.input_text,
|
||||
images=processed_data.images,
|
||||
return_tensors="pt",
|
||||
)
|
||||
|
||||
# Handle image resolutions and aspect ratios
|
||||
if "pixel_values" in image_inputs:
|
||||
if "pixel_values" in processor_output:
|
||||
image_processor = processor.image_processor
|
||||
tokenizer = self._processor.tokenizer
|
||||
|
||||
@@ -100,8 +97,8 @@ class Mllama4ImageProcessor(BaseMultimodalProcessor):
|
||||
]
|
||||
|
||||
# Add to image_inputs
|
||||
image_inputs["aspect_ratios"] = aspect_ratios
|
||||
image_inputs["patches_per_image"] = torch.tensor(patches_per_image)
|
||||
processor_output["aspect_ratios"] = aspect_ratios
|
||||
processor_output["patches_per_image"] = torch.tensor(patches_per_image)
|
||||
|
||||
# Process embed_is_patch
|
||||
vocab = tokenizer.get_vocab()
|
||||
@@ -109,7 +106,7 @@ class Mllama4ImageProcessor(BaseMultimodalProcessor):
|
||||
image_end_id = vocab.get(processor.end_of_img_token, -1)
|
||||
|
||||
if patch_id != -1 and image_end_id != -1:
|
||||
input_ids = image_inputs["input_ids"].view(-1)
|
||||
input_ids = processor_output["input_ids"].view(-1)
|
||||
|
||||
# Remove BOS token if present
|
||||
if input_ids.size(0) > 0 and input_ids[0] == tokenizer.bos_token_id:
|
||||
@@ -129,33 +126,21 @@ class Mllama4ImageProcessor(BaseMultimodalProcessor):
|
||||
for per_image_input_ids in split_input_ids:
|
||||
embed_is_patch.append(per_image_input_ids == patch_id)
|
||||
|
||||
image_inputs["embed_is_patch"] = embed_is_patch
|
||||
processor_output["embed_is_patch"] = embed_is_patch
|
||||
|
||||
# Convert to the format expected by SGLang
|
||||
image_inputs["input_ids"] = image_inputs["input_ids"].tolist()[0]
|
||||
processor_output["input_ids"] = processor_output["input_ids"].tolist()[0]
|
||||
|
||||
processor_output["im_start_id"] = self.boi_token_index
|
||||
processor_output["im_end_id"] = self.eoi_token_index
|
||||
processor_output["im_token_id"] = self.image_token_index
|
||||
|
||||
# Add metadata for image processing
|
||||
image_inputs["mm_items"] = [
|
||||
processor_output["mm_items"] = [
|
||||
MultimodalDataItem(
|
||||
pixel_values=image_inputs["pixel_values"],
|
||||
pixel_values=processor_output["pixel_values"],
|
||||
modality=Modality.IMAGE,
|
||||
# Add additional metadata needed for Llama4 vision processing
|
||||
embed_is_patch=image_inputs.get("embed_is_patch", None),
|
||||
aspect_ratios=image_inputs.get("aspect_ratios", None),
|
||||
patches_per_image=image_inputs.get("patches_per_image", None),
|
||||
)
|
||||
]
|
||||
|
||||
return image_inputs
|
||||
|
||||
def get_patch_per_chunk(self):
|
||||
"""Calculate patches per chunk based on vision config"""
|
||||
image_size = self.vision_config.image_size
|
||||
patch_size = self.vision_config.patch_size
|
||||
|
||||
assert (
|
||||
image_size % patch_size == 0
|
||||
), f"chunk size {image_size} should be multiple of patch_size {patch_size}"
|
||||
|
||||
ds_ratio = int(round(1.0 / (self.vision_config.pixel_shuffle_ratio**2)))
|
||||
return (image_size // patch_size) ** 2 // ds_ratio
|
||||
return processor_output
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import hashlib
|
||||
from enum import Enum, auto
|
||||
|
||||
# Copyright 2023-2024 SGLang Team
|
||||
@@ -157,7 +158,7 @@ class Modality(Enum):
|
||||
@dataclasses.dataclass
|
||||
class MultimodalDataItem:
|
||||
"""
|
||||
A single multimodal data, from a single image/video/audio or other
|
||||
A single multimodal data, from a single image/video/audio or others
|
||||
"""
|
||||
|
||||
modality: Modality
|
||||
@@ -195,25 +196,54 @@ class MultimodalDataItem:
|
||||
|
||||
def set_pad_value(self):
|
||||
"""
|
||||
Set the pad value after first hashign the data
|
||||
Set the pad value after first hashing the data
|
||||
"""
|
||||
|
||||
def tensor_hash(f):
|
||||
f_list = flatten_nested_list(f)
|
||||
f_list = [x.flatten() if isinstance(x, torch.Tensor) else x for x in f_list]
|
||||
f_cat = torch.concat(f_list).contiguous().numpy().tobytes()
|
||||
return hash(f_cat)
|
||||
def data_hash(data) -> int:
|
||||
hash_bytes = hashlib.sha256(data).digest()[:8]
|
||||
return int.from_bytes(hash_bytes, byteorder="big", signed=False)
|
||||
|
||||
def tensor_hash(tensor_list) -> int:
|
||||
"""
|
||||
hash a tensor or a tensor list
|
||||
"""
|
||||
tensor = tensor_list
|
||||
if isinstance(tensor_list, list):
|
||||
tensor_list = flatten_nested_list(tensor_list)
|
||||
tensor_list = [
|
||||
x.flatten() if isinstance(x, torch.Tensor) else x
|
||||
for x in tensor_list
|
||||
]
|
||||
tensor = torch.concat(tensor_list)
|
||||
|
||||
tensor = tensor.detach().contiguous()
|
||||
|
||||
if tensor.dtype == torch.bfloat16:
|
||||
# memoryview() doesn't support PyTorch's BFloat16 dtype
|
||||
tensor = tensor.float()
|
||||
|
||||
if tensor.is_cuda:
|
||||
tensor_cpu = torch.frombuffer(
|
||||
tensor.storage().untyped(), dtype=tensor.dtype, count=tensor.numel()
|
||||
).clone()
|
||||
else:
|
||||
tensor_cpu = tensor
|
||||
|
||||
mv = memoryview(tensor_cpu.numpy())
|
||||
return data_hash(mv.tobytes())
|
||||
|
||||
def hash_feature(f):
|
||||
if isinstance(f, list):
|
||||
if isinstance(f[0], torch.Tensor):
|
||||
return tensor_hash(f)
|
||||
return hash(tuple(flatten_nested_list(f)))
|
||||
return data_hash(tuple(flatten_nested_list(f)))
|
||||
elif isinstance(f, np.ndarray):
|
||||
arr = np.ascontiguousarray(f)
|
||||
arr_bytes = arr.tobytes()
|
||||
return hash(arr_bytes)
|
||||
return hash(f)
|
||||
return data_hash(arr_bytes)
|
||||
elif isinstance(f, torch.Tensor):
|
||||
return tensor_hash([f])
|
||||
return data_hash(f)
|
||||
|
||||
if self.is_audio():
|
||||
self.hash = hash_feature(self.audio_features)
|
||||
@@ -256,7 +286,7 @@ class MultimodalInputs:
|
||||
mrope_position_delta: Optional[torch.Tensor] = None
|
||||
|
||||
# image
|
||||
im_token_id: Optional[torch.Tensor] = None
|
||||
im_token_id: Optional[int] = None
|
||||
im_start_id: Optional[int] = None
|
||||
im_end_id: Optional[int] = None
|
||||
slice_start_id: Optional[int] = None
|
||||
@@ -330,10 +360,8 @@ class MultimodalInputs:
|
||||
|
||||
# args needed to be merged
|
||||
optional_args = [
|
||||
"items",
|
||||
"image_offsets",
|
||||
"mm_items",
|
||||
"image_pad_len",
|
||||
# "modalities", # modalities should be ["multi-images"] (one entry) even for multiple images
|
||||
]
|
||||
for arg in optional_args:
|
||||
self_arg = getattr(self, arg, None)
|
||||
|
||||
@@ -466,6 +466,9 @@ class Llama4ForCausalLM(LlamaForCausalLM):
|
||||
):
|
||||
super().__init__(config, quant_config, prefix)
|
||||
|
||||
def get_input_embeddings(self):
|
||||
return self.model.embed_tokens
|
||||
|
||||
def _init_model(
|
||||
self,
|
||||
config: Llama4TextConfig,
|
||||
|
||||
@@ -1,14 +1,19 @@
|
||||
# TODO: add Aapted from vllm/mllama4.py
|
||||
from collections.abc import Iterable
|
||||
from typing import Optional, Set, Tuple
|
||||
from typing import List, Optional, Set, Tuple
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
from transformers import Llama4Config
|
||||
from transformers import Llama4Config, Llama4VisionModel
|
||||
from transformers.models.llama4.modeling_llama4 import Llama4MultiModalProjector
|
||||
|
||||
from sglang.srt.layers.logits_processor import LogitsProcessor
|
||||
from sglang.srt.layers.moe.fused_moe_triton import FusedMoE
|
||||
from sglang.srt.layers.quantization import QuantizationConfig
|
||||
from sglang.srt.managers.mm_utils import (
|
||||
MultiModalityDataPaddingPatternImageTokens,
|
||||
general_mm_embed_routine,
|
||||
)
|
||||
from sglang.srt.managers.schedule_batch import MultimodalDataItem, MultimodalInputs
|
||||
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
|
||||
from sglang.srt.model_loader.weight_utils import default_weight_loader
|
||||
from sglang.srt.utils import add_prefix
|
||||
@@ -30,6 +35,9 @@ class Llama4ForConditionalGeneration(nn.Module):
|
||||
self.config = config
|
||||
self.quant_config = quant_config
|
||||
|
||||
self.vision_model = Llama4VisionModel(config.vision_config)
|
||||
self.multi_modal_projector = Llama4MultiModalProjector(config)
|
||||
|
||||
# Initialize the language model
|
||||
from sglang.srt.models.llama4 import Llama4ForCausalLM
|
||||
|
||||
@@ -41,6 +49,29 @@ class Llama4ForConditionalGeneration(nn.Module):
|
||||
|
||||
self.logits_processor = LogitsProcessor(config.text_config)
|
||||
|
||||
def pad_input_ids(self, input_ids: List[int], mm_inputs: MultimodalInputs):
|
||||
# Get all special token IDs
|
||||
im_token_id: int = mm_inputs.im_token_id
|
||||
|
||||
pattern = MultiModalityDataPaddingPatternImageTokens(torch.tensor(im_token_id))
|
||||
return pattern.pad_input_tokens(input_ids, mm_inputs)
|
||||
|
||||
def get_image_feature(
|
||||
self,
|
||||
items: List[MultimodalDataItem],
|
||||
) -> torch.Tensor:
|
||||
pixel_values = (
|
||||
torch.concat([item.pixel_values for item in items])
|
||||
.to(next(self.vision_model.parameters()).device)
|
||||
.type(next(self.vision_model.parameters()).dtype)
|
||||
)
|
||||
|
||||
image_outputs = self.vision_model(pixel_values, output_hidden_states=False)
|
||||
image_features = image_outputs.last_hidden_state
|
||||
vision_flat = image_features.view(-1, image_features.size(-1))
|
||||
projected_vision_flat = self.multi_modal_projector(vision_flat)
|
||||
return projected_vision_flat
|
||||
|
||||
def forward(
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
@@ -49,7 +80,15 @@ class Llama4ForConditionalGeneration(nn.Module):
|
||||
**kwargs: object,
|
||||
) -> torch.Tensor:
|
||||
|
||||
return self.language_model(input_ids, positions, forward_batch)
|
||||
hs = general_mm_embed_routine(
|
||||
input_ids=input_ids,
|
||||
forward_batch=forward_batch,
|
||||
language_model=self.language_model,
|
||||
image_data_embedding_func=self.get_image_feature,
|
||||
positions=positions,
|
||||
)
|
||||
|
||||
return hs
|
||||
|
||||
def permute_qk_weight_for_rotary(
|
||||
self,
|
||||
@@ -108,17 +147,17 @@ class Llama4ForConditionalGeneration(nn.Module):
|
||||
)
|
||||
|
||||
for name, loaded_weight in weights:
|
||||
|
||||
if name.startswith("vision_model") or name.startswith(
|
||||
"multi_modal_projector"
|
||||
):
|
||||
continue
|
||||
|
||||
name, loaded_weight = self.permute_qk_weight_for_rotary(name, loaded_weight)
|
||||
if not "vision" in name:
|
||||
name, loaded_weight = self.permute_qk_weight_for_rotary(
|
||||
name, loaded_weight
|
||||
)
|
||||
|
||||
for param_name, weight_name, shard_id in stacked_params_mapping:
|
||||
if weight_name not in name:
|
||||
continue
|
||||
|
||||
if "vision" in name:
|
||||
continue
|
||||
name = name.replace(weight_name, param_name)
|
||||
param = params_dict[name]
|
||||
weight_loader = param.weight_loader
|
||||
|
||||
Reference in New Issue
Block a user