Sync from upstream llama.cpp repository

This commit is contained in:
2026-01-16 10:43:34 +08:00
parent 3bc369a6f7
commit f4ae4cc7da
2053 changed files with 956010 additions and 1 deletions

45
examples/CMakeLists.txt Normal file
View File

@@ -0,0 +1,45 @@
# dependencies
find_package(Threads REQUIRED)
# third-party
# ...
# flags
llama_add_compile_flags()
# examples
if (EMSCRIPTEN)
else()
add_subdirectory(batched)
add_subdirectory(debug)
add_subdirectory(embedding)
add_subdirectory(eval-callback)
add_subdirectory(gguf-hash)
add_subdirectory(gguf)
add_subdirectory(idle)
add_subdirectory(lookahead)
add_subdirectory(lookup)
add_subdirectory(parallel)
add_subdirectory(passkey)
add_subdirectory(retrieval)
add_subdirectory(save-load-state)
add_subdirectory(simple)
add_subdirectory(simple-chat)
add_subdirectory(speculative)
add_subdirectory(speculative-simple)
add_subdirectory(gen-docs)
add_subdirectory(training)
add_subdirectory(diffusion)
if (NOT GGML_BACKEND_DL)
add_subdirectory(convert-llama2c-to-ggml)
# these examples use the backends directly and cannot be built with dynamic loading
if (GGML_SYCL)
add_subdirectory(sycl)
endif()
endif()
endif()

9
examples/batched.swift/.gitignore vendored Normal file
View File

@@ -0,0 +1,9 @@
.DS_Store
/.build
/Packages
xcuserdata/
DerivedData/
.swiftpm/configuration/registries.json
.swiftpm/xcode/package.xcworkspace/contents.xcworkspacedata
.netrc
batched_swift

View File

@@ -0,0 +1,6 @@
.PHONY: build
build:
xcodebuild -scheme llama-batched-swift -destination "generic/platform=macOS" -derivedDataPath build
rm -f ./llama-batched-swift
ln -s ./build/Build/Products/Debug/llama-batched-swift ./llama-batched-swift

View File

@@ -0,0 +1,22 @@
// swift-tools-version: 5.5
// The swift-tools-version declares the minimum version of Swift required to build this package.
import PackageDescription
let package = Package(
name: "llama-batched-swift",
platforms: [.macOS(.v12)],
dependencies: [
.package(name: "llama", path: "../../"),
],
targets: [
// Targets are the basic building blocks of a package, defining a module or a test suite.
// Targets can depend on other targets in this package and products from dependencies.
.executableTarget(
name: "llama-batched-swift",
dependencies: ["llama"],
path: "Sources",
linkerSettings: [.linkedFramework("Foundation"), .linkedFramework("AppKit")]
),
]
)

View File

@@ -0,0 +1,5 @@
This is a swift clone of `examples/batched`.
```bash
$ ./llama-batched-swift MODEL_PATH [PROMPT] [PARALLEL]
```

View File

@@ -0,0 +1,256 @@
import Foundation
import llama
let arguments = CommandLine.arguments
// Check that we have at least one argument (the model path)
guard arguments.count > 1 else {
print("Usage: swift MODEL_PATH [PROMPT] [PARALLEL]")
exit(1)
}
let modelPath: String = arguments[1]
let prompt: String = arguments.count > 2 ? arguments[2] : "Hello my name is"
let n_parallel: Int = arguments.count > 3 && Int(arguments[3]) != nil ? Int(arguments[3])! : 1
// total length of the sequences including the prompt
let n_len: Int = 32
// init LLM
llama_backend_init()
defer {
llama_backend_free()
}
let model_params = llama_model_default_params()
guard let model = llama_model_load_from_file(modelPath.cString(using: .utf8), model_params) else {
print("Failed to load model")
exit(1)
}
defer {
llama_model_free(model)
}
guard let vocab = llama_model_get_vocab(model) else {
print("Failed to get vocab")
exit(1)
}
var tokens = tokenize(text: prompt, add_bos: true)
let n_kv_req = UInt32(tokens.count) + UInt32((n_len - Int(tokens.count)) * n_parallel)
var context_params = llama_context_default_params()
context_params.n_ctx = n_kv_req
context_params.n_batch = UInt32(max(n_len, n_parallel))
context_params.n_threads = 8
context_params.n_threads_batch = 8
let context = llama_init_from_model(model, context_params)
guard context != nil else {
print("Failed to initialize context")
exit(1)
}
defer {
llama_free(context)
}
var sparams = llama_sampler_chain_default_params()
let smpl = llama_sampler_chain_init(sparams)
guard smpl != nil else {
print("Failed to initialize sampling")
exit(1)
}
defer {
llama_sampler_free(smpl)
}
llama_sampler_chain_add(smpl, llama_sampler_init_top_k(40));
llama_sampler_chain_add(smpl, llama_sampler_init_top_p(0.9, 1));
llama_sampler_chain_add(smpl, llama_sampler_init_temp (0.4));
llama_sampler_chain_add(smpl, llama_sampler_init_dist (1234));
let n_ctx = llama_n_ctx(context)
print("\nn_len = \(n_len), n_ctx = \(n_ctx), n_batch = \(context_params.n_batch), n_parallel = \(n_parallel), n_kv_req = \(n_kv_req)\n")
if n_kv_req > n_ctx {
print("error: n_kv_req (%d) > n_ctx, the required KV cache size is not big enough\n", n_kv_req)
exit(1)
}
var buffer: [CChar] = []
for id: llama_token in tokens {
print(token_to_piece(token: id, buffer: &buffer) ?? "", terminator: "")
}
print("\n")
var batch = llama_batch_init(max(Int32(tokens.count), Int32(n_parallel)), 0, 1)
defer {
llama_batch_free(batch)
}
// evaluate the initial prompt
batch.n_tokens = Int32(tokens.count)
for (i, token) in tokens.enumerated() {
batch.token[i] = token
batch.pos[i] = Int32(i)
batch.n_seq_id[i] = 1
// batch.seq_id[i][0] = 0
// TODO: is this the proper way to do this?
if let seq_id = batch.seq_id[i] {
seq_id[0] = 0
}
batch.logits[i] = 0
}
// llama_decode will output logits only for the last token of the prompt
batch.logits[Int(batch.n_tokens) - 1] = 1
if llama_decode(context, batch) != 0 {
print("llama_decode() failed")
exit(1)
}
for i in 1 ..< n_parallel {
llama_memory_seq_cp(llama_get_memory(context), 0, Int32(i), 0, batch.n_tokens)
}
if n_parallel > 1 {
print("generating \(n_parallel) sequences ...\n")
}
var streams: [String] = .init(repeating: "", count: n_parallel)
var streamBuffers: [[CChar]] = .init(repeating: [], count: n_parallel)
var i_batch = [Int32](repeating: batch.n_tokens - 1, count: n_parallel)
var n_cur = batch.n_tokens
var n_decode = 0
let t_main_start = ggml_time_us()
while n_cur <= n_len {
// prepare the next batch
batch.n_tokens = 0
// sample the next token for each parallel sequence / stream
for i in 0 ..< n_parallel {
if i_batch[i] < 0 {
// the stream has already finished
continue
}
let new_token_id = llama_sampler_sample(smpl, context, i_batch[i])
// is it an end of stream? -> mark the stream as finished
if llama_vocab_is_eog(vocab, new_token_id) || n_cur == n_len {
i_batch[i] = -1
// print("")
if n_parallel > 1 {
print("stream \(i) finished at n_cur = \(n_cur)")
}
continue
}
let nextStringPiece = token_to_piece(token: new_token_id, buffer: &streamBuffers[i]) ?? ""
// if there is only one stream, we print immediately to stdout
if n_parallel == 1 {
print(nextStringPiece, terminator: "")
}
streams[i] += nextStringPiece
// push this new token for next evaluation
batch.token[Int(batch.n_tokens)] = new_token_id
batch.pos[Int(batch.n_tokens)] = n_cur
batch.n_seq_id[Int(batch.n_tokens)] = 1
if let seq_id = batch.seq_id[Int(batch.n_tokens)] {
seq_id[0] = Int32(i)
}
batch.logits[Int(batch.n_tokens)] = 1
i_batch[i] = batch.n_tokens
batch.n_tokens += 1
n_decode += 1
}
// all streams are finished
if batch.n_tokens == 0 {
break
}
n_cur += 1
// evaluate the current batch with the transformer model
if llama_decode(context, batch) != 0 {
print("llama_decode() failed")
exit(1)
}
}
if n_parallel > 1 {
print("\n")
for (i, stream) in streams.enumerated() {
print("sequence \(i):\n\n\(prompt)\(stream)\n")
}
}
let t_main_end = ggml_time_us()
print("decoded \(n_decode) tokens in \(String(format: "%.2f", Double(t_main_end - t_main_start) / 1_000_000.0)) s, speed: \(String(format: "%.2f", Double(n_decode) / (Double(t_main_end - t_main_start) / 1_000_000.0))) t/s\n\n")
llama_perf_sampler_print(smpl)
llama_perf_context_print(context)
private func tokenize(text: String, add_bos: Bool) -> [llama_token] {
let utf8Count = text.utf8.count
let n_tokens = utf8Count + (add_bos ? 1 : 0)
let tokens = UnsafeMutablePointer<llama_token>.allocate(capacity: n_tokens)
let tokenCount = llama_tokenize(vocab, text, Int32(utf8Count), tokens, Int32(n_tokens), add_bos, /*special tokens*/ false)
var swiftTokens: [llama_token] = []
for i in 0 ..< tokenCount {
swiftTokens.append(tokens[Int(i)])
}
tokens.deallocate()
return swiftTokens
}
private func token_to_piece(token: llama_token, buffer: inout [CChar]) -> String? {
var result = [CChar](repeating: 0, count: 8)
let nTokens = llama_token_to_piece(vocab, token, &result, Int32(result.count), 0, false)
if nTokens < 0 {
let actualTokensCount = -Int(nTokens)
result = .init(repeating: 0, count: actualTokensCount)
let check = llama_token_to_piece(
vocab,
token,
&result,
Int32(result.count),
0,
false
)
assert(check == actualTokensCount)
} else {
result.removeLast(result.count - Int(nTokens))
}
if buffer.isEmpty, let utfString = String(cString: result + [0], encoding: .utf8) {
return utfString
} else {
buffer.append(contentsOf: result)
let data = Data(buffer.map { UInt8(bitPattern: $0) })
if buffer.count >= 4 { // 4 bytes is the max length of a utf8 character so if we're here we need to reset the buffer
buffer = []
}
guard let bufferString = String(data: data, encoding: .utf8) else {
return nil
}
buffer = []
return bufferString
}
}

View File

@@ -0,0 +1,5 @@
set(TARGET llama-batched)
add_executable(${TARGET} batched.cpp)
install(TARGETS ${TARGET} RUNTIME)
target_link_libraries(${TARGET} PRIVATE common llama ${CMAKE_THREAD_LIBS_INIT})
target_compile_features(${TARGET} PRIVATE cxx_std_17)

View File

@@ -0,0 +1,44 @@
# llama.cpp/example/batched
The example demonstrates batched generation from a given prompt
```bash
./llama-batched -m ./models/llama-7b-v2/ggml-model-f16.gguf -p "Hello my name is" -np 4 --kv-unified
...
main: n_len = 32, n_ctx = 2048, n_parallel = 4, n_kv_req = 113
Hello my name is
main: generating 4 sequences ...
main: stream 0 finished
main: stream 1 finished
main: stream 2 finished
main: stream 3 finished
sequence 0:
Hello my name is Shirley. I am a 25-year-old female who has been working for over 5 years as a b
sequence 1:
Hello my name is Renee and I'm a 32 year old female from the United States. I'm looking for a man between
sequence 2:
Hello my name is Diana. I am looking for a housekeeping job. I have experience with children and have my own transportation. I am
sequence 3:
Hello my name is Cody. I am a 3 year old neutered male. I am a very friendly cat. I am very playful and
main: decoded 108 tokens in 3.57 s, speed: 30.26 t/s
llama_print_timings: load time = 587.00 ms
llama_print_timings: sample time = 2.56 ms / 112 runs ( 0.02 ms per token, 43664.72 tokens per second)
llama_print_timings: prompt eval time = 4089.11 ms / 118 tokens ( 34.65 ms per token, 28.86 tokens per second)
llama_print_timings: eval time = 0.00 ms / 1 runs ( 0.00 ms per token, inf tokens per second)
llama_print_timings: total time = 4156.04 ms
```

View File

@@ -0,0 +1,261 @@
#include "arg.h"
#include "common.h"
#include "log.h"
#include "llama.h"
#include "sampling.h"
#include <algorithm>
#include <cstdio>
#include <string>
#include <vector>
static void print_usage(int, char ** argv) {
LOG("\nexample usage:\n");
LOG("\n %s -m model.gguf -p \"Hello my name is\" -n 32 -np 4\n", argv[0]);
LOG("\n");
}
int main(int argc, char ** argv) {
common_params params;
params.prompt = "Hello my name is";
params.n_predict = 32;
if (!common_params_parse(argc, argv, params, LLAMA_EXAMPLE_BATCHED, print_usage)) {
return 1;
}
common_init();
// number of parallel batches
int n_parallel = params.n_parallel;
// total length of the sequences including the prompt
int n_predict = params.n_predict;
// init LLM
llama_backend_init();
llama_numa_init(params.numa);
// initialize the model
llama_model_params model_params = common_model_params_to_llama(params);
llama_model * model = llama_model_load_from_file(params.model.path.c_str(), model_params);
if (model == NULL) {
LOG_ERR("%s: error: unable to load model\n" , __func__);
return 1;
}
const llama_vocab * vocab = llama_model_get_vocab(model);
// tokenize the prompt
std::vector<llama_token> tokens_list;
tokens_list = common_tokenize(vocab, params.prompt, true);
const int n_kv_req = tokens_list.size() + (n_predict - tokens_list.size())*n_parallel;
// initialize the context
llama_context_params ctx_params = common_context_params_to_llama(params);
ctx_params.n_ctx = n_kv_req;
ctx_params.n_batch = std::max(n_predict, n_parallel);
auto sparams = llama_sampler_chain_default_params();
sparams.no_perf = false;
std::vector<llama_sampler_seq_config> sampler_configs;
for (int32_t i = 0; i < n_parallel; ++i) {
llama_sampler * smpl = llama_sampler_chain_init(sparams);
llama_sampler_chain_add(smpl, llama_sampler_init_top_k(params.sampling.top_k));
llama_sampler_chain_add(smpl, llama_sampler_init_top_p(params.sampling.top_p, params.sampling.min_keep));
llama_sampler_chain_add(smpl, llama_sampler_init_temp (params.sampling.temp));
llama_sampler_chain_add(smpl, llama_sampler_init_dist (params.sampling.seed));
sampler_configs.push_back({ i, smpl });
}
if (params.sampling.backend_sampling) {
ctx_params.samplers = sampler_configs.data();
ctx_params.n_samplers = sampler_configs.size();
}
llama_context * ctx = llama_init_from_model(model, ctx_params);
if (ctx == NULL) {
LOG_ERR("%s: error: failed to create the llama_context\n" , __func__);
return 1;
}
const int n_ctx = llama_n_ctx(ctx);
LOG_INF("\n%s: n_predict = %d, n_ctx = %d, n_batch = %u, n_parallel = %d, n_kv_req = %d\n", __func__, n_predict, n_ctx, ctx_params.n_batch, n_parallel, n_kv_req);
// make sure the KV cache is big enough to hold all the prompt and generated tokens
if (n_kv_req > n_ctx) {
LOG_ERR("%s: error: n_kv_req (%d) > n_ctx, the required KV cache size is not big enough\n", __func__, n_kv_req);
LOG_ERR("%s: either reduce n_parallel or increase n_ctx\n", __func__);
return 1;
}
// print the prompt token-by-token
LOG("\n");
for (auto id : tokens_list) {
LOG("%s", common_token_to_piece(ctx, id).c_str());
}
// create a llama_batch
// we use this object to submit token data for decoding
llama_batch batch = llama_batch_init(std::max(tokens_list.size(), (size_t) n_parallel), 0, n_parallel);
std::vector<llama_seq_id> seq_ids(n_parallel, 0);
for (int32_t i = 0; i < n_parallel; ++i) {
seq_ids[i] = i;
}
// evaluate the initial prompt
for (size_t i = 0; i < tokens_list.size(); ++i) {
common_batch_add(batch, tokens_list[i], i, seq_ids, false);
}
GGML_ASSERT(batch.n_tokens == (int) tokens_list.size());
if (llama_model_has_encoder(model)) {
if (llama_encode(ctx, batch)) {
LOG_ERR("%s : failed to eval\n", __func__);
return 1;
}
llama_token decoder_start_token_id = llama_model_decoder_start_token(model);
if (decoder_start_token_id == LLAMA_TOKEN_NULL) {
decoder_start_token_id = llama_vocab_bos(vocab);
}
common_batch_clear(batch);
common_batch_add(batch, decoder_start_token_id, 0, seq_ids, false);
}
// llama_decode will output logits only for the last token of the prompt
batch.logits[batch.n_tokens - 1] = true;
if (llama_decode(ctx, batch) != 0) {
LOG_ERR("%s: llama_decode() failed\n", __func__);
return 1;
}
//// assign the system KV cache to all parallel sequences
//// this way, the parallel sequences will "reuse" the prompt tokens without having to copy them
//for (int32_t i = 1; i < n_parallel; ++i) {
// llama_kv_cache_seq_cp(ctx, 0, i, -1, -1);
//}
if (n_parallel > 1) {
LOG("\n\n%s: generating %d sequences ...\n", __func__, n_parallel);
}
// main loop
// we will store the parallel decoded sequences in this vector
std::vector<std::string> streams(n_parallel);
// remember the batch index of the last token for each parallel sequence
// we need this to determine which logits to sample from
std::vector<int32_t> i_batch(n_parallel, batch.n_tokens - 1);
int n_cur = batch.n_tokens;
int n_decode = 0;
const auto t_main_start = ggml_time_us();
while (n_cur <= n_predict) {
// prepare the next batch
common_batch_clear(batch);
// sample the next token for each parallel sequence / stream
for (int32_t i = 0; i < n_parallel; ++i) {
if (i_batch[i] < 0) {
// the stream has already finished
continue;
}
const llama_token new_token_id = llama_sampler_sample(sampler_configs[i].sampler, ctx, i_batch[i]);
// is it an end of generation? -> mark the stream as finished
if (llama_vocab_is_eog(vocab, new_token_id) || n_cur == n_predict) {
i_batch[i] = -1;
LOG("\n");
if (n_parallel > 1) {
LOG_INF("%s: stream %d finished at n_cur = %d", __func__, i, n_cur);
}
continue;
}
// if there is only one stream, we print immediately to stdout
if (n_parallel == 1) {
LOG("%s", common_token_to_piece(ctx, new_token_id).c_str());
}
streams[i] += common_token_to_piece(ctx, new_token_id);
i_batch[i] = batch.n_tokens;
// push this new token for next evaluation
common_batch_add(batch, new_token_id, n_cur, { i }, true);
n_decode += 1;
}
// all streams are finished
if (batch.n_tokens == 0) {
break;
}
n_cur += 1;
// evaluate the current batch with the transformer model
if (llama_decode(ctx, batch)) {
LOG_ERR("%s : failed to eval, return code %d\n", __func__, 1);
return 1;
}
}
if (n_parallel > 1) {
LOG("\n");
for (int32_t i = 0; i < n_parallel; ++i) {
LOG("sequence %d:\n\n%s%s\n\n", i, params.prompt.c_str(), streams[i].c_str());
}
}
const auto t_main_end = ggml_time_us();
LOG_INF("%s: decoded %d tokens in %.2f s, speed: %.2f t/s\n",
__func__, n_decode, (t_main_end - t_main_start) / 1000000.0f, n_decode / ((t_main_end - t_main_start) / 1000000.0f));
LOG("\n");
llama_perf_sampler_print(sampler_configs[0].sampler);
llama_perf_context_print(ctx);
fprintf(stderr, "\n");
llama_batch_free(batch);
for (auto & sampler_config : sampler_configs) {
llama_sampler_free(sampler_config.sampler);
}
llama_free(ctx);
llama_model_free(model);
llama_backend_free();
return 0;
}

View File

@@ -0,0 +1,5 @@
set(TARGET llama-convert-llama2c-to-ggml)
add_executable(${TARGET} convert-llama2c-to-ggml.cpp)
install(TARGETS ${TARGET} RUNTIME)
target_link_libraries(${TARGET} PRIVATE common llama ${CMAKE_THREAD_LIBS_INIT})
target_compile_features(${TARGET} PRIVATE cxx_std_17)

View File

