fix: fp8 mllama4 without vision modules being quantized (#10611)
This commit is contained in:
@@ -291,7 +291,7 @@ class Llama4UnfoldConvolution(nn.Module):
|
||||
|
||||
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
||||
hidden_states = self.unfold(hidden_states)
|
||||
hidden_states = hidden_states.permute(0, 2, 1)
|
||||
hidden_states = hidden_states.permute(0, 2, 1).contiguous()
|
||||
hidden_states, _ = self.linear(hidden_states)
|
||||
return hidden_states
|
||||
|
||||
@@ -446,9 +446,20 @@ class Llama4ForConditionalGeneration(nn.Module):
|
||||
)
|
||||
|
||||
if self.has_vision:
|
||||
# TODO: make this more general
|
||||
ignore_quant_layers = getattr(config, "quantization_config", {}).get(
|
||||
"ignore", {}
|
||||
)
|
||||
if (
|
||||
"model.layers.vision_model*" in ignore_quant_layers
|
||||
and "model.layers.multi_modal_projector*" in ignore_quant_layers
|
||||
):
|
||||
vision_quant_config = None
|
||||
else:
|
||||
vision_quant_config = quant_config
|
||||
self.vision_model = Llama4VisionModel(
|
||||
config.vision_config,
|
||||
quant_config=quant_config,
|
||||
quant_config=vision_quant_config,
|
||||
prefix=add_prefix("vision_model", prefix),
|
||||
)
|
||||
|
||||
@@ -560,7 +571,7 @@ class Llama4ForConditionalGeneration(nn.Module):
|
||||
forward_batch=forward_batch,
|
||||
language_model=self.language_model,
|
||||
data_embedding_funcs={
|
||||
Modality.IMAGE: self.get_image_feature,
|
||||
Modality.IMAGE: image_embedding_func,
|
||||
},
|
||||
positions=positions,
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user