From 8e48ca8cc1c7409a66eaff61685cd4be40d93908 Mon Sep 17 00:00:00 2001 From: Hui Liu <96135754+hliuca@users.noreply.github.com> Date: Sat, 25 Jan 2025 18:29:14 -0800 Subject: [PATCH] enable kv_scale for Gemma2 (#3113) --- python/sglang/srt/models/gemma2.py | 10 +++++++++- 1 file changed, 9 insertions(+), 1 deletion(-) diff --git a/python/sglang/srt/models/gemma2.py b/python/sglang/srt/models/gemma2.py index 4d21901de..06a7b0302 100644 --- a/python/sglang/srt/models/gemma2.py +++ b/python/sglang/srt/models/gemma2.py @@ -35,7 +35,10 @@ from sglang.srt.layers.radix_attention import RadixAttention from sglang.srt.layers.rotary_embedding import get_rope from sglang.srt.layers.vocab_parallel_embedding import VocabParallelEmbedding from sglang.srt.model_executor.forward_batch_info import ForwardBatch -from sglang.srt.model_loader.weight_utils import default_weight_loader +from sglang.srt.model_loader.weight_utils import ( + default_weight_loader, + maybe_remap_kv_scale_name, +) from sglang.srt.utils import make_layers @@ -424,6 +427,11 @@ class Gemma2ForCausalLM(nn.Module): # Skip loading extra bias for GPTQ models. if name.endswith(".bias") and name not in params_dict: continue + # Remapping the name of FP8 kv-scale. + name = maybe_remap_kv_scale_name(name, params_dict) + if name is None: + continue + param = params_dict[name] weight_loader = getattr(param, "weight_loader", default_weight_loader) weight_loader(param, loaded_weight)