@@ -0,0 +1,25 @@
## Convert llama2.c model to ggml
This example reads weights from project [llama2.c](https://github.com/karpathy/llama2.c) and saves them in ggml compatible format. The vocab that is available in `models/ggml-vocab.bin` is used by default.
To convert the model first download the models from the [llama2.c](https://github.com/karpathy/llama2.c) repository.
```
usage: ./llama-convert-llama2c-to-ggml [options]
options:
-h, --help show this help message and exit
--copy-vocab-from-model FNAME path of gguf llama model or llama2.c vocabulary from which to copy vocab (default 'models/7B/ggml-model-f16.gguf')
--llama2c-model FNAME [REQUIRED] model path from which to load Karpathy's llama2.c model
--llama2c-output-model FNAME model path to save the converted llama2.c model (default ak_llama_model.bin')
```
An example command using a model from [karpathy/tinyllamas](https://huggingface.co/karpathy/tinyllamas) is as follows:
`$ ./llama-convert-llama2c-to-ggml --copy-vocab-from-model llama-2-7b-chat.gguf.q2_K.bin --llama2c-model stories42M.bin --llama2c-output-model stories42M.gguf.bin`
Note: The vocabulary for `stories260K.bin` should be its own tokenizer `tok512.bin` found in [karpathy/tinyllamas/stories260K](https://huggingface.co/karpathy/tinyllamas/tree/main/stories260K).
Now you can use the model with a command like:
`$ ./llama-cli -m stories42M.gguf.bin -p "One day, Lily met a Shoggoth" -n 500 -c 256`

View File

@@ -0,0 +1,941 @@
#include "ggml.h"
#include "gguf.h"
#include "llama.h"
#include "common.h"
#include "log.h"
#include <unordered_map>
#include <vector>
#include <cassert>
#include <climits>
#include <cstring>
#include <cstdarg>
#include <cinttypes>
#include <ctime>
#include <random>
#include <stdexcept>
#include <sstream>
#include <algorithm>
#include <string>
// GGUF keys & tensor names.
#define KV_GENERAL_ARCHITECTURE "general.architecture"
#define KV_GENERAL_NAME "general.name"
#define KV_TOKENIZER_MODEL "tokenizer.ggml.model"
#define KV_TOKENIZER_LIST "tokenizer.ggml.tokens"
#define KV_TOKENIZER_TOKEN_TYPE "tokenizer.ggml.token_type"
#define KV_TOKENIZER_SCORES "tokenizer.ggml.scores"
#define KV_TOKENIZER_BOS_ID "tokenizer.ggml.bos_token_id"
#define KV_TOKENIZER_EOS_ID "tokenizer.ggml.eos_token_id"
#define KV_TOKENIZER_UNK_ID "tokenizer.ggml.unknown_token_id"
#define KV_TOKENIZER_SEP_ID "tokenizer.ggml.seperator_token_id"
#define KV_TOKENIZER_PAD_ID "tokenizer.ggml.padding_token_id"
#define KV_TOKENIZER_HF_JSON "tokenizer.huggingface.json"
#define KV_CONTEXT_LENGTH "llama.context_length"
#define KV_EMBEDDING_LENGTH "llama.embedding_length"
#define KV_BLOCK_COUNT "llama.block_count"
#define KV_FEED_FORWARD_LENGTH "llama.feed_forward_length"
#define KV_ATTENTION_HEAD_COUNT "llama.attention.head_count"
#define KV_ATTENTION_HEAD_COUNT_KV "llama.attention.head_count_kv"
#define KV_ATTENTION_LAYERNORM_RMS_EPS "llama.attention.layer_norm_rms_epsilon"
#define KV_ROPE_DIMENSION_COUNT "llama.rope.dimension_count"
#define TN_TOKEN_EMBD "token_embd.weight"
#define TN_OUTPUT_NORM "output_norm.weight"
#define TN_OUTPUT "output.weight"
#define TN_ATTN_NORM "blk.%d.attn_norm.weight"
#define TN_ATTN_Q "blk.%d.attn_q.weight"
#define TN_ATTN_K "blk.%d.attn_k.weight"
#define TN_ATTN_V "blk.%d.attn_v.weight"
#define TN_ATTN_OUTPUT "blk.%d.attn_output.weight"
#define TN_FFN_NORM "blk.%d.ffn_norm.weight"
#define TN_FFN_GATE "blk.%d.ffn_gate.weight"
#define TN_FFN_DOWN "blk.%d.ffn_down.weight"
#define TN_FFN_UP "blk.%d.ffn_up.weight"
#if defined(_MSC_VER)
#pragma warning(disable: 4244 4267) // possible loss of data
#endif
#define LLAMA_FILE_MAGIC_GGJT 0x67676a74u // 'ggjt'
#define LLAMA_FILE_VERSION_GGJT_V3 3
#define TOKENIZER_NAME "llama"
#define UNKNOWN_TOKEN_ID 0
#define BOS_TOKEN_ID 1
#define EOS_TOKEN_ID 2
//////////////////////////////////////// llama2.c model structs and functions to load models, alloc memory etc.
typedef struct {
int dim; // transformer dimension
int hidden_dim; // for ffn layers
int n_layers; // number of layers
int n_heads; // number of query heads
int n_kv_heads; // number of key/value heads (can be < query heads because of multiquery)
int vocab_size; // vocabulary size, usually 256 (byte-level)
int seq_len; // max sequence length
} Config;
struct TransformerWeights {
// token embedding table
std::vector<float> token_embedding_table; // (vocab_size, dim)
// weights for rmsnorms
std::vector<float> rms_att_weight; // (layer, dim) rmsnorm weights
std::vector<float> rms_ffn_weight; // (layer, dim)
// weights for matmuls
std::vector<float> wq; // (layer, dim, dim)
std::vector<float> wk; // (layer, dim, dim)
std::vector<float> wv; // (layer, dim, dim)
std::vector<float> wo; // (layer, dim, dim)
// weights for ffn
std::vector<float> w1; // (layer, hidden_dim, dim)
std::vector<float> w2; // (layer, dim, hidden_dim)
std::vector<float> w3; // (layer, hidden_dim, dim)
// final rmsnorm
std::vector<float> rms_final_weight; // (dim,)
// freq_cis for RoPE relatively positional embeddings
// std::vector<float> freq_cis_real; // (seq_len, dim/2)
// std::vector<float> freq_cis_imag; // (seq_len, dim/2)
// (optional) classifier weights for the logits, on the last layer
std::vector<float> wcls;
};
static void alloc_weights(TransformerWeights * w, const Config * p, bool shared_weights) {
const int n_multiqueries = p->n_kv_heads <= 0 || p->n_kv_heads >= p->n_heads ? 1 : p->n_heads / p->n_kv_heads;
try {
w->token_embedding_table.resize(p->vocab_size * p->dim);
LOG_INF("%s: Allocating [%d] x [%d] = [%d] float space for w->token_embedding_table\n",__func__,p->vocab_size , p->dim, p->vocab_size * p->dim);
w->rms_att_weight.resize(p->n_layers * p->dim);
LOG_INF("%s: Allocating [%d] x [%d] = [%d] float space for w->rms_att_weight\n",__func__,p->n_layers, p->dim, p->n_layers * p->dim);
w->rms_ffn_weight.resize(p->n_layers * p->dim);
LOG_INF("%s: Allocating [%d] x [%d] = [%d] float space for w->rms_ffn_weight\n",__func__,p->n_layers , p->dim, p->n_layers * p->dim);
w->wq.resize(p->n_layers * p->dim * p->dim);
LOG_INF("%s: Allocating [%d] x [%d] x [%d] = [%d] float space for w->wq\n",__func__,p->n_layers, p->dim, p->dim, p->n_layers * p->dim * p->dim);
w->wk.resize(p->n_layers * p->dim * p->dim / n_multiqueries);
LOG_INF("%s: Allocating [%d] x [%d] x [%d] = [%d] float space for w->wk\n",__func__,p->n_layers, p->dim, p->dim / n_multiqueries, p->n_layers * p->dim * p->dim / n_multiqueries);
w->wv.resize(p->n_layers * p->dim * p->dim / n_multiqueries);
LOG_INF("%s: Allocating [%d] x [%d] x [%d] = [%d] float space for w->wv\n",__func__, p->n_layers, p->dim, p->dim / n_multiqueries, p->n_layers * p->dim * p->dim / n_multiqueries);
w->wo.resize(p->n_layers * p->dim * p->dim);
LOG_INF("%s: Allocating [%d] x [%d] x [%d] = [%d] float space for w->wo\n",__func__,p->n_layers, p->dim, p->dim, p->n_layers * p->dim * p->dim);
w->w1.resize(p->n_layers * p->hidden_dim * p->dim);
LOG_INF("%s: Allocating [%d] x [%d] x [%d] = [%d] float space for w->w1\n",__func__,p->n_layers, p->hidden_dim, p->dim, p->n_layers * p->hidden_dim * p->dim);
w->w2.resize(p->n_layers * p->hidden_dim * p->dim);
LOG_INF("%s: Allocating [%d] x [%d] x [%d] = [%d] float space for w->w2\n",__func__,p->n_layers, p->dim, p->hidden_dim, p->n_layers * p->hidden_dim * p->dim);
w->w3.resize(p->n_layers * p->hidden_dim * p->dim);
LOG_INF("%s: Allocating [%d] x [%d] x [%d] = [%d] float space for w->w3\n",__func__,p->n_layers, p->hidden_dim, p->dim, p->n_layers * p->hidden_dim * p->dim);
w->rms_final_weight.resize(p->dim);
LOG_INF("%s: Allocating [%d] float space for w->rms_final_weight\n",__func__,p->dim);
if (shared_weights) {
w->wcls = {};
} else {
w->wcls.resize(p->vocab_size * p->dim);
LOG_INF("%s: Allocating [%d] x [%d] = [%d] float space for w->wcls\n",__func__,p->vocab_size , p->dim, p->vocab_size * p->dim);
}
}
catch (std::length_error &) {
die("Invalid configuration. Failed to allocate memory for weights");
}
}
static int checkpoint_init_weights(TransformerWeights * w, const Config * p, FILE * f, bool shared_weights) {
if (fread(w->token_embedding_table.data(), sizeof(float), w->token_embedding_table.size(), f) != w->token_embedding_table.size()) return 1;
if (fread(w->rms_att_weight.data(), sizeof(float), w->rms_att_weight.size(), f) != w->rms_att_weight.size()) return 1;
if (fread(w->wq.data(), sizeof(float), w->wq.size(), f) != w->wq.size()) return 1;
if (fread(w->wk.data(), sizeof(float), w->wk.size(), f) != w->wk.size()) return 1;
if (fread(w->wv.data(), sizeof(float), w->wv.size(), f) != w->wv.size()) return 1;
if (fread(w->wo.data(), sizeof(float), w->wo.size(), f) != w->wo.size()) return 1;
if (fread(w->rms_ffn_weight.data(), sizeof(float), w->rms_ffn_weight.size(), f) != w->rms_ffn_weight.size()) return 1;
if (fread(w->w1.data(), sizeof(float), w->w1.size(), f) != w->w1.size()) return 1;
if (fread(w->w2.data(), sizeof(float), w->w2.size(), f) != w->w2.size()) return 1;
if (fread(w->w3.data(), sizeof(float), w->w3.size(), f) != w->w3.size()) return 1;
if (fread(w->rms_final_weight.data(), sizeof(float), w->rms_final_weight.size(), f) != w->rms_final_weight.size()) return 1;
// Skip freq_cis_real & freq_cis_imag
int head_size = p->dim / p->n_heads;
fseek(f, p->seq_len * head_size * sizeof(float), SEEK_CUR);
if (!shared_weights && fread(w->wcls.data(), sizeof(float), w->wcls.size(), f) != w->wcls.size()) return 1;
// Check we didn't forget to read anything
auto curr = ftell(f);
fseek(f, 0, SEEK_END);
auto end = ftell(f);
if (curr != end) {
LOG_ERR("%s: Error: failed to read the checkpoint file to the end (curr = %ld, end = %ld)\n", __func__, curr, end);
return 1;
}
return 0;
}
static void print_sample_weights(TransformerWeights *w){
LOG_INF("----- Quick print of first of the weight vales of all the variables\n");
LOG_INF("%f\n", w->token_embedding_table[0]);
LOG_INF("%f\n", w->rms_att_weight[0]);
LOG_INF("%f\n", w->rms_ffn_weight[0]);
LOG_INF("%f\n", w->wq[0]);
LOG_INF("%f\n", w->wk[0]);
LOG_INF("%f\n", w->wv[0]);
LOG_INF("%f\n", w->wo[0]);
LOG_INF("%f\n", w->w1[0]);
LOG_INF("%f\n", w->w2[0]);
LOG_INF("%f\n", w->w3[0]);
LOG_INF("%f\n", w->rms_att_weight[0]);
if (!w->wcls.empty()) LOG_INF("%f\n", w->wcls[0]);
}
////////////////////////////////////////////////////////////////////////////////////////////////////////////
//////////////////////////////////////// ggml structs and functions required to load models, configs and save the model.
struct my_llama_vocab {
using id = int32_t;
using token = std::string;
using ttype = llama_token_type;
struct token_data {
token text;
float score;
ttype type;
};
std::unordered_map<token, id> token_to_id;
std::vector<token_data> id_to_token;
};
struct my_llama_hparams {
uint32_t n_vocab = 32000;
uint32_t n_ctx = 512; // this is provided as user input?
uint32_t n_embd = 4096;
uint32_t n_ff = 11008;
uint32_t n_mult = 4;
uint32_t n_head = 32;
uint32_t n_head_kv = 32;
uint32_t n_layer = 32;
uint32_t n_rot = 64;
bool operator!=(const my_llama_hparams& other) const {
return memcmp(this, &other, sizeof(my_llama_hparams));
}
};
struct my_llama_layer {
// normalization
struct ggml_tensor * attention_norm;
// attention
struct ggml_tensor * wq;
struct ggml_tensor * wk;
struct ggml_tensor * wv;
struct ggml_tensor * wo;
// normalization
struct ggml_tensor * ffn_norm;
// ff
struct ggml_tensor * w1;
struct ggml_tensor * w2;
struct ggml_tensor * w3;
};
struct my_llama_model {
struct ggml_context * ctx = NULL;
std::string name;
my_llama_hparams hparams;
struct ggml_tensor * tok_embeddings;
struct ggml_tensor * norm;
struct ggml_tensor * output;
std::vector<my_llama_layer> layers;
uint32_t train_its = 0;
uint32_t train_samples = 0;
uint32_t train_tokens = 0;
};
struct train_params {
const char * fn_vocab_model;
const char * fn_llama2c_model;
const char * fn_llama2c_output_model;
const char * fn_train_data;
const char * fn_checkpoint_in;
const char * fn_checkpoint_out;
const char * fn_model_out;
uint32_t seed;
int n_ctx;
int n_embd;
int n_mult;
int n_head;
int n_layer;
int n_rotmax;
int n_threads;
int n_batch;
int n_examples;
int n_predict;
int print_info_interval;
int print_details_interval;
bool samples_start_after_nl;
bool use_adam;
bool use_flash;
bool use_scratch;
// only adam
int warmup;
int cos_decay_steps;
float cos_decay_restart;
float cos_decay_alpha;
int lbfgs_n_iter;
int adam_n_iter;
float adam_alpha;
float adam_decay;
int mem_model_gb;
int mem_compute_gb;
int mem_compute0_gb;
int mem_compute1_gb;
};
static void print_params(struct my_llama_hparams * params) {
LOG_INF("%s: n_vocab: %u\n", __func__, params->n_vocab);
LOG_INF("%s: n_ctx: %u\n", __func__, params->n_ctx);
LOG_INF("%s: n_embd: %u\n", __func__, params->n_embd);
LOG_INF("%s: n_mult: %u\n", __func__, params->n_mult);
LOG_INF("%s: n_head: %u\n", __func__, params->n_head);
LOG_INF("%s: n_head_kv: %u\n", __func__, params->n_head_kv);
LOG_INF("%s: n_ff: %u\n", __func__, params->n_ff);
LOG_INF("%s: n_layer: %u\n", __func__, params->n_layer);
LOG_INF("%s: n_rot: %u\n", __func__, params->n_rot);
}
static void print_tensor_info(const struct ggml_context * ctx) {
for (auto * t = ggml_get_first_tensor(ctx); t != NULL; t = ggml_get_next_tensor(ctx, t)) {
LOG_INF("%s: Allocating ", __func__);
int64_t total = 1;
int i = 0;
for (; i < ggml_n_dims(t); ++i) {
if (i > 0) { LOG_INF("x "); }
LOG_INF("[%" PRId64 "] ", t->ne[i]);
total *= t->ne[i];
}
if (i > 1) { LOG_INF("= [%" PRId64 "] ", total); }
LOG_INF("float space for %s\n", ggml_get_name(t));
}
}
static void init_model(struct my_llama_model * model) {
const auto & hparams = model->hparams;
const uint32_t n_embd = hparams.n_embd;
const uint32_t n_layer = hparams.n_layer;
const uint32_t n_vocab = hparams.n_vocab;
const uint32_t n_multiqueries = hparams.n_head_kv <= 0 || hparams.n_head_kv >= hparams.n_head ? 1 : hparams.n_head / hparams.n_head_kv;
const uint32_t n_ff = hparams.n_ff;
struct ggml_context * ctx = model->ctx;
model->train_its = 0;
model->train_samples = 0;
model->train_tokens = 0;
model->tok_embeddings = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, n_embd, n_vocab);
model->norm = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_embd);
model->output = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, n_embd, n_vocab);
ggml_set_name(model->tok_embeddings, "tok_embeddings.weight");
ggml_set_name(model->norm, "norm.weight");
ggml_set_name(model->output, "output.weight");
model->layers.resize(n_layer);
for (uint32_t i = 0; i < n_layer; ++i) {
auto & layer = model->layers[i];
std::string layers_i = "layers." + std::to_string(i);
layer.attention_norm = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_embd);
layer.wq = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, n_embd, n_embd);
layer.wk = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, n_embd, n_embd / n_multiqueries);
layer.wv = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, n_embd, n_embd / n_multiqueries);
layer.wo = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, n_embd, n_embd);
layer.ffn_norm = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_embd);
layer.w1 = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, n_embd, n_ff);
layer.w2 = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, n_ff, n_embd);
layer.w3 = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, n_embd, n_ff);
ggml_set_name(layer.attention_norm, (layers_i + ".attention_norm.weight").c_str());
ggml_set_name(layer.wq, (layers_i + ".attention.wq.weight").c_str());
ggml_set_name(layer.wk, (layers_i + ".attention.wk.weight").c_str());
ggml_set_name(layer.wv, (layers_i + ".attention.wv.weight").c_str());
ggml_set_name(layer.wo, (layers_i + ".attention.wo.weight").c_str());
ggml_set_name(layer.ffn_norm, (layers_i + ".ffn_norm.weight").c_str());
ggml_format_name(layer.w1, "%s.feed_forward.w1.weight", layers_i.c_str());
ggml_format_name(layer.w2, "%s.feed_forward.w2.weight", layers_i.c_str());
ggml_format_name(layer.w3, "%s.feed_forward.w3.weight", layers_i.c_str());
}
print_tensor_info(ctx);
}
static float get_f32_2d(struct ggml_tensor * tensor, int64_t i0, int64_t i1) {
float * ptr = (float *) ((char *) tensor->data + i0*tensor->nb[0] + i1*tensor->nb[1]);
return *ptr;
}
static int32_t get_i32_2d(struct ggml_tensor * tensor, int64_t i0, int64_t i1) {
int32_t * ptr = (int32_t *) ((char *) tensor->data + i0*tensor->nb[0] + i1*tensor->nb[1]);
return *ptr;
}
static void print_row(struct ggml_tensor * probs, int i) {
for (int k = 0; k < probs->ne[0]; ++k) {
float p = get_f32_2d(probs, k, i);
LOG(" %f", p);
}
LOG("\n");
}
static void print_matrix(struct ggml_tensor * probs) {
assert(ggml_is_matrix(probs));
for (int i = 0; i < probs->ne[1]; ++i) {
for (int k = 0; k < probs->ne[0]; ++k) {
float p = get_f32_2d(probs, k, i);
LOG(" %.2f", p);
}
LOG("\n");
}
}
struct my_llama_file {
// use FILE * so we don't have to re-open the file to mmap
FILE * fp;
size_t size;
my_llama_file(const char * fname, const char * mode) {
fp = std::fopen(fname, mode);
if (fp == NULL) {
size = 0;
} else {
seek(0, SEEK_END);
size = tell();
seek(0, SEEK_SET);
}
}
size_t tell() const {
#ifdef _WIN32
__int64 ret = _ftelli64(fp);
#else
long ret = std::ftell(fp);
#endif
GGML_ASSERT(ret != -1); // this really shouldn't fail
return (size_t) ret;
}
void seek(size_t offset, int whence) {
#ifdef _WIN32
int ret = _fseeki64(fp, (__int64) offset, whence);
#else
int ret = std::fseek(fp, (long) offset, whence);
#endif
GGML_ASSERT(ret == 0); // same
}
void read_raw(void * ptr, size_t size) {
if (size == 0) {
return;
}
errno = 0;
std::size_t ret = std::fread(ptr, size, 1, fp);
if (ferror(fp)) {
die_fmt("fread failed: %s", strerror(errno));
}
if (ret != 1) {
die("unexpectedly reached end of file");
}
}
std::uint32_t read_u32() {
std::uint32_t ret;
read_raw(&ret, sizeof(ret));
return ret;
}
std::float_t read_f32() {
std::float_t ret;
read_raw(&ret, sizeof(ret));
return ret;
}
std::string read_string(std::uint32_t len) {
std::vector<char> chars(len);
read_raw(chars.data(), len);
return std::string(chars.data(), len);
}
~my_llama_file() {
if (fp) {
std::fclose(fp);
}
}
};
static bool is_ggml_file(const char * filename) {
my_llama_file file(filename, "rb");
if (file.size < 4) {
return false;
}
std::string magic = file.read_string(4);
return magic == GGUF_MAGIC;
}
static std::string llama_escape_whitespaces(const std::string & text) {
std::ostringstream out;
for (char c : text) {
if (c == ' ') out << "\xe2\x96\x81";
else out << c;
}
return out.str();
}
static void load_vocab(const char * filename, const Config * config, struct my_llama_vocab * vocab) {
if (is_ggml_file(filename)) {
LOG_INF("%s: Loading vocabulary from gguf file %s\n", __func__, filename);
struct ggml_context * ctx_data = NULL;
struct gguf_init_params params = {
/*.no_alloc = */ false,
/*.ctx = */ &ctx_data,
};
struct gguf_context * ctx = gguf_init_from_file(filename, params);
GGML_ASSERT(ctx != NULL);
const int model_idx = gguf_find_key(ctx, KV_TOKENIZER_MODEL);
GGML_ASSERT(model_idx >= 0);
std::string tokenizer_name = gguf_get_val_str(ctx, model_idx);
GGML_ASSERT(tokenizer_name == TOKENIZER_NAME);
const int token_idx = gguf_find_key(ctx, KV_TOKENIZER_LIST);
GGML_ASSERT(token_idx >= 0);
const int score_idx = gguf_find_key(ctx, KV_TOKENIZER_SCORES);
GGML_ASSERT(score_idx >= 0);
const float * scores = (const float * ) gguf_get_arr_data(ctx, score_idx);
const int toktype_idx = gguf_find_key(ctx, KV_TOKENIZER_TOKEN_TYPE);
GGML_ASSERT(toktype_idx >= 0);
const int * toktypes = (const int * ) gguf_get_arr_data(ctx, toktype_idx);
const uint32_t n_vocab = gguf_get_arr_n(ctx, token_idx);
if (n_vocab != static_cast<uint32_t>(config->vocab_size)) {
die_fmt("vocab size mismatch: (gguf) %u != (llama2c) %d", n_vocab, config->vocab_size);
}
vocab->id_to_token.resize(n_vocab);
for (uint32_t i = 0; i < n_vocab; i++) {
std::string word = gguf_get_arr_str(ctx, token_idx, i);
vocab->token_to_id[word] = i;
auto & token_data = vocab->id_to_token[i];
token_data.text = std::move(word);
token_data.score = scores[i];
token_data.type = (llama_token_type) toktypes[i];
}
ggml_free(ctx_data);
gguf_free(ctx);
} else {
// assume llama2.c vocabulary
LOG_INF("%s: Assuming llama2.c vocabulary since %s is not a gguf file\n", __func__, filename);
my_llama_file file(filename, "rb");
if (!file.fp) {
die_fmt("%s: %s", strerror(errno), filename);
}
const int n_vocab = config->vocab_size;
/* uint32_t max_token_length = */ file.read_u32(); // unused
vocab->id_to_token.resize(n_vocab);
for (my_llama_vocab::id id=0; id<n_vocab; ++id) {
float_t score = file.read_f32();
uint32_t len = file.read_u32();
std::string text = file.read_string(len);
unsigned char byte_val;
my_llama_vocab::ttype type = LLAMA_TOKEN_TYPE_NORMAL;
if (id == UNKNOWN_TOKEN_ID) {
text = "<unk>";
type = LLAMA_TOKEN_TYPE_UNKNOWN;
} else if (id == BOS_TOKEN_ID) {
text = "<s>";
type = LLAMA_TOKEN_TYPE_CONTROL;
} else if (id == EOS_TOKEN_ID) {
text = "</s>";
type = LLAMA_TOKEN_TYPE_CONTROL;
} else if (text.empty()) {
type = LLAMA_TOKEN_TYPE_CONTROL;
} else if (sscanf(text.c_str(), "<0x%02hhX>", &byte_val) == 1) {
// Text of byte tokens is already in the expected format.
type = LLAMA_TOKEN_TYPE_BYTE;
} else {
type = LLAMA_TOKEN_TYPE_NORMAL;
}
text = llama_escape_whitespaces(text);
vocab->id_to_token[id].text = text;
vocab->id_to_token[id].score = score;
vocab->id_to_token[id].type = type;
vocab->token_to_id.emplace(text, id);
}
}
}
static void convert_weights_ak_to_gg(struct ggml_tensor * gg_weights, const float * karpathy_weights) {
int size = 1;
for (int dim = 0; dim < ggml_n_dims(gg_weights); ++dim) {
size *= gg_weights->ne[dim];
}
for (int ct = 0; ct < size; ++ct) {
int64_t i0 = 0; int64_t i1 = 0;
int64_t i2 = 0; int64_t i3 = 0;
ggml_unravel_index(gg_weights, ct, &i0, &i1, &i2, &i3);
ggml_set_f32_nd(gg_weights, i0, i1, i2, i3, karpathy_weights[ct]);
}
}
static void save_as_llama_model(
struct my_llama_vocab * vocab, struct my_llama_model * model, TransformerWeights* w, const char * filename
) {
// convert AK weights into GG weights one by one.
// w->token_embedding_table -> model->tok_embeddings
// float* -> struct ggml_tensor
convert_weights_ak_to_gg(model->tok_embeddings, w->token_embedding_table.data());
convert_weights_ak_to_gg(model->output, !w->wcls.empty() ? w->wcls.data() : w->token_embedding_table.data());
convert_weights_ak_to_gg(model->norm, w->rms_final_weight.data());
//print_row(model->norm, 0);
// for rms-att-weight
int row_length = model->hparams.n_embd;
int n_ff = model->hparams.n_ff;
const uint32_t n_multiqueries = model->hparams.n_head_kv <= 0 || model->hparams.n_head_kv >= model->hparams.n_head ? 1 : model->hparams.n_head / model->hparams.n_head_kv;
for (uint32_t i = 0; i < model->hparams.n_layer; ++i){
auto & layer = model->layers[i];
// 1d
convert_weights_ak_to_gg(layer.attention_norm, &w->rms_att_weight[i*row_length]);
convert_weights_ak_to_gg(layer.ffn_norm , &w->rms_ffn_weight[i*row_length]);
// from 3d matrix layer x dim x dim to 2d matrix dim x dim
convert_weights_ak_to_gg(layer.wq , &w->wq[i*row_length*row_length]);
convert_weights_ak_to_gg(layer.wo , &w->wo[i*row_length*row_length]);
// from 3d matrix layer x dim x dim to 2d matrix dim x dim / n_multiqueries
convert_weights_ak_to_gg(layer.wk , &w->wk[i*row_length*row_length/n_multiqueries]);
convert_weights_ak_to_gg(layer.wv , &w->wv[i*row_length*row_length/n_multiqueries]);
convert_weights_ak_to_gg(layer.w1 , &w->w1[i*row_length*n_ff]);
convert_weights_ak_to_gg(layer.w2 , &w->w2[i*n_ff*row_length]);
convert_weights_ak_to_gg(layer.w3 , &w->w3[i*row_length*n_ff]);
}
struct gguf_context * ctx = gguf_init_empty();
std::vector<const char*> tokens;
std::vector<float> scores;
std::vector<llama_token_type> token_types;
for (const my_llama_vocab::token_data & token_data : vocab->id_to_token) {
tokens.push_back(token_data.text.c_str());
scores.push_back(token_data.score);
token_types.push_back(token_data.type);
}
gguf_set_arr_str(ctx, KV_TOKENIZER_LIST, tokens.data(), tokens.size());
gguf_set_arr_data(ctx, KV_TOKENIZER_SCORES, GGUF_TYPE_FLOAT32, scores.data(), scores.size());
gguf_set_arr_data(ctx, KV_TOKENIZER_TOKEN_TYPE, GGUF_TYPE_INT32, token_types.data(), token_types.size());
gguf_set_val_str(ctx, KV_TOKENIZER_MODEL, TOKENIZER_NAME);
gguf_set_val_str(ctx, KV_GENERAL_ARCHITECTURE, "llama");
gguf_set_val_str(ctx, KV_GENERAL_NAME, "llama");
// special tokens
gguf_set_val_u32(ctx, KV_TOKENIZER_UNK_ID, UNKNOWN_TOKEN_ID);
gguf_set_val_u32(ctx, KV_TOKENIZER_BOS_ID, BOS_TOKEN_ID);
gguf_set_val_u32(ctx, KV_TOKENIZER_EOS_ID, EOS_TOKEN_ID);
gguf_set_val_u32(ctx, KV_TOKENIZER_SEP_ID, LLAMA_TOKEN_NULL);
gguf_set_val_u32(ctx, KV_TOKENIZER_PAD_ID, LLAMA_TOKEN_NULL);
gguf_set_val_u32(ctx, KV_CONTEXT_LENGTH, model->hparams.n_ctx);
gguf_set_val_u32(ctx, KV_EMBEDDING_LENGTH, model->hparams.n_embd);
gguf_set_val_u32(ctx, KV_FEED_FORWARD_LENGTH, model->hparams.n_ff);
gguf_set_val_u32(ctx, KV_ATTENTION_HEAD_COUNT, model->hparams.n_head);
gguf_set_val_u32(ctx, KV_ATTENTION_HEAD_COUNT, model->hparams.n_head);
gguf_set_val_u32(ctx, KV_ATTENTION_HEAD_COUNT_KV, model->hparams.n_head_kv);
gguf_set_val_u32(ctx, KV_BLOCK_COUNT, model->hparams.n_layer);
gguf_set_val_u32(ctx, KV_ROPE_DIMENSION_COUNT, model->hparams.n_rot);
gguf_set_val_f32(ctx, KV_ATTENTION_LAYERNORM_RMS_EPS, 1e-5f);
// write tensors
ggml_set_name(model->tok_embeddings, TN_TOKEN_EMBD);
gguf_add_tensor(ctx, model->tok_embeddings);
ggml_set_name(model->norm, TN_OUTPUT_NORM);
gguf_add_tensor(ctx, model->norm);
ggml_set_name(model->output, TN_OUTPUT);
gguf_add_tensor(ctx, model->output);
for (uint32_t i = 0; i < model->hparams.n_layer; ++i) {
auto & layer = model->layers[i];
ggml_format_name(layer.wq, TN_ATTN_Q, i);
gguf_add_tensor(ctx, layer.wq);
ggml_format_name(layer.wk, TN_ATTN_K, i);
gguf_add_tensor(ctx, layer.wk);
ggml_format_name(layer.wv, TN_ATTN_V, i);
gguf_add_tensor(ctx, layer.wv);
ggml_format_name(layer.wo, TN_ATTN_OUTPUT, i);
gguf_add_tensor(ctx, layer.wo);
ggml_format_name(layer.attention_norm, TN_ATTN_NORM, i);
gguf_add_tensor(ctx, layer.attention_norm);
ggml_format_name(layer.w1, TN_FFN_GATE, i);
gguf_add_tensor(ctx, layer.w1);
ggml_format_name(layer.w2, TN_FFN_DOWN, i);
gguf_add_tensor(ctx, layer.w2);
ggml_format_name(layer.w3, TN_FFN_UP, i);
gguf_add_tensor(ctx, layer.w3);
ggml_format_name(layer.ffn_norm, TN_FFN_NORM, i);
gguf_add_tensor(ctx, layer.ffn_norm);
}
gguf_write_to_file(ctx, filename, false);
gguf_free(ctx);
}
static struct train_params get_default_train_params() {
struct train_params params;
params.fn_vocab_model = "models/7B/ggml-model-f16.gguf";
params.fn_llama2c_output_model = "ak_llama_model.bin";
params.fn_train_data = "shakespeare.txt";
params.fn_checkpoint_in = "checkpoint.bin";
params.fn_checkpoint_out = "checkpoint.bin";
params.fn_model_out = "ggml-checkpoint-f32.bin";
params.seed = -1;
params.n_ctx = 128;
params.n_embd = 256;
params.n_mult = 256;
params.n_head = 8;
params.n_layer = 16;
params.n_rotmax = 64;
params.n_threads = 6;
params.n_batch = 8;
params.n_examples = 8;
params.n_predict = 1024;
params.print_info_interval = 1;
params.print_details_interval = 2;
params.samples_start_after_nl = false;
params.use_adam = true;
params.use_flash = false;
params.use_scratch = true;
// only adam
params.warmup = 100;
params.cos_decay_steps = 1000;
params.cos_decay_restart = 1.1f;
params.cos_decay_alpha = 0.0f;
params.lbfgs_n_iter = 16;
params.adam_n_iter = 16;
params.adam_alpha = 1e-3f;
params.adam_decay = 1e-3f;
params.mem_model_gb = 2;
params.mem_compute_gb = 24;
params.mem_compute0_gb = 8;
params.mem_compute1_gb = 2;
return params;
}
static void print_usage(int /*argc*/, char ** argv, const struct train_params * params) {
fprintf(stderr, "usage: %s [options]\n", argv[0]);
fprintf(stderr, "\n");
fprintf(stderr, "options:\n");
fprintf(stderr, " -h, --help show this help message and exit\n");
fprintf(stderr, " --copy-vocab-from-model FNAME path of gguf llama model or llama2.c vocabulary from which to copy vocab (default '%s')\n", params->fn_vocab_model);
fprintf(stderr, " --llama2c-model FNAME [REQUIRED] model path from which to load Karpathy's llama2.c model\n");
fprintf(stderr, " --llama2c-output-model FNAME model path to save the converted llama2.c model (default %s')\n", params->fn_llama2c_output_model);
fprintf(stderr, "\n");
}
static bool params_parse(int argc, char ** argv, struct train_params * params) {
bool invalid_param = false;
bool reqd_param_found = false;
std::string arg;
struct train_params default_params = get_default_train_params();
const std::string arg_prefix = "--";
for (int i = 1; i < argc; i++) {
arg = argv[i];
if (arg.compare(0, arg_prefix.size(), arg_prefix) == 0) {
std::replace(arg.begin(), arg.end(), '_', '-');
}
if (arg == "--copy-vocab-from-model") {
if (++i >= argc) {
invalid_param = true;
break;
}
params->fn_vocab_model = argv[i];
} else if (arg == "--llama2c-model") {
if (++i >= argc) {
invalid_param = true;
break;
}
reqd_param_found = true;
params->fn_llama2c_model = argv[i];
} else if (arg == "--llama2c-output-model") {
if (++i >= argc) {
invalid_param = true;
break;
}
params->fn_llama2c_output_model = argv[i];
} else if (arg == "-h" || arg == "--help") {
print_usage(argc, argv, &default_params);
exit(0);
} else {
fprintf(stderr, "error: unknown argument: %s\n", arg.c_str());
print_usage(argc, argv, &default_params);
exit(1);
}
}
if (invalid_param) {
fprintf(stderr, "error: invalid parameter for argument: %s\n", arg.c_str());
print_usage(argc, argv, &default_params);
exit(1);
}
if (!reqd_param_found){
fprintf(stderr, "error: please specify a llama2.c .bin file to be converted with argument --llama2c-model\n");
print_usage(argc, argv, &default_params);
exit(1);
}
return true;
}
static std::string basename(const std::string &path) {
size_t pos = path.find_last_of("/\\");
if (pos == std::string::npos) {
return path;
}
return path.substr(pos + 1);
}
int main(int argc, char ** argv) {
common_init();
struct train_params params = get_default_train_params();
if (!params_parse(argc, argv, &params)) {
return 1;
}
Config config;
TransformerWeights weights = {};
{
LOG_INF("%s: Loading llama2c model from %s\n", __func__, params.fn_llama2c_model);
FILE * file = fopen(params.fn_llama2c_model, "rb");
if (!file) {
LOG_ERR("%s: Unable to open the checkpoint file %s!\n", __func__, params.fn_llama2c_model);
return 1;
}
// read in the config header
if (fread(&config, sizeof(Config), 1, file) != 1) {
LOG_ERR("%s: Unable to read llama2c config from %s!\n",__func__,params.fn_llama2c_model);
return 1;
}
auto shared_weights = config.vocab_size > 0;
config.vocab_size = abs(config.vocab_size);
// read in the Transformer weights
alloc_weights(&weights, &config, shared_weights);
if (checkpoint_init_weights(&weights, &config, file, shared_weights)) {
LOG_ERR("%s: Unable to initialize transformer weights from %s!",__func__,params.fn_llama2c_model);
return 1;
}
fclose(file);
}
struct my_llama_vocab vocab;
load_vocab(params.fn_vocab_model, &config, &vocab);
struct my_llama_model model;
model.hparams.n_vocab = config.vocab_size; //llama_vocab_n_vocab(lctx);
model.hparams.n_ctx = params.n_ctx;
model.hparams.n_embd = config.dim; //params.n_embd;
model.hparams.n_ff = config.hidden_dim;
model.hparams.n_mult = 32;//params.n_mult;
model.hparams.n_head = config.n_heads; //params.n_head;
model.hparams.n_head_kv = config.n_kv_heads;
model.hparams.n_layer = config.n_layers; //params.n_layer;
model.hparams.n_rot = std::min((uint32_t)params.n_rotmax, model.hparams.n_embd / model.hparams.n_head);
print_params(&model.hparams);
struct ggml_init_params lcparams;
lcparams.mem_size = 1024ll*1024ll*1024ll*((size_t) params.mem_model_gb);
lcparams.mem_buffer = NULL;
lcparams.no_alloc = false;
model.ctx = ggml_init(lcparams);
init_model(&model);
model.name = basename(params.fn_llama2c_model);
save_as_llama_model(&vocab, &model, &weights, params.fn_llama2c_output_model);
LOG_INF("%s: Saving llama.c model file %s in ggml format at %s\n", __func__, params.fn_llama2c_model, params.fn_llama2c_output_model);
ggml_free(model.ctx);
return 0;
}

1462
examples/convert_legacy_llama.py Executable file

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,5 @@
set(TARGET llama-debug)
add_executable(${TARGET} debug.cpp)
install(TARGETS ${TARGET} RUNTIME)
target_link_libraries(${TARGET} PRIVATE common llama ${CMAKE_THREAD_LIBS_INIT})
target_compile_features(${TARGET} PRIVATE cxx_std_17)

54
examples/debug/README.md Normal file
View File

@@ -0,0 +1,54 @@
# llama.cpp/examples/debug
This is a utility intended to help debug a model by registering a callback that
logs GGML operations and tensor data. It can also store the generated logits or
embeddings as well as the prompt and token ids for comparision with the original
model.
### Usage
```shell
llama-debug \
--hf-repo ggml-org/models \
--hf-file phi-2/ggml-model-q4_0.gguf \
--model phi-2-q4_0.gguf \
--prompt hello \
--save-logits \
--verbose
```
The tensor data is logged as debug and required the --verbose flag. The reason
for this is that while useful for a model with many layers there can be a lot of
output. You can filter the tensor names using the `--tensor-filter` option.
A recommended approach is to first run without `--verbose` and see if the
generated logits/embeddings are close to the original model. If they are not,
then it might be required to inspect tensor by tensor and in that case it is
useful to enable the `--verbose` flag along with `--tensor-filter` to focus on
specific tensors.
### Options
This example supports all standard `llama.cpp` options and also accepts the
following options:
```console
$ llama-debug --help
...
----- example-specific params -----
--save-logits save final logits to files for verification (default: false)
--logits-output-dir PATH directory for saving logits output files (default: data)
--tensor-filter REGEX filter tensor names for debug output (regex pattern, can be specified multiple times)
```
### Output Files
When `--save-logits` is enabled, the following files are created in the output
directory:
* `llamacpp-<model>[-embeddings].bin` - Binary output (logits or embeddings)
* `llamacpp-<model>[-embeddings].txt` - Text output (logits or embeddings, one per line)
* `llamacpp-<model>[-embeddings]-prompt.txt` - Prompt text and token IDs
* `llamacpp-<model>[-embeddings]-tokens.bin` - Binary token IDs for programmatic comparison
These files can be compared against the original model's output to verify the
converted model.

253
examples/debug/debug.cpp Normal file
View File

@@ -0,0 +1,253 @@
#include "debug.h"
#include "arg.h"
#include "common.h"
#include "log.h"
#include "llama.h"
#include <cstdlib>
#include <string>
#include <vector>
#include <filesystem>
#include <fstream>
#include <regex>
static void print_usage(int /*argc*/, char ** argv) {
const std::string usage_template = R"(
example usage:
Print tensors:
{prog} -m model.gguf -p "Hello my name is" --verbose
The tensors to be printed can be filtered with --tensor-filter option.
Save logits/embeddings:
{prog} -m model.gguf -p "Hello my name is" --save-logits
Add --embedding to save embeddings)" "\n";
// Fix the source code indentation above that is introduced by the raw string literal.
std::string usage = std::regex_replace(usage_template, std::regex("\\n {8}"), "\n");
usage = std::regex_replace(usage, std::regex("\\{prog\\}"), argv[0]);
LOG("%s\n", usage.c_str());
}
static bool has_pooling(llama_context * ctx) {
switch (llama_pooling_type(ctx)) {
case LLAMA_POOLING_TYPE_NONE:
case LLAMA_POOLING_TYPE_UNSPECIFIED:
return false;
default:
return true;
}
}
struct output_data {
float * data_ptr = nullptr;
int data_size = 0;
std::string type_suffix;
std::vector<float> embd_norm;
std::string prompt;
std::vector<llama_token> tokens;
output_data(llama_context * ctx, const llama_model * model, const common_params & params) {
const llama_vocab * vocab = llama_model_get_vocab(model);
const bool add_bos = llama_vocab_get_add_bos(vocab);
tokens = common_tokenize(ctx, params.prompt, add_bos);
prompt = params.prompt;
if (params.embedding) {
const int n_embd = llama_model_n_embd_out(model);
const bool pooling = has_pooling(ctx);
const int n_embd_count = pooling ? 1 : tokens.size();
const int n_floats = n_embd * n_embd_count;
float * embd_raw = pooling ? llama_get_embeddings_seq(ctx, 0) : llama_get_embeddings(ctx);
if (embd_raw == nullptr) {
throw std::runtime_error("failed to get embeddings from the model");
}
LOG_DBG("pooling_enabled: %s\n", pooling ? "true" : "false");
LOG_DBG("n_embd: %d\n", n_embd);
LOG_DBG("n_floats: %d\n", n_floats);
LOG_DBG("n_embd_count: %d\n", n_embd_count);
data_ptr = embd_raw;
data_size = n_floats;
type_suffix = "-embeddings";
if (params.embd_normalize >= 0) {
embd_norm.resize(n_floats);
for (int i = 0; i < n_embd_count; i++) {
common_embd_normalize(embd_raw+i*n_embd, embd_norm.data()+i*n_embd, n_embd, params.embd_normalize);
}
data_ptr = embd_norm.data();
}
} else {
const float * logits = llama_get_logits_ith(ctx, tokens.size() - 1);
const int n_logits = llama_vocab_n_tokens(vocab);
data_ptr = const_cast<float*>(logits);
data_size = n_logits;
type_suffix = "";
}
}
};
static void save_output_data(const output_data & output, const std::string & model_name, const std::string & output_dir) {
std::filesystem::create_directory(output_dir);
auto base_path = std::filesystem::path{output_dir} / ("llamacpp-" + model_name + output.type_suffix);
// Save logits/embeddings to binary file.
{
std::filesystem::path filepath{base_path.string() + ".bin"};
std::ofstream file{filepath, std::ios::binary};
if (!file) {
throw std::runtime_error("failed to open binary output file: " + filepath.string());
}
file.write(reinterpret_cast<const char*>(output.data_ptr), output.data_size * sizeof(float));
LOG("Data saved to %s\n", filepath.c_str());
}
// Save logits/embeddings to text file.
{
std::filesystem::path filepath{base_path.string() + ".txt"};
std::ofstream file{filepath};
if (!file) {
throw std::runtime_error("failed to open text output file: " + filepath.string());
}
for (int i = 0; i < output.data_size; i++) {
file << i << ": " << output.data_ptr[i] << '\n';
}
LOG("Data saved to %s\n", filepath.c_str());
}
// Save prompt and tokens to text file.
{
std::filesystem::path filepath{base_path.string() + "-prompt.txt"};
std::ofstream file{filepath};
if (!file) {
throw std::runtime_error("failed to open prompt output file: " + filepath.string());
}
file << "prompt: " << output.prompt << '\n';
file << "n_tokens: " << output.tokens.size() << '\n';
file << "token ids: ";
for (size_t i = 0; i < output.tokens.size(); i++) {
file << output.tokens[i];
if (i + 1 < output.tokens.size()) {
file << ", ";
}
}
file << '\n';
LOG("Prompt saved to %s\n", filepath.c_str());
}
// Save token ids to binary file.
{
std::filesystem::path filepath{base_path.string() + "-tokens.bin"};
std::ofstream file{filepath, std::ios::binary};
if (!file) {
throw std::runtime_error("failed to open tokens binary file: " + filepath.string());
}
file.write(reinterpret_cast<const char*>(output.tokens.data()), output.tokens.size() * sizeof(llama_token));
LOG("Tokens saved to %s\n", filepath.c_str());
}
}
static void print_tokenized_prompt(llama_context * ctx, const std::vector<llama_token> & tokens, const std::string & prompt) {
const llama_model * model = llama_get_model(ctx);
const llama_vocab * vocab = llama_model_get_vocab(model);
LOG("Model add_bos: %s\n", llama_vocab_get_add_bos(vocab) ? "true" : "false");
LOG("Input prompt: \"%s\"\n", prompt.c_str());
LOG("Token ids (%zu):\n", tokens.size());
for (auto id : tokens) {
std::string piece(128, '\0');
int n = llama_token_to_piece(vocab, id, piece.data(), piece.size(), 0, true);
if (n < 0) {
LOG_ERR("failed to convert token %d to piece\n", id);
continue;
}
piece.resize(n);
LOG("%s(%d) ", piece.c_str(), id);
}
LOG("\n");
}
static bool run(llama_context * ctx, const common_params & params) {
const llama_model * model = llama_get_model(ctx);
const llama_vocab * vocab = llama_model_get_vocab(model);
const bool add_bos = llama_vocab_get_add_bos(vocab);
std::vector<llama_token> tokens = common_tokenize(ctx, params.prompt, add_bos);
if (tokens.empty()) {
LOG_ERR("%s : there are not input tokens to process - (try to provide a prompt with '-p')\n", __func__);
return false;
}
if (llama_decode(ctx, llama_batch_get_one(tokens.data(), tokens.size()))) {
LOG_ERR("%s : failed to eval\n", __func__);
return false;
}
print_tokenized_prompt(ctx, tokens, params.prompt);
if (params.save_logits) {
output_data output {ctx, model, params};
std::filesystem::path model_path{params.model.path};
std::string model_name{model_path.stem().string()};
save_output_data(output, model_name, params.logits_output_dir);
}
return true;
}
int main(int argc, char ** argv) {
common_params params;
if (!common_params_parse(argc, argv, params, LLAMA_EXAMPLE_DEBUG, print_usage)) {
return 1;
}
common_init();
llama_backend_init();
llama_numa_init(params.numa);
base_callback_data cb_data(params, params.tensor_filter);
auto llama_init = common_init_from_params(params);
auto * model = llama_init->model();
auto * ctx = llama_init->context();
if (model == nullptr || ctx == nullptr) {
LOG_ERR("%s : failed to init\n", __func__);
return 1;
}
{
LOG_INF("\n");
LOG_INF("%s\n", common_params_get_system_info(params).c_str());
LOG_INF("\n");
}
if (!run(ctx, params)) {
return 1;
}
LOG("\n");
llama_perf_context_print(ctx);
llama_backend_free();
return 0;
}

View File

