From 7cb20754faca1779860723c3fbd9c1a19acacac8 Mon Sep 17 00:00:00 2001 From: Lifu Huang Date: Mon, 4 Aug 2025 17:11:46 -0700 Subject: [PATCH] [Fix] Fix several issues preventing gemma3n LoRA support. (#8776) --- python/sglang/srt/lora/lora_manager.py | 7 +++++ python/sglang/srt/models/gemma3n_mm.py | 39 ++++++++++++++++++++++++++ python/sglang/srt/server_args.py | 10 +++++-- 3 files changed, 54 insertions(+), 2 deletions(-) diff --git a/python/sglang/srt/lora/lora_manager.py b/python/sglang/srt/lora/lora_manager.py index c1d6439a0..e4fe1d0d1 100644 --- a/python/sglang/srt/lora/lora_manager.py +++ b/python/sglang/srt/lora/lora_manager.py @@ -386,6 +386,13 @@ class LoRAManager: else: self.target_modules = set() for config in self.configs.values(): + if not isinstance(config.target_modules, list): + raise ValueError( + f"SGLang currently only supports inferring LoRA target modules when a list of " + "suffixes is provided in `target_modules` field of PEFT config. Please explicitly " + "specify `--lora-target-modules` during server startup. You can specify `all` to " + "enable all support modules types. " + ) self.target_modules.update(config.target_modules) if max_lora_rank is not None: diff --git a/python/sglang/srt/models/gemma3n_mm.py b/python/sglang/srt/models/gemma3n_mm.py index 5139a9c2d..b4bf2ba75 100644 --- a/python/sglang/srt/models/gemma3n_mm.py +++ b/python/sglang/srt/models/gemma3n_mm.py @@ -492,5 +492,44 @@ class Gemma3nForConditionalGeneration(PreTrainedModel): loaded_params.add(name) return loaded_params + lora_pattern = re.compile( + r"^language_model\.layers\.(\d+)\.(?:self_attn|mlp)\.(?:qkv_proj|o_proj|down_proj|gate_up_proj)" + ) + + def should_apply_lora(self, module_name: str) -> bool: + return bool(self.lora_pattern.match(module_name)) + + def get_hidden_dim(self, module_name): + # return input_dim, output_dim + if module_name in ["q_proj", "qkv_proj"]: + return ( + self.config.hidden_size, + self.config.head_dim * self.config.num_attention_heads, + ) + elif module_name in ["o_proj"]: + return ( + self.config.head_dim * self.config.num_attention_heads, + self.config.hidden_size, + ) + elif module_name in ["kv_proj"]: + return ( + self.config.hidden_size, + self.config.head_dim * self.config.num_key_value_heads, + ) + elif module_name == "gate_up_proj": + assert len(set(self.config.intermediate_size)) == 1, ( + "Currently SGLang requires uniform intermediate size for all layers. " + "Please file an issue if you need support for non-uniform intermediate sizes." + ) + return self.config.hidden_size, self.config.intermediate_size[0] + elif module_name == "down_proj": + assert len(set(self.config.intermediate_size)) == 1, ( + "Currently SGLang requires uniform intermediate size for all layers. " + "Please file an issue if you need support for non-uniform intermediate sizes." + ) + return self.config.intermediate_size[0], self.config.hidden_size + else: + raise NotImplementedError() + EntryClass = Gemma3nForConditionalGeneration diff --git a/python/sglang/srt/server_args.py b/python/sglang/srt/server_args.py index fb3f80f87..aacaaf1cd 100644 --- a/python/sglang/srt/server_args.py +++ b/python/sglang/srt/server_args.py @@ -1943,10 +1943,16 @@ class ServerArgs: if "Llama4" in model_arch: assert self.attention_backend == "fa3", "fa3 is required for Llama4 model" - if "Gemma2ForCausalLM" in model_arch: + if model_arch in [ + "Gemma2ForCausalLM", + "Gemma3nForCausalLM", + "Gemma3nForConditionalGeneration", + ]: # FIXME: https://github.com/sgl-project/sglang/pull/7367 is not compatible with gemma2 model. # It failed at this test: https://github.com/sgl-project/sglang/actions/runs/16255155597/job/45890331952#step:4:736 - logger.warning("Disable hybrid SWA memory for Gemma2ForCausalLM.") + logger.warning( + f"Disable hybrid SWA memory for {model_arch} as it is not yet supported." + ) self.disable_hybrid_swa_memory = True # Check LoRA