From 04f2abcb341037f2587e74c1d04e0b08c4ac65fb Mon Sep 17 00:00:00 2001 From: Yineng Zhang Date: Tue, 22 Apr 2025 01:16:08 -0700 Subject: [PATCH] fix: gemma 3 not use softcap (#5622) --- python/sglang/srt/configs/model_config.py | 5 +++++ python/sglang/srt/models/gemma3_causal.py | 2 +- python/sglang/srt/server_args.py | 11 ++++++++++- python/sglang/srt/utils.py | 1 + 4 files changed, 17 insertions(+), 2 deletions(-) diff --git a/python/sglang/srt/configs/model_config.py b/python/sglang/srt/configs/model_config.py index 28bf9c83e..a719bf32b 100644 --- a/python/sglang/srt/configs/model_config.py +++ b/python/sglang/srt/configs/model_config.py @@ -78,6 +78,11 @@ class ModelConfig: logger.info( "Multimodal is disabled for Llama4. To enable it, set --enable-llama4-multimodal." ) + elif self.hf_config.architectures[0] == "Gemma3ForConditionalGeneration": + enable_multimodal = False + logger.info( + "Multimodal is disabled for Gemma3. To enable it, set --enable-gemma3-multimodal." + ) else: enable_multimodal = True diff --git a/python/sglang/srt/models/gemma3_causal.py b/python/sglang/srt/models/gemma3_causal.py index e34715571..511d9c7e8 100644 --- a/python/sglang/srt/models/gemma3_causal.py +++ b/python/sglang/srt/models/gemma3_causal.py @@ -189,7 +189,7 @@ class Gemma3Attention(nn.Module): self.scaling, num_kv_heads=self.num_kv_heads, layer_id=layer_id, - logit_cap=getattr(self.config, "attn_logit_softcapping", None), + logit_cap=0.0, # Module must also define `get_attention_sliding_window_size` to correctly initialize # attention backend in `ForwardBatch`. sliding_window_size=self.sliding_window, diff --git a/python/sglang/srt/server_args.py b/python/sglang/srt/server_args.py index 4c2b0122f..b35dd9321 100644 --- a/python/sglang/srt/server_args.py +++ b/python/sglang/srt/server_args.py @@ -154,6 +154,7 @@ class ServerArgs: disable_outlines_disk_cache: bool = False disable_custom_all_reduce: bool = False enable_llama4_multimodal: Optional[bool] = None + enable_gemma3_multimodal: Optional[bool] = None disable_overlap_schedule: bool = False enable_mixed_chunk: bool = False enable_dp_attention: bool = False @@ -285,7 +286,9 @@ class ServerArgs: if self.grammar_backend is None: self.grammar_backend = "xgrammar" - self.enable_multimodal: Optional[bool] = self.enable_llama4_multimodal + self.enable_multimodal: Optional[bool] = ( + self.enable_llama4_multimodal or self.enable_gemma3_multimodal + ) # Data parallelism attention if self.enable_dp_attention: @@ -984,6 +987,12 @@ class ServerArgs: action="store_true", help="Enable the multimodal functionality for Llama-4.", ) + parser.add_argument( + "--enable-gemma3-multimodal", + default=ServerArgs.enable_gemma3_multimodal, + action="store_true", + help="Enable the multimodal functionality for Gemma-3.", + ) parser.add_argument( "--disable-overlap-schedule", action="store_true", diff --git a/python/sglang/srt/utils.py b/python/sglang/srt/utils.py index 1e9e66441..ba6bb6140 100644 --- a/python/sglang/srt/utils.py +++ b/python/sglang/srt/utils.py @@ -1971,6 +1971,7 @@ def is_fa3_default_architecture(hf_config): "LlamaForCausalLM", "MistralForCausalLM", "Gemma2ForCausalLM", + "Gemma3ForConditionalGeneration", } return architectures[0] in default_archs