model: support mllama4 (#5144)
This commit is contained in:
@@ -466,6 +466,9 @@ class Llama4ForCausalLM(LlamaForCausalLM):
|
||||
):
|
||||
super().__init__(config, quant_config, prefix)
|
||||
|
||||
def get_input_embeddings(self):
|
||||
return self.model.embed_tokens
|
||||
|
||||
def _init_model(
|
||||
self,
|
||||
config: Llama4TextConfig,
|
||||
|
||||
@@ -1,14 +1,19 @@
|
||||
# TODO: add Aapted from vllm/mllama4.py
|
||||
from collections.abc import Iterable
|
||||
from typing import Optional, Set, Tuple
|
||||
from typing import List, Optional, Set, Tuple
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
from transformers import Llama4Config
|
||||
from transformers import Llama4Config, Llama4VisionModel
|
||||
from transformers.models.llama4.modeling_llama4 import Llama4MultiModalProjector
|
||||
|
||||
from sglang.srt.layers.logits_processor import LogitsProcessor
|
||||
from sglang.srt.layers.moe.fused_moe_triton import FusedMoE
|
||||
from sglang.srt.layers.quantization import QuantizationConfig
|
||||
from sglang.srt.managers.mm_utils import (
|
||||
MultiModalityDataPaddingPatternImageTokens,
|
||||
general_mm_embed_routine,
|
||||
)
|
||||
from sglang.srt.managers.schedule_batch import MultimodalDataItem, MultimodalInputs
|
||||
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
|
||||
from sglang.srt.model_loader.weight_utils import default_weight_loader
|
||||
from sglang.srt.utils import add_prefix
|
||||
@@ -30,6 +35,9 @@ class Llama4ForConditionalGeneration(nn.Module):
|
||||
self.config = config
|
||||
self.quant_config = quant_config
|
||||
|
||||
self.vision_model = Llama4VisionModel(config.vision_config)
|
||||
self.multi_modal_projector = Llama4MultiModalProjector(config)
|
||||
|
||||
# Initialize the language model
|
||||
from sglang.srt.models.llama4 import Llama4ForCausalLM
|
||||
|
||||
@@ -41,6 +49,29 @@ class Llama4ForConditionalGeneration(nn.Module):
|
||||
|
||||
self.logits_processor = LogitsProcessor(config.text_config)
|
||||
|
||||
def pad_input_ids(self, input_ids: List[int], mm_inputs: MultimodalInputs):
|
||||
# Get all special token IDs
|
||||
im_token_id: int = mm_inputs.im_token_id
|
||||
|
||||
pattern = MultiModalityDataPaddingPatternImageTokens(torch.tensor(im_token_id))
|
||||
return pattern.pad_input_tokens(input_ids, mm_inputs)
|
||||
|
||||
def get_image_feature(
|
||||
self,
|
||||
items: List[MultimodalDataItem],
|
||||
) -> torch.Tensor:
|
||||
pixel_values = (
|
||||
torch.concat([item.pixel_values for item in items])
|
||||
.to(next(self.vision_model.parameters()).device)
|
||||
.type(next(self.vision_model.parameters()).dtype)
|
||||
)
|
||||
|
||||
image_outputs = self.vision_model(pixel_values, output_hidden_states=False)
|
||||
image_features = image_outputs.last_hidden_state
|
||||
vision_flat = image_features.view(-1, image_features.size(-1))
|
||||
projected_vision_flat = self.multi_modal_projector(vision_flat)
|
||||
return projected_vision_flat
|
||||
|
||||
def forward(
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
@@ -49,7 +80,15 @@ class Llama4ForConditionalGeneration(nn.Module):
|
||||
**kwargs: object,
|
||||
) -> torch.Tensor:
|
||||
|
||||
return self.language_model(input_ids, positions, forward_batch)
|
||||
hs = general_mm_embed_routine(
|
||||
input_ids=input_ids,
|
||||
forward_batch=forward_batch,
|
||||
language_model=self.language_model,
|
||||
image_data_embedding_func=self.get_image_feature,
|
||||
positions=positions,
|
||||
)
|
||||
|
||||
return hs
|
||||
|
||||
def permute_qk_weight_for_rotary(
|
||||
self,
|
||||
@@ -108,17 +147,17 @@ class Llama4ForConditionalGeneration(nn.Module):
|
||||
)
|
||||
|
||||
for name, loaded_weight in weights:
|
||||
|
||||
if name.startswith("vision_model") or name.startswith(
|
||||
"multi_modal_projector"
|
||||
):
|
||||
continue
|
||||
|
||||
name, loaded_weight = self.permute_qk_weight_for_rotary(name, loaded_weight)
|
||||
if not "vision" in name:
|
||||
name, loaded_weight = self.permute_qk_weight_for_rotary(
|
||||
name, loaded_weight
|
||||
)
|
||||
|
||||
for param_name, weight_name, shard_id in stacked_params_mapping:
|
||||
if weight_name not in name:
|
||||
continue
|
||||
|
||||
if "vision" in name:
|
||||
continue
|
||||
name = name.replace(weight_name, param_name)
|
||||
param = params_dict[name]
|
||||
weight_loader = param.weight_loader
|
||||
|
||||
Reference in New Issue
Block a user