@@ -0,0 +1,49 @@
# Migration notice for binary filenames
> [!IMPORTANT]
[2024 Jun 12] Binaries have been renamed w/ a `llama-` prefix. `main` is now `llama-cli`, `server` is `llama-server`, etc (https://github.com/ggerganov/llama.cpp/pull/7809)
This migration was important, but it is a breaking change that may not always be immediately obvious to users.
Please update all scripts and workflows to use the new binary names.
| Old Filename | New Filename |
| ---- | ---- |
| main | llama-cli |
| server | llama-server |
| llama-bench | llama-bench |
| embedding | llama-embedding |
| quantize | llama-quantize |
| tokenize | llama-tokenize |
| export-lora | llama-export-lora |
| libllava.a | libllava.a |
| baby-llama | llama-baby-llama |
| batched | llama-batched |
| batched-bench | llama-batched-bench |
| benchmark-matmult | llama-benchmark-matmult |
| convert-llama2c-to-ggml | llama-convert-llama2c-to-ggml |
| eval-callback | llama-eval-callback |
| gbnf-validator | llama-gbnf-validator |
| gguf | llama-gguf |
| gguf-split | llama-gguf-split |
| gritlm | llama-gritlm |
| imatrix | llama-imatrix |
| infill | llama-infill |
| llava-cli | llama-llava-cli |
| lookahead | llama-lookahead |
| lookup | llama-lookup |
| lookup-create | llama-lookup-create |
| lookup-merge | llama-lookup-merge |
| lookup-stats | llama-lookup-stats |
| parallel | llama-parallel |
| passkey | llama-passkey |
| perplexity | llama-perplexity |
| q8dot | llama-q8dot |
| quantize-stats | llama-quantize-stats |
| retrieval | llama-retrieval |
| save-load-state | llama-save-load-state |
| simple | llama-simple |
| speculative | llama-speculative |
| vdot | llama-vdot |
| tests/test-c.o | tests/test-c.o |

View File

@@ -0,0 +1,35 @@
// Warns users that this filename was deprecated, and provides a link for more information.
#include <cstdio>
#include <string>
#include <unordered_map>
// Main
int main(int argc, char** argv) {
std::string filename = "main";
if (argc >= 1) {
filename = argv[0];
}
// Get only the program name from the full path
auto pos = filename.find_last_of("/\\");
if (pos != std::string::npos) {
filename = filename.substr(pos+1);
}
// Append "llama-" to the beginning of filename to get the replacemnt filename
auto replacement_filename = "llama-" + filename;
// The exception is if the filename is "main", then our replacement filename is "llama-cli"
if (filename == "main") {
replacement_filename = "llama-cli";
}
fprintf(stdout, "\n");
fprintf(stdout, "WARNING: The binary '%s' is deprecated.\n", filename.c_str());
fprintf(stdout, " Please use '%s' instead.\n", replacement_filename.c_str());
fprintf(stdout, " See https://github.com/ggerganov/llama.cpp/tree/master/examples/deprecation-warning/README.md for more information.\n");
fprintf(stdout, "\n");
return EXIT_FAILURE;
}

View File

@@ -0,0 +1,5 @@
set(TARGET llama-diffusion-cli)
add_executable(${TARGET} diffusion-cli.cpp)
install(TARGETS ${TARGET} RUNTIME)
target_link_libraries(${TARGET} PRIVATE llama common ${CMAKE_THREAD_LIBS_INIT})
target_compile_features(${TARGET} PRIVATE cxx_std_17)

View File

@@ -0,0 +1,59 @@
# Diffusion Text Generation
This directory contains implementations for Diffusion LLMs (DLLMs)
More Info:
- https://github.com/ggml-org/llama.cpp/pull/14644
- https://github.com/ggml-org/llama.cpp/pull/14771
## Parameters
The diffusion CLI supports various parameters to control the generation process:
### Core Diffusion Parameters
- `--diffusion-steps`: Number of diffusion steps (default: 256)
- `--diffusion-algorithm`: Algorithm for token selection
- `0`: ORIGIN - Token will be generated in a purely random order from https://arxiv.org/abs/2107.03006.
- `1`: ENTROPY_BASED - Entropy-based selection
- `2`: MARGIN_BASED - Margin-based selection
- `3`: RANDOM - Random selection
- `4`: CONFIDENCE_BASED - Confidence-based selection (default)
- More documentation here https://github.com/DreamLM/Dream
- `--diffusion-visual`: Enable live visualization during generation
### Scheduling Parameters
Choose one of the following scheduling methods:
**Timestep-based scheduling:**
- `--diffusion-eps`: Epsilon value for timestep scheduling (e.g., 0.001)
**Block-based scheduling:**
- `--diffusion-block-length`: Block size for block-based scheduling (e.g., 32)
### Sampling Parameters
- `--temp`: Temperature for sampling (0.0 = greedy/deterministic, higher = more random)
- `--top-k`: Top-k filtering for sampling
- `--top-p`: Top-p (nucleus) filtering for sampling
- `--seed`: Random seed for reproducibility
### Model Parameters
- `-m`: Path to the GGUF model file
- `-p`: Input prompt text
- `-ub`: Maximum sequence length (ubatch size)
- `-c`: Context size
- `-b`: Batch size
### Examples
#### Dream architechture:
```
llama-diffusion-cli -m dream7b.gguf -p "write code to train MNIST in pytorch" -ub 512 --diffusion-eps 0.001 --diffusion-algorithm 3 --diffusion-steps 256 --diffusion-visual
```
#### LLaDA architechture:
```
llama-diffusion-cli -m llada-8b.gguf -p "write code to train MNIST in pytorch" -ub 512 --diffusion-block-length 32 --diffusion-steps 256 --diffusion-visual
```
#### RND1 architecture:
```
llama-diffusion-cli -m RND1-Base-0910.gguf -p "write code to train MNIST in pytorch" -ub 512 --diffusion-algorithm 1 --diffusion-steps 256 --diffusion-visual --temp 0.5 --diffusion-eps 0.001
```

View File

@@ -0,0 +1,694 @@
#include "arg.h"
#include "chat.h"
#include "common.h"
#include "llama.h"
#include "log.h"
#include <limits.h>
#include <algorithm>
#include <cmath>
#include <cstring>
#include <limits>
#include <random>
#include <string>
#include <vector>
enum diffusion_algorithm { ORIGIN = 0, ENTROPY_BASED = 1, MARGIN_BASED = 2, RANDOM = 3, CONFIDENCE_BASED = 4 };
// Unified transfer scheduling methods
enum transfer_schedule {
TIMESTEP_BASED = 0, // Dream-style: (1.0 - s/t) * remaining
BLOCK_BASED = 1, // LLaDA-style: process in blocks with get_num_transfer_tokens
};
typedef bool (*diffusion_step_callback_t)(int32_t step,
int32_t total_steps,
const llama_token * tokens,
int32_t n_tokens,
void * user_data);
struct diffusion_params {
int32_t steps = 0;
float temperature = 0;
llama_token mask_token_id = LLAMA_TOKEN_NULL;
diffusion_step_callback_t step_callback = nullptr;
void * step_callback_user_data = nullptr;
int32_t seed = 0;
bool visual_mode = false;
bool shift_logits = false; // Shift logits by -1 after decode
float top_p = 0.;
int32_t top_k = 0.;
diffusion_algorithm algorithm = CONFIDENCE_BASED;
transfer_schedule schedule = TIMESTEP_BASED;
float cfg_scale = 0.; // Config scale for classifier-free guidance
float eps = 0.; // Timestep scheduling
int32_t block_length = 0; // Block size (for block scheduling)
float alg_temp = 0; // algorithm temperature (0.0 = deterministic)
bool add_gumbel_noise = false; // Add gumbel noise to the logits if temp > 0.0
int32_t max_length = 0; // Maximum sequence length
};
struct callback_data {
diffusion_params * diff_params;
const llama_vocab * vocab;
int32_t n_input;
};
static float calculate_confidence(const llama_token_data_array & cur_p,
diffusion_algorithm algorithm,
std::mt19937 & rng) {
switch (algorithm) {
case CONFIDENCE_BASED:
return cur_p.data[cur_p.selected].p; // Selected token probability
case ENTROPY_BASED:
{
float entropy = 0.0f;
const float epsilon = 1e-10f;
for (size_t i = 0; i < cur_p.size; i++) {
float prob = cur_p.data[i].p;
entropy += prob * logf(prob + epsilon);
}
return -entropy; // Higher entropy = lower confidence
}
case MARGIN_BASED:
return (cur_p.size > 1) ? cur_p.data[0].p - cur_p.data[1].p : cur_p.data[0].p;
case RANDOM:
{
std::uniform_real_distribution<float> uniform(0.0f, 1.0f);
return uniform(rng); // Random confidence
}
case ORIGIN:
return cur_p.data[cur_p.selected].p;
default:
return 0.0f;
}
}
// Unified transfer count calculation function
static int32_t calculate_transfer_count(int32_t step,
int32_t total_steps,
int32_t remaining_masked,
transfer_schedule schedule,
float eps,
const std::vector<int32_t> & num_transfer_tokens = {}) {
switch (schedule) {
case TIMESTEP_BASED:
{
float t = 1.0f - (float) step / total_steps * (1.0f - eps);
float s = 1.0f - (float) (step + 1) / total_steps * (1.0f - eps);
float p_transfer = (step < total_steps - 1) ? (1.0f - s / t) : 1.0f;
return (int32_t) (remaining_masked * p_transfer);
}
case BLOCK_BASED:
if (!num_transfer_tokens.empty() && step < (int32_t) num_transfer_tokens.size()) {
return num_transfer_tokens[step];
}
return remaining_masked / (total_steps - step); // Fallback
default:
return remaining_masked / (total_steps - step);
}
}
static bool diffusion_step_callback(int32_t step,
int32_t total_steps,
const llama_token * tokens,
int32_t n_tokens,
void * user_data) {
(void) user_data;
callback_data * data = static_cast<callback_data *>(user_data);
auto print_progress_bar = [](int32_t step, int32_t total_steps) {
int progress_percent = (step * 100) / total_steps;
int progress_bars = (step * 50) / total_steps;
LOG_INF("\rdiffusion step: %d/%d [%s%s] %d%%",
step,
total_steps,
std::string(progress_bars, '=').c_str(),
std::string(50 - progress_bars, ' ').c_str(),
progress_percent);
};
if (data->diff_params->visual_mode) {
// Visual mode: clear
LOG_INF("\033[2J\033[H"); // Clear screen and move cursor to top-left
print_progress_bar(step, total_steps);
LOG_INF("\n");
std::string current_text = " ";
for (int32_t i = data->n_input; i < n_tokens; i++) {
std::string token_str;
if (tokens[i] != llama_vocab_mask(data->vocab)) {
char piece[256];
int n_chars = llama_token_to_piece(data->vocab, tokens[i], piece, sizeof(piece), 0, false);
if (n_chars > 0) {
piece[n_chars] = '\0';
token_str = piece;
}
} else {
token_str = " ";
}
current_text += token_str;
}
LOG_INF("%s\n", current_text.c_str());
} else {
print_progress_bar(step, total_steps);
}
return true;
}
static void add_gumbel_noise(float * logits, int32_t n_vocab, float temperature, std::mt19937 & rng) {
if (temperature == 0.0f) {
return;
}
std::uniform_real_distribution<double> uniform(0.0, 1.0);
for (int32_t i = 0; i < n_vocab; i++) {
double noise = uniform(rng);
// Prevent log(0)
noise = std::max(noise, 1e-20);
double gumbel_noise = std::pow(-std::log(noise), temperature);
logits[i] = std::exp(logits[i]) / gumbel_noise;
}
}
static std::vector<int32_t> get_num_transfer_tokens(int32_t mask_count, int32_t steps) {
std::vector<int32_t> num_transfer_tokens(steps);
int32_t base = mask_count / steps;
int32_t remainder = mask_count % steps;
for (int32_t i = 0; i < steps; i++) {
num_transfer_tokens[i] = base + (i < remainder ? 1 : 0);
}
return num_transfer_tokens;
}
static void diffusion_generate(llama_context * ctx,
const llama_token * input_tokens,
llama_token * output_tokens,
int32_t n_input,
const diffusion_params & params,
int32_t & n_generated) {
n_generated = 0;
if (!ctx || !input_tokens || !output_tokens || n_input <= 0 || params.max_length <= n_input) {
return;
}
const llama_model * model = llama_get_model(ctx);
// Initialize with input and pad with mask tokens
std::copy(input_tokens, input_tokens + n_input, output_tokens);
std::fill(output_tokens + n_input, output_tokens + params.max_length, params.mask_token_id);
std::mt19937 rng(params.seed);
llama_set_causal_attn(ctx, false);
int32_t n_vocab = llama_vocab_n_tokens(llama_model_get_vocab(model));
std::vector<llama_token_data> candidates(n_vocab);
std::vector<llama_token_data> conf_candidates;
conf_candidates.reserve(params.max_length);
std::vector<int32_t> mask_positions;
mask_positions.reserve(params.max_length);
// Setup sampler chain
struct llama_sampler * sampler = llama_sampler_chain_init(llama_sampler_chain_default_params());
if (params.top_k > 0) {
llama_sampler_chain_add(sampler, llama_sampler_init_top_k(params.top_k));
}
if (params.top_p < 1.0f) {
llama_sampler_chain_add(sampler, llama_sampler_init_top_p(params.top_p, 1));
}
if (params.temperature > 0.0f) {
llama_sampler_chain_add(sampler, llama_sampler_init_temp(params.temperature));
}
llama_sampler_chain_add(sampler, llama_sampler_init_dist(params.seed));
struct llama_sampler * dist_sampler = llama_sampler_init_dist(params.seed);
llama_batch batch = llama_batch_init(params.max_length, 0, 1);
batch.n_tokens = params.max_length;
// Pre-allocate buffers for CFG if needed
int32_t logits_size = n_vocab * params.max_length;
std::vector<float> cond_logits_buffer;
std::vector<llama_token> un_x_buffer;
if (params.cfg_scale > 0.0f) {
cond_logits_buffer.resize(logits_size);
un_x_buffer.resize(params.max_length);
}
// For block-based processing
std::vector<int32_t> num_transfer_tokens;
int32_t num_blocks = 1;
int32_t steps_per_block = params.steps;
if (params.schedule == BLOCK_BASED) {
GGML_ASSERT(params.max_length % params.block_length == 0);
num_blocks = params.max_length / params.block_length;
GGML_ASSERT(params.steps % num_blocks == 0);
steps_per_block = params.steps / num_blocks;
}
std::vector<float> confidence(params.max_length);
int64_t total_sampling_time = 0;
int64_t total_time = 0;
int64_t time_start = ggml_time_us();
for (int block_num = 0; block_num < num_blocks; block_num++) {
int32_t block_start = (params.schedule == BLOCK_BASED) ? n_input + block_num * params.block_length : 0;
int32_t block_end = (params.schedule == BLOCK_BASED) ?
std::min(n_input + (block_num + 1) * params.block_length, params.max_length) :
params.max_length;
// Count masked tokens in current block for block-based processing
if (params.schedule == BLOCK_BASED) {
int32_t block_mask_count = 0;
for (int i = block_start; i < block_end; i++) {
if (output_tokens[i] == params.mask_token_id) {
block_mask_count++;
}
}
num_transfer_tokens = get_num_transfer_tokens(block_mask_count, steps_per_block);
}
for (int32_t step = 0; step < steps_per_block; step++) {
int32_t global_step = block_num * steps_per_block + step;
if (params.step_callback) {
if (!params.step_callback(
global_step, params.steps, output_tokens, params.max_length, params.step_callback_user_data)) {
break;
}
}
// Setup batch
for (int32_t i = 0; i < params.max_length; i++) {
batch.token[i] = output_tokens[i];
batch.pos[i] = i;
batch.n_seq_id[i] = 1;
batch.seq_id[i][0] = 0;
batch.logits[i] = 1;
}
float * logits = nullptr;
if (params.cfg_scale > 0.0f) {
int ret = llama_decode(ctx, batch);
if (ret != 0) {
LOG_ERR("Failed to generate conditional");
break;
}
float * cond_logits_ptr = llama_get_logits(ctx);
std::memcpy(cond_logits_buffer.data(), cond_logits_ptr, logits_size * sizeof(float));
// Unconditional generation (mask input)
std::copy(output_tokens, output_tokens + params.max_length, un_x_buffer.begin());
for (int32_t i = 0; i < n_input; i++) {
un_x_buffer[i] = params.mask_token_id;
}
for (int32_t i = 0; i < params.max_length; i++) {
batch.token[i] = un_x_buffer[i];
}
ret = llama_decode(ctx, batch);
if (ret != 0) {
LOG_ERR("Failed to generate unconditional");
break;
}
float * uncond_logits = llama_get_logits(ctx);
// Apply CFG
for (int32_t i = 0; i < logits_size; i++) {
cond_logits_buffer[i] =
uncond_logits[i] + (params.cfg_scale + 1.0f) * (cond_logits_buffer[i] - uncond_logits[i]);
}
logits = cond_logits_buffer.data();
} else {
int ret = llama_decode(ctx, batch);
if (ret != 0) {
LOG_ERR("%s: failed to decode at step %d, ret = %d\n", __func__, global_step, ret);
break;
}
logits = llama_get_logits(ctx);
}
if (!logits) {
LOG_ERR("%s: failed to get logits at step %d\n", __func__, global_step);
break;
}
auto get_logits_for_pos = [&](int32_t pos) -> const float * {
if (params.shift_logits) {
return pos == 0 ? logits : logits + (pos - 1) * n_vocab;
}
return logits + (pos) *n_vocab;
};
int64_t time_start_sampling = ggml_time_us();
mask_positions.clear();
for (int32_t i = 0; i < params.max_length; i++) {
if (output_tokens[i] == params.mask_token_id) {
// For block-based, only consider current block
if (params.schedule != BLOCK_BASED || (i >= block_start && i < block_end)) {
mask_positions.push_back(i);
}
}
}
if (mask_positions.empty()) {
break;
}
if (params.add_gumbel_noise && params.temperature > 0.0f) {
add_gumbel_noise(logits, n_vocab, params.temperature, rng);
}
if (params.algorithm == ORIGIN) {
int32_t transfer_count = calculate_transfer_count(
step, steps_per_block, mask_positions.size(), params.schedule, params.eps, num_transfer_tokens);
float p_transfer = (float) transfer_count / mask_positions.size();
for (int32_t pos : mask_positions) {
if (std::uniform_real_distribution<float>(0.0f, 1.0f)(rng) < p_transfer) {
const float * pos_logits = get_logits_for_pos(pos);
for (int32_t token_id = 0; token_id < n_vocab; token_id++) {
candidates[token_id].id = token_id;
candidates[token_id].logit = pos_logits[token_id];
candidates[token_id].p = 0.0f;
}
llama_token_data_array cur_p = {
candidates.data(),
(size_t) n_vocab,
-1,
false,
};
llama_sampler_apply(sampler, &cur_p);
output_tokens[pos] = cur_p.data[cur_p.selected].id;
}
}
} else {
std::vector<std::pair<float, int32_t>> confidences;
std::vector<llama_token> sampled_tokens(mask_positions.size());
for (size_t i = 0; i < mask_positions.size(); i++) {
int32_t pos = mask_positions[i];
const float * pos_logits = get_logits_for_pos(pos);
for (int32_t token_id = 0; token_id < n_vocab; token_id++) {
candidates[token_id].logit = pos_logits[token_id];
candidates[token_id].p = 0.0f;
candidates[token_id].id = token_id;
}
llama_token_data_array cur_p = {
candidates.data(),
candidates.size(),
-1,
false,
};
llama_sampler_apply(sampler, &cur_p);
llama_token sampled_token = cur_p.data[cur_p.selected].id;
float conf = calculate_confidence(cur_p, params.algorithm, rng);
sampled_tokens[i] = sampled_token;
confidences.emplace_back(conf, i);
}
int32_t transfer_count = calculate_transfer_count(
step, steps_per_block, mask_positions.size(), params.schedule, params.eps, num_transfer_tokens);
if (transfer_count > 0) {
if (params.alg_temp == 0.0f) {
std::partial_sort(confidences.begin(),
confidences.begin() + std::min(transfer_count, (int32_t) confidences.size()),
confidences.end(),
[](const std::pair<float, int32_t> & a, const std::pair<float, int32_t> & b) {
if (a.first != b.first) {
return a.first > b.first;
}
return a.second < b.second;
});
for (int32_t i = 0; i < std::min(transfer_count, (int32_t) confidences.size()); i++) {
int32_t mask_idx = confidences[i].second;
int32_t pos = mask_positions[mask_idx];
output_tokens[pos] = sampled_tokens[mask_idx];
}
} else {
conf_candidates.clear();
for (size_t i = 0; i < confidences.size(); i++) {
float conf_logit = confidences[i].first / params.alg_temp;
conf_candidates.emplace_back(llama_token_data{ (int32_t) i, conf_logit, 0.0f });
}
llama_token_data_array conf_array = {
conf_candidates.data(),
conf_candidates.size(),
-1,
false,
};
for (int32_t i = 0; i < std::min(transfer_count, (int32_t) confidences.size()); i++) {
llama_sampler_apply(dist_sampler, &conf_array);
int32_t selected_idx = conf_array.selected;
int32_t mask_idx = selected_idx;
int32_t pos = mask_positions[mask_idx];
output_tokens[pos] = sampled_tokens[mask_idx];
conf_candidates[selected_idx].p = 0.0f;
conf_array.selected = -1;
}
}
}
}
int64_t time_end_sampling = ggml_time_us();
total_sampling_time += time_end_sampling - time_start_sampling;
}
}
int64_t time_end = ggml_time_us();
total_time += time_end - time_start;
LOG_INF("\ntotal time: %0.2fms, time per step: %0.2fms, sampling time per step: %0.2fms\n",
total_time / 1000.0,
total_time / 1000.0 / params.steps,
total_sampling_time / 1000.0 / params.steps);
llama_batch_free(batch);
llama_sampler_free(sampler);
llama_sampler_free(dist_sampler);
n_generated = params.max_length;
}
static std::string format_input_text(const std::string & prompt, const std::string & system_prompt, bool use_chat_template, llama_model * model) {
if (!use_chat_template) {
return prompt;
}
auto chat_templates = common_chat_templates_init(model, "");
common_chat_templates_inputs inputs;
common_chat_msg system_msg;
if (!system_prompt.empty()) {
system_msg.role = "system";
system_msg.content = system_prompt;
inputs.messages.push_back(system_msg);
}
common_chat_msg user_msg;
user_msg.role = "user";
user_msg.content = prompt;
inputs.messages.push_back(user_msg);
inputs.add_generation_prompt = true;
auto result = common_chat_templates_apply(chat_templates.get(), inputs);
return result.prompt;
}
int main(int argc, char ** argv) {
ggml_time_init();
common_params params;
if (!common_params_parse(argc, argv, params, LLAMA_EXAMPLE_DIFFUSION)) {
return 1;
}
common_init();
llama_backend_init();
llama_model_params model_params = llama_model_default_params();
model_params.n_gpu_layers = params.n_gpu_layers;
model_params.devices = params.devices.data();
model_params.use_mmap = params.use_mmap;
model_params.use_direct_io = params.use_direct_io;
model_params.use_mlock = params.use_mlock;
model_params.check_tensors = params.check_tensors;
llama_model * model = llama_model_load_from_file(params.model.path.c_str(), model_params);
if (!model) {
LOG_ERR("error: failed to load model '%s'\n", params.model.path.c_str());
return 1;
}
if (!llama_model_is_diffusion(model)) {
LOG_ERR("error: unsupported model for diffusion");
llama_model_free(model);
return 1;
}
llama_context_params ctx_params = llama_context_default_params();
ctx_params.n_ctx = params.n_ctx;
ctx_params.n_batch = params.n_batch;
ctx_params.n_ubatch = params.n_ubatch;
ctx_params.flash_attn_type = params.flash_attn_type;
ctx_params.no_perf = params.no_perf;
ctx_params.type_k = params.cache_type_k;
ctx_params.type_v = params.cache_type_v;
llama_context * ctx = llama_init_from_model(model, ctx_params);
if (!ctx) {
LOG_ERR("error: failed to create context\n");
llama_model_free(model);
return 1;
}
llama_set_n_threads(ctx, params.cpuparams.n_threads, params.cpuparams_batch.n_threads);
const llama_vocab * vocab = llama_model_get_vocab(model);
std::string formatted_prompt = format_input_text(params.prompt, params.system_prompt, params.enable_chat_template, model);
std::vector<llama_token> input_tokens = common_tokenize(vocab,
formatted_prompt,
/*add special tokens*/ true,
/*parse special*/ true);
int n_input = input_tokens.size();
if (n_input >= params.n_ctx) {
LOG_ERR("error: input too long (%d tokens), max context is %d\n", n_input, params.n_ctx);
llama_free(ctx);
llama_model_free(model);
return 1;
}
llama_token mask_token_id = llama_vocab_mask(vocab);
GGML_ASSERT(mask_token_id != LLAMA_TOKEN_NULL);
bool visual_mode = params.diffusion.visual_mode;
int32_t n_generated = 0;
std::vector<llama_token> output_tokens(params.n_ubatch);
struct diffusion_params diff_params;
char shift_logits_str[8];
if (llama_model_meta_val_str(model, "diffusion.shift_logits", shift_logits_str, sizeof(shift_logits_str)) >= 0) {
diff_params.shift_logits = (strcmp(shift_logits_str, "true") == 0);
} else {
diff_params.shift_logits = true;
}
//Use either eps or block length, but not both
GGML_ASSERT((params.diffusion.eps == 0) ^ (params.diffusion.block_length == 0));
if (params.diffusion.eps) {
diff_params.schedule = TIMESTEP_BASED;
diff_params.eps = params.diffusion.eps;
} else if (params.diffusion.block_length) {
diff_params.schedule = BLOCK_BASED;
diff_params.block_length = params.diffusion.block_length;
}
diff_params.mask_token_id = mask_token_id;
diff_params.seed = params.sampling.seed;
diff_params.temperature = params.sampling.temp;
diff_params.steps = params.diffusion.steps;
diff_params.algorithm = static_cast<diffusion_algorithm>(params.diffusion.algorithm);
diff_params.max_length = params.n_ubatch;
diff_params.top_p = params.sampling.top_p;
diff_params.top_k = params.sampling.top_k;
diff_params.visual_mode = params.diffusion.visual_mode;
diff_params.add_gumbel_noise = params.diffusion.add_gumbel_noise;
diff_params.step_callback = diffusion_step_callback;
callback_data cb_data = { &diff_params, vocab, n_input };
diff_params.step_callback_user_data = &cb_data;
const char * alg_names[] = { "ORIGIN", "ENTROPY_BASED", "MARGIN_BASED", "RANDOM", "CONFIDENCE_BASED" };
const char * sched_names[] = { "TIMESTEP_BASED", "BLOCK_BASED" };
const char * alg_name =
(diff_params.algorithm >= 0 && diff_params.algorithm <= 4) ? alg_names[diff_params.algorithm] : "UNKNOWN";
const char * sched_name =
(diff_params.schedule >= 0 && diff_params.schedule <= 1) ? sched_names[diff_params.schedule] : "UNKNOWN";
LOG_INF("diffusion_params: - %-25s llama_token = %d\n", "mask_token_id", mask_token_id);
LOG_INF("diffusion_params: - %-25s u32 = %d\n", "steps", diff_params.steps);
LOG_INF("diffusion_params: - %-25s u32 = %d\n", "max_length", diff_params.max_length);
LOG_INF("diffusion_params: - %-25s enum = %d (%s)\n", "algorithm", diff_params.algorithm, alg_name);
LOG_INF("diffusion_params: - %-25s enum = %d (%s)\n", "schedule", diff_params.schedule, sched_name);
LOG_INF("diffusion_params: - %-25s f32 = %.3f\n", "temperature", diff_params.temperature);
if (diff_params.schedule == TIMESTEP_BASED) {
LOG_INF("diffusion_params: - %-25s f32 = %.6f\n", "eps", diff_params.eps);
LOG_INF("diffusion_params: - %-25s f32 = %.3f\n", "alg_temp", diff_params.alg_temp);
}
if (diff_params.schedule == BLOCK_BASED) {
LOG_INF("diffusion_params: - %-25s u32 = %d\n", "block_length", diff_params.block_length);
LOG_INF("diffusion_params: - %-25s f32 = %.3f\n", "cfg_scale", diff_params.cfg_scale);
}
diffusion_generate(ctx, input_tokens.data(), output_tokens.data(), n_input, diff_params, n_generated);
if (n_generated > 0) {
if (visual_mode) {
//clear screen and move cursor to top-left
LOG_INF("\033[2J\033[H");
}
output_tokens.erase(output_tokens.begin(), output_tokens.begin() + n_input);
std::string output_data = common_detokenize(vocab, output_tokens, false);
LOG_INF("\n%s\n", output_data.c_str());
} else {
LOG_INF("Error: diffusion generation failed\n");
}
llama_free(ctx);
llama_model_free(model);
llama_backend_free();
return 0;
}

View File

@@ -0,0 +1,5 @@
set(TARGET llama-embedding)
add_executable(${TARGET} embedding.cpp)
install(TARGETS ${TARGET} RUNTIME)
target_link_libraries(${TARGET} PRIVATE common llama ${CMAKE_THREAD_LIBS_INIT})
target_compile_features(${TARGET} PRIVATE cxx_std_17)

View File

@@ -0,0 +1,61 @@
# llama.cpp/example/embedding
This example demonstrates generate high-dimensional embedding vector of a given text with llama.cpp.
## Quick Start
To get started right away, run the following command, making sure to use the correct path for the model you have:
### Unix-based systems (Linux, macOS, etc.):
```bash
./llama-embedding -m ./path/to/model --pooling mean --log-disable -p "Hello World!" 2>/dev/null
```
### Windows:
```powershell
llama-embedding.exe -m ./path/to/model --pooling mean --log-disable -p "Hello World!" 2>$null
```
The above command will output space-separated float values.
## extra parameters
### --embd-normalize $integer$
| $integer$ | description | formula |
|-----------|---------------------|---------|
| $-1$ | none |
| $0$ | max absolute int16 | $\Large{{32760 * x_i} \over\max \lvert x_i\rvert}$
| $1$ | taxicab | $\Large{x_i \over\sum \lvert x_i\rvert}$
| $2$ | euclidean (default) | $\Large{x_i \over\sqrt{\sum x_i^2}}$
| $>2$ | p-norm | $\Large{x_i \over\sqrt[p]{\sum \lvert x_i\rvert^p}}$
### --embd-output-format $'string'$
| $'string'$ | description | |
|------------|------------------------------|--|
| '' | same as before | (default)
| 'array' | single embeddings | $[[x_1,...,x_n]]$
| | multiple embeddings | $[[x_1,...,x_n],[x_1,...,x_n],...,[x_1,...,x_n]]$
| 'json' | openai style |
| 'json+' | add cosine similarity matrix |
| 'raw' | plain text output |
### --embd-separator $"string"$
| $"string"$ | |
|--------------|-|
| "\n" | (default)
| "<#embSep#>" | for example
| "<#sep#>" | other example
## examples
### Unix-based systems (Linux, macOS, etc.):
```bash
./llama-embedding -p 'Castle<#sep#>Stronghold<#sep#>Dog<#sep#>Cat' --pooling mean --embd-separator '<#sep#>' --embd-normalize 2 --embd-output-format '' -m './path/to/model.gguf' --n-gpu-layers 99 --log-disable 2>/dev/null
```
### Windows:
```powershell
llama-embedding.exe -p 'Castle<#sep#>Stronghold<#sep#>Dog<#sep#>Cat' --pooling mean --embd-separator '<#sep#>' --embd-normalize 2 --embd-output-format '' -m './path/to/model.gguf' --n-gpu-layers 99 --log-disable 2>/dev/null
```

View File

@@ -0,0 +1,411 @@
#include "arg.h"
#include "common.h"
#include "log.h"
#include "llama.h"
#include <ctime>
#include <algorithm>
#if defined(_MSC_VER)
#pragma warning(disable: 4244 4267) // possible loss of data
#endif
static std::vector<std::string> split_lines(const std::string & s, const std::string & separator = "\n") {
std::vector<std::string> lines;
size_t start = 0;
size_t end = s.find(separator);
while (end != std::string::npos) {
lines.push_back(s.substr(start, end - start));
start = end + separator.length();
end = s.find(separator, start);
}
lines.push_back(s.substr(start)); // Add the last part
return lines;
}
static void batch_add_seq(llama_batch & batch, const std::vector<int32_t> & tokens, llama_seq_id seq_id) {
size_t n_tokens = tokens.size();
for (size_t i = 0; i < n_tokens; i++) {
common_batch_add(batch, tokens[i], i, { seq_id }, true);
}
}
static void batch_decode(llama_context * ctx, llama_batch & batch, float * output, int n_seq, int n_embd_out, int embd_norm) {
const enum llama_pooling_type pooling_type = llama_pooling_type(ctx);
// clear previous kv_cache values (irrelevant for embeddings)
llama_memory_clear(llama_get_memory(ctx), true);
// run model
LOG_INF("%s: n_tokens = %d, n_seq = %d\n", __func__, batch.n_tokens, n_seq);
if (llama_decode(ctx, batch) < 0) {
LOG_ERR("%s : failed to process\n", __func__);
}
for (int i = 0; i < batch.n_tokens; i++) {
if (!batch.logits[i]) {
continue;
}
const float * embd = nullptr;
int embd_pos = 0;
if (pooling_type == LLAMA_POOLING_TYPE_NONE) {
// try to get token embeddings
embd = llama_get_embeddings_ith(ctx, i);
embd_pos = i;
GGML_ASSERT(embd != NULL && "failed to get token embeddings");
} else {
// try to get sequence embeddings - supported only when pooling_type is not NONE
embd = llama_get_embeddings_seq(ctx, batch.seq_id[i][0]);
embd_pos = batch.seq_id[i][0];
GGML_ASSERT(embd != NULL && "failed to get sequence embeddings");
}
float * out = output + embd_pos * n_embd_out;
common_embd_normalize(embd, out, n_embd_out, embd_norm);
}
}
// plain, pipe-friendly output: one embedding per line
static void print_raw_embeddings(const float * emb,
int n_embd_count,
int n_embd,
const llama_model * model,
enum llama_pooling_type pooling_type,
int embd_normalize) {
const uint32_t n_cls_out = llama_model_n_cls_out(model);
const bool is_rank = (pooling_type == LLAMA_POOLING_TYPE_RANK);
const int cols = is_rank ? std::min<int>(n_embd, (int) n_cls_out) : n_embd;
for (int j = 0; j < n_embd_count; ++j) {
for (int i = 0; i < cols; ++i) {
if (embd_normalize == 0) {
LOG("%1.0f%s", emb[j * n_embd + i], (i + 1 < cols ? " " : ""));
} else {
LOG("%1.7f%s", emb[j * n_embd + i], (i + 1 < cols ? " " : ""));
}
}
LOG("\n");
}
}
int main(int argc, char ** argv) {
common_params params;
if (!common_params_parse(argc, argv, params, LLAMA_EXAMPLE_EMBEDDING)) {
return 1;
}
common_init();
params.embedding = true;
// get max number of sequences per batch
const int n_seq_max = llama_max_parallel_sequences();
// if the number of prompts that would be encoded is known in advance, it's more efficient to specify the
// --parallel argument accordingly. for convenience, if not specified, we fallback to unified KV cache
// in order to support any number of prompts
if (params.n_parallel == 1) {
LOG_INF("%s: n_parallel == 1 -> unified KV cache is enabled\n", __func__);
params.kv_unified = true;
params.n_parallel = n_seq_max;
}
// utilize the full context
if (params.n_batch < params.n_ctx) {
LOG_WRN("%s: setting batch size to %d\n", __func__, params.n_ctx);
params.n_batch = params.n_ctx;
}
// for non-causal models, batch size must be equal to ubatch size
if (params.attention_type != LLAMA_ATTENTION_TYPE_CAUSAL) {
params.n_ubatch = params.n_batch;
}
llama_backend_init();
llama_numa_init(params.numa);
// load the model
auto llama_init = common_init_from_params(params);
auto * model = llama_init->model();
auto * ctx = llama_init->context();
if (model == NULL) {
LOG_ERR("%s: unable to load model\n", __func__);
return 1;
}
const llama_vocab * vocab = llama_model_get_vocab(model);
const int n_ctx_train = llama_model_n_ctx_train(model);
const int n_ctx = llama_n_ctx(ctx);
const enum llama_pooling_type pooling_type = llama_pooling_type(ctx);
if (llama_model_has_encoder(model) && llama_model_has_decoder(model)) {
LOG_ERR("%s: computing embeddings in encoder-decoder models is not supported\n", __func__);
return 1;
}
if (n_ctx > n_ctx_train) {
LOG_WRN("%s: warning: model was trained on only %d context tokens (%d specified)\n",
__func__, n_ctx_train, n_ctx);
}
// print system information
{
LOG_INF("\n");
LOG_INF("%s\n", common_params_get_system_info(params).c_str());
}
// split the prompt into lines
std::vector<std::string> prompts = split_lines(params.prompt, params.embd_sep);
// max batch size
const uint64_t n_batch = params.n_batch;
// get added sep and eos token, if any
const std::string added_sep_token = llama_vocab_get_add_sep(vocab) ? llama_vocab_get_text(vocab, llama_vocab_sep(vocab)) : "";
const std::string added_eos_token = llama_vocab_get_add_eos(vocab) ? llama_vocab_get_text(vocab, llama_vocab_eos(vocab)) : "";
const char * rerank_prompt = llama_model_chat_template(model, "rerank");
// tokenize the prompts and trim
std::vector<std::vector<int32_t>> inputs;
for (const auto & prompt : prompts) {
std::vector<llama_token> inp;
// split classification pairs and insert expected separator tokens
if (pooling_type == LLAMA_POOLING_TYPE_RANK && prompt.find(params.cls_sep) != std::string::npos) {
std::vector<std::string> pairs = split_lines(prompt, params.cls_sep);
if (rerank_prompt != nullptr) {
const std::string query = pairs[0];
const std::string doc = pairs[1];
std::string final_prompt = rerank_prompt;
string_replace_all(final_prompt, "{query}" , query);
string_replace_all(final_prompt, "{document}", doc );
inp = common_tokenize(vocab, final_prompt, true, true);
} else {
std::string final_prompt;
for (size_t i = 0; i < pairs.size(); i++) {
final_prompt += pairs[i];
if (i != pairs.size() - 1) {
if (!added_eos_token.empty()) {
final_prompt += added_eos_token;
}
if (!added_sep_token.empty()) {
final_prompt += added_sep_token;
}
}
}
inp = common_tokenize(ctx, final_prompt, true, true);
}
} else {
inp = common_tokenize(ctx, prompt, true, true);
}
if (inp.size() > n_batch) {
LOG_ERR("%s: number of tokens in input line (%lld) exceeds batch size (%lld), increase batch size and re-run\n",
__func__, (long long int) inp.size(), (long long int) n_batch);
return 1;
}
inputs.push_back(inp);
}
// check if the last token is SEP/EOS
// it should be automatically added by the tokenizer when 'tokenizer.ggml.add_eos_token' is set to 'true'
for (auto & inp : inputs) {
if (inp.empty() || (inp.back() != llama_vocab_sep(vocab) && inp.back() != llama_vocab_eos(vocab))) {
LOG_WRN("%s: last token in the prompt is not SEP or EOS\n", __func__);
LOG_WRN("%s: 'tokenizer.ggml.add_eos_token' should be set to 'true' in the GGUF header\n", __func__);
}
}
// tokenization stats
if (params.verbose_prompt) {
for (int i = 0; i < (int) inputs.size(); i++) {
LOG_INF("%s: prompt %d: '%s'\n", __func__, i, prompts[i].c_str());
LOG_INF("%s: number of tokens in prompt = %zu\n", __func__, inputs[i].size());
for (int j = 0; j < (int) inputs[i].size(); j++) {
LOG("%6d -> '%s'\n", inputs[i][j], common_token_to_piece(ctx, inputs[i][j]).c_str());
}
LOG("\n\n");
}
}
// initialize batch
const int n_prompts = prompts.size();
struct llama_batch batch = llama_batch_init(n_batch, 0, 1);
// count number of embeddings
int n_embd_count = 0;
if (pooling_type == LLAMA_POOLING_TYPE_NONE) {
for (int k = 0; k < n_prompts; k++) {
n_embd_count += inputs[k].size();
}
} else {
n_embd_count = n_prompts;
}
// allocate output
const int n_embd_out = llama_model_n_embd_out(model);
std::vector<float> embeddings(n_embd_count * n_embd_out, 0);
float * emb = embeddings.data();
// break into batches
int e = 0; // number of embeddings already stored
int s = 0; // number of prompts in current batch
for (int k = 0; k < n_prompts; k++) {
// clamp to n_batch tokens
auto & inp = inputs[k];
const uint64_t n_toks = inp.size();
// encode if at capacity
if (batch.n_tokens + n_toks > n_batch || s >= n_seq_max) {
float * out = emb + e * n_embd_out;
batch_decode(ctx, batch, out, s, n_embd_out, params.embd_normalize);
e += pooling_type == LLAMA_POOLING_TYPE_NONE ? batch.n_tokens : s;
s = 0;
common_batch_clear(batch);
}
// add to batch
batch_add_seq(batch, inp, s);
s += 1;
}
// final batch
float * out = emb + e * n_embd_out;
batch_decode(ctx, batch, out, s, n_embd_out, params.embd_normalize);
if (params.embd_out.empty()) {
LOG("\n");
if (pooling_type == LLAMA_POOLING_TYPE_NONE) {
for (int j = 0; j < n_embd_count; j++) {
LOG("embedding %d: ", j);
for (int i = 0; i < std::min(3, n_embd_out); i++) {
if (params.embd_normalize == 0) {
LOG("%6.0f ", emb[j * n_embd_out + i]);
} else {
LOG("%9.6f ", emb[j * n_embd_out + i]);
}
}
LOG(" ... ");
for (int i = n_embd_out - 3; i < n_embd_out; i++) {
if (params.embd_normalize == 0) {
LOG("%6.0f ", emb[j * n_embd_out + i]);
} else {
LOG("%9.6f ", emb[j * n_embd_out + i]);
}
}
LOG("\n");
}
} else if (pooling_type == LLAMA_POOLING_TYPE_RANK) {
const uint32_t n_cls_out = llama_model_n_cls_out(model);
std::vector<std::string> cls_out_labels;
for (uint32_t i = 0; i < n_cls_out; i++) {
const char * label = llama_model_cls_label(model, i);
const std::string label_i(label == nullptr ? "" : label);
cls_out_labels.emplace_back(label_i.empty() ? std::to_string(i) : label_i);
}
for (int j = 0; j < n_embd_count; j++) {
for (uint32_t i = 0; i < n_cls_out; i++) {
// NOTE: if you change this log - update the tests in ci/run.sh
if (n_cls_out == 1) {
LOG("rerank score %d: %8.3f\n", j, emb[j * n_embd_out]);
} else {
LOG("rerank score %d: %8.3f [%s]\n", j, emb[j * n_embd_out + i], cls_out_labels[i].c_str());
}
}
}
} else {
// print the first part of the embeddings or for a single prompt, the full embedding
for (int j = 0; j < n_prompts; j++) {
LOG("embedding %d: ", j);
for (int i = 0; i < (n_prompts > 1 ? std::min(16, n_embd_out) : n_embd_out); i++) {
if (params.embd_normalize == 0) {
LOG("%6.0f ", emb[j * n_embd_out + i]);
} else {
LOG("%9.6f ", emb[j * n_embd_out + i]);
}
}
LOG("\n");
}
// print cosine similarity matrix
if (n_prompts > 1) {
LOG("\n");
LOG("cosine similarity matrix:\n\n");
for (int i = 0; i < n_prompts; i++) {
LOG("%6.6s ", prompts[i].c_str());
}
LOG("\n");
for (int i = 0; i < n_prompts; i++) {
for (int j = 0; j < n_prompts; j++) {
float sim = common_embd_similarity_cos(emb + i * n_embd_out, emb + j * n_embd_out, n_embd_out);
LOG("%6.2f ", sim);
}
LOG("%1.10s", prompts[i].c_str());
LOG("\n");
}
}
}
}
if (params.embd_out == "json" || params.embd_out == "json+" || params.embd_out == "array") {
const bool notArray = params.embd_out != "array";
LOG(notArray ? "{\n \"object\": \"list\",\n \"data\": [\n" : "[");
for (int j = 0;;) { // at least one iteration (one prompt)
if (notArray) LOG(" {\n \"object\": \"embedding\",\n \"index\": %d,\n \"embedding\": ",j);
LOG("[");
for (int i = 0;;) { // at least one iteration (n_embd > 0)
LOG(params.embd_normalize == 0 ? "%1.0f" : "%1.7f", emb[j * n_embd_out + i]);
i++;
if (i < n_embd_out) LOG(","); else break;
}
LOG(notArray ? "]\n }" : "]");
j++;
if (j < n_embd_count) LOG(notArray ? ",\n" : ","); else break;
}
LOG(notArray ? "\n ]" : "]\n");
if (params.embd_out == "json+" && n_prompts > 1) {
LOG(",\n \"cosineSimilarity\": [\n");
for (int i = 0;;) { // at least two iteration (n_embd_count > 1)
LOG(" [");
for (int j = 0;;) { // at least two iteration (n_embd_count > 1)
float sim = common_embd_similarity_cos(emb + i * n_embd_out, emb + j * n_embd_out, n_embd_out);
LOG("%6.2f", sim);
j++;
if (j < n_embd_count) LOG(", "); else break;
}
LOG(" ]");
i++;
if (i < n_embd_count) LOG(",\n"); else break;
}
LOG("\n ]");
}
if (notArray) LOG("\n}\n");
} else if (params.embd_out == "raw") {
print_raw_embeddings(emb, n_embd_count, n_embd_out, model, pooling_type, params.embd_normalize);
}
LOG("\n");
llama_perf_context_print(ctx);
// clean up
llama_batch_free(batch);
llama_backend_free();
return 0;
}

View File

@@ -0,0 +1,26 @@
set(TARGET llama-eval-callback)
add_executable(${TARGET} eval-callback.cpp)
install(TARGETS ${TARGET} RUNTIME)
target_link_libraries(${TARGET} PRIVATE common llama ${CMAKE_THREAD_LIBS_INIT})
target_compile_features(${TARGET} PRIVATE cxx_std_17)
if(LLAMA_BUILD_TESTS)
if(NOT ${CMAKE_SYSTEM_PROCESSOR} MATCHES "s390x")
set(MODEL_NAME "tinyllamas/stories15M-q4_0.gguf")
set(MODEL_HASH "SHA256=66967fbece6dbe97886593fdbb73589584927e29119ec31f08090732d1861739")
else()
set(MODEL_NAME "tinyllamas/stories15M-be.Q4_0.gguf")
set(MODEL_HASH "SHA256=9aec857937849d976f30397e97eb1cabb53eb9dcb1ce4611ba8247fb5f44c65d")
endif()
set(MODEL_DEST "${CMAKE_BINARY_DIR}/${MODEL_NAME}")
set(TEST_TARGET test-eval-callback)
add_test(NAME ${TEST_TARGET}-download-model COMMAND ${CMAKE_COMMAND}
-DDEST=${MODEL_DEST}
-DNAME=${MODEL_NAME}
-DHASH=${MODEL_HASH}
-P ${CMAKE_SOURCE_DIR}/cmake/download-models.cmake
)
set_tests_properties(${TEST_TARGET}-download-model PROPERTIES FIXTURES_SETUP ${TEST_TARGET}-download-model)
add_test(NAME ${TEST_TARGET} COMMAND llama-eval-callback -m "${MODEL_DEST}" --prompt hello --seed 42 -ngl 0)
set_tests_properties(${TEST_TARGET} PROPERTIES FIXTURES_REQUIRED ${TEST_TARGET}-download-model)
endif()

View File

@@ -0,0 +1,95 @@
# llama.cpp/examples/eval-callback
A simple example which demonstrates how to use callback during the inference.
It simply prints to the console all operations and tensor data.
Usage:
```shell
llama-eval-callback \
--hf-repo ggml-org/models \
--hf-file phi-2/ggml-model-q4_0.gguf \
--model phi-2-q4_0.gguf \
--prompt hello \
--seed 42 \
-ngl 33
```
Will print:
```shell
llm_load_tensors: offloaded 33/33 layers to GPU
...
llama_new_context_with_model: n_ctx = 512
...
llama_new_context_with_model: CUDA0 compute buffer size = 105.00 MiB
llama_new_context_with_model: CUDA_Host compute buffer size = 6.01 MiB
llama_new_context_with_model: graph nodes = 1225
llama_new_context_with_model: graph splits = 2
ggml_debug: inp_embd = (f32) GET_ROWS(token_embd.weight{2560, 51200, 1, 1}, inp_tokens{1, 1, 1, 1}}) = {2560, 1, 1, 1}
[
[
[ -0.0181, 0.0272, 0.0272, ...],
],
]
ggml_debug: norm-0 = (f32) NORM(CUDA0#inp_embd#0{2560, 1, 1, 1}, }) = {2560, 1, 1, 1}
[
[
[ -0.6989, 1.0636, 1.0636, ...],
],
]
ggml_debug: norm_w-0 = (f32) MUL(norm-0{2560, 1, 1, 1}, blk.0.attn_norm.weight{2560, 1, 1, 1}}) = {2560, 1, 1, 1}
[
[
[ -0.1800, 0.2817, 0.2632, ...],
],
]
ggml_debug: attn_norm-0 = (f32) ADD(norm_w-0{2560, 1, 1, 1}, blk.0.attn_norm.bias{2560, 1, 1, 1}}) = {2560, 1, 1, 1}
[
[
[ -0.1863, 0.2970, 0.2604, ...],
],
]
ggml_debug: wqkv-0 = (f32) MUL_MAT(blk.0.attn_qkv.weight{2560, 7680, 1, 1}, attn_norm-0{2560, 1, 1, 1}}) = {7680, 1, 1, 1}
[
[
[ -1.1238, 1.2876, -1.8086, ...],
],
]
ggml_debug: bqkv-0 = (f32) ADD(wqkv-0{7680, 1, 1, 1}, blk.0.attn_qkv.bias{7680, 1, 1, 1}}) = {7680, 1, 1, 1}
[
[
[ -1.1135, 1.4604, -1.9226, ...],
],
]
ggml_debug: bqkv-0 (view) = (f32) VIEW(bqkv-0{7680, 1, 1, 1}, }) = {2560, 1, 1, 1}
[
[
[ -1.1135, 1.4604, -1.9226, ...],
],
]
ggml_debug: Qcur-0 = (f32) CONT(bqkv-0 (view){2560, 1, 1, 1}, }) = {2560, 1, 1, 1}
[
[
[ -1.1135, 1.4604, -1.9226, ...],
],
]
ggml_debug: Qcur-0 (reshaped) = (f32) RESHAPE(Qcur-0{2560, 1, 1, 1}, }) = {80, 32, 1, 1}
[
[
[ -1.1135, 1.4604, -1.9226, ...],
[ -0.3608, 0.5076, -1.8866, ...],
[ 1.7643, 0.0273, -2.1065, ...],
...
],
]
ggml_debug: Qcur-0 = (f32) ROPE(Qcur-0 (reshaped){80, 32, 1, 1}, CUDA0#inp_pos#0{1, 1, 1, 1}}) = {80, 32, 1, 1}
[
[
[ -1.1135, 1.4604, -1.9226, ...],
[ -0.3608, 0.5076, -1.8866, ...],
[ 1.7643, 0.0273, -2.1065, ...],
...
],
]
```

View File

@@ -0,0 +1,80 @@
#include "arg.h"
#include "common.h"
#include "debug.h"
#include "log.h"
#include "llama.h"
#include "llama-cpp.h"
#include <string>
#include <vector>
static bool run(llama_context * ctx, const common_params & params) {
const llama_model * model = llama_get_model(ctx);
const llama_vocab * vocab = llama_model_get_vocab(model);
const bool add_bos = llama_vocab_get_add_bos(vocab);
std::vector<llama_token> tokens = common_tokenize(ctx, params.prompt, add_bos);
if (tokens.empty()) {
LOG_ERR("%s : there are not input tokens to process - (try to provide a prompt with '-p')\n", __func__);
return false;
}
if (llama_decode(ctx, llama_batch_get_one(tokens.data(), tokens.size()))) {
LOG_ERR("%s : failed to eval\n", __func__);
return false;
}
return true;
}
int main(int argc, char ** argv) {
base_callback_data cb_data;
common_params params;
if (!common_params_parse(argc, argv, params, LLAMA_EXAMPLE_COMMON)) {
return 1;
}
common_init();
llama_backend_init();
llama_numa_init(params.numa);
// pass the callback to the backend scheduler
// it will be executed for each node during the graph computation
params.cb_eval = common_debug_cb_eval<false>;
params.cb_eval_user_data = &cb_data;
params.warmup = false;
// init
auto llama_init = common_init_from_params(params);
auto * model = llama_init->model();
auto * ctx = llama_init->context();
if (model == nullptr || ctx == nullptr) {
LOG_ERR("%s : failed to init\n", __func__);
return 1;
}
// print system information
{
LOG_INF("\n");
LOG_INF("%s\n", common_params_get_system_info(params).c_str());
LOG_INF("\n");
}
bool OK = run(ctx, params);
if (!OK) {
return 1;
}
LOG("\n");
llama_perf_context_print(ctx);
llama_backend_free();
return 0;
}

View File

@@ -0,0 +1,5 @@
set(TARGET llama-gen-docs)
add_executable(${TARGET} gen-docs.cpp)
install(TARGETS ${TARGET} RUNTIME)
target_link_libraries(${TARGET} PRIVATE common llama ${CMAKE_THREAD_LIBS_INIT})
target_compile_features(${TARGET} PRIVATE cxx_std_17)

View File

@@ -0,0 +1,142 @@
#include "arg.h"
#include "common.h"
#include <fstream>
#include <sstream>
#include <string>
// Export usage message (-h) to markdown format
// Automatically update the markdown docs
#define HELP_START_MARKER "<!-- HELP_START -->"
#define HELP_END_MARKER "<!-- HELP_END -->"
#define NOTE_MESSAGE "<!-- IMPORTANT: The list below is auto-generated by llama-gen-docs; do NOT modify it manually -->"
struct md_file {
llama_example ex;
std::string fname;
std::string specific_section_header;
};
std::vector<md_file> md_files = {
{LLAMA_EXAMPLE_CLI, "tools/cli/README.md", "CLI-specific params"},
{LLAMA_EXAMPLE_COMPLETION, "tools/completion/README.md", "Completion-specific params"},
{LLAMA_EXAMPLE_SERVER, "tools/server/README.md", "Server-specific params"},
};
static void write_table_header(std::ostringstream & ss) {
ss << "| Argument | Explanation |\n";
ss << "| -------- | ----------- |\n";
}
static void write_table_entry(std::ostringstream & ss, const common_arg & opt) {
ss << "| `";
// args
auto all_args = opt.get_args();
for (const auto & arg : all_args) {
if (arg == all_args.front()) {
ss << arg;
if (all_args.size() > 1) ss << ", ";
} else {
ss << arg << (arg != all_args.back() ? ", " : "");
}
}
// value hint
if (opt.value_hint) {
std::string md_value_hint(opt.value_hint);
string_replace_all(md_value_hint, "|", "\\|");
ss << " " << md_value_hint;
}
if (opt.value_hint_2) {
std::string md_value_hint_2(opt.value_hint_2);
string_replace_all(md_value_hint_2, "|", "\\|");
ss << " " << md_value_hint_2;
}
// help text
std::string md_help(opt.help);
md_help = string_strip(md_help);
string_replace_all(md_help, "\n", "<br/>");
string_replace_all(md_help, "|", "\\|");
ss << "` | " << md_help << " |\n";
}
static void write_table(std::ostringstream & ss, std::vector<common_arg *> & opts) {
write_table_header(ss);
for (const auto & opt : opts) {
write_table_entry(ss, *opt);
}
}
static void write_help(std::ostringstream & ss, const md_file & md) {
common_params params;
auto ctx_arg = common_params_parser_init(params, md.ex);
std::vector<common_arg *> common_options;
std::vector<common_arg *> sparam_options;
std::vector<common_arg *> specific_options;
for (auto & opt : ctx_arg.options) {
// in case multiple LLAMA_EXAMPLE_* are set, we prioritize the LLAMA_EXAMPLE_* matching current example
if (opt.is_sparam) {
sparam_options.push_back(&opt);
} else if (opt.in_example(ctx_arg.ex)) {
specific_options.push_back(&opt);
} else {
common_options.push_back(&opt);
}
}
ss << HELP_START_MARKER << "\n\n";
ss << NOTE_MESSAGE << "\n\n";
ss << "### Common params\n\n";
write_table(ss, common_options);
ss << "\n\n### Sampling params\n\n";
write_table(ss, sparam_options);
ss << "\n\n### " << md.specific_section_header << "\n\n";
write_table(ss, specific_options);
ss << "\n" << HELP_END_MARKER;
}
int main(int, char **) {
for (const auto & md : md_files) {
std::ifstream infile(md.fname);
if (!infile.is_open()) {
fprintf(stderr, "failed to open file '%s' for reading\n", md.fname.c_str());
return 1;
}
std::ostringstream ss;
ss << infile.rdbuf();
infile.close();
std::string content = ss.str();
size_t help_start = content.find(HELP_START_MARKER);
size_t help_end = content.find(HELP_END_MARKER);
if (help_start == std::string::npos || help_end == std::string::npos || help_end <= help_start) {
fprintf(stderr, "failed to find help markers in file '%s'\n", md.fname.c_str());
return 1;
}
std::ostringstream new_help_ss;
write_help(new_help_ss, md);
std::string new_help = new_help_ss.str();
content = content.substr(0, help_start) + new_help + content.substr(help_end + strlen(HELP_END_MARKER));
std::ofstream outfile(md.fname);
if (!outfile.is_open()) {
fprintf(stderr, "failed to open file '%s' for writing\n", md.fname.c_str());
return 1;
}
outfile << content;
outfile.close();
printf("Updated help in '%s'\n", md.fname.c_str());
}
return 0;
}

View File

@@ -0,0 +1,22 @@
set(TARGET llama-gguf-hash)
add_executable(${TARGET} gguf-hash.cpp)
install(TARGETS ${TARGET} RUNTIME)
# clibs dependencies
include_directories(deps/)
add_library(xxhash OBJECT deps/xxhash/xxhash.c deps/xxhash/xxhash.h)
target_link_libraries(${TARGET} PRIVATE xxhash)
add_library(sha1 OBJECT deps/sha1/sha1.c deps/sha1/sha1.h)
target_link_libraries(${TARGET} PRIVATE sha1)
if (NOT MSVC)
# disable warnings in 3rd party code
target_compile_options(sha1 PRIVATE -w)
endif()
add_library(sha256 OBJECT deps/sha256/sha256.c deps/sha256/sha256.h)
target_link_libraries(${TARGET} PRIVATE sha256)
target_link_libraries(${TARGET} PRIVATE ggml ${CMAKE_THREAD_LIBS_INIT})
target_compile_features(${TARGET} PRIVATE cxx_std_17)

View File

@@ -0,0 +1,206 @@
# llama-gguf-hash
CLI to hash GGUF files to detect difference on a per model and per tensor level.
**Command line options:**
- `--help`: display help message
- `--xxh64`: use xhash 64bit hash mode (default)
- `--sha1`: use sha1
- `--uuid`: use uuid
- `--sha256`: use sha256
- `--all`: use all hash
- `--no-layer`: exclude per layer hash
- `--uuid`: generate UUIDv5 ID
- `-c`, `--check <manifest>`: verify against a manifest
## About
While most POSIX systems already have hash checking programs like sha256sum, it
is designed to check entire files. This is not ideal for our purpose if we want
to check for consistency of the tensor data even if the metadata content of the
gguf KV store has been updated.
This program is designed to hash a gguf tensor payload on a 'per tensor layer'
in addition to a 'entire tensor model' hash. The intent is that the entire
tensor layer can be checked first but if there is any detected inconsistencies,
then the per tensor hash can be used to narrow down the specific tensor layer
that has inconsistencies.
For Maintainers:
- Detection of tensor inconsistency during development and automated tests
- This is served by xxh64 which is fast
- This is also served by having per tensor layer to assist in narrowing down
the location of the faulty tensor layer
- This is also served by sha1 which is much slower but more widely supported
For Model Creators:
- Optional consistent UUID generation based on model tensor content
- This is served by UUIDv5 which is useful for databases keys
- llama.cpp UUIDv5 Namespace: `ef001206-dadc-5f6d-a15f-3359e577d4e5`
- Made via UUIDv5 URL namespace of `en.wikipedia.org/wiki/Llama.cpp`
For Model Users:
- Assurance of tensor layer integrity even if metadata was updated
- This is served by sha256 which is still considered very secure as of 2024
### Design Note
- The default behavior of this program if no arguments is provided is to hash
using xxhash's xxh32 mode because it is very fast and is primarily targeted
towards maintainers who may want to use this in automated tests.
- xxhash support xxh32 and xxh128 for 32bit hash and 128bit hash respectively
however we picked 64bit xxhash as most computers are 64bit as of 2024 and thus
would have a better affinity to calculating hash that is 64bit in size.
## Compile Example
```bash
cmake -B build -DCMAKE_BUILD_TYPE=Debug -DLLAMA_FATAL_WARNINGS=ON
make -C build clean
make -C build llama-gguf-hash VERBOSE=1
./build/bin/llama-gguf-hash test.gguf
./build/bin/llama-gguf-hash --xxh64 test.gguf
./build/bin/llama-gguf-hash --sha1 test.gguf
./build/bin/llama-gguf-hash --uuid test.gguf
./build/bin/llama-gguf-hash --sha256 test.gguf
```
## Generation and Verification Example
To generate we may use this command
```bash
./llama-gguf-hash --all test.gguf > test.gguf.manifest
```
Which would generate a manifest that looks like below, which contains multiple hash type and per tensor layer hashes as well
(This excludes UUID as that is an ID not a hash)
```bash
xxh64 f66e9cd66a4396a0 test.gguf:tensor_0
sha1 59f79ecefd8125a996fdf419239051a7e99e5f20 test.gguf:tensor_0
sha256 c0510d38fa060c46265e0160a85c7243096b01dd31c2f355bdbb5516b20de1bd test.gguf:tensor_0
xxh64 7d3a1f9ac04d0537 test.gguf:tensor_1
sha1 4765f592eacf096df4628ba59476af94d767080a test.gguf:tensor_1
sha256 8514cbcc73692a2c56bd7a33a022edd5ff819614bd23b19915d7224387f397a7 test.gguf:tensor_1
xxh64 a0af5d700049693b test.gguf:tensor_2
sha1 25cbfbad4513cc348e2c95ebdee69d6ff2fd8753 test.gguf:tensor_2
sha256 947e6b36e20f2cc95e1d2ce1c1669d813d574657ac6b5ac5196158d454d35180 test.gguf:tensor_2
xxh64 e83fddf559d7b6a6 test.gguf:tensor_3
sha1 a9cba73e2d90f2ee3dae2548caa42bef3fe6a96c test.gguf:tensor_3
sha256 423b044e016d8ac73c39f23f60bf01bedef5ecb03c0230accd824c91fe86f1a1 test.gguf:tensor_3
xxh64 1257733306b7992d test.gguf:tensor_4
sha1 d7bc61db93bb685ce9d598da89717c66729b7543 test.gguf:tensor_4
sha256 79737cb3912d4201384cf7f16a1a37ff7823f23ea796cb205b6ca361ab9e3ebf test.gguf:tensor_4
xxh64 d238d16ba4711e58 test.gguf:tensor_5
sha1 0706566c198fe1072f37e0a5135b4b5f23654c52 test.gguf:tensor_5
sha256 60949be8298eced0ecdde64487643d018407bd261691e061d9e9c3dbc9fd358b test.gguf:tensor_5
xxh64 3fbc3b65ab8c7f39 test.gguf:tensor_6
sha1 73922a0727226a409049f6fc3172a52219ca6f00 test.gguf:tensor_6
sha256 574f4c46ff384a3b9a225eb955d2a871847a2e8b3fa59387a8252832e92ef7b0 test.gguf:tensor_6
xxh64 c22021c29854f093 test.gguf:tensor_7
sha1 efc39cece6a951188fc41e354c73bbfe6813d447 test.gguf:tensor_7
sha256 4c0410cd3c500f078ae5b21e8dc9eb79e29112713b2ab58a882f82a3868d4d75 test.gguf:tensor_7
xxh64 936df61f5d64261f test.gguf:tensor_8
sha1 c2490296d789a4f34398a337fed8377d943d9f06 test.gguf:tensor_8
sha256 c4401313feeba0261275c3b25bd2d8fe40ce04e0f440c2980ed0e9674c30ff01 test.gguf:tensor_8
xxh64 93fd20c64421c081 test.gguf:tensor_9
sha1 7047ce1e78437a6884337a3751c7ee0421918a65 test.gguf:tensor_9
sha256 23d57cf0d7a6e90b0b3616b41300e0cd354781e812add854a5f95aa55f2bc514 test.gguf:tensor_9
xxh64 5a54d3aad816f302 test.gguf
sha1 d15be52c4ff213e823cb6dd13af7ee2f978e7042 test.gguf
sha256 7dd641b32f59b60dbd4b5420c4b0f6321ccf48f58f6ae201a3dbc4a58a27c6e4 test.gguf
```
We can then use the normal check command which will by default check for the highest security strength hash and verify against that:
```bash
$ ./llama-gguf-hash --check test.gguf.manifest test.gguf
manifest test.gguf.manifest sha256 sha1 xxh64
sha256 c0510d38fa060c46265e0160a85c7243096b01dd31c2f355bdbb5516b20de1bd test.gguf:tensor_0 - Ok
sha256 8514cbcc73692a2c56bd7a33a022edd5ff819614bd23b19915d7224387f397a7 test.gguf:tensor_1 - Ok
sha256 947e6b36e20f2cc95e1d2ce1c1669d813d574657ac6b5ac5196158d454d35180 test.gguf:tensor_2 - Ok
sha256 423b044e016d8ac73c39f23f60bf01bedef5ecb03c0230accd824c91fe86f1a1 test.gguf:tensor_3 - Ok
sha256 79737cb3912d4201384cf7f16a1a37ff7823f23ea796cb205b6ca361ab9e3ebf test.gguf:tensor_4 - Ok
sha256 60949be8298eced0ecdde64487643d018407bd261691e061d9e9c3dbc9fd358b test.gguf:tensor_5 - Ok
sha256 574f4c46ff384a3b9a225eb955d2a871847a2e8b3fa59387a8252832e92ef7b0 test.gguf:tensor_6 - Ok
sha256 4c0410cd3c500f078ae5b21e8dc9eb79e29112713b2ab58a882f82a3868d4d75 test.gguf:tensor_7 - Ok
sha256 c4401313feeba0261275c3b25bd2d8fe40ce04e0f440c2980ed0e9674c30ff01 test.gguf:tensor_8 - Ok
sha256 23d57cf0d7a6e90b0b3616b41300e0cd354781e812add854a5f95aa55f2bc514 test.gguf:tensor_9 - Ok
sha256 7dd641b32f59b60dbd4b5420c4b0f6321ccf48f58f6ae201a3dbc4a58a27c6e4 test.gguf - Ok
Verification results for test.gguf.manifest - Success
```
Or we may explicitly ask for a faster hash like:
```bash
$ ./llama-gguf-hash --check test.gguf.manifest --xxh64 test.gguf
manifest test.gguf.manifest sha256 sha1 xxh64
xxh64 f66e9cd66a4396a0 test.gguf:tensor_0 - Ok
xxh64 7d3a1f9ac04d0537 test.gguf:tensor_1 - Ok
xxh64 a0af5d700049693b test.gguf:tensor_2 - Ok
xxh64 e83fddf559d7b6a6 test.gguf:tensor_3 - Ok
xxh64 1257733306b7992d test.gguf:tensor_4 - Ok
xxh64 d238d16ba4711e58 test.gguf:tensor_5 - Ok
xxh64 3fbc3b65ab8c7f39 test.gguf:tensor_6 - Ok
xxh64 c22021c29854f093 test.gguf:tensor_7 - Ok
xxh64 936df61f5d64261f test.gguf:tensor_8 - Ok
xxh64 93fd20c64421c081 test.gguf:tensor_9 - Ok
xxh64 5a54d3aad816f302 test.gguf - Ok
Verification results for test.gguf.manifest - Success
```
Or maybe we want to just check that all the hash is valid:
```bash
$./llama-gguf-hash --check test.gguf.manifest --all test.gguf.manifest
manifest test.gguf.manifest sha256 sha1 xxh64
xxh64 f66e9cd66a4396a0 test.gguf:tensor_0 - Ok
sha1 59f79ecefd8125a996fdf419239051a7e99e5f20 test.gguf:tensor_0 - Ok
sha256 c0510d38fa060c46265e0160a85c7243096b01dd31c2f355bdbb5516b20de1bd test.gguf:tensor_0 - Ok
xxh64 7d3a1f9ac04d0537 test.gguf:tensor_1 - Ok
sha1 4765f592eacf096df4628ba59476af94d767080a test.gguf:tensor_1 - Ok
sha256 8514cbcc73692a2c56bd7a33a022edd5ff819614bd23b19915d7224387f397a7 test.gguf:tensor_1 - Ok
xxh64 a0af5d700049693b test.gguf:tensor_2 - Ok
sha1 25cbfbad4513cc348e2c95ebdee69d6ff2fd8753 test.gguf:tensor_2 - Ok
sha256 947e6b36e20f2cc95e1d2ce1c1669d813d574657ac6b5ac5196158d454d35180 test.gguf:tensor_2 - Ok
xxh64 e83fddf559d7b6a6 test.gguf:tensor_3 - Ok
sha1 a9cba73e2d90f2ee3dae2548caa42bef3fe6a96c test.gguf:tensor_3 - Ok
sha256 423b044e016d8ac73c39f23f60bf01bedef5ecb03c0230accd824c91fe86f1a1 test.gguf:tensor_3 - Ok
xxh64 1257733306b7992d test.gguf:tensor_4 - Ok
sha1 d7bc61db93bb685ce9d598da89717c66729b7543 test.gguf:tensor_4 - Ok
sha256 79737cb3912d4201384cf7f16a1a37ff7823f23ea796cb205b6ca361ab9e3ebf test.gguf:tensor_4 - Ok
xxh64 d238d16ba4711e58 test.gguf:tensor_5 - Ok
sha1 0706566c198fe1072f37e0a5135b4b5f23654c52 test.gguf:tensor_5 - Ok
sha256 60949be8298eced0ecdde64487643d018407bd261691e061d9e9c3dbc9fd358b test.gguf:tensor_5 - Ok
xxh64 3fbc3b65ab8c7f39 test.gguf:tensor_6 - Ok
sha1 73922a0727226a409049f6fc3172a52219ca6f00 test.gguf:tensor_6 - Ok
sha256 574f4c46ff384a3b9a225eb955d2a871847a2e8b3fa59387a8252832e92ef7b0 test.gguf:tensor_6 - Ok
xxh64 c22021c29854f093 test.gguf:tensor_7 - Ok
sha1 efc39cece6a951188fc41e354c73bbfe6813d447 test.gguf:tensor_7 - Ok
sha256 4c0410cd3c500f078ae5b21e8dc9eb79e29112713b2ab58a882f82a3868d4d75 test.gguf:tensor_7 - Ok
xxh64 936df61f5d64261f test.gguf:tensor_8 - Ok
sha1 c2490296d789a4f34398a337fed8377d943d9f06 test.gguf:tensor_8 - Ok
sha256 c4401313feeba0261275c3b25bd2d8fe40ce04e0f440c2980ed0e9674c30ff01 test.gguf:tensor_8 - Ok
xxh64 93fd20c64421c081 test.gguf:tensor_9 - Ok
sha1 7047ce1e78437a6884337a3751c7ee0421918a65 test.gguf:tensor_9 - Ok
sha256 23d57cf0d7a6e90b0b3616b41300e0cd354781e812add854a5f95aa55f2bc514 test.gguf:tensor_9 - Ok
xxh64 5a54d3aad816f302 test.gguf - Ok
sha1 d15be52c4ff213e823cb6dd13af7ee2f978e7042 test.gguf - Ok
sha256 7dd641b32f59b60dbd4b5420c4b0f6321ccf48f58f6ae201a3dbc4a58a27c6e4 test.gguf - Ok
Verification results for test.gguf.manifest - Success
```
## Crypto/Hash Libraries Used
These micro c libraries dependencies was installed via the [clib c package manager](https://github.com/clibs)
- https://github.com/Cyan4973/xxHash
- https://github.com/clibs/sha1/
- https://github.com/jb55/sha256.c

View File

@@ -0,0 +1,13 @@
{
"name": "rotate-bits",
"version": "0.1.1",
"repo": "jb55/rotate-bits.h",
"description": "rotate bits",
"keywords": ["rotl", "rotr"],
"src": ["rotate-bits.h"],
"license": "Public Domain",
"development": {
"thlorenz/tap.c": "*"
}
}

View File

@@ -0,0 +1,46 @@
#ifndef __ROTATE_DEFS_H
#define __ROTATE_DEFS_H
#ifdef _MSC_VER
#include <stdlib.h>
#define ROTL32(v, n) _rotl((v), (n))
#define ROTL64(v, n) _rotl64((v), (n))
#define ROTR32(v, n) _rotr((v), (n))
#define ROTR64(v, n) _rotr64((v), (n))
#else
#include <stdint.h>
#define U8V(v) ((uint8_t)(v) & 0xFFU)
#define U16V(v) ((uint16_t)(v) & 0xFFFFU)
#define U32V(v) ((uint32_t)(v) & 0xFFFFFFFFU)
#define U64V(v) ((uint64_t)(v) & 0xFFFFFFFFFFFFFFFFU)
#define ROTL32(v, n) \
(U32V((uint32_t)(v) << (n)) | ((uint32_t)(v) >> (32 - (n))))
// tests fail if we don't have this cast...
#define ROTL64(v, n) \
(U64V((uint64_t)(v) << (n)) | ((uint64_t)(v) >> (64 - (n))))
#define ROTR32(v, n) ROTL32(v, 32 - (n))
#define ROTR64(v, n) ROTL64(v, 64 - (n))
#endif
#define ROTL8(v, n) \
(U8V((uint8_t)(v) << (n)) | ((uint8_t)(v) >> (8 - (n))))
#define ROTL16(v, n) \
(U16V((uint16_t)(v) << (n)) | ((uint16_t)(v) >> (16 - (n))))
#define ROTR8(v, n) ROTL8(v, 8 - (n))
#define ROTR16(v, n) ROTL16(v, 16 - (n))
#endif

View File

@@ -0,0 +1,9 @@
{
"name": "sha1",
"version": "0.0.1",
"repo": "clibs/sha1",
"description": "sha1 hash algorithm",
"keywords": ["sha1", "hash"],
"license": "public domain",
"src": ["sha1.c", "sha1.h"]
}

View File

@@ -0,0 +1,295 @@
/*
SHA-1 in C
By Steve Reid <steve@edmweb.com>
100% Public Domain
Test Vectors (from FIPS PUB 180-1)
"abc"
A9993E36 4706816A BA3E2571 7850C26C 9CD0D89D
"abcdbcdecdefdefgefghfghighijhijkijkljklmklmnlmnomnopnopq"
84983E44 1C3BD26E BAAE4AA1 F95129E5 E54670F1
A million repetitions of "a"
34AA973C D4C4DAA4 F61EEB2B DBAD2731 6534016F
*/
/* #define LITTLE_ENDIAN * This should be #define'd already, if true. */
/* #define SHA1HANDSOFF * Copies data before messing with it. */
#define SHA1HANDSOFF
#include <stdio.h>
#include <string.h>
/* for uint32_t */
#include <stdint.h>
#include "sha1.h"
#define rol(value, bits) (((value) << (bits)) | ((value) >> (32 - (bits))))
/* blk0() and blk() perform the initial expand. */
/* I got the idea of expanding during the round function from SSLeay */
#if BYTE_ORDER == LITTLE_ENDIAN
#define blk0(i) (block->l[i] = (rol(block->l[i],24)&0xFF00FF00) \
|(rol(block->l[i],8)&0x00FF00FF))
#elif BYTE_ORDER == BIG_ENDIAN
#define blk0(i) block->l[i]
#else
#error "Endianness not defined!"
#endif
#define blk(i) (block->l[i&15] = rol(block->l[(i+13)&15]^block->l[(i+8)&15] \
^block->l[(i+2)&15]^block->l[i&15],1))
/* (R0+R1), R2, R3, R4 are the different operations used in SHA1 */
#define R0(v,w,x,y,z,i) z+=((w&(x^y))^y)+blk0(i)+0x5A827999+rol(v,5);w=rol(w,30);
#define R1(v,w,x,y,z,i) z+=((w&(x^y))^y)+blk(i)+0x5A827999+rol(v,5);w=rol(w,30);
#define R2(v,w,x,y,z,i) z+=(w^x^y)+blk(i)+0x6ED9EBA1+rol(v,5);w=rol(w,30);
#define R3(v,w,x,y,z,i) z+=(((w|x)&y)|(w&x))+blk(i)+0x8F1BBCDC+rol(v,5);w=rol(w,30);
#define R4(v,w,x,y,z,i) z+=(w^x^y)+blk(i)+0xCA62C1D6+rol(v,5);w=rol(w,30);
/* Hash a single 512-bit block. This is the core of the algorithm. */
void SHA1Transform(
uint32_t state[5],
const unsigned char buffer[64]
)
{
uint32_t a, b, c, d, e;
typedef union
{
unsigned char c[64];
uint32_t l[16];
} CHAR64LONG16;
#ifdef SHA1HANDSOFF
CHAR64LONG16 block[1]; /* use array to appear as a pointer */
memcpy(block, buffer, 64);
#else
/* The following had better never be used because it causes the
* pointer-to-const buffer to be cast into a pointer to non-const.
* And the result is written through. I threw a "const" in, hoping
* this will cause a diagnostic.
*/
CHAR64LONG16 *block = (const CHAR64LONG16 *) buffer;
#endif
/* Copy context->state[] to working vars */
a = state[0];
b = state[1];
c = state[2];
d = state[3];
e = state[4];
/* 4 rounds of 20 operations each. Loop unrolled. */
R0(a, b, c, d, e, 0);
R0(e, a, b, c, d, 1);
R0(d, e, a, b, c, 2);
R0(c, d, e, a, b, 3);
R0(b, c, d, e, a, 4);
R0(a, b, c, d, e, 5);
R0(e, a, b, c, d, 6);
R0(d, e, a, b, c, 7);
R0(c, d, e, a, b, 8);
R0(b, c, d, e, a, 9);
R0(a, b, c, d, e, 10);
R0(e, a, b, c, d, 11);
R0(d, e, a, b, c, 12);
R0(c, d, e, a, b, 13);
R0(b, c, d, e, a, 14);
R0(a, b, c, d, e, 15);
R1(e, a, b, c, d, 16);
R1(d, e, a, b, c, 17);
R1(c, d, e, a, b, 18);
R1(b, c, d, e, a, 19);
R2(a, b, c, d, e, 20);
R2(e, a, b, c, d, 21);
R2(d, e, a, b, c, 22);
R2(c, d, e, a, b, 23);
R2(b, c, d, e, a, 24);
R2(a, b, c, d, e, 25);
R2(e, a, b, c, d, 26);
R2(d, e, a, b, c, 27);
R2(c, d, e, a, b, 28);
R2(b, c, d, e, a, 29);
R2(a, b, c, d, e, 30);
R2(e, a, b, c, d, 31);
R2(d, e, a, b, c, 32);
R2(c, d, e, a, b, 33);
R2(b, c, d, e, a, 34);
R2(a, b, c, d, e, 35);
R2(e, a, b, c, d, 36);
R2(d, e, a, b, c, 37);
R2(c, d, e, a, b, 38);
R2(b, c, d, e, a, 39);
R3(a, b, c, d, e, 40);
R3(e, a, b, c, d, 41);
R3(d, e, a, b, c, 42);
R3(c, d, e, a, b, 43);
R3(b, c, d, e, a, 44);
R3(a, b, c, d, e, 45);
R3(e, a, b, c, d, 46);
R3(d, e, a, b, c, 47);
R3(c, d, e, a, b, 48);
R3(b, c, d, e, a, 49);
R3(a, b, c, d, e, 50);
R3(e, a, b, c, d, 51);
R3(d, e, a, b, c, 52);
R3(c, d, e, a, b, 53);
R3(b, c, d, e, a, 54);
R3(a, b, c, d, e, 55);
R3(e, a, b, c, d, 56);
R3(d, e, a, b, c, 57);
R3(c, d, e, a, b, 58);
R3(b, c, d, e, a, 59);
R4(a, b, c, d, e, 60);
R4(e, a, b, c, d, 61);
R4(d, e, a, b, c, 62);
R4(c, d, e, a, b, 63);
R4(b, c, d, e, a, 64);
R4(a, b, c, d, e, 65);
R4(e, a, b, c, d, 66);
R4(d, e, a, b, c, 67);
R4(c, d, e, a, b, 68);
R4(b, c, d, e, a, 69);
R4(a, b, c, d, e, 70);
R4(e, a, b, c, d, 71);
R4(d, e, a, b, c, 72);
R4(c, d, e, a, b, 73);
R4(b, c, d, e, a, 74);
R4(a, b, c, d, e, 75);
R4(e, a, b, c, d, 76);
R4(d, e, a, b, c, 77);
R4(c, d, e, a, b, 78);
R4(b, c, d, e, a, 79);
/* Add the working vars back into context.state[] */
state[0] += a;
state[1] += b;
state[2] += c;
state[3] += d;
state[4] += e;
/* Wipe variables */
a = b = c = d = e = 0;
#ifdef SHA1HANDSOFF
memset(block, '\0', sizeof(block));
#endif
}
/* SHA1Init - Initialize new context */
void SHA1Init(
SHA1_CTX * context
)
{
/* SHA1 initialization constants */
context->state[0] = 0x67452301;
context->state[1] = 0xEFCDAB89;
context->state[2] = 0x98BADCFE;
context->state[3] = 0x10325476;
context->state[4] = 0xC3D2E1F0;
context->count[0] = context->count[1] = 0;
}
/* Run your data through this. */
void SHA1Update(
SHA1_CTX * context,
const unsigned char *data,
uint32_t len
)
{
uint32_t i;
uint32_t j;
j = context->count[0];
if ((context->count[0] += len << 3) < j)
context->count[1]++;
context->count[1] += (len >> 29);
j = (j >> 3) & 63;
if ((j + len) > 63)
{
memcpy(&context->buffer[j], data, (i = 64 - j));
SHA1Transform(context->state, context->buffer);
for (; i + 63 < len; i += 64)
{
SHA1Transform(context->state, &data[i]);
}
j = 0;
}
else
i = 0;
memcpy(&context->buffer[j], &data[i], len - i);
}
/* Add padding and return the message digest. */
void SHA1Final(
unsigned char digest[20],
SHA1_CTX * context
)
{
unsigned i;
unsigned char finalcount[8];
unsigned char c;
#if 0 /* untested "improvement" by DHR */
/* Convert context->count to a sequence of bytes
* in finalcount. Second element first, but
* big-endian order within element.
* But we do it all backwards.
*/
unsigned char *fcp = &finalcount[8];
for (i = 0; i < 2; i++)
{
uint32_t t = context->count[i];
int j;
for (j = 0; j < 4; t >>= 8, j++)
*--fcp = (unsigned char) t}
#else
for (i = 0; i < 8; i++)
{
finalcount[i] = (unsigned char) ((context->count[(i >= 4 ? 0 : 1)] >> ((3 - (i & 3)) * 8)) & 255); /* Endian independent */
}
#endif
c = 0200;
SHA1Update(context, &c, 1);
while ((context->count[0] & 504) != 448)
{
c = 0000;
SHA1Update(context, &c, 1);
}
SHA1Update(context, finalcount, 8); /* Should cause a SHA1Transform() */
for (i = 0; i < 20; i++)
{
digest[i] = (unsigned char)
((context->state[i >> 2] >> ((3 - (i & 3)) * 8)) & 255);
}
/* Wipe variables */
memset(context, '\0', sizeof(*context));
memset(&finalcount, '\0', sizeof(finalcount));
}
void SHA1(
char *hash_out,
const char *str,
uint32_t len)
{
SHA1_CTX ctx;
unsigned int ii;
SHA1Init(&ctx);
for (ii=0; ii<len; ii+=1)
SHA1Update(&ctx, (const unsigned char*)str + ii, 1);
SHA1Final((unsigned char *)hash_out, &ctx);
}

