model: support mllama4 (#5144)

This commit is contained in:
Mick
2025-04-10 00:28:44 +08:00
committed by GitHub
parent 87eddedfa2
commit fbebcb7aa4
7 changed files with 145 additions and 65 deletions

View File

@@ -466,6 +466,9 @@ class Llama4ForCausalLM(LlamaForCausalLM):
):
super().__init__(config, quant_config, prefix)
def get_input_embeddings(self):
return self.model.embed_tokens
def _init_model(
self,
config: Llama4TextConfig,

View File

@@ -1,14 +1,19 @@
# TODO: add Aapted from vllm/mllama4.py
from collections.abc import Iterable
from typing import Optional, Set, Tuple
from typing import List, Optional, Set, Tuple
import torch
from torch import nn
from transformers import Llama4Config
from transformers import Llama4Config, Llama4VisionModel
from transformers.models.llama4.modeling_llama4 import Llama4MultiModalProjector
from sglang.srt.layers.logits_processor import LogitsProcessor
from sglang.srt.layers.moe.fused_moe_triton import FusedMoE
from sglang.srt.layers.quantization import QuantizationConfig
from sglang.srt.managers.mm_utils import (
MultiModalityDataPaddingPatternImageTokens,
general_mm_embed_routine,
)
from sglang.srt.managers.schedule_batch import MultimodalDataItem, MultimodalInputs
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
from sglang.srt.model_loader.weight_utils import default_weight_loader
from sglang.srt.utils import add_prefix
@@ -30,6 +35,9 @@ class Llama4ForConditionalGeneration(nn.Module):
self.config = config
self.quant_config = quant_config
self.vision_model = Llama4VisionModel(config.vision_config)
self.multi_modal_projector = Llama4MultiModalProjector(config)
# Initialize the language model
from sglang.srt.models.llama4 import Llama4ForCausalLM
@@ -41,6 +49,29 @@ class Llama4ForConditionalGeneration(nn.Module):
self.logits_processor = LogitsProcessor(config.text_config)
def pad_input_ids(self, input_ids: List[int], mm_inputs: MultimodalInputs):
# Get all special token IDs
im_token_id: int = mm_inputs.im_token_id
pattern = MultiModalityDataPaddingPatternImageTokens(torch.tensor(im_token_id))
return pattern.pad_input_tokens(input_ids, mm_inputs)
def get_image_feature(
self,
items: List[MultimodalDataItem],
) -> torch.Tensor:
pixel_values = (
torch.concat([item.pixel_values for item in items])
.to(next(self.vision_model.parameters()).device)
.type(next(self.vision_model.parameters()).dtype)
)
image_outputs = self.vision_model(pixel_values, output_hidden_states=False)
image_features = image_outputs.last_hidden_state
vision_flat = image_features.view(-1, image_features.size(-1))
projected_vision_flat = self.multi_modal_projector(vision_flat)
return projected_vision_flat
def forward(
self,
input_ids: torch.Tensor,
@@ -49,7 +80,15 @@ class Llama4ForConditionalGeneration(nn.Module):
**kwargs: object,
) -> torch.Tensor:
return self.language_model(input_ids, positions, forward_batch)
hs = general_mm_embed_routine(
input_ids=input_ids,
forward_batch=forward_batch,
language_model=self.language_model,
image_data_embedding_func=self.get_image_feature,
positions=positions,
)
return hs
def permute_qk_weight_for_rotary(
self,
@@ -108,17 +147,17 @@ class Llama4ForConditionalGeneration(nn.Module):
)
for name, loaded_weight in weights:
if name.startswith("vision_model") or name.startswith(
"multi_modal_projector"
):
continue
name, loaded_weight = self.permute_qk_weight_for_rotary(name, loaded_weight)
if not "vision" in name:
name, loaded_weight = self.permute_qk_weight_for_rotary(
name, loaded_weight
)
for param_name, weight_name, shard_id in stacked_params_mapping:
if weight_name not in name:
continue
if "vision" in name:
continue
name = name.replace(weight_name, param_name)
param = params_dict[name]
weight_loader = param.weight_loader