diff --git a/python/sglang/srt/models/gemma2.py b/python/sglang/srt/models/gemma2.py index 4d21901de..06a7b0302 100644 --- a/python/sglang/srt/models/gemma2.py +++ b/python/sglang/srt/models/gemma2.py @@ -35,7 +35,10 @@ from sglang.srt.layers.radix_attention import RadixAttention from sglang.srt.layers.rotary_embedding import get_rope from sglang.srt.layers.vocab_parallel_embedding import VocabParallelEmbedding from sglang.srt.model_executor.forward_batch_info import ForwardBatch -from sglang.srt.model_loader.weight_utils import default_weight_loader +from sglang.srt.model_loader.weight_utils import ( + default_weight_loader, + maybe_remap_kv_scale_name, +) from sglang.srt.utils import make_layers @@ -424,6 +427,11 @@ class Gemma2ForCausalLM(nn.Module): # Skip loading extra bias for GPTQ models. if name.endswith(".bias") and name not in params_dict: continue + # Remapping the name of FP8 kv-scale. + name = maybe_remap_kv_scale_name(name, params_dict) + if name is None: + continue + param = params_dict[name] weight_loader = getattr(param, "weight_loader", default_weight_loader) weight_loader(param, loaded_weight)