View File

@@ -0,0 +1,52 @@
#ifndef SHA1_H
#define SHA1_H
/*
SHA-1 in C
By Steve Reid <steve@edmweb.com>
100% Public Domain
*/
#include "stdint.h"
#if defined(__cplusplus)
extern "C" {
#endif
typedef struct
{
uint32_t state[5];
uint32_t count[2];
unsigned char buffer[64];
} SHA1_CTX;
void SHA1Transform(
uint32_t state[5],
const unsigned char buffer[64]
);
void SHA1Init(
SHA1_CTX * context
);
void SHA1Update(
SHA1_CTX * context,
const unsigned char *data,
uint32_t len
);
void SHA1Final(
unsigned char digest[20],
SHA1_CTX * context
);
void SHA1(
char *hash_out,
const char *str,
uint32_t len);
#if defined(__cplusplus)
}
#endif
#endif /* SHA1_H */

View File

@@ -0,0 +1,15 @@
{
"name": "sha256",
"version": "0.0.2",
"repo": "jb55/sha256.c",
"description": "sha256 in c",
"keywords": ["sha256", "sha2"],
"src": ["sha256.c", "sha256.h"],
"dependencies": {
"jb55/rotate-bits.h": "0.1.1"
},
"development": {
"thlorenz/tap.c": "*"
}
}

