llama : add support for qwen3 reranker (#15824)

This commit is contained in:
Douglas Hanley
2025-09-25 03:53:09 -05:00
committed by GitHub
parent dfcd53f7ec
commit b5bd037832
9 changed files with 166 additions and 78 deletions

View File

@@ -721,6 +721,7 @@ static const std::map<llm_arch, std::map<llm_tensor, const char *>> LLM_TENSOR_N
{ LLM_TENSOR_TOKEN_EMBD, "token_embd" },
{ LLM_TENSOR_OUTPUT_NORM, "output_norm" },
{ LLM_TENSOR_OUTPUT, "output" },
{ LLM_TENSOR_CLS_OUT, "cls.output" },
{ LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" },
{ LLM_TENSOR_ATTN_Q, "blk.%d.attn_q" },
{ LLM_TENSOR_ATTN_Q_NORM, "blk.%d.attn_q_norm" },

View File

@@ -204,7 +204,10 @@ void llm_graph_input_cls::set_input(const llama_ubatch * ubatch) {
std::vector<int> target_pos(n_seqs_unq, -1);
std::vector<int> target_row(n_seqs_unq, -1);
bool last = cparams.pooling_type == LLAMA_POOLING_TYPE_LAST;
const bool last = (
cparams.pooling_type == LLAMA_POOLING_TYPE_LAST ||
(cparams.pooling_type == LLAMA_POOLING_TYPE_RANK && arch == LLM_ARCH_QWEN3) // qwen3 reranking & embedding models use last token
);
for (int i = 0; i < n_tokens; ++i) {
const llama_pos pos = ubatch->pos[i];
@@ -1177,7 +1180,7 @@ ggml_tensor * llm_graph_context::build_inp_mean() const {
}
ggml_tensor * llm_graph_context::build_inp_cls() const {
auto inp = std::make_unique<llm_graph_input_cls>(cparams);
auto inp = std::make_unique<llm_graph_input_cls>(cparams, arch);
auto & cur = inp->cls;
@@ -1877,34 +1880,32 @@ void llm_graph_context::build_pooling(
case LLAMA_POOLING_TYPE_RANK:
{
ggml_tensor * inp_cls = build_inp_cls();
inp = ggml_get_rows(ctx0, inp, inp_cls);
cur = ggml_get_rows(ctx0, inp, inp_cls);
// classification head
// https://github.com/huggingface/transformers/blob/5af7d41e49bbfc8319f462eb45253dcb3863dfb7/src/transformers/models/roberta/modeling_roberta.py#L1566
if (cls) {
// classification head
// https://github.com/huggingface/transformers/blob/5af7d41e49bbfc8319f462eb45253dcb3863dfb7/src/transformers/models/roberta/modeling_roberta.py#L1566
cur = ggml_mul_mat(ctx0, cls, inp);
cur = ggml_mul_mat(ctx0, cls, cur);
if (cls_b) {
cur = ggml_add(ctx0, cur, cls_b);
}
cur = ggml_tanh(ctx0, cur);
}
// some models don't have `cls_out`, for example: https://huggingface.co/jinaai/jina-reranker-v1-tiny-en
// https://huggingface.co/jinaai/jina-reranker-v1-tiny-en/blob/cb5347e43979c3084a890e3f99491952603ae1b7/modeling_bert.py#L884-L896
if (cls_out) {
cur = ggml_mul_mat(ctx0, cls_out, cur);
if (cls_out_b) {
cur = ggml_add(ctx0, cur, cls_out_b);
}
}
} else if (cls_out) {
// Single layer classification head (direct projection)
// https://github.com/huggingface/transformers/blob/f4fc42216cd56ab6b68270bf80d811614d8d59e4/src/transformers/models/bert/modeling_bert.py#L1476
cur = ggml_mul_mat(ctx0, cls_out, inp);
// some models don't have `cls_out`, for example: https://huggingface.co/jinaai/jina-reranker-v1-tiny-en
// https://huggingface.co/jinaai/jina-reranker-v1-tiny-en/blob/cb5347e43979c3084a890e3f99491952603ae1b7/modeling_bert.py#L884-L896
// Single layer classification head (direct projection)
// https://github.com/huggingface/transformers/blob/f4fc42216cd56ab6b68270bf80d811614d8d59e4/src/transformers/models/bert/modeling_bert.py#L1476
if (cls_out) {
cur = ggml_mul_mat(ctx0, cls_out, cur);
if (cls_out_b) {
cur = ggml_add(ctx0, cur, cls_out_b);
}
} else {
GGML_ABORT("RANK pooling requires either cls+cls_b or cls_out+cls_out_b");
}
// softmax for qwen3 reranker
if (arch == LLM_ARCH_QWEN3) {
cur = ggml_soft_max(ctx0, cur);
}
} break;
default:

View File

@@ -206,7 +206,7 @@ public:
class llm_graph_input_cls : public llm_graph_input_i {
public:
llm_graph_input_cls(const llama_cparams & cparams) : cparams(cparams) {}
llm_graph_input_cls(const llama_cparams & cparams, const llm_arch arch) : cparams(cparams), arch(arch) {}
virtual ~llm_graph_input_cls() = default;
void set_input(const llama_ubatch * ubatch) override;
@@ -214,6 +214,7 @@ public:
ggml_tensor * cls; // I32 [n_batch]
const llama_cparams cparams;
const llm_arch arch;
};
class llm_graph_input_rs : public llm_graph_input_i {

View File

@@ -3167,6 +3167,9 @@ bool llama_model::load_tensors(llama_model_loader & ml) {
output = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, TENSOR_DUPLICATED);
}
// output rerank head
cls_out = create_tensor(tn(LLM_TENSOR_CLS_OUT, "weight"), {n_embd, hparams.n_cls_out}, TENSOR_NOT_REQUIRED);
for (int i = 0; i < n_layer; ++i) {
auto & layer = layers[i];