Fix LoRA support for multimodal models (VLMs) by implementing a consistent pattern for skipping vision components (#11261)
This commit is contained in:
@@ -418,10 +418,6 @@ class LoRAManager:
|
|||||||
replace_submodule(self.base_model, module_name, lora_module)
|
replace_submodule(self.base_model, module_name, lora_module)
|
||||||
return lora_module
|
return lora_module
|
||||||
|
|
||||||
def should_skip_lora_for_vision_model(self, module_name):
|
|
||||||
# TODO: support different vision models
|
|
||||||
return module_name.find("vision_model.model") != -1
|
|
||||||
|
|
||||||
def init_lora_modules(self):
|
def init_lora_modules(self):
|
||||||
# Look-up table that essentially maps (layer_index, module_name) to the corresponding LoRA module.
|
# Look-up table that essentially maps (layer_index, module_name) to the corresponding LoRA module.
|
||||||
self.lora_modules: List[Dict[str, BaseLayerWithLoRA]] = [
|
self.lora_modules: List[Dict[str, BaseLayerWithLoRA]] = [
|
||||||
@@ -439,10 +435,6 @@ class LoRAManager:
|
|||||||
) and not self.base_model.should_apply_lora(module_name):
|
) and not self.base_model.should_apply_lora(module_name):
|
||||||
continue
|
continue
|
||||||
|
|
||||||
# Skip vision model
|
|
||||||
if self.should_skip_lora_for_vision_model(module_name):
|
|
||||||
continue
|
|
||||||
|
|
||||||
# The module should be converted if it is included in target_names
|
# The module should be converted if it is included in target_names
|
||||||
if module_name.split(".")[-1] in self.target_modules:
|
if module_name.split(".")[-1] in self.target_modules:
|
||||||
layer_id = get_layer_id(module_name)
|
layer_id = get_layer_id(module_name)
|
||||||
|
|||||||
@@ -16,6 +16,7 @@
|
|||||||
# https://github.com/vllm-project/vllm/blob/main/vllm/model_executor/models/gemma3_mm.py
|
# https://github.com/vllm-project/vllm/blob/main/vllm/model_executor/models/gemma3_mm.py
|
||||||
|
|
||||||
import logging
|
import logging
|
||||||
|
import re
|
||||||
from functools import lru_cache
|
from functools import lru_cache
|
||||||
from typing import Dict, Iterable, List, Optional, Set, Tuple, TypedDict
|
from typing import Dict, Iterable, List, Optional, Set, Tuple, TypedDict
|
||||||
|
|
||||||
@@ -154,6 +155,10 @@ class Gemma3ForConditionalGeneration(PreTrainedModel):
|
|||||||
embedding_modules = {}
|
embedding_modules = {}
|
||||||
embedding_padding_modules = []
|
embedding_padding_modules = []
|
||||||
supports_lora = True
|
supports_lora = True
|
||||||
|
# Pattern to match language model layers only (skip vision_tower and multi_modal_projector)
|
||||||
|
lora_pattern = re.compile(
|
||||||
|
r"^language_model\.model\.layers\.(\d+)\.(?:self_attn|mlp)\.(?:qkv_proj|o_proj|down_proj|gate_up_proj)"
|
||||||
|
)
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
@@ -165,6 +170,13 @@ class Gemma3ForConditionalGeneration(PreTrainedModel):
|
|||||||
self.config = config
|
self.config = config
|
||||||
self.quant_config = quant_config
|
self.quant_config = quant_config
|
||||||
|
|
||||||
|
# For LoRA compatibility: expose text_config attributes at top level
|
||||||
|
# This allows LoRA code to work without special multimodal handling
|
||||||
|
if not hasattr(config, "num_hidden_layers"):
|
||||||
|
config.num_hidden_layers = config.text_config.num_hidden_layers
|
||||||
|
if not hasattr(config, "hidden_size"):
|
||||||
|
config.hidden_size = config.text_config.hidden_size
|
||||||
|
|
||||||
self.vision_tower = SiglipVisionModel(
|
self.vision_tower = SiglipVisionModel(
|
||||||
config=config.vision_config,
|
config=config.vision_config,
|
||||||
quant_config=quant_config,
|
quant_config=quant_config,
|
||||||
@@ -380,6 +392,10 @@ class Gemma3ForConditionalGeneration(PreTrainedModel):
|
|||||||
|
|
||||||
return hs
|
return hs
|
||||||
|
|
||||||
|
def should_apply_lora(self, module_name: str) -> bool:
|
||||||
|
"""Skip vision tower and multi_modal_projector for LoRA."""
|
||||||
|
return bool(self.lora_pattern.match(module_name))
|
||||||
|
|
||||||
def tie_weights(self):
|
def tie_weights(self):
|
||||||
return self.language_model.tie_weights()
|
return self.language_model.tie_weights()
|
||||||
|
|
||||||
|
|||||||
@@ -2,6 +2,7 @@ import json as json_lib
|
|||||||
import logging
|
import logging
|
||||||
import math
|
import math
|
||||||
import os
|
import os
|
||||||
|
import re
|
||||||
from collections.abc import Iterable
|
from collections.abc import Iterable
|
||||||
from typing import List, Optional, Set, Tuple
|
from typing import List, Optional, Set, Tuple
|
||||||
|
|
||||||
@@ -422,6 +423,11 @@ class Llama4ForConditionalGeneration(nn.Module):
|
|||||||
"gate_up_proj": ["gate_proj", "up_proj"],
|
"gate_up_proj": ["gate_proj", "up_proj"],
|
||||||
}
|
}
|
||||||
|
|
||||||
|
# Pattern to match language model layers only (skip vision_model and multi_modal_projector)
|
||||||
|
lora_pattern = re.compile(
|
||||||
|
r"^language_model\.model\.layers\.(\d+)\.(?:self_attn|mlp)\.(?:qkv_proj|o_proj|down_proj|gate_up_proj)"
|
||||||
|
)
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
config: Llama4Config,
|
config: Llama4Config,
|
||||||
@@ -555,6 +561,10 @@ class Llama4ForConditionalGeneration(nn.Module):
|
|||||||
|
|
||||||
return projected_vision_flat
|
return projected_vision_flat
|
||||||
|
|
||||||
|
def should_apply_lora(self, module_name: str) -> bool:
|
||||||
|
"""Skip vision model and multi_modal_projector for LoRA."""
|
||||||
|
return bool(self.lora_pattern.match(module_name))
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
input_ids: torch.Tensor,
|
input_ids: torch.Tensor,
|
||||||
|
|||||||
Reference in New Issue
Block a user