View File

@@ -0,0 +1,221 @@
/* Crypto/Sha256.c -- SHA-256 Hash
2010-06-11 : Igor Pavlov : Public domain
This code is based on public domain code from Wei Dai's Crypto++ library. */
#include "rotate-bits/rotate-bits.h"
#include "sha256.h"
/* define it for speed optimization */
#define _SHA256_UNROLL
#define _SHA256_UNROLL2
void
sha256_init(sha256_t *p)
{
p->state[0] = 0x6a09e667;
p->state[1] = 0xbb67ae85;
p->state[2] = 0x3c6ef372;
p->state[3] = 0xa54ff53a;
p->state[4] = 0x510e527f;
p->state[5] = 0x9b05688c;
p->state[6] = 0x1f83d9ab;
p->state[7] = 0x5be0cd19;
p->count = 0;
}
#define S0(x) (ROTR32(x, 2) ^ ROTR32(x,13) ^ ROTR32(x, 22))
#define S1(x) (ROTR32(x, 6) ^ ROTR32(x,11) ^ ROTR32(x, 25))
#define s0(x) (ROTR32(x, 7) ^ ROTR32(x,18) ^ (x >> 3))
#define s1(x) (ROTR32(x,17) ^ ROTR32(x,19) ^ (x >> 10))
#define blk0(i) (W[i] = data[i])
#define blk2(i) (W[i&15] += s1(W[(i-2)&15]) + W[(i-7)&15] + s0(W[(i-15)&15]))
#define Ch(x,y,z) (z^(x&(y^z)))
#define Maj(x,y,z) ((x&y)|(z&(x|y)))
#define a(i) T[(0-(i))&7]
#define b(i) T[(1-(i))&7]
#define c(i) T[(2-(i))&7]
#define d(i) T[(3-(i))&7]
#define e(i) T[(4-(i))&7]
#define f(i) T[(5-(i))&7]
#define g(i) T[(6-(i))&7]
#define h(i) T[(7-(i))&7]
#ifdef _SHA256_UNROLL2
#define R(a,b,c,d,e,f,g,h, i) h += S1(e) + Ch(e,f,g) + K[i+j] + (j?blk2(i):blk0(i));\
d += h; h += S0(a) + Maj(a, b, c)
#define RX_8(i) \
R(a,b,c,d,e,f,g,h, i); \
R(h,a,b,c,d,e,f,g, (i+1)); \
R(g,h,a,b,c,d,e,f, (i+2)); \
R(f,g,h,a,b,c,d,e, (i+3)); \
R(e,f,g,h,a,b,c,d, (i+4)); \
R(d,e,f,g,h,a,b,c, (i+5)); \
R(c,d,e,f,g,h,a,b, (i+6)); \
R(b,c,d,e,f,g,h,a, (i+7))
#else
#define R(i) h(i) += S1(e(i)) + Ch(e(i),f(i),g(i)) + K[i+j] + (j?blk2(i):blk0(i));\
d(i) += h(i); h(i) += S0(a(i)) + Maj(a(i), b(i), c(i))
#ifdef _SHA256_UNROLL
#define RX_8(i) R(i+0); R(i+1); R(i+2); R(i+3); R(i+4); R(i+5); R(i+6); R(i+7);
#endif
#endif
static const uint32_t K[64] = {
0x428a2f98, 0x71374491, 0xb5c0fbcf, 0xe9b5dba5,
0x3956c25b, 0x59f111f1, 0x923f82a4, 0xab1c5ed5,
0xd807aa98, 0x12835b01, 0x243185be, 0x550c7dc3,
0x72be5d74, 0x80deb1fe, 0x9bdc06a7, 0xc19bf174,
0xe49b69c1, 0xefbe4786, 0x0fc19dc6, 0x240ca1cc,
0x2de92c6f, 0x4a7484aa, 0x5cb0a9dc, 0x76f988da,
0x983e5152, 0xa831c66d, 0xb00327c8, 0xbf597fc7,
0xc6e00bf3, 0xd5a79147, 0x06ca6351, 0x14292967,
0x27b70a85, 0x2e1b2138, 0x4d2c6dfc, 0x53380d13,
0x650a7354, 0x766a0abb, 0x81c2c92e, 0x92722c85,
0xa2bfe8a1, 0xa81a664b, 0xc24b8b70, 0xc76c51a3,
0xd192e819, 0xd6990624, 0xf40e3585, 0x106aa070,
0x19a4c116, 0x1e376c08, 0x2748774c, 0x34b0bcb5,
0x391c0cb3, 0x4ed8aa4a, 0x5b9cca4f, 0x682e6ff3,
0x748f82ee, 0x78a5636f, 0x84c87814, 0x8cc70208,
0x90befffa, 0xa4506ceb, 0xbef9a3f7, 0xc67178f2
};
static void
sha256_transform(uint32_t *state, const uint32_t *data)
{
uint32_t W[16] = {0};
unsigned j;
#ifdef _SHA256_UNROLL2
uint32_t a,b,c,d,e,f,g,h;
a = state[0];
b = state[1];
c = state[2];
d = state[3];
e = state[4];
f = state[5];
g = state[6];
h = state[7];
#else
uint32_t T[8];
for (j = 0; j < 8; j++)
T[j] = state[j];
#endif
for (j = 0; j < 64; j += 16)
{
#if defined(_SHA256_UNROLL) || defined(_SHA256_UNROLL2)
RX_8(0); RX_8(8);
#else
unsigned i;
for (i = 0; i < 16; i++) { R(i); }
#endif
}
#ifdef _SHA256_UNROLL2
state[0] += a;
state[1] += b;
state[2] += c;
state[3] += d;
state[4] += e;
state[5] += f;
state[6] += g;
state[7] += h;
#else
for (j = 0; j < 8; j++)
state[j] += T[j];
#endif
/* Wipe variables */
/* memset(W, 0, sizeof(W)); */
/* memset(T, 0, sizeof(T)); */
}
#undef S0
#undef S1
#undef s0
#undef s1
static void
sha256_write_byte_block(sha256_t *p)
{
uint32_t data32[16];
unsigned i;
for (i = 0; i < 16; i++)
data32[i] =
((uint32_t)(p->buffer[i * 4 ]) << 24) +
((uint32_t)(p->buffer[i * 4 + 1]) << 16) +
((uint32_t)(p->buffer[i * 4 + 2]) << 8) +
((uint32_t)(p->buffer[i * 4 + 3]));
sha256_transform(p->state, data32);
}
void
sha256_hash(unsigned char *buf, const unsigned char *data, size_t size)
{
sha256_t hash;
sha256_init(&hash);
sha256_update(&hash, data, size);
sha256_final(&hash, buf);
}
void
sha256_update(sha256_t *p, const unsigned char *data, size_t size)
{
uint32_t curBufferPos = (uint32_t)p->count & 0x3F;
while (size > 0)
{
p->buffer[curBufferPos++] = *data++;
p->count++;
size--;
if (curBufferPos == 64)
{
curBufferPos = 0;
sha256_write_byte_block(p);
}
}
}
void
sha256_final(sha256_t *p, unsigned char *digest)
{
uint64_t lenInBits = (p->count << 3);
uint32_t curBufferPos = (uint32_t)p->count & 0x3F;
unsigned i;
p->buffer[curBufferPos++] = 0x80;
while (curBufferPos != (64 - 8))
{
curBufferPos &= 0x3F;
if (curBufferPos == 0)
sha256_write_byte_block(p);
p->buffer[curBufferPos++] = 0;
}
for (i = 0; i < 8; i++)
{
p->buffer[curBufferPos++] = (unsigned char)(lenInBits >> 56);
lenInBits <<= 8;
}
sha256_write_byte_block(p);
for (i = 0; i < 8; i++)
{
*digest++ = (unsigned char)(p->state[i] >> 24);
*digest++ = (unsigned char)(p->state[i] >> 16);
*digest++ = (unsigned char)(p->state[i] >> 8);
*digest++ = (unsigned char)(p->state[i]);
}
sha256_init(p);
}

View File

@@ -0,0 +1,24 @@
/* Sha256.h -- SHA-256 Hash
2010-06-11 : Igor Pavlov : Public domain */
#ifndef __CRYPTO_SHA256_H
#define __CRYPTO_SHA256_H
#include <stdlib.h>
#include <stdint.h>
#define SHA256_DIGEST_SIZE 32
typedef struct sha256_t
{
uint32_t state[8];
uint64_t count;
unsigned char buffer[64];
} sha256_t;
void sha256_init(sha256_t *p);
void sha256_update(sha256_t *p, const unsigned char *data, size_t size);
void sha256_final(sha256_t *p, unsigned char *digest);
void sha256_hash(unsigned char *buf, const unsigned char *data, size_t size);
#endif

View File

@@ -0,0 +1,12 @@
{
"name": "xxhash",
"version": "0.8.2",
"repo": "Cyan4973/xxhash",
"description": "Extremely fast non-cryptographic hash algorithm",
"keywords": ["xxhash", "hashing"],
"license": "BSD-2-Clause",
"src": [
"xxhash.c",
"xxhash.h"
]
}

View File

@@ -0,0 +1,42 @@
/*
* xxHash - Extremely Fast Hash algorithm
* Copyright (C) 2012-2023 Yann Collet
*
* BSD 2-Clause License (https://www.opensource.org/licenses/bsd-license.php)
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are
* met:
*
* * Redistributions of source code must retain the above copyright
* notice, this list of conditions and the following disclaimer.
* * Redistributions in binary form must reproduce the above
* copyright notice, this list of conditions and the following disclaimer
* in the documentation and/or other materials provided with the
* distribution.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
* "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
* LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
* A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
* OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
* SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
* LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
* DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
* THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
* (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
* You can contact the author at:
* - xxHash homepage: https://www.xxhash.com
* - xxHash source repository: https://github.com/Cyan4973/xxHash
*/
/*
* xxhash.c instantiates functions defined in xxhash.h
*/
#define XXH_STATIC_LINKING_ONLY /* access advanced declarations */
#define XXH_IMPLEMENTATION /* access definitions */
#include "xxhash.h"

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,694 @@
#include "ggml.h"
#include "gguf.h"
#include <cstdlib> /* abort() */
#include <cstddef>
#include <cstdio>
#include <string>
#include <stdexcept>
#include <algorithm>
#include <cstring>
#include <sstream>
#include <fstream>
#ifdef __cplusplus
extern "C" {
#endif
#include "xxhash/xxhash.h"
#include "sha1/sha1.h"
#include "sha256/sha256.h"
#ifdef __cplusplus
}
#endif
// uuid.uuid5(uuid.NAMESPACE_URL, 'en.wikipedia.org/wiki/Llama.cpp')
#define UUID_NAMESPACE_LLAMA_CPP "ef001206-dadc-5f6d-a15f-3359e577d4e5"
#define UUID_NAMESPACE_LLAMA_CPP_HEX 0xef, 0x00, 0x12, 0x06, 0xda, 0xdc, 0x5f, 0x6d, 0xa1, 0x5f, 0x33, 0x59, 0xe5, 0x77, 0xd4, 0xe5
#define HASH_TYPE_SHA256_STR "sha256"
#define HASH_TYPE_SHA1_STR "sha1"
#define HASH_TYPE_XXH64_STR "xxh64"
#define HASH_TYPE_UUID_STR "uuid"
typedef enum {
HASH_EXIT_SUCCESS = 0, // All hash has been generated or validated
HASH_EXIT_FAILURE = 1, // Generic Failure
HASH_EXIT_MISMATCH = 2, // Hash mismatched during validation
HASH_EXIT_MANIFEST_MISSING_ENTRY = 3, // Hash attempted validation but missing entry in manifest
HASH_EXIT_MANIFEST_UNKNOWN_HASH = 4, // Manifest is present, but we do not know any hash format within it
HASH_EXIT_MANIFEST_FILE_ERROR = 5 // Manifest is either missing or not a known format
} hash_exit_code_t;
typedef enum {
HASH_MANIFEST_NOT_FOUND,
HASH_MANIFEST_MISMATCH,
HASH_MANIFEST_OK,
} hash_manifest_result_t;
struct hash_params {
std::string input;
bool xxh64 = false;
bool sha1 = false;
bool sha256 = false;
bool uuid = false;
bool no_layer = false;
bool manifest_is_usable = false;
std::string manifest_file;
};
struct manifest_check_params {
bool xxh64 = false;
bool sha1 = false;
bool sha256 = false;
bool uuid = false;
};
static char const * hash_manifest_result_to_str(hash_manifest_result_t value) {
switch (value) {
case HASH_MANIFEST_NOT_FOUND: return "Not Found";
case HASH_MANIFEST_MISMATCH: return "Mismatch";
case HASH_MANIFEST_OK: return "Ok";
}
return "?";
}
static char const * hash_exit_code_to_str(hash_exit_code_t value) {
switch (value) {
case HASH_EXIT_SUCCESS: return "Success";
case HASH_EXIT_FAILURE: return "Failure";
case HASH_EXIT_MISMATCH: return "Mismatch";
case HASH_EXIT_MANIFEST_MISSING_ENTRY: return "Manifest Missing Entry";
case HASH_EXIT_MANIFEST_UNKNOWN_HASH: return "Manifest Unknown Hash";
case HASH_EXIT_MANIFEST_FILE_ERROR: return "Manifest File Error";
}
return "?";
}
static void hash_print_usage(const char * executable) {
const hash_params default_params;
printf("\n");
printf("usage: %s [options] GGUF_IN\n", executable);
printf("\n");
printf("Hash a GGUF file");
printf("\n");
printf("options:\n");
printf(" -h, --help show this help message and exit\n");
printf(" --xxh64 use xxh64 hash\n");
printf(" --sha1 use sha1 hash\n");
printf(" --sha256 use sha256 hash\n");
printf(" --all use all hash\n");
printf(" --no-layer exclude per layer hash\n");
printf(" --uuid generate UUIDv5 ID\n");
printf(" -c, --check <manifest> verify against a manifest\n");
printf("\n");
}
static void hash_params_parse_ex(int argc, const char ** argv, hash_params & params) {
std::string arg;
bool invalid_param = false;
const std::string arg_prefix = "--";
int arg_idx = 1;
for (; arg_idx < argc && strncmp(argv[arg_idx], "--", 2) == 0; arg_idx++) {
arg = argv[arg_idx];
if (arg.compare(0, arg_prefix.size(), arg_prefix) == 0) {
std::replace(arg.begin(), arg.end(), '_', '-');
}
bool arg_found = false;
if (arg == "-h" || arg == "--help") {
hash_print_usage(argv[0]);
exit(0);
}
if (arg == "--xxh64") {
arg_found = true;
params.xxh64 = true;
}
if (arg == "--sha1") {
arg_found = true;
params.sha1 = true;
}
if (arg == "--uuid") {
arg_found = true;
params.uuid = true;
}
if (arg == "--sha256") {
arg_found = true;
params.sha256 = true;
}
if (arg == "--all") {
arg_found = true;
params.sha256 = true;
params.sha1 = true;
params.xxh64 = true;
}
if (arg == "--no-layer") {
arg_found = true;
params.no_layer = true;
}
if (arg == "-c" || arg == "--check") {
if (++arg_idx >= argc) {
invalid_param = true;
break;
}
arg_found = true;
params.manifest_file = argv[arg_idx];
}
if (!arg_found) {
throw std::invalid_argument("error: unknown argument: " + arg);
}
}
if (invalid_param) {
throw std::invalid_argument("error: invalid parameter for argument:" + arg);
}
if (argc - arg_idx < 1) {
throw std::invalid_argument("error: bad arguments");
}
params.input = argv[arg_idx++];
}
static bool hash_params_parse(int argc, const char ** argv, hash_params & params) {
bool result = true;
try {
hash_params_parse_ex(argc, argv, params);
}
catch (const std::invalid_argument & ex) {
fprintf(stderr, "%s\n", ex.what());
hash_print_usage(argv[0]);
exit(EXIT_FAILURE);
}
return result;
}
static bool manifest_type(const std::string & manifest_file, manifest_check_params & manifest_check) {
if (manifest_file.empty()) {
return false;
}
std::ifstream file(manifest_file);
if (!file.is_open()) {
return false;
}
std::string manifest_entry_line;
while (getline(file, manifest_entry_line)) {
// hash_type_str hash_str tensor_name
// e.g. 'xxh64 f66e9cd66a4396a0 test.gguf:tensor_0'
std::istringstream line_stream(manifest_entry_line);
std::string file_hash_type;
if (line_stream >> file_hash_type) {
if (file_hash_type == HASH_TYPE_SHA256_STR) {
manifest_check.sha256 = true;
} else if (file_hash_type == HASH_TYPE_SHA1_STR) {
manifest_check.sha1 = true;
} else if (file_hash_type == HASH_TYPE_XXH64_STR) {
manifest_check.xxh64 = true;
} else if (file_hash_type == HASH_TYPE_UUID_STR) {
manifest_check.uuid = true;
}
}
}
return true;
}
static hash_manifest_result_t manifest_verify(const std::string& manifest_file, const std::string& hash_type_str, const std::string& hash_str, const std::string& tensor_name) {
if (manifest_file.empty()) {
return HASH_MANIFEST_NOT_FOUND;
}
std::ifstream file(manifest_file);
if (!file.is_open()) {
return HASH_MANIFEST_NOT_FOUND;
}
std::string manifest_entry_line;
while (getline(file, manifest_entry_line)) {
std::istringstream line_stream(manifest_entry_line);
std::string file_hash_type;
std::string file_hash;
std::string file_tensor_name;
if (line_stream >> file_hash_type >> file_hash >> file_tensor_name) {
// Line parsed. Check hash validity
if (file_hash_type != hash_type_str) {
continue;
}
if (file_tensor_name != tensor_name) {
continue;
}
return (file_hash == hash_str) ? HASH_MANIFEST_OK : HASH_MANIFEST_MISMATCH;
}
}
return HASH_MANIFEST_NOT_FOUND;
}
static void generate_uuidv5(const unsigned char sha1_digest[20], unsigned char uuid[16]) {
// Ref: https://www.rfc-editor.org/rfc/rfc9562.html#section-5.5
// Assumes that digest was processed correctly with the expected namespace
for (int i = 0; i < 16; i++) {
uuid[i] = sha1_digest[i];
}
// Set bits corresponding to UUID ver 5
uuid[ 6] &= ~(0xF << 4);
uuid[ 6] |= (5 << 4);
// Set bits corresponding to UUID variant 0b10XX
uuid[ 8] &= ~(0xc << 4);
uuid[ 8] |= (0x8 << 4);
}
static hash_exit_code_t gguf_hash(const hash_params & hash_params) {
const std::string & fname = hash_params.input;
struct ggml_context * ctx_data = NULL;
struct gguf_init_params params = {
/*.no_alloc = */ false,
/*.ctx = */ &ctx_data,
};
// xxh64 init
XXH64_state_t* xxh64_model_hash_state = NULL;
if (hash_params.xxh64) {
xxh64_model_hash_state = XXH64_createState();
if (xxh64_model_hash_state==NULL) {
abort();
}
XXH64_hash_t const seed = 0;
if (XXH64_reset(xxh64_model_hash_state, seed) == XXH_ERROR) {
abort();
}
}
// sha1 init
SHA1_CTX sha1_model_hash_ctx;
if (hash_params.sha1) {
SHA1Init(&sha1_model_hash_ctx);
}
// sha256 init
sha256_t sha256_model_hash_ctx;
if (hash_params.sha256) {
sha256_init(&sha256_model_hash_ctx);
}
// sha1 for uuid init
SHA1_CTX sha1_for_uuid_ctx;
if (hash_params.uuid) {
unsigned char const uuidv5_namespace[] = {UUID_NAMESPACE_LLAMA_CPP_HEX};
SHA1Init(&sha1_for_uuid_ctx);
SHA1Update( &sha1_for_uuid_ctx, (unsigned char const *)uuidv5_namespace, sizeof(uuidv5_namespace));
}
struct gguf_context * ctx = gguf_init_from_file(fname.c_str(), params);
const int n_tensors = gguf_get_n_tensors(ctx);
bool tensor_layer_in_manifest = false;
bool model_in_manifest = false;
bool tensor_layer_has_mismatch = false;
bool model_has_mismatch = false;
for (int i = 0; i < n_tensors; ++i) {
const char * name = gguf_get_tensor_name(ctx, i);
struct ggml_tensor * cur = ggml_get_tensor(ctx_data, name);
auto n_bytes = ggml_nbytes(cur);
auto *raw_data = cur->data;
const std::string tensor_layer_name = fname + ":" + name;
if (hash_params.xxh64) {
if (!hash_params.no_layer) {
// Per Layer Hash
XXH64_hash_t hash = XXH64(raw_data, n_bytes, 0);
char hex_result[17];
for (int offset = 0; offset < 8; offset++) {
unsigned int shift_bits_by = (8 * (8 - offset - 1));
snprintf( ( hex_result + (2*offset)), sizeof(hex_result) - (2*offset), "%02x", (unsigned char) (hash >> shift_bits_by)&0xff);
}
if (hash_params.manifest_is_usable) {
hash_manifest_result_t verify_result = manifest_verify(hash_params.manifest_file, HASH_TYPE_XXH64_STR, hex_result, tensor_layer_name);
switch (verify_result) {
case HASH_MANIFEST_NOT_FOUND:
break;
case HASH_MANIFEST_MISMATCH:
tensor_layer_in_manifest = true;
tensor_layer_has_mismatch = true;
break;
case HASH_MANIFEST_OK:
tensor_layer_in_manifest = true;
break;
}
printf("%-8s %-s %s - %s\n", HASH_TYPE_XXH64_STR, hex_result, tensor_layer_name.c_str(), hash_manifest_result_to_str(verify_result));
} else {
printf("%-8s %-s %s\n", HASH_TYPE_XXH64_STR, hex_result, tensor_layer_name.c_str());
}
}
// Overall Model Hash
if (XXH64_update(xxh64_model_hash_state, raw_data, n_bytes) == XXH_ERROR) abort();
}
if (hash_params.sha1) {
if (!hash_params.no_layer) {
// Per Layer Hash
char result[21]; // sha1 outputs 20 bytes
SHA1( result, (const char *)raw_data, n_bytes);
char hex_result[41] = {0};
for (int offset = 0; offset < 20; offset++) {
snprintf( ( hex_result + (2*offset)), sizeof(hex_result) - (2*offset), "%02x", result[offset]&0xff);
}
if (hash_params.manifest_is_usable) {
hash_manifest_result_t verify_result = manifest_verify(hash_params.manifest_file, HASH_TYPE_SHA1_STR, hex_result, tensor_layer_name);
switch (verify_result) {
case HASH_MANIFEST_NOT_FOUND:
break;
case HASH_MANIFEST_MISMATCH:
tensor_layer_in_manifest = true;
tensor_layer_has_mismatch = true;
break;
case HASH_MANIFEST_OK:
tensor_layer_in_manifest = true;
break;
}
printf("%-8s %-s %s - %s\n", HASH_TYPE_SHA1_STR, hex_result, tensor_layer_name.c_str(), hash_manifest_result_to_str(verify_result));
} else {
printf("%-8s %-s %s\n", HASH_TYPE_SHA1_STR, hex_result, tensor_layer_name.c_str());
}
}
// Overall Model Hash
SHA1Update( &sha1_model_hash_ctx, (unsigned char const *)raw_data, n_bytes);
}
if (hash_params.sha256) {
if (!hash_params.no_layer) {
// Per Layer Hash
unsigned char result[SHA256_DIGEST_SIZE]; // sha256 outputs 32 bytes
sha256_hash((unsigned char*) result, (const unsigned char *)raw_data, n_bytes);
char hex_result[SHA256_DIGEST_SIZE * 2 + 1] = {0};
for (int offset = 0; offset < SHA256_DIGEST_SIZE; offset++) {
snprintf( ( hex_result + (2*offset)), sizeof(hex_result) - (2*offset), "%02x", result[offset]&0xff);
}
if (hash_params.manifest_is_usable) {
hash_manifest_result_t verify_result = manifest_verify(hash_params.manifest_file, HASH_TYPE_SHA256_STR, hex_result, tensor_layer_name);
switch (verify_result) {
case HASH_MANIFEST_NOT_FOUND:
break;
case HASH_MANIFEST_MISMATCH:
tensor_layer_in_manifest = true;
tensor_layer_has_mismatch = true;
break;
case HASH_MANIFEST_OK:
tensor_layer_in_manifest = true;
break;
}
printf("%-8s %-s %s - %s\n", HASH_TYPE_SHA256_STR, hex_result, tensor_layer_name.c_str(), hash_manifest_result_to_str(verify_result));
} else {
printf("%-8s %-s %s\n", HASH_TYPE_SHA256_STR, hex_result, tensor_layer_name.c_str());
}
}
// Overall Model Hash
sha256_update( &sha256_model_hash_ctx, (unsigned char const *)raw_data, n_bytes);
}
if (hash_params.uuid) {
SHA1Update( &sha1_for_uuid_ctx, (unsigned char const *)raw_data, n_bytes);
}
}
if (hash_params.xxh64) {
XXH64_hash_t const hash = XXH64_digest(xxh64_model_hash_state);
char hex_result[17];
for (int offset = 0; offset < 8; offset++) {
unsigned int shift_bits_by = (8 * (8 - offset - 1));
snprintf( ( hex_result + (2*offset)), sizeof(hex_result) - (2*offset), "%02x", (unsigned char) (hash >> shift_bits_by)&0xff);
}
if (hash_params.manifest_is_usable) {
hash_manifest_result_t verify_result = manifest_verify(hash_params.manifest_file, HASH_TYPE_XXH64_STR, hex_result, fname);
switch (verify_result) {
case HASH_MANIFEST_NOT_FOUND:
break;
case HASH_MANIFEST_MISMATCH:
model_in_manifest = true;
model_has_mismatch = true;
break;
case HASH_MANIFEST_OK:
model_in_manifest = true;
break;
}
printf("%-8s %-s %s - %s\n", HASH_TYPE_XXH64_STR, hex_result, fname.c_str(), hash_manifest_result_to_str(verify_result));
} else {
printf("%-8s %-s %s\n", HASH_TYPE_XXH64_STR, hex_result, fname.c_str());
}
}
if (hash_params.sha1) {
unsigned char result[21];
SHA1Final(result, &sha1_model_hash_ctx);
char hex_result[41];
for (int offset = 0; offset < 20; offset++) {
snprintf( ( hex_result + (2*offset)), sizeof(hex_result) - (2*offset), "%02x", result[offset]&0xff);
}
if (hash_params.manifest_is_usable) {
hash_manifest_result_t verify_result = manifest_verify(hash_params.manifest_file, HASH_TYPE_SHA1_STR, hex_result, fname);
switch (verify_result) {
case HASH_MANIFEST_NOT_FOUND:
break;
case HASH_MANIFEST_MISMATCH:
model_in_manifest = true;
model_has_mismatch = true;
break;
case HASH_MANIFEST_OK:
model_in_manifest = true;
break;
}
printf("%-8s %-s %s - %s\n", HASH_TYPE_SHA1_STR, hex_result, fname.c_str(), hash_manifest_result_to_str(verify_result));
} else {
printf("%-8s %-s %s\n", HASH_TYPE_SHA1_STR, hex_result, fname.c_str());
}
}
if (hash_params.sha256) {
unsigned char result[SHA256_DIGEST_SIZE]; // sha256 outputs 32 bytes
sha256_final( &sha256_model_hash_ctx, result);
char hex_result[SHA256_DIGEST_SIZE * 2 + 1] = {0};
for (int offset = 0; offset < SHA256_DIGEST_SIZE; offset++) {
snprintf( ( hex_result + (2*offset)), sizeof(hex_result) - (2*offset), "%02x", result[offset]&0xff);
}
if (hash_params.manifest_is_usable) {
hash_manifest_result_t verify_result = manifest_verify(hash_params.manifest_file, HASH_TYPE_SHA256_STR, hex_result, fname);
switch (verify_result) {
case HASH_MANIFEST_NOT_FOUND:
break;
case HASH_MANIFEST_MISMATCH:
model_in_manifest = true;
model_has_mismatch = true;
break;
case HASH_MANIFEST_OK:
model_in_manifest = true;
break;
}
printf("%-8s %-s %s - %s\n", HASH_TYPE_SHA256_STR, hex_result, fname.c_str(), hash_manifest_result_to_str(verify_result));
} else {
printf("%-8s %-s %s\n", HASH_TYPE_SHA256_STR, hex_result, fname.c_str());
}
}
if (hash_params.uuid) {
unsigned char result[21];
SHA1Final(result, &sha1_for_uuid_ctx);
unsigned char uuid[16];
generate_uuidv5(result, uuid);
char string_buffer[37] = {0};
snprintf(string_buffer, sizeof(string_buffer), "%02x%02x%02x%02x-%02x%02x-%02x%02x-%02x%02x-%02x%02x%02x%02x%02x%02x",
uuid[0], uuid[1], uuid[2], uuid[3],
uuid[4], uuid[5], uuid[6], uuid[7],
uuid[8], uuid[9], uuid[10], uuid[11],
uuid[12], uuid[13], uuid[14], uuid[15]);
if (hash_params.manifest_is_usable) {
hash_manifest_result_t verify_result = manifest_verify(hash_params.manifest_file, HASH_TYPE_SHA256_STR, string_buffer, fname);
switch (verify_result) {
case HASH_MANIFEST_NOT_FOUND:
break;
case HASH_MANIFEST_MISMATCH:
model_in_manifest = true;
model_has_mismatch = true;
break;
case HASH_MANIFEST_OK:
model_in_manifest = true;
break;
}
printf("%-8s %-s %s - %s\n", HASH_TYPE_UUID_STR, string_buffer, fname.c_str(), hash_manifest_result_to_str(verify_result));
} else {
printf("%-8s %-s %s\n", HASH_TYPE_UUID_STR, string_buffer, fname.c_str());
}
}
ggml_free(ctx_data);
gguf_free(ctx);
if (hash_params.manifest_is_usable) {
// In hash verification mode
if (!model_in_manifest) {
// model missing in manifest?
// Check tensor layer...
if (!tensor_layer_in_manifest) {
// Still missing? Maybe we are reading the wrong manifest.
return HASH_EXIT_MANIFEST_MISSING_ENTRY;
}
if (tensor_layer_has_mismatch) {
// Per tensor check found error
return HASH_EXIT_FAILURE;
}
// All per tensor layer checks passed? Sounds good enough.
return HASH_EXIT_SUCCESS;
}
// Overall model check passed, but let's check per layer just in case
// If missing, we don't care too much as the overall model checked
if (tensor_layer_in_manifest && tensor_layer_has_mismatch) {
return HASH_EXIT_FAILURE;
}
if (model_has_mismatch) {
// model has failed hash somewhere in the model
return HASH_EXIT_FAILURE;
}
// All checks appears to be fine
return HASH_EXIT_SUCCESS;
}
// In hash generation mode
return HASH_EXIT_SUCCESS;
}
int main(int argc, const char ** argv) {
hash_params params;
manifest_check_params manifest_check;
hash_params_parse(argc, argv, params);
if (!params.manifest_file.empty()) {
if (!manifest_type(params.manifest_file, manifest_check)) {
printf("ERROR cannot open manifest %s", params.manifest_file.c_str());
return HASH_EXIT_MANIFEST_FILE_ERROR;
}
if (!manifest_check.sha256 && !manifest_check.sha1 && !manifest_check.xxh64 && !manifest_check.uuid) {
printf("ERROR manifest does not have any known hash format in %s", params.manifest_file.c_str());
return HASH_EXIT_MANIFEST_UNKNOWN_HASH;
}
printf("manifest %s", params.manifest_file.c_str());
if (manifest_check.sha256) {
printf(" sha256");
}
if (manifest_check.sha1) {
printf(" sha1");
}
if (manifest_check.xxh64) {
printf(" xxh64");
}
if (manifest_check.uuid) {
printf(" uuid");
}
printf("\n");
// Autoselect the highest security hash if manifest is provided but
// the user has not specifically defined the hash they care about
if (!params.xxh64 && !params.sha1 && !params.uuid && !params.sha256) {
// User has not selected a specific value, pick most secure hash
if (manifest_check.sha256) {
params.sha256 = true;
} else if (manifest_check.sha1) {
params.sha1 = true;
} else if (manifest_check.xxh64) {
params.xxh64 = true;
} else if (manifest_check.uuid) {
params.uuid = true;
}
}
params.manifest_is_usable = true;
}
// By default if no swich argument provided, assume xxh64
if (!params.xxh64 && !params.sha1 && !params.uuid && !params.sha256) {
params.xxh64 = true;
}
hash_exit_code_t exit_code = gguf_hash(params);
if (params.manifest_is_usable) {
printf("\nVerification results for %s - %s\n", params.manifest_file.c_str(), hash_exit_code_to_str(exit_code));
}
return exit_code;
}

View File

@@ -0,0 +1,5 @@
set(TARGET llama-gguf)
add_executable(${TARGET} gguf.cpp)
install(TARGETS ${TARGET} RUNTIME)
target_link_libraries(${TARGET} PRIVATE ggml ${CMAKE_THREAD_LIBS_INIT})
target_compile_features(${TARGET} PRIVATE cxx_std_17)

270
examples/gguf/gguf.cpp Normal file
View File

