model : add grok-2 support (#15539)

* add grok-2 support

* type fix

* type fix

* type fix

* "fix" vocab for invalid sequences

* fix expert tensor mapping and spaces in vocab

* add chat template

* fix norm tensor mapping

* rename layer_out_norm to ffn_post_norm

* ensure ffn_post_norm is mapped

* fix experts merging

* remove erroneous FFN_GATE entry

* concatenate split tensors and add more metadata

* process all expert layers and try cat instead of hstack

* add support for community BPE vocab

* fix expert feed forward length and ffn_down concat

* commit this too

* add ffn_up/gate/down, unsure if sequence is right

* add ffn_gate/down/up to tensor names

* correct residual moe (still not working)

* mess--

* fix embedding scale being applied twice

* add built in chat template

* change beta fast for grok if default value

* remove spm vocab in favor of community bpe vocab

* change attention temp length metadata type to integer

* update attention temp length metadata

* remove comment

* replace M_SQRT2 with std::sqrt(2)

* add yarn metadata, move defaults to hparams
This commit is contained in:
Sigbjørn Skjæret
2025-09-14 23:00:59 +02:00
committed by GitHub
parent 6c019cb04e
commit b8e09f08b9
16 changed files with 281 additions and 96 deletions

View File

@@ -111,6 +111,7 @@ class Keys:
DECODER_START_TOKEN_ID = "{arch}.decoder_start_token_id"
DECODER_BLOCK_COUNT = "{arch}.decoder_block_count"
ATTN_LOGIT_SOFTCAPPING = "{arch}.attn_logit_softcapping"
ROUTER_LOGIT_SOFTCAPPING = "{arch}.router_logit_softcapping"
FINAL_LOGIT_SOFTCAPPING = "{arch}.final_logit_softcapping"
SWIN_NORM = "{arch}.swin_norm"
RESCALE_EVERY_N_LAYERS = "{arch}.rescale_every_n_layers"
@@ -146,21 +147,27 @@ class Keys:
REL_BUCKETS_COUNT = "{arch}.attention.relative_buckets_count"
SLIDING_WINDOW = "{arch}.attention.sliding_window"
SCALE = "{arch}.attention.scale"
OUTPUT_SCALE = "{arch}.attention.output_scale"
TEMPERATURE_LENGTH = "{arch}.attention.temperature_length"
KEY_LENGTH_MLA = "{arch}.attention.key_length_mla"
VALUE_LENGTH_MLA = "{arch}.attention.value_length_mla"
SHARED_KV_LAYERS = "{arch}.attention.shared_kv_layers"
SLIDING_WINDOW_PATTERN = "{arch}.attention.sliding_window_pattern"
class Rope:
DIMENSION_COUNT = "{arch}.rope.dimension_count"
DIMENSION_SECTIONS = "{arch}.rope.dimension_sections"
FREQ_BASE = "{arch}.rope.freq_base"
SCALING_TYPE = "{arch}.rope.scaling.type"
SCALING_FACTOR = "{arch}.rope.scaling.factor"
SCALING_ATTN_FACTOR = "{arch}.rope.scaling.attn_factor"
SCALING_ORIG_CTX_LEN = "{arch}.rope.scaling.original_context_length"
SCALING_FINETUNED = "{arch}.rope.scaling.finetuned"
SCALING_YARN_LOG_MUL = "{arch}.rope.scaling.yarn_log_multiplier"
DIMENSION_COUNT = "{arch}.rope.dimension_count"
DIMENSION_SECTIONS = "{arch}.rope.dimension_sections"
FREQ_BASE = "{arch}.rope.freq_base"
SCALING_TYPE = "{arch}.rope.scaling.type"
SCALING_FACTOR = "{arch}.rope.scaling.factor"
SCALING_ATTN_FACTOR = "{arch}.rope.scaling.attn_factor"
SCALING_ORIG_CTX_LEN = "{arch}.rope.scaling.original_context_length"
SCALING_FINETUNED = "{arch}.rope.scaling.finetuned"
SCALING_YARN_LOG_MUL = "{arch}.rope.scaling.yarn_log_multiplier"
SCALING_YARN_EXT_FACTOR = "{arch}.rope.scaling.yarn_ext_factor"
SCALING_YARN_ATTN_FACTOR = "{arch}.rope.scaling.yarn_attn_factor"
SCALING_YARN_BETA_FAST = "{arch}.rope.scaling.yarn_beta_fast"
SCALING_YARN_BETA_SLOW = "{arch}.rope.scaling.yarn_beta_slow"
class Split:
LLM_KV_SPLIT_NO = "split.no"
@@ -1114,6 +1121,7 @@ MODEL_TENSORS: dict[MODEL_ARCH, list[MODEL_TENSOR]] = {
MODEL_TENSOR.FFN_GATE_EXP,
MODEL_TENSOR.FFN_DOWN_EXP,
MODEL_TENSOR.FFN_UP_EXP,
MODEL_TENSOR.FFN_POST_NORM,
MODEL_TENSOR.LAYER_OUT_NORM,
],
MODEL_ARCH.GPTNEOX: [

View File

@@ -733,6 +733,9 @@ class GGUFWriter:
def add_attn_logit_softcapping(self, value: float) -> None:
self.add_float32(Keys.LLM.ATTN_LOGIT_SOFTCAPPING.format(arch=self.arch), value)
def add_router_logit_softcapping(self, value: float) -> None:
self.add_float32(Keys.LLM.ROUTER_LOGIT_SOFTCAPPING.format(arch=self.arch), value)
def add_final_logit_softcapping(self, value: float) -> None:
self.add_float32(Keys.LLM.FINAL_LOGIT_SOFTCAPPING.format(arch=self.arch), value)
@@ -829,6 +832,12 @@ class GGUFWriter:
def add_attention_scale(self, value: float) -> None:
self.add_float32(Keys.Attention.SCALE.format(arch=self.arch), value)
def add_attn_output_scale(self, value: float) -> None:
self.add_float32(Keys.Attention.OUTPUT_SCALE.format(arch=self.arch), value)
def add_attn_temperature_length(self, value: int) -> None:
self.add_uint32(Keys.Attention.TEMPERATURE_LENGTH.format(arch=self.arch), value)
def add_pooling_type(self, value: PoolingType) -> None:
self.add_uint32(Keys.LLM.POOLING_TYPE.format(arch=self.arch), value.value)
@@ -859,6 +868,18 @@ class GGUFWriter:
def add_rope_scaling_yarn_log_mul(self, value: float) -> None:
self.add_float32(Keys.Rope.SCALING_YARN_LOG_MUL.format(arch=self.arch), value)
def add_rope_scaling_yarn_ext_factor(self, value: float) -> None:
self.add_float32(Keys.Rope.SCALING_YARN_EXT_FACTOR.format(arch=self.arch), value)
def add_rope_scaling_yarn_attn_factor(self, value: float) -> None:
self.add_float32(Keys.Rope.SCALING_YARN_ATTN_FACTOR.format(arch=self.arch), value)
def add_rope_scaling_yarn_beta_fast(self, value: float) -> None:
self.add_float32(Keys.Rope.SCALING_YARN_BETA_FAST.format(arch=self.arch), value)
def add_rope_scaling_yarn_beta_slow(self, value: float) -> None:
self.add_float32(Keys.Rope.SCALING_YARN_BETA_SLOW.format(arch=self.arch), value)
def add_ssm_conv_kernel(self, value: int) -> None:
self.add_uint32(Keys.SSM.CONV_KERNEL.format(arch=self.arch), value)

View File

@@ -136,6 +136,7 @@ class TensorNameMap:
"model.layers.{bid}.norm", # mamba-qbert
"backbone.layers.{bid}.norm", # mamba
"transformer.decoder_layer.{bid}.rms_norm", # Grok
"model.layers.{bid}.pre_attn_norm", # grok-2
"transformer.blocks.{bid}.norm_attn_norm.norm_1", # dbrx
"encoder.layers.{bid}.input_layernorm", # chatglm
"transformer.layers.{bid}.attn_norm", # openelm
@@ -278,6 +279,7 @@ class TensorNameMap:
"transformer.layer.{bid}.sa_layer_norm", # distillbert
"encoder.layers.{bid}.norm1", # nomic-bert
"transformer.decoder_layer.{bid}.rms_norm_1", # Grok
"model.layers.{bid}.post_attn_norm", # grok-2
"transformer.blocks.{bid}.norm_attn_norm.norm_2", # dbrx
),
@@ -313,6 +315,7 @@ class TensorNameMap:
"h.{bid}.ln_2", # gpt2
"model.layers.{bid}.ffn_norm", # internlm2
"transformer.decoder_layer.{bid}.rms_norm_2", # Grok
"model.layers.{bid}.pre_moe_norm", # grok-2
"encoder.layers.{bid}.post_attention_layernorm", # chatglm
"transformer.layers.{bid}.ffn_norm", # openelm
"model.layers.{bid}.pre_ff_layernorm", # jamba granite-hybrid
@@ -333,11 +336,12 @@ class TensorNameMap:
# Post feed-forward norm
MODEL_TENSOR.FFN_POST_NORM: (
"model.layers.{bid}.post_feedforward_layernorm", # gemma2 olmo2
"layers.{bid}.post_feedforward_layernorm", # embeddinggemma
"model.layers.{bid}.post_mlp_layernorm", # glm-4-0414
"model.layers.{bid}.post_feedforward_layernorm", # gemma2 olmo2
"layers.{bid}.post_feedforward_layernorm", # embeddinggemma
"model.layers.{bid}.post_mlp_layernorm", # glm-4-0414
"model.layers.layers.{bid}.post_mlp_norm.weight", # plamo2
"model.layers.{bid}.feed_forward.up_proj",
"model.layers.{bid}.post_moe_norm", # grok-2
),
MODEL_TENSOR.FFN_GATE_INP: (