Urgent model support: support gemma-3-it (#4424)
This commit is contained in:
@@ -391,9 +391,13 @@ def _get_and_verify_dtype(
|
||||
dtype = dtype.lower()
|
||||
if dtype == "auto":
|
||||
if config_dtype == torch.float32:
|
||||
if config.model_type == "gemma2":
|
||||
if config.model_type.startswith("gemma"):
|
||||
if config.model_type == "gemma":
|
||||
gemma_version = ""
|
||||
else:
|
||||
gemma_version = config.model_type[5]
|
||||
logger.info(
|
||||
"For Gemma 2, we downcast float32 to bfloat16 instead "
|
||||
f"For Gemma {gemma_version}, we downcast float32 to bfloat16 instead "
|
||||
"of float16 by default. Please specify `dtype` if you "
|
||||
"want to use float16."
|
||||
)
|
||||
@@ -453,6 +457,7 @@ multimodal_model_archs = [
|
||||
"LlavaQwenForCausalLM",
|
||||
"LlavaMistralForCausalLM",
|
||||
"LlavaVidForCausalLM",
|
||||
"Gemma3ForConditionalGeneration",
|
||||
"Grok1VForCausalLM",
|
||||
"Grok1AForCausalLM",
|
||||
"MllamaForConditionalGeneration",
|
||||
|
||||
Reference in New Issue
Block a user