@@ -0,0 +1,270 @@
#include "ggml.h"
#include "gguf.h"
#include <cstdio>
#include <string>
#include <sstream>
#include <vector>
#undef MIN
#undef MAX
#define MIN(a, b) ((a) < (b) ? (a) : (b))
#define MAX(a, b) ((a) > (b) ? (a) : (b))
template <typename T>
static std::string to_string(const T & val) {
std::stringstream ss;
ss << val;
return ss.str();
}
static bool gguf_ex_write(const std::string & fname) {
struct gguf_context * ctx = gguf_init_empty();
gguf_set_val_u8 (ctx, "some.parameter.uint8", 0x12);
gguf_set_val_i8 (ctx, "some.parameter.int8", -0x13);
gguf_set_val_u16 (ctx, "some.parameter.uint16", 0x1234);
gguf_set_val_i16 (ctx, "some.parameter.int16", -0x1235);
gguf_set_val_u32 (ctx, "some.parameter.uint32", 0x12345678);
gguf_set_val_i32 (ctx, "some.parameter.int32", -0x12345679);
gguf_set_val_f32 (ctx, "some.parameter.float32", 0.123456789f);
gguf_set_val_u64 (ctx, "some.parameter.uint64", 0x123456789abcdef0ull);
gguf_set_val_i64 (ctx, "some.parameter.int64", -0x123456789abcdef1ll);
gguf_set_val_f64 (ctx, "some.parameter.float64", 0.1234567890123456789);
gguf_set_val_bool(ctx, "some.parameter.bool", true);
gguf_set_val_str (ctx, "some.parameter.string", "hello world");
gguf_set_arr_data(ctx, "some.parameter.arr.i16", GGUF_TYPE_INT16, std::vector<int16_t>{ 1, 2, 3, 4, }.data(), 4);
gguf_set_arr_data(ctx, "some.parameter.arr.f32", GGUF_TYPE_FLOAT32, std::vector<float>{ 3.145f, 2.718f, 1.414f, }.data(), 3);
gguf_set_arr_str (ctx, "some.parameter.arr.str", std::vector<const char *>{ "hello", "world", "!" }.data(), 3);
struct ggml_init_params params = {
/*.mem_size =*/ 128ull*1024ull*1024ull,
/*.mem_buffer =*/ NULL,
/*.no_alloc =*/ false,
};
struct ggml_context * ctx_data = ggml_init(params);
const int n_tensors = 10;
// tensor infos
for (int i = 0; i < n_tensors; ++i) {
const std::string name = "tensor_" + to_string(i);
int64_t ne[GGML_MAX_DIMS] = { 1 };
int32_t n_dims = rand() % GGML_MAX_DIMS + 1;
for (int j = 0; j < n_dims; ++j) {
ne[j] = rand() % 10 + 1;
}
struct ggml_tensor * cur = ggml_new_tensor(ctx_data, GGML_TYPE_F32, n_dims, ne);
ggml_set_name(cur, name.c_str());
{
float * data = (float *) cur->data;
for (int j = 0; j < ggml_nelements(cur); ++j) {
data[j] = 100 + i;
}
}
gguf_add_tensor(ctx, cur);
}
gguf_write_to_file(ctx, fname.c_str(), false);
printf("%s: wrote file '%s;\n", __func__, fname.c_str());
ggml_free(ctx_data);
gguf_free(ctx);
return true;
}
// just read tensor info
static bool gguf_ex_read_0(const std::string & fname) {
struct gguf_init_params params = {
/*.no_alloc = */ false,
/*.ctx = */ NULL,
};
struct gguf_context * ctx = gguf_init_from_file(fname.c_str(), params);
if (!ctx) {
fprintf(stderr, "%s: failed to load '%s'\n", __func__, fname.c_str());
return false;
}
printf("%s: version: %d\n", __func__, gguf_get_version(ctx));
printf("%s: alignment: %zu\n", __func__, gguf_get_alignment(ctx));
printf("%s: data offset: %zu\n", __func__, gguf_get_data_offset(ctx));
// kv
{
const int n_kv = gguf_get_n_kv(ctx);
printf("%s: n_kv: %d\n", __func__, n_kv);
for (int i = 0; i < n_kv; ++i) {
const char * key = gguf_get_key(ctx, i);
printf("%s: kv[%d]: key = %s\n", __func__, i, key);
}
}
// find kv string
{
const char * findkey = "some.parameter.string";
const int keyidx = gguf_find_key(ctx, findkey);
if (keyidx == -1) {
printf("%s: find key: %s not found.\n", __func__, findkey);
} else {
const char * key_value = gguf_get_val_str(ctx, keyidx);
printf("%s: find key: %s found, kv[%d] value = %s\n", __func__, findkey, keyidx, key_value);
}
}
// tensor info
{
const int n_tensors = gguf_get_n_tensors(ctx);
printf("%s: n_tensors: %d\n", __func__, n_tensors);
for (int i = 0; i < n_tensors; ++i) {
const char * name = gguf_get_tensor_name (ctx, i);
const size_t size = gguf_get_tensor_size (ctx, i);
const size_t offset = gguf_get_tensor_offset(ctx, i);
printf("%s: tensor[%d]: name = %s, size = %zu, offset = %zu\n", __func__, i, name, size, offset);
}
}
gguf_free(ctx);
return true;
}
// read and create ggml_context containing the tensors and their data
static bool gguf_ex_read_1(const std::string & fname, bool check_data) {
struct ggml_context * ctx_data = NULL;
struct gguf_init_params params = {
/*.no_alloc = */ false,
/*.ctx = */ &ctx_data,
};
struct gguf_context * ctx = gguf_init_from_file(fname.c_str(), params);
printf("%s: version: %d\n", __func__, gguf_get_version(ctx));
printf("%s: alignment: %zu\n", __func__, gguf_get_alignment(ctx));
printf("%s: data offset: %zu\n", __func__, gguf_get_data_offset(ctx));
// kv
{
const int n_kv = gguf_get_n_kv(ctx);
printf("%s: n_kv: %d\n", __func__, n_kv);
for (int i = 0; i < n_kv; ++i) {
const char * key = gguf_get_key(ctx, i);
printf("%s: kv[%d]: key = %s\n", __func__, i, key);
}
}
// tensor info
{
const int n_tensors = gguf_get_n_tensors(ctx);
printf("%s: n_tensors: %d\n", __func__, n_tensors);
for (int i = 0; i < n_tensors; ++i) {
const char * name = gguf_get_tensor_name (ctx, i);
const size_t size = gguf_get_tensor_size (ctx, i);
const size_t offset = gguf_get_tensor_offset(ctx, i);
const auto type = gguf_get_tensor_type (ctx, i);
const char * type_name = ggml_type_name(type);
const size_t type_size = ggml_type_size(type);
const size_t n_elements = size / type_size;
printf("%s: tensor[%d]: name = %s, size = %zu, offset = %zu, type = %s, n_elts = %zu\n", __func__, i, name, size, offset, type_name, n_elements);
}
}
// data
{
const int n_tensors = gguf_get_n_tensors(ctx);
for (int i = 0; i < n_tensors; ++i) {
printf("%s: reading tensor %d data\n", __func__, i);
const char * name = gguf_get_tensor_name(ctx, i);
struct ggml_tensor * cur = ggml_get_tensor(ctx_data, name);
printf("%s: tensor[%d]: n_dims = %d, ne = (%d, %d, %d, %d), name = %s, data = %p\n",
__func__, i, ggml_n_dims(cur), int(cur->ne[0]), int(cur->ne[1]), int(cur->ne[2]), int(cur->ne[3]), cur->name, cur->data);
// print first 10 elements
const float * data = (const float *) cur->data;
printf("%s data[:10] : ", name);
for (int j = 0; j < MIN(10, ggml_nelements(cur)); ++j) {
printf("%f ", data[j]);
}
printf("\n\n");
// check data
if (check_data) {
const float * data = (const float *) cur->data;
for (int j = 0; j < ggml_nelements(cur); ++j) {
if (data[j] != 100 + i) {
fprintf(stderr, "%s: tensor[%d], data[%d]: found %f, expected %f\n", __func__, i, j, data[j], float(100 + i));
gguf_free(ctx);
return false;
}
}
}
}
}
printf("%s: ctx_data size: %zu\n", __func__, ggml_get_mem_size(ctx_data));
ggml_free(ctx_data);
gguf_free(ctx);
return true;
}
int main(int argc, char ** argv) {
if (argc < 3) {
printf("usage: %s data.gguf r|w [n]\n", argv[0]);
printf("r: read data.gguf file\n");
printf("w: write data.gguf file\n");
printf("n: no check of tensor data\n");
return -1;
}
bool check_data = true;
if (argc == 4) {
check_data = false;
}
srand(123456);
const std::string fname(argv[1]);
const std::string mode (argv[2]);
GGML_ASSERT((mode == "r" || mode == "w") && "mode must be r or w");
if (mode == "w") {
GGML_ASSERT(gguf_ex_write(fname) && "failed to write gguf file");
} else if (mode == "r") {
GGML_ASSERT(gguf_ex_read_0(fname) && "failed to read gguf file");
GGML_ASSERT(gguf_ex_read_1(fname, check_data) && "failed to read gguf file");
}
return 0;
}

View File

@@ -0,0 +1,5 @@
set(TARGET llama-idle)
add_executable(${TARGET} idle.cpp)
install(TARGETS ${TARGET} RUNTIME)
target_link_libraries(${TARGET} PRIVATE llama common ${CMAKE_THREAD_LIBS_INIT})
target_compile_features(${TARGET} PRIVATE cxx_std_11)

3
examples/idle/README.md Normal file
View File

@@ -0,0 +1,3 @@
# llama.cpp/example/idle
https://github.com/ggml-org/llama.cpp/pull/17766

110
examples/idle/idle.cpp Normal file
View File

@@ -0,0 +1,110 @@
#include "arg.h"
#include "common.h"
#include "log.h"
#include "llama.h"
#include <cmath>
#include <cstdio>
#include <cstring>
#include <string>
#include <thread>
#include <vector>
static void print_usage(int /*argc*/, char ** argv) {
printf("\nexample usage:\n");
printf("\n %s -m model.gguf [-ngl n_gpu_layers]\n", argv[0]);
printf("\n");
}
int main(int argc, char ** argv) {
common_params params;
if (!common_params_parse(argc, argv, params, LLAMA_EXAMPLE_COMMON, print_usage)) {
return 1;
}
common_init();
// init LLM
llama_backend_init();
llama_numa_init(params.numa);
// initialize the model
llama_model_params model_params = common_model_params_to_llama(params);
llama_model * model = llama_model_load_from_file(params.model.path.c_str(), model_params);
if (model == NULL) {
LOG_ERR("%s: error: unable to load model\n" , __func__);
return 1;
}
const llama_vocab * vocab = llama_model_get_vocab(model);
// we need just a dummy token to evaluate
std::vector<llama_token> prompt_tokens(1, llama_vocab_bos(vocab));
llama_context_params ctx_params = llama_context_default_params();
ctx_params.n_ctx = 512;
ctx_params.n_batch = 512;
ctx_params.no_perf = false;
llama_context * ctx = llama_init_from_model(model, ctx_params);
if (ctx == NULL) {
fprintf(stderr , "%s: error: failed to create the llama_context\n" , __func__);
return 1;
}
llama_batch batch = llama_batch_get_one(prompt_tokens.data(), prompt_tokens.size());
const int n_iters = 3;
// warm-up
llama_decode(ctx, batch);
llama_memory_clear(llama_get_memory(ctx), true);
llama_synchronize(ctx);
for (int64_t t_pause_ms = 0; t_pause_ms <= 4000; t_pause_ms += 800) {
double t_sum_us = 0.0;
double t_sum2_us = 0.0;
for (int i = 0; i < n_iters; i++) {
// this pause is important - it simulates "idle GPU"
std::this_thread::sleep_for(std::chrono::milliseconds(t_pause_ms));
const int64_t t_start_us = llama_time_us();
// this should take constant time
llama_decode(ctx, batch);
llama_synchronize(ctx);
const int64_t t_end_us = llama_time_us();
const double t_cur_us = t_end_us - t_start_us;
#if 1
// print individual decode times
printf(" - decode time: %8.2f ms\n", t_cur_us / 1000);
#endif
t_sum_us += t_cur_us;
t_sum2_us += t_cur_us * t_cur_us;
llama_memory_clear(llama_get_memory(ctx), true);
llama_synchronize(ctx); // just in case
}
const double t_avg_us = t_sum_us / n_iters;
const double t_dev_us = sqrt((t_sum2_us / (n_iters - 1)) - (t_avg_us * t_avg_us * n_iters) / (n_iters - 1));
printf("iters: %4d, pause: %5d ms, avg decode time: %8.2f +/- %4.2f ms\n", n_iters, (int) t_pause_ms, t_avg_us / 1000, t_dev_us / 1000);
fflush(stdout);
}
llama_free(ctx);
llama_model_free(model);
return 0;
}

View File

@@ -0,0 +1,82 @@
# Usage:
#! ./llama-server -m some-model.gguf &
#! pip install pydantic
#! python json_schema_pydantic_example.py
from pydantic import BaseModel, Field, TypeAdapter
from annotated_types import MinLen
from typing import Annotated, List, Optional
import json, requests
if True:
def create_completion(*, response_model=None, endpoint="http://localhost:8080/v1/chat/completions", messages, **kwargs):
'''
Creates a chat completion using an OpenAI-compatible endpoint w/ JSON schema support
(llama.cpp server, llama-cpp-python, Anyscale / Together...)
The response_model param takes a type (+ supports Pydantic) and behaves just as w/ Instructor (see below)
'''
response_format = None
type_adapter = None
if response_model:
type_adapter = TypeAdapter(response_model)
schema = type_adapter.json_schema()
messages = [{
"role": "system",
"content": f"You respond in JSON format with the following schema: {json.dumps(schema, indent=2)}"
}] + messages
response_format={"type": "json_object", "schema": schema}
data = requests.post(endpoint, headers={"Content-Type": "application/json"},
json=dict(messages=messages, response_format=response_format, **kwargs)).json()
if 'error' in data:
raise Exception(data['error']['message'])
content = data["choices"][0]["message"]["content"]
return type_adapter.validate_json(content) if type_adapter else content
else:
# This alternative branch uses Instructor + OpenAI client lib.
# Instructor support streamed iterable responses, retry & more.
# (see https://python.useinstructor.com/)
#! pip install instructor openai
import instructor, openai
client = instructor.patch(
openai.OpenAI(api_key="123", base_url="http://localhost:8080"),
mode=instructor.Mode.JSON_SCHEMA)
create_completion = client.chat.completions.create
if __name__ == '__main__':
class QAPair(BaseModel):
class Config:
extra = 'forbid' # triggers additionalProperties: false in the JSON schema
question: str
concise_answer: str
justification: str
stars: Annotated[int, Field(ge=1, le=5)]
class PyramidalSummary(BaseModel):
class Config:
extra = 'forbid' # triggers additionalProperties: false in the JSON schema
title: str
summary: str
question_answers: Annotated[List[QAPair], MinLen(2)]
sub_sections: Optional[Annotated[List['PyramidalSummary'], MinLen(2)]]
print("# Summary\n", create_completion(
model="...",
response_model=PyramidalSummary,
messages=[{
"role": "user",
"content": f"""
You are a highly efficient corporate document summarizer.
Create a pyramidal summary of an imaginary internal document about our company processes
(starting high-level, going down to each sub sections).
Keep questions short, and answers even shorter (trivia / quizz style).
"""
}]))

View File

@@ -0,0 +1,837 @@
#!/usr/bin/env python3
from __future__ import annotations
import argparse
import itertools
import json
import re
import sys
from typing import Any, List, Optional, Set, Tuple, Union
def _build_repetition(item_rule, min_items, max_items, separator_rule=None):
if max_items == 0:
return ""
if min_items == 0 and max_items == 1:
return f'{item_rule}?'
if not separator_rule:
if min_items == 1 and max_items is None:
return f'{item_rule}+'
elif min_items == 0 and max_items is None:
return f'{item_rule}*'
else:
return f'{item_rule}{{{min_items},{max_items if max_items is not None else ""}}}'
result = item_rule + ' ' + _build_repetition(f'({separator_rule} {item_rule})', min_items - 1 if min_items > 0 else 0, max_items - 1 if max_items is not None else None)
return f'({result})?' if min_items == 0 else result
def _generate_min_max_int(min_value: Optional[int], max_value: Optional[int], out: list, decimals_left: int = 16, top_level: bool = True):
has_min = min_value != None
has_max = max_value != None
def digit_range(from_char: str, to_char: str):
out.append("[")
if from_char == to_char:
out.append(from_char)
else:
out.append(from_char)
out.append("-")
out.append(to_char)
out.append("]")
def more_digits(min_digits: int, max_digits: int):
out.append("[0-9]")
if min_digits == max_digits and min_digits == 1:
return
out.append("{")
out.append(str(min_digits))
if max_digits != min_digits:
out.append(",")
if max_digits != sys.maxsize:
out.append(str(max_digits))
out.append("}")
def uniform_range(from_str: str, to_str: str):
i = 0
while i < len(from_str) and from_str[i] == to_str[i]:
i += 1
if i > 0:
out.append("\"")
out.append(from_str[:i])
out.append("\"")
if i < len(from_str):
if i > 0:
out.append(" ")
sub_len = len(from_str) - i - 1
if sub_len > 0:
from_sub = from_str[i+1:]
to_sub = to_str[i+1:]
sub_zeros = "0" * sub_len
sub_nines = "9" * sub_len
to_reached = False
out.append("(")
if from_sub == sub_zeros:
digit_range(from_str[i], chr(ord(to_str[i]) - 1))
out.append(" ")
more_digits(sub_len, sub_len)
else:
out.append("[")
out.append(from_str[i])
out.append("] ")
out.append("(")
uniform_range(from_sub, sub_nines)
out.append(")")
if ord(from_str[i]) < ord(to_str[i]) - 1:
out.append(" | ")
if to_sub == sub_nines:
digit_range(chr(ord(from_str[i]) + 1), to_str[i])
to_reached = True
else:
digit_range(chr(ord(from_str[i]) + 1), chr(ord(to_str[i]) - 1))
out.append(" ")
more_digits(sub_len, sub_len)
if not to_reached:
out.append(" | ")
digit_range(to_str[i], to_str[i])
out.append(" ")
uniform_range(sub_zeros, to_sub)
out.append(")")
else:
out.append("[")
out.append(from_str[i])
out.append("-")
out.append(to_str[i])
out.append("]")
if has_min and has_max:
if min_value < 0 and max_value < 0:
out.append("\"-\" (")
_generate_min_max_int(-max_value, -min_value, out, decimals_left, top_level=True)
out.append(")")
return
if min_value < 0:
out.append("\"-\" (")
_generate_min_max_int(0, -min_value, out, decimals_left, top_level=True)
out.append(") | ")
min_value = 0
min_s = str(min_value)
max_s = str(max_value)
min_digits = len(min_s)
max_digits = len(max_s)
for digits in range(min_digits, max_digits):
uniform_range(min_s, "9" * digits)
min_s = "1" + "0" * digits
out.append(" | ")
uniform_range(min_s, max_s)
return
less_decimals = max(decimals_left - 1, 1)
if has_min:
if min_value < 0:
out.append("\"-\" (")
_generate_min_max_int(None, -min_value, out, decimals_left, top_level=False)
out.append(") | [0] | [1-9] ")
more_digits(0, decimals_left - 1)
elif min_value == 0:
if top_level:
out.append("[0] | [1-9] ")
more_digits(0, less_decimals)
else:
more_digits(1, decimals_left)
elif min_value <= 9:
c = str(min_value)
range_start = '1' if top_level else '0'
if c > range_start:
digit_range(range_start, chr(ord(c) - 1))
out.append(" ")
more_digits(1, less_decimals)
out.append(" | ")
digit_range(c, "9")
out.append(" ")
more_digits(0, less_decimals)
else:
min_s = str(min_value)
length = len(min_s)
c = min_s[0]
if c > "1":
digit_range("1" if top_level else "0", chr(ord(c) - 1))
out.append(" ")
more_digits(length, less_decimals)
out.append(" | ")
digit_range(c, c)
out.append(" (")
_generate_min_max_int(int(min_s[1:]), None, out, less_decimals, top_level=False)
out.append(")")
if c < "9":
out.append(" | ")
digit_range(chr(ord(c) + 1), "9")
out.append(" ")
more_digits(length - 1, less_decimals)
return
if has_max:
if max_value >= 0:
if top_level:
out.append("\"-\" [1-9] ")
more_digits(0, less_decimals)
out.append(" | ")
_generate_min_max_int(0, max_value, out, decimals_left, top_level=True)
else:
out.append("\"-\" (")
_generate_min_max_int(-max_value, None, out, decimals_left, top_level=False)
out.append(")")
return
raise RuntimeError("At least one of min_value or max_value must be set")
class BuiltinRule:
def __init__(self, content: str, deps: list | None = None):
self.content = content
self.deps = deps or []
# Constraining spaces to prevent model "running away".
SPACE_RULE = '| " " | "\\n"{1,2} [ \\t]{0,20}'
PRIMITIVE_RULES = {
'boolean' : BuiltinRule('("true" | "false") space', []),
'decimal-part' : BuiltinRule('[0-9]{1,16}', []),
'integral-part': BuiltinRule('[0] | [1-9] [0-9]{0,15}', []),
'number' : BuiltinRule('("-"? integral-part) ("." decimal-part)? ([eE] [-+]? integral-part)? space', ['integral-part', 'decimal-part']),
'integer' : BuiltinRule('("-"? integral-part) space', ['integral-part']),
'value' : BuiltinRule('object | array | string | number | boolean | null', ['object', 'array', 'string', 'number', 'boolean', 'null']),
'object' : BuiltinRule('"{" space ( string ":" space value ("," space string ":" space value)* )? "}" space', ['string', 'value']),
'array' : BuiltinRule('"[" space ( value ("," space value)* )? "]" space', ['value']),
'uuid' : BuiltinRule(r'"\"" [0-9a-fA-F]{8} "-" [0-9a-fA-F]{4} "-" [0-9a-fA-F]{4} "-" [0-9a-fA-F]{4} "-" [0-9a-fA-F]{12} "\"" space', []),
'char' : BuiltinRule(r'[^"\\\x7F\x00-\x1F] | [\\] (["\\bfnrt] | "u" [0-9a-fA-F]{4})', []),
'string' : BuiltinRule(r'"\"" char* "\"" space', ['char']),
'null' : BuiltinRule('"null" space', []),
}
# TODO: support "uri", "email" string formats
STRING_FORMAT_RULES = {
'date' : BuiltinRule('[0-9]{4} "-" ( "0" [1-9] | "1" [0-2] ) "-" ( \"0\" [1-9] | [1-2] [0-9] | "3" [0-1] )', []),
'time' : BuiltinRule('([01] [0-9] | "2" [0-3]) ":" [0-5] [0-9] ":" [0-5] [0-9] ( "." [0-9]{3} )? ( "Z" | ( "+" | "-" ) ( [01] [0-9] | "2" [0-3] ) ":" [0-5] [0-9] )', []),
'date-time' : BuiltinRule('date "T" time', ['date', 'time']),
'date-string' : BuiltinRule('"\\"" date "\\"" space', ['date']),
'time-string' : BuiltinRule('"\\"" time "\\"" space', ['time']),
'date-time-string': BuiltinRule('"\\"" date-time "\\"" space', ['date-time']),
}
DOTALL = '[\\U00000000-\\U0010FFFF]'
DOT = '[^\\x0A\\x0D]'
RESERVED_NAMES = set(["root", "dot", *PRIMITIVE_RULES.keys(), *STRING_FORMAT_RULES.keys()])
INVALID_RULE_CHARS_RE = re.compile(r'[^a-zA-Z0-9-]+')
GRAMMAR_LITERAL_ESCAPE_RE = re.compile(r'[\r\n"\\]')
GRAMMAR_RANGE_LITERAL_ESCAPE_RE = re.compile(r'[\r\n"\]\-\\]')
GRAMMAR_LITERAL_ESCAPES = {'\r': '\\r', '\n': '\\n', '"': '\\"', '-': '\\-', ']': '\\]', '\\': '\\\\'}
NON_LITERAL_SET = set('|.()[]{}*+?')
ESCAPED_IN_REGEXPS_BUT_NOT_IN_LITERALS = set('^$.[]()|{}*+?')
class SchemaConverter:
def __init__(self, *, prop_order, allow_fetch, dotall, raw_pattern):
self._prop_order = prop_order
self._allow_fetch = allow_fetch
self._dotall = dotall
self._raw_pattern = raw_pattern
self._rules = {
'space': SPACE_RULE,
}
self._refs = {}
self._refs_being_resolved = set()
def _format_literal(self, literal):
escaped = GRAMMAR_LITERAL_ESCAPE_RE.sub(
lambda m: GRAMMAR_LITERAL_ESCAPES.get(m.group(0)) or m.group(0), literal
)
return f'"{escaped}"'
def not_literal(self, literal: str, dotall: bool = True, maybe_escaped_underscores = False) -> str:
'''
not_literal('a') -> '[^a]'
not_literal('abc') -> '([^a] | "a" ([^b] | "b" ([^c])?)?)?'
'''
assert len(literal) > 0, 'Empty literal not supported'
def recurse(i: int):
c = literal[i]
if maybe_escaped_underscores and c == '_':
yield f'[^{c}\\\\]'
yield ' | '
yield f'"\\\\"? "{c}"'
else:
yield f'[^{c}]'
if i < len(literal) - 1:
yield ' | '
yield self._format_literal(c)
yield ' ('
yield from recurse(i + 1)
yield ')?'
return ''.join(('(', *recurse(0), ')'))
def _not_strings(self, strings):
class TrieNode:
def __init__(self):
self.children = {}
self.is_end_of_string = False
def insert(self, string):
node = self
for c in string:
node = node.children.setdefault(c, TrieNode())
node.is_end_of_string = True
trie = TrieNode()
for s in strings:
trie.insert(s)
char_rule = self._add_primitive('char', PRIMITIVE_RULES['char'])
out = ['["] ( ']
def visit(node):
rejects = []
first = True
for c in sorted(node.children.keys()):
child = node.children[c]
rejects.append(c)
if first:
first = False
else:
out.append(' | ')
out.append(f'[{c}]')
if child.children:
out.append(f' (')
visit(child)
out.append(')')
elif child.is_end_of_string:
out.append(f' {char_rule}+')
if node.children:
if not first:
out.append(' | ')
out.append(f'[^"{"".join(rejects)}] {char_rule}*')
visit(trie)
out.append(f' ){"" if trie.is_end_of_string else "?"} ["] space')
return ''.join(out)
def _add_rule(self, name, rule):
esc_name = INVALID_RULE_CHARS_RE.sub('-', name)
if esc_name not in self._rules or self._rules[esc_name] == rule:
key = esc_name
else:
i = 0
while f'{esc_name}{i}' in self._rules and self._rules[f'{esc_name}{i}'] != rule:
i += 1
key = f'{esc_name}{i}'
self._rules[key] = rule
return key
def resolve_refs(self, schema: dict, url: str):
'''
Resolves all $ref fields in the given schema, fetching any remote schemas,
replacing $ref with absolute reference URL and populating self._refs with the
respective referenced (sub)schema dictionaries.
'''
def visit(n: dict):
if isinstance(n, list):
return [visit(x) for x in n]
elif isinstance(n, dict):
ref = n.get('$ref')
if ref is not None and ref not in self._refs:
if ref.startswith('https://'):
assert self._allow_fetch, 'Fetching remote schemas is not allowed (use --allow-fetch for force)'
import requests
frag_split = ref.split('#')
base_url = frag_split[0]
target = self._refs.get(base_url)
if target is None:
target = self.resolve_refs(requests.get(ref).json(), base_url)
self._refs[base_url] = target
if len(frag_split) == 1 or frag_split[-1] == '':
return target
elif ref.startswith('#/'):
target = schema
ref = f'{url}{ref}'
n['$ref'] = ref
else:
raise ValueError(f'Unsupported ref {ref}')
for sel in ref.split('#')[-1].split('/')[1:]:
assert target is not None, f'Error resolving ref {ref}: {sel} not in {target}'
if isinstance(target, list):
try:
sel_index = int(sel)
except ValueError:
raise ValueError(f'Error resolving ref {ref}: {sel} not in {target}')
assert 0 <= sel_index < len(target), f'Error resolving ref {ref}: {sel} not in {target}'
target = target[sel_index]
else:
assert sel in target, f'Error resolving ref {ref}: {sel} not in {target}'
target = target[sel]
self._refs[ref] = target
else:
for v in n.values():
visit(v)
return n
return visit(schema)
def _generate_union_rule(self, name, alt_schemas):
return ' | '.join((
self.visit(alt_schema, f'{name}{"-" if name else "alternative-"}{i}')
for i, alt_schema in enumerate(alt_schemas)
))
def _visit_pattern(self, pattern, name):
'''
Transforms a regular expression pattern into a GBNF rule.
Input: https://json-schema.org/understanding-json-schema/reference/regular_expressions
Output: https://github.com/ggerganov/llama.cpp/blob/master/grammars/README.md
Unsupported features: negative/positive lookaheads, greedy/non-greedy modifiers.
Mostly a 1:1 translation, except for {x} / {x,} / {x,y} quantifiers for which
we define sub-rules to keep the output lean.
'''
assert pattern.startswith('^') and pattern.endswith('$'), 'Pattern must start with "^" and end with "$"'
pattern = pattern[1:-1]
sub_rule_ids = {}
i = 0
length = len(pattern)
def to_rule(s: tuple[str, bool]) -> str:
(txt, is_literal) = s
return "\"" + txt + "\"" if is_literal else txt
def transform() -> tuple[str, bool]:
'''
Parse a unit at index i (advancing it), and return its string representation + whether it's a literal.
'''
nonlocal i
nonlocal pattern
nonlocal sub_rule_ids
start = i
# For each component of this sequence, store its string representation and whether it's a literal.
# We only need a flat structure here to apply repetition operators to the last item, and
# to merge literals at the and (we're parsing grouped ( sequences ) recursively and don't treat '|' specially
# (GBNF's syntax is luckily very close to regular expressions!)
seq: list[tuple[str, bool]] = []
def get_dot():
if self._dotall:
rule = DOTALL
else:
# Accept any character... except \n and \r line break chars (\x0A and \xOD)
rule = DOT
return self._add_rule(f'dot', rule)
def join_seq():
nonlocal seq
ret = []
for is_literal, g in itertools.groupby(seq, lambda x: x[1]):
if is_literal:
ret.append((''.join(x[0] for x in g), True))
else:
ret.extend(g)
if len(ret) == 1:
return ret[0]
return (' '.join(to_rule(x) for x in seq), False)
while i < length:
c = pattern[i]
if c == '.':
seq.append((get_dot(), False))
i += 1
elif c == '(':
i += 1
if i < length:
assert pattern[i] != '?', f'Unsupported pattern syntax "{pattern[i]}" at index {i} of /{pattern}/'
seq.append((f'({to_rule(transform())})', False))
elif c == ')':
i += 1
assert start > 0 and pattern[start-1] == '(', f'Unbalanced parentheses; start = {start}, i = {i}, pattern = {pattern}'
return join_seq()
elif c == '[':
square_brackets = c
i += 1
while i < length and pattern[i] != ']':
if pattern[i] == '\\':
square_brackets += pattern[i:i+2]
i += 2
else:
square_brackets += pattern[i]
i += 1
assert i < length, f'Unbalanced square brackets; start = {start}, i = {i}, pattern = {pattern}'
square_brackets += ']'
i += 1
seq.append((square_brackets, False))
elif c == '|':
seq.append(('|', False))
i += 1
elif c in ('*', '+', '?'):
seq[-1] = (to_rule(seq[-1]) + c, False)
i += 1
elif c == '{':
curly_brackets = c
i += 1
while i < length and pattern[i] != '}':
curly_brackets += pattern[i]
i += 1
assert i < length, f'Unbalanced curly brackets; start = {start}, i = {i}, pattern = {pattern}'
curly_brackets += '}'
i += 1
nums = [s.strip() for s in curly_brackets[1:-1].split(',')]
min_times = 0
max_times = None
try:
if len(nums) == 1:
min_times = int(nums[0])
max_times = min_times
else:
assert len(nums) == 2
min_times = int(nums[0]) if nums[0] else 0
max_times = int(nums[1]) if nums[1] else None
except ValueError:
raise ValueError(f'Invalid quantifier {curly_brackets} in /{pattern}/')
(sub, sub_is_literal) = seq[-1]
if not sub_is_literal:
id = sub_rule_ids.get(sub)
if id is None:
id = self._add_rule(f'{name}-{len(sub_rule_ids) + 1}', sub)
sub_rule_ids[sub] = id
sub = id
seq[-1] = (_build_repetition(f'"{sub}"' if sub_is_literal else sub, min_times, max_times), False)
else:
literal = ''
while i < length:
if pattern[i] == '\\' and i < length - 1:
next = pattern[i + 1]
if next in ESCAPED_IN_REGEXPS_BUT_NOT_IN_LITERALS:
i += 1
literal += pattern[i]
i += 1
else:
literal += pattern[i:i+2]
i += 2
elif pattern[i] == '"' and not self._raw_pattern:
literal += '\\"'
i += 1
elif pattern[i] not in NON_LITERAL_SET and \
(i == length - 1 or literal == '' or pattern[i+1] == '.' or pattern[i+1] not in NON_LITERAL_SET):
literal += pattern[i]
i += 1
else:
break
if literal:
seq.append((literal, True))
return join_seq()
return self._add_rule(
name,
to_rule(transform()) if self._raw_pattern \
else "\"\\\"\" (" + to_rule(transform()) + ") \"\\\"\" space")
def _resolve_ref(self, ref):
ref_fragment = ref.split('#')[-1]
ref_name = 'ref' + re.sub(r'[^a-zA-Z0-9-]+', '-', ref_fragment)
if ref_name not in self._rules and ref not in self._refs_being_resolved:
self._refs_being_resolved.add(ref)
resolved = self._refs[ref]
ref_name = self.visit(resolved, ref_name)
self._refs_being_resolved.remove(ref)
return ref_name
def _generate_constant_rule(self, value):
return self._format_literal(json.dumps(value))
def visit(self, schema, name):
schema_type = schema.get('type')
schema_format = schema.get('format')
rule_name = name + '-' if name in RESERVED_NAMES else name or 'root'
if (ref := schema.get('$ref')) is not None:
return self._add_rule(rule_name, self._resolve_ref(ref))
elif 'oneOf' in schema or 'anyOf' in schema:
return self._add_rule(rule_name, self._generate_union_rule(name, schema.get('oneOf') or schema['anyOf']))
elif isinstance(schema_type, list):
return self._add_rule(rule_name, self._generate_union_rule(name, [{**schema, 'type': t} for t in schema_type]))
elif 'const' in schema:
return self._add_rule(rule_name, self._generate_constant_rule(schema['const']) + ' space')
elif 'enum' in schema:
rule = '(' + ' | '.join((self._generate_constant_rule(v) for v in schema['enum'])) + ') space'
return self._add_rule(rule_name, rule)
elif schema_type in (None, 'object') and \
('properties' in schema or \
('additionalProperties' in schema and schema['additionalProperties'] is not True)):
required = set(schema.get('required', []))
properties = list(schema.get('properties', {}).items())
return self._add_rule(rule_name, self._build_object_rule(properties, required, name, schema.get('additionalProperties')))
elif schema_type in (None, 'object', 'string') and 'allOf' in schema:
required = set()
properties = []
enum_sets = []
hybrid_name = name
def add_component(comp_schema, is_required):
if (ref := comp_schema.get('$ref')) is not None:
comp_schema = self._refs[ref]
if 'properties' in comp_schema:
for prop_name, prop_schema in comp_schema['properties'].items():
properties.append((prop_name, prop_schema))
if is_required:
required.add(prop_name)
if 'enum' in comp_schema:
enum_sets.append(set(comp_schema['enum']))
for t in schema['allOf']:
if 'anyOf' in t:
for tt in t['anyOf']:
add_component(tt, is_required=False)
else:
add_component(t, is_required=True)
if enum_sets:
enum_intersection = enum_sets[0]
for s in enum_sets[1:]:
enum_intersection &= s
if enum_intersection:
rule = '(' + ' | '.join((self._generate_constant_rule(v) for v in sorted(enum_intersection))) + ') space'
return self._add_rule(rule_name, rule)
return self._add_rule(rule_name, self._build_object_rule(properties, required, hybrid_name, additional_properties=None))
elif schema_type in (None, 'array') and ('items' in schema or 'prefixItems' in schema):
items = schema.get('items') or schema['prefixItems']
if isinstance(items, list):
return self._add_rule(
rule_name,
'"[" space ' +
' "," space '.join(
self.visit(item, f'{name}{"-" if name else ""}tuple-{i}')
for i, item in enumerate(items)) +
' "]" space')
else:
item_rule_name = self.visit(items, f'{name}{"-" if name else ""}item')
min_items = schema.get("minItems", 0)
max_items = schema.get("maxItems")
return self._add_rule(rule_name, '"[" space ' + _build_repetition(item_rule_name, min_items, max_items, separator_rule='"," space') + ' "]" space')
elif schema_type in (None, 'string') and 'pattern' in schema:
return self._visit_pattern(schema['pattern'], rule_name)
elif schema_type in (None, 'string') and re.match(r'^uuid[1-5]?$', schema_format or ''):
return self._add_primitive(
'root' if rule_name == 'root' else schema_format,
PRIMITIVE_RULES['uuid']
)
elif schema_type in (None, 'string') and f'{schema_format}-string' in STRING_FORMAT_RULES:
prim_name = f'{schema_format}-string'
return self._add_rule(rule_name, self._add_primitive(prim_name, STRING_FORMAT_RULES[prim_name]))
elif schema_type == 'string' and ('minLength' in schema or 'maxLength' in schema):
char_rule = self._add_primitive('char', PRIMITIVE_RULES['char'])
min_len = schema.get('minLength', 0)
max_len = schema.get('maxLength')
return self._add_rule(rule_name, r'"\"" ' + _build_repetition(char_rule, min_len, max_len) + r' "\"" space')
elif schema_type in (None, 'integer') and \
('minimum' in schema or 'exclusiveMinimum' in schema or 'maximum' in schema or 'exclusiveMaximum' in schema):
min_value = None
max_value = None
if 'minimum' in schema:
min_value = schema['minimum']
elif 'exclusiveMinimum' in schema:
min_value = schema['exclusiveMinimum'] + 1
if 'maximum' in schema:
max_value = schema['maximum']
elif 'exclusiveMaximum' in schema:
max_value = schema['exclusiveMaximum'] - 1
out = ["("]
_generate_min_max_int(min_value, max_value, out)
out.append(") space")
return self._add_rule(rule_name, ''.join(out))
elif (schema_type == 'object') or (len(schema) == 0):
return self._add_rule(rule_name, self._add_primitive('object', PRIMITIVE_RULES['object']))
else:
assert schema_type in PRIMITIVE_RULES, f'Unrecognized schema: {schema}'
# TODO: support minimum, maximum, exclusiveMinimum, exclusiveMaximum at least for zero
return self._add_primitive('root' if rule_name == 'root' else schema_type, PRIMITIVE_RULES[schema_type])
def _add_primitive(self, name: str, rule: BuiltinRule):
n = self._add_rule(name, rule.content)
for dep in rule.deps:
dep_rule = PRIMITIVE_RULES.get(dep) or STRING_FORMAT_RULES.get(dep)
assert dep_rule, f'Rule {dep} not known'
if dep not in self._rules:
self._add_primitive(dep, dep_rule)
return n
def _build_object_rule(self, properties: List[Tuple[str, Any]], required: Set[str], name: str, additional_properties: Optional[Union[bool, Any]]):
prop_order = self._prop_order
# sort by position in prop_order (if specified) then by original order
sorted_props = [kv[0] for _, kv in sorted(enumerate(properties), key=lambda ikv: (prop_order.get(ikv[1][0], len(prop_order)), ikv[0]))]
prop_kv_rule_names = {}
for prop_name, prop_schema in properties:
prop_rule_name = self.visit(prop_schema, f'{name}{"-" if name else ""}{prop_name}')
prop_kv_rule_names[prop_name] = self._add_rule(
f'{name}{"-" if name else ""}{prop_name}-kv',
fr'{self._format_literal(json.dumps(prop_name))} space ":" space {prop_rule_name}'
)
required_props = [k for k in sorted_props if k in required]
optional_props = [k for k in sorted_props if k not in required]
if additional_properties is not None and additional_properties != False:
sub_name = f'{name}{"-" if name else ""}additional'
value_rule = self.visit(additional_properties, f'{sub_name}-value') if isinstance(additional_properties, dict) else \
self._add_primitive('value', PRIMITIVE_RULES['value'])
key_rule = self._add_primitive('string', PRIMITIVE_RULES['string']) if not sorted_props \
else self._add_rule(f'{sub_name}-k', self._not_strings(sorted_props))
prop_kv_rule_names["*"] = self._add_rule(
f'{sub_name}-kv',
f'{key_rule} ":" space {value_rule}'
)
optional_props.append("*")
rule = '"{" space '
rule += ' "," space '.join(prop_kv_rule_names[k] for k in required_props)
if optional_props:
rule += ' ('
if required_props:
rule += ' "," space ( '
def get_recursive_refs(ks, first_is_optional):
[k, *rest] = ks
kv_rule_name = prop_kv_rule_names[k]
comma_ref = f'( "," space {kv_rule_name} )'
if first_is_optional:
res = comma_ref + ('*' if k == '*' else '?')
else:
res = kv_rule_name + (' ' + comma_ref + "*" if k == '*' else '')
if len(rest) > 0:
res += ' ' + self._add_rule(
f'{name}{"-" if name else ""}{k}-rest',
get_recursive_refs(rest, first_is_optional=True)
)
return res
rule += ' | '.join(
get_recursive_refs(optional_props[i:], first_is_optional=False)
for i in range(len(optional_props))
)
if required_props:
rule += ' )'
rule += ' )?'
rule += ' "}" space'
return rule
def format_grammar(self):
return '\n'.join(
f'{name} ::= {rule}'
for name, rule in sorted(self._rules.items(), key=lambda kv: kv[0])
)
def main(args_in = None):
parser = argparse.ArgumentParser(
description='''
Generates a grammar (suitable for use in ./llama-cli) that produces JSON conforming to a
given JSON schema. Only a subset of JSON schema features are supported; more may be
added in the future.
''',
)
parser.add_argument(
'--prop-order',
default=[],
type=lambda s: s.split(','),
help='''
comma-separated property names defining the order of precedence for object properties;
properties not specified here are given lower precedence than those that are, and
are kept in their original order from the schema. Required properties are always
given precedence over optional properties.
'''
)
parser.add_argument(
'--allow-fetch',
action='store_true',
default=False,
help='Whether to allow fetching referenced schemas over HTTPS')
parser.add_argument(
'--dotall',
action='store_true',
default=False,
help='Whether to treat dot (".") as matching all chars including line breaks in regular expression patterns')
parser.add_argument(
'--raw-pattern',
action='store_true',
default=False,
help='Treats string patterns as raw patterns w/o quotes (or quote escapes)')
parser.add_argument('schema', help='file containing JSON schema ("-" for stdin)')
args = parser.parse_args(args_in)
if args.schema.startswith('https://'):
url = args.schema
import requests
schema = requests.get(url).json()
elif args.schema == '-':
url = 'stdin'
schema = json.load(sys.stdin)
else:
url = f'file://{args.schema}'
with open(args.schema) as f:
schema = json.load(f)
converter = SchemaConverter(
prop_order={name: idx for idx, name in enumerate(args.prop_order)},
allow_fetch=args.allow_fetch,
dotall=args.dotall,
raw_pattern=args.raw_pattern)
schema = converter.resolve_refs(schema, url)
converter.visit(schema, '')
print(converter.format_grammar())
if __name__ == '__main__':
main()

