Fix LoRA support for multimodal models (VLMs) by implementing a consistent pattern for skipping vision components (#11261)
This commit is contained in:
@@ -418,10 +418,6 @@ class LoRAManager:
|
||||
replace_submodule(self.base_model, module_name, lora_module)
|
||||
return lora_module
|
||||
|
||||
def should_skip_lora_for_vision_model(self, module_name):
|
||||
# TODO: support different vision models
|
||||
return module_name.find("vision_model.model") != -1
|
||||
|
||||
def init_lora_modules(self):
|
||||
# Look-up table that essentially maps (layer_index, module_name) to the corresponding LoRA module.
|
||||
self.lora_modules: List[Dict[str, BaseLayerWithLoRA]] = [
|
||||
@@ -439,10 +435,6 @@ class LoRAManager:
|
||||
) and not self.base_model.should_apply_lora(module_name):
|
||||
continue
|
||||
|
||||
# Skip vision model
|
||||
if self.should_skip_lora_for_vision_model(module_name):
|
||||
continue
|
||||
|
||||
# The module should be converted if it is included in target_names
|
||||
if module_name.split(".")[-1] in self.target_modules:
|
||||
layer_id = get_layer_id(module_name)
|
||||
|
||||
@@ -16,6 +16,7 @@
|
||||
# https://github.com/vllm-project/vllm/blob/main/vllm/model_executor/models/gemma3_mm.py
|
||||
|
||||
import logging
|
||||
import re
|
||||
from functools import lru_cache
|
||||
from typing import Dict, Iterable, List, Optional, Set, Tuple, TypedDict
|
||||
|
||||
@@ -154,6 +155,10 @@ class Gemma3ForConditionalGeneration(PreTrainedModel):
|
||||
embedding_modules = {}
|
||||
embedding_padding_modules = []
|
||||
supports_lora = True
|
||||
# Pattern to match language model layers only (skip vision_tower and multi_modal_projector)
|
||||
lora_pattern = re.compile(
|
||||
r"^language_model\.model\.layers\.(\d+)\.(?:self_attn|mlp)\.(?:qkv_proj|o_proj|down_proj|gate_up_proj)"
|
||||
)
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
@@ -165,6 +170,13 @@ class Gemma3ForConditionalGeneration(PreTrainedModel):
|
||||
self.config = config
|
||||
self.quant_config = quant_config
|
||||
|
||||
# For LoRA compatibility: expose text_config attributes at top level
|
||||
# This allows LoRA code to work without special multimodal handling
|
||||
if not hasattr(config, "num_hidden_layers"):
|
||||
config.num_hidden_layers = config.text_config.num_hidden_layers
|
||||
if not hasattr(config, "hidden_size"):
|
||||
config.hidden_size = config.text_config.hidden_size
|
||||
|
||||
self.vision_tower = SiglipVisionModel(
|
||||
config=config.vision_config,
|
||||
quant_config=quant_config,
|
||||
@@ -380,6 +392,10 @@ class Gemma3ForConditionalGeneration(PreTrainedModel):
|
||||
|
||||
return hs
|
||||
|
||||
def should_apply_lora(self, module_name: str) -> bool:
|
||||
"""Skip vision tower and multi_modal_projector for LoRA."""
|
||||
return bool(self.lora_pattern.match(module_name))
|
||||
|
||||
def tie_weights(self):
|
||||
return self.language_model.tie_weights()
|
||||
|
||||
|
||||
@@ -2,6 +2,7 @@ import json as json_lib
|
||||
import logging
|
||||
import math
|
||||
import os
|
||||
import re
|
||||
from collections.abc import Iterable
|
||||
from typing import List, Optional, Set, Tuple
|
||||
|
||||
@@ -422,6 +423,11 @@ class Llama4ForConditionalGeneration(nn.Module):
|
||||
"gate_up_proj": ["gate_proj", "up_proj"],
|
||||
}
|
||||
|
||||
# Pattern to match language model layers only (skip vision_model and multi_modal_projector)
|
||||
lora_pattern = re.compile(
|
||||
r"^language_model\.model\.layers\.(\d+)\.(?:self_attn|mlp)\.(?:qkv_proj|o_proj|down_proj|gate_up_proj)"
|
||||
)
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
config: Llama4Config,
|
||||
@@ -555,6 +561,10 @@ class Llama4ForConditionalGeneration(nn.Module):
|
||||
|
||||
return projected_vision_flat
|
||||
|
||||
def should_apply_lora(self, module_name: str) -> bool:
|
||||
"""Skip vision model and multi_modal_projector for LoRA."""
|
||||
return bool(self.lora_pattern.match(module_name))
|
||||
|
||||
def forward(
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
|
||||
Reference in New Issue
Block a user