fix: fp8 mllama4 without vision modules being quantized (#10611)

This commit is contained in:
Mick
2025-09-30 05:17:12 +08:00
committed by GitHub
parent dda34c2f93
commit 9de1320b63

View File

@@ -291,7 +291,7 @@ class Llama4UnfoldConvolution(nn.Module):
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
hidden_states = self.unfold(hidden_states)
hidden_states = hidden_states.permute(0, 2, 1)
hidden_states = hidden_states.permute(0, 2, 1).contiguous()
hidden_states, _ = self.linear(hidden_states)
return hidden_states
@@ -446,9 +446,20 @@ class Llama4ForConditionalGeneration(nn.Module):
)
if self.has_vision:
# TODO: make this more general
ignore_quant_layers = getattr(config, "quantization_config", {}).get(
"ignore", {}
)
if (
"model.layers.vision_model*" in ignore_quant_layers
and "model.layers.multi_modal_projector*" in ignore_quant_layers
):
vision_quant_config = None
else:
vision_quant_config = quant_config
self.vision_model = Llama4VisionModel(
config.vision_config,
quant_config=quant_config,
quant_config=vision_quant_config,
prefix=add_prefix("vision_model", prefix),
)
@@ -560,7 +571,7 @@ class Llama4ForConditionalGeneration(nn.Module):
forward_batch=forward_batch,
language_model=self.language_model,
data_embedding_funcs={
Modality.IMAGE: self.get_image_feature,
Modality.IMAGE: image_embedding_func,
},
positions=positions,
)