33
examples/llama.android/.gitignore vendored Normal file
View File

@@ -0,0 +1,33 @@
# Gradle files
.gradle/
build/
# Local configuration file (sdk path, etc)
local.properties
# Log/OS Files
*.log
# Android Studio generated files and folders
captures/
.externalNativeBuild/
.cxx/
*.apk
output.json
# IntelliJ
*.iml
.idea/
misc.xml
deploymentTargetDropDown.xml
render.experimental.xml
# Keystore files
*.jks
*.keystore
# Google Services (e.g. APIs or Firebase)
google-services.json
# Android Profiling
*.hprof

1
examples/llama.android/app/.gitignore vendored Normal file
View File

@@ -0,0 +1 @@
/build

View File

@@ -0,0 +1,58 @@
plugins {
alias(libs.plugins.android.application)
alias(libs.plugins.jetbrains.kotlin.android)
}
android {
namespace = "com.example.llama"
compileSdk = 36
defaultConfig {
applicationId = "com.example.llama.aichat"
minSdk = 33
targetSdk = 36
versionCode = 1
versionName = "1.0"
testInstrumentationRunner = "androidx.test.runner.AndroidJUnitRunner"
vectorDrawables {
useSupportLibrary = true
}
}
buildTypes {
debug {
isMinifyEnabled = true
isShrinkResources = true
proguardFiles(
getDefaultProguardFile("proguard-android.txt"),
"proguard-rules.pro"
)
}
release {
isMinifyEnabled = true
isShrinkResources = true
proguardFiles(
getDefaultProguardFile("proguard-android-optimize.txt"),
"proguard-rules.pro"
)
}
}
compileOptions {
sourceCompatibility = JavaVersion.VERSION_17
targetCompatibility = JavaVersion.VERSION_17
}
}
dependencies {
implementation(libs.bundles.androidx)
implementation(libs.material)
implementation(project(":lib"))
testImplementation(libs.junit)
androidTestImplementation(libs.androidx.junit)
androidTestImplementation(libs.androidx.espresso.core)
}

View File

@@ -0,0 +1,29 @@
# Add project specific ProGuard rules here.
# You can control the set of applied configuration files using the
# proguardFiles setting in build.gradle.
#
# For more details, see
# http://developer.android.com/guide/developing/tools/proguard.html
# If your project uses WebView with JS, uncomment the following
# and specify the fully qualified class name to the JavaScript interface
# class:
#-keepclassmembers class fqcn.of.javascript.interface.for.webview {
# public *;
#}
# Uncomment this to preserve the line number information for
# debugging stack traces.
#-keepattributes SourceFile,LineNumberTable
# If you keep the line number information, uncomment this to
# hide the original source file name.
#-renamesourcefileattribute SourceFile
-keep class com.arm.aichat.* { *; }
-keep class com.arm.aichat.gguf.* { *; }
-assumenosideeffects class android.util.Log {
public static int v(...);
public static int d(...);
}

View File

@@ -0,0 +1,27 @@
<?xml version="1.0" encoding="utf-8"?>
<manifest xmlns:android="http://schemas.android.com/apk/res/android">
<application
android:allowBackup="true"
android:dataExtractionRules="@xml/data_extraction_rules"
android:extractNativeLibs="true"
android:fullBackupContent="@xml/backup_rules"
android:icon="@mipmap/ic_launcher_round"
android:label="@string/app_name"
android:roundIcon="@mipmap/ic_launcher_round"
android:supportsRtl="true"
android:theme="@style/Theme.AiChatSample"
>
<activity
android:name=".MainActivity"
android:exported="true">
<intent-filter>
<action android:name="android.intent.action.MAIN" />
<category android:name="android.intent.category.LAUNCHER" />
</intent-filter>
</activity>
</application>
</manifest>

View File

@@ -0,0 +1,275 @@
package com.example.llama
import android.net.Uri
import android.os.Bundle
import android.util.Log
import android.widget.EditText
import android.widget.TextView
import android.widget.Toast
import androidx.activity.addCallback
import androidx.activity.enableEdgeToEdge
import androidx.activity.result.contract.ActivityResultContracts
import androidx.appcompat.app.AppCompatActivity
import androidx.lifecycle.lifecycleScope
import androidx.recyclerview.widget.LinearLayoutManager
import androidx.recyclerview.widget.RecyclerView
import com.arm.aichat.AiChat
import com.arm.aichat.InferenceEngine
import com.arm.aichat.gguf.GgufMetadata
import com.arm.aichat.gguf.GgufMetadataReader
import com.google.android.material.floatingactionbutton.FloatingActionButton
import kotlinx.coroutines.Dispatchers
import kotlinx.coroutines.Job
import kotlinx.coroutines.flow.onCompletion
import kotlinx.coroutines.launch
import kotlinx.coroutines.withContext
import java.io.File
import java.io.FileOutputStream
import java.io.InputStream
import java.util.UUID
class MainActivity : AppCompatActivity() {
// Android views
private lateinit var ggufTv: TextView
private lateinit var messagesRv: RecyclerView
private lateinit var userInputEt: EditText
private lateinit var userActionFab: FloatingActionButton
// Arm AI Chat inference engine
private lateinit var engine: InferenceEngine
private var generationJob: Job? = null
// Conversation states
private var isModelReady = false
private val messages = mutableListOf<Message>()
private val lastAssistantMsg = StringBuilder()
private val messageAdapter = MessageAdapter(messages)
override fun onCreate(savedInstanceState: Bundle?) {
super.onCreate(savedInstanceState)
enableEdgeToEdge()
setContentView(R.layout.activity_main)
// View model boilerplate and state management is out of this basic sample's scope
onBackPressedDispatcher.addCallback { Log.w(TAG, "Ignore back press for simplicity") }
// Find views
ggufTv = findViewById(R.id.gguf)
messagesRv = findViewById(R.id.messages)
messagesRv.layoutManager = LinearLayoutManager(this).apply { stackFromEnd = true }
messagesRv.adapter = messageAdapter
userInputEt = findViewById(R.id.user_input)
userActionFab = findViewById(R.id.fab)
// Arm AI Chat initialization
lifecycleScope.launch(Dispatchers.Default) {
engine = AiChat.getInferenceEngine(applicationContext)
}
// Upon CTA button tapped
userActionFab.setOnClickListener {
if (isModelReady) {
// If model is ready, validate input and send to engine
handleUserInput()
} else {
// Otherwise, prompt user to select a GGUF metadata on the device
getContent.launch(arrayOf("*/*"))
}
}
}
private val getContent = registerForActivityResult(
ActivityResultContracts.OpenDocument()
) { uri ->
Log.i(TAG, "Selected file uri:\n $uri")
uri?.let { handleSelectedModel(it) }
}
/**
* Handles the file Uri from [getContent] result
*/
private fun handleSelectedModel(uri: Uri) {
// Update UI states
userActionFab.isEnabled = false
userInputEt.hint = "Parsing GGUF..."
ggufTv.text = "Parsing metadata from selected file \n$uri"
lifecycleScope.launch(Dispatchers.IO) {
// Parse GGUF metadata
Log.i(TAG, "Parsing GGUF metadata...")
contentResolver.openInputStream(uri)?.use {
GgufMetadataReader.create().readStructuredMetadata(it)
}?.let { metadata ->
// Update UI to show GGUF metadata to user
Log.i(TAG, "GGUF parsed: \n$metadata")
withContext(Dispatchers.Main) {
ggufTv.text = metadata.toString()
}
// Ensure the model file is available
val modelName = metadata.filename() + FILE_EXTENSION_GGUF
contentResolver.openInputStream(uri)?.use { input ->
ensureModelFile(modelName, input)
}?.let { modelFile ->
loadModel(modelName, modelFile)
withContext(Dispatchers.Main) {
isModelReady = true
userInputEt.hint = "Type and send a message!"
userInputEt.isEnabled = true
userActionFab.setImageResource(R.drawable.outline_send_24)
userActionFab.isEnabled = true
}
}
}
}
}
/**
* Prepare the model file within app's private storage
*/
private suspend fun ensureModelFile(modelName: String, input: InputStream) =
withContext(Dispatchers.IO) {
File(ensureModelsDirectory(), modelName).also { file ->
// Copy the file into local storage if not yet done
if (!file.exists()) {
Log.i(TAG, "Start copying file to $modelName")
withContext(Dispatchers.Main) {
userInputEt.hint = "Copying file..."
}
FileOutputStream(file).use { input.copyTo(it) }
Log.i(TAG, "Finished copying file to $modelName")
} else {
Log.i(TAG, "File already exists $modelName")
}
}
}
/**
* Load the model file from the app private storage
*/
private suspend fun loadModel(modelName: String, modelFile: File) =
withContext(Dispatchers.IO) {
Log.i(TAG, "Loading model $modelName")
withContext(Dispatchers.Main) {
userInputEt.hint = "Loading model..."
}
engine.loadModel(modelFile.path)
}
/**
* Validate and send the user message into [InferenceEngine]
*/
private fun handleUserInput() {
userInputEt.text.toString().also { userMsg ->
if (userMsg.isEmpty()) {
Toast.makeText(this, "Input message is empty!", Toast.LENGTH_SHORT).show()
} else {
userInputEt.text = null
userInputEt.isEnabled = false
userActionFab.isEnabled = false
// Update message states
messages.add(Message(UUID.randomUUID().toString(), userMsg, true))
lastAssistantMsg.clear()
messages.add(Message(UUID.randomUUID().toString(), lastAssistantMsg.toString(), false))
generationJob = lifecycleScope.launch(Dispatchers.Default) {
engine.sendUserPrompt(userMsg)
.onCompletion {
withContext(Dispatchers.Main) {
userInputEt.isEnabled = true
userActionFab.isEnabled = true
}
}.collect { token ->
withContext(Dispatchers.Main) {
val messageCount = messages.size
check(messageCount > 0 && !messages[messageCount - 1].isUser)
messages.removeAt(messageCount - 1).copy(
content = lastAssistantMsg.append(token).toString()
).let { messages.add(it) }
messageAdapter.notifyItemChanged(messages.size - 1)
}
}
}
}
}
}
/**
* Run a benchmark with the model file
*/
@Deprecated("This benchmark doesn't accurately indicate GUI performance expected by app developers")
private suspend fun runBenchmark(modelName: String, modelFile: File) =
withContext(Dispatchers.Default) {
Log.i(TAG, "Starts benchmarking $modelName")
withContext(Dispatchers.Main) {
userInputEt.hint = "Running benchmark..."
}
engine.bench(
pp=BENCH_PROMPT_PROCESSING_TOKENS,
tg=BENCH_TOKEN_GENERATION_TOKENS,
pl=BENCH_SEQUENCE,
nr=BENCH_REPETITION
).let { result ->
messages.add(Message(UUID.randomUUID().toString(), result, false))
withContext(Dispatchers.Main) {
messageAdapter.notifyItemChanged(messages.size - 1)
}
}
}
/**
* Create the `models` directory if not exist.
*/
private fun ensureModelsDirectory() =
File(filesDir, DIRECTORY_MODELS).also {
if (it.exists() && !it.isDirectory) { it.delete() }
if (!it.exists()) { it.mkdir() }
}
override fun onStop() {
generationJob?.cancel()
super.onStop()
}
override fun onDestroy() {
engine.destroy()
super.onDestroy()
}
companion object {
private val TAG = MainActivity::class.java.simpleName
private const val DIRECTORY_MODELS = "models"
private const val FILE_EXTENSION_GGUF = ".gguf"
private const val BENCH_PROMPT_PROCESSING_TOKENS = 512
private const val BENCH_TOKEN_GENERATION_TOKENS = 128
private const val BENCH_SEQUENCE = 1
private const val BENCH_REPETITION = 3
}
}
fun GgufMetadata.filename() = when {
basic.name != null -> {
basic.name?.let { name ->
basic.sizeLabel?.let { size ->
"$name-$size"
} ?: name
}
}
architecture?.architecture != null -> {
architecture?.architecture?.let { arch ->
basic.uuid?.let { uuid ->
"$arch-$uuid"
} ?: "$arch-${System.currentTimeMillis()}"
}
}
else -> {
"model-${System.currentTimeMillis().toHexString()}"
}
}

View File

@@ -0,0 +1,51 @@
package com.example.llama
import android.view.LayoutInflater
import android.view.View
import android.view.ViewGroup
import android.widget.TextView
import androidx.recyclerview.widget.RecyclerView
data class Message(
val id: String,
val content: String,
val isUser: Boolean
)
class MessageAdapter(
private val messages: List<Message>
) : RecyclerView.Adapter<RecyclerView.ViewHolder>() {
companion object {
private const val VIEW_TYPE_USER = 1
private const val VIEW_TYPE_ASSISTANT = 2
}
override fun getItemViewType(position: Int): Int {
return if (messages[position].isUser) VIEW_TYPE_USER else VIEW_TYPE_ASSISTANT
}
override fun onCreateViewHolder(parent: ViewGroup, viewType: Int): RecyclerView.ViewHolder {
val layoutInflater = LayoutInflater.from(parent.context)
return if (viewType == VIEW_TYPE_USER) {
val view = layoutInflater.inflate(R.layout.item_message_user, parent, false)
UserMessageViewHolder(view)
} else {
val view = layoutInflater.inflate(R.layout.item_message_assistant, parent, false)
AssistantMessageViewHolder(view)
}
}
override fun onBindViewHolder(holder: RecyclerView.ViewHolder, position: Int) {
val message = messages[position]
if (holder is UserMessageViewHolder || holder is AssistantMessageViewHolder) {
val textView = holder.itemView.findViewById<TextView>(R.id.msg_content)
textView.text = message.content
}
}
override fun getItemCount(): Int = messages.size
class UserMessageViewHolder(view: View) : RecyclerView.ViewHolder(view)
class AssistantMessageViewHolder(view: View) : RecyclerView.ViewHolder(view)
}

View File

@@ -0,0 +1,4 @@
<shape xmlns:android="http://schemas.android.com/apk/res/android" android:shape="rectangle">
<solid android:color="#E5E5EA" />
<corners android:radius="16dp" />
</shape>

View File

@@ -0,0 +1,4 @@
<shape xmlns:android="http://schemas.android.com/apk/res/android" android:shape="rectangle">
<solid android:color="#4285F4" />
<corners android:radius="16dp" />
</shape>

View File

@@ -0,0 +1,170 @@
<?xml version="1.0" encoding="utf-8"?>
<vector xmlns:android="http://schemas.android.com/apk/res/android"
android:width="108dp"
android:height="108dp"
android:viewportWidth="108"
android:viewportHeight="108">
<path
android:fillColor="#3DDC84"
android:pathData="M0,0h108v108h-108z" />
<path
android:fillColor="#00000000"
android:pathData="M9,0L9,108"
android:strokeWidth="0.8"
android:strokeColor="#33FFFFFF" />
<path
android:fillColor="#00000000"
android:pathData="M19,0L19,108"
android:strokeWidth="0.8"
android:strokeColor="#33FFFFFF" />
<path
android:fillColor="#00000000"
android:pathData="M29,0L29,108"
android:strokeWidth="0.8"
android:strokeColor="#33FFFFFF" />
<path
android:fillColor="#00000000"
android:pathData="M39,0L39,108"
android:strokeWidth="0.8"
android:strokeColor="#33FFFFFF" />
<path
android:fillColor="#00000000"
android:pathData="M49,0L49,108"
android:strokeWidth="0.8"
android:strokeColor="#33FFFFFF" />
<path
android:fillColor="#00000000"
android:pathData="M59,0L59,108"
android:strokeWidth="0.8"
android:strokeColor="#33FFFFFF" />
<path
android:fillColor="#00000000"
android:pathData="M69,0L69,108"
android:strokeWidth="0.8"
android:strokeColor="#33FFFFFF" />
<path
android:fillColor="#00000000"
android:pathData="M79,0L79,108"
android:strokeWidth="0.8"
android:strokeColor="#33FFFFFF" />
<path
android:fillColor="#00000000"
android:pathData="M89,0L89,108"
android:strokeWidth="0.8"
android:strokeColor="#33FFFFFF" />
<path
android:fillColor="#00000000"
android:pathData="M99,0L99,108"
android:strokeWidth="0.8"
android:strokeColor="#33FFFFFF" />
<path
android:fillColor="#00000000"
android:pathData="M0,9L108,9"
android:strokeWidth="0.8"
android:strokeColor="#33FFFFFF" />
<path
android:fillColor="#00000000"
android:pathData="M0,19L108,19"
android:strokeWidth="0.8"
android:strokeColor="#33FFFFFF" />
<path
android:fillColor="#00000000"
android:pathData="M0,29L108,29"
android:strokeWidth="0.8"
android:strokeColor="#33FFFFFF" />
<path
android:fillColor="#00000000"
android:pathData="M0,39L108,39"
android:strokeWidth="0.8"
android:strokeColor="#33FFFFFF" />
<path
android:fillColor="#00000000"
android:pathData="M0,49L108,49"
android:strokeWidth="0.8"
android:strokeColor="#33FFFFFF" />
<path
android:fillColor="#00000000"
android:pathData="M0,59L108,59"
android:strokeWidth="0.8"
android:strokeColor="#33FFFFFF" />
<path
android:fillColor="#00000000"
android:pathData="M0,69L108,69"
android:strokeWidth="0.8"
android:strokeColor="#33FFFFFF" />
<path
android:fillColor="#00000000"
android:pathData="M0,79L108,79"
android:strokeWidth="0.8"
android:strokeColor="#33FFFFFF" />
<path
android:fillColor="#00000000"
android:pathData="M0,89L108,89"
android:strokeWidth="0.8"
android:strokeColor="#33FFFFFF" />
<path
android:fillColor="#00000000"
android:pathData="M0,99L108,99"
android:strokeWidth="0.8"
android:strokeColor="#33FFFFFF" />
<path
android:fillColor="#00000000"
android:pathData="M19,29L89,29"
android:strokeWidth="0.8"
android:strokeColor="#33FFFFFF" />
<path
android:fillColor="#00000000"
android:pathData="M19,39L89,39"
android:strokeWidth="0.8"
android:strokeColor="#33FFFFFF" />
<path
android:fillColor="#00000000"
android:pathData="M19,49L89,49"
android:strokeWidth="0.8"
android:strokeColor="#33FFFFFF" />
<path
android:fillColor="#00000000"
android:pathData="M19,59L89,59"
android:strokeWidth="0.8"
android:strokeColor="#33FFFFFF" />
<path
android:fillColor="#00000000"
android:pathData="M19,69L89,69"
android:strokeWidth="0.8"
android:strokeColor="#33FFFFFF" />
<path
android:fillColor="#00000000"
android:pathData="M19,79L89,79"
android:strokeWidth="0.8"
android:strokeColor="#33FFFFFF" />
<path
android:fillColor="#00000000"
android:pathData="M29,19L29,89"
android:strokeWidth="0.8"
android:strokeColor="#33FFFFFF" />
<path
android:fillColor="#00000000"
android:pathData="M39,19L39,89"
android:strokeWidth="0.8"
android:strokeColor="#33FFFFFF" />
<path
android:fillColor="#00000000"
android:pathData="M49,19L49,89"
android:strokeWidth="0.8"
android:strokeColor="#33FFFFFF" />
<path
android:fillColor="#00000000"
android:pathData="M59,19L59,89"
android:strokeWidth="0.8"
android:strokeColor="#33FFFFFF" />
<path
android:fillColor="#00000000"
android:pathData="M69,19L69,89"
android:strokeWidth="0.8"
android:strokeColor="#33FFFFFF" />
<path
android:fillColor="#00000000"
android:pathData="M79,19L79,89"
android:strokeWidth="0.8"
android:strokeColor="#33FFFFFF" />
</vector>

View File

@@ -0,0 +1,30 @@
<vector xmlns:android="http://schemas.android.com/apk/res/android"
xmlns:aapt="http://schemas.android.com/aapt"
android:width="108dp"
android:height="108dp"
android:viewportWidth="108"
android:viewportHeight="108">
<path android:pathData="M31,63.928c0,0 6.4,-11 12.1,-13.1c7.2,-2.6 26,-1.4 26,-1.4l38.1,38.1L107,108.928l-32,-1L31,63.928z">
<aapt:attr name="android:fillColor">
<gradient
android:endX="85.84757"
android:endY="92.4963"
android:startX="42.9492"
android:startY="49.59793"
android:type="linear">
<item
android:color="#44000000"
android:offset="0.0" />
<item
android:color="#00000000"
android:offset="1.0" />
</gradient>
</aapt:attr>
</path>
<path
android:fillColor="#FFFFFF"
android:fillType="nonZero"
android:pathData="M65.3,45.828l3.8,-6.6c0.2,-0.4 0.1,-0.9 -0.3,-1.1c-0.4,-0.2 -0.9,-0.1 -1.1,0.3l-3.9,6.7c-6.3,-2.8 -13.4,-2.8 -19.7,0l-3.9,-6.7c-0.2,-0.4 -0.7,-0.5 -1.1,-0.3C38.8,38.328 38.7,38.828 38.9,39.228l3.8,6.6C36.2,49.428 31.7,56.028 31,63.928h46C76.3,56.028 71.8,49.428 65.3,45.828zM43.4,57.328c-0.8,0 -1.5,-0.5 -1.8,-1.2c-0.3,-0.7 -0.1,-1.5 0.4,-2.1c0.5,-0.5 1.4,-0.7 2.1,-0.4c0.7,0.3 1.2,1 1.2,1.8C45.3,56.528 44.5,57.328 43.4,57.328L43.4,57.328zM64.6,57.328c-0.8,0 -1.5,-0.5 -1.8,-1.2s-0.1,-1.5 0.4,-2.1c0.5,-0.5 1.4,-0.7 2.1,-0.4c0.7,0.3 1.2,1 1.2,1.8C66.5,56.528 65.6,57.328 64.6,57.328L64.6,57.328z"
android:strokeWidth="1"
android:strokeColor="#00000000" />
</vector>

View File

@@ -0,0 +1,10 @@
<vector xmlns:android="http://schemas.android.com/apk/res/android"
android:width="24dp"
android:height="24dp"
android:viewportWidth="24"
android:viewportHeight="24"
android:tint="?attr/colorControlNormal">
<path
android:fillColor="@android:color/white"
android:pathData="M20,6h-8l-2,-2L4,4c-1.1,0 -1.99,0.9 -1.99,2L2,18c0,1.1 0.9,2 2,2h16c1.1,0 2,-0.9 2,-2L22,8c0,-1.1 -0.9,-2 -2,-2zM20,18L4,18L4,8h16v10z"/>
</vector>

View File

@@ -0,0 +1,11 @@
<vector xmlns:android="http://schemas.android.com/apk/res/android"
android:width="24dp"
android:height="24dp"
android:viewportWidth="24"
android:viewportHeight="24"
android:tint="?attr/colorControlNormal"
android:autoMirrored="true">
<path
android:fillColor="@android:color/white"
android:pathData="M4.01,6.03l7.51,3.22 -7.52,-1 0.01,-2.22m7.5,8.72L4,17.97v-2.22l7.51,-1M2.01,3L2,10l15,2 -15,2 0.01,7L23,12 2.01,3z"/>
</vector>

View File

@@ -0,0 +1,77 @@
<?xml version="1.0" encoding="utf-8"?>
<androidx.constraintlayout.widget.ConstraintLayout xmlns:android="http://schemas.android.com/apk/res/android"
xmlns:app="http://schemas.android.com/apk/res-auto"
xmlns:tools="http://schemas.android.com/tools"
android:id="@+id/main"
android:layout_height="match_parent"
android:layout_width="match_parent">
<LinearLayout
android:fitsSystemWindows="true"
android:layout_width="match_parent"
android:layout_height="match_parent"
android:orientation="vertical"
android:layout_marginEnd="4dp"
tools:context=".MainActivity">
<ScrollView
android:layout_width="match_parent"
android:layout_height="0dp"
android:layout_weight="1"
android:fadeScrollbars="false">
<TextView
android:id="@+id/gguf"
android:layout_width="match_parent"
android:layout_height="wrap_content"
android:padding="16dp"
android:text="Selected GGUF model's metadata will show here."
style="@style/TextAppearance.MaterialComponents.Body2" />
</ScrollView>
<com.google.android.material.divider.MaterialDivider
android:layout_width="match_parent"
android:layout_height="2dp"
android:layout_marginHorizontal="16dp" />
<androidx.recyclerview.widget.RecyclerView
android:id="@+id/messages"
android:layout_width="match_parent"
android:layout_height="0dp"
android:layout_weight="4"
android:fadeScrollbars="false"
android:scrollbars="vertical"
app:reverseLayout="true"
tools:listitem="@layout/item_message_assistant"/>
<LinearLayout
android:layout_width="match_parent"
android:layout_height="wrap_content"
android:orientation="horizontal"
android:paddingStart="16dp"
android:paddingEnd="4dp">
<EditText
android:id="@+id/user_input"
android:enabled="false"
android:layout_width="0dp"
android:layout_weight="1"
android:layout_height="match_parent"
android:padding="8dp"
style="@style/TextAppearance.MaterialComponents.Body2"
android:hint="Please first pick a GGUF model file to import." />
<com.google.android.material.floatingactionbutton.FloatingActionButton
android:id="@+id/fab"
android:enabled="true"
style="@style/Widget.Material3.FloatingActionButton.Primary"
android:layout_width="wrap_content"
android:layout_height="wrap_content"
android:layout_margin="12dp"
android:src="@drawable/outline_folder_open_24" />
</LinearLayout>
</LinearLayout>
</androidx.constraintlayout.widget.ConstraintLayout>

View File

@@ -0,0 +1,16 @@
<?xml version="1.0" encoding="utf-8"?>
<LinearLayout xmlns:android="http://schemas.android.com/apk/res/android"
android:layout_width="match_parent"
android:layout_height="wrap_content"
android:layout_marginHorizontal="16dp"
android:layout_marginVertical="8dp"
android:gravity="start">
<TextView
android:id="@+id/msg_content"
android:layout_width="wrap_content"
android:layout_height="wrap_content"
android:background="@drawable/bg_assistant_message"
android:padding="12dp"
android:textColor="@android:color/black" />
</LinearLayout>

View File

@@ -0,0 +1,16 @@
<?xml version="1.0" encoding="utf-8"?>
<LinearLayout xmlns:android="http://schemas.android.com/apk/res/android"
android:layout_width="match_parent"
android:layout_height="wrap_content"
android:layout_marginHorizontal="16dp"
android:layout_marginVertical="8dp"
android:gravity="end">
<TextView
android:id="@+id/msg_content"
android:layout_width="wrap_content"
android:layout_height="wrap_content"
android:background="@drawable/bg_user_message"
android:padding="12dp"
android:textColor="@android:color/white" />
</LinearLayout>

View File

@@ -0,0 +1,6 @@
<?xml version="1.0" encoding="utf-8"?>
<adaptive-icon xmlns:android="http://schemas.android.com/apk/res/android">
<background android:drawable="@drawable/ic_launcher_background" />
<foreground android:drawable="@drawable/ic_launcher_foreground" />
<monochrome android:drawable="@drawable/ic_launcher_foreground" />
</adaptive-icon>

View File

@@ -0,0 +1,6 @@
<?xml version="1.0" encoding="utf-8"?>
<adaptive-icon xmlns:android="http://schemas.android.com/apk/res/android">
<background android:drawable="@drawable/ic_launcher_background" />
<foreground android:drawable="@drawable/ic_launcher_foreground" />
<monochrome android:drawable="@drawable/ic_launcher_foreground" />
</adaptive-icon>

Binary file not shown.

After

Width:  |  Height:  |  Size: 1.4 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 2.8 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 982 B

Binary file not shown.

After

Width:  |  Height:  |  Size: 1.7 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 1.9 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 3.8 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 2.8 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 5.8 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 3.8 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 7.6 KiB

View File

@@ -0,0 +1,10 @@
<?xml version="1.0" encoding="utf-8"?>
<resources>
<color name="purple_200">#FFBB86FC</color>
<color name="purple_500">#FF6200EE</color>
<color name="purple_700">#FF3700B3</color>
<color name="teal_200">#FF03DAC5</color>
<color name="teal_700">#FF018786</color>
<color name="black">#FF000000</color>
<color name="white">#FFFFFFFF</color>
</resources>

View File

@@ -0,0 +1,3 @@
<resources>
<string name="app_name">AI Chat basic sample</string>
</resources>

View File

@@ -0,0 +1,10 @@
<?xml version="1.0" encoding="utf-8"?>
<resources>
<style name="Base.Theme.AiChatSample" parent="Theme.Material3.DayNight.NoActionBar">
<!-- Customize your light theme here. -->
<!-- <item name="colorPrimary">@color/my_light_primary</item> -->
</style>
<style name="Theme.AiChatSample" parent="Base.Theme.AiChatSample" />
</resources>

View File

@@ -0,0 +1,13 @@
<?xml version="1.0" encoding="utf-8"?><!--
Sample backup rules file; uncomment and customize as necessary.
See https://developer.android.com/guide/topics/data/autobackup
for details.
Note: This file is ignored for devices older that API 31
See https://developer.android.com/about/versions/12/backup-restore
-->
<full-backup-content>
<!--
<include domain="sharedpref" path="."/>
<exclude domain="sharedpref" path="device.xml"/>
-->
</full-backup-content>

View File

@@ -0,0 +1,19 @@
<?xml version="1.0" encoding="utf-8"?><!--
Sample data extraction rules file; uncomment and customize as necessary.
See https://developer.android.com/about/versions/12/backup-restore#xml-changes
for details.
-->
<data-extraction-rules>
<cloud-backup>
<!-- TODO: Use <include> and <exclude> to control what is backed up.
<include .../>
<exclude .../>
-->
</cloud-backup>
<!--
<device-transfer>
<include .../>
<exclude .../>
</device-transfer>
-->
</data-extraction-rules>

View File

@@ -0,0 +1,6 @@
// Top-level build file where you can add configuration options common to all sub-projects/modules.
plugins {
alias(libs.plugins.android.application) apply false
alias(libs.plugins.android.library) apply false
alias(libs.plugins.jetbrains.kotlin.android) apply false
}

View File

@@ -0,0 +1,24 @@
# Project-wide Gradle settings.
# IDE (e.g. Android Studio) users:
# Gradle settings configured through the IDE *will override*
# any settings specified in this file.
# For more details on how to configure your build environment visit
# http://www.gradle.org/docs/current/userguide/build_environment.html
# Specifies the JVM arguments used for the daemon process.
# The setting is particularly useful for tweaking memory settings.
org.gradle.jvmargs=-Xmx2048m -Dfile.encoding=UTF-8
# When configured, Gradle will run in incubating parallel mode.
# This option should only be used with decoupled projects. More details, visit
# http://www.gradle.org/docs/current/userguide/multi_project_builds.html#sec:decoupled_projects
# org.gradle.parallel=true
# AndroidX package structure to make it clearer which packages are bundled with the
# Android operating system, and which are packaged with your app's APK
# https://developer.android.com/topic/libraries/support-library/androidx-rn
android.useAndroidX=true
# Kotlin code style for this project: "official" or "obsolete":
kotlin.code.style=official
# Enables namespacing of each library's R class so that its R class includes only the
# resources declared in the library itself and none from the library's dependencies,
# thereby reducing the size of the R class for that library
android.nonTransitiveRClass=true
android.native.buildOutput=verbose

View File

@@ -0,0 +1,53 @@
[versions]
# Plugins
agp = "8.13.2"
kotlin = "2.3.0"
# AndroidX
activity = "1.12.2"
appcompat = "1.7.1"
core-ktx = "1.17.0"
constraint-layout = "2.2.1"
datastore-preferences = "1.2.0"
# Material
material = "1.13.0"
# Testing
espresso-core = "3.7.0"
androidx-junit = "1.3.0"
junit = "4.13.2"
[plugins]
android-application = { id = "com.android.application", version.ref = "agp" }
android-library = { id = "com.android.library", version.ref = "agp" }
jetbrains-kotlin-android = { id = "org.jetbrains.kotlin.android", version.ref = "kotlin" }
[libraries]
# AndroidX
androidx-activity = { group = "androidx.activity", name = "activity", version.ref = "activity" }
androidx-appcompat = { group = "androidx.appcompat", name = "appcompat", version.ref = "appcompat" }
androidx-constraintlayout = { group = "androidx.constraintlayout", name = "constraintlayout", version.ref = "constraint-layout" }
androidx-core-ktx = { group = "androidx.core", name = "core-ktx", version.ref = "core-ktx" }
androidx-datastore-preferences = { group = "androidx.datastore", name = "datastore-preferences", version.ref = "datastore-preferences" }
#Material
material = { group = "com.google.android.material", name = "material", version.ref = "material" }
# Testing
androidx-espresso-core = { group = "androidx.test.espresso", name = "espresso-core", version.ref = "espresso-core" }
androidx-junit = { group = "androidx.test.ext", name = "junit", version.ref = "androidx-junit" }
junit = { group = "junit", name = "junit", version.ref = "junit" }
[bundles]
androidx = [
"androidx-activity",
"androidx-appcompat",
"androidx-constraintlayout",
"androidx-core-ktx",
"androidx-datastore-preferences",
]

Binary file not shown.

View File

@@ -0,0 +1,6 @@
#Tue Apr 01 11:15:06 PDT 2025
distributionBase=GRADLE_USER_HOME
distributionPath=wrapper/dists
distributionUrl=https\://services.gradle.org/distributions/gradle-8.14.3-bin.zip
zipStoreBase=GRADLE_USER_HOME
zipStorePath=wrapper/dists

185
examples/llama.android/gradlew vendored Executable file
View File

@@ -0,0 +1,185 @@
#!/usr/bin/env sh
#
# Copyright 2015 the original author or authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
##############################################################################
##
## Gradle start up script for UN*X
##
##############################################################################
# Attempt to set APP_HOME
# Resolve links: $0 may be a link
PRG="$0"
# Need this for relative symlinks.
while [ -h "$PRG" ] ; do
ls=`ls -ld "$PRG"`
link=`expr "$ls" : '.*-> \(.*\)$'`
if expr "$link" : '/.*' > /dev/null; then
PRG="$link"
else
PRG=`dirname "$PRG"`"/$link"
fi
done
SAVED="`pwd`"
cd "`dirname \"$PRG\"`/" >/dev/null
APP_HOME="`pwd -P`"
cd "$SAVED" >/dev/null
APP_NAME="Gradle"
APP_BASE_NAME=`basename "$0"`
# Add default JVM options here. You can also use JAVA_OPTS and GRADLE_OPTS to pass JVM options to this script.
DEFAULT_JVM_OPTS='"-Xmx64m" "-Xms64m"'
# Use the maximum available, or set MAX_FD != -1 to use that value.
MAX_FD="maximum"
warn () {
echo "$*"
}
die () {
echo
echo "$*"
echo
exit 1
}
# OS specific support (must be 'true' or 'false').
cygwin=false
msys=false
darwin=false
nonstop=false
case "`uname`" in
CYGWIN* )
cygwin=true
;;
Darwin* )
darwin=true
;;
MINGW* )
msys=true
;;
NONSTOP* )
nonstop=true
;;
esac
CLASSPATH=$APP_HOME/gradle/wrapper/gradle-wrapper.jar
# Determine the Java command to use to start the JVM.
if [ -n "$JAVA_HOME" ] ; then
if [ -x "$JAVA_HOME/jre/sh/java" ] ; then
# IBM's JDK on AIX uses strange locations for the executables
JAVACMD="$JAVA_HOME/jre/sh/java"
else
JAVACMD="$JAVA_HOME/bin/java"
fi
if [ ! -x "$JAVACMD" ] ; then
die "ERROR: JAVA_HOME is set to an invalid directory: $JAVA_HOME
Please set the JAVA_HOME variable in your environment to match the
location of your Java installation."
fi
else
JAVACMD="java"
which java >/dev/null 2>&1 || die "ERROR: JAVA_HOME is not set and no 'java' command could be found in your PATH.
Please set the JAVA_HOME variable in your environment to match the
location of your Java installation."
fi
# Increase the maximum file descriptors if we can.
if [ "$cygwin" = "false" -a "$darwin" = "false" -a "$nonstop" = "false" ] ; then
MAX_FD_LIMIT=`ulimit -H -n`
if [ $? -eq 0 ] ; then
if [ "$MAX_FD" = "maximum" -o "$MAX_FD" = "max" ] ; then
MAX_FD="$MAX_FD_LIMIT"
fi
ulimit -n $MAX_FD
if [ $? -ne 0 ] ; then
warn "Could not set maximum file descriptor limit: $MAX_FD"
fi
else
warn "Could not query maximum file descriptor limit: $MAX_FD_LIMIT"
fi
fi
# For Darwin, add options to specify how the application appears in the dock
if $darwin; then
GRADLE_OPTS="$GRADLE_OPTS \"-Xdock:name=$APP_NAME\" \"-Xdock:icon=$APP_HOME/media/gradle.icns\""
fi
# For Cygwin or MSYS, switch paths to Windows format before running java
if [ "$cygwin" = "true" -o "$msys" = "true" ] ; then
APP_HOME=`cygpath --path --mixed "$APP_HOME"`
CLASSPATH=`cygpath --path --mixed "$CLASSPATH"`
JAVACMD=`cygpath --unix "$JAVACMD"`
# We build the pattern for arguments to be converted via cygpath
ROOTDIRSRAW=`find -L / -maxdepth 1 -mindepth 1 -type d 2>/dev/null`
SEP=""
for dir in $ROOTDIRSRAW ; do
ROOTDIRS="$ROOTDIRS$SEP$dir"
SEP="|"
done
OURCYGPATTERN="(^($ROOTDIRS))"
# Add a user-defined pattern to the cygpath arguments
if [ "$GRADLE_CYGPATTERN" != "" ] ; then
OURCYGPATTERN="$OURCYGPATTERN|($GRADLE_CYGPATTERN)"
fi
# Now convert the arguments - kludge to limit ourselves to /bin/sh
i=0
for arg in "$@" ; do
CHECK=`echo "$arg"|egrep -c "$OURCYGPATTERN" -`
CHECK2=`echo "$arg"|egrep -c "^-"` ### Determine if an option
if [ $CHECK -ne 0 ] && [ $CHECK2 -eq 0 ] ; then ### Added a condition
eval `echo args$i`=`cygpath --path --ignore --mixed "$arg"`
else
eval `echo args$i`="\"$arg\""
fi
i=`expr $i + 1`
done
case $i in
0) set -- ;;
1) set -- "$args0" ;;
2) set -- "$args0" "$args1" ;;
3) set -- "$args0" "$args1" "$args2" ;;
4) set -- "$args0" "$args1" "$args2" "$args3" ;;
5) set -- "$args0" "$args1" "$args2" "$args3" "$args4" ;;
6) set -- "$args0" "$args1" "$args2" "$args3" "$args4" "$args5" ;;
7) set -- "$args0" "$args1" "$args2" "$args3" "$args4" "$args5" "$args6" ;;
8) set -- "$args0" "$args1" "$args2" "$args3" "$args4" "$args5" "$args6" "$args7" ;;
9) set -- "$args0" "$args1" "$args2" "$args3" "$args4" "$args5" "$args6" "$args7" "$args8" ;;
esac
fi
# Escape application args
save () {
for i do printf %s\\n "$i" | sed "s/'/'\\\\''/g;1s/^/'/;\$s/\$/' \\\\/" ; done
echo " "
}
APP_ARGS=`save "$@"`
# Collect all arguments for the java command, following the shell quoting and substitution rules
eval set -- $DEFAULT_JVM_OPTS $JAVA_OPTS $GRADLE_OPTS "\"-Dorg.gradle.appname=$APP_BASE_NAME\"" -classpath "\"$CLASSPATH\"" org.gradle.wrapper.GradleWrapperMain "$APP_ARGS"
exec "$JAVACMD" "$@"

1
examples/llama.android/lib/.gitignore vendored Normal file
View File

@@ -0,0 +1 @@
/build

View File

@@ -0,0 +1,78 @@
plugins {
alias(libs.plugins.android.library)
alias(libs.plugins.jetbrains.kotlin.android)
}
android {
namespace = "com.arm.aichat"
compileSdk = 36
ndkVersion = "29.0.13113456"
defaultConfig {
minSdk = 33
testInstrumentationRunner = "androidx.test.runner.AndroidJUnitRunner"
consumerProguardFiles("consumer-rules.pro")
ndk {
abiFilters += listOf("arm64-v8a", "x86_64")
}
externalNativeBuild {
cmake {
arguments += "-DCMAKE_BUILD_TYPE=Release"
arguments += "-DCMAKE_MESSAGE_LOG_LEVEL=DEBUG"
arguments += "-DCMAKE_VERBOSE_MAKEFILE=ON"
arguments += "-DBUILD_SHARED_LIBS=ON"
arguments += "-DLLAMA_BUILD_COMMON=ON"
arguments += "-DLLAMA_OPENSSL=OFF"
arguments += "-DGGML_NATIVE=OFF"
arguments += "-DGGML_BACKEND_DL=ON"
arguments += "-DGGML_CPU_ALL_VARIANTS=ON"
arguments += "-DGGML_LLAMAFILE=OFF"
}
}
aarMetadata {
minCompileSdk = 35
}
}
externalNativeBuild {
cmake {
path("src/main/cpp/CMakeLists.txt")
version = "3.31.6"
}
}
compileOptions {
sourceCompatibility = JavaVersion.VERSION_17
targetCompatibility = JavaVersion.VERSION_17
}
kotlin {
jvmToolchain(17)
compileOptions {
targetCompatibility = JavaVersion.VERSION_17
}
}
packaging {
resources {
excludes += "/META-INF/{AL2.0,LGPL2.1}"
}
}
publishing {
singleVariant("release") {
withJavadocJar()
}
}
}
dependencies {
implementation(libs.androidx.core.ktx)
implementation(libs.androidx.datastore.preferences)
testImplementation(libs.junit)
androidTestImplementation(libs.androidx.junit)
}

View File

@@ -0,0 +1,8 @@
-keep class com.arm.aichat.* { *; }
-keep class com.arm.aichat.gguf.* { *; }
-keepclasseswithmembernames class * {
native <methods>;
}
-keep class kotlin.Metadata { *; }

View File

@@ -0,0 +1,21 @@
# Add project specific ProGuard rules here.
# You can control the set of applied configuration files using the
# proguardFiles setting in build.gradle.
#
# For more details, see
# http://developer.android.com/guide/developing/tools/proguard.html
# If your project uses WebView with JS, uncomment the following
# and specify the fully qualified class name to the JavaScript interface
# class:
#-keepclassmembers class fqcn.of.javascript.interface.for.webview {
# public *;
#}
# Uncomment this to preserve the line number information for
# debugging stack traces.
#-keepattributes SourceFile,LineNumberTable
# If you keep the line number information, uncomment this to
# hide the original source file name.
#-renamesourcefileattribute SourceFile

View File

@@ -0,0 +1,24 @@
package android.llama.cpp
import androidx.test.platform.app.InstrumentationRegistry
import androidx.test.ext.junit.runners.AndroidJUnit4
import org.junit.Test
import org.junit.runner.RunWith
import org.junit.Assert.*
/**
* Instrumented test, which will execute on an Android device.
*
* See [testing documentation](http://d.android.com/tools/testing).
*/
@RunWith(AndroidJUnit4::class)
class ExampleInstrumentedTest {
@Test
fun useAppContext() {
// Context of the app under test.
val appContext = InstrumentationRegistry.getInstrumentation().targetContext
assertEquals("android.llama.cpp.test", appContext.packageName)
}
}

View File

@@ -0,0 +1,4 @@
<?xml version="1.0" encoding="utf-8"?>
<manifest xmlns:android="http://schemas.android.com/apk/res/android">
</manifest>

View File

@@ -0,0 +1,56 @@
cmake_minimum_required(VERSION 3.31.6)
project("ai-chat" VERSION 1.0.0 LANGUAGES C CXX)
set(CMAKE_C_STANDARD 11)
set(CMAKE_C_STANDARD_REQUIRED true)
set(CMAKE_CXX_STANDARD 17)
set(CMAKE_CXX_STANDARD_REQUIRED true)
set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS}" CACHE STRING "" FORCE)
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS}" CACHE STRING "" FORCE)
# --------------------------------------------------------------------------
# AI Chat library
# --------------------------------------------------------------------------
if(DEFINED ANDROID_ABI)
message(STATUS "Detected Android ABI: ${ANDROID_ABI}")
if(ANDROID_ABI STREQUAL "arm64-v8a")
set(GGML_SYSTEM_ARCH "ARM")
set(GGML_CPU_KLEIDIAI ON)
set(GGML_OPENMP ON)
elseif(ANDROID_ABI STREQUAL "x86_64")
set(GGML_SYSTEM_ARCH "x86")
set(GGML_CPU_KLEIDIAI OFF)
set(GGML_OPENMP OFF)
else()
message(FATAL_ERROR "Unsupported ABI: ${ANDROID_ABI}")
endif()
endif()
set(LLAMA_SRC ${CMAKE_CURRENT_LIST_DIR}/../../../../../../)
add_subdirectory(${LLAMA_SRC} build-llama)
add_library(${CMAKE_PROJECT_NAME} SHARED
ai_chat.cpp)
target_compile_definitions(${CMAKE_PROJECT_NAME} PRIVATE
GGML_SYSTEM_ARCH=${GGML_SYSTEM_ARCH}
GGML_CPU_KLEIDIAI=$<BOOL:${GGML_CPU_KLEIDIAI}>
GGML_OPENMP=$<BOOL:${GGML_OPENMP}>
)
target_include_directories(${CMAKE_PROJECT_NAME} PRIVATE
${LLAMA_SRC}
${LLAMA_SRC}/common
${LLAMA_SRC}/include
${LLAMA_SRC}/ggml/include
${LLAMA_SRC}/ggml/src)
target_link_libraries(${CMAKE_PROJECT_NAME}
llama
common
android
log)

View File

@@ -0,0 +1,565 @@
#include <android/log.h>
#include <jni.h>
#include <iomanip>
#include <cmath>
#include <string>
#include <unistd.h>
#include <sampling.h>
#include "logging.h"
#include "chat.h"
#include "common.h"
#include "llama.h"
template<class T>
static std::string join(const std::vector<T> &values, const std::string &delim) {
std::ostringstream str;
for (size_t i = 0; i < values.size(); i++) {
str << values[i];
if (i < values.size() - 1) { str << delim; }
}
return str.str();
}
/**
* LLama resources: context, model, batch and sampler
*/
constexpr int N_THREADS_MIN = 2;
constexpr int N_THREADS_MAX = 4;
constexpr int N_THREADS_HEADROOM = 2;
constexpr int DEFAULT_CONTEXT_SIZE = 8192;
constexpr int OVERFLOW_HEADROOM = 4;
constexpr int BATCH_SIZE = 512;
constexpr float DEFAULT_SAMPLER_TEMP = 0.3f;
static llama_model * g_model;
static llama_context * g_context;
static llama_batch g_batch;
static common_chat_templates_ptr g_chat_templates;
static common_sampler * g_sampler;
extern "C"
JNIEXPORT void JNICALL
Java_com_arm_aichat_internal_InferenceEngineImpl_init(JNIEnv *env, jobject /*unused*/, jstring nativeLibDir) {
// Set llama log handler to Android
llama_log_set(aichat_android_log_callback, nullptr);
// Loading all CPU backend variants
const auto *path_to_backend = env->GetStringUTFChars(nativeLibDir, 0);
LOGi("Loading backends from %s", path_to_backend);
ggml_backend_load_all_from_path(path_to_backend);
env->ReleaseStringUTFChars(nativeLibDir, path_to_backend);
// Initialize backends
llama_backend_init();
LOGi("Backend initiated; Log handler set.");
}
extern "C"
JNIEXPORT jint JNICALL
Java_com_arm_aichat_internal_InferenceEngineImpl_load(JNIEnv *env, jobject, jstring jmodel_path) {
llama_model_params model_params = llama_model_default_params();
const auto *model_path = env->GetStringUTFChars(jmodel_path, 0);
LOGd("%s: Loading model from: \n%s\n", __func__, model_path);
auto *model = llama_model_load_from_file(model_path, model_params);
env->ReleaseStringUTFChars(jmodel_path, model_path);
if (!model) {
return 1;
}
g_model = model;
return 0;
}
static llama_context *init_context(llama_model *model, const int n_ctx = DEFAULT_CONTEXT_SIZE) {
if (!model) {
LOGe("%s: model cannot be null", __func__);
return nullptr;
}
// Multi-threading setup
const int n_threads = std::max(N_THREADS_MIN, std::min(N_THREADS_MAX,
(int) sysconf(_SC_NPROCESSORS_ONLN) -
N_THREADS_HEADROOM));
LOGi("%s: Using %d threads", __func__, n_threads);
// Context parameters setup
llama_context_params ctx_params = llama_context_default_params();
const int trained_context_size = llama_model_n_ctx_train(model);
if (n_ctx > trained_context_size) {
LOGw("%s: Model was trained with only %d context size! Enforcing %d context size...",
__func__, trained_context_size, n_ctx);
}
ctx_params.n_ctx = n_ctx;
ctx_params.n_batch = BATCH_SIZE;
ctx_params.n_ubatch = BATCH_SIZE;
ctx_params.n_threads = n_threads;
ctx_params.n_threads_batch = n_threads;
auto *context = llama_init_from_model(g_model, ctx_params);
if (context == nullptr) {
LOGe("%s: llama_new_context_with_model() returned null)", __func__);
}
return context;
}
static common_sampler *new_sampler(float temp) {
common_params_sampling sparams;
sparams.temp = temp;
return common_sampler_init(g_model, sparams);
}
extern "C"
JNIEXPORT jint JNICALL
Java_com_arm_aichat_internal_InferenceEngineImpl_prepare(JNIEnv * /*env*/, jobject /*unused*/) {
auto *context = init_context(g_model);
if (!context) { return 1; }
g_context = context;
g_batch = llama_batch_init(BATCH_SIZE, 0, 1);
g_chat_templates = common_chat_templates_init(g_model, "");
g_sampler = new_sampler(DEFAULT_SAMPLER_TEMP);
return 0;
}
static std::string get_backend() {
std::vector<std::string> backends;
for (size_t i = 0; i < ggml_backend_reg_count(); i++) {
auto *reg = ggml_backend_reg_get(i);
std::string name = ggml_backend_reg_name(reg);
if (name != "CPU") {
backends.push_back(ggml_backend_reg_name(reg));
}
}
return backends.empty() ? "CPU" : join(backends, ",");
}
extern "C"
JNIEXPORT jstring JNICALL
Java_com_arm_aichat_internal_InferenceEngineImpl_systemInfo(JNIEnv *env, jobject /*unused*/) {
return env->NewStringUTF(llama_print_system_info());
}
extern "C"
JNIEXPORT jstring JNICALL
Java_com_arm_aichat_internal_InferenceEngineImpl_benchModel(JNIEnv *env, jobject /*unused*/, jint pp, jint tg,
jint pl, jint nr) {
auto *context = init_context(g_model, pp);
if (!context) {
const auto *const err_msg = "Fail to init_context! Bench aborted.";
LOGe(err_msg);
return env->NewStringUTF(err_msg);
}
auto pp_avg = 0.0;
auto tg_avg = 0.0;
auto pp_std = 0.0;
auto tg_std = 0.0;
const uint32_t n_ctx = llama_n_ctx(context);
LOGi("n_ctx = %d", n_ctx);
int i, j;
int nri;
for (nri = 0; nri < nr; nri++) {
LOGi("Benchmark prompt processing (pp = %d)", pp);
common_batch_clear(g_batch);
const int n_tokens = pp;
for (i = 0; i < n_tokens; i++) {
common_batch_add(g_batch, 0, i, {0}, false);
}
g_batch.logits[g_batch.n_tokens - 1] = true;
llama_memory_clear(llama_get_memory(context), false);
const auto t_pp_start = ggml_time_us();
if (llama_decode(context, g_batch) != 0) {
LOGe("llama_decode() failed during prompt processing");
}
const auto t_pp_end = ggml_time_us();
// bench text generation
LOGi("Benchmark text generation (tg = %d)", tg);
llama_memory_clear(llama_get_memory(context), false);
const auto t_tg_start = ggml_time_us();
for (i = 0; i < tg; i++) {
common_batch_clear(g_batch);
for (j = 0; j < pl; j++) {
common_batch_add(g_batch, 0, i, {j}, true);
}
if (llama_decode(context, g_batch) != 0) {
LOGe("llama_decode() failed during text generation");
}
}
const auto t_tg_end = ggml_time_us();
llama_memory_clear(llama_get_memory(context), false);
const auto t_pp = double(t_pp_end - t_pp_start) / 1000000.0;
const auto t_tg = double(t_tg_end - t_tg_start) / 1000000.0;
const auto speed_pp = double(pp) / t_pp;
const auto speed_tg = double(pl * tg) / t_tg;
pp_avg += speed_pp;
tg_avg += speed_tg;
pp_std += speed_pp * speed_pp;
tg_std += speed_tg * speed_tg;
LOGi("pp %f t/s, tg %f t/s", speed_pp, speed_tg);
}
llama_free(context);
pp_avg /= double(nr);
tg_avg /= double(nr);
if (nr > 1) {
pp_std = sqrt(pp_std / double(nr - 1) - pp_avg * pp_avg * double(nr) / double(nr - 1));
tg_std = sqrt(tg_std / double(nr - 1) - tg_avg * tg_avg * double(nr) / double(nr - 1));
} else {
pp_std = 0;
tg_std = 0;
}
char model_desc[128];
llama_model_desc(g_model, model_desc, sizeof(model_desc));
const auto model_size = double(llama_model_size(g_model)) / 1024.0 / 1024.0 / 1024.0;
const auto model_n_params = double(llama_model_n_params(g_model)) / 1e9;
const auto backend = get_backend();
std::stringstream result;
result << std::setprecision(3);
result << "| model | size | params | backend | test | t/s |\n";
result << "| --- | --- | --- | --- | --- | --- |\n";
result << "| " << model_desc << " | " << model_size << "GiB | " << model_n_params << "B | "
<< backend << " | pp " << pp << " | " << pp_avg << " ± " << pp_std << " |\n";
result << "| " << model_desc << " | " << model_size << "GiB | " << model_n_params << "B | "
<< backend << " | tg " << tg << " | " << tg_avg << " ± " << tg_std << " |\n";
return env->NewStringUTF(result.str().c_str());
}
/**
* Completion loop's long-term states:
* - chat management
* - position tracking
*/
constexpr const char *ROLE_SYSTEM = "system";
constexpr const char *ROLE_USER = "user";
constexpr const char *ROLE_ASSISTANT = "assistant";
static std::vector<common_chat_msg> chat_msgs;
static llama_pos system_prompt_position;
static llama_pos current_position;
static void reset_long_term_states(const bool clear_kv_cache = true) {
chat_msgs.clear();
system_prompt_position = 0;
current_position = 0;
if (clear_kv_cache)
llama_memory_clear(llama_get_memory(g_context), false);
}
/**
* TODO-hyin: implement sliding-window version as a better alternative
*
* Context shifting by discarding the older half of the tokens appended after system prompt:
* - take the [system_prompt_position] first tokens from the original prompt
* - take half of the last (system_prompt_position - system_prompt_position) tokens
* - recompute the logits in batches
*/
static void shift_context() {
const int n_discard = (current_position - system_prompt_position) / 2;
LOGi("%s: Discarding %d tokens", __func__, n_discard);
llama_memory_seq_rm(llama_get_memory(g_context), 0, system_prompt_position, system_prompt_position + n_discard);
llama_memory_seq_add(llama_get_memory(g_context), 0, system_prompt_position + n_discard, current_position, -n_discard);
current_position -= n_discard;
LOGi("%s: Context shifting done! Current position: %d", __func__, current_position);
}
static std::string chat_add_and_format(const std::string &role, const std::string &content) {
common_chat_msg new_msg;
new_msg.role = role;
new_msg.content = content;
auto formatted = common_chat_format_single(
g_chat_templates.get(), chat_msgs, new_msg, role == ROLE_USER, /* use_jinja */ false);
chat_msgs.push_back(new_msg);
LOGi("%s: Formatted and added %s message: \n%s\n", __func__, role.c_str(), formatted.c_str());
return formatted;
}
/**
* Completion loop's short-term states:
* - stop generation position
* - token chars caching
* - current assistant message being generated
*/
static llama_pos stop_generation_position;
static std::string cached_token_chars;
static std::ostringstream assistant_ss;
static void reset_short_term_states() {
stop_generation_position = 0;
cached_token_chars.clear();
assistant_ss.str("");
}
static int decode_tokens_in_batches(
llama_context *context,
llama_batch &batch,
const llama_tokens &tokens,
const llama_pos start_pos,
const bool compute_last_logit = false) {
// Process tokens in batches using the global batch
LOGd("%s: Decode %d tokens starting at position %d", __func__, (int) tokens.size(), start_pos);
for (int i = 0; i < (int) tokens.size(); i += BATCH_SIZE) {
const int cur_batch_size = std::min((int) tokens.size() - i, BATCH_SIZE);
common_batch_clear(batch);
LOGv("%s: Preparing a batch size of %d starting at: %d", __func__, cur_batch_size, i);
// Shift context if current batch cannot fit into the context
if (start_pos + i + cur_batch_size >= DEFAULT_CONTEXT_SIZE - OVERFLOW_HEADROOM) {
LOGw("%s: Current batch won't fit into context! Shifting...", __func__);
shift_context();
}
// Add tokens to the batch with proper positions
for (int j = 0; j < cur_batch_size; j++) {
const llama_token token_id = tokens[i + j];
const llama_pos position = start_pos + i + j;
const bool want_logit = compute_last_logit && (i + j == tokens.size() - 1);
common_batch_add(batch, token_id, position, {0}, want_logit);
}
// Decode this batch
const int decode_result = llama_decode(context, batch);
if (decode_result) {
LOGe("%s: llama_decode failed w/ %d", __func__, decode_result);
return 1;
}
}
return 0;
}
extern "C"
JNIEXPORT jint JNICALL
Java_com_arm_aichat_internal_InferenceEngineImpl_processSystemPrompt(
JNIEnv *env,
jobject /*unused*/,
jstring jsystem_prompt
) {
// Reset long-term & short-term states
reset_long_term_states();
reset_short_term_states();
// Obtain system prompt from JEnv
const auto *system_prompt = env->GetStringUTFChars(jsystem_prompt, nullptr);
LOGd("%s: System prompt received: \n%s", __func__, system_prompt);
std::string formatted_system_prompt(system_prompt);
env->ReleaseStringUTFChars(jsystem_prompt, system_prompt);
// Format system prompt if applicable
const bool has_chat_template = common_chat_templates_was_explicit(g_chat_templates.get());
if (has_chat_template) {
formatted_system_prompt = chat_add_and_format(ROLE_SYSTEM, system_prompt);
}
// Tokenize system prompt
const auto system_tokens = common_tokenize(g_context, formatted_system_prompt,
has_chat_template, has_chat_template);
for (auto id: system_tokens) {
LOGv("token: `%s`\t -> `%d`", common_token_to_piece(g_context, id).c_str(), id);
}
// Handle context overflow
const int max_batch_size = DEFAULT_CONTEXT_SIZE - OVERFLOW_HEADROOM;
if ((int) system_tokens.size() > max_batch_size) {
LOGe("%s: System prompt too long for context! %d tokens, max: %d",
__func__, (int) system_tokens.size(), max_batch_size);
return 1;
}
// Decode system tokens in batches
if (decode_tokens_in_batches(g_context, g_batch, system_tokens, current_position)) {
LOGe("%s: llama_decode() failed!", __func__);
return 2;
}
// Update position
system_prompt_position = current_position = (int) system_tokens.size();
return 0;
}
extern "C"
JNIEXPORT jint JNICALL
Java_com_arm_aichat_internal_InferenceEngineImpl_processUserPrompt(
JNIEnv *env,
jobject /*unused*/,
jstring juser_prompt,
jint n_predict
) {
// Reset short-term states
reset_short_term_states();
// Obtain and tokenize user prompt
const auto *const user_prompt = env->GetStringUTFChars(juser_prompt, nullptr);
LOGd("%s: User prompt received: \n%s", __func__, user_prompt);
std::string formatted_user_prompt(user_prompt);
env->ReleaseStringUTFChars(juser_prompt, user_prompt);
// Format user prompt if applicable
const bool has_chat_template = common_chat_templates_was_explicit(g_chat_templates.get());
if (has_chat_template) {
formatted_user_prompt = chat_add_and_format(ROLE_USER, user_prompt);
}
// Decode formatted user prompts
auto user_tokens = common_tokenize(g_context, formatted_user_prompt, has_chat_template, has_chat_template);
for (auto id: user_tokens) {
LOGv("token: `%s`\t -> `%d`", common_token_to_piece(g_context, id).c_str(), id);
}
// Ensure user prompt doesn't exceed the context size by truncating if necessary.
const int user_prompt_size = (int) user_tokens.size();
const int max_batch_size = DEFAULT_CONTEXT_SIZE - OVERFLOW_HEADROOM;
if (user_prompt_size > max_batch_size) {
const int skipped_tokens = user_prompt_size - max_batch_size;
user_tokens.resize(max_batch_size);
LOGw("%s: User prompt too long! Skipped %d tokens!", __func__, skipped_tokens);
}
// Decode user tokens in batches
if (decode_tokens_in_batches(g_context, g_batch, user_tokens, current_position, true)) {
LOGe("%s: llama_decode() failed!", __func__);
return 2;
}
// Update position
current_position += user_prompt_size;
stop_generation_position = current_position + user_prompt_size + n_predict;
return 0;
}
static bool is_valid_utf8(const char *string) {
if (!string) { return true; }
const auto *bytes = (const unsigned char *) string;
int num;
while (*bytes != 0x00) {
if ((*bytes & 0x80) == 0x00) {
// U+0000 to U+007F
num = 1;
} else if ((*bytes & 0xE0) == 0xC0) {
// U+0080 to U+07FF
num = 2;
} else if ((*bytes & 0xF0) == 0xE0) {
// U+0800 to U+FFFF
num = 3;
} else if ((*bytes & 0xF8) == 0xF0) {
// U+10000 to U+10FFFF
num = 4;
} else {
return false;
}
bytes += 1;
for (int i = 1; i < num; ++i) {
if ((*bytes & 0xC0) != 0x80) {
return false;
}
bytes += 1;
}
}
return true;
}
extern "C"
JNIEXPORT jstring JNICALL
Java_com_arm_aichat_internal_InferenceEngineImpl_generateNextToken(
JNIEnv *env,
jobject /*unused*/
) {
// Infinite text generation via context shifting
if (current_position >= DEFAULT_CONTEXT_SIZE - OVERFLOW_HEADROOM) {
LOGw("%s: Context full! Shifting...", __func__);
shift_context();
}
// Stop if reaching the marked position
if (current_position >= stop_generation_position) {
LOGw("%s: STOP: hitting stop position: %d", __func__, stop_generation_position);
return nullptr;
}
// Sample next token
const auto new_token_id = common_sampler_sample(g_sampler, g_context, -1);
common_sampler_accept(g_sampler, new_token_id, true);
// Populate the batch with new token, then decode
common_batch_clear(g_batch);
common_batch_add(g_batch, new_token_id, current_position, {0}, true);
if (llama_decode(g_context, g_batch) != 0) {
LOGe("%s: llama_decode() failed for generated token", __func__);
return nullptr;
}
// Update position
current_position++;
// Stop if next token is EOG
if (llama_vocab_is_eog(llama_model_get_vocab(g_model), new_token_id)) {
LOGd("id: %d,\tIS EOG!\nSTOP.", new_token_id);
chat_add_and_format(ROLE_ASSISTANT, assistant_ss.str());
return nullptr;
}
// If not EOG, convert to text
auto new_token_chars = common_token_to_piece(g_context, new_token_id);
cached_token_chars += new_token_chars;
// Create and return a valid UTF-8 Java string
jstring result = nullptr;
if (is_valid_utf8(cached_token_chars.c_str())) {
result = env->NewStringUTF(cached_token_chars.c_str());
LOGv("id: %d,\tcached: `%s`,\tnew: `%s`", new_token_id, cached_token_chars.c_str(), new_token_chars.c_str());
assistant_ss << cached_token_chars;
cached_token_chars.clear();
} else {
LOGv("id: %d,\tappend to cache", new_token_id);
result = env->NewStringUTF("");
}
return result;
}
extern "C"
JNIEXPORT void JNICALL
Java_com_arm_aichat_internal_InferenceEngineImpl_unload(JNIEnv * /*unused*/, jobject /*unused*/) {
// Reset long-term & short-term states
reset_long_term_states();
reset_short_term_states();
// Free up resources
common_sampler_free(g_sampler);
g_chat_templates.reset();
llama_batch_free(g_batch);
llama_free(g_context);
llama_model_free(g_model);
}
extern "C"
JNIEXPORT void JNICALL
Java_com_arm_aichat_internal_InferenceEngineImpl_shutdown(JNIEnv *, jobject /*unused*/) {
llama_backend_free();
}

View File

@@ -0,0 +1,61 @@
//
// Created by Han Yin on 10/31/25.
//
#ifndef AICHAT_LOGGING_H
#define AICHAT_LOGGING_H
#endif //AICHAT_LOGGING_H
#pragma once
#include <android/log.h>
#ifndef LOG_TAG
#define LOG_TAG "ai-chat"
#endif
#ifndef LOG_MIN_LEVEL
#if defined(NDEBUG)
#define LOG_MIN_LEVEL ANDROID_LOG_INFO
#else
#define LOG_MIN_LEVEL ANDROID_LOG_VERBOSE
#endif
#endif
static inline int ai_should_log(int prio) {
return __android_log_is_loggable(prio, LOG_TAG, LOG_MIN_LEVEL);
}
#if LOG_MIN_LEVEL <= ANDROID_LOG_VERBOSE
#define LOGv(...) do { if (ai_should_log(ANDROID_LOG_VERBOSE)) __android_log_print(ANDROID_LOG_VERBOSE, LOG_TAG, __VA_ARGS__); } while (0)
#else
#define LOGv(...) ((void)0)
#endif
#if LOG_MIN_LEVEL <= ANDROID_LOG_DEBUG
#define LOGd(...) do { if (ai_should_log(ANDROID_LOG_DEBUG)) __android_log_print(ANDROID_LOG_DEBUG, LOG_TAG, __VA_ARGS__); } while (0)
#else
#define LOGd(...) ((void)0)
#endif
#define LOGi(...) do { if (ai_should_log(ANDROID_LOG_INFO )) __android_log_print(ANDROID_LOG_INFO , LOG_TAG, __VA_ARGS__); } while (0)
#define LOGw(...) do { if (ai_should_log(ANDROID_LOG_WARN )) __android_log_print(ANDROID_LOG_WARN , LOG_TAG, __VA_ARGS__); } while (0)
#define LOGe(...) do { if (ai_should_log(ANDROID_LOG_ERROR)) __android_log_print(ANDROID_LOG_ERROR, LOG_TAG, __VA_ARGS__); } while (0)
static inline int android_log_prio_from_ggml(enum ggml_log_level level) {
switch (level) {
case GGML_LOG_LEVEL_ERROR: return ANDROID_LOG_ERROR;
case GGML_LOG_LEVEL_WARN: return ANDROID_LOG_WARN;
case GGML_LOG_LEVEL_INFO: return ANDROID_LOG_INFO;
case GGML_LOG_LEVEL_DEBUG: return ANDROID_LOG_DEBUG;
default: return ANDROID_LOG_DEFAULT;
}
}
static inline void aichat_android_log_callback(enum ggml_log_level level,
const char* text,
void* /*user*/) {
const int prio = android_log_prio_from_ggml(level);
if (!ai_should_log(prio)) return;
__android_log_write(prio, LOG_TAG, text);
}

View File

@@ -0,0 +1,14 @@
package com.arm.aichat
import android.content.Context
import com.arm.aichat.internal.InferenceEngineImpl
/**
* Main entry point for Arm's AI Chat library.
*/
object AiChat {
/**
* Get the inference engine single instance.
*/
fun getInferenceEngine(context: Context) = InferenceEngineImpl.getInstance(context)
}

View File

@@ -0,0 +1,89 @@
package com.arm.aichat
import com.arm.aichat.InferenceEngine.State
import kotlinx.coroutines.flow.Flow
import kotlinx.coroutines.flow.StateFlow
/**
* Interface defining the core LLM inference operations.
*/
interface InferenceEngine {
/**
* Current state of the inference engine
*/
val state: StateFlow<State>
/**
* Load a model from the given path.
*
* @throws UnsupportedArchitectureException if model architecture not supported
*/
suspend fun loadModel(pathToModel: String)
/**
* Sends a system prompt to the loaded model
*/
suspend fun setSystemPrompt(systemPrompt: String)
/**
* Sends a user prompt to the loaded model and returns a Flow of generated tokens.
*/
fun sendUserPrompt(message: String, predictLength: Int = DEFAULT_PREDICT_LENGTH): Flow<String>
/**
* Runs a benchmark with the specified parameters.
*/
suspend fun bench(pp: Int, tg: Int, pl: Int, nr: Int = 1): String
/**
* Unloads the currently loaded model.
*/
fun cleanUp()
/**
* Cleans up resources when the engine is no longer needed.
*/
fun destroy()
/**
* States of the inference engine
*/
sealed class State {
object Uninitialized : State()
object Initializing : State()
object Initialized : State()
object LoadingModel : State()
object UnloadingModel : State()
object ModelReady : State()
object Benchmarking : State()
object ProcessingSystemPrompt : State()
object ProcessingUserPrompt : State()
object Generating : State()
data class Error(val exception: Exception) : State()
}
companion object {
const val DEFAULT_PREDICT_LENGTH = 1024
}
}
val State.isUninterruptible
get() = this is State.Initializing ||
this is State.LoadingModel ||
this is State.UnloadingModel ||
this is State.Benchmarking ||
this is State.ProcessingSystemPrompt ||
this is State.ProcessingUserPrompt
val State.isModelLoaded: Boolean
get() = this is State.ModelReady ||
this is State.Benchmarking ||
this is State.ProcessingSystemPrompt ||
this is State.ProcessingUserPrompt ||
this is State.Generating
class UnsupportedArchitectureException : Exception()

Some files were not shown because too many files have changed in this diff Show More