fix: fp8 mllama4 without vision modules being quantized (#10611)
This commit is contained in:
@@ -291,7 +291,7 @@ class Llama4UnfoldConvolution(nn.Module):
|
|||||||
|
|
||||||
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
||||||
hidden_states = self.unfold(hidden_states)
|
hidden_states = self.unfold(hidden_states)
|
||||||
hidden_states = hidden_states.permute(0, 2, 1)
|
hidden_states = hidden_states.permute(0, 2, 1).contiguous()
|
||||||
hidden_states, _ = self.linear(hidden_states)
|
hidden_states, _ = self.linear(hidden_states)
|
||||||
return hidden_states
|
return hidden_states
|
||||||
|
|
||||||
@@ -446,9 +446,20 @@ class Llama4ForConditionalGeneration(nn.Module):
|
|||||||
)
|
)
|
||||||
|
|
||||||
if self.has_vision:
|
if self.has_vision:
|
||||||
|
# TODO: make this more general
|
||||||
|
ignore_quant_layers = getattr(config, "quantization_config", {}).get(
|
||||||
|
"ignore", {}
|
||||||
|
)
|
||||||
|
if (
|
||||||
|
"model.layers.vision_model*" in ignore_quant_layers
|
||||||
|
and "model.layers.multi_modal_projector*" in ignore_quant_layers
|
||||||
|
):
|
||||||
|
vision_quant_config = None
|
||||||
|
else:
|
||||||
|
vision_quant_config = quant_config
|
||||||
self.vision_model = Llama4VisionModel(
|
self.vision_model = Llama4VisionModel(
|
||||||
config.vision_config,
|
config.vision_config,
|
||||||
quant_config=quant_config,
|
quant_config=vision_quant_config,
|
||||||
prefix=add_prefix("vision_model", prefix),
|
prefix=add_prefix("vision_model", prefix),
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -560,7 +571,7 @@ class Llama4ForConditionalGeneration(nn.Module):
|
|||||||
forward_batch=forward_batch,
|
forward_batch=forward_batch,
|
||||||
language_model=self.language_model,
|
language_model=self.language_model,
|
||||||
data_embedding_funcs={
|
data_embedding_funcs={
|
||||||
Modality.IMAGE: self.get_image_feature,
|
Modality.IMAGE: image_embedding_func,
|
||||||
},
|
},
|
||||||
positions=positions,
|
positions=positions,
|
||||||
)
|
)
|
||||||
|
|||||||
Reference in New Issue
Block a user