From ea93079b3038dd156e6168cee1f3bf2defb37f51 Mon Sep 17 00:00:00 2001 From: Wenchen Lo Date: Sat, 2 Aug 2025 00:39:40 -0700 Subject: [PATCH] model: adapt mllama4 to VisionAttention (#8512) Co-authored-by: root --- python/sglang/srt/hf_transformers_utils.py | 35 +- python/sglang/srt/layers/attention/vision.py | 37 +- .../sglang/srt/managers/tokenizer_manager.py | 31 +- python/sglang/srt/models/llama4.py | 13 +- python/sglang/srt/models/mllama4.py | 447 +++++++++++++++++- .../multimodal/processors/base_processor.py | 7 +- 6 files changed, 518 insertions(+), 52 deletions(-) diff --git a/python/sglang/srt/hf_transformers_utils.py b/python/sglang/srt/hf_transformers_utils.py index bf16addc5..e4c87d573 100644 --- a/python/sglang/srt/hf_transformers_utils.py +++ b/python/sglang/srt/hf_transformers_utils.py @@ -14,7 +14,6 @@ """Utilities for Huggingface Transformers.""" import contextlib -import logging import os import warnings from pathlib import Path @@ -45,7 +44,7 @@ from sglang.srt.configs import ( ) from sglang.srt.configs.internvl import InternVLChatConfig from sglang.srt.connector import create_remote_connector -from sglang.srt.utils import is_remote_url, lru_cache_frozenset +from sglang.srt.utils import is_remote_url, logger, lru_cache_frozenset _CONFIG_REGISTRY: Dict[str, Type[PretrainedConfig]] = { ChatGLMConfig.model_type: ChatGLMConfig, @@ -317,15 +316,31 @@ def get_processor( if config.model_type not in {"llava", "clip"}: kwargs["use_fast"] = use_fast + try: + processor = AutoProcessor.from_pretrained( + tokenizer_name, + *args, + trust_remote_code=trust_remote_code, + revision=revision, + **kwargs, + ) - processor = AutoProcessor.from_pretrained( - tokenizer_name, - *args, - trust_remote_code=trust_remote_code, - revision=revision, - **kwargs, - ) - + except ValueError as e: + error_message = str(e) + if "does not have a slow version" in error_message: + logger.info( + f"Processor {tokenizer_name} does not have a slow version. Automatically use fast version" + ) + kwargs["use_fast"] = True + processor = AutoProcessor.from_pretrained( + tokenizer_name, + *args, + trust_remote_code=trust_remote_code, + revision=revision, + **kwargs, + ) + else: + raise e tokenizer = get_tokenizer_from_processor(processor) attach_additional_stop_token_ids(tokenizer) diff --git a/python/sglang/srt/layers/attention/vision.py b/python/sglang/srt/layers/attention/vision.py index c7bbd3ea6..ed7a36cdb 100644 --- a/python/sglang/srt/layers/attention/vision.py +++ b/python/sglang/srt/layers/attention/vision.py @@ -4,7 +4,7 @@ import dataclasses import functools import math from functools import lru_cache, partial -from typing import Any, Optional, Tuple, Union +from typing import Any, Callable, Optional, Tuple, Union import torch import torch.nn as nn @@ -308,6 +308,7 @@ class VisionFlash3Attention(nn.Module): cu_seqlens = cu_seqlens.to(dtype=torch.int32).to(q.device) seq_lens = cu_seqlens[1:] - cu_seqlens[:-1] max_seqlen = seq_lens.max().item() + output = flash_attn_varlen_func( q, k, @@ -358,6 +359,9 @@ class VisionAttention(nn.Module): qkv_bias: bool = True, qk_normalization: bool = False, layer_norm_eps: float = 1e-06, + customized_position_embedding_applier: Callable[ + [torch.Tensor, torch.Tensor, Any, Any], Tuple[torch.Tensor, torch.Tensor] + ] = None, **kwargs, ): super().__init__() @@ -392,6 +396,7 @@ class VisionAttention(nn.Module): self.dummy_dim, eps=layer_norm_eps, var_hidden_size=embed_dim ) + # priority: server_args > passed qkv_backend > sdpa if global_server_args_dict["mm_attention_backend"] is None: if qkv_backend is None: qkv_backend = "sdpa" @@ -401,6 +406,9 @@ class VisionAttention(nn.Module): print_info_once(f"Using {qkv_backend} as multimodal attention backend.") + self.customized_position_embedding_applier = ( + customized_position_embedding_applier + ) self.qkv_backend = QKV_BACKEND_IMPL[qkv_backend]( head_dim=self.head_size, num_heads=self.num_attention_heads_per_partition, @@ -473,13 +481,13 @@ class VisionAttention(nn.Module): if x.dim() == 2: x = x.unsqueeze(0) assert x.dim() == 3, x.shape - bsz, s, _ = x.shape + x_shape = x.shape + bsz, s, _ = x_shape head = self.num_attention_heads_per_partition kv_head = self.num_attention_kv_heads_per_partition if self.use_qkv_parallel: # [b, s, embed_dim] --> [b, s, embed_dim] qkv, _ = self.qkv_proj(x) - q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1) # [b, s, embed_dim] --> [b * s, head, head_size] @@ -508,16 +516,25 @@ class VisionAttention(nn.Module): ] if position_embeddings is not None: - cos, sin = position_embeddings original_shape = q.shape - # [total_tokens, head, head_size] - q = q.view(-1, head, self.head_size) - k = k.view(-1, head, self.head_size) - q, k = apply_rotary_pos_emb(q, k, cos, sin) + if self.customized_position_embedding_applier is not None: + q, k = self.customized_position_embedding_applier( + q, k, position_embeddings, x_shape + ) + q = q.view(original_shape) + k = k.view(original_shape) + else: + cos, sin = position_embeddings - q = q.view(original_shape) - k = k.view(original_shape) + # [total_tokens, head, head_size] + q = q.view(-1, head, self.head_size) + k = k.view(-1, head, self.head_size) + + q, k = apply_rotary_pos_emb(q, k, cos, sin) + + q = q.view(original_shape) + k = k.view(original_shape) if q.dim() == 4: # [b, s, head, head_size] --> [b * s, head, head_size] diff --git a/python/sglang/srt/managers/tokenizer_manager.py b/python/sglang/srt/managers/tokenizer_manager.py index 46fd967e5..76a31e334 100644 --- a/python/sglang/srt/managers/tokenizer_manager.py +++ b/python/sglang/srt/managers/tokenizer_manager.py @@ -70,7 +70,6 @@ from sglang.srt.managers.io_struct import ( BatchMultimodalOut, BatchStrOut, BatchTokenIDOut, - BlockReqType, CloseSessionReqInput, ConfigureLoggingReq, EmbeddingReqInput, @@ -202,13 +201,29 @@ class TokenizerManager: if self.model_config.is_multimodal: import_processors() - _processor = get_processor( - server_args.tokenizer_path, - tokenizer_mode=server_args.tokenizer_mode, - trust_remote_code=server_args.trust_remote_code, - revision=server_args.revision, - use_fast=not server_args.disable_fast_image_processor, - ) + try: + _processor = get_processor( + server_args.tokenizer_path, + tokenizer_mode=server_args.tokenizer_mode, + trust_remote_code=server_args.trust_remote_code, + revision=server_args.revision, + use_fast=not server_args.disable_fast_image_processor, + ) + except ValueError as e: + error_message = str(e) + if "does not have a slow version" in error_message: + logger.info( + f"Processor {server_args.tokenizer_path} does not have a slow version. Automatically use fast version" + ) + _processor = get_processor( + server_args.tokenizer_path, + tokenizer_mode=server_args.tokenizer_mode, + trust_remote_code=server_args.trust_remote_code, + revision=server_args.revision, + use_fast=True, + ) + else: + raise e transport_mode = _determine_tensor_transport_mode(self.server_args) # We want to parallelize the image pre-processing so we create an executor for it diff --git a/python/sglang/srt/models/llama4.py b/python/sglang/srt/models/llama4.py index 265a9391d..16cdd9e80 100644 --- a/python/sglang/srt/models/llama4.py +++ b/python/sglang/srt/models/llama4.py @@ -241,13 +241,22 @@ class Llama4Attention(nn.Module): if self.use_qk_norm else None ) + + qkv_quant_config = quant_config + o_quant_config = quant_config + if quant_config and hasattr(quant_config, "ignore") and quant_config.ignore: + if add_prefix("q_proj", prefix) in quant_config.ignore: + qkv_quant_config = None + if add_prefix("o_proj", prefix) in quant_config.ignore: + o_quant_config = None + self.qkv_proj = QKVParallelLinear( hidden_size=hidden_size, head_size=self.head_dim, total_num_heads=self.total_num_heads, total_num_kv_heads=self.total_num_kv_heads, bias=bias, - quant_config=quant_config, + quant_config=qkv_quant_config, prefix=add_prefix("qkv_proj", prefix), tp_rank=attn_tp_rank, tp_size=attn_tp_size, @@ -257,7 +266,7 @@ class Llama4Attention(nn.Module): input_size=self.total_num_heads * self.head_dim, output_size=hidden_size, bias=bias_o_proj, - quant_config=quant_config, + quant_config=o_quant_config, prefix=add_prefix("o_proj", prefix), tp_rank=attn_tp_rank, tp_size=attn_tp_size, diff --git a/python/sglang/srt/models/mllama4.py b/python/sglang/srt/models/mllama4.py index 4a2d5f7de..b57d637f0 100644 --- a/python/sglang/srt/models/mllama4.py +++ b/python/sglang/srt/models/mllama4.py @@ -1,17 +1,24 @@ import json as json_lib import logging +import math import os from collections.abc import Iterable from typing import List, Optional, Set, Tuple import torch from torch import nn -from transformers import Llama4Config +from transformers import Llama4Config, Llama4VisionConfig from transformers.models.llama4.modeling_llama4 import ( Llama4MultiModalProjector, - Llama4VisionModel, + vision_apply_rotary_emb, ) +from sglang.srt.layers.attention.vision import VisionAttention +from sglang.srt.layers.linear import ( + ColumnParallelLinear, + ReplicatedLinear, + RowParallelLinear, +) from sglang.srt.layers.logits_processor import LogitsProcessor from sglang.srt.layers.moe.fused_moe_triton import FusedMoE from sglang.srt.layers.quantization import QuantizationConfig @@ -26,10 +33,10 @@ from sglang.srt.managers.schedule_batch import ( global_server_args_dict, ) from sglang.srt.model_executor.forward_batch_info import ForwardBatch -from sglang.srt.model_loader.weight_utils import default_weight_loader -from sglang.srt.utils import add_prefix, is_cpu +from sglang.srt.utils import is_cpu _is_cpu = is_cpu() + from sglang.srt.model_loader.weight_utils import ( default_weight_loader, maybe_remap_kv_scale_name, @@ -39,6 +46,376 @@ from sglang.srt.utils import add_prefix logger = logging.getLogger(__name__) +class Llama4VisionMLP(nn.Module): + + def __init__( + self, + input_size: int, + intermediate_size: int, + output_size: int, + bias: bool, + output_activation: bool, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + use_data_parallel: bool = False, + ): + super().__init__() + cls_fc1 = ReplicatedLinear if use_data_parallel else ColumnParallelLinear + self.fc1 = cls_fc1( + input_size=input_size, + output_size=intermediate_size, + bias=bias, + quant_config=quant_config, + prefix=f"{prefix}.fc1", + ) + cls_fc2 = ReplicatedLinear if use_data_parallel else RowParallelLinear + self.fc2 = cls_fc2( + input_size=intermediate_size, + output_size=output_size, + bias=bias, + quant_config=quant_config, + prefix=f"{prefix}.fc2", + ) + self.activation_fn = nn.GELU() + self.output_activation = output_activation + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + hidden_states, _ = self.fc1(hidden_states) + hidden_states = self.activation_fn(hidden_states) + hidden_states, _ = self.fc2(hidden_states) + if self.output_activation: + return self.activation_fn(hidden_states) + return hidden_states + + +def pixel_shuffle(input_tensor, shuffle_ratio): + # input_tensor: [batch_size, num_patches, channels] + batch_size, num_patches, channels = input_tensor.shape + patch_size = int(math.sqrt(num_patches)) + + input_tensor = input_tensor.view(batch_size, patch_size, patch_size, -1) + batch_size, height, width, channels = input_tensor.size() + + reshaped_tensor = input_tensor.view( + batch_size, height, int(width * shuffle_ratio), int(channels / shuffle_ratio) + ) + reshaped_tensor = reshaped_tensor.permute(0, 2, 1, 3).contiguous() + + reshaped_tensor = reshaped_tensor.view( + batch_size, + int(height * shuffle_ratio), + int(width * shuffle_ratio), + int(channels / (shuffle_ratio**2)), + ) + reshaped_tensor = reshaped_tensor.permute(0, 2, 1, 3).contiguous() + + output_tensor = reshaped_tensor.view(batch_size, -1, reshaped_tensor.shape[-1]) + return output_tensor + + +class Llama4VisionPixelShuffleMLP(nn.Module): + + def __init__( + self, + config, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + use_data_parallel: bool = False, + ): + super().__init__() + self.pixel_shuffle_ratio = config.pixel_shuffle_ratio + self.mlp = Llama4VisionMLP( + input_size=config.intermediate_size, + intermediate_size=config.projector_input_dim, + output_size=config.projector_output_dim, + bias=config.multi_modal_projector_bias, + output_activation=True, + quant_config=quant_config, + prefix=f"{prefix}.mlp", + use_data_parallel=use_data_parallel, + ) + + def forward(self, encoded_patches: torch.Tensor) -> torch.Tensor: + encoded_patches = pixel_shuffle(encoded_patches, self.pixel_shuffle_ratio) + return self.mlp(encoded_patches) + + +def apply_position_embedding(q, k, freqs_ci, shape): + # [batch_size_times_num_tiles, num_channels] + input_shape = shape[:2] + # [batch_size_times_num_tiles, num_channels, num_heads, head_dim] + hidden_shape = (*input_shape, *q.shape[-2:]) + q = q.view(hidden_shape) + k = k.view(hidden_shape) + q, k = vision_apply_rotary_emb(q, k, freqs_ci) + return q, k + + +class Llama4VisionEncoderLayer(nn.Module): + + def __init__( + self, + config: Llama4VisionConfig, + quant_config: Optional[QuantizationConfig], + prefix: str = "", + use_data_parallel: bool = False, + ): + super().__init__() + self.hidden_size = config.hidden_size + self.num_attention_heads = config.num_attention_heads + self.intermediate_size = config.intermediate_size + + self.self_attn = VisionAttention( + self.hidden_size, + self.num_attention_heads, + self.hidden_size, + use_qkv_parallel=True, + # vision_model is explicitly ignored in Maverick-17B-128E-Instruct-FP8 + quant_config=None, + dropout=0.0, + qkv_backend="sdpa", + softmax_in_single_precision=False, + flatten_batch=False, + prefix=add_prefix("self_attn", prefix), + qkv_bias=True, + customized_position_embedding_applier=apply_position_embedding, + ) + self.mlp = Llama4VisionMLP( + input_size=config.hidden_size, + intermediate_size=config.intermediate_size, + output_size=config.hidden_size, + bias=True, + output_activation=False, + quant_config=quant_config, + prefix=f"{prefix}.mlp", + use_data_parallel=use_data_parallel, + ) + + self.input_layernorm = nn.LayerNorm(config.hidden_size) + self.post_attention_layernorm = nn.LayerNorm(config.hidden_size) + + def forward( + self, + hidden_state: torch.Tensor, + freqs_ci: torch.Tensor, + ): + # Self Attention + residual = hidden_state + hidden_state = self.input_layernorm(hidden_state) + hidden_state = self.self_attn(hidden_state, position_embeddings=freqs_ci) + hidden_state = residual + hidden_state + + # Feed forward + residual = hidden_state + hidden_state = self.post_attention_layernorm(hidden_state) + hidden_state = self.mlp(hidden_state) + hidden_state = residual + hidden_state + + outputs = hidden_state + return outputs + + +class Llama4VisionEncoder(nn.Module): + + def __init__( + self, + config: Llama4VisionConfig, + quant_config: Optional[QuantizationConfig], + prefix: str = "", + use_data_parallel: bool = False, + ): + super().__init__() + self.config = config + self.layers = nn.ModuleList( + [ + Llama4VisionEncoderLayer( + config, + quant_config=quant_config, + prefix=f"{prefix}.layers.{layer_idx}", + use_data_parallel=use_data_parallel, + ) + for layer_idx in range(config.num_hidden_layers) + ] + ) + + def forward( + self, + hidden_states: torch.Tensor, + freqs_ci: torch.Tensor, # TODO: move this to an attribute instead of keeping it around + ) -> torch.Tensor: + r""" + Args: + hidden_states (`torch.FloatTensor` of shape + `(batch_size, sequence_length, hidden_size)`): + Optionally, instead of passing `input_ids` you can choose to + directly pass an embedded representation. This is useful if you + want more control over how to convert `input_ids` indices into + associated vectors than the model's internal embedding + lookup matrix. + """ + + for encoder_layer in self.layers: + layer_outputs = encoder_layer(hidden_states, freqs_ci=freqs_ci) + hidden_states = layer_outputs + + return hidden_states + + +class Llama4UnfoldConvolution(nn.Module): + + def __init__( + self, + config: Llama4VisionConfig, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + use_data_parallel: bool = False, + ): + super().__init__() + kernel_size = config.patch_size + if isinstance(kernel_size, int): + kernel_size = (kernel_size, kernel_size) + self.unfold = torch.nn.Unfold(kernel_size=kernel_size, stride=config.patch_size) + params = { + "input_size": config.num_channels * kernel_size[0] * kernel_size[1], + "output_size": config.hidden_size, + "bias": False, + "quant_config": quant_config, + "prefix": f"{prefix}.linear", + } + if use_data_parallel: + cls = ReplicatedLinear + else: + cls = ColumnParallelLinear + params["gather_output"] = True + self.linear = cls(**params) + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + hidden_states = self.unfold(hidden_states) + hidden_states = hidden_states.permute(0, 2, 1) + hidden_states, _ = self.linear(hidden_states) + return hidden_states + + +class Llama4VisionRotaryEmbedding(nn.Module): + def __init__(self, config): + super().__init__() + idx = config.image_size // config.patch_size + img_idx = torch.arange(idx**2, dtype=torch.int32).reshape(idx**2, 1) + img_idx = torch.cat([img_idx, img_idx[:1]], dim=0) + img_idx[-1, -1] = -2 # ID_CLS_TOKEN + frequencies_x = img_idx % idx # get the coordinates of the 2d matrix along x + frequencies_y = img_idx // idx # get the coordinates of the 2d matrix along y + freq_dim = config.hidden_size // config.num_attention_heads // 2 + rope_freq = 1.0 / ( + config.rope_theta + ** (torch.arange(0, freq_dim, 2)[: (freq_dim // 2)].float() / freq_dim) + ) + freqs_x = ( + (frequencies_x + 1)[..., None] * rope_freq[None, None, :] + ).repeat_interleave(2, dim=-1) + freqs_y = ( + (frequencies_y + 1)[..., None] * rope_freq[None, None, :] + ).repeat_interleave(2, dim=-1) + freqs = torch.cat([freqs_x, freqs_y], dim=-1).float().contiguous()[..., ::2] + freqs = freqs.masked_fill(img_idx.reshape(-1, 1, 1) < 0, 0) + freq_cis = torch.view_as_complex( + torch.stack([torch.cos(freqs), torch.sin(freqs)], dim=-1) + ) + self.freqs_ci = freq_cis # idx**2, idx**2, idx * 2 + + def forward(self, hidden_states): + return self.freqs_ci.to(hidden_states.device) + + +class Llama4VisionModel(nn.Module): + + def __init__( + self, + config: Llama4VisionConfig, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + ): + super().__init__() + self.config = config + self.image_size = config.image_size + self.patch_size = config.patch_size + self.hidden_size = config.hidden_size + self.num_channels = config.num_channels + + self.num_patches = (self.image_size // self.patch_size) ** 2 + 1 + self.scale = config.hidden_size**-0.5 + + self.patch_embedding = Llama4UnfoldConvolution( + config, + quant_config=quant_config, + prefix=f"{prefix}.patch_embedding", + ) + + self.class_embedding = nn.Parameter(self.scale * torch.randn(self.hidden_size)) + self.positional_embedding_vlm = nn.Parameter( + self.scale * torch.randn(self.num_patches, self.hidden_size) + ) + + self.rotary_embedding = Llama4VisionRotaryEmbedding(config) + + # layer norms + self.layernorm_pre = nn.LayerNorm(self.hidden_size, eps=1e-5) + self.layernorm_post = nn.LayerNorm(self.hidden_size, eps=1e-5) + + # encoders + self.model = Llama4VisionEncoder( + config, + quant_config=quant_config, + prefix=f"{prefix}.model", + ) + self.vision_adapter = Llama4VisionPixelShuffleMLP( + config, + quant_config, + prefix=f"{prefix}.vision_adapter", + ) + + def forward( + self, + pixel_values: torch.Tensor, + ) -> torch.Tensor: + # Patch embedding + hidden_state = self.patch_embedding(pixel_values) + num_tiles, num_patches, hidden_dim = hidden_state.shape + + # Add cls token + class_embedding = self.class_embedding.expand( + hidden_state.shape[0], 1, hidden_state.shape[-1] + ) + hidden_state = torch.cat([hidden_state, class_embedding], dim=1) + num_patches += 1 + + # Position embeddings + hidden_state = hidden_state.reshape( + num_tiles, + 1, + num_patches, + hidden_dim, + ) + positional_embedding = self.positional_embedding_vlm.to( + dtype=hidden_state.dtype, device=hidden_state.device + ) + hidden_state = hidden_state + positional_embedding + hidden_state = self.layernorm_pre(hidden_state) + hidden_state = hidden_state.view(num_tiles, -1, hidden_dim) + freqs_ci = self.rotary_embedding(pixel_values) + # Apply encoder + hidden_state = self.model(hidden_state, freqs_ci=freqs_ci) + hidden_state = self.layernorm_post(hidden_state) + + # Remove CLS token output + hidden_state = hidden_state[:, :-1, :] + + # now, we use Llama4VisionPixelShuffle + mlp to project embeddings + hidden_state = self.vision_adapter(hidden_state) + + return hidden_state + + class Llama4ForConditionalGeneration(nn.Module): packed_modules_mapping = { "qkv_proj": ["q_proj", "k_proj", "v_proj"], @@ -60,7 +437,8 @@ class Llama4ForConditionalGeneration(nn.Module): if not self.has_vision_weights: logger.warning( "No vision weights found in checkpoint. Model will run in text-only mode. " - "Multimodal capabilities (image processing) will be unavailable." + "Multimodal capabilities (vision understanding) will be unavailable. " + "Please not that this warning might be inaccurate if the weights haven't been fully downloaded" ) self.has_vision = ( @@ -68,7 +446,12 @@ class Llama4ForConditionalGeneration(nn.Module): ) if self.has_vision: - self.vision_model = Llama4VisionModel(config.vision_config) + self.vision_model = Llama4VisionModel( + config.vision_config, + quant_config=quant_config, + prefix=add_prefix("vision_model", prefix), + ) + self.multi_modal_projector = Llama4MultiModalProjector(config) else: self.vision_model = None @@ -112,7 +495,6 @@ class Llama4ForConditionalGeneration(nn.Module): filename="model.safetensors.index.json", cache_dir=None, ) - if index_file_path and os.path.exists(index_file_path): return self._check_vision_weights_in_index(index_file_path) @@ -120,7 +502,7 @@ class Llama4ForConditionalGeneration(nn.Module): # If we can't access the cache, fall back to config-based detection pass - # Fallback, assume text-only + # Fallback, assume text-only return False def _check_vision_weights_in_index(self, index_file: str) -> bool: @@ -131,7 +513,6 @@ class Llama4ForConditionalGeneration(nn.Module): vision_patterns = ["vision_model", "vision_tower", "multi_modal_projector"] weight_names = index_data.get("weight_map", {}).keys() - return any( pattern in weight_name for weight_name in weight_names @@ -150,17 +531,17 @@ class Llama4ForConditionalGeneration(nn.Module): # For text-only models, return None or raise an error if not self.has_vision or self.vision_model is None: raise ValueError("Vision model not available for text-only checkpoint") - pixel_values = ( torch.concat([item.feature for item in items]) .to(next(self.vision_model.parameters()).device) .type(next(self.vision_model.parameters()).dtype) ) + image_features = self.vision_model(pixel_values) - image_outputs = self.vision_model(pixel_values, output_hidden_states=False) - image_features = image_outputs.last_hidden_state vision_flat = image_features.view(-1, image_features.size(-1)) + projected_vision_flat = self.multi_modal_projector(vision_flat) + return projected_vision_flat def forward( @@ -246,31 +627,47 @@ class Llama4ForConditionalGeneration(nn.Module): num_experts=num_experts, ) + loaded_params = set() + for name, loaded_weight in weights: if self._should_skip_weight(name): continue name = self._transform_weight_name(name) - if "vision" not in name: + if "vision" in name: + name = name.replace(".self_attn.o_proj", ".self_attn.proj") + else: name, loaded_weight = self.permute_qk_weight_for_rotary( name, loaded_weight ) if self._handle_scale_remapping(name, params_dict): + loaded_params.add(name) continue if self._handle_stacked_params( - name, loaded_weight, stacked_params_mapping, params_dict + name, loaded_weight, stacked_params_mapping, params_dict, loaded_params ): continue if self._handle_expert_weights( - name, loaded_weight, expert_params_mapping, params_dict, num_experts + name, + loaded_weight, + expert_params_mapping, + params_dict, + num_experts, + loaded_params, ): continue + loaded_params.add(name) self._handle_default_weight(name, loaded_weight, params_dict) + unloaded_params = params_dict.keys() - loaded_params + if unloaded_params: + logger.warning( + f"Some weights are not initialized from checkpoints {unloaded_params}" + ) def _should_skip_weight(self, name: str) -> bool: """Check if we should skip loading this weight.""" @@ -301,11 +698,13 @@ class Llama4ForConditionalGeneration(nn.Module): loaded_weight: torch.Tensor, stacked_params_mapping: list, params_dict: dict, + loaded_params: set, ) -> bool: """Handle stacked parameter loading. Returns True if handled.""" for param_name, weight_name, shard_id in stacked_params_mapping: - if weight_name in name and "vision" not in name: + if weight_name in name: transformed_name = name.replace(weight_name, param_name) + loaded_params.add(transformed_name) param = params_dict[transformed_name] param.weight_loader(param, loaded_weight, shard_id) return True @@ -318,6 +717,7 @@ class Llama4ForConditionalGeneration(nn.Module): expert_params_mapping: list, params_dict: dict, num_experts: int, + loaded_params: set, ) -> bool: """Handle expert weight loading for MoE (Mixture of Experts) layers. @@ -336,16 +736,16 @@ class Llama4ForConditionalGeneration(nn.Module): if "experts.gate_up_proj" not in name and "experts.down_proj" not in name: return self._handle_other_expert_params( - name, loaded_weight, expert_params_mapping, params_dict + name, loaded_weight, expert_params_mapping, params_dict, loaded_params ) if "scale" in name: return self._handle_expert_scale_params( - name, loaded_weight, params_dict, num_experts + name, loaded_weight, params_dict, num_experts, loaded_params ) else: return self._handle_expert_weight_params( - name, loaded_weight, params_dict, num_experts + name, loaded_weight, params_dict, num_experts, loaded_params ) def _handle_other_expert_params( @@ -354,6 +754,7 @@ class Llama4ForConditionalGeneration(nn.Module): loaded_weight: torch.Tensor, expert_params_mapping: list, params_dict: dict, + loaded_params: set, ) -> bool: """Handle expert parameters that are not gate_up_proj or down_proj weights. @@ -362,6 +763,7 @@ class Llama4ForConditionalGeneration(nn.Module): loaded_weight: The weight tensor to be loaded expert_params_mapping: List of tuples mapping checkpoint names to model parameters params_dict: Dictionary of model parameters + loaded_params: Set of loaded parameter names Returns: bool: True if parameter was found and handled, False otherwise @@ -373,6 +775,7 @@ class Llama4ForConditionalGeneration(nn.Module): param.weight_loader( param, loaded_weight, name, shard_id=shard_id, expert_id=expert_id ) + loaded_params.add(transformed_name) return True return False @@ -411,6 +814,7 @@ class Llama4ForConditionalGeneration(nn.Module): loaded_weight: torch.Tensor, params_dict: dict, num_experts: int, + loaded_params: set, ) -> bool: """Handle quantization scale parameters for expert weights. @@ -419,6 +823,7 @@ class Llama4ForConditionalGeneration(nn.Module): loaded_weight: Scale tensor to be loaded params_dict: Dictionary of model parameters num_experts: Total number of experts for broadcast operations + loaded_params: Set of loaded parameter names Returns: bool: True (always handles scale parameters) @@ -447,6 +852,7 @@ class Llama4ForConditionalGeneration(nn.Module): # Load the same scale for all experts for expert_id in range(num_experts): param.data[expert_id] = loaded_weight + loaded_params.add(transformed_name) return True @@ -456,6 +862,7 @@ class Llama4ForConditionalGeneration(nn.Module): loaded_weight: torch.Tensor, params_dict: dict, num_experts: int, + loaded_params: set, ) -> bool: """Handle actual weight tensors for expert layers (gate_up_proj and down_proj). @@ -464,6 +871,7 @@ class Llama4ForConditionalGeneration(nn.Module): loaded_weight: Weight tensor(s) to be loaded params_dict: Dictionary of model parameters num_experts: Total number of experts for tensor distribution + loaded_params: Set of loaded parameter names Returns: bool: True (always handles weight parameters) @@ -486,6 +894,7 @@ class Llama4ForConditionalGeneration(nn.Module): param = params_dict[param_name] weight_loader = param.weight_loader + loaded_params.add(param_name) # Handle the case where loaded_weight might be a single tensor for all experts if weight_chunk.dim() == 2: diff --git a/python/sglang/srt/multimodal/processors/base_processor.py b/python/sglang/srt/multimodal/processors/base_processor.py index 06e5c0da0..760d3c26f 100644 --- a/python/sglang/srt/multimodal/processors/base_processor.py +++ b/python/sglang/srt/multimodal/processors/base_processor.py @@ -12,7 +12,6 @@ import torch from PIL import Image from transformers import BaseImageProcessorFast -from sglang.srt.managers.mm_utils import TransportProxyTensor from sglang.srt.managers.schedule_batch import Modality, MultimodalDataItem from sglang.srt.utils import load_audio, load_image, load_video, logger @@ -218,8 +217,10 @@ class BaseMultimodalProcessor(ABC): kwargs["audio"] = audios processor = self._processor - if hasattr(processor, "image_processor") and isinstance( - processor.image_processor, BaseImageProcessorFast + if ( + hasattr(processor, "image_processor") + and isinstance(processor.image_processor, BaseImageProcessorFast) + and not self.server_args.disable_fast_image_processor ): kwargs["device"] = "cuda" result = processor.__call__(