From 2614adf9caf1118e1c594ae7fe4744f489755b46 Mon Sep 17 00:00:00 2001 From: Antonin Vidon Date: Fri, 17 Oct 2025 18:39:57 -0400 Subject: [PATCH] [Fix] Skip visual layers when applying LoRA to Qwen2VL modules (#11519) --- python/sglang/srt/models/qwen2_vl.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/python/sglang/srt/models/qwen2_vl.py b/python/sglang/srt/models/qwen2_vl.py index 7a42829e8..943846210 100644 --- a/python/sglang/srt/models/qwen2_vl.py +++ b/python/sglang/srt/models/qwen2_vl.py @@ -28,7 +28,6 @@ from typing import Iterable, List, Optional, Tuple, Type, TypedDict import torch import torch.nn as nn -import torch.nn.functional as F from einops import rearrange from transformers import Qwen2VLConfig from transformers.models.qwen2_vl.configuration_qwen2_vl import Qwen2VLVisionConfig @@ -514,6 +513,10 @@ class Qwen2VLForConditionalGeneration(nn.Module): def get_input_embeddings(self): return self.model.embed_tokens + def should_apply_lora(self, module_name: str) -> bool: + # skip visual tower + return not module_name.startswith("visual") + def forward( self, input_ids: torch.Tensor,