Urgent model support: support gemma-3-it (#4424)
This commit is contained in:
@@ -1,6 +1,7 @@
|
||||
from sglang.srt.configs.chatglm import ChatGLMConfig
|
||||
from sglang.srt.configs.dbrx import DbrxConfig
|
||||
from sglang.srt.configs.exaone import ExaoneConfig
|
||||
from sglang.srt.configs.gemma3 import Gemma3Config, Gemma3TextConfig
|
||||
from sglang.srt.configs.janus_pro import MultiModalityConfig
|
||||
from sglang.srt.configs.qwen2_5_vl_config import (
|
||||
Qwen2_5_VLConfig,
|
||||
@@ -14,4 +15,6 @@ __all__ = [
|
||||
"Qwen2_5_VLConfig",
|
||||
"Qwen2_5_VLVisionConfig",
|
||||
"MultiModalityConfig",
|
||||
"Gemma3Config",
|
||||
"Gemma3TextConfig",
|
||||
]
|
||||
|
||||
1086
python/sglang/srt/configs/gemma3.py
Normal file
1086
python/sglang/srt/configs/gemma3.py
Normal file
File diff suppressed because it is too large
Load Diff
@@ -391,9 +391,13 @@ def _get_and_verify_dtype(
|
||||
dtype = dtype.lower()
|
||||
if dtype == "auto":
|
||||
if config_dtype == torch.float32:
|
||||
if config.model_type == "gemma2":
|
||||
if config.model_type.startswith("gemma"):
|
||||
if config.model_type == "gemma":
|
||||
gemma_version = ""
|
||||
else:
|
||||
gemma_version = config.model_type[5]
|
||||
logger.info(
|
||||
"For Gemma 2, we downcast float32 to bfloat16 instead "
|
||||
f"For Gemma {gemma_version}, we downcast float32 to bfloat16 instead "
|
||||
"of float16 by default. Please specify `dtype` if you "
|
||||
"want to use float16."
|
||||
)
|
||||
@@ -453,6 +457,7 @@ multimodal_model_archs = [
|
||||
"LlavaQwenForCausalLM",
|
||||
"LlavaMistralForCausalLM",
|
||||
"LlavaVidForCausalLM",
|
||||
"Gemma3ForConditionalGeneration",
|
||||
"Grok1VForCausalLM",
|
||||
"Grok1AForCausalLM",
|
||||
"MllamaForConditionalGeneration",
|
||||
|
||||
Reference in New Issue
Block a user