From e2d66f60c8f8c90ed9491e21061b73d959c2c4d7 Mon Sep 17 00:00:00 2001 From: Ke Bao Date: Wed, 23 Jul 2025 12:41:25 +0800 Subject: [PATCH] Skip llama4 vision module loading when multimodal disabled (#8272) Co-authored-by: Mick --- python/sglang/srt/managers/schedule_batch.py | 1 + python/sglang/srt/models/mllama4.py | 13 ++++++++++--- 2 files changed, 11 insertions(+), 3 deletions(-) diff --git a/python/sglang/srt/managers/schedule_batch.py b/python/sglang/srt/managers/schedule_batch.py index 536198cd2..714af6fba 100644 --- a/python/sglang/srt/managers/schedule_batch.py +++ b/python/sglang/srt/managers/schedule_batch.py @@ -106,6 +106,7 @@ GLOBAL_SERVER_ARGS_KEYS = [ "num_reserved_decode_tokens", "weight_loader_disable_mmap", "enable_triton_kernel_moe", + "enable_multimodal", ] # Put some global args for easy access diff --git a/python/sglang/srt/models/mllama4.py b/python/sglang/srt/models/mllama4.py index 8712191a9..4a2d5f7de 100644 --- a/python/sglang/srt/models/mllama4.py +++ b/python/sglang/srt/models/mllama4.py @@ -23,6 +23,7 @@ from sglang.srt.managers.schedule_batch import ( Modality, MultimodalDataItem, MultimodalInputs, + global_server_args_dict, ) from sglang.srt.model_executor.forward_batch_info import ForwardBatch from sglang.srt.model_loader.weight_utils import default_weight_loader @@ -55,13 +56,17 @@ class Llama4ForConditionalGeneration(nn.Module): self.quant_config = quant_config # Check if this is a text-only model (modelopt fp8 llama4 has no vision components) - self.has_vision = self._has_vision_weights(config) - if not self.has_vision: + self.has_vision_weights = self._has_vision_weights(config) + if not self.has_vision_weights: logger.warning( "No vision weights found in checkpoint. Model will run in text-only mode. " "Multimodal capabilities (image processing) will be unavailable." ) + self.has_vision = ( + self.has_vision_weights and global_server_args_dict["enable_multimodal"] + ) + if self.has_vision: self.vision_model = Llama4VisionModel(config.vision_config) self.multi_modal_projector = Llama4MultiModalProjector(config) @@ -269,7 +274,9 @@ class Llama4ForConditionalGeneration(nn.Module): def _should_skip_weight(self, name: str) -> bool: """Check if we should skip loading this weight.""" - return "vision" in name and not self.has_vision + return not self.has_vision and ( + "vision" in name or "multi_modal_projector" in name + ) def _transform_weight_name(self, name: str) -> str: """Transform weight name by adding language_model prefix if needed."""