退回到 b7516 版本

This commit is contained in:
2026-01-16 18:12:13 +08:00
parent 9d7890f8c6
commit 7e0d40b535
380 changed files with 18454 additions and 38808 deletions

View File

@@ -217,13 +217,13 @@ int main(int argc, char ** argv) {
struct llama_batch batch = llama_batch_init(n_batch, 0, 1);
// allocate output
const int n_embd_out = llama_model_n_embd_out(model);
std::vector<float> embeddings(n_chunks * n_embd_out, 0);
const int n_embd = llama_model_n_embd(model);
std::vector<float> embeddings(n_chunks * n_embd, 0);
float * emb = embeddings.data();
// break into batches
unsigned int p = 0; // number of prompts processed already
unsigned int s = 0; // number of prompts in current batch
int p = 0; // number of prompts processed already
int s = 0; // number of prompts in current batch
for (int k = 0; k < n_chunks; k++) {
// clamp to n_batch tokens
auto & inp = chunks[k].tokens;
@@ -231,9 +231,9 @@ int main(int argc, char ** argv) {
const uint64_t n_toks = inp.size();
// encode if at capacity
if (batch.n_tokens + n_toks > n_batch || s >= llama_n_seq_max(ctx)) {
float * out = emb + p * n_embd_out;
batch_process(ctx, batch, out, s, n_embd_out);
if (batch.n_tokens + n_toks > n_batch) {
float * out = emb + p * n_embd;
batch_process(ctx, batch, out, s, n_embd);
common_batch_clear(batch);
p += s;
s = 0;
@@ -245,12 +245,12 @@ int main(int argc, char ** argv) {
}
// final batch
float * out = emb + p * n_embd_out;
batch_process(ctx, batch, out, s, n_embd_out);
float * out = emb + p * n_embd;
batch_process(ctx, batch, out, s, n_embd);
// save embeddings to chunks
for (int i = 0; i < n_chunks; i++) {
chunks[i].embedding = std::vector<float>(emb + i * n_embd_out, emb + (i + 1) * n_embd_out);
chunks[i].embedding = std::vector<float>(emb + i * n_embd, emb + (i + 1) * n_embd);
// clear tokens as they are no longer needed
chunks[i].tokens.clear();
}
@@ -266,8 +266,8 @@ int main(int argc, char ** argv) {
batch_add_seq(query_batch, query_tokens, 0);
std::vector<float> query_emb(n_embd_out, 0);
batch_process(ctx, query_batch, query_emb.data(), 1, n_embd_out);
std::vector<float> query_emb(n_embd, 0);
batch_process(ctx, query_batch, query_emb.data(), 1, n_embd);
common_batch_clear(query_batch);
@@ -275,7 +275,7 @@ int main(int argc, char ** argv) {
{
std::vector<std::pair<int, float>> similarities;
for (int i = 0; i < n_chunks; i++) {
float sim = common_embd_similarity_cos(chunks[i].embedding.data(), query_emb.data(), n_embd_out);
float sim = common_embd_similarity_cos(chunks[i].embedding.data(), query_emb.data(), n_embd);
similarities.push_back(std::make_pair(i, sim));
}