From 9de1320b637ce6f4683179fe4abbf4237648b638 Mon Sep 17 00:00:00 2001 From: Mick Date: Tue, 30 Sep 2025 05:17:12 +0800 Subject: [PATCH] fix: fp8 mllama4 without vision modules being quantized (#10611) --- python/sglang/srt/models/mllama4.py | 17 ++++++++++++++--- 1 file changed, 14 insertions(+), 3 deletions(-) diff --git a/python/sglang/srt/models/mllama4.py b/python/sglang/srt/models/mllama4.py index f0184390c..72077d96a 100644 --- a/python/sglang/srt/models/mllama4.py +++ b/python/sglang/srt/models/mllama4.py @@ -291,7 +291,7 @@ class Llama4UnfoldConvolution(nn.Module): def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: hidden_states = self.unfold(hidden_states) - hidden_states = hidden_states.permute(0, 2, 1) + hidden_states = hidden_states.permute(0, 2, 1).contiguous() hidden_states, _ = self.linear(hidden_states) return hidden_states @@ -446,9 +446,20 @@ class Llama4ForConditionalGeneration(nn.Module): ) if self.has_vision: + # TODO: make this more general + ignore_quant_layers = getattr(config, "quantization_config", {}).get( + "ignore", {} + ) + if ( + "model.layers.vision_model*" in ignore_quant_layers + and "model.layers.multi_modal_projector*" in ignore_quant_layers + ): + vision_quant_config = None + else: + vision_quant_config = quant_config self.vision_model = Llama4VisionModel( config.vision_config, - quant_config=quant_config, + quant_config=vision_quant_config, prefix=add_prefix("vision_model", prefix), ) @@ -560,7 +571,7 @@ class Llama4ForConditionalGeneration(nn.Module): forward_batch=forward_batch, language_model=self.language_model, data_embedding_funcs={ - Modality.IMAGE: self.get_image_feature, + Modality.IMAGE: image_embedding_func, }, positions=positions, )