退回到 b7516 版本

This commit is contained in:
2026-01-16 18:12:13 +08:00
parent 9d7890f8c6
commit 7e0d40b535
380 changed files with 18454 additions and 38808 deletions

View File

@@ -33,7 +33,7 @@ static void batch_add_seq(llama_batch & batch, const std::vector<int32_t> & toke
}
}
static void batch_decode(llama_context * ctx, llama_batch & batch, float * output, int n_seq, int n_embd_out, int embd_norm) {
static void batch_decode(llama_context * ctx, llama_batch & batch, float * output, int n_seq, int n_embd, int embd_norm) {
const enum llama_pooling_type pooling_type = llama_pooling_type(ctx);
// clear previous kv_cache values (irrelevant for embeddings)
@@ -65,8 +65,8 @@ static void batch_decode(llama_context * ctx, llama_batch & batch, float * outpu
GGML_ASSERT(embd != NULL && "failed to get sequence embeddings");
}
float * out = output + embd_pos * n_embd_out;
common_embd_normalize(embd, out, n_embd_out, embd_norm);
float * out = output + embd_pos * n_embd;
common_embd_normalize(embd, out, n_embd, embd_norm);
}
}
@@ -252,8 +252,8 @@ int main(int argc, char ** argv) {
}
// allocate output
const int n_embd_out = llama_model_n_embd_out(model);
std::vector<float> embeddings(n_embd_count * n_embd_out, 0);
const int n_embd = llama_model_n_embd(model);
std::vector<float> embeddings(n_embd_count * n_embd, 0);
float * emb = embeddings.data();
// break into batches
@@ -267,8 +267,8 @@ int main(int argc, char ** argv) {
// encode if at capacity
if (batch.n_tokens + n_toks > n_batch || s >= n_seq_max) {
float * out = emb + e * n_embd_out;
batch_decode(ctx, batch, out, s, n_embd_out, params.embd_normalize);
float * out = emb + e * n_embd;
batch_decode(ctx, batch, out, s, n_embd, params.embd_normalize);
e += pooling_type == LLAMA_POOLING_TYPE_NONE ? batch.n_tokens : s;
s = 0;
common_batch_clear(batch);
@@ -280,8 +280,8 @@ int main(int argc, char ** argv) {
}
// final batch
float * out = emb + e * n_embd_out;
batch_decode(ctx, batch, out, s, n_embd_out, params.embd_normalize);
float * out = emb + e * n_embd;
batch_decode(ctx, batch, out, s, n_embd, params.embd_normalize);
if (params.embd_out.empty()) {
LOG("\n");
@@ -289,19 +289,19 @@ int main(int argc, char ** argv) {
if (pooling_type == LLAMA_POOLING_TYPE_NONE) {
for (int j = 0; j < n_embd_count; j++) {
LOG("embedding %d: ", j);
for (int i = 0; i < std::min(3, n_embd_out); i++) {
for (int i = 0; i < std::min(3, n_embd); i++) {
if (params.embd_normalize == 0) {
LOG("%6.0f ", emb[j * n_embd_out + i]);
LOG("%6.0f ", emb[j * n_embd + i]);
} else {
LOG("%9.6f ", emb[j * n_embd_out + i]);
LOG("%9.6f ", emb[j * n_embd + i]);
}
}
LOG(" ... ");
for (int i = n_embd_out - 3; i < n_embd_out; i++) {
for (int i = n_embd - 3; i < n_embd; i++) {
if (params.embd_normalize == 0) {
LOG("%6.0f ", emb[j * n_embd_out + i]);
LOG("%6.0f ", emb[j * n_embd + i]);
} else {
LOG("%9.6f ", emb[j * n_embd_out + i]);
LOG("%9.6f ", emb[j * n_embd + i]);
}
}
LOG("\n");
@@ -320,9 +320,9 @@ int main(int argc, char ** argv) {
for (uint32_t i = 0; i < n_cls_out; i++) {
// NOTE: if you change this log - update the tests in ci/run.sh
if (n_cls_out == 1) {
LOG("rerank score %d: %8.3f\n", j, emb[j * n_embd_out]);
LOG("rerank score %d: %8.3f\n", j, emb[j * n_embd]);
} else {
LOG("rerank score %d: %8.3f [%s]\n", j, emb[j * n_embd_out + i], cls_out_labels[i].c_str());
LOG("rerank score %d: %8.3f [%s]\n", j, emb[j * n_embd + i], cls_out_labels[i].c_str());
}
}
}
@@ -330,11 +330,11 @@ int main(int argc, char ** argv) {
// print the first part of the embeddings or for a single prompt, the full embedding
for (int j = 0; j < n_prompts; j++) {
LOG("embedding %d: ", j);
for (int i = 0; i < (n_prompts > 1 ? std::min(16, n_embd_out) : n_embd_out); i++) {
for (int i = 0; i < (n_prompts > 1 ? std::min(16, n_embd) : n_embd); i++) {
if (params.embd_normalize == 0) {
LOG("%6.0f ", emb[j * n_embd_out + i]);
LOG("%6.0f ", emb[j * n_embd + i]);
} else {
LOG("%9.6f ", emb[j * n_embd_out + i]);
LOG("%9.6f ", emb[j * n_embd + i]);
}
}
LOG("\n");
@@ -350,7 +350,7 @@ int main(int argc, char ** argv) {
LOG("\n");
for (int i = 0; i < n_prompts; i++) {
for (int j = 0; j < n_prompts; j++) {
float sim = common_embd_similarity_cos(emb + i * n_embd_out, emb + j * n_embd_out, n_embd_out);
float sim = common_embd_similarity_cos(emb + i * n_embd, emb + j * n_embd, n_embd);
LOG("%6.2f ", sim);
}
LOG("%1.10s", prompts[i].c_str());
@@ -368,9 +368,9 @@ int main(int argc, char ** argv) {
if (notArray) LOG(" {\n \"object\": \"embedding\",\n \"index\": %d,\n \"embedding\": ",j);
LOG("[");
for (int i = 0;;) { // at least one iteration (n_embd > 0)
LOG(params.embd_normalize == 0 ? "%1.0f" : "%1.7f", emb[j * n_embd_out + i]);
LOG(params.embd_normalize == 0 ? "%1.0f" : "%1.7f", emb[j * n_embd + i]);
i++;
if (i < n_embd_out) LOG(","); else break;
if (i < n_embd) LOG(","); else break;
}
LOG(notArray ? "]\n }" : "]");
j++;
@@ -383,7 +383,7 @@ int main(int argc, char ** argv) {
for (int i = 0;;) { // at least two iteration (n_embd_count > 1)
LOG(" [");
for (int j = 0;;) { // at least two iteration (n_embd_count > 1)
float sim = common_embd_similarity_cos(emb + i * n_embd_out, emb + j * n_embd_out, n_embd_out);
float sim = common_embd_similarity_cos(emb + i * n_embd, emb + j * n_embd, n_embd);
LOG("%6.2f", sim);
j++;
if (j < n_embd_count) LOG(", "); else break;
@@ -397,7 +397,7 @@ int main(int argc, char ** argv) {
if (notArray) LOG("\n}\n");
} else if (params.embd_out == "raw") {
print_raw_embeddings(emb, n_embd_count, n_embd_out, model, pooling_type, params.embd_normalize);
print_raw_embeddings(emb, n_embd_count, n_embd, model, pooling_type, params.embd_normalize);
}
LOG("\n");