From afc35ccc5e78d3bb18aa28d3fa46617995c899b7 Mon Sep 17 00:00:00 2001 From: Chenxi Li <41864925+ConnorLi96@users.noreply.github.com> Date: Mon, 6 Oct 2025 17:23:00 -0700 Subject: [PATCH] Fix LoRA support for multimodal models (VLMs) by implementing a consistent pattern for skipping vision components (#11261) --- python/sglang/srt/lora/lora_manager.py | 8 -------- python/sglang/srt/models/gemma3_mm.py | 16 ++++++++++++++++ python/sglang/srt/models/mllama4.py | 10 ++++++++++ 3 files changed, 26 insertions(+), 8 deletions(-) diff --git a/python/sglang/srt/lora/lora_manager.py b/python/sglang/srt/lora/lora_manager.py index 2b90a8741..ac6f26e5d 100644 --- a/python/sglang/srt/lora/lora_manager.py +++ b/python/sglang/srt/lora/lora_manager.py @@ -418,10 +418,6 @@ class LoRAManager: replace_submodule(self.base_model, module_name, lora_module) return lora_module - def should_skip_lora_for_vision_model(self, module_name): - # TODO: support different vision models - return module_name.find("vision_model.model") != -1 - def init_lora_modules(self): # Look-up table that essentially maps (layer_index, module_name) to the corresponding LoRA module. self.lora_modules: List[Dict[str, BaseLayerWithLoRA]] = [ @@ -439,10 +435,6 @@ class LoRAManager: ) and not self.base_model.should_apply_lora(module_name): continue - # Skip vision model - if self.should_skip_lora_for_vision_model(module_name): - continue - # The module should be converted if it is included in target_names if module_name.split(".")[-1] in self.target_modules: layer_id = get_layer_id(module_name) diff --git a/python/sglang/srt/models/gemma3_mm.py b/python/sglang/srt/models/gemma3_mm.py index 8060fdee9..de2300522 100644 --- a/python/sglang/srt/models/gemma3_mm.py +++ b/python/sglang/srt/models/gemma3_mm.py @@ -16,6 +16,7 @@ # https://github.com/vllm-project/vllm/blob/main/vllm/model_executor/models/gemma3_mm.py import logging +import re from functools import lru_cache from typing import Dict, Iterable, List, Optional, Set, Tuple, TypedDict @@ -154,6 +155,10 @@ class Gemma3ForConditionalGeneration(PreTrainedModel): embedding_modules = {} embedding_padding_modules = [] supports_lora = True + # Pattern to match language model layers only (skip vision_tower and multi_modal_projector) + lora_pattern = re.compile( + r"^language_model\.model\.layers\.(\d+)\.(?:self_attn|mlp)\.(?:qkv_proj|o_proj|down_proj|gate_up_proj)" + ) def __init__( self, @@ -165,6 +170,13 @@ class Gemma3ForConditionalGeneration(PreTrainedModel): self.config = config self.quant_config = quant_config + # For LoRA compatibility: expose text_config attributes at top level + # This allows LoRA code to work without special multimodal handling + if not hasattr(config, "num_hidden_layers"): + config.num_hidden_layers = config.text_config.num_hidden_layers + if not hasattr(config, "hidden_size"): + config.hidden_size = config.text_config.hidden_size + self.vision_tower = SiglipVisionModel( config=config.vision_config, quant_config=quant_config, @@ -380,6 +392,10 @@ class Gemma3ForConditionalGeneration(PreTrainedModel): return hs + def should_apply_lora(self, module_name: str) -> bool: + """Skip vision tower and multi_modal_projector for LoRA.""" + return bool(self.lora_pattern.match(module_name)) + def tie_weights(self): return self.language_model.tie_weights() diff --git a/python/sglang/srt/models/mllama4.py b/python/sglang/srt/models/mllama4.py index 4bc0e275f..e452cb2d4 100644 --- a/python/sglang/srt/models/mllama4.py +++ b/python/sglang/srt/models/mllama4.py @@ -2,6 +2,7 @@ import json as json_lib import logging import math import os +import re from collections.abc import Iterable from typing import List, Optional, Set, Tuple @@ -422,6 +423,11 @@ class Llama4ForConditionalGeneration(nn.Module): "gate_up_proj": ["gate_proj", "up_proj"], } + # Pattern to match language model layers only (skip vision_model and multi_modal_projector) + lora_pattern = re.compile( + r"^language_model\.model\.layers\.(\d+)\.(?:self_attn|mlp)\.(?:qkv_proj|o_proj|down_proj|gate_up_proj)" + ) + def __init__( self, config: Llama4Config, @@ -555,6 +561,10 @@ class Llama4ForConditionalGeneration(nn.Module): return projected_vision_flat + def should_apply_lora(self, module_name: str) -> bool: + """Skip vision model and multi_modal_projector for LoRA.""" + return bool(self.lora_pattern.match(module_name)) + def forward( self, input_ids: torch.Tensor,