diff --git a/python/sglang/srt/managers/schedule_batch.py b/python/sglang/srt/managers/schedule_batch.py index 536198cd2..714af6fba 100644 --- a/python/sglang/srt/managers/schedule_batch.py +++ b/python/sglang/srt/managers/schedule_batch.py @@ -106,6 +106,7 @@ GLOBAL_SERVER_ARGS_KEYS = [ "num_reserved_decode_tokens", "weight_loader_disable_mmap", "enable_triton_kernel_moe", + "enable_multimodal", ] # Put some global args for easy access diff --git a/python/sglang/srt/models/mllama4.py b/python/sglang/srt/models/mllama4.py index 8712191a9..4a2d5f7de 100644 --- a/python/sglang/srt/models/mllama4.py +++ b/python/sglang/srt/models/mllama4.py @@ -23,6 +23,7 @@ from sglang.srt.managers.schedule_batch import ( Modality, MultimodalDataItem, MultimodalInputs, + global_server_args_dict, ) from sglang.srt.model_executor.forward_batch_info import ForwardBatch from sglang.srt.model_loader.weight_utils import default_weight_loader @@ -55,13 +56,17 @@ class Llama4ForConditionalGeneration(nn.Module): self.quant_config = quant_config # Check if this is a text-only model (modelopt fp8 llama4 has no vision components) - self.has_vision = self._has_vision_weights(config) - if not self.has_vision: + self.has_vision_weights = self._has_vision_weights(config) + if not self.has_vision_weights: logger.warning( "No vision weights found in checkpoint. Model will run in text-only mode. " "Multimodal capabilities (image processing) will be unavailable." ) + self.has_vision = ( + self.has_vision_weights and global_server_args_dict["enable_multimodal"] + ) + if self.has_vision: self.vision_model = Llama4VisionModel(config.vision_config) self.multi_modal_projector = Llama4MultiModalProjector(config) @@ -269,7 +274,9 @@ class Llama4ForConditionalGeneration(nn.Module): def _should_skip_weight(self, name: str) -> bool: """Check if we should skip loading this weight.""" - return "vision" in name and not self.has_vision + return not self.has_vision and ( + "vision" in name or "multi_modal_projector" in name + ) def _transform_weight_name(self, name: str) -> str: """Transform weight name by adding language_model prefix if needed."""