llama : add gpt-oss (#15091)
* oai moe * compat with new checkpoint * add attn sink impl * add rope scaling yarn * logits match with latest transformers code * wip chat template * rm trailing space * use ggml_scale_bias * rm redundant is_swa_all * convert interleaved gate_up * graph : fix activation function to match reference (#7) * vocab : handle o200k_harmony special tokens * ggml : add attention sinks support (#1) * llama : add attn sinks * ggml : add attn sinks * cuda : add attn sinks * vulkan : add support for sinks in softmax remove unnecessary return * ggml : add fused swiglu_oai op (#11) * ggml : add fused swiglu_oai op * Update ggml/src/ggml-cpu/ops.cpp Co-authored-by: Georgi Gerganov <ggerganov@gmail.com> * update CUDA impl * cont : metal impl * add vulkan impl * test-backend-ops : more test cases, clean up * llama : remove unfused impl * remove extra lines --------- Co-authored-by: Georgi Gerganov <ggerganov@gmail.com> --------- Co-authored-by: slaren <slarengh@gmail.com> * repack mxfp4 upon conversion * clean up a bit * enable thinking * add quick hack to render only some special tokens * fix bf16 conversion * remove vocab hack * webui ok * support chat parsing for gpt-oss * fix webui * direct mapping mxfp4, FINALLY * force using mxfp4 * properly use lazy tensor * ggml : add mxfp4 ggml : use e8m0 conversion instead of powf Co-authored-by: Diego Devesa <slarengh@gmail.com> change kvalues_mxfp4 table to match e2m1 (#6) metal : remove quantization for now (not used) cuda : fix disabled CUDA graphs due to ffn moe bias vulkan : add support for mxfp4 cont : add cm2 dequant * ggml : add ggml_add_id (#13) * ggml : add ggml_add_id * add cuda impl * llama : add weight support check for add_id * perf opt * add vulkan impl * rename cuda files * add metal impl * allow in-place ggml_add_id * llama : keep biases on CPU with --cpu-moe * llama : fix compile error ggml-ci * cuda : add fallback for __nv_cvt_e8m0_to_bf16raw ggml-ci * cleanup ggml-ci * sycl : fix supports_op for MXFP4 ggml-ci * fix Unknown reasoning format * ggml-cpu : fix AVX build ggml-ci * fix hip build ggml-ci * cuda : add mxfp4 dequantization support for cuBLAS ggml-ci * ggml-cpu : fix mxfp4 fallback definitions for some architectures ggml-ci * cuda : fix version required for __nv_cvt_e8m0_to_bf16raw --------- Co-authored-by: Xuan Son Nguyen <son@huggingface.co> Co-authored-by: slaren <slarengh@gmail.com>
This commit is contained in:
@@ -88,6 +88,7 @@ static const std::map<llm_arch, const char *> LLM_ARCH_NAMES = {
|
||||
{ LLM_ARCH_HUNYUAN_MOE, "hunyuan-moe" },
|
||||
{ LLM_ARCH_HUNYUAN_DENSE, "hunyuan-dense" },
|
||||
{ LLM_ARCH_SMOLLM3, "smollm3" },
|
||||
{ LLM_ARCH_OPENAI_MOE, "gpt-oss" },
|
||||
{ LLM_ARCH_LFM2, "lfm2" },
|
||||
{ LLM_ARCH_DREAM, "dream" },
|
||||
{ LLM_ARCH_SMALLTHINKER, "smallthinker" },
|
||||
@@ -1971,6 +1972,25 @@ static const std::map<llm_arch, std::map<llm_tensor, const char *>> LLM_TENSOR_N
|
||||
{ LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" },
|
||||
},
|
||||
},
|
||||
{
|
||||
LLM_ARCH_OPENAI_MOE,
|
||||
{
|
||||
{ LLM_TENSOR_TOKEN_EMBD, "token_embd" },
|
||||
{ LLM_TENSOR_OUTPUT_NORM, "output_norm" },
|
||||
{ LLM_TENSOR_OUTPUT, "output" },
|
||||
{ LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" },
|
||||
{ LLM_TENSOR_ATTN_POST_NORM, "blk.%d.post_attention_norm" },
|
||||
{ LLM_TENSOR_ATTN_Q, "blk.%d.attn_q" },
|
||||
{ LLM_TENSOR_ATTN_K, "blk.%d.attn_k" },
|
||||
{ LLM_TENSOR_ATTN_V, "blk.%d.attn_v" },
|
||||
{ LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" },
|
||||
{ LLM_TENSOR_ATTN_SINKS, "blk.%d.attn_sinks" },
|
||||
{ LLM_TENSOR_FFN_GATE_INP, "blk.%d.ffn_gate_inp" },
|
||||
{ LLM_TENSOR_FFN_GATE_EXPS, "blk.%d.ffn_gate_exps" },
|
||||
{ LLM_TENSOR_FFN_DOWN_EXPS, "blk.%d.ffn_down_exps" },
|
||||
{ LLM_TENSOR_FFN_UP_EXPS, "blk.%d.ffn_up_exps" },
|
||||
},
|
||||
},
|
||||
{
|
||||
LLM_ARCH_LFM2,
|
||||
{
|
||||
@@ -2086,6 +2106,7 @@ static const std::map<llm_tensor, llm_tensor_info> LLM_TENSOR_INFOS = {
|
||||
{LLM_TENSOR_ATTN_KV_B, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
|
||||
{LLM_TENSOR_ATTN_K_B, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
|
||||
{LLM_TENSOR_ATTN_V_B, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
|
||||
{LLM_TENSOR_ATTN_SINKS, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_SCALE}},
|
||||
{LLM_TENSOR_DEC_ATTN_Q, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
|
||||
{LLM_TENSOR_DEC_ATTN_K, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
|
||||
{LLM_TENSOR_DEC_ATTN_V, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
|
||||
|
||||
@@ -92,6 +92,7 @@ enum llm_arch {
|
||||
LLM_ARCH_HUNYUAN_MOE,
|
||||
LLM_ARCH_HUNYUAN_DENSE,
|
||||
LLM_ARCH_SMOLLM3,
|
||||
LLM_ARCH_OPENAI_MOE,
|
||||
LLM_ARCH_LFM2,
|
||||
LLM_ARCH_DREAM,
|
||||
LLM_ARCH_SMALLTHINKER,
|
||||
@@ -265,6 +266,7 @@ enum llm_tensor {
|
||||
LLM_TENSOR_ATTN_OUT_NORM,
|
||||
LLM_TENSOR_ATTN_POST_NORM,
|
||||
LLM_TENSOR_ATTN_ROT_EMBD,
|
||||
LLM_TENSOR_ATTN_SINKS,
|
||||
LLM_TENSOR_FFN_GATE_INP,
|
||||
LLM_TENSOR_FFN_GATE_INP_SHEXP,
|
||||
LLM_TENSOR_FFN_NORM,
|
||||
|
||||
@@ -66,6 +66,7 @@ static const std::map<std::string, llm_chat_template> LLM_CHAT_TEMPLATES = {
|
||||
{ "llama4", LLM_CHAT_TEMPLATE_LLAMA4 },
|
||||
{ "smolvlm", LLM_CHAT_TEMPLATE_SMOLVLM },
|
||||
{ "hunyuan-moe", LLM_CHAT_TEMPLATE_HUNYUAN_MOE },
|
||||
{ "gpt-oss", LLM_CHAT_TEMPLATE_OPENAI_MOE },
|
||||
{ "hunyuan-dense", LLM_CHAT_TEMPLATE_HUNYUAN_DENSE },
|
||||
{ "kimi-k2", LLM_CHAT_TEMPLATE_KIMI_K2 },
|
||||
};
|
||||
@@ -194,6 +195,8 @@ llm_chat_template llm_chat_detect_template(const std::string & tmpl) {
|
||||
return LLM_CHAT_TEMPLATE_DOTS1;
|
||||
} else if (tmpl_contains("<|startoftext|>") && tmpl_contains("<|extra_4|>")) {
|
||||
return LLM_CHAT_TEMPLATE_HUNYUAN_MOE;
|
||||
} else if (tmpl_contains("<|start|>") && tmpl_contains("<|channel|>")) {
|
||||
return LLM_CHAT_TEMPLATE_OPENAI_MOE;
|
||||
} else if (tmpl_contains("<|hy_place▁holder▁no▁2|>") && tmpl_contains("<|hy_place▁holder▁no▁3|>")) {
|
||||
return LLM_CHAT_TEMPLATE_HUNYUAN_DENSE;
|
||||
} else if (tmpl_contains("<|im_assistant|>assistant<|im_middle|>")) {
|
||||
@@ -706,6 +709,16 @@ int32_t llm_chat_apply_template(
|
||||
ss << "<|startoftext|>" << message->content << "<|extra_0|>";
|
||||
}
|
||||
}
|
||||
} else if (tmpl == LLM_CHAT_TEMPLATE_OPENAI_MOE) {
|
||||
// OpenAI MoE (based on Harmony chat template)
|
||||
for (auto message : chat) {
|
||||
std::string role(message->role);
|
||||
ss << "<|start|>" << role << "<|message|>" << message->content;
|
||||
ss << (role == "assistant" ? "<|return|>" : "<|end|>");
|
||||
}
|
||||
if (add_ass) {
|
||||
ss << "<|start|>assistant";
|
||||
}
|
||||
} else if (tmpl == LLM_CHAT_TEMPLATE_HUNYUAN_DENSE) {
|
||||
// tencent/Hunyuan-4B-Instruct
|
||||
for (size_t i = 0; i < chat.size(); i++) {
|
||||
|
||||
@@ -46,6 +46,7 @@ enum llm_chat_template {
|
||||
LLM_CHAT_TEMPLATE_SMOLVLM,
|
||||
LLM_CHAT_TEMPLATE_DOTS1,
|
||||
LLM_CHAT_TEMPLATE_HUNYUAN_MOE,
|
||||
LLM_CHAT_TEMPLATE_OPENAI_MOE,
|
||||
LLM_CHAT_TEMPLATE_HUNYUAN_DENSE,
|
||||
LLM_CHAT_TEMPLATE_KIMI_K2,
|
||||
LLM_CHAT_TEMPLATE_UNKNOWN,
|
||||
|
||||
@@ -740,6 +740,8 @@ ggml_tensor * llm_graph_context::build_ffn(
|
||||
cur = ggml_reglu(ctx0, cur);
|
||||
cb(cur, "ffn_reglu", il);
|
||||
} break;
|
||||
default:
|
||||
GGML_ABORT("fatal error");
|
||||
}
|
||||
|
||||
if (gate && type_gate == LLM_FFN_PAR) {
|
||||
@@ -787,6 +789,45 @@ ggml_tensor * llm_graph_context::build_moe_ffn(
|
||||
llama_expert_gating_func_type gating_op,
|
||||
int il,
|
||||
ggml_tensor * probs_in) const {
|
||||
return build_moe_ffn(
|
||||
cur,
|
||||
gate_inp, /* gate_inp_b */ nullptr,
|
||||
up_exps, /* up_exps_b */ nullptr,
|
||||
gate_exps, /* gate_exps_b */ nullptr,
|
||||
down_exps, /* down_exps_b */ nullptr,
|
||||
exp_probs_b,
|
||||
n_expert,
|
||||
n_expert_used,
|
||||
type_op,
|
||||
norm_w,
|
||||
scale_w,
|
||||
w_scale,
|
||||
gating_op,
|
||||
il,
|
||||
probs_in
|
||||
);
|
||||
}
|
||||
|
||||
ggml_tensor * llm_graph_context::build_moe_ffn(
|
||||
ggml_tensor * cur,
|
||||
ggml_tensor * gate_inp,
|
||||
ggml_tensor * gate_inp_b,
|
||||
ggml_tensor * up_exps,
|
||||
ggml_tensor * up_exps_b,
|
||||
ggml_tensor * gate_exps,
|
||||
ggml_tensor * gate_exps_b,
|
||||
ggml_tensor * down_exps,
|
||||
ggml_tensor * down_exps_b,
|
||||
ggml_tensor * exp_probs_b,
|
||||
int64_t n_expert,
|
||||
int64_t n_expert_used,
|
||||
llm_ffn_op_type type_op,
|
||||
bool norm_w,
|
||||
bool scale_w,
|
||||
float w_scale,
|
||||
llama_expert_gating_func_type gating_op,
|
||||
int il,
|
||||
ggml_tensor * probs_in) const {
|
||||
const int64_t n_embd = cur->ne[0];
|
||||
const int64_t n_tokens = cur->ne[1];
|
||||
const bool weight_before_ffn = arch == LLM_ARCH_LLAMA4; // for llama4, we apply the sigmoid-ed weights before the FFN
|
||||
@@ -800,6 +841,11 @@ ggml_tensor * llm_graph_context::build_moe_ffn(
|
||||
logits = probs_in;
|
||||
}
|
||||
|
||||
if (gate_inp_b) {
|
||||
logits = ggml_add(ctx0, logits, gate_inp_b);
|
||||
cb(logits, "ffn_moe_logits_biased", il);
|
||||
}
|
||||
|
||||
ggml_tensor * probs = nullptr;
|
||||
switch (gating_op) {
|
||||
case LLAMA_EXPERT_GATING_FUNC_TYPE_SOFTMAX:
|
||||
@@ -810,6 +856,10 @@ ggml_tensor * llm_graph_context::build_moe_ffn(
|
||||
{
|
||||
probs = ggml_sigmoid(ctx0, logits); // [n_expert, n_tokens]
|
||||
} break;
|
||||
case LLAMA_EXPERT_GATING_FUNC_TYPE_SOFTMAX_WEIGHT:
|
||||
{
|
||||
probs = logits; // [n_expert, n_tokens]
|
||||
} break;
|
||||
default:
|
||||
GGML_ABORT("fatal error");
|
||||
}
|
||||
@@ -838,6 +888,13 @@ ggml_tensor * llm_graph_context::build_moe_ffn(
|
||||
ggml_reshape_3d(ctx0, probs, 1, n_expert, n_tokens), selected_experts); // [1, n_expert_used, n_tokens]
|
||||
cb(weights, "ffn_moe_weights", il);
|
||||
|
||||
if (gating_op == LLAMA_EXPERT_GATING_FUNC_TYPE_SOFTMAX_WEIGHT) {
|
||||
weights = ggml_reshape_2d(ctx0, weights, n_expert_used, n_tokens);
|
||||
weights = ggml_soft_max(ctx0, weights); // [n_expert_used, n_tokens]
|
||||
weights = ggml_reshape_3d(ctx0, weights, 1, n_expert_used, n_tokens);
|
||||
cb(weights, "ffn_moe_weights_softmax", il);
|
||||
}
|
||||
|
||||
if (norm_w) {
|
||||
weights = ggml_reshape_2d(ctx0, weights, n_expert_used, n_tokens);
|
||||
|
||||
@@ -866,6 +923,11 @@ ggml_tensor * llm_graph_context::build_moe_ffn(
|
||||
ggml_tensor * up = build_lora_mm_id(up_exps, cur, selected_experts); // [n_ff, n_expert_used, n_tokens]
|
||||
cb(up, "ffn_moe_up", il);
|
||||
|
||||
if (up_exps_b) {
|
||||
up = ggml_add_id(ctx0, up, up_exps_b, selected_experts);
|
||||
cb(up, "ffn_moe_up_biased", il);
|
||||
}
|
||||
|
||||
ggml_tensor * experts = nullptr;
|
||||
if (gate_exps) {
|
||||
cur = build_lora_mm_id(gate_exps, cur, selected_experts); // [n_ff, n_expert_used, n_tokens]
|
||||
@@ -874,6 +936,11 @@ ggml_tensor * llm_graph_context::build_moe_ffn(
|
||||
cur = up;
|
||||
}
|
||||
|
||||
if (gate_exps_b) {
|
||||
cur = ggml_add_id(ctx0, cur, gate_exps_b, selected_experts);
|
||||
cb(cur, "ffn_moe_gate_biased", il);
|
||||
}
|
||||
|
||||
switch (type_op) {
|
||||
case LLM_FFN_SILU:
|
||||
if (gate_exps) {
|
||||
@@ -891,6 +958,14 @@ ggml_tensor * llm_graph_context::build_moe_ffn(
|
||||
cur = ggml_gelu(ctx0, cur);
|
||||
cb(cur, "ffn_moe_gelu", il);
|
||||
} break;
|
||||
case LLM_FFN_SWIGLU_OAI_MOE:
|
||||
{
|
||||
// TODO: move to hparams?
|
||||
constexpr float alpha = 1.702f;
|
||||
constexpr float limit = 7.0f;
|
||||
cur = ggml_swiglu_oai(ctx0, cur, up, alpha, limit);
|
||||
cb(cur, "ffn_moe_swiglu_oai", il);
|
||||
} break;
|
||||
case LLM_FFN_RELU:
|
||||
if (gate_exps) {
|
||||
cur = ggml_reglu_split(ctx0, cur, up);
|
||||
@@ -906,6 +981,11 @@ ggml_tensor * llm_graph_context::build_moe_ffn(
|
||||
experts = build_lora_mm_id(down_exps, cur, selected_experts); // [n_embd, n_expert_used, n_tokens]
|
||||
cb(experts, "ffn_moe_down", il);
|
||||
|
||||
if (down_exps_b) {
|
||||
experts = ggml_add_id(ctx0, experts, down_exps_b, selected_experts);
|
||||
cb(experts, "ffn_moe_down_biased", il);
|
||||
}
|
||||
|
||||
if (!weight_before_ffn) {
|
||||
experts = ggml_mul(ctx0, experts, weights);
|
||||
cb(cur, "ffn_moe_weighted", il);
|
||||
@@ -1144,6 +1224,7 @@ ggml_tensor * llm_graph_context::build_attn_mha(
|
||||
ggml_tensor * kq_b,
|
||||
ggml_tensor * kq_mask,
|
||||
ggml_tensor * v_mla,
|
||||
ggml_tensor * sinks,
|
||||
float kq_scale) const {
|
||||
const bool v_trans = v->nb[1] > v->nb[2];
|
||||
|
||||
@@ -1180,7 +1261,8 @@ ggml_tensor * llm_graph_context::build_attn_mha(
|
||||
cur = ggml_flash_attn_ext(ctx0, q, k, v, kq_mask, kq_scale, hparams.f_max_alibi_bias,
|
||||
hparams.attn_soft_cap ? hparams.f_attn_logit_softcapping : 0.0f);
|
||||
|
||||
ggml_flash_attn_ext_set_prec(cur, GGML_PREC_F32);
|
||||
ggml_flash_attn_ext_add_sinks(cur, sinks);
|
||||
ggml_flash_attn_ext_set_prec (cur, GGML_PREC_F32);
|
||||
|
||||
if (v_mla) {
|
||||
#if 0
|
||||
@@ -1228,6 +1310,7 @@ ggml_tensor * llm_graph_context::build_attn_mha(
|
||||
}
|
||||
|
||||
kq = ggml_soft_max_ext(ctx0, kq, kq_mask, kq_scale, hparams.f_max_alibi_bias);
|
||||
ggml_soft_max_add_sinks(kq, sinks);
|
||||
|
||||
if (!v_trans) {
|
||||
// note: avoid this branch
|
||||
@@ -1298,7 +1381,7 @@ ggml_tensor * llm_graph_context::build_attn(
|
||||
ggml_tensor * k = k_cur;
|
||||
ggml_tensor * v = v_cur;
|
||||
|
||||
ggml_tensor * cur = build_attn_mha(q, k, v, kq_b, kq_mask, v_mla, kq_scale);
|
||||
ggml_tensor * cur = build_attn_mha(q, k, v, kq_b, kq_mask, v_mla, nullptr, kq_scale);
|
||||
cb(cur, "kqv_out", il);
|
||||
|
||||
if (wo) {
|
||||
@@ -1386,7 +1469,7 @@ ggml_tensor * llm_graph_context::build_attn(
|
||||
ggml_tensor * k = mctx_cur->get_k(ctx0, il);
|
||||
ggml_tensor * v = mctx_cur->get_v(ctx0, il);
|
||||
|
||||
ggml_tensor * cur = build_attn_mha(q, k, v, kq_b, kq_mask, v_mla, kq_scale);
|
||||
ggml_tensor * cur = build_attn_mha(q, k, v, kq_b, kq_mask, v_mla, nullptr, kq_scale);
|
||||
cb(cur, "kqv_out", il);
|
||||
|
||||
if (wo) {
|
||||
@@ -1415,6 +1498,32 @@ ggml_tensor * llm_graph_context::build_attn(
|
||||
ggml_tensor * v_mla,
|
||||
float kq_scale,
|
||||
int il) const {
|
||||
return build_attn_with_sinks(
|
||||
inp,
|
||||
wo,
|
||||
wo_b,
|
||||
q_cur,
|
||||
k_cur,
|
||||
v_cur,
|
||||
kq_b,
|
||||
v_mla,
|
||||
nullptr,
|
||||
kq_scale,
|
||||
il);
|
||||
}
|
||||
|
||||
ggml_tensor * llm_graph_context::build_attn_with_sinks(
|
||||
llm_graph_input_attn_kv_unified_iswa * inp,
|
||||
ggml_tensor * wo,
|
||||
ggml_tensor * wo_b,
|
||||
ggml_tensor * q_cur,
|
||||
ggml_tensor * k_cur,
|
||||
ggml_tensor * v_cur,
|
||||
ggml_tensor * kq_b,
|
||||
ggml_tensor * v_mla,
|
||||
ggml_tensor * sinks,
|
||||
float kq_scale,
|
||||
int il) const {
|
||||
// these nodes are added to the graph together so that they are not reordered
|
||||
// by doing so, the number of splits in the graph is reduced
|
||||
ggml_build_forward_expand(gf, q_cur);
|
||||
@@ -1452,7 +1561,7 @@ ggml_tensor * llm_graph_context::build_attn(
|
||||
ggml_tensor * k = mctx_cur->get_k(ctx0, il);
|
||||
ggml_tensor * v = mctx_cur->get_v(ctx0, il);
|
||||
|
||||
ggml_tensor * cur = build_attn_mha(q, k, v, kq_b, kq_mask, v_mla, kq_scale);
|
||||
ggml_tensor * cur = build_attn_mha(q, k, v, kq_b, kq_mask, v_mla, sinks, kq_scale);
|
||||
cb(cur, "kqv_out", il);
|
||||
|
||||
if (wo) {
|
||||
@@ -1506,7 +1615,7 @@ ggml_tensor * llm_graph_context::build_attn(
|
||||
ggml_tensor * k = k_cur;
|
||||
ggml_tensor * v = v_cur;
|
||||
|
||||
ggml_tensor * cur = build_attn_mha(q, k, v, kq_b, kq_mask, v_mla, kq_scale);
|
||||
ggml_tensor * cur = build_attn_mha(q, k, v, kq_b, kq_mask, v_mla, nullptr, kq_scale);
|
||||
cb(cur, "kqv_out", il);
|
||||
|
||||
if (wo) {
|
||||
|
||||
@@ -39,6 +39,7 @@ enum llm_ffn_op_type {
|
||||
LLM_FFN_SWIGLU,
|
||||
LLM_FFN_GEGLU,
|
||||
LLM_FFN_REGLU,
|
||||
LLM_FFN_SWIGLU_OAI_MOE,
|
||||
};
|
||||
|
||||
enum llm_ffn_gate_type {
|
||||
@@ -619,6 +620,7 @@ struct llm_graph_context {
|
||||
llm_ffn_gate_type type_gate,
|
||||
int il) const;
|
||||
|
||||
// build MoE FFN without bias tensors
|
||||
ggml_tensor * build_moe_ffn(
|
||||
ggml_tensor * cur,
|
||||
ggml_tensor * gate_inp,
|
||||
@@ -636,6 +638,27 @@ struct llm_graph_context {
|
||||
int il,
|
||||
ggml_tensor * probs_in = nullptr) const;
|
||||
|
||||
ggml_tensor * build_moe_ffn(
|
||||
ggml_tensor * cur,
|
||||
ggml_tensor * gate_inp,
|
||||
ggml_tensor * gate_inp_b,
|
||||
ggml_tensor * up_exps,
|
||||
ggml_tensor * up_exps_b,
|
||||
ggml_tensor * gate_exps,
|
||||
ggml_tensor * gate_exps_b,
|
||||
ggml_tensor * down_exps,
|
||||
ggml_tensor * down_exps_b,
|
||||
ggml_tensor * exp_probs_b,
|
||||
int64_t n_expert,
|
||||
int64_t n_expert_used,
|
||||
llm_ffn_op_type type_op,
|
||||
bool norm_w,
|
||||
bool scale_w,
|
||||
float w_scale,
|
||||
llama_expert_gating_func_type gating_op,
|
||||
int il,
|
||||
ggml_tensor * probs_in = nullptr) const;
|
||||
|
||||
//
|
||||
// inputs
|
||||
//
|
||||
@@ -662,6 +685,7 @@ struct llm_graph_context {
|
||||
ggml_tensor * v, // [n_embd_head_v, n_head_v, n_tokens] (v_trans == false)
|
||||
ggml_tensor * kq_b,
|
||||
ggml_tensor * kq_mask,
|
||||
ggml_tensor * sinks,
|
||||
ggml_tensor * v_mla, // [n_embd_head_v_mla, n_embd_head_v, n_head_v]
|
||||
float kq_scale) const;
|
||||
|
||||
@@ -708,6 +732,20 @@ struct llm_graph_context {
|
||||
float kq_scale,
|
||||
int il) const;
|
||||
|
||||
// TODO: temporary to keep the diff small. after the code is public will refactor to simplify this
|
||||
ggml_tensor * build_attn_with_sinks(
|
||||
llm_graph_input_attn_kv_unified_iswa * inp,
|
||||
ggml_tensor * wo,
|
||||
ggml_tensor * wo_b,
|
||||
ggml_tensor * q_cur, // [n_embd_head_q, n_head_q, n_tokens]
|
||||
ggml_tensor * k_cur, // [n_embd_head_k, n_head_k, n_tokens] optional
|
||||
ggml_tensor * v_cur, // [n_embd_head_v, n_head_v, n_tokens] optional
|
||||
ggml_tensor * kq_b,
|
||||
ggml_tensor * v_mla, // [n_embd_head_v_mla, n_embd_head_v, n_head_v]
|
||||
ggml_tensor * sinks, // [n_head_q]
|
||||
float kq_scale,
|
||||
int il) const;
|
||||
|
||||
llm_graph_input_attn_cross * build_attn_inp_cross() const;
|
||||
|
||||
ggml_tensor * build_attn(
|
||||
|
||||
@@ -9,9 +9,10 @@
|
||||
#define LLAMA_MAX_EXPERTS 384 // Kimi-K2
|
||||
|
||||
enum llama_expert_gating_func_type {
|
||||
LLAMA_EXPERT_GATING_FUNC_TYPE_NONE = 0,
|
||||
LLAMA_EXPERT_GATING_FUNC_TYPE_SOFTMAX = 1,
|
||||
LLAMA_EXPERT_GATING_FUNC_TYPE_SIGMOID = 2,
|
||||
LLAMA_EXPERT_GATING_FUNC_TYPE_NONE = 0,
|
||||
LLAMA_EXPERT_GATING_FUNC_TYPE_SOFTMAX = 1,
|
||||
LLAMA_EXPERT_GATING_FUNC_TYPE_SIGMOID = 2,
|
||||
LLAMA_EXPERT_GATING_FUNC_TYPE_SOFTMAX_WEIGHT = 3, // applied to the router weights instead of the logits
|
||||
};
|
||||
|
||||
enum llama_swa_type {
|
||||
|
||||
@@ -35,6 +35,7 @@ static std::string llama_model_ftype_name(llama_ftype ftype) {
|
||||
case LLAMA_FTYPE_MOSTLY_Q5_0: return "Q5_0";
|
||||
case LLAMA_FTYPE_MOSTLY_Q5_1: return "Q5_1";
|
||||
case LLAMA_FTYPE_MOSTLY_Q8_0: return "Q8_0";
|
||||
case LLAMA_FTYPE_MOSTLY_MXFP4_MOE: return "MXFP4 MoE";
|
||||
case LLAMA_FTYPE_MOSTLY_Q2_K: return "Q2_K - Medium";
|
||||
case LLAMA_FTYPE_MOSTLY_Q2_K_S: return "Q2_K - Small";
|
||||
case LLAMA_FTYPE_MOSTLY_Q3_K_S: return "Q3_K - Small";
|
||||
|
||||
@@ -192,6 +192,13 @@ static bool weight_buft_supported(const llama_hparams & hparams, ggml_tensor * w
|
||||
ggml_tensor * a = ggml_new_tensor_4d(ctx, GGML_TYPE_F32, w->ne[0], w->ne[1], w->ne[2], w->ne[3]);
|
||||
op_tensor = ggml_add(ctx, a, w);
|
||||
} break;
|
||||
case GGML_OP_ADD_ID:
|
||||
{
|
||||
int n_expert_used = hparams.n_expert_used;
|
||||
ggml_tensor * a = ggml_new_tensor_3d(ctx, GGML_TYPE_F32, w->ne[0], n_expert_used, 512);
|
||||
ggml_tensor * c = ggml_new_tensor_2d(ctx, GGML_TYPE_I32, n_expert_used, 512);
|
||||
op_tensor = ggml_add_id(ctx, a, w, c);
|
||||
} break;
|
||||
case GGML_OP_MUL:
|
||||
{
|
||||
ggml_tensor * a = ggml_new_tensor_4d(ctx, GGML_TYPE_F32, w->ne[0], w->ne[1], w->ne[2], w->ne[3]);
|
||||
@@ -260,6 +267,10 @@ static bool weight_buft_supported(const llama_hparams & hparams, ggml_tensor * w
|
||||
ggml_tensor * b = ggml_new_tensor_4d(ctx, GGML_TYPE_F32, n_embd, w->ne[1], 1, 1);
|
||||
op_tensor = ggml_im2col(ctx, w, b, 1, 0, 0, 0, 1, 0, false, GGML_TYPE_F16);
|
||||
} break;
|
||||
case GGML_OP_SCALE:
|
||||
{
|
||||
op_tensor = ggml_scale(ctx, w, 1.0f);
|
||||
} break;
|
||||
default:
|
||||
GGML_ABORT("%s: missing test for op %s for tensor %s", __func__, ggml_op_name(op), w->name);
|
||||
}
|
||||
@@ -1813,6 +1824,17 @@ void llama_model::load_hparams(llama_model_loader & ml) {
|
||||
default: type = LLM_TYPE_UNKNOWN;
|
||||
}
|
||||
} break;
|
||||
case LLM_ARCH_OPENAI_MOE:
|
||||
{
|
||||
ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps);
|
||||
ml.get_key(LLM_KV_EXPERT_FEED_FORWARD_LENGTH, hparams.n_ff_exp);
|
||||
ml.get_key(LLM_KV_ATTENTION_SLIDING_WINDOW, hparams.n_swa);
|
||||
|
||||
hparams.swa_type = LLAMA_SWA_TYPE_STANDARD;
|
||||
hparams.set_swa_pattern(2);
|
||||
|
||||
// TODO: switch (hparams.n_layer)
|
||||
} break;
|
||||
case LLM_ARCH_LFM2:
|
||||
{
|
||||
ml.get_key(LLM_KV_SHORTCONV_L_CACHE, hparams.n_shortconv_l_cache);
|
||||
@@ -2045,11 +2067,15 @@ bool llama_model::load_tensors(llama_model_loader & ml) {
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
// tensors with "bias" suffix are always used with GGML_OP_ADD
|
||||
// tensors with "bias" suffix are always used with GGML_OP_ADD or GGML_OP_ADD_ID
|
||||
ggml_op op;
|
||||
bool bias = tn.suffix != nullptr && strcmp(tn.suffix, "bias") == 0;
|
||||
if (bias) {
|
||||
op = GGML_OP_ADD;
|
||||
if (info.op == GGML_OP_MUL_MAT_ID) {
|
||||
op = GGML_OP_ADD_ID;
|
||||
} else {
|
||||
op = GGML_OP_ADD;
|
||||
}
|
||||
} else {
|
||||
op = info.op;
|
||||
}
|
||||
@@ -5400,6 +5426,46 @@ bool llama_model::load_tensors(llama_model_loader & ml) {
|
||||
layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, 0);
|
||||
}
|
||||
} break;
|
||||
case LLM_ARCH_OPENAI_MOE:
|
||||
{
|
||||
const int64_t n_ff_exp = hparams.n_ff_exp;
|
||||
|
||||
tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0);
|
||||
|
||||
// output
|
||||
output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0);
|
||||
output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, 0);
|
||||
|
||||
for (int i = 0; i < n_layer; ++i) {
|
||||
auto & layer = layers[i];
|
||||
|
||||
layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0);
|
||||
layer.attn_post_norm = create_tensor(tn(LLM_TENSOR_ATTN_POST_NORM, "weight", i), {n_embd}, 0);
|
||||
|
||||
layer.wq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "weight", i), {n_embd, n_head * n_rot}, 0);
|
||||
layer.wk = create_tensor(tn(LLM_TENSOR_ATTN_K, "weight", i), {n_embd, n_head_kv * n_rot}, 0);
|
||||
layer.wv = create_tensor(tn(LLM_TENSOR_ATTN_V, "weight", i), {n_embd, n_head_kv * n_rot}, 0);
|
||||
layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_head * n_rot, n_embd}, 0);
|
||||
|
||||
layer.attn_sinks = create_tensor(tn(LLM_TENSOR_ATTN_SINKS, "weight", i), {n_head}, 0);
|
||||
|
||||
layer.ffn_gate_inp = create_tensor(tn(LLM_TENSOR_FFN_GATE_INP, "weight", i), { n_embd, n_expert}, 0);
|
||||
layer.ffn_gate_exps = create_tensor(tn(LLM_TENSOR_FFN_GATE_EXPS, "weight", i), { n_embd, n_ff_exp, n_expert}, 0);
|
||||
layer.ffn_down_exps = create_tensor(tn(LLM_TENSOR_FFN_DOWN_EXPS, "weight", i), {n_ff_exp, n_embd, n_expert}, 0);
|
||||
layer.ffn_up_exps = create_tensor(tn(LLM_TENSOR_FFN_UP_EXPS, "weight", i), { n_embd, n_ff_exp, n_expert}, 0);
|
||||
|
||||
// bias
|
||||
layer.bq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "bias", i), {n_head * n_rot}, 0);
|
||||
layer.bk = create_tensor(tn(LLM_TENSOR_ATTN_K, "bias", i), {n_head_kv * n_rot}, 0);
|
||||
layer.bv = create_tensor(tn(LLM_TENSOR_ATTN_V, "bias", i), {n_head_kv * n_rot}, 0);
|
||||
layer.bo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "bias", i), {n_embd}, 0);
|
||||
|
||||
layer.ffn_gate_inp_b = create_tensor(tn(LLM_TENSOR_FFN_GATE_INP, "bias", i), {n_expert}, 0);
|
||||
layer.ffn_gate_exps_b = create_tensor(tn(LLM_TENSOR_FFN_GATE_EXPS, "bias", i), {n_ff_exp, n_expert}, 0);
|
||||
layer.ffn_down_exps_b = create_tensor(tn(LLM_TENSOR_FFN_DOWN_EXPS, "bias", i), { n_embd, n_expert}, 0);
|
||||
layer.ffn_up_exps_b = create_tensor(tn(LLM_TENSOR_FFN_UP_EXPS, "bias", i), {n_ff_exp, n_expert}, 0);
|
||||
}
|
||||
} break;
|
||||
case LLM_ARCH_LFM2:
|
||||
{
|
||||
tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0);
|
||||
@@ -5772,7 +5838,7 @@ void llama_model::print_info() const {
|
||||
LLAMA_LOG_INFO("%s: n_ff_shexp = %d\n", __func__, hparams.n_ff_shexp);
|
||||
}
|
||||
|
||||
if (arch == LLM_ARCH_QWEN3MOE) {
|
||||
if (arch == LLM_ARCH_QWEN3MOE || arch == LLM_ARCH_OPENAI_MOE) {
|
||||
LLAMA_LOG_INFO("%s: n_ff_exp = %d\n", __func__, hparams.n_ff_exp);
|
||||
}
|
||||
|
||||
@@ -17540,6 +17606,136 @@ struct llm_build_smollm3 : public llm_graph_context {
|
||||
}
|
||||
};
|
||||
|
||||
struct llm_build_openai_moe_iswa : public llm_graph_context {
|
||||
llm_build_openai_moe_iswa(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) {
|
||||
ggml_tensor * cur;
|
||||
ggml_tensor * inpL;
|
||||
|
||||
inpL = build_inp_embd(model.tok_embd);
|
||||
|
||||
// inp_pos - contains the positions
|
||||
ggml_tensor * inp_pos = build_inp_pos();
|
||||
|
||||
auto * inp_attn = build_attn_inp_kv_unified_iswa();
|
||||
|
||||
for (int il = 0; il < n_layer; ++il) {
|
||||
ggml_tensor * inpSA = inpL;
|
||||
|
||||
// norm
|
||||
cur = build_norm(inpL,
|
||||
model.layers[il].attn_norm, nullptr,
|
||||
LLM_NORM_RMS, il);
|
||||
cb(cur, "attn_norm", il);
|
||||
|
||||
// self-attention
|
||||
{
|
||||
// compute Q and K and RoPE them
|
||||
ggml_tensor * Qcur = build_lora_mm(model.layers[il].wq, cur);
|
||||
cb(Qcur, "Qcur", il);
|
||||
if (model.layers[il].bq) {
|
||||
Qcur = ggml_add(ctx0, Qcur, model.layers[il].bq);
|
||||
cb(Qcur, "Qcur", il);
|
||||
}
|
||||
|
||||
ggml_tensor * Kcur = build_lora_mm(model.layers[il].wk, cur);
|
||||
cb(Kcur, "Kcur", il);
|
||||
if (model.layers[il].bk) {
|
||||
Kcur = ggml_add(ctx0, Kcur, model.layers[il].bk);
|
||||
cb(Kcur, "Kcur", il);
|
||||
}
|
||||
|
||||
ggml_tensor * Vcur = build_lora_mm(model.layers[il].wv, cur);
|
||||
cb(Vcur, "Vcur", il);
|
||||
if (model.layers[il].bv) {
|
||||
Vcur = ggml_add(ctx0, Vcur, model.layers[il].bv);
|
||||
cb(Vcur, "Vcur", il);
|
||||
}
|
||||
|
||||
Qcur = ggml_reshape_3d(ctx0, Qcur, n_rot, n_head, n_tokens);
|
||||
Kcur = ggml_reshape_3d(ctx0, Kcur, n_rot, n_head_kv, n_tokens);
|
||||
Vcur = ggml_reshape_3d(ctx0, Vcur, n_rot, n_head_kv, n_tokens);
|
||||
|
||||
Qcur = ggml_rope_ext(
|
||||
ctx0, Qcur, inp_pos, nullptr,
|
||||
n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
|
||||
ext_factor, attn_factor, beta_fast, beta_slow
|
||||
);
|
||||
|
||||
Kcur = ggml_rope_ext(
|
||||
ctx0, Kcur, inp_pos, nullptr,
|
||||
n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
|
||||
ext_factor, attn_factor, beta_fast, beta_slow
|
||||
);
|
||||
|
||||
cb(Qcur, "Qcur", il);
|
||||
cb(Kcur, "Kcur", il);
|
||||
cb(Vcur, "Vcur", il);
|
||||
|
||||
cur = build_attn_with_sinks(inp_attn,
|
||||
model.layers[il].wo, model.layers[il].bo,
|
||||
Qcur, Kcur, Vcur, nullptr, nullptr, model.layers[il].attn_sinks, 1.0f/sqrtf(float(n_rot)), il);
|
||||
|
||||
cb(cur, "attn_out", il);
|
||||
}
|
||||
|
||||
if (il == n_layer - 1) {
|
||||
// skip computing output for unused tokens
|
||||
ggml_tensor * inp_out_ids = build_inp_out_ids();
|
||||
cur = ggml_get_rows(ctx0, cur, inp_out_ids);
|
||||
inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids);
|
||||
}
|
||||
|
||||
ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpSA);
|
||||
cb(ffn_inp, "ffn_inp", il);
|
||||
|
||||
cur = ffn_inp;
|
||||
cur = build_norm(cur,
|
||||
model.layers[il].attn_post_norm, nullptr,
|
||||
LLM_NORM_RMS, il);
|
||||
cb(cur, "attn_post_norm", il);
|
||||
|
||||
// MoE branch
|
||||
cur = build_moe_ffn(cur,
|
||||
model.layers[il].ffn_gate_inp, model.layers[il].ffn_gate_inp_b,
|
||||
model.layers[il].ffn_up_exps, model.layers[il].ffn_up_exps_b,
|
||||
model.layers[il].ffn_gate_exps, model.layers[il].ffn_gate_exps_b,
|
||||
model.layers[il].ffn_down_exps, model.layers[il].ffn_down_exps_b,
|
||||
nullptr,
|
||||
n_expert, n_expert_used,
|
||||
LLM_FFN_SWIGLU_OAI_MOE, false,
|
||||
false, 0.0,
|
||||
LLAMA_EXPERT_GATING_FUNC_TYPE_SOFTMAX_WEIGHT,
|
||||
il);
|
||||
cb(cur, "ffn_moe_out", il);
|
||||
|
||||
cur = ggml_add(ctx0, cur, ffn_inp);
|
||||
|
||||
cur = build_cvec(cur, il);
|
||||
cb(cur, "l_out", il);
|
||||
|
||||
// input for next layer
|
||||
inpL = cur;
|
||||
}
|
||||
|
||||
cur = inpL;
|
||||
|
||||
cur = build_norm(cur,
|
||||
model.output_norm, NULL,
|
||||
LLM_NORM_RMS, -1);
|
||||
|
||||
cb(cur, "result_norm", -1);
|
||||
res->t_embd = cur;
|
||||
|
||||
// lm_head
|
||||
cur = build_lora_mm(model.output, cur);
|
||||
|
||||
cb(cur, "result_output", -1);
|
||||
res->t_logits = cur;
|
||||
|
||||
ggml_build_forward_expand(gf, cur);
|
||||
}
|
||||
};
|
||||
|
||||
struct llm_build_lfm2 : public llm_graph_context {
|
||||
const llama_model & model;
|
||||
|
||||
@@ -18283,6 +18479,10 @@ ggml_cgraph * llama_model::build_graph(const llm_graph_params & params) const {
|
||||
{
|
||||
llm = std::make_unique<llm_build_smollm3>(*this, params);
|
||||
} break;
|
||||
case LLM_ARCH_OPENAI_MOE:
|
||||
{
|
||||
llm = std::make_unique<llm_build_openai_moe_iswa>(*this, params);
|
||||
} break;
|
||||
case LLM_ARCH_FALCON_H1:
|
||||
{
|
||||
llm = std::make_unique<llm_build_falcon_h1>(*this, params);
|
||||
@@ -18498,6 +18698,7 @@ llama_rope_type llama_model_rope_type(const llama_model * model) {
|
||||
case LLM_ARCH_MINICPM3:
|
||||
case LLM_ARCH_DOTS1:
|
||||
case LLM_ARCH_HUNYUAN_MOE:
|
||||
case LLM_ARCH_OPENAI_MOE:
|
||||
case LLM_ARCH_HUNYUAN_DENSE:
|
||||
case LLM_ARCH_LFM2:
|
||||
case LLM_ARCH_SMALLTHINKER:
|
||||
|
||||
@@ -252,10 +252,14 @@ struct llama_layer {
|
||||
struct ggml_tensor * ffn_up_enc = nullptr;
|
||||
|
||||
// ff MoE
|
||||
struct ggml_tensor * ffn_gate_inp = nullptr;
|
||||
struct ggml_tensor * ffn_gate_exps = nullptr;
|
||||
struct ggml_tensor * ffn_down_exps = nullptr;
|
||||
struct ggml_tensor * ffn_up_exps = nullptr;
|
||||
struct ggml_tensor * ffn_gate_inp = nullptr;
|
||||
struct ggml_tensor * ffn_gate_exps = nullptr;
|
||||
struct ggml_tensor * ffn_down_exps = nullptr;
|
||||
struct ggml_tensor * ffn_up_exps = nullptr;
|
||||
struct ggml_tensor * ffn_gate_inp_b = nullptr;
|
||||
struct ggml_tensor * ffn_gate_exps_b = nullptr;
|
||||
struct ggml_tensor * ffn_down_exps_b = nullptr;
|
||||
struct ggml_tensor * ffn_up_exps_b = nullptr;
|
||||
|
||||
// ff shared expert (shexp)
|
||||
struct ggml_tensor * ffn_gate_inp_shexp = nullptr;
|
||||
@@ -360,6 +364,9 @@ struct llama_layer {
|
||||
struct ggml_tensor * laurel_r = nullptr;
|
||||
struct ggml_tensor * laurel_post_norm = nullptr;
|
||||
|
||||
// openai-moe
|
||||
struct ggml_tensor * attn_sinks = nullptr;
|
||||
|
||||
struct llama_layer_posnet posnet;
|
||||
|
||||
struct llama_layer_convnext convnext;
|
||||
|
||||
@@ -211,7 +211,10 @@ static ggml_type llama_tensor_get_type(quantize_state_impl & qs, ggml_type new_t
|
||||
const int64_t nx = tensor->ne[0];
|
||||
const int64_t qk_k = ggml_blck_size(new_type);
|
||||
|
||||
if (arch == LLM_ARCH_FALCON || nx % qk_k != 0) {
|
||||
if (ftype == LLAMA_FTYPE_MOSTLY_MXFP4_MOE) {
|
||||
new_type = GGML_TYPE_Q8_0;
|
||||
}
|
||||
else if (arch == LLM_ARCH_FALCON || nx % qk_k != 0) {
|
||||
new_type = GGML_TYPE_Q8_0;
|
||||
}
|
||||
else if (ftype == LLAMA_FTYPE_MOSTLY_IQ2_XXS || ftype == LLAMA_FTYPE_MOSTLY_IQ2_XS || ftype == LLAMA_FTYPE_MOSTLY_IQ3_XXS ||
|
||||
@@ -223,6 +226,14 @@ static ggml_type llama_tensor_get_type(quantize_state_impl & qs, ggml_type new_t
|
||||
new_type = GGML_TYPE_Q6_K;
|
||||
}
|
||||
}
|
||||
} else if (ftype == LLAMA_FTYPE_MOSTLY_MXFP4_MOE) {
|
||||
// MoE tensors -> MXFP4
|
||||
// other tensors -> Q8_0
|
||||
if (tensor->ne[2] > 1) {
|
||||
new_type = GGML_TYPE_MXFP4;
|
||||
} else {
|
||||
new_type = GGML_TYPE_Q8_0;
|
||||
}
|
||||
} else if (name == "token_embd.weight" || name == "per_layer_token_embd.weight") {
|
||||
if (qs.params->token_embedding_type < GGML_TYPE_COUNT) {
|
||||
new_type = qs.params->token_embedding_type;
|
||||
@@ -533,6 +544,8 @@ static void llama_model_quantize_impl(const std::string & fname_inp, const std::
|
||||
case LLAMA_FTYPE_MOSTLY_BF16: default_type = GGML_TYPE_BF16; break;
|
||||
case LLAMA_FTYPE_ALL_F32: default_type = GGML_TYPE_F32; break;
|
||||
|
||||
case LLAMA_FTYPE_MOSTLY_MXFP4_MOE: default_type = GGML_TYPE_MXFP4; break;
|
||||
|
||||
// K-quants
|
||||
case LLAMA_FTYPE_MOSTLY_Q2_K_S:
|
||||
case LLAMA_FTYPE_MOSTLY_Q2_K: default_type = GGML_TYPE_Q2_K; break;
|
||||
@@ -984,6 +997,29 @@ static void llama_model_quantize_impl(const std::string & fname_inp, const std::
|
||||
const float * imatrix_03 = imatrix ? imatrix + i03 * n_per_row : nullptr;
|
||||
|
||||
new_size += llama_tensor_quantize_impl(new_type, f32_data_03, new_data_03, chunk_size, nrows, n_per_row, imatrix_03, workers, nthread_use);
|
||||
|
||||
// TODO: temporary sanity check that the F16 -> MXFP4 is lossless
|
||||
#if 1
|
||||
if (new_type == GGML_TYPE_MXFP4) {
|
||||
auto * x = f32_data_03;
|
||||
|
||||
//LLAMA_LOG_INFO("nrows = %d, n_per_row = %d\n", nrows, n_per_row);
|
||||
std::vector<float> deq(nrows*n_per_row);
|
||||
const ggml_type_traits * qtype = ggml_get_type_traits(new_type);
|
||||
qtype->to_float(new_data_03, deq.data(), deq.size());
|
||||
|
||||
double err = 0.0f;
|
||||
for (int i = 0; i < (int) deq.size(); ++i) {
|
||||
err += fabsf(deq[i] - x[i]);
|
||||
//if (fabsf(deq[i] - x[i]) > 0.00001 && i < 256) {
|
||||
if (deq[i] != x[i]) {
|
||||
LLAMA_LOG_INFO("deq[%d] = %f, x[%d] = %f\n", i, deq[i], i, x[i]);
|
||||
}
|
||||
}
|
||||
//LLAMA_LOG_INFO("err = %f\n", err);
|
||||
GGML_ASSERT(err == 0.00000);
|
||||
}
|
||||
#endif
|
||||
}
|
||||
LLAMA_LOG_INFO("size = %8.2f MiB -> %8.2f MiB\n", ggml_nbytes(tensor)/1024.0/1024.0, new_size/1024.0/1024.0);
|
||||
}
|
||||
|
||||
@@ -2314,6 +2314,8 @@ void llama_vocab::impl::load(llama_model_loader & ml, const LLM_KV & kv) {
|
||||
|| t.first == "<|eot_id|>"
|
||||
|| t.first == "<|im_end|>"
|
||||
|| t.first == "<|end|>"
|
||||
|| t.first == "<|return|>" // o200k_harmony
|
||||
|| t.first == "<|call|>" // o200k_harmony
|
||||
|| t.first == "<end_of_turn>"
|
||||
|| t.first == "<|endoftext|>"
|
||||
|| t.first == "<|eom_id|>"
|
||||
@@ -2337,6 +2339,13 @@ void llama_vocab::impl::load(llama_model_loader & ml, const LLM_KV & kv) {
|
||||
}
|
||||
}
|
||||
|
||||
// @ngxson : quick hack for gpt-oss, always render these tokens
|
||||
for (const auto & t : token_to_id) {
|
||||
if (t.first == "<|channel|>" || t.first == "<|message|>" || t.first == "<|start|>") {
|
||||
id_to_token[t.second].attr = LLAMA_TOKEN_ATTR_USER_DEFINED;
|
||||
}
|
||||
}
|
||||
|
||||
// sanity checks
|
||||
if (special_eos_id != LLAMA_TOKEN_NULL && special_eog_ids.count(special_eos_id) == 0) {
|
||||
special_eog_ids.insert(special_eos_id);
|
||||
@@ -2352,6 +2361,36 @@ void llama_vocab::impl::load(llama_model_loader & ml, const LLM_KV & kv) {
|
||||
special_eog_ids.insert(special_eom_id);
|
||||
LLAMA_LOG_WARN("%s: special_eom_id is not in special_eog_ids - the tokenizer config may be incorrect\n", __func__);
|
||||
}
|
||||
|
||||
// TODO: workaround for o200k_harmony tokenizer: the "<|end|>" token should not be EOG
|
||||
// we don't have a good way to detect this, so for now, if we have "<|return|>" and "<|call|>" tokens,
|
||||
// we remove the "<|end|>" token from the EOG list
|
||||
{
|
||||
bool has_return = false;
|
||||
bool has_call = false;
|
||||
bool has_end = false;
|
||||
|
||||
llama_token end_id = LLAMA_TOKEN_NULL;
|
||||
|
||||
LLAMA_LOG_INFO("%s: printing all EOG tokens:\n", __func__);
|
||||
for (auto tid : special_eog_ids) {
|
||||
LLAMA_LOG_INFO("%s: - %d ('%s')\n", __func__, tid, id_to_token[tid].text.c_str());
|
||||
|
||||
if (id_to_token[tid].text == "<|return|>") {
|
||||
has_return = true;
|
||||
} else if (id_to_token[tid].text == "<|call|>") {
|
||||
has_call = true;
|
||||
} else if (id_to_token[tid].text == "<|end|>") {
|
||||
has_end = true;
|
||||
end_id = tid;
|
||||
}
|
||||
}
|
||||
|
||||
if (has_return && has_call && has_end) {
|
||||
special_eog_ids.erase(end_id);
|
||||
LLAMA_LOG_WARN("%s: special_eog_ids contains both '<|return|>' and '<|call|>' tokens, removing '<|end|>' token from EOG list\n", __func__);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// build special tokens cache
|
||||
|
||||
Reference in New Issue
Block a user