add gemma3
This commit is contained in:
@@ -226,7 +226,7 @@ class ModelConfig:
|
||||
sliding_window = getattr(self.hf_text_config, "sliding_window", None)
|
||||
has_interleaved_attention = (sliding_window is not None) and (
|
||||
isinstance(sliding_window, list) or
|
||||
(self.hf_text_config.model_type in ["gemma2"]))
|
||||
(self.hf_text_config.model_type in ["gemma2", "gemma3"]))
|
||||
|
||||
if (not self.disable_sliding_window and has_interleaved_attention):
|
||||
sliding_window_len_min = get_min_sliding_window(
|
||||
@@ -1854,9 +1854,9 @@ def _get_and_verify_dtype(
|
||||
dtype = dtype.lower()
|
||||
if dtype == "auto":
|
||||
if config_dtype == torch.float32:
|
||||
if config.model_type == "gemma2":
|
||||
if config.model_type in ("gemma2", "gemma3"):
|
||||
logger.info(
|
||||
"For Gemma 2, we downcast float32 to bfloat16 instead "
|
||||
"For Gemma 2/3, we downcast float32 to bfloat16 instead "
|
||||
"of float16 by default. Please specify `dtype` if you "
|
||||
"want to use float16.")
|
||||
torch_dtype = torch.bfloat16
|
||||
|
||||
Reference in New Issue
Block a user