Sync from upstream llama.cpp repository
45
examples/CMakeLists.txt
Normal file
@@ -0,0 +1,45 @@
|
||||
# dependencies
|
||||
|
||||
find_package(Threads REQUIRED)
|
||||
|
||||
# third-party
|
||||
|
||||
# ...
|
||||
|
||||
# flags
|
||||
|
||||
llama_add_compile_flags()
|
||||
|
||||
# examples
|
||||
|
||||
if (EMSCRIPTEN)
|
||||
else()
|
||||
add_subdirectory(batched)
|
||||
add_subdirectory(debug)
|
||||
add_subdirectory(embedding)
|
||||
add_subdirectory(eval-callback)
|
||||
|
||||
add_subdirectory(gguf-hash)
|
||||
add_subdirectory(gguf)
|
||||
add_subdirectory(idle)
|
||||
add_subdirectory(lookahead)
|
||||
add_subdirectory(lookup)
|
||||
add_subdirectory(parallel)
|
||||
add_subdirectory(passkey)
|
||||
add_subdirectory(retrieval)
|
||||
add_subdirectory(save-load-state)
|
||||
add_subdirectory(simple)
|
||||
add_subdirectory(simple-chat)
|
||||
add_subdirectory(speculative)
|
||||
add_subdirectory(speculative-simple)
|
||||
add_subdirectory(gen-docs)
|
||||
add_subdirectory(training)
|
||||
add_subdirectory(diffusion)
|
||||
if (NOT GGML_BACKEND_DL)
|
||||
add_subdirectory(convert-llama2c-to-ggml)
|
||||
# these examples use the backends directly and cannot be built with dynamic loading
|
||||
if (GGML_SYCL)
|
||||
add_subdirectory(sycl)
|
||||
endif()
|
||||
endif()
|
||||
endif()
|
||||
9
examples/batched.swift/.gitignore
vendored
Normal file
@@ -0,0 +1,9 @@
|
||||
.DS_Store
|
||||
/.build
|
||||
/Packages
|
||||
xcuserdata/
|
||||
DerivedData/
|
||||
.swiftpm/configuration/registries.json
|
||||
.swiftpm/xcode/package.xcworkspace/contents.xcworkspacedata
|
||||
.netrc
|
||||
batched_swift
|
||||
6
examples/batched.swift/Makefile
Executable file
@@ -0,0 +1,6 @@
|
||||
.PHONY: build
|
||||
|
||||
build:
|
||||
xcodebuild -scheme llama-batched-swift -destination "generic/platform=macOS" -derivedDataPath build
|
||||
rm -f ./llama-batched-swift
|
||||
ln -s ./build/Build/Products/Debug/llama-batched-swift ./llama-batched-swift
|
||||
22
examples/batched.swift/Package.swift
Normal file
@@ -0,0 +1,22 @@
|
||||
// swift-tools-version: 5.5
|
||||
// The swift-tools-version declares the minimum version of Swift required to build this package.
|
||||
|
||||
import PackageDescription
|
||||
|
||||
let package = Package(
|
||||
name: "llama-batched-swift",
|
||||
platforms: [.macOS(.v12)],
|
||||
dependencies: [
|
||||
.package(name: "llama", path: "../../"),
|
||||
],
|
||||
targets: [
|
||||
// Targets are the basic building blocks of a package, defining a module or a test suite.
|
||||
// Targets can depend on other targets in this package and products from dependencies.
|
||||
.executableTarget(
|
||||
name: "llama-batched-swift",
|
||||
dependencies: ["llama"],
|
||||
path: "Sources",
|
||||
linkerSettings: [.linkedFramework("Foundation"), .linkedFramework("AppKit")]
|
||||
),
|
||||
]
|
||||
)
|
||||
5
examples/batched.swift/README.md
Normal file
@@ -0,0 +1,5 @@
|
||||
This is a swift clone of `examples/batched`.
|
||||
|
||||
```bash
|
||||
$ ./llama-batched-swift MODEL_PATH [PROMPT] [PARALLEL]
|
||||
```
|
||||
256
examples/batched.swift/Sources/main.swift
Normal file
@@ -0,0 +1,256 @@
|
||||
import Foundation
|
||||
import llama
|
||||
|
||||
let arguments = CommandLine.arguments
|
||||
|
||||
// Check that we have at least one argument (the model path)
|
||||
guard arguments.count > 1 else {
|
||||
print("Usage: swift MODEL_PATH [PROMPT] [PARALLEL]")
|
||||
exit(1)
|
||||
}
|
||||
|
||||
let modelPath: String = arguments[1]
|
||||
let prompt: String = arguments.count > 2 ? arguments[2] : "Hello my name is"
|
||||
let n_parallel: Int = arguments.count > 3 && Int(arguments[3]) != nil ? Int(arguments[3])! : 1
|
||||
|
||||
// total length of the sequences including the prompt
|
||||
let n_len: Int = 32
|
||||
|
||||
// init LLM
|
||||
llama_backend_init()
|
||||
defer {
|
||||
llama_backend_free()
|
||||
}
|
||||
|
||||
let model_params = llama_model_default_params()
|
||||
guard let model = llama_model_load_from_file(modelPath.cString(using: .utf8), model_params) else {
|
||||
print("Failed to load model")
|
||||
exit(1)
|
||||
}
|
||||
defer {
|
||||
llama_model_free(model)
|
||||
}
|
||||
|
||||
guard let vocab = llama_model_get_vocab(model) else {
|
||||
print("Failed to get vocab")
|
||||
exit(1)
|
||||
}
|
||||
|
||||
var tokens = tokenize(text: prompt, add_bos: true)
|
||||
|
||||
let n_kv_req = UInt32(tokens.count) + UInt32((n_len - Int(tokens.count)) * n_parallel)
|
||||
|
||||
var context_params = llama_context_default_params()
|
||||
context_params.n_ctx = n_kv_req
|
||||
context_params.n_batch = UInt32(max(n_len, n_parallel))
|
||||
context_params.n_threads = 8
|
||||
context_params.n_threads_batch = 8
|
||||
|
||||
let context = llama_init_from_model(model, context_params)
|
||||
guard context != nil else {
|
||||
print("Failed to initialize context")
|
||||
exit(1)
|
||||
}
|
||||
defer {
|
||||
llama_free(context)
|
||||
}
|
||||
|
||||
var sparams = llama_sampler_chain_default_params()
|
||||
|
||||
let smpl = llama_sampler_chain_init(sparams)
|
||||
guard smpl != nil else {
|
||||
print("Failed to initialize sampling")
|
||||
exit(1)
|
||||
}
|
||||
defer {
|
||||
llama_sampler_free(smpl)
|
||||
}
|
||||
|
||||
llama_sampler_chain_add(smpl, llama_sampler_init_top_k(40));
|
||||
llama_sampler_chain_add(smpl, llama_sampler_init_top_p(0.9, 1));
|
||||
llama_sampler_chain_add(smpl, llama_sampler_init_temp (0.4));
|
||||
llama_sampler_chain_add(smpl, llama_sampler_init_dist (1234));
|
||||
|
||||
let n_ctx = llama_n_ctx(context)
|
||||
|
||||
print("\nn_len = \(n_len), n_ctx = \(n_ctx), n_batch = \(context_params.n_batch), n_parallel = \(n_parallel), n_kv_req = \(n_kv_req)\n")
|
||||
|
||||
if n_kv_req > n_ctx {
|
||||
print("error: n_kv_req (%d) > n_ctx, the required KV cache size is not big enough\n", n_kv_req)
|
||||
exit(1)
|
||||
}
|
||||
|
||||
var buffer: [CChar] = []
|
||||
for id: llama_token in tokens {
|
||||
print(token_to_piece(token: id, buffer: &buffer) ?? "", terminator: "")
|
||||
}
|
||||
|
||||
print("\n")
|
||||
|
||||
var batch = llama_batch_init(max(Int32(tokens.count), Int32(n_parallel)), 0, 1)
|
||||
defer {
|
||||
llama_batch_free(batch)
|
||||
}
|
||||
|
||||
// evaluate the initial prompt
|
||||
batch.n_tokens = Int32(tokens.count)
|
||||
|
||||
for (i, token) in tokens.enumerated() {
|
||||
batch.token[i] = token
|
||||
batch.pos[i] = Int32(i)
|
||||
batch.n_seq_id[i] = 1
|
||||
// batch.seq_id[i][0] = 0
|
||||
// TODO: is this the proper way to do this?
|
||||
if let seq_id = batch.seq_id[i] {
|
||||
seq_id[0] = 0
|
||||
}
|
||||
batch.logits[i] = 0
|
||||
}
|
||||
|
||||
// llama_decode will output logits only for the last token of the prompt
|
||||
batch.logits[Int(batch.n_tokens) - 1] = 1
|
||||
|
||||
if llama_decode(context, batch) != 0 {
|
||||
print("llama_decode() failed")
|
||||
exit(1)
|
||||
}
|
||||
|
||||
for i in 1 ..< n_parallel {
|
||||
llama_memory_seq_cp(llama_get_memory(context), 0, Int32(i), 0, batch.n_tokens)
|
||||
}
|
||||
|
||||
if n_parallel > 1 {
|
||||
print("generating \(n_parallel) sequences ...\n")
|
||||
}
|
||||
|
||||
var streams: [String] = .init(repeating: "", count: n_parallel)
|
||||
var streamBuffers: [[CChar]] = .init(repeating: [], count: n_parallel)
|
||||
var i_batch = [Int32](repeating: batch.n_tokens - 1, count: n_parallel)
|
||||
|
||||
var n_cur = batch.n_tokens
|
||||
var n_decode = 0
|
||||
|
||||
let t_main_start = ggml_time_us()
|
||||
|
||||
while n_cur <= n_len {
|
||||
// prepare the next batch
|
||||
batch.n_tokens = 0
|
||||
|
||||
// sample the next token for each parallel sequence / stream
|
||||
for i in 0 ..< n_parallel {
|
||||
if i_batch[i] < 0 {
|
||||
// the stream has already finished
|
||||
continue
|
||||
}
|
||||
|
||||
let new_token_id = llama_sampler_sample(smpl, context, i_batch[i])
|
||||
|
||||
// is it an end of stream? -> mark the stream as finished
|
||||
if llama_vocab_is_eog(vocab, new_token_id) || n_cur == n_len {
|
||||
i_batch[i] = -1
|
||||
// print("")
|
||||
if n_parallel > 1 {
|
||||
print("stream \(i) finished at n_cur = \(n_cur)")
|
||||
}
|
||||
|
||||
continue
|
||||
}
|
||||
|
||||
let nextStringPiece = token_to_piece(token: new_token_id, buffer: &streamBuffers[i]) ?? ""
|
||||
|
||||
// if there is only one stream, we print immediately to stdout
|
||||
if n_parallel == 1 {
|
||||
print(nextStringPiece, terminator: "")
|
||||
}
|
||||
streams[i] += nextStringPiece
|
||||
|
||||
// push this new token for next evaluation
|
||||
batch.token[Int(batch.n_tokens)] = new_token_id
|
||||
batch.pos[Int(batch.n_tokens)] = n_cur
|
||||
batch.n_seq_id[Int(batch.n_tokens)] = 1
|
||||
if let seq_id = batch.seq_id[Int(batch.n_tokens)] {
|
||||
seq_id[0] = Int32(i)
|
||||
}
|
||||
batch.logits[Int(batch.n_tokens)] = 1
|
||||
|
||||
i_batch[i] = batch.n_tokens
|
||||
|
||||
batch.n_tokens += 1
|
||||
|
||||
n_decode += 1
|
||||
}
|
||||
|
||||
// all streams are finished
|
||||
if batch.n_tokens == 0 {
|
||||
break
|
||||
}
|
||||
|
||||
n_cur += 1
|
||||
|
||||
// evaluate the current batch with the transformer model
|
||||
if llama_decode(context, batch) != 0 {
|
||||
print("llama_decode() failed")
|
||||
exit(1)
|
||||
}
|
||||
}
|
||||
|
||||
if n_parallel > 1 {
|
||||
print("\n")
|
||||
for (i, stream) in streams.enumerated() {
|
||||
print("sequence \(i):\n\n\(prompt)\(stream)\n")
|
||||
}
|
||||
}
|
||||
|
||||
let t_main_end = ggml_time_us()
|
||||
|
||||
print("decoded \(n_decode) tokens in \(String(format: "%.2f", Double(t_main_end - t_main_start) / 1_000_000.0)) s, speed: \(String(format: "%.2f", Double(n_decode) / (Double(t_main_end - t_main_start) / 1_000_000.0))) t/s\n\n")
|
||||
|
||||
llama_perf_sampler_print(smpl)
|
||||
llama_perf_context_print(context)
|
||||
|
||||
private func tokenize(text: String, add_bos: Bool) -> [llama_token] {
|
||||
let utf8Count = text.utf8.count
|
||||
let n_tokens = utf8Count + (add_bos ? 1 : 0)
|
||||
let tokens = UnsafeMutablePointer<llama_token>.allocate(capacity: n_tokens)
|
||||
let tokenCount = llama_tokenize(vocab, text, Int32(utf8Count), tokens, Int32(n_tokens), add_bos, /*special tokens*/ false)
|
||||
var swiftTokens: [llama_token] = []
|
||||
for i in 0 ..< tokenCount {
|
||||
swiftTokens.append(tokens[Int(i)])
|
||||
}
|
||||
tokens.deallocate()
|
||||
return swiftTokens
|
||||
}
|
||||
|
||||
private func token_to_piece(token: llama_token, buffer: inout [CChar]) -> String? {
|
||||
var result = [CChar](repeating: 0, count: 8)
|
||||
let nTokens = llama_token_to_piece(vocab, token, &result, Int32(result.count), 0, false)
|
||||
if nTokens < 0 {
|
||||
let actualTokensCount = -Int(nTokens)
|
||||
result = .init(repeating: 0, count: actualTokensCount)
|
||||
let check = llama_token_to_piece(
|
||||
vocab,
|
||||
token,
|
||||
&result,
|
||||
Int32(result.count),
|
||||
0,
|
||||
false
|
||||
)
|
||||
assert(check == actualTokensCount)
|
||||
} else {
|
||||
result.removeLast(result.count - Int(nTokens))
|
||||
}
|
||||
if buffer.isEmpty, let utfString = String(cString: result + [0], encoding: .utf8) {
|
||||
return utfString
|
||||
} else {
|
||||
buffer.append(contentsOf: result)
|
||||
let data = Data(buffer.map { UInt8(bitPattern: $0) })
|
||||
if buffer.count >= 4 { // 4 bytes is the max length of a utf8 character so if we're here we need to reset the buffer
|
||||
buffer = []
|
||||
}
|
||||
guard let bufferString = String(data: data, encoding: .utf8) else {
|
||||
return nil
|
||||
}
|
||||
buffer = []
|
||||
return bufferString
|
||||
}
|
||||
}
|
||||
5
examples/batched/CMakeLists.txt
Normal file
@@ -0,0 +1,5 @@
|
||||
set(TARGET llama-batched)
|
||||
add_executable(${TARGET} batched.cpp)
|
||||
install(TARGETS ${TARGET} RUNTIME)
|
||||
target_link_libraries(${TARGET} PRIVATE common llama ${CMAKE_THREAD_LIBS_INIT})
|
||||
target_compile_features(${TARGET} PRIVATE cxx_std_17)
|
||||
44
examples/batched/README.md
Normal file
@@ -0,0 +1,44 @@
|
||||
# llama.cpp/example/batched
|
||||
|
||||
The example demonstrates batched generation from a given prompt
|
||||
|
||||
```bash
|
||||
./llama-batched -m ./models/llama-7b-v2/ggml-model-f16.gguf -p "Hello my name is" -np 4 --kv-unified
|
||||
|
||||
...
|
||||
|
||||
main: n_len = 32, n_ctx = 2048, n_parallel = 4, n_kv_req = 113
|
||||
|
||||
Hello my name is
|
||||
|
||||
main: generating 4 sequences ...
|
||||
|
||||
main: stream 0 finished
|
||||
main: stream 1 finished
|
||||
main: stream 2 finished
|
||||
main: stream 3 finished
|
||||
|
||||
sequence 0:
|
||||
|
||||
Hello my name is Shirley. I am a 25-year-old female who has been working for over 5 years as a b
|
||||
|
||||
sequence 1:
|
||||
|
||||
Hello my name is Renee and I'm a 32 year old female from the United States. I'm looking for a man between
|
||||
|
||||
sequence 2:
|
||||
|
||||
Hello my name is Diana. I am looking for a housekeeping job. I have experience with children and have my own transportation. I am
|
||||
|
||||
sequence 3:
|
||||
|
||||
Hello my name is Cody. I am a 3 year old neutered male. I am a very friendly cat. I am very playful and
|
||||
|
||||
main: decoded 108 tokens in 3.57 s, speed: 30.26 t/s
|
||||
|
||||
llama_print_timings: load time = 587.00 ms
|
||||
llama_print_timings: sample time = 2.56 ms / 112 runs ( 0.02 ms per token, 43664.72 tokens per second)
|
||||
llama_print_timings: prompt eval time = 4089.11 ms / 118 tokens ( 34.65 ms per token, 28.86 tokens per second)
|
||||
llama_print_timings: eval time = 0.00 ms / 1 runs ( 0.00 ms per token, inf tokens per second)
|
||||
llama_print_timings: total time = 4156.04 ms
|
||||
```
|
||||
261
examples/batched/batched.cpp
Normal file
@@ -0,0 +1,261 @@
|
||||
#include "arg.h"
|
||||
#include "common.h"
|
||||
#include "log.h"
|
||||
#include "llama.h"
|
||||
#include "sampling.h"
|
||||
|
||||
#include <algorithm>
|
||||
#include <cstdio>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
|
||||
static void print_usage(int, char ** argv) {
|
||||
LOG("\nexample usage:\n");
|
||||
LOG("\n %s -m model.gguf -p \"Hello my name is\" -n 32 -np 4\n", argv[0]);
|
||||
LOG("\n");
|
||||
}
|
||||
|
||||
int main(int argc, char ** argv) {
|
||||
common_params params;
|
||||
|
||||
params.prompt = "Hello my name is";
|
||||
params.n_predict = 32;
|
||||
|
||||
if (!common_params_parse(argc, argv, params, LLAMA_EXAMPLE_BATCHED, print_usage)) {
|
||||
return 1;
|
||||
}
|
||||
|
||||
common_init();
|
||||
|
||||
// number of parallel batches
|
||||
int n_parallel = params.n_parallel;
|
||||
|
||||
// total length of the sequences including the prompt
|
||||
int n_predict = params.n_predict;
|
||||
|
||||
// init LLM
|
||||
|
||||
llama_backend_init();
|
||||
llama_numa_init(params.numa);
|
||||
|
||||
// initialize the model
|
||||
|
||||
llama_model_params model_params = common_model_params_to_llama(params);
|
||||
|
||||
llama_model * model = llama_model_load_from_file(params.model.path.c_str(), model_params);
|
||||
|
||||
if (model == NULL) {
|
||||
LOG_ERR("%s: error: unable to load model\n" , __func__);
|
||||
return 1;
|
||||
}
|
||||
|
||||
const llama_vocab * vocab = llama_model_get_vocab(model);
|
||||
|
||||
// tokenize the prompt
|
||||
|
||||
std::vector<llama_token> tokens_list;
|
||||
tokens_list = common_tokenize(vocab, params.prompt, true);
|
||||
|
||||
const int n_kv_req = tokens_list.size() + (n_predict - tokens_list.size())*n_parallel;
|
||||
|
||||
// initialize the context
|
||||
|
||||
llama_context_params ctx_params = common_context_params_to_llama(params);
|
||||
|
||||
ctx_params.n_ctx = n_kv_req;
|
||||
ctx_params.n_batch = std::max(n_predict, n_parallel);
|
||||
|
||||
auto sparams = llama_sampler_chain_default_params();
|
||||
sparams.no_perf = false;
|
||||
|
||||
std::vector<llama_sampler_seq_config> sampler_configs;
|
||||
|
||||
for (int32_t i = 0; i < n_parallel; ++i) {
|
||||
llama_sampler * smpl = llama_sampler_chain_init(sparams);
|
||||
|
||||
llama_sampler_chain_add(smpl, llama_sampler_init_top_k(params.sampling.top_k));
|
||||
llama_sampler_chain_add(smpl, llama_sampler_init_top_p(params.sampling.top_p, params.sampling.min_keep));
|
||||
llama_sampler_chain_add(smpl, llama_sampler_init_temp (params.sampling.temp));
|
||||
llama_sampler_chain_add(smpl, llama_sampler_init_dist (params.sampling.seed));
|
||||
|
||||
sampler_configs.push_back({ i, smpl });
|
||||
}
|
||||
|
||||
if (params.sampling.backend_sampling) {
|
||||
ctx_params.samplers = sampler_configs.data();
|
||||
ctx_params.n_samplers = sampler_configs.size();
|
||||
}
|
||||
|
||||
llama_context * ctx = llama_init_from_model(model, ctx_params);
|
||||
|
||||
if (ctx == NULL) {
|
||||
LOG_ERR("%s: error: failed to create the llama_context\n" , __func__);
|
||||
return 1;
|
||||
}
|
||||
|
||||
const int n_ctx = llama_n_ctx(ctx);
|
||||
|
||||
LOG_INF("\n%s: n_predict = %d, n_ctx = %d, n_batch = %u, n_parallel = %d, n_kv_req = %d\n", __func__, n_predict, n_ctx, ctx_params.n_batch, n_parallel, n_kv_req);
|
||||
|
||||
// make sure the KV cache is big enough to hold all the prompt and generated tokens
|
||||
if (n_kv_req > n_ctx) {
|
||||
LOG_ERR("%s: error: n_kv_req (%d) > n_ctx, the required KV cache size is not big enough\n", __func__, n_kv_req);
|
||||
LOG_ERR("%s: either reduce n_parallel or increase n_ctx\n", __func__);
|
||||
return 1;
|
||||
}
|
||||
|
||||
// print the prompt token-by-token
|
||||
|
||||
LOG("\n");
|
||||
|
||||
for (auto id : tokens_list) {
|
||||
LOG("%s", common_token_to_piece(ctx, id).c_str());
|
||||
}
|
||||
|
||||
// create a llama_batch
|
||||
// we use this object to submit token data for decoding
|
||||
llama_batch batch = llama_batch_init(std::max(tokens_list.size(), (size_t) n_parallel), 0, n_parallel);
|
||||
|
||||
std::vector<llama_seq_id> seq_ids(n_parallel, 0);
|
||||
for (int32_t i = 0; i < n_parallel; ++i) {
|
||||
seq_ids[i] = i;
|
||||
}
|
||||
|
||||
// evaluate the initial prompt
|
||||
for (size_t i = 0; i < tokens_list.size(); ++i) {
|
||||
common_batch_add(batch, tokens_list[i], i, seq_ids, false);
|
||||
}
|
||||
GGML_ASSERT(batch.n_tokens == (int) tokens_list.size());
|
||||
|
||||
if (llama_model_has_encoder(model)) {
|
||||
if (llama_encode(ctx, batch)) {
|
||||
LOG_ERR("%s : failed to eval\n", __func__);
|
||||
return 1;
|
||||
}
|
||||
|
||||
llama_token decoder_start_token_id = llama_model_decoder_start_token(model);
|
||||
if (decoder_start_token_id == LLAMA_TOKEN_NULL) {
|
||||
decoder_start_token_id = llama_vocab_bos(vocab);
|
||||
}
|
||||
|
||||
common_batch_clear(batch);
|
||||
common_batch_add(batch, decoder_start_token_id, 0, seq_ids, false);
|
||||
}
|
||||
|
||||
// llama_decode will output logits only for the last token of the prompt
|
||||
batch.logits[batch.n_tokens - 1] = true;
|
||||
|
||||
if (llama_decode(ctx, batch) != 0) {
|
||||
LOG_ERR("%s: llama_decode() failed\n", __func__);
|
||||
return 1;
|
||||
}
|
||||
|
||||
//// assign the system KV cache to all parallel sequences
|
||||
//// this way, the parallel sequences will "reuse" the prompt tokens without having to copy them
|
||||
//for (int32_t i = 1; i < n_parallel; ++i) {
|
||||
// llama_kv_cache_seq_cp(ctx, 0, i, -1, -1);
|
||||
//}
|
||||
|
||||
if (n_parallel > 1) {
|
||||
LOG("\n\n%s: generating %d sequences ...\n", __func__, n_parallel);
|
||||
}
|
||||
|
||||
// main loop
|
||||
|
||||
// we will store the parallel decoded sequences in this vector
|
||||
std::vector<std::string> streams(n_parallel);
|
||||
|
||||
// remember the batch index of the last token for each parallel sequence
|
||||
// we need this to determine which logits to sample from
|
||||
std::vector<int32_t> i_batch(n_parallel, batch.n_tokens - 1);
|
||||
|
||||
int n_cur = batch.n_tokens;
|
||||
int n_decode = 0;
|
||||
|
||||
const auto t_main_start = ggml_time_us();
|
||||
|
||||
while (n_cur <= n_predict) {
|
||||
// prepare the next batch
|
||||
common_batch_clear(batch);
|
||||
|
||||
// sample the next token for each parallel sequence / stream
|
||||
for (int32_t i = 0; i < n_parallel; ++i) {
|
||||
if (i_batch[i] < 0) {
|
||||
// the stream has already finished
|
||||
continue;
|
||||
}
|
||||
|
||||
const llama_token new_token_id = llama_sampler_sample(sampler_configs[i].sampler, ctx, i_batch[i]);
|
||||
|
||||
// is it an end of generation? -> mark the stream as finished
|
||||
if (llama_vocab_is_eog(vocab, new_token_id) || n_cur == n_predict) {
|
||||
i_batch[i] = -1;
|
||||
LOG("\n");
|
||||
if (n_parallel > 1) {
|
||||
LOG_INF("%s: stream %d finished at n_cur = %d", __func__, i, n_cur);
|
||||
}
|
||||
|
||||
continue;
|
||||
}
|
||||
|
||||
// if there is only one stream, we print immediately to stdout
|
||||
if (n_parallel == 1) {
|
||||
LOG("%s", common_token_to_piece(ctx, new_token_id).c_str());
|
||||
}
|
||||
|
||||
streams[i] += common_token_to_piece(ctx, new_token_id);
|
||||
|
||||
i_batch[i] = batch.n_tokens;
|
||||
|
||||
// push this new token for next evaluation
|
||||
common_batch_add(batch, new_token_id, n_cur, { i }, true);
|
||||
|
||||
n_decode += 1;
|
||||
}
|
||||
|
||||
// all streams are finished
|
||||
if (batch.n_tokens == 0) {
|
||||
break;
|
||||
}
|
||||
|
||||
n_cur += 1;
|
||||
|
||||
// evaluate the current batch with the transformer model
|
||||
if (llama_decode(ctx, batch)) {
|
||||
LOG_ERR("%s : failed to eval, return code %d\n", __func__, 1);
|
||||
return 1;
|
||||
}
|
||||
}
|
||||
|
||||
if (n_parallel > 1) {
|
||||
LOG("\n");
|
||||
|
||||
for (int32_t i = 0; i < n_parallel; ++i) {
|
||||
LOG("sequence %d:\n\n%s%s\n\n", i, params.prompt.c_str(), streams[i].c_str());
|
||||
}
|
||||
}
|
||||
|
||||
const auto t_main_end = ggml_time_us();
|
||||
|
||||
LOG_INF("%s: decoded %d tokens in %.2f s, speed: %.2f t/s\n",
|
||||
__func__, n_decode, (t_main_end - t_main_start) / 1000000.0f, n_decode / ((t_main_end - t_main_start) / 1000000.0f));
|
||||
|
||||
LOG("\n");
|
||||
llama_perf_sampler_print(sampler_configs[0].sampler);
|
||||
llama_perf_context_print(ctx);
|
||||
|
||||
fprintf(stderr, "\n");
|
||||
|
||||
llama_batch_free(batch);
|
||||
|
||||
for (auto & sampler_config : sampler_configs) {
|
||||
llama_sampler_free(sampler_config.sampler);
|
||||
}
|
||||
|
||||
llama_free(ctx);
|
||||
llama_model_free(model);
|
||||
|
||||
llama_backend_free();
|
||||
|
||||
return 0;
|
||||
}
|
||||
5
examples/convert-llama2c-to-ggml/CMakeLists.txt
Normal file
@@ -0,0 +1,5 @@
|
||||
set(TARGET llama-convert-llama2c-to-ggml)
|
||||
add_executable(${TARGET} convert-llama2c-to-ggml.cpp)
|
||||
install(TARGETS ${TARGET} RUNTIME)
|
||||
target_link_libraries(${TARGET} PRIVATE common llama ${CMAKE_THREAD_LIBS_INIT})
|
||||
target_compile_features(${TARGET} PRIVATE cxx_std_17)
|
||||
25
examples/convert-llama2c-to-ggml/README.md
Normal file
@@ -0,0 +1,25 @@
|
||||
## Convert llama2.c model to ggml
|
||||
|
||||
This example reads weights from project [llama2.c](https://github.com/karpathy/llama2.c) and saves them in ggml compatible format. The vocab that is available in `models/ggml-vocab.bin` is used by default.
|
||||
|
||||
To convert the model first download the models from the [llama2.c](https://github.com/karpathy/llama2.c) repository.
|
||||
|
||||
```
|
||||
usage: ./llama-convert-llama2c-to-ggml [options]
|
||||
|
||||
options:
|
||||
-h, --help show this help message and exit
|
||||
--copy-vocab-from-model FNAME path of gguf llama model or llama2.c vocabulary from which to copy vocab (default 'models/7B/ggml-model-f16.gguf')
|
||||
--llama2c-model FNAME [REQUIRED] model path from which to load Karpathy's llama2.c model
|
||||
--llama2c-output-model FNAME model path to save the converted llama2.c model (default ak_llama_model.bin')
|
||||
```
|
||||
|
||||
An example command using a model from [karpathy/tinyllamas](https://huggingface.co/karpathy/tinyllamas) is as follows:
|
||||
|
||||
`$ ./llama-convert-llama2c-to-ggml --copy-vocab-from-model llama-2-7b-chat.gguf.q2_K.bin --llama2c-model stories42M.bin --llama2c-output-model stories42M.gguf.bin`
|
||||
|
||||
Note: The vocabulary for `stories260K.bin` should be its own tokenizer `tok512.bin` found in [karpathy/tinyllamas/stories260K](https://huggingface.co/karpathy/tinyllamas/tree/main/stories260K).
|
||||
|
||||
Now you can use the model with a command like:
|
||||
|
||||
`$ ./llama-cli -m stories42M.gguf.bin -p "One day, Lily met a Shoggoth" -n 500 -c 256`
|
||||
941
examples/convert-llama2c-to-ggml/convert-llama2c-to-ggml.cpp
Normal file
@@ -0,0 +1,941 @@
|
||||
#include "ggml.h"
|
||||
#include "gguf.h"
|
||||
|
||||
#include "llama.h"
|
||||
#include "common.h"
|
||||
#include "log.h"
|
||||
|
||||
#include <unordered_map>
|
||||
#include <vector>
|
||||
#include <cassert>
|
||||
#include <climits>
|
||||
#include <cstring>
|
||||
#include <cstdarg>
|
||||
#include <cinttypes>
|
||||
#include <ctime>
|
||||
#include <random>
|
||||
#include <stdexcept>
|
||||
#include <sstream>
|
||||
#include <algorithm>
|
||||
#include <string>
|
||||
|
||||
// GGUF keys & tensor names.
|
||||
|
||||
#define KV_GENERAL_ARCHITECTURE "general.architecture"
|
||||
#define KV_GENERAL_NAME "general.name"
|
||||
|
||||
#define KV_TOKENIZER_MODEL "tokenizer.ggml.model"
|
||||
#define KV_TOKENIZER_LIST "tokenizer.ggml.tokens"
|
||||
#define KV_TOKENIZER_TOKEN_TYPE "tokenizer.ggml.token_type"
|
||||
#define KV_TOKENIZER_SCORES "tokenizer.ggml.scores"
|
||||
#define KV_TOKENIZER_BOS_ID "tokenizer.ggml.bos_token_id"
|
||||
#define KV_TOKENIZER_EOS_ID "tokenizer.ggml.eos_token_id"
|
||||
#define KV_TOKENIZER_UNK_ID "tokenizer.ggml.unknown_token_id"
|
||||
#define KV_TOKENIZER_SEP_ID "tokenizer.ggml.seperator_token_id"
|
||||
#define KV_TOKENIZER_PAD_ID "tokenizer.ggml.padding_token_id"
|
||||
#define KV_TOKENIZER_HF_JSON "tokenizer.huggingface.json"
|
||||
|
||||
#define KV_CONTEXT_LENGTH "llama.context_length"
|
||||
#define KV_EMBEDDING_LENGTH "llama.embedding_length"
|
||||
#define KV_BLOCK_COUNT "llama.block_count"
|
||||
#define KV_FEED_FORWARD_LENGTH "llama.feed_forward_length"
|
||||
#define KV_ATTENTION_HEAD_COUNT "llama.attention.head_count"
|
||||
#define KV_ATTENTION_HEAD_COUNT_KV "llama.attention.head_count_kv"
|
||||
#define KV_ATTENTION_LAYERNORM_RMS_EPS "llama.attention.layer_norm_rms_epsilon"
|
||||
#define KV_ROPE_DIMENSION_COUNT "llama.rope.dimension_count"
|
||||
|
||||
#define TN_TOKEN_EMBD "token_embd.weight"
|
||||
#define TN_OUTPUT_NORM "output_norm.weight"
|
||||
#define TN_OUTPUT "output.weight"
|
||||
#define TN_ATTN_NORM "blk.%d.attn_norm.weight"
|
||||
#define TN_ATTN_Q "blk.%d.attn_q.weight"
|
||||
#define TN_ATTN_K "blk.%d.attn_k.weight"
|
||||
#define TN_ATTN_V "blk.%d.attn_v.weight"
|
||||
#define TN_ATTN_OUTPUT "blk.%d.attn_output.weight"
|
||||
#define TN_FFN_NORM "blk.%d.ffn_norm.weight"
|
||||
#define TN_FFN_GATE "blk.%d.ffn_gate.weight"
|
||||
#define TN_FFN_DOWN "blk.%d.ffn_down.weight"
|
||||
#define TN_FFN_UP "blk.%d.ffn_up.weight"
|
||||
|
||||
#if defined(_MSC_VER)
|
||||
#pragma warning(disable: 4244 4267) // possible loss of data
|
||||
#endif
|
||||
|
||||
#define LLAMA_FILE_MAGIC_GGJT 0x67676a74u // 'ggjt'
|
||||
#define LLAMA_FILE_VERSION_GGJT_V3 3
|
||||
|
||||
#define TOKENIZER_NAME "llama"
|
||||
#define UNKNOWN_TOKEN_ID 0
|
||||
#define BOS_TOKEN_ID 1
|
||||
#define EOS_TOKEN_ID 2
|
||||
|
||||
//////////////////////////////////////// llama2.c model structs and functions to load models, alloc memory etc.
|
||||
typedef struct {
|
||||
int dim; // transformer dimension
|
||||
int hidden_dim; // for ffn layers
|
||||
int n_layers; // number of layers
|
||||
int n_heads; // number of query heads
|
||||
int n_kv_heads; // number of key/value heads (can be < query heads because of multiquery)
|
||||
int vocab_size; // vocabulary size, usually 256 (byte-level)
|
||||
int seq_len; // max sequence length
|
||||
} Config;
|
||||
|
||||
struct TransformerWeights {
|
||||
// token embedding table
|
||||
std::vector<float> token_embedding_table; // (vocab_size, dim)
|
||||
// weights for rmsnorms
|
||||
std::vector<float> rms_att_weight; // (layer, dim) rmsnorm weights
|
||||
std::vector<float> rms_ffn_weight; // (layer, dim)
|
||||
// weights for matmuls
|
||||
std::vector<float> wq; // (layer, dim, dim)
|
||||
std::vector<float> wk; // (layer, dim, dim)
|
||||
std::vector<float> wv; // (layer, dim, dim)
|
||||
std::vector<float> wo; // (layer, dim, dim)
|
||||
// weights for ffn
|
||||
std::vector<float> w1; // (layer, hidden_dim, dim)
|
||||
std::vector<float> w2; // (layer, dim, hidden_dim)
|
||||
std::vector<float> w3; // (layer, hidden_dim, dim)
|
||||
// final rmsnorm
|
||||
std::vector<float> rms_final_weight; // (dim,)
|
||||
// freq_cis for RoPE relatively positional embeddings
|
||||
// std::vector<float> freq_cis_real; // (seq_len, dim/2)
|
||||
// std::vector<float> freq_cis_imag; // (seq_len, dim/2)
|
||||
// (optional) classifier weights for the logits, on the last layer
|
||||
std::vector<float> wcls;
|
||||
};
|
||||
|
||||
static void alloc_weights(TransformerWeights * w, const Config * p, bool shared_weights) {
|
||||
const int n_multiqueries = p->n_kv_heads <= 0 || p->n_kv_heads >= p->n_heads ? 1 : p->n_heads / p->n_kv_heads;
|
||||
try {
|
||||
w->token_embedding_table.resize(p->vocab_size * p->dim);
|
||||
LOG_INF("%s: Allocating [%d] x [%d] = [%d] float space for w->token_embedding_table\n",__func__,p->vocab_size , p->dim, p->vocab_size * p->dim);
|
||||
|
||||
w->rms_att_weight.resize(p->n_layers * p->dim);
|
||||
LOG_INF("%s: Allocating [%d] x [%d] = [%d] float space for w->rms_att_weight\n",__func__,p->n_layers, p->dim, p->n_layers * p->dim);
|
||||
|
||||
w->rms_ffn_weight.resize(p->n_layers * p->dim);
|
||||
LOG_INF("%s: Allocating [%d] x [%d] = [%d] float space for w->rms_ffn_weight\n",__func__,p->n_layers , p->dim, p->n_layers * p->dim);
|
||||
|
||||
w->wq.resize(p->n_layers * p->dim * p->dim);
|
||||
LOG_INF("%s: Allocating [%d] x [%d] x [%d] = [%d] float space for w->wq\n",__func__,p->n_layers, p->dim, p->dim, p->n_layers * p->dim * p->dim);
|
||||
|
||||
w->wk.resize(p->n_layers * p->dim * p->dim / n_multiqueries);
|
||||
LOG_INF("%s: Allocating [%d] x [%d] x [%d] = [%d] float space for w->wk\n",__func__,p->n_layers, p->dim, p->dim / n_multiqueries, p->n_layers * p->dim * p->dim / n_multiqueries);
|
||||
|
||||
w->wv.resize(p->n_layers * p->dim * p->dim / n_multiqueries);
|
||||
LOG_INF("%s: Allocating [%d] x [%d] x [%d] = [%d] float space for w->wv\n",__func__, p->n_layers, p->dim, p->dim / n_multiqueries, p->n_layers * p->dim * p->dim / n_multiqueries);
|
||||
|
||||
w->wo.resize(p->n_layers * p->dim * p->dim);
|
||||
LOG_INF("%s: Allocating [%d] x [%d] x [%d] = [%d] float space for w->wo\n",__func__,p->n_layers, p->dim, p->dim, p->n_layers * p->dim * p->dim);
|
||||
|
||||
w->w1.resize(p->n_layers * p->hidden_dim * p->dim);
|
||||
LOG_INF("%s: Allocating [%d] x [%d] x [%d] = [%d] float space for w->w1\n",__func__,p->n_layers, p->hidden_dim, p->dim, p->n_layers * p->hidden_dim * p->dim);
|
||||
|
||||
w->w2.resize(p->n_layers * p->hidden_dim * p->dim);
|
||||
LOG_INF("%s: Allocating [%d] x [%d] x [%d] = [%d] float space for w->w2\n",__func__,p->n_layers, p->dim, p->hidden_dim, p->n_layers * p->hidden_dim * p->dim);
|
||||
|
||||
w->w3.resize(p->n_layers * p->hidden_dim * p->dim);
|
||||
LOG_INF("%s: Allocating [%d] x [%d] x [%d] = [%d] float space for w->w3\n",__func__,p->n_layers, p->hidden_dim, p->dim, p->n_layers * p->hidden_dim * p->dim);
|
||||
|
||||
w->rms_final_weight.resize(p->dim);
|
||||
LOG_INF("%s: Allocating [%d] float space for w->rms_final_weight\n",__func__,p->dim);
|
||||
|
||||
if (shared_weights) {
|
||||
w->wcls = {};
|
||||
} else {
|
||||
w->wcls.resize(p->vocab_size * p->dim);
|
||||
LOG_INF("%s: Allocating [%d] x [%d] = [%d] float space for w->wcls\n",__func__,p->vocab_size , p->dim, p->vocab_size * p->dim);
|
||||
}
|
||||
}
|
||||
catch (std::length_error &) {
|
||||
die("Invalid configuration. Failed to allocate memory for weights");
|
||||
}
|
||||
}
|
||||
|
||||
static int checkpoint_init_weights(TransformerWeights * w, const Config * p, FILE * f, bool shared_weights) {
|
||||
if (fread(w->token_embedding_table.data(), sizeof(float), w->token_embedding_table.size(), f) != w->token_embedding_table.size()) return 1;
|
||||
if (fread(w->rms_att_weight.data(), sizeof(float), w->rms_att_weight.size(), f) != w->rms_att_weight.size()) return 1;
|
||||
if (fread(w->wq.data(), sizeof(float), w->wq.size(), f) != w->wq.size()) return 1;
|
||||
if (fread(w->wk.data(), sizeof(float), w->wk.size(), f) != w->wk.size()) return 1;
|
||||
if (fread(w->wv.data(), sizeof(float), w->wv.size(), f) != w->wv.size()) return 1;
|
||||
if (fread(w->wo.data(), sizeof(float), w->wo.size(), f) != w->wo.size()) return 1;
|
||||
if (fread(w->rms_ffn_weight.data(), sizeof(float), w->rms_ffn_weight.size(), f) != w->rms_ffn_weight.size()) return 1;
|
||||
if (fread(w->w1.data(), sizeof(float), w->w1.size(), f) != w->w1.size()) return 1;
|
||||
if (fread(w->w2.data(), sizeof(float), w->w2.size(), f) != w->w2.size()) return 1;
|
||||
if (fread(w->w3.data(), sizeof(float), w->w3.size(), f) != w->w3.size()) return 1;
|
||||
if (fread(w->rms_final_weight.data(), sizeof(float), w->rms_final_weight.size(), f) != w->rms_final_weight.size()) return 1;
|
||||
|
||||
// Skip freq_cis_real & freq_cis_imag
|
||||
int head_size = p->dim / p->n_heads;
|
||||
fseek(f, p->seq_len * head_size * sizeof(float), SEEK_CUR);
|
||||
|
||||
if (!shared_weights && fread(w->wcls.data(), sizeof(float), w->wcls.size(), f) != w->wcls.size()) return 1;
|
||||
|
||||
// Check we didn't forget to read anything
|
||||
auto curr = ftell(f);
|
||||
fseek(f, 0, SEEK_END);
|
||||
auto end = ftell(f);
|
||||
if (curr != end) {
|
||||
LOG_ERR("%s: Error: failed to read the checkpoint file to the end (curr = %ld, end = %ld)\n", __func__, curr, end);
|
||||
return 1;
|
||||
}
|
||||
|
||||
return 0;
|
||||
}
|
||||
|
||||
static void print_sample_weights(TransformerWeights *w){
|
||||
LOG_INF("----- Quick print of first of the weight vales of all the variables\n");
|
||||
LOG_INF("%f\n", w->token_embedding_table[0]);
|
||||
LOG_INF("%f\n", w->rms_att_weight[0]);
|
||||
LOG_INF("%f\n", w->rms_ffn_weight[0]);
|
||||
|
||||
LOG_INF("%f\n", w->wq[0]);
|
||||
LOG_INF("%f\n", w->wk[0]);
|
||||
LOG_INF("%f\n", w->wv[0]);
|
||||
LOG_INF("%f\n", w->wo[0]);
|
||||
LOG_INF("%f\n", w->w1[0]);
|
||||
LOG_INF("%f\n", w->w2[0]);
|
||||
LOG_INF("%f\n", w->w3[0]);
|
||||
LOG_INF("%f\n", w->rms_att_weight[0]);
|
||||
if (!w->wcls.empty()) LOG_INF("%f\n", w->wcls[0]);
|
||||
}
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
//////////////////////////////////////// ggml structs and functions required to load models, configs and save the model.
|
||||
|
||||
struct my_llama_vocab {
|
||||
using id = int32_t;
|
||||
using token = std::string;
|
||||
using ttype = llama_token_type;
|
||||
|
||||
struct token_data {
|
||||
token text;
|
||||
float score;
|
||||
ttype type;
|
||||
};
|
||||
|
||||
std::unordered_map<token, id> token_to_id;
|
||||
std::vector<token_data> id_to_token;
|
||||
};
|
||||
|
||||
struct my_llama_hparams {
|
||||
uint32_t n_vocab = 32000;
|
||||
uint32_t n_ctx = 512; // this is provided as user input?
|
||||
uint32_t n_embd = 4096;
|
||||
uint32_t n_ff = 11008;
|
||||
uint32_t n_mult = 4;
|
||||
uint32_t n_head = 32;
|
||||
uint32_t n_head_kv = 32;
|
||||
uint32_t n_layer = 32;
|
||||
uint32_t n_rot = 64;
|
||||
|
||||
bool operator!=(const my_llama_hparams& other) const {
|
||||
return memcmp(this, &other, sizeof(my_llama_hparams));
|
||||
}
|
||||
};
|
||||
|
||||
struct my_llama_layer {
|
||||
// normalization
|
||||
struct ggml_tensor * attention_norm;
|
||||
|
||||
// attention
|
||||
struct ggml_tensor * wq;
|
||||
struct ggml_tensor * wk;
|
||||
struct ggml_tensor * wv;
|
||||
struct ggml_tensor * wo;
|
||||
|
||||
// normalization
|
||||
struct ggml_tensor * ffn_norm;
|
||||
|
||||
// ff
|
||||
struct ggml_tensor * w1;
|
||||
struct ggml_tensor * w2;
|
||||
struct ggml_tensor * w3;
|
||||
};
|
||||
|
||||
struct my_llama_model {
|
||||
struct ggml_context * ctx = NULL;
|
||||
|
||||
std::string name;
|
||||
|
||||
my_llama_hparams hparams;
|
||||
|
||||
struct ggml_tensor * tok_embeddings;
|
||||
|
||||
struct ggml_tensor * norm;
|
||||
struct ggml_tensor * output;
|
||||
|
||||
std::vector<my_llama_layer> layers;
|
||||
|
||||
uint32_t train_its = 0;
|
||||
uint32_t train_samples = 0;
|
||||
uint32_t train_tokens = 0;
|
||||
};
|
||||
|
||||
struct train_params {
|
||||
const char * fn_vocab_model;
|
||||
const char * fn_llama2c_model;
|
||||
const char * fn_llama2c_output_model;
|
||||
const char * fn_train_data;
|
||||
const char * fn_checkpoint_in;
|
||||
const char * fn_checkpoint_out;
|
||||
const char * fn_model_out;
|
||||
|
||||
uint32_t seed;
|
||||
|
||||
int n_ctx;
|
||||
int n_embd;
|
||||
int n_mult;
|
||||
int n_head;
|
||||
int n_layer;
|
||||
int n_rotmax;
|
||||
|
||||
int n_threads;
|
||||
int n_batch;
|
||||
int n_examples;
|
||||
int n_predict;
|
||||
|
||||
int print_info_interval;
|
||||
int print_details_interval;
|
||||
|
||||
bool samples_start_after_nl;
|
||||
bool use_adam;
|
||||
bool use_flash;
|
||||
bool use_scratch;
|
||||
|
||||
// only adam
|
||||
int warmup;
|
||||
int cos_decay_steps;
|
||||
float cos_decay_restart;
|
||||
float cos_decay_alpha;
|
||||
|
||||
int lbfgs_n_iter;
|
||||
int adam_n_iter;
|
||||
float adam_alpha;
|
||||
float adam_decay;
|
||||
|
||||
int mem_model_gb;
|
||||
int mem_compute_gb;
|
||||
int mem_compute0_gb;
|
||||
int mem_compute1_gb;
|
||||
};
|
||||
|
||||
static void print_params(struct my_llama_hparams * params) {
|
||||
LOG_INF("%s: n_vocab: %u\n", __func__, params->n_vocab);
|
||||
LOG_INF("%s: n_ctx: %u\n", __func__, params->n_ctx);
|
||||
LOG_INF("%s: n_embd: %u\n", __func__, params->n_embd);
|
||||
LOG_INF("%s: n_mult: %u\n", __func__, params->n_mult);
|
||||
LOG_INF("%s: n_head: %u\n", __func__, params->n_head);
|
||||
LOG_INF("%s: n_head_kv: %u\n", __func__, params->n_head_kv);
|
||||
LOG_INF("%s: n_ff: %u\n", __func__, params->n_ff);
|
||||
LOG_INF("%s: n_layer: %u\n", __func__, params->n_layer);
|
||||
LOG_INF("%s: n_rot: %u\n", __func__, params->n_rot);
|
||||
}
|
||||
|
||||
static void print_tensor_info(const struct ggml_context * ctx) {
|
||||
for (auto * t = ggml_get_first_tensor(ctx); t != NULL; t = ggml_get_next_tensor(ctx, t)) {
|
||||
LOG_INF("%s: Allocating ", __func__);
|
||||
int64_t total = 1;
|
||||
int i = 0;
|
||||
for (; i < ggml_n_dims(t); ++i) {
|
||||
if (i > 0) { LOG_INF("x "); }
|
||||
LOG_INF("[%" PRId64 "] ", t->ne[i]);
|
||||
total *= t->ne[i];
|
||||
}
|
||||
if (i > 1) { LOG_INF("= [%" PRId64 "] ", total); }
|
||||
LOG_INF("float space for %s\n", ggml_get_name(t));
|
||||
}
|
||||
}
|
||||
|
||||
static void init_model(struct my_llama_model * model) {
|
||||
const auto & hparams = model->hparams;
|
||||
|
||||
const uint32_t n_embd = hparams.n_embd;
|
||||
const uint32_t n_layer = hparams.n_layer;
|
||||
const uint32_t n_vocab = hparams.n_vocab;
|
||||
|
||||
const uint32_t n_multiqueries = hparams.n_head_kv <= 0 || hparams.n_head_kv >= hparams.n_head ? 1 : hparams.n_head / hparams.n_head_kv;
|
||||
|
||||
const uint32_t n_ff = hparams.n_ff;
|
||||
struct ggml_context * ctx = model->ctx;
|
||||
|
||||
model->train_its = 0;
|
||||
model->train_samples = 0;
|
||||
model->train_tokens = 0;
|
||||
|
||||
model->tok_embeddings = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, n_embd, n_vocab);
|
||||
model->norm = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_embd);
|
||||
model->output = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, n_embd, n_vocab);
|
||||
|
||||
ggml_set_name(model->tok_embeddings, "tok_embeddings.weight");
|
||||
ggml_set_name(model->norm, "norm.weight");
|
||||
ggml_set_name(model->output, "output.weight");
|
||||
|
||||
model->layers.resize(n_layer);
|
||||
for (uint32_t i = 0; i < n_layer; ++i) {
|
||||
auto & layer = model->layers[i];
|
||||
|
||||
std::string layers_i = "layers." + std::to_string(i);
|
||||
|
||||
layer.attention_norm = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_embd);
|
||||
|
||||
layer.wq = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, n_embd, n_embd);
|
||||
layer.wk = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, n_embd, n_embd / n_multiqueries);
|
||||
layer.wv = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, n_embd, n_embd / n_multiqueries);
|
||||
layer.wo = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, n_embd, n_embd);
|
||||
|
||||
layer.ffn_norm = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_embd);
|
||||
|
||||
layer.w1 = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, n_embd, n_ff);
|
||||
layer.w2 = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, n_ff, n_embd);
|
||||
layer.w3 = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, n_embd, n_ff);
|
||||
|
||||
ggml_set_name(layer.attention_norm, (layers_i + ".attention_norm.weight").c_str());
|
||||
|
||||
ggml_set_name(layer.wq, (layers_i + ".attention.wq.weight").c_str());
|
||||
ggml_set_name(layer.wk, (layers_i + ".attention.wk.weight").c_str());
|
||||
ggml_set_name(layer.wv, (layers_i + ".attention.wv.weight").c_str());
|
||||
ggml_set_name(layer.wo, (layers_i + ".attention.wo.weight").c_str());
|
||||
|
||||
ggml_set_name(layer.ffn_norm, (layers_i + ".ffn_norm.weight").c_str());
|
||||
|
||||
ggml_format_name(layer.w1, "%s.feed_forward.w1.weight", layers_i.c_str());
|
||||
ggml_format_name(layer.w2, "%s.feed_forward.w2.weight", layers_i.c_str());
|
||||
ggml_format_name(layer.w3, "%s.feed_forward.w3.weight", layers_i.c_str());
|
||||
}
|
||||
|
||||
print_tensor_info(ctx);
|
||||
}
|
||||
|
||||
static float get_f32_2d(struct ggml_tensor * tensor, int64_t i0, int64_t i1) {
|
||||
float * ptr = (float *) ((char *) tensor->data + i0*tensor->nb[0] + i1*tensor->nb[1]);
|
||||
return *ptr;
|
||||
}
|
||||
|
||||
static int32_t get_i32_2d(struct ggml_tensor * tensor, int64_t i0, int64_t i1) {
|
||||
int32_t * ptr = (int32_t *) ((char *) tensor->data + i0*tensor->nb[0] + i1*tensor->nb[1]);
|
||||
return *ptr;
|
||||
}
|
||||
|
||||
static void print_row(struct ggml_tensor * probs, int i) {
|
||||
for (int k = 0; k < probs->ne[0]; ++k) {
|
||||
float p = get_f32_2d(probs, k, i);
|
||||
LOG(" %f", p);
|
||||
}
|
||||
LOG("\n");
|
||||
}
|
||||
|
||||
static void print_matrix(struct ggml_tensor * probs) {
|
||||
assert(ggml_is_matrix(probs));
|
||||
for (int i = 0; i < probs->ne[1]; ++i) {
|
||||
for (int k = 0; k < probs->ne[0]; ++k) {
|
||||
float p = get_f32_2d(probs, k, i);
|
||||
LOG(" %.2f", p);
|
||||
}
|
||||
LOG("\n");
|
||||
}
|
||||
}
|
||||
|
||||
struct my_llama_file {
|
||||
// use FILE * so we don't have to re-open the file to mmap
|
||||
FILE * fp;
|
||||
size_t size;
|
||||
|
||||
my_llama_file(const char * fname, const char * mode) {
|
||||
fp = std::fopen(fname, mode);
|
||||
if (fp == NULL) {
|
||||
size = 0;
|
||||
} else {
|
||||
seek(0, SEEK_END);
|
||||
size = tell();
|
||||
seek(0, SEEK_SET);
|
||||
}
|
||||
}
|
||||
|
||||
size_t tell() const {
|
||||
#ifdef _WIN32
|
||||
__int64 ret = _ftelli64(fp);
|
||||
#else
|
||||
long ret = std::ftell(fp);
|
||||
#endif
|
||||
GGML_ASSERT(ret != -1); // this really shouldn't fail
|
||||
return (size_t) ret;
|
||||
}
|
||||
|
||||
void seek(size_t offset, int whence) {
|
||||
#ifdef _WIN32
|
||||
int ret = _fseeki64(fp, (__int64) offset, whence);
|
||||
#else
|
||||
int ret = std::fseek(fp, (long) offset, whence);
|
||||
#endif
|
||||
GGML_ASSERT(ret == 0); // same
|
||||
}
|
||||
|
||||
void read_raw(void * ptr, size_t size) {
|
||||
if (size == 0) {
|
||||
return;
|
||||
}
|
||||
errno = 0;
|
||||
std::size_t ret = std::fread(ptr, size, 1, fp);
|
||||
if (ferror(fp)) {
|
||||
die_fmt("fread failed: %s", strerror(errno));
|
||||
}
|
||||
if (ret != 1) {
|
||||
die("unexpectedly reached end of file");
|
||||
}
|
||||
}
|
||||
|
||||
std::uint32_t read_u32() {
|
||||
std::uint32_t ret;
|
||||
read_raw(&ret, sizeof(ret));
|
||||
return ret;
|
||||
}
|
||||
std::float_t read_f32() {
|
||||
std::float_t ret;
|
||||
read_raw(&ret, sizeof(ret));
|
||||
return ret;
|
||||
}
|
||||
|
||||
std::string read_string(std::uint32_t len) {
|
||||
std::vector<char> chars(len);
|
||||
read_raw(chars.data(), len);
|
||||
return std::string(chars.data(), len);
|
||||
}
|
||||
|
||||
~my_llama_file() {
|
||||
if (fp) {
|
||||
std::fclose(fp);
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
static bool is_ggml_file(const char * filename) {
|
||||
my_llama_file file(filename, "rb");
|
||||
if (file.size < 4) {
|
||||
return false;
|
||||
}
|
||||
std::string magic = file.read_string(4);
|
||||
return magic == GGUF_MAGIC;
|
||||
}
|
||||
|
||||
static std::string llama_escape_whitespaces(const std::string & text) {
|
||||
std::ostringstream out;
|
||||
for (char c : text) {
|
||||
if (c == ' ') out << "\xe2\x96\x81";
|
||||
else out << c;
|
||||
}
|
||||
return out.str();
|
||||
}
|
||||
|
||||
static void load_vocab(const char * filename, const Config * config, struct my_llama_vocab * vocab) {
|
||||
if (is_ggml_file(filename)) {
|
||||
LOG_INF("%s: Loading vocabulary from gguf file %s\n", __func__, filename);
|
||||
struct ggml_context * ctx_data = NULL;
|
||||
|
||||
struct gguf_init_params params = {
|
||||
/*.no_alloc = */ false,
|
||||
/*.ctx = */ &ctx_data,
|
||||
};
|
||||
|
||||
struct gguf_context * ctx = gguf_init_from_file(filename, params);
|
||||
GGML_ASSERT(ctx != NULL);
|
||||
|
||||
const int model_idx = gguf_find_key(ctx, KV_TOKENIZER_MODEL);
|
||||
GGML_ASSERT(model_idx >= 0);
|
||||
std::string tokenizer_name = gguf_get_val_str(ctx, model_idx);
|
||||
GGML_ASSERT(tokenizer_name == TOKENIZER_NAME);
|
||||
|
||||
const int token_idx = gguf_find_key(ctx, KV_TOKENIZER_LIST);
|
||||
GGML_ASSERT(token_idx >= 0);
|
||||
|
||||
const int score_idx = gguf_find_key(ctx, KV_TOKENIZER_SCORES);
|
||||
GGML_ASSERT(score_idx >= 0);
|
||||
const float * scores = (const float * ) gguf_get_arr_data(ctx, score_idx);
|
||||
|
||||
const int toktype_idx = gguf_find_key(ctx, KV_TOKENIZER_TOKEN_TYPE);
|
||||
GGML_ASSERT(toktype_idx >= 0);
|
||||
const int * toktypes = (const int * ) gguf_get_arr_data(ctx, toktype_idx);
|
||||
|
||||
const uint32_t n_vocab = gguf_get_arr_n(ctx, token_idx);
|
||||
if (n_vocab != static_cast<uint32_t>(config->vocab_size)) {
|
||||
die_fmt("vocab size mismatch: (gguf) %u != (llama2c) %d", n_vocab, config->vocab_size);
|
||||
}
|
||||
|
||||
vocab->id_to_token.resize(n_vocab);
|
||||
|
||||
for (uint32_t i = 0; i < n_vocab; i++) {
|
||||
std::string word = gguf_get_arr_str(ctx, token_idx, i);
|
||||
|
||||
vocab->token_to_id[word] = i;
|
||||
|
||||
auto & token_data = vocab->id_to_token[i];
|
||||
token_data.text = std::move(word);
|
||||
token_data.score = scores[i];
|
||||
token_data.type = (llama_token_type) toktypes[i];
|
||||
}
|
||||
ggml_free(ctx_data);
|
||||
gguf_free(ctx);
|
||||
} else {
|
||||
// assume llama2.c vocabulary
|
||||
LOG_INF("%s: Assuming llama2.c vocabulary since %s is not a gguf file\n", __func__, filename);
|
||||
my_llama_file file(filename, "rb");
|
||||
if (!file.fp) {
|
||||
die_fmt("%s: %s", strerror(errno), filename);
|
||||
}
|
||||
const int n_vocab = config->vocab_size;
|
||||
/* uint32_t max_token_length = */ file.read_u32(); // unused
|
||||
vocab->id_to_token.resize(n_vocab);
|
||||
for (my_llama_vocab::id id=0; id<n_vocab; ++id) {
|
||||
float_t score = file.read_f32();
|
||||
uint32_t len = file.read_u32();
|
||||
std::string text = file.read_string(len);
|
||||
|
||||
unsigned char byte_val;
|
||||
my_llama_vocab::ttype type = LLAMA_TOKEN_TYPE_NORMAL;
|
||||
if (id == UNKNOWN_TOKEN_ID) {
|
||||
text = "<unk>";
|
||||
type = LLAMA_TOKEN_TYPE_UNKNOWN;
|
||||
} else if (id == BOS_TOKEN_ID) {
|
||||
text = "<s>";
|
||||
type = LLAMA_TOKEN_TYPE_CONTROL;
|
||||
} else if (id == EOS_TOKEN_ID) {
|
||||
text = "</s>";
|
||||
type = LLAMA_TOKEN_TYPE_CONTROL;
|
||||
} else if (text.empty()) {
|
||||
type = LLAMA_TOKEN_TYPE_CONTROL;
|
||||
} else if (sscanf(text.c_str(), "<0x%02hhX>", &byte_val) == 1) {
|
||||
// Text of byte tokens is already in the expected format.
|
||||
type = LLAMA_TOKEN_TYPE_BYTE;
|
||||
} else {
|
||||
type = LLAMA_TOKEN_TYPE_NORMAL;
|
||||
}
|
||||
text = llama_escape_whitespaces(text);
|
||||
|
||||
vocab->id_to_token[id].text = text;
|
||||
vocab->id_to_token[id].score = score;
|
||||
vocab->id_to_token[id].type = type;
|
||||
vocab->token_to_id.emplace(text, id);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
static void convert_weights_ak_to_gg(struct ggml_tensor * gg_weights, const float * karpathy_weights) {
|
||||
int size = 1;
|
||||
for (int dim = 0; dim < ggml_n_dims(gg_weights); ++dim) {
|
||||
size *= gg_weights->ne[dim];
|
||||
}
|
||||
for (int ct = 0; ct < size; ++ct) {
|
||||
int64_t i0 = 0; int64_t i1 = 0;
|
||||
int64_t i2 = 0; int64_t i3 = 0;
|
||||
ggml_unravel_index(gg_weights, ct, &i0, &i1, &i2, &i3);
|
||||
ggml_set_f32_nd(gg_weights, i0, i1, i2, i3, karpathy_weights[ct]);
|
||||
}
|
||||
}
|
||||
|
||||
static void save_as_llama_model(
|
||||
struct my_llama_vocab * vocab, struct my_llama_model * model, TransformerWeights* w, const char * filename
|
||||
) {
|
||||
// convert AK weights into GG weights one by one.
|
||||
// w->token_embedding_table -> model->tok_embeddings
|
||||
// float* -> struct ggml_tensor
|
||||
convert_weights_ak_to_gg(model->tok_embeddings, w->token_embedding_table.data());
|
||||
convert_weights_ak_to_gg(model->output, !w->wcls.empty() ? w->wcls.data() : w->token_embedding_table.data());
|
||||
|
||||
convert_weights_ak_to_gg(model->norm, w->rms_final_weight.data());
|
||||
//print_row(model->norm, 0);
|
||||
|
||||
// for rms-att-weight
|
||||
int row_length = model->hparams.n_embd;
|
||||
int n_ff = model->hparams.n_ff;
|
||||
|
||||
const uint32_t n_multiqueries = model->hparams.n_head_kv <= 0 || model->hparams.n_head_kv >= model->hparams.n_head ? 1 : model->hparams.n_head / model->hparams.n_head_kv;
|
||||
|
||||
for (uint32_t i = 0; i < model->hparams.n_layer; ++i){
|
||||
auto & layer = model->layers[i];
|
||||
// 1d
|
||||
convert_weights_ak_to_gg(layer.attention_norm, &w->rms_att_weight[i*row_length]);
|
||||
convert_weights_ak_to_gg(layer.ffn_norm , &w->rms_ffn_weight[i*row_length]);
|
||||
|
||||
// from 3d matrix layer x dim x dim to 2d matrix dim x dim
|
||||
convert_weights_ak_to_gg(layer.wq , &w->wq[i*row_length*row_length]);
|
||||
convert_weights_ak_to_gg(layer.wo , &w->wo[i*row_length*row_length]);
|
||||
// from 3d matrix layer x dim x dim to 2d matrix dim x dim / n_multiqueries
|
||||
convert_weights_ak_to_gg(layer.wk , &w->wk[i*row_length*row_length/n_multiqueries]);
|
||||
convert_weights_ak_to_gg(layer.wv , &w->wv[i*row_length*row_length/n_multiqueries]);
|
||||
|
||||
convert_weights_ak_to_gg(layer.w1 , &w->w1[i*row_length*n_ff]);
|
||||
convert_weights_ak_to_gg(layer.w2 , &w->w2[i*n_ff*row_length]);
|
||||
convert_weights_ak_to_gg(layer.w3 , &w->w3[i*row_length*n_ff]);
|
||||
}
|
||||
|
||||
struct gguf_context * ctx = gguf_init_empty();
|
||||
|
||||
std::vector<const char*> tokens;
|
||||
std::vector<float> scores;
|
||||
std::vector<llama_token_type> token_types;
|
||||
for (const my_llama_vocab::token_data & token_data : vocab->id_to_token) {
|
||||
tokens.push_back(token_data.text.c_str());
|
||||
scores.push_back(token_data.score);
|
||||
token_types.push_back(token_data.type);
|
||||
}
|
||||
gguf_set_arr_str(ctx, KV_TOKENIZER_LIST, tokens.data(), tokens.size());
|
||||
gguf_set_arr_data(ctx, KV_TOKENIZER_SCORES, GGUF_TYPE_FLOAT32, scores.data(), scores.size());
|
||||
gguf_set_arr_data(ctx, KV_TOKENIZER_TOKEN_TYPE, GGUF_TYPE_INT32, token_types.data(), token_types.size());
|
||||
|
||||
gguf_set_val_str(ctx, KV_TOKENIZER_MODEL, TOKENIZER_NAME);
|
||||
|
||||
gguf_set_val_str(ctx, KV_GENERAL_ARCHITECTURE, "llama");
|
||||
gguf_set_val_str(ctx, KV_GENERAL_NAME, "llama");
|
||||
|
||||
// special tokens
|
||||
gguf_set_val_u32(ctx, KV_TOKENIZER_UNK_ID, UNKNOWN_TOKEN_ID);
|
||||
gguf_set_val_u32(ctx, KV_TOKENIZER_BOS_ID, BOS_TOKEN_ID);
|
||||
gguf_set_val_u32(ctx, KV_TOKENIZER_EOS_ID, EOS_TOKEN_ID);
|
||||
gguf_set_val_u32(ctx, KV_TOKENIZER_SEP_ID, LLAMA_TOKEN_NULL);
|
||||
gguf_set_val_u32(ctx, KV_TOKENIZER_PAD_ID, LLAMA_TOKEN_NULL);
|
||||
|
||||
gguf_set_val_u32(ctx, KV_CONTEXT_LENGTH, model->hparams.n_ctx);
|
||||
gguf_set_val_u32(ctx, KV_EMBEDDING_LENGTH, model->hparams.n_embd);
|
||||
gguf_set_val_u32(ctx, KV_FEED_FORWARD_LENGTH, model->hparams.n_ff);
|
||||
gguf_set_val_u32(ctx, KV_ATTENTION_HEAD_COUNT, model->hparams.n_head);
|
||||
gguf_set_val_u32(ctx, KV_ATTENTION_HEAD_COUNT, model->hparams.n_head);
|
||||
gguf_set_val_u32(ctx, KV_ATTENTION_HEAD_COUNT_KV, model->hparams.n_head_kv);
|
||||
gguf_set_val_u32(ctx, KV_BLOCK_COUNT, model->hparams.n_layer);
|
||||
gguf_set_val_u32(ctx, KV_ROPE_DIMENSION_COUNT, model->hparams.n_rot);
|
||||
gguf_set_val_f32(ctx, KV_ATTENTION_LAYERNORM_RMS_EPS, 1e-5f);
|
||||
|
||||
// write tensors
|
||||
ggml_set_name(model->tok_embeddings, TN_TOKEN_EMBD);
|
||||
gguf_add_tensor(ctx, model->tok_embeddings);
|
||||
|
||||
ggml_set_name(model->norm, TN_OUTPUT_NORM);
|
||||
gguf_add_tensor(ctx, model->norm);
|
||||
|
||||
ggml_set_name(model->output, TN_OUTPUT);
|
||||
gguf_add_tensor(ctx, model->output);
|
||||
|
||||
for (uint32_t i = 0; i < model->hparams.n_layer; ++i) {
|
||||
auto & layer = model->layers[i];
|
||||
|
||||
ggml_format_name(layer.wq, TN_ATTN_Q, i);
|
||||
gguf_add_tensor(ctx, layer.wq);
|
||||
|
||||
ggml_format_name(layer.wk, TN_ATTN_K, i);
|
||||
gguf_add_tensor(ctx, layer.wk);
|
||||
|
||||
ggml_format_name(layer.wv, TN_ATTN_V, i);
|
||||
gguf_add_tensor(ctx, layer.wv);
|
||||
|
||||
ggml_format_name(layer.wo, TN_ATTN_OUTPUT, i);
|
||||
gguf_add_tensor(ctx, layer.wo);
|
||||
|
||||
ggml_format_name(layer.attention_norm, TN_ATTN_NORM, i);
|
||||
gguf_add_tensor(ctx, layer.attention_norm);
|
||||
|
||||
ggml_format_name(layer.w1, TN_FFN_GATE, i);
|
||||
gguf_add_tensor(ctx, layer.w1);
|
||||
|
||||
ggml_format_name(layer.w2, TN_FFN_DOWN, i);
|
||||
gguf_add_tensor(ctx, layer.w2);
|
||||
|
||||
ggml_format_name(layer.w3, TN_FFN_UP, i);
|
||||
gguf_add_tensor(ctx, layer.w3);
|
||||
|
||||
ggml_format_name(layer.ffn_norm, TN_FFN_NORM, i);
|
||||
gguf_add_tensor(ctx, layer.ffn_norm);
|
||||
}
|
||||
|
||||
gguf_write_to_file(ctx, filename, false);
|
||||
gguf_free(ctx);
|
||||
}
|
||||
|
||||
static struct train_params get_default_train_params() {
|
||||
struct train_params params;
|
||||
params.fn_vocab_model = "models/7B/ggml-model-f16.gguf";
|
||||
params.fn_llama2c_output_model = "ak_llama_model.bin";
|
||||
params.fn_train_data = "shakespeare.txt";
|
||||
params.fn_checkpoint_in = "checkpoint.bin";
|
||||
params.fn_checkpoint_out = "checkpoint.bin";
|
||||
params.fn_model_out = "ggml-checkpoint-f32.bin";
|
||||
|
||||
params.seed = -1;
|
||||
|
||||
params.n_ctx = 128;
|
||||
params.n_embd = 256;
|
||||
params.n_mult = 256;
|
||||
params.n_head = 8;
|
||||
params.n_layer = 16;
|
||||
params.n_rotmax = 64;
|
||||
|
||||
params.n_threads = 6;
|
||||
params.n_batch = 8;
|
||||
params.n_examples = 8;
|
||||
params.n_predict = 1024;
|
||||
|
||||
params.print_info_interval = 1;
|
||||
params.print_details_interval = 2;
|
||||
|
||||
params.samples_start_after_nl = false;
|
||||
params.use_adam = true;
|
||||
params.use_flash = false;
|
||||
params.use_scratch = true;
|
||||
|
||||
// only adam
|
||||
params.warmup = 100;
|
||||
params.cos_decay_steps = 1000;
|
||||
params.cos_decay_restart = 1.1f;
|
||||
params.cos_decay_alpha = 0.0f;
|
||||
|
||||
params.lbfgs_n_iter = 16;
|
||||
params.adam_n_iter = 16;
|
||||
params.adam_alpha = 1e-3f;
|
||||
params.adam_decay = 1e-3f;
|
||||
|
||||
params.mem_model_gb = 2;
|
||||
params.mem_compute_gb = 24;
|
||||
params.mem_compute0_gb = 8;
|
||||
params.mem_compute1_gb = 2;
|
||||
|
||||
return params;
|
||||
}
|
||||
|
||||
static void print_usage(int /*argc*/, char ** argv, const struct train_params * params) {
|
||||
fprintf(stderr, "usage: %s [options]\n", argv[0]);
|
||||
fprintf(stderr, "\n");
|
||||
fprintf(stderr, "options:\n");
|
||||
fprintf(stderr, " -h, --help show this help message and exit\n");
|
||||
fprintf(stderr, " --copy-vocab-from-model FNAME path of gguf llama model or llama2.c vocabulary from which to copy vocab (default '%s')\n", params->fn_vocab_model);
|
||||
fprintf(stderr, " --llama2c-model FNAME [REQUIRED] model path from which to load Karpathy's llama2.c model\n");
|
||||
fprintf(stderr, " --llama2c-output-model FNAME model path to save the converted llama2.c model (default %s')\n", params->fn_llama2c_output_model);
|
||||
fprintf(stderr, "\n");
|
||||
}
|
||||
|
||||
static bool params_parse(int argc, char ** argv, struct train_params * params) {
|
||||
bool invalid_param = false;
|
||||
bool reqd_param_found = false;
|
||||
std::string arg;
|
||||
struct train_params default_params = get_default_train_params();
|
||||
const std::string arg_prefix = "--";
|
||||
|
||||
for (int i = 1; i < argc; i++) {
|
||||
arg = argv[i];
|
||||
if (arg.compare(0, arg_prefix.size(), arg_prefix) == 0) {
|
||||
std::replace(arg.begin(), arg.end(), '_', '-');
|
||||
}
|
||||
|
||||
if (arg == "--copy-vocab-from-model") {
|
||||
if (++i >= argc) {
|
||||
invalid_param = true;
|
||||
break;
|
||||
}
|
||||
params->fn_vocab_model = argv[i];
|
||||
} else if (arg == "--llama2c-model") {
|
||||
if (++i >= argc) {
|
||||
invalid_param = true;
|
||||
break;
|
||||
}
|
||||
reqd_param_found = true;
|
||||
params->fn_llama2c_model = argv[i];
|
||||
} else if (arg == "--llama2c-output-model") {
|
||||
if (++i >= argc) {
|
||||
invalid_param = true;
|
||||
break;
|
||||
}
|
||||
params->fn_llama2c_output_model = argv[i];
|
||||
} else if (arg == "-h" || arg == "--help") {
|
||||
print_usage(argc, argv, &default_params);
|
||||
exit(0);
|
||||
} else {
|
||||
fprintf(stderr, "error: unknown argument: %s\n", arg.c_str());
|
||||
print_usage(argc, argv, &default_params);
|
||||
exit(1);
|
||||
}
|
||||
}
|
||||
if (invalid_param) {
|
||||
fprintf(stderr, "error: invalid parameter for argument: %s\n", arg.c_str());
|
||||
print_usage(argc, argv, &default_params);
|
||||
exit(1);
|
||||
}
|
||||
if (!reqd_param_found){
|
||||
fprintf(stderr, "error: please specify a llama2.c .bin file to be converted with argument --llama2c-model\n");
|
||||
print_usage(argc, argv, &default_params);
|
||||
exit(1);
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
static std::string basename(const std::string &path) {
|
||||
size_t pos = path.find_last_of("/\\");
|
||||
if (pos == std::string::npos) {
|
||||
return path;
|
||||
}
|
||||
return path.substr(pos + 1);
|
||||
}
|
||||
|
||||
int main(int argc, char ** argv) {
|
||||
common_init();
|
||||
|
||||
struct train_params params = get_default_train_params();
|
||||
if (!params_parse(argc, argv, ¶ms)) {
|
||||
return 1;
|
||||
}
|
||||
|
||||
Config config;
|
||||
TransformerWeights weights = {};
|
||||
{
|
||||
LOG_INF("%s: Loading llama2c model from %s\n", __func__, params.fn_llama2c_model);
|
||||
FILE * file = fopen(params.fn_llama2c_model, "rb");
|
||||
if (!file) {
|
||||
LOG_ERR("%s: Unable to open the checkpoint file %s!\n", __func__, params.fn_llama2c_model);
|
||||
return 1;
|
||||
}
|
||||
// read in the config header
|
||||
if (fread(&config, sizeof(Config), 1, file) != 1) {
|
||||
LOG_ERR("%s: Unable to read llama2c config from %s!\n",__func__,params.fn_llama2c_model);
|
||||
return 1;
|
||||
}
|
||||
auto shared_weights = config.vocab_size > 0;
|
||||
config.vocab_size = abs(config.vocab_size);
|
||||
|
||||
// read in the Transformer weights
|
||||
alloc_weights(&weights, &config, shared_weights);
|
||||
if (checkpoint_init_weights(&weights, &config, file, shared_weights)) {
|
||||
LOG_ERR("%s: Unable to initialize transformer weights from %s!",__func__,params.fn_llama2c_model);
|
||||
return 1;
|
||||
}
|
||||
fclose(file);
|
||||
}
|
||||
|
||||
struct my_llama_vocab vocab;
|
||||
load_vocab(params.fn_vocab_model, &config, &vocab);
|
||||
|
||||
struct my_llama_model model;
|
||||
model.hparams.n_vocab = config.vocab_size; //llama_vocab_n_vocab(lctx);
|
||||
model.hparams.n_ctx = params.n_ctx;
|
||||
model.hparams.n_embd = config.dim; //params.n_embd;
|
||||
model.hparams.n_ff = config.hidden_dim;
|
||||
model.hparams.n_mult = 32;//params.n_mult;
|
||||
model.hparams.n_head = config.n_heads; //params.n_head;
|
||||
model.hparams.n_head_kv = config.n_kv_heads;
|
||||
model.hparams.n_layer = config.n_layers; //params.n_layer;
|
||||
model.hparams.n_rot = std::min((uint32_t)params.n_rotmax, model.hparams.n_embd / model.hparams.n_head);
|
||||
|
||||
print_params(&model.hparams);
|
||||
|
||||
struct ggml_init_params lcparams;
|
||||
lcparams.mem_size = 1024ll*1024ll*1024ll*((size_t) params.mem_model_gb);
|
||||
lcparams.mem_buffer = NULL;
|
||||
lcparams.no_alloc = false;
|
||||
|
||||
model.ctx = ggml_init(lcparams);
|
||||
|
||||
init_model(&model);
|
||||
model.name = basename(params.fn_llama2c_model);
|
||||
save_as_llama_model(&vocab, &model, &weights, params.fn_llama2c_output_model);
|
||||
|
||||
LOG_INF("%s: Saving llama.c model file %s in ggml format at %s\n", __func__, params.fn_llama2c_model, params.fn_llama2c_output_model);
|
||||
|
||||
ggml_free(model.ctx);
|
||||
return 0;
|
||||
}
|
||||
1462
examples/convert_legacy_llama.py
Executable file
5
examples/debug/CMakeLists.txt
Normal file
@@ -0,0 +1,5 @@
|
||||
set(TARGET llama-debug)
|
||||
add_executable(${TARGET} debug.cpp)
|
||||
install(TARGETS ${TARGET} RUNTIME)
|
||||
target_link_libraries(${TARGET} PRIVATE common llama ${CMAKE_THREAD_LIBS_INIT})
|
||||
target_compile_features(${TARGET} PRIVATE cxx_std_17)
|
||||
54
examples/debug/README.md
Normal file
@@ -0,0 +1,54 @@
|
||||
# llama.cpp/examples/debug
|
||||
|
||||
This is a utility intended to help debug a model by registering a callback that
|
||||
logs GGML operations and tensor data. It can also store the generated logits or
|
||||
embeddings as well as the prompt and token ids for comparision with the original
|
||||
model.
|
||||
|
||||
### Usage
|
||||
|
||||
```shell
|
||||
llama-debug \
|
||||
--hf-repo ggml-org/models \
|
||||
--hf-file phi-2/ggml-model-q4_0.gguf \
|
||||
--model phi-2-q4_0.gguf \
|
||||
--prompt hello \
|
||||
--save-logits \
|
||||
--verbose
|
||||
```
|
||||
The tensor data is logged as debug and required the --verbose flag. The reason
|
||||
for this is that while useful for a model with many layers there can be a lot of
|
||||
output. You can filter the tensor names using the `--tensor-filter` option.
|
||||
|
||||
A recommended approach is to first run without `--verbose` and see if the
|
||||
generated logits/embeddings are close to the original model. If they are not,
|
||||
then it might be required to inspect tensor by tensor and in that case it is
|
||||
useful to enable the `--verbose` flag along with `--tensor-filter` to focus on
|
||||
specific tensors.
|
||||
|
||||
### Options
|
||||
This example supports all standard `llama.cpp` options and also accepts the
|
||||
following options:
|
||||
```console
|
||||
$ llama-debug --help
|
||||
...
|
||||
|
||||
----- example-specific params -----
|
||||
|
||||
--save-logits save final logits to files for verification (default: false)
|
||||
--logits-output-dir PATH directory for saving logits output files (default: data)
|
||||
--tensor-filter REGEX filter tensor names for debug output (regex pattern, can be specified multiple times)
|
||||
```
|
||||
|
||||
### Output Files
|
||||
|
||||
When `--save-logits` is enabled, the following files are created in the output
|
||||
directory:
|
||||
|
||||
* `llamacpp-<model>[-embeddings].bin` - Binary output (logits or embeddings)
|
||||
* `llamacpp-<model>[-embeddings].txt` - Text output (logits or embeddings, one per line)
|
||||
* `llamacpp-<model>[-embeddings]-prompt.txt` - Prompt text and token IDs
|
||||
* `llamacpp-<model>[-embeddings]-tokens.bin` - Binary token IDs for programmatic comparison
|
||||
|
||||
These files can be compared against the original model's output to verify the
|
||||
converted model.
|
||||
253
examples/debug/debug.cpp
Normal file
@@ -0,0 +1,253 @@
|
||||
#include "debug.h"
|
||||
#include "arg.h"
|
||||
#include "common.h"
|
||||
#include "log.h"
|
||||
#include "llama.h"
|
||||
|
||||
#include <cstdlib>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
#include <filesystem>
|
||||
#include <fstream>
|
||||
#include <regex>
|
||||
|
||||
static void print_usage(int /*argc*/, char ** argv) {
|
||||
const std::string usage_template = R"(
|
||||
example usage:
|
||||
|
||||
Print tensors:
|
||||
|
||||
{prog} -m model.gguf -p "Hello my name is" --verbose
|
||||
|
||||
The tensors to be printed can be filtered with --tensor-filter option.
|
||||
|
||||
Save logits/embeddings:
|
||||
|
||||
{prog} -m model.gguf -p "Hello my name is" --save-logits
|
||||
|
||||
Add --embedding to save embeddings)" "\n";
|
||||
|
||||
// Fix the source code indentation above that is introduced by the raw string literal.
|
||||
std::string usage = std::regex_replace(usage_template, std::regex("\\n {8}"), "\n");
|
||||
usage = std::regex_replace(usage, std::regex("\\{prog\\}"), argv[0]);
|
||||
LOG("%s\n", usage.c_str());
|
||||
}
|
||||
|
||||
static bool has_pooling(llama_context * ctx) {
|
||||
switch (llama_pooling_type(ctx)) {
|
||||
case LLAMA_POOLING_TYPE_NONE:
|
||||
case LLAMA_POOLING_TYPE_UNSPECIFIED:
|
||||
return false;
|
||||
default:
|
||||
return true;
|
||||
}
|
||||
}
|
||||
|
||||
struct output_data {
|
||||
float * data_ptr = nullptr;
|
||||
int data_size = 0;
|
||||
std::string type_suffix;
|
||||
std::vector<float> embd_norm;
|
||||
std::string prompt;
|
||||
std::vector<llama_token> tokens;
|
||||
|
||||
output_data(llama_context * ctx, const llama_model * model, const common_params & params) {
|
||||
const llama_vocab * vocab = llama_model_get_vocab(model);
|
||||
const bool add_bos = llama_vocab_get_add_bos(vocab);
|
||||
|
||||
tokens = common_tokenize(ctx, params.prompt, add_bos);
|
||||
prompt = params.prompt;
|
||||
|
||||
if (params.embedding) {
|
||||
const int n_embd = llama_model_n_embd_out(model);
|
||||
const bool pooling = has_pooling(ctx);
|
||||
const int n_embd_count = pooling ? 1 : tokens.size();
|
||||
const int n_floats = n_embd * n_embd_count;
|
||||
|
||||
float * embd_raw = pooling ? llama_get_embeddings_seq(ctx, 0) : llama_get_embeddings(ctx);
|
||||
if (embd_raw == nullptr) {
|
||||
throw std::runtime_error("failed to get embeddings from the model");
|
||||
}
|
||||
|
||||
LOG_DBG("pooling_enabled: %s\n", pooling ? "true" : "false");
|
||||
LOG_DBG("n_embd: %d\n", n_embd);
|
||||
LOG_DBG("n_floats: %d\n", n_floats);
|
||||
LOG_DBG("n_embd_count: %d\n", n_embd_count);
|
||||
|
||||
data_ptr = embd_raw;
|
||||
data_size = n_floats;
|
||||
type_suffix = "-embeddings";
|
||||
|
||||
if (params.embd_normalize >= 0) {
|
||||
embd_norm.resize(n_floats);
|
||||
for (int i = 0; i < n_embd_count; i++) {
|
||||
common_embd_normalize(embd_raw+i*n_embd, embd_norm.data()+i*n_embd, n_embd, params.embd_normalize);
|
||||
}
|
||||
data_ptr = embd_norm.data();
|
||||
}
|
||||
} else {
|
||||
const float * logits = llama_get_logits_ith(ctx, tokens.size() - 1);
|
||||
const int n_logits = llama_vocab_n_tokens(vocab);
|
||||
|
||||
data_ptr = const_cast<float*>(logits);
|
||||
data_size = n_logits;
|
||||
type_suffix = "";
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
static void save_output_data(const output_data & output, const std::string & model_name, const std::string & output_dir) {
|
||||
std::filesystem::create_directory(output_dir);
|
||||
auto base_path = std::filesystem::path{output_dir} / ("llamacpp-" + model_name + output.type_suffix);
|
||||
|
||||
// Save logits/embeddings to binary file.
|
||||
{
|
||||
std::filesystem::path filepath{base_path.string() + ".bin"};
|
||||
std::ofstream file{filepath, std::ios::binary};
|
||||
if (!file) {
|
||||
throw std::runtime_error("failed to open binary output file: " + filepath.string());
|
||||
}
|
||||
file.write(reinterpret_cast<const char*>(output.data_ptr), output.data_size * sizeof(float));
|
||||
LOG("Data saved to %s\n", filepath.c_str());
|
||||
}
|
||||
|
||||
// Save logits/embeddings to text file.
|
||||
{
|
||||
std::filesystem::path filepath{base_path.string() + ".txt"};
|
||||
std::ofstream file{filepath};
|
||||
if (!file) {
|
||||
throw std::runtime_error("failed to open text output file: " + filepath.string());
|
||||
}
|
||||
for (int i = 0; i < output.data_size; i++) {
|
||||
file << i << ": " << output.data_ptr[i] << '\n';
|
||||
}
|
||||
LOG("Data saved to %s\n", filepath.c_str());
|
||||
}
|
||||
|
||||
// Save prompt and tokens to text file.
|
||||
{
|
||||
std::filesystem::path filepath{base_path.string() + "-prompt.txt"};
|
||||
std::ofstream file{filepath};
|
||||
if (!file) {
|
||||
throw std::runtime_error("failed to open prompt output file: " + filepath.string());
|
||||
}
|
||||
|
||||
file << "prompt: " << output.prompt << '\n';
|
||||
file << "n_tokens: " << output.tokens.size() << '\n';
|
||||
|
||||
file << "token ids: ";
|
||||
for (size_t i = 0; i < output.tokens.size(); i++) {
|
||||
file << output.tokens[i];
|
||||
if (i + 1 < output.tokens.size()) {
|
||||
file << ", ";
|
||||
}
|
||||
}
|
||||
file << '\n';
|
||||
LOG("Prompt saved to %s\n", filepath.c_str());
|
||||
}
|
||||
|
||||
// Save token ids to binary file.
|
||||
{
|
||||
std::filesystem::path filepath{base_path.string() + "-tokens.bin"};
|
||||
std::ofstream file{filepath, std::ios::binary};
|
||||
if (!file) {
|
||||
throw std::runtime_error("failed to open tokens binary file: " + filepath.string());
|
||||
}
|
||||
file.write(reinterpret_cast<const char*>(output.tokens.data()), output.tokens.size() * sizeof(llama_token));
|
||||
LOG("Tokens saved to %s\n", filepath.c_str());
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
static void print_tokenized_prompt(llama_context * ctx, const std::vector<llama_token> & tokens, const std::string & prompt) {
|
||||
const llama_model * model = llama_get_model(ctx);
|
||||
const llama_vocab * vocab = llama_model_get_vocab(model);
|
||||
|
||||
LOG("Model add_bos: %s\n", llama_vocab_get_add_bos(vocab) ? "true" : "false");
|
||||
LOG("Input prompt: \"%s\"\n", prompt.c_str());
|
||||
LOG("Token ids (%zu):\n", tokens.size());
|
||||
|
||||
for (auto id : tokens) {
|
||||
std::string piece(128, '\0');
|
||||
int n = llama_token_to_piece(vocab, id, piece.data(), piece.size(), 0, true);
|
||||
if (n < 0) {
|
||||
LOG_ERR("failed to convert token %d to piece\n", id);
|
||||
continue;
|
||||
}
|
||||
piece.resize(n);
|
||||
LOG("%s(%d) ", piece.c_str(), id);
|
||||
}
|
||||
LOG("\n");
|
||||
}
|
||||
|
||||
static bool run(llama_context * ctx, const common_params & params) {
|
||||
const llama_model * model = llama_get_model(ctx);
|
||||
const llama_vocab * vocab = llama_model_get_vocab(model);
|
||||
|
||||
const bool add_bos = llama_vocab_get_add_bos(vocab);
|
||||
|
||||
std::vector<llama_token> tokens = common_tokenize(ctx, params.prompt, add_bos);
|
||||
|
||||
if (tokens.empty()) {
|
||||
LOG_ERR("%s : there are not input tokens to process - (try to provide a prompt with '-p')\n", __func__);
|
||||
return false;
|
||||
}
|
||||
|
||||
if (llama_decode(ctx, llama_batch_get_one(tokens.data(), tokens.size()))) {
|
||||
LOG_ERR("%s : failed to eval\n", __func__);
|
||||
return false;
|
||||
}
|
||||
|
||||
print_tokenized_prompt(ctx, tokens, params.prompt);
|
||||
|
||||
if (params.save_logits) {
|
||||
output_data output {ctx, model, params};
|
||||
std::filesystem::path model_path{params.model.path};
|
||||
std::string model_name{model_path.stem().string()};
|
||||
save_output_data(output, model_name, params.logits_output_dir);
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
int main(int argc, char ** argv) {
|
||||
common_params params;
|
||||
|
||||
if (!common_params_parse(argc, argv, params, LLAMA_EXAMPLE_DEBUG, print_usage)) {
|
||||
return 1;
|
||||
}
|
||||
|
||||
common_init();
|
||||
|
||||
llama_backend_init();
|
||||
llama_numa_init(params.numa);
|
||||
|
||||
base_callback_data cb_data(params, params.tensor_filter);
|
||||
|
||||
auto llama_init = common_init_from_params(params);
|
||||
|
||||
auto * model = llama_init->model();
|
||||
auto * ctx = llama_init->context();
|
||||
|
||||
if (model == nullptr || ctx == nullptr) {
|
||||
LOG_ERR("%s : failed to init\n", __func__);
|
||||
return 1;
|
||||
}
|
||||
|
||||
{
|
||||
LOG_INF("\n");
|
||||
LOG_INF("%s\n", common_params_get_system_info(params).c_str());
|
||||
LOG_INF("\n");
|
||||
}
|
||||
|
||||
if (!run(ctx, params)) {
|
||||
return 1;
|
||||
}
|
||||
|
||||
LOG("\n");
|
||||
llama_perf_context_print(ctx);
|
||||
|
||||
llama_backend_free();
|
||||
|
||||
return 0;
|
||||
}
|
||||
49
examples/deprecation-warning/README.md
Normal file
@@ -0,0 +1,49 @@
|
||||
# Migration notice for binary filenames
|
||||
|
||||
> [!IMPORTANT]
|
||||
[2024 Jun 12] Binaries have been renamed w/ a `llama-` prefix. `main` is now `llama-cli`, `server` is `llama-server`, etc (https://github.com/ggerganov/llama.cpp/pull/7809)
|
||||
|
||||
This migration was important, but it is a breaking change that may not always be immediately obvious to users.
|
||||
|
||||
Please update all scripts and workflows to use the new binary names.
|
||||
|
||||
| Old Filename | New Filename |
|
||||
| ---- | ---- |
|
||||
| main | llama-cli |
|
||||
| server | llama-server |
|
||||
| llama-bench | llama-bench |
|
||||
| embedding | llama-embedding |
|
||||
| quantize | llama-quantize |
|
||||
| tokenize | llama-tokenize |
|
||||
| export-lora | llama-export-lora |
|
||||
| libllava.a | libllava.a |
|
||||
| baby-llama | llama-baby-llama |
|
||||
| batched | llama-batched |
|
||||
| batched-bench | llama-batched-bench |
|
||||
| benchmark-matmult | llama-benchmark-matmult |
|
||||
| convert-llama2c-to-ggml | llama-convert-llama2c-to-ggml |
|
||||
| eval-callback | llama-eval-callback |
|
||||
| gbnf-validator | llama-gbnf-validator |
|
||||
| gguf | llama-gguf |
|
||||
| gguf-split | llama-gguf-split |
|
||||
| gritlm | llama-gritlm |
|
||||
| imatrix | llama-imatrix |
|
||||
| infill | llama-infill |
|
||||
| llava-cli | llama-llava-cli |
|
||||
| lookahead | llama-lookahead |
|
||||
| lookup | llama-lookup |
|
||||
| lookup-create | llama-lookup-create |
|
||||
| lookup-merge | llama-lookup-merge |
|
||||
| lookup-stats | llama-lookup-stats |
|
||||
| parallel | llama-parallel |
|
||||
| passkey | llama-passkey |
|
||||
| perplexity | llama-perplexity |
|
||||
| q8dot | llama-q8dot |
|
||||
| quantize-stats | llama-quantize-stats |
|
||||
| retrieval | llama-retrieval |
|
||||
| save-load-state | llama-save-load-state |
|
||||
| simple | llama-simple |
|
||||
| speculative | llama-speculative |
|
||||
| vdot | llama-vdot |
|
||||
| tests/test-c.o | tests/test-c.o |
|
||||
|
||||
35
examples/deprecation-warning/deprecation-warning.cpp
Normal file
@@ -0,0 +1,35 @@
|
||||
// Warns users that this filename was deprecated, and provides a link for more information.
|
||||
|
||||
#include <cstdio>
|
||||
#include <string>
|
||||
#include <unordered_map>
|
||||
|
||||
// Main
|
||||
int main(int argc, char** argv) {
|
||||
std::string filename = "main";
|
||||
if (argc >= 1) {
|
||||
filename = argv[0];
|
||||
}
|
||||
|
||||
// Get only the program name from the full path
|
||||
auto pos = filename.find_last_of("/\\");
|
||||
if (pos != std::string::npos) {
|
||||
filename = filename.substr(pos+1);
|
||||
}
|
||||
|
||||
// Append "llama-" to the beginning of filename to get the replacemnt filename
|
||||
auto replacement_filename = "llama-" + filename;
|
||||
|
||||
// The exception is if the filename is "main", then our replacement filename is "llama-cli"
|
||||
if (filename == "main") {
|
||||
replacement_filename = "llama-cli";
|
||||
}
|
||||
|
||||
fprintf(stdout, "\n");
|
||||
fprintf(stdout, "WARNING: The binary '%s' is deprecated.\n", filename.c_str());
|
||||
fprintf(stdout, " Please use '%s' instead.\n", replacement_filename.c_str());
|
||||
fprintf(stdout, " See https://github.com/ggerganov/llama.cpp/tree/master/examples/deprecation-warning/README.md for more information.\n");
|
||||
fprintf(stdout, "\n");
|
||||
|
||||
return EXIT_FAILURE;
|
||||
}
|
||||
5
examples/diffusion/CMakeLists.txt
Normal file
@@ -0,0 +1,5 @@
|
||||
set(TARGET llama-diffusion-cli)
|
||||
add_executable(${TARGET} diffusion-cli.cpp)
|
||||
install(TARGETS ${TARGET} RUNTIME)
|
||||
target_link_libraries(${TARGET} PRIVATE llama common ${CMAKE_THREAD_LIBS_INIT})
|
||||
target_compile_features(${TARGET} PRIVATE cxx_std_17)
|
||||
59
examples/diffusion/README.md
Normal file
@@ -0,0 +1,59 @@
|
||||
# Diffusion Text Generation
|
||||
|
||||
This directory contains implementations for Diffusion LLMs (DLLMs)
|
||||
|
||||
More Info:
|
||||
- https://github.com/ggml-org/llama.cpp/pull/14644
|
||||
- https://github.com/ggml-org/llama.cpp/pull/14771
|
||||
|
||||
## Parameters
|
||||
The diffusion CLI supports various parameters to control the generation process:
|
||||
|
||||
### Core Diffusion Parameters
|
||||
- `--diffusion-steps`: Number of diffusion steps (default: 256)
|
||||
- `--diffusion-algorithm`: Algorithm for token selection
|
||||
- `0`: ORIGIN - Token will be generated in a purely random order from https://arxiv.org/abs/2107.03006.
|
||||
- `1`: ENTROPY_BASED - Entropy-based selection
|
||||
- `2`: MARGIN_BASED - Margin-based selection
|
||||
- `3`: RANDOM - Random selection
|
||||
- `4`: CONFIDENCE_BASED - Confidence-based selection (default)
|
||||
- More documentation here https://github.com/DreamLM/Dream
|
||||
- `--diffusion-visual`: Enable live visualization during generation
|
||||
|
||||
### Scheduling Parameters
|
||||
Choose one of the following scheduling methods:
|
||||
|
||||
**Timestep-based scheduling:**
|
||||
- `--diffusion-eps`: Epsilon value for timestep scheduling (e.g., 0.001)
|
||||
|
||||
**Block-based scheduling:**
|
||||
- `--diffusion-block-length`: Block size for block-based scheduling (e.g., 32)
|
||||
|
||||
### Sampling Parameters
|
||||
- `--temp`: Temperature for sampling (0.0 = greedy/deterministic, higher = more random)
|
||||
- `--top-k`: Top-k filtering for sampling
|
||||
- `--top-p`: Top-p (nucleus) filtering for sampling
|
||||
- `--seed`: Random seed for reproducibility
|
||||
|
||||
### Model Parameters
|
||||
- `-m`: Path to the GGUF model file
|
||||
- `-p`: Input prompt text
|
||||
- `-ub`: Maximum sequence length (ubatch size)
|
||||
- `-c`: Context size
|
||||
- `-b`: Batch size
|
||||
|
||||
### Examples
|
||||
#### Dream architechture:
|
||||
```
|
||||
llama-diffusion-cli -m dream7b.gguf -p "write code to train MNIST in pytorch" -ub 512 --diffusion-eps 0.001 --diffusion-algorithm 3 --diffusion-steps 256 --diffusion-visual
|
||||
```
|
||||
|
||||
#### LLaDA architechture:
|
||||
```
|
||||
llama-diffusion-cli -m llada-8b.gguf -p "write code to train MNIST in pytorch" -ub 512 --diffusion-block-length 32 --diffusion-steps 256 --diffusion-visual
|
||||
```
|
||||
|
||||
#### RND1 architecture:
|
||||
```
|
||||
llama-diffusion-cli -m RND1-Base-0910.gguf -p "write code to train MNIST in pytorch" -ub 512 --diffusion-algorithm 1 --diffusion-steps 256 --diffusion-visual --temp 0.5 --diffusion-eps 0.001
|
||||
```
|
||||
694
examples/diffusion/diffusion-cli.cpp
Normal file
@@ -0,0 +1,694 @@
|
||||
#include "arg.h"
|
||||
#include "chat.h"
|
||||
#include "common.h"
|
||||
#include "llama.h"
|
||||
#include "log.h"
|
||||
|
||||
#include <limits.h>
|
||||
|
||||
#include <algorithm>
|
||||
#include <cmath>
|
||||
#include <cstring>
|
||||
#include <limits>
|
||||
#include <random>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
|
||||
enum diffusion_algorithm { ORIGIN = 0, ENTROPY_BASED = 1, MARGIN_BASED = 2, RANDOM = 3, CONFIDENCE_BASED = 4 };
|
||||
|
||||
// Unified transfer scheduling methods
|
||||
enum transfer_schedule {
|
||||
TIMESTEP_BASED = 0, // Dream-style: (1.0 - s/t) * remaining
|
||||
BLOCK_BASED = 1, // LLaDA-style: process in blocks with get_num_transfer_tokens
|
||||
};
|
||||
|
||||
typedef bool (*diffusion_step_callback_t)(int32_t step,
|
||||
int32_t total_steps,
|
||||
const llama_token * tokens,
|
||||
int32_t n_tokens,
|
||||
void * user_data);
|
||||
|
||||
struct diffusion_params {
|
||||
int32_t steps = 0;
|
||||
float temperature = 0;
|
||||
llama_token mask_token_id = LLAMA_TOKEN_NULL;
|
||||
diffusion_step_callback_t step_callback = nullptr;
|
||||
void * step_callback_user_data = nullptr;
|
||||
int32_t seed = 0;
|
||||
bool visual_mode = false;
|
||||
bool shift_logits = false; // Shift logits by -1 after decode
|
||||
|
||||
float top_p = 0.;
|
||||
int32_t top_k = 0.;
|
||||
|
||||
diffusion_algorithm algorithm = CONFIDENCE_BASED;
|
||||
transfer_schedule schedule = TIMESTEP_BASED;
|
||||
|
||||
float cfg_scale = 0.; // Config scale for classifier-free guidance
|
||||
float eps = 0.; // Timestep scheduling
|
||||
int32_t block_length = 0; // Block size (for block scheduling)
|
||||
float alg_temp = 0; // algorithm temperature (0.0 = deterministic)
|
||||
bool add_gumbel_noise = false; // Add gumbel noise to the logits if temp > 0.0
|
||||
|
||||
int32_t max_length = 0; // Maximum sequence length
|
||||
};
|
||||
|
||||
struct callback_data {
|
||||
diffusion_params * diff_params;
|
||||
const llama_vocab * vocab;
|
||||
int32_t n_input;
|
||||
};
|
||||
|
||||
static float calculate_confidence(const llama_token_data_array & cur_p,
|
||||
diffusion_algorithm algorithm,
|
||||
std::mt19937 & rng) {
|
||||
switch (algorithm) {
|
||||
case CONFIDENCE_BASED:
|
||||
return cur_p.data[cur_p.selected].p; // Selected token probability
|
||||
|
||||
case ENTROPY_BASED:
|
||||
{
|
||||
float entropy = 0.0f;
|
||||
const float epsilon = 1e-10f;
|
||||
for (size_t i = 0; i < cur_p.size; i++) {
|
||||
float prob = cur_p.data[i].p;
|
||||
entropy += prob * logf(prob + epsilon);
|
||||
}
|
||||
return -entropy; // Higher entropy = lower confidence
|
||||
}
|
||||
|
||||
case MARGIN_BASED:
|
||||
return (cur_p.size > 1) ? cur_p.data[0].p - cur_p.data[1].p : cur_p.data[0].p;
|
||||
|
||||
case RANDOM:
|
||||
{
|
||||
std::uniform_real_distribution<float> uniform(0.0f, 1.0f);
|
||||
return uniform(rng); // Random confidence
|
||||
}
|
||||
|
||||
case ORIGIN:
|
||||
return cur_p.data[cur_p.selected].p;
|
||||
|
||||
default:
|
||||
return 0.0f;
|
||||
}
|
||||
}
|
||||
|
||||
// Unified transfer count calculation function
|
||||
static int32_t calculate_transfer_count(int32_t step,
|
||||
int32_t total_steps,
|
||||
int32_t remaining_masked,
|
||||
transfer_schedule schedule,
|
||||
float eps,
|
||||
const std::vector<int32_t> & num_transfer_tokens = {}) {
|
||||
switch (schedule) {
|
||||
case TIMESTEP_BASED:
|
||||
{
|
||||
float t = 1.0f - (float) step / total_steps * (1.0f - eps);
|
||||
float s = 1.0f - (float) (step + 1) / total_steps * (1.0f - eps);
|
||||
float p_transfer = (step < total_steps - 1) ? (1.0f - s / t) : 1.0f;
|
||||
return (int32_t) (remaining_masked * p_transfer);
|
||||
}
|
||||
|
||||
case BLOCK_BASED:
|
||||
if (!num_transfer_tokens.empty() && step < (int32_t) num_transfer_tokens.size()) {
|
||||
return num_transfer_tokens[step];
|
||||
}
|
||||
return remaining_masked / (total_steps - step); // Fallback
|
||||
|
||||
default:
|
||||
return remaining_masked / (total_steps - step);
|
||||
}
|
||||
}
|
||||
|
||||
static bool diffusion_step_callback(int32_t step,
|
||||
int32_t total_steps,
|
||||
const llama_token * tokens,
|
||||
int32_t n_tokens,
|
||||
void * user_data) {
|
||||
(void) user_data;
|
||||
|
||||
callback_data * data = static_cast<callback_data *>(user_data);
|
||||
|
||||
auto print_progress_bar = [](int32_t step, int32_t total_steps) {
|
||||
int progress_percent = (step * 100) / total_steps;
|
||||
int progress_bars = (step * 50) / total_steps;
|
||||
LOG_INF("\rdiffusion step: %d/%d [%s%s] %d%%",
|
||||
step,
|
||||
total_steps,
|
||||
std::string(progress_bars, '=').c_str(),
|
||||
std::string(50 - progress_bars, ' ').c_str(),
|
||||
progress_percent);
|
||||
};
|
||||
|
||||
if (data->diff_params->visual_mode) {
|
||||
// Visual mode: clear
|
||||
LOG_INF("\033[2J\033[H"); // Clear screen and move cursor to top-left
|
||||
|
||||
print_progress_bar(step, total_steps);
|
||||
|
||||
LOG_INF("\n");
|
||||
|
||||
std::string current_text = " ";
|
||||
|
||||
for (int32_t i = data->n_input; i < n_tokens; i++) {
|
||||
std::string token_str;
|
||||
if (tokens[i] != llama_vocab_mask(data->vocab)) {
|
||||
char piece[256];
|
||||
int n_chars = llama_token_to_piece(data->vocab, tokens[i], piece, sizeof(piece), 0, false);
|
||||
if (n_chars > 0) {
|
||||
piece[n_chars] = '\0';
|
||||
token_str = piece;
|
||||
}
|
||||
} else {
|
||||
token_str = " ";
|
||||
}
|
||||
|
||||
current_text += token_str;
|
||||
}
|
||||
|
||||
LOG_INF("%s\n", current_text.c_str());
|
||||
} else {
|
||||
print_progress_bar(step, total_steps);
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
static void add_gumbel_noise(float * logits, int32_t n_vocab, float temperature, std::mt19937 & rng) {
|
||||
if (temperature == 0.0f) {
|
||||
return;
|
||||
}
|
||||
|
||||
std::uniform_real_distribution<double> uniform(0.0, 1.0);
|
||||
for (int32_t i = 0; i < n_vocab; i++) {
|
||||
double noise = uniform(rng);
|
||||
// Prevent log(0)
|
||||
noise = std::max(noise, 1e-20);
|
||||
double gumbel_noise = std::pow(-std::log(noise), temperature);
|
||||
logits[i] = std::exp(logits[i]) / gumbel_noise;
|
||||
}
|
||||
}
|
||||
|
||||
static std::vector<int32_t> get_num_transfer_tokens(int32_t mask_count, int32_t steps) {
|
||||
std::vector<int32_t> num_transfer_tokens(steps);
|
||||
|
||||
int32_t base = mask_count / steps;
|
||||
int32_t remainder = mask_count % steps;
|
||||
|
||||
for (int32_t i = 0; i < steps; i++) {
|
||||
num_transfer_tokens[i] = base + (i < remainder ? 1 : 0);
|
||||
}
|
||||
|
||||
return num_transfer_tokens;
|
||||
}
|
||||
|
||||
static void diffusion_generate(llama_context * ctx,
|
||||
const llama_token * input_tokens,
|
||||
llama_token * output_tokens,
|
||||
int32_t n_input,
|
||||
const diffusion_params & params,
|
||||
int32_t & n_generated) {
|
||||
n_generated = 0;
|
||||
if (!ctx || !input_tokens || !output_tokens || n_input <= 0 || params.max_length <= n_input) {
|
||||
return;
|
||||
}
|
||||
|
||||
const llama_model * model = llama_get_model(ctx);
|
||||
|
||||
// Initialize with input and pad with mask tokens
|
||||
std::copy(input_tokens, input_tokens + n_input, output_tokens);
|
||||
std::fill(output_tokens + n_input, output_tokens + params.max_length, params.mask_token_id);
|
||||
|
||||
std::mt19937 rng(params.seed);
|
||||
|
||||
llama_set_causal_attn(ctx, false);
|
||||
|
||||
int32_t n_vocab = llama_vocab_n_tokens(llama_model_get_vocab(model));
|
||||
|
||||
std::vector<llama_token_data> candidates(n_vocab);
|
||||
std::vector<llama_token_data> conf_candidates;
|
||||
conf_candidates.reserve(params.max_length);
|
||||
std::vector<int32_t> mask_positions;
|
||||
mask_positions.reserve(params.max_length);
|
||||
|
||||
// Setup sampler chain
|
||||
struct llama_sampler * sampler = llama_sampler_chain_init(llama_sampler_chain_default_params());
|
||||
if (params.top_k > 0) {
|
||||
llama_sampler_chain_add(sampler, llama_sampler_init_top_k(params.top_k));
|
||||
}
|
||||
if (params.top_p < 1.0f) {
|
||||
llama_sampler_chain_add(sampler, llama_sampler_init_top_p(params.top_p, 1));
|
||||
}
|
||||
if (params.temperature > 0.0f) {
|
||||
llama_sampler_chain_add(sampler, llama_sampler_init_temp(params.temperature));
|
||||
}
|
||||
llama_sampler_chain_add(sampler, llama_sampler_init_dist(params.seed));
|
||||
|
||||
struct llama_sampler * dist_sampler = llama_sampler_init_dist(params.seed);
|
||||
|
||||
llama_batch batch = llama_batch_init(params.max_length, 0, 1);
|
||||
batch.n_tokens = params.max_length;
|
||||
|
||||
// Pre-allocate buffers for CFG if needed
|
||||
int32_t logits_size = n_vocab * params.max_length;
|
||||
std::vector<float> cond_logits_buffer;
|
||||
std::vector<llama_token> un_x_buffer;
|
||||
if (params.cfg_scale > 0.0f) {
|
||||
cond_logits_buffer.resize(logits_size);
|
||||
un_x_buffer.resize(params.max_length);
|
||||
}
|
||||
|
||||
// For block-based processing
|
||||
std::vector<int32_t> num_transfer_tokens;
|
||||
int32_t num_blocks = 1;
|
||||
int32_t steps_per_block = params.steps;
|
||||
|
||||
if (params.schedule == BLOCK_BASED) {
|
||||
GGML_ASSERT(params.max_length % params.block_length == 0);
|
||||
num_blocks = params.max_length / params.block_length;
|
||||
GGML_ASSERT(params.steps % num_blocks == 0);
|
||||
steps_per_block = params.steps / num_blocks;
|
||||
}
|
||||
|
||||
std::vector<float> confidence(params.max_length);
|
||||
|
||||
int64_t total_sampling_time = 0;
|
||||
int64_t total_time = 0;
|
||||
int64_t time_start = ggml_time_us();
|
||||
|
||||
for (int block_num = 0; block_num < num_blocks; block_num++) {
|
||||
int32_t block_start = (params.schedule == BLOCK_BASED) ? n_input + block_num * params.block_length : 0;
|
||||
int32_t block_end = (params.schedule == BLOCK_BASED) ?
|
||||
std::min(n_input + (block_num + 1) * params.block_length, params.max_length) :
|
||||
params.max_length;
|
||||
|
||||
// Count masked tokens in current block for block-based processing
|
||||
if (params.schedule == BLOCK_BASED) {
|
||||
int32_t block_mask_count = 0;
|
||||
for (int i = block_start; i < block_end; i++) {
|
||||
if (output_tokens[i] == params.mask_token_id) {
|
||||
block_mask_count++;
|
||||
}
|
||||
}
|
||||
num_transfer_tokens = get_num_transfer_tokens(block_mask_count, steps_per_block);
|
||||
}
|
||||
|
||||
for (int32_t step = 0; step < steps_per_block; step++) {
|
||||
int32_t global_step = block_num * steps_per_block + step;
|
||||
|
||||
if (params.step_callback) {
|
||||
if (!params.step_callback(
|
||||
global_step, params.steps, output_tokens, params.max_length, params.step_callback_user_data)) {
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
// Setup batch
|
||||
for (int32_t i = 0; i < params.max_length; i++) {
|
||||
batch.token[i] = output_tokens[i];
|
||||
batch.pos[i] = i;
|
||||
batch.n_seq_id[i] = 1;
|
||||
batch.seq_id[i][0] = 0;
|
||||
batch.logits[i] = 1;
|
||||
}
|
||||
|
||||
float * logits = nullptr;
|
||||
|
||||
if (params.cfg_scale > 0.0f) {
|
||||
int ret = llama_decode(ctx, batch);
|
||||
if (ret != 0) {
|
||||
LOG_ERR("Failed to generate conditional");
|
||||
break;
|
||||
}
|
||||
float * cond_logits_ptr = llama_get_logits(ctx);
|
||||
std::memcpy(cond_logits_buffer.data(), cond_logits_ptr, logits_size * sizeof(float));
|
||||
|
||||
// Unconditional generation (mask input)
|
||||
std::copy(output_tokens, output_tokens + params.max_length, un_x_buffer.begin());
|
||||
for (int32_t i = 0; i < n_input; i++) {
|
||||
un_x_buffer[i] = params.mask_token_id;
|
||||
}
|
||||
|
||||
for (int32_t i = 0; i < params.max_length; i++) {
|
||||
batch.token[i] = un_x_buffer[i];
|
||||
}
|
||||
ret = llama_decode(ctx, batch);
|
||||
if (ret != 0) {
|
||||
LOG_ERR("Failed to generate unconditional");
|
||||
break;
|
||||
}
|
||||
float * uncond_logits = llama_get_logits(ctx);
|
||||
|
||||
// Apply CFG
|
||||
for (int32_t i = 0; i < logits_size; i++) {
|
||||
cond_logits_buffer[i] =
|
||||
uncond_logits[i] + (params.cfg_scale + 1.0f) * (cond_logits_buffer[i] - uncond_logits[i]);
|
||||
}
|
||||
logits = cond_logits_buffer.data();
|
||||
} else {
|
||||
int ret = llama_decode(ctx, batch);
|
||||
if (ret != 0) {
|
||||
LOG_ERR("%s: failed to decode at step %d, ret = %d\n", __func__, global_step, ret);
|
||||
break;
|
||||
}
|
||||
logits = llama_get_logits(ctx);
|
||||
}
|
||||
|
||||
if (!logits) {
|
||||
LOG_ERR("%s: failed to get logits at step %d\n", __func__, global_step);
|
||||
break;
|
||||
}
|
||||
|
||||
auto get_logits_for_pos = [&](int32_t pos) -> const float * {
|
||||
if (params.shift_logits) {
|
||||
return pos == 0 ? logits : logits + (pos - 1) * n_vocab;
|
||||
}
|
||||
return logits + (pos) *n_vocab;
|
||||
};
|
||||
|
||||
int64_t time_start_sampling = ggml_time_us();
|
||||
|
||||
mask_positions.clear();
|
||||
for (int32_t i = 0; i < params.max_length; i++) {
|
||||
if (output_tokens[i] == params.mask_token_id) {
|
||||
// For block-based, only consider current block
|
||||
if (params.schedule != BLOCK_BASED || (i >= block_start && i < block_end)) {
|
||||
mask_positions.push_back(i);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if (mask_positions.empty()) {
|
||||
break;
|
||||
}
|
||||
|
||||
if (params.add_gumbel_noise && params.temperature > 0.0f) {
|
||||
add_gumbel_noise(logits, n_vocab, params.temperature, rng);
|
||||
}
|
||||
|
||||
if (params.algorithm == ORIGIN) {
|
||||
int32_t transfer_count = calculate_transfer_count(
|
||||
step, steps_per_block, mask_positions.size(), params.schedule, params.eps, num_transfer_tokens);
|
||||
float p_transfer = (float) transfer_count / mask_positions.size();
|
||||
|
||||
for (int32_t pos : mask_positions) {
|
||||
if (std::uniform_real_distribution<float>(0.0f, 1.0f)(rng) < p_transfer) {
|
||||
const float * pos_logits = get_logits_for_pos(pos);
|
||||
for (int32_t token_id = 0; token_id < n_vocab; token_id++) {
|
||||
candidates[token_id].id = token_id;
|
||||
candidates[token_id].logit = pos_logits[token_id];
|
||||
candidates[token_id].p = 0.0f;
|
||||
}
|
||||
|
||||
llama_token_data_array cur_p = {
|
||||
candidates.data(),
|
||||
(size_t) n_vocab,
|
||||
-1,
|
||||
false,
|
||||
};
|
||||
|
||||
llama_sampler_apply(sampler, &cur_p);
|
||||
output_tokens[pos] = cur_p.data[cur_p.selected].id;
|
||||
}
|
||||
}
|
||||
} else {
|
||||
std::vector<std::pair<float, int32_t>> confidences;
|
||||
std::vector<llama_token> sampled_tokens(mask_positions.size());
|
||||
|
||||
for (size_t i = 0; i < mask_positions.size(); i++) {
|
||||
int32_t pos = mask_positions[i];
|
||||
const float * pos_logits = get_logits_for_pos(pos);
|
||||
|
||||
for (int32_t token_id = 0; token_id < n_vocab; token_id++) {
|
||||
candidates[token_id].logit = pos_logits[token_id];
|
||||
candidates[token_id].p = 0.0f;
|
||||
candidates[token_id].id = token_id;
|
||||
}
|
||||
|
||||
llama_token_data_array cur_p = {
|
||||
candidates.data(),
|
||||
candidates.size(),
|
||||
-1,
|
||||
false,
|
||||
};
|
||||
|
||||
llama_sampler_apply(sampler, &cur_p);
|
||||
llama_token sampled_token = cur_p.data[cur_p.selected].id;
|
||||
|
||||
float conf = calculate_confidence(cur_p, params.algorithm, rng);
|
||||
|
||||
sampled_tokens[i] = sampled_token;
|
||||
confidences.emplace_back(conf, i);
|
||||
}
|
||||
|
||||
int32_t transfer_count = calculate_transfer_count(
|
||||
step, steps_per_block, mask_positions.size(), params.schedule, params.eps, num_transfer_tokens);
|
||||
|
||||
if (transfer_count > 0) {
|
||||
if (params.alg_temp == 0.0f) {
|
||||
std::partial_sort(confidences.begin(),
|
||||
confidences.begin() + std::min(transfer_count, (int32_t) confidences.size()),
|
||||
confidences.end(),
|
||||
[](const std::pair<float, int32_t> & a, const std::pair<float, int32_t> & b) {
|
||||
if (a.first != b.first) {
|
||||
return a.first > b.first;
|
||||
}
|
||||
return a.second < b.second;
|
||||
});
|
||||
|
||||
for (int32_t i = 0; i < std::min(transfer_count, (int32_t) confidences.size()); i++) {
|
||||
int32_t mask_idx = confidences[i].second;
|
||||
int32_t pos = mask_positions[mask_idx];
|
||||
output_tokens[pos] = sampled_tokens[mask_idx];
|
||||
}
|
||||
} else {
|
||||
conf_candidates.clear();
|
||||
for (size_t i = 0; i < confidences.size(); i++) {
|
||||
float conf_logit = confidences[i].first / params.alg_temp;
|
||||
conf_candidates.emplace_back(llama_token_data{ (int32_t) i, conf_logit, 0.0f });
|
||||
}
|
||||
|
||||
llama_token_data_array conf_array = {
|
||||
conf_candidates.data(),
|
||||
conf_candidates.size(),
|
||||
-1,
|
||||
false,
|
||||
};
|
||||
|
||||
for (int32_t i = 0; i < std::min(transfer_count, (int32_t) confidences.size()); i++) {
|
||||
llama_sampler_apply(dist_sampler, &conf_array);
|
||||
int32_t selected_idx = conf_array.selected;
|
||||
int32_t mask_idx = selected_idx;
|
||||
int32_t pos = mask_positions[mask_idx];
|
||||
output_tokens[pos] = sampled_tokens[mask_idx];
|
||||
|
||||
conf_candidates[selected_idx].p = 0.0f;
|
||||
conf_array.selected = -1;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
int64_t time_end_sampling = ggml_time_us();
|
||||
total_sampling_time += time_end_sampling - time_start_sampling;
|
||||
}
|
||||
}
|
||||
|
||||
int64_t time_end = ggml_time_us();
|
||||
total_time += time_end - time_start;
|
||||
|
||||
LOG_INF("\ntotal time: %0.2fms, time per step: %0.2fms, sampling time per step: %0.2fms\n",
|
||||
total_time / 1000.0,
|
||||
total_time / 1000.0 / params.steps,
|
||||
total_sampling_time / 1000.0 / params.steps);
|
||||
|
||||
llama_batch_free(batch);
|
||||
llama_sampler_free(sampler);
|
||||
llama_sampler_free(dist_sampler);
|
||||
|
||||
n_generated = params.max_length;
|
||||
}
|
||||
|
||||
static std::string format_input_text(const std::string & prompt, const std::string & system_prompt, bool use_chat_template, llama_model * model) {
|
||||
if (!use_chat_template) {
|
||||
return prompt;
|
||||
}
|
||||
|
||||
auto chat_templates = common_chat_templates_init(model, "");
|
||||
common_chat_templates_inputs inputs;
|
||||
common_chat_msg system_msg;
|
||||
|
||||
if (!system_prompt.empty()) {
|
||||
system_msg.role = "system";
|
||||
system_msg.content = system_prompt;
|
||||
inputs.messages.push_back(system_msg);
|
||||
}
|
||||
|
||||
common_chat_msg user_msg;
|
||||
user_msg.role = "user";
|
||||
user_msg.content = prompt;
|
||||
|
||||
inputs.messages.push_back(user_msg);
|
||||
inputs.add_generation_prompt = true;
|
||||
|
||||
auto result = common_chat_templates_apply(chat_templates.get(), inputs);
|
||||
|
||||
return result.prompt;
|
||||
}
|
||||
|
||||
int main(int argc, char ** argv) {
|
||||
ggml_time_init();
|
||||
|
||||
common_params params;
|
||||
|
||||
if (!common_params_parse(argc, argv, params, LLAMA_EXAMPLE_DIFFUSION)) {
|
||||
return 1;
|
||||
}
|
||||
|
||||
common_init();
|
||||
llama_backend_init();
|
||||
|
||||
llama_model_params model_params = llama_model_default_params();
|
||||
model_params.n_gpu_layers = params.n_gpu_layers;
|
||||
model_params.devices = params.devices.data();
|
||||
model_params.use_mmap = params.use_mmap;
|
||||
model_params.use_direct_io = params.use_direct_io;
|
||||
model_params.use_mlock = params.use_mlock;
|
||||
model_params.check_tensors = params.check_tensors;
|
||||
|
||||
llama_model * model = llama_model_load_from_file(params.model.path.c_str(), model_params);
|
||||
if (!model) {
|
||||
LOG_ERR("error: failed to load model '%s'\n", params.model.path.c_str());
|
||||
return 1;
|
||||
}
|
||||
|
||||
if (!llama_model_is_diffusion(model)) {
|
||||
LOG_ERR("error: unsupported model for diffusion");
|
||||
llama_model_free(model);
|
||||
return 1;
|
||||
}
|
||||
|
||||
llama_context_params ctx_params = llama_context_default_params();
|
||||
ctx_params.n_ctx = params.n_ctx;
|
||||
ctx_params.n_batch = params.n_batch;
|
||||
ctx_params.n_ubatch = params.n_ubatch;
|
||||
ctx_params.flash_attn_type = params.flash_attn_type;
|
||||
ctx_params.no_perf = params.no_perf;
|
||||
ctx_params.type_k = params.cache_type_k;
|
||||
ctx_params.type_v = params.cache_type_v;
|
||||
|
||||
llama_context * ctx = llama_init_from_model(model, ctx_params);
|
||||
if (!ctx) {
|
||||
LOG_ERR("error: failed to create context\n");
|
||||
llama_model_free(model);
|
||||
return 1;
|
||||
}
|
||||
|
||||
llama_set_n_threads(ctx, params.cpuparams.n_threads, params.cpuparams_batch.n_threads);
|
||||
|
||||
const llama_vocab * vocab = llama_model_get_vocab(model);
|
||||
|
||||
std::string formatted_prompt = format_input_text(params.prompt, params.system_prompt, params.enable_chat_template, model);
|
||||
|
||||
std::vector<llama_token> input_tokens = common_tokenize(vocab,
|
||||
formatted_prompt,
|
||||
/*add special tokens*/ true,
|
||||
/*parse special*/ true);
|
||||
|
||||
int n_input = input_tokens.size();
|
||||
|
||||
if (n_input >= params.n_ctx) {
|
||||
LOG_ERR("error: input too long (%d tokens), max context is %d\n", n_input, params.n_ctx);
|
||||
llama_free(ctx);
|
||||
llama_model_free(model);
|
||||
return 1;
|
||||
}
|
||||
|
||||
llama_token mask_token_id = llama_vocab_mask(vocab);
|
||||
|
||||
GGML_ASSERT(mask_token_id != LLAMA_TOKEN_NULL);
|
||||
|
||||
bool visual_mode = params.diffusion.visual_mode;
|
||||
|
||||
int32_t n_generated = 0;
|
||||
std::vector<llama_token> output_tokens(params.n_ubatch);
|
||||
|
||||
struct diffusion_params diff_params;
|
||||
|
||||
char shift_logits_str[8];
|
||||
if (llama_model_meta_val_str(model, "diffusion.shift_logits", shift_logits_str, sizeof(shift_logits_str)) >= 0) {
|
||||
diff_params.shift_logits = (strcmp(shift_logits_str, "true") == 0);
|
||||
} else {
|
||||
diff_params.shift_logits = true;
|
||||
}
|
||||
|
||||
//Use either eps or block length, but not both
|
||||
GGML_ASSERT((params.diffusion.eps == 0) ^ (params.diffusion.block_length == 0));
|
||||
|
||||
if (params.diffusion.eps) {
|
||||
diff_params.schedule = TIMESTEP_BASED;
|
||||
diff_params.eps = params.diffusion.eps;
|
||||
} else if (params.diffusion.block_length) {
|
||||
diff_params.schedule = BLOCK_BASED;
|
||||
diff_params.block_length = params.diffusion.block_length;
|
||||
}
|
||||
|
||||
diff_params.mask_token_id = mask_token_id;
|
||||
diff_params.seed = params.sampling.seed;
|
||||
diff_params.temperature = params.sampling.temp;
|
||||
diff_params.steps = params.diffusion.steps;
|
||||
diff_params.algorithm = static_cast<diffusion_algorithm>(params.diffusion.algorithm);
|
||||
diff_params.max_length = params.n_ubatch;
|
||||
diff_params.top_p = params.sampling.top_p;
|
||||
diff_params.top_k = params.sampling.top_k;
|
||||
diff_params.visual_mode = params.diffusion.visual_mode;
|
||||
diff_params.add_gumbel_noise = params.diffusion.add_gumbel_noise;
|
||||
|
||||
diff_params.step_callback = diffusion_step_callback;
|
||||
callback_data cb_data = { &diff_params, vocab, n_input };
|
||||
diff_params.step_callback_user_data = &cb_data;
|
||||
|
||||
const char * alg_names[] = { "ORIGIN", "ENTROPY_BASED", "MARGIN_BASED", "RANDOM", "CONFIDENCE_BASED" };
|
||||
const char * sched_names[] = { "TIMESTEP_BASED", "BLOCK_BASED" };
|
||||
const char * alg_name =
|
||||
(diff_params.algorithm >= 0 && diff_params.algorithm <= 4) ? alg_names[diff_params.algorithm] : "UNKNOWN";
|
||||
const char * sched_name =
|
||||
(diff_params.schedule >= 0 && diff_params.schedule <= 1) ? sched_names[diff_params.schedule] : "UNKNOWN";
|
||||
|
||||
LOG_INF("diffusion_params: - %-25s llama_token = %d\n", "mask_token_id", mask_token_id);
|
||||
LOG_INF("diffusion_params: - %-25s u32 = %d\n", "steps", diff_params.steps);
|
||||
LOG_INF("diffusion_params: - %-25s u32 = %d\n", "max_length", diff_params.max_length);
|
||||
LOG_INF("diffusion_params: - %-25s enum = %d (%s)\n", "algorithm", diff_params.algorithm, alg_name);
|
||||
LOG_INF("diffusion_params: - %-25s enum = %d (%s)\n", "schedule", diff_params.schedule, sched_name);
|
||||
LOG_INF("diffusion_params: - %-25s f32 = %.3f\n", "temperature", diff_params.temperature);
|
||||
if (diff_params.schedule == TIMESTEP_BASED) {
|
||||
LOG_INF("diffusion_params: - %-25s f32 = %.6f\n", "eps", diff_params.eps);
|
||||
LOG_INF("diffusion_params: - %-25s f32 = %.3f\n", "alg_temp", diff_params.alg_temp);
|
||||
}
|
||||
if (diff_params.schedule == BLOCK_BASED) {
|
||||
LOG_INF("diffusion_params: - %-25s u32 = %d\n", "block_length", diff_params.block_length);
|
||||
LOG_INF("diffusion_params: - %-25s f32 = %.3f\n", "cfg_scale", diff_params.cfg_scale);
|
||||
}
|
||||
|
||||
diffusion_generate(ctx, input_tokens.data(), output_tokens.data(), n_input, diff_params, n_generated);
|
||||
|
||||
if (n_generated > 0) {
|
||||
if (visual_mode) {
|
||||
//clear screen and move cursor to top-left
|
||||
LOG_INF("\033[2J\033[H");
|
||||
}
|
||||
|
||||
output_tokens.erase(output_tokens.begin(), output_tokens.begin() + n_input);
|
||||
std::string output_data = common_detokenize(vocab, output_tokens, false);
|
||||
LOG_INF("\n%s\n", output_data.c_str());
|
||||
} else {
|
||||
LOG_INF("Error: diffusion generation failed\n");
|
||||
}
|
||||
|
||||
llama_free(ctx);
|
||||
llama_model_free(model);
|
||||
llama_backend_free();
|
||||
|
||||
return 0;
|
||||
}
|
||||
5
examples/embedding/CMakeLists.txt
Normal file
@@ -0,0 +1,5 @@
|
||||
set(TARGET llama-embedding)
|
||||
add_executable(${TARGET} embedding.cpp)
|
||||
install(TARGETS ${TARGET} RUNTIME)
|
||||
target_link_libraries(${TARGET} PRIVATE common llama ${CMAKE_THREAD_LIBS_INIT})
|
||||
target_compile_features(${TARGET} PRIVATE cxx_std_17)
|
||||
61
examples/embedding/README.md
Normal file
@@ -0,0 +1,61 @@
|
||||
# llama.cpp/example/embedding
|
||||
|
||||
This example demonstrates generate high-dimensional embedding vector of a given text with llama.cpp.
|
||||
|
||||
## Quick Start
|
||||
|
||||
To get started right away, run the following command, making sure to use the correct path for the model you have:
|
||||
|
||||
### Unix-based systems (Linux, macOS, etc.):
|
||||
|
||||
```bash
|
||||
./llama-embedding -m ./path/to/model --pooling mean --log-disable -p "Hello World!" 2>/dev/null
|
||||
```
|
||||
|
||||
### Windows:
|
||||
|
||||
```powershell
|
||||
llama-embedding.exe -m ./path/to/model --pooling mean --log-disable -p "Hello World!" 2>$null
|
||||
```
|
||||
|
||||
The above command will output space-separated float values.
|
||||
|
||||
## extra parameters
|
||||
### --embd-normalize $integer$
|
||||
| $integer$ | description | formula |
|
||||
|-----------|---------------------|---------|
|
||||
| $-1$ | none |
|
||||
| $0$ | max absolute int16 | $\Large{{32760 * x_i} \over\max \lvert x_i\rvert}$
|
||||
| $1$ | taxicab | $\Large{x_i \over\sum \lvert x_i\rvert}$
|
||||
| $2$ | euclidean (default) | $\Large{x_i \over\sqrt{\sum x_i^2}}$
|
||||
| $>2$ | p-norm | $\Large{x_i \over\sqrt[p]{\sum \lvert x_i\rvert^p}}$
|
||||
|
||||
### --embd-output-format $'string'$
|
||||
| $'string'$ | description | |
|
||||
|------------|------------------------------|--|
|
||||
| '' | same as before | (default)
|
||||
| 'array' | single embeddings | $[[x_1,...,x_n]]$
|
||||
| | multiple embeddings | $[[x_1,...,x_n],[x_1,...,x_n],...,[x_1,...,x_n]]$
|
||||
| 'json' | openai style |
|
||||
| 'json+' | add cosine similarity matrix |
|
||||
| 'raw' | plain text output |
|
||||
|
||||
### --embd-separator $"string"$
|
||||
| $"string"$ | |
|
||||
|--------------|-|
|
||||
| "\n" | (default)
|
||||
| "<#embSep#>" | for example
|
||||
| "<#sep#>" | other example
|
||||
|
||||
## examples
|
||||
### Unix-based systems (Linux, macOS, etc.):
|
||||
|
||||
```bash
|
||||
./llama-embedding -p 'Castle<#sep#>Stronghold<#sep#>Dog<#sep#>Cat' --pooling mean --embd-separator '<#sep#>' --embd-normalize 2 --embd-output-format '' -m './path/to/model.gguf' --n-gpu-layers 99 --log-disable 2>/dev/null
|
||||
```
|
||||
|
||||
### Windows:
|
||||
|
||||
```powershell
|
||||
llama-embedding.exe -p 'Castle<#sep#>Stronghold<#sep#>Dog<#sep#>Cat' --pooling mean --embd-separator '<#sep#>' --embd-normalize 2 --embd-output-format '' -m './path/to/model.gguf' --n-gpu-layers 99 --log-disable 2>/dev/null
|
||||
```
|
||||
411
examples/embedding/embedding.cpp
Normal file
@@ -0,0 +1,411 @@
|
||||
#include "arg.h"
|
||||
#include "common.h"
|
||||
#include "log.h"
|
||||
#include "llama.h"
|
||||
|
||||
#include <ctime>
|
||||
#include <algorithm>
|
||||
|
||||
#if defined(_MSC_VER)
|
||||
#pragma warning(disable: 4244 4267) // possible loss of data
|
||||
#endif
|
||||
|
||||
static std::vector<std::string> split_lines(const std::string & s, const std::string & separator = "\n") {
|
||||
std::vector<std::string> lines;
|
||||
size_t start = 0;
|
||||
size_t end = s.find(separator);
|
||||
|
||||
while (end != std::string::npos) {
|
||||
lines.push_back(s.substr(start, end - start));
|
||||
start = end + separator.length();
|
||||
end = s.find(separator, start);
|
||||
}
|
||||
|
||||
lines.push_back(s.substr(start)); // Add the last part
|
||||
|
||||
return lines;
|
||||
}
|
||||
|
||||
static void batch_add_seq(llama_batch & batch, const std::vector<int32_t> & tokens, llama_seq_id seq_id) {
|
||||
size_t n_tokens = tokens.size();
|
||||
for (size_t i = 0; i < n_tokens; i++) {
|
||||
common_batch_add(batch, tokens[i], i, { seq_id }, true);
|
||||
}
|
||||
}
|
||||
|
||||
static void batch_decode(llama_context * ctx, llama_batch & batch, float * output, int n_seq, int n_embd_out, int embd_norm) {
|
||||
const enum llama_pooling_type pooling_type = llama_pooling_type(ctx);
|
||||
|
||||
// clear previous kv_cache values (irrelevant for embeddings)
|
||||
llama_memory_clear(llama_get_memory(ctx), true);
|
||||
|
||||
// run model
|
||||
LOG_INF("%s: n_tokens = %d, n_seq = %d\n", __func__, batch.n_tokens, n_seq);
|
||||
if (llama_decode(ctx, batch) < 0) {
|
||||
LOG_ERR("%s : failed to process\n", __func__);
|
||||
}
|
||||
|
||||
for (int i = 0; i < batch.n_tokens; i++) {
|
||||
if (!batch.logits[i]) {
|
||||
continue;
|
||||
}
|
||||
|
||||
const float * embd = nullptr;
|
||||
int embd_pos = 0;
|
||||
|
||||
if (pooling_type == LLAMA_POOLING_TYPE_NONE) {
|
||||
// try to get token embeddings
|
||||
embd = llama_get_embeddings_ith(ctx, i);
|
||||
embd_pos = i;
|
||||
GGML_ASSERT(embd != NULL && "failed to get token embeddings");
|
||||
} else {
|
||||
// try to get sequence embeddings - supported only when pooling_type is not NONE
|
||||
embd = llama_get_embeddings_seq(ctx, batch.seq_id[i][0]);
|
||||
embd_pos = batch.seq_id[i][0];
|
||||
GGML_ASSERT(embd != NULL && "failed to get sequence embeddings");
|
||||
}
|
||||
|
||||
float * out = output + embd_pos * n_embd_out;
|
||||
common_embd_normalize(embd, out, n_embd_out, embd_norm);
|
||||
}
|
||||
}
|
||||
|
||||
// plain, pipe-friendly output: one embedding per line
|
||||
static void print_raw_embeddings(const float * emb,
|
||||
int n_embd_count,
|
||||
int n_embd,
|
||||
const llama_model * model,
|
||||
enum llama_pooling_type pooling_type,
|
||||
int embd_normalize) {
|
||||
const uint32_t n_cls_out = llama_model_n_cls_out(model);
|
||||
const bool is_rank = (pooling_type == LLAMA_POOLING_TYPE_RANK);
|
||||
const int cols = is_rank ? std::min<int>(n_embd, (int) n_cls_out) : n_embd;
|
||||
|
||||
for (int j = 0; j < n_embd_count; ++j) {
|
||||
for (int i = 0; i < cols; ++i) {
|
||||
if (embd_normalize == 0) {
|
||||
LOG("%1.0f%s", emb[j * n_embd + i], (i + 1 < cols ? " " : ""));
|
||||
} else {
|
||||
LOG("%1.7f%s", emb[j * n_embd + i], (i + 1 < cols ? " " : ""));
|
||||
}
|
||||
}
|
||||
LOG("\n");
|
||||
}
|
||||
}
|
||||
|
||||
int main(int argc, char ** argv) {
|
||||
common_params params;
|
||||
|
||||
if (!common_params_parse(argc, argv, params, LLAMA_EXAMPLE_EMBEDDING)) {
|
||||
return 1;
|
||||
}
|
||||
|
||||
common_init();
|
||||
|
||||
params.embedding = true;
|
||||
|
||||
// get max number of sequences per batch
|
||||
const int n_seq_max = llama_max_parallel_sequences();
|
||||
|
||||
// if the number of prompts that would be encoded is known in advance, it's more efficient to specify the
|
||||
// --parallel argument accordingly. for convenience, if not specified, we fallback to unified KV cache
|
||||
// in order to support any number of prompts
|
||||
if (params.n_parallel == 1) {
|
||||
LOG_INF("%s: n_parallel == 1 -> unified KV cache is enabled\n", __func__);
|
||||
params.kv_unified = true;
|
||||
params.n_parallel = n_seq_max;
|
||||
}
|
||||
|
||||
// utilize the full context
|
||||
if (params.n_batch < params.n_ctx) {
|
||||
LOG_WRN("%s: setting batch size to %d\n", __func__, params.n_ctx);
|
||||
params.n_batch = params.n_ctx;
|
||||
}
|
||||
|
||||
// for non-causal models, batch size must be equal to ubatch size
|
||||
if (params.attention_type != LLAMA_ATTENTION_TYPE_CAUSAL) {
|
||||
params.n_ubatch = params.n_batch;
|
||||
}
|
||||
|
||||
llama_backend_init();
|
||||
llama_numa_init(params.numa);
|
||||
|
||||
// load the model
|
||||
auto llama_init = common_init_from_params(params);
|
||||
|
||||
auto * model = llama_init->model();
|
||||
auto * ctx = llama_init->context();
|
||||
|
||||
if (model == NULL) {
|
||||
LOG_ERR("%s: unable to load model\n", __func__);
|
||||
return 1;
|
||||
}
|
||||
|
||||
const llama_vocab * vocab = llama_model_get_vocab(model);
|
||||
|
||||
const int n_ctx_train = llama_model_n_ctx_train(model);
|
||||
const int n_ctx = llama_n_ctx(ctx);
|
||||
|
||||
const enum llama_pooling_type pooling_type = llama_pooling_type(ctx);
|
||||
|
||||
if (llama_model_has_encoder(model) && llama_model_has_decoder(model)) {
|
||||
LOG_ERR("%s: computing embeddings in encoder-decoder models is not supported\n", __func__);
|
||||
return 1;
|
||||
}
|
||||
|
||||
if (n_ctx > n_ctx_train) {
|
||||
LOG_WRN("%s: warning: model was trained on only %d context tokens (%d specified)\n",
|
||||
__func__, n_ctx_train, n_ctx);
|
||||
}
|
||||
|
||||
// print system information
|
||||
{
|
||||
LOG_INF("\n");
|
||||
LOG_INF("%s\n", common_params_get_system_info(params).c_str());
|
||||
}
|
||||
|
||||
// split the prompt into lines
|
||||
std::vector<std::string> prompts = split_lines(params.prompt, params.embd_sep);
|
||||
|
||||
// max batch size
|
||||
const uint64_t n_batch = params.n_batch;
|
||||
|
||||
// get added sep and eos token, if any
|
||||
const std::string added_sep_token = llama_vocab_get_add_sep(vocab) ? llama_vocab_get_text(vocab, llama_vocab_sep(vocab)) : "";
|
||||
const std::string added_eos_token = llama_vocab_get_add_eos(vocab) ? llama_vocab_get_text(vocab, llama_vocab_eos(vocab)) : "";
|
||||
const char * rerank_prompt = llama_model_chat_template(model, "rerank");
|
||||
|
||||
// tokenize the prompts and trim
|
||||
std::vector<std::vector<int32_t>> inputs;
|
||||
for (const auto & prompt : prompts) {
|
||||
std::vector<llama_token> inp;
|
||||
|
||||
// split classification pairs and insert expected separator tokens
|
||||
if (pooling_type == LLAMA_POOLING_TYPE_RANK && prompt.find(params.cls_sep) != std::string::npos) {
|
||||
std::vector<std::string> pairs = split_lines(prompt, params.cls_sep);
|
||||
if (rerank_prompt != nullptr) {
|
||||
const std::string query = pairs[0];
|
||||
const std::string doc = pairs[1];
|
||||
std::string final_prompt = rerank_prompt;
|
||||
string_replace_all(final_prompt, "{query}" , query);
|
||||
string_replace_all(final_prompt, "{document}", doc );
|
||||
inp = common_tokenize(vocab, final_prompt, true, true);
|
||||
} else {
|
||||
std::string final_prompt;
|
||||
for (size_t i = 0; i < pairs.size(); i++) {
|
||||
final_prompt += pairs[i];
|
||||
if (i != pairs.size() - 1) {
|
||||
if (!added_eos_token.empty()) {
|
||||
final_prompt += added_eos_token;
|
||||
}
|
||||
if (!added_sep_token.empty()) {
|
||||
final_prompt += added_sep_token;
|
||||
}
|
||||
}
|
||||
}
|
||||
inp = common_tokenize(ctx, final_prompt, true, true);
|
||||
}
|
||||
} else {
|
||||
inp = common_tokenize(ctx, prompt, true, true);
|
||||
}
|
||||
if (inp.size() > n_batch) {
|
||||
LOG_ERR("%s: number of tokens in input line (%lld) exceeds batch size (%lld), increase batch size and re-run\n",
|
||||
__func__, (long long int) inp.size(), (long long int) n_batch);
|
||||
return 1;
|
||||
}
|
||||
inputs.push_back(inp);
|
||||
}
|
||||
|
||||
// check if the last token is SEP/EOS
|
||||
// it should be automatically added by the tokenizer when 'tokenizer.ggml.add_eos_token' is set to 'true'
|
||||
for (auto & inp : inputs) {
|
||||
if (inp.empty() || (inp.back() != llama_vocab_sep(vocab) && inp.back() != llama_vocab_eos(vocab))) {
|
||||
LOG_WRN("%s: last token in the prompt is not SEP or EOS\n", __func__);
|
||||
LOG_WRN("%s: 'tokenizer.ggml.add_eos_token' should be set to 'true' in the GGUF header\n", __func__);
|
||||
}
|
||||
}
|
||||
|
||||
// tokenization stats
|
||||
if (params.verbose_prompt) {
|
||||
for (int i = 0; i < (int) inputs.size(); i++) {
|
||||
LOG_INF("%s: prompt %d: '%s'\n", __func__, i, prompts[i].c_str());
|
||||
LOG_INF("%s: number of tokens in prompt = %zu\n", __func__, inputs[i].size());
|
||||
for (int j = 0; j < (int) inputs[i].size(); j++) {
|
||||
LOG("%6d -> '%s'\n", inputs[i][j], common_token_to_piece(ctx, inputs[i][j]).c_str());
|
||||
}
|
||||
LOG("\n\n");
|
||||
}
|
||||
}
|
||||
|
||||
// initialize batch
|
||||
const int n_prompts = prompts.size();
|
||||
struct llama_batch batch = llama_batch_init(n_batch, 0, 1);
|
||||
|
||||
// count number of embeddings
|
||||
int n_embd_count = 0;
|
||||
if (pooling_type == LLAMA_POOLING_TYPE_NONE) {
|
||||
for (int k = 0; k < n_prompts; k++) {
|
||||
n_embd_count += inputs[k].size();
|
||||
}
|
||||
} else {
|
||||
n_embd_count = n_prompts;
|
||||
}
|
||||
|
||||
// allocate output
|
||||
const int n_embd_out = llama_model_n_embd_out(model);
|
||||
std::vector<float> embeddings(n_embd_count * n_embd_out, 0);
|
||||
float * emb = embeddings.data();
|
||||
|
||||
// break into batches
|
||||
int e = 0; // number of embeddings already stored
|
||||
int s = 0; // number of prompts in current batch
|
||||
for (int k = 0; k < n_prompts; k++) {
|
||||
// clamp to n_batch tokens
|
||||
auto & inp = inputs[k];
|
||||
|
||||
const uint64_t n_toks = inp.size();
|
||||
|
||||
// encode if at capacity
|
||||
if (batch.n_tokens + n_toks > n_batch || s >= n_seq_max) {
|
||||
float * out = emb + e * n_embd_out;
|
||||
batch_decode(ctx, batch, out, s, n_embd_out, params.embd_normalize);
|
||||
e += pooling_type == LLAMA_POOLING_TYPE_NONE ? batch.n_tokens : s;
|
||||
s = 0;
|
||||
common_batch_clear(batch);
|
||||
}
|
||||
|
||||
// add to batch
|
||||
batch_add_seq(batch, inp, s);
|
||||
s += 1;
|
||||
}
|
||||
|
||||
// final batch
|
||||
float * out = emb + e * n_embd_out;
|
||||
batch_decode(ctx, batch, out, s, n_embd_out, params.embd_normalize);
|
||||
|
||||
if (params.embd_out.empty()) {
|
||||
LOG("\n");
|
||||
|
||||
if (pooling_type == LLAMA_POOLING_TYPE_NONE) {
|
||||
for (int j = 0; j < n_embd_count; j++) {
|
||||
LOG("embedding %d: ", j);
|
||||
for (int i = 0; i < std::min(3, n_embd_out); i++) {
|
||||
if (params.embd_normalize == 0) {
|
||||
LOG("%6.0f ", emb[j * n_embd_out + i]);
|
||||
} else {
|
||||
LOG("%9.6f ", emb[j * n_embd_out + i]);
|
||||
}
|
||||
}
|
||||
LOG(" ... ");
|
||||
for (int i = n_embd_out - 3; i < n_embd_out; i++) {
|
||||
if (params.embd_normalize == 0) {
|
||||
LOG("%6.0f ", emb[j * n_embd_out + i]);
|
||||
} else {
|
||||
LOG("%9.6f ", emb[j * n_embd_out + i]);
|
||||
}
|
||||
}
|
||||
LOG("\n");
|
||||
}
|
||||
} else if (pooling_type == LLAMA_POOLING_TYPE_RANK) {
|
||||
const uint32_t n_cls_out = llama_model_n_cls_out(model);
|
||||
std::vector<std::string> cls_out_labels;
|
||||
|
||||
for (uint32_t i = 0; i < n_cls_out; i++) {
|
||||
const char * label = llama_model_cls_label(model, i);
|
||||
const std::string label_i(label == nullptr ? "" : label);
|
||||
cls_out_labels.emplace_back(label_i.empty() ? std::to_string(i) : label_i);
|
||||
}
|
||||
|
||||
for (int j = 0; j < n_embd_count; j++) {
|
||||
for (uint32_t i = 0; i < n_cls_out; i++) {
|
||||
// NOTE: if you change this log - update the tests in ci/run.sh
|
||||
if (n_cls_out == 1) {
|
||||
LOG("rerank score %d: %8.3f\n", j, emb[j * n_embd_out]);
|
||||
} else {
|
||||
LOG("rerank score %d: %8.3f [%s]\n", j, emb[j * n_embd_out + i], cls_out_labels[i].c_str());
|
||||
}
|
||||
}
|
||||
}
|
||||
} else {
|
||||
// print the first part of the embeddings or for a single prompt, the full embedding
|
||||
for (int j = 0; j < n_prompts; j++) {
|
||||
LOG("embedding %d: ", j);
|
||||
for (int i = 0; i < (n_prompts > 1 ? std::min(16, n_embd_out) : n_embd_out); i++) {
|
||||
if (params.embd_normalize == 0) {
|
||||
LOG("%6.0f ", emb[j * n_embd_out + i]);
|
||||
} else {
|
||||
LOG("%9.6f ", emb[j * n_embd_out + i]);
|
||||
}
|
||||
}
|
||||
LOG("\n");
|
||||
}
|
||||
|
||||
// print cosine similarity matrix
|
||||
if (n_prompts > 1) {
|
||||
LOG("\n");
|
||||
LOG("cosine similarity matrix:\n\n");
|
||||
for (int i = 0; i < n_prompts; i++) {
|
||||
LOG("%6.6s ", prompts[i].c_str());
|
||||
}
|
||||
LOG("\n");
|
||||
for (int i = 0; i < n_prompts; i++) {
|
||||
for (int j = 0; j < n_prompts; j++) {
|
||||
float sim = common_embd_similarity_cos(emb + i * n_embd_out, emb + j * n_embd_out, n_embd_out);
|
||||
LOG("%6.2f ", sim);
|
||||
}
|
||||
LOG("%1.10s", prompts[i].c_str());
|
||||
LOG("\n");
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if (params.embd_out == "json" || params.embd_out == "json+" || params.embd_out == "array") {
|
||||
const bool notArray = params.embd_out != "array";
|
||||
|
||||
LOG(notArray ? "{\n \"object\": \"list\",\n \"data\": [\n" : "[");
|
||||
for (int j = 0;;) { // at least one iteration (one prompt)
|
||||
if (notArray) LOG(" {\n \"object\": \"embedding\",\n \"index\": %d,\n \"embedding\": ",j);
|
||||
LOG("[");
|
||||
for (int i = 0;;) { // at least one iteration (n_embd > 0)
|
||||
LOG(params.embd_normalize == 0 ? "%1.0f" : "%1.7f", emb[j * n_embd_out + i]);
|
||||
i++;
|
||||
if (i < n_embd_out) LOG(","); else break;
|
||||
}
|
||||
LOG(notArray ? "]\n }" : "]");
|
||||
j++;
|
||||
if (j < n_embd_count) LOG(notArray ? ",\n" : ","); else break;
|
||||
}
|
||||
LOG(notArray ? "\n ]" : "]\n");
|
||||
|
||||
if (params.embd_out == "json+" && n_prompts > 1) {
|
||||
LOG(",\n \"cosineSimilarity\": [\n");
|
||||
for (int i = 0;;) { // at least two iteration (n_embd_count > 1)
|
||||
LOG(" [");
|
||||
for (int j = 0;;) { // at least two iteration (n_embd_count > 1)
|
||||
float sim = common_embd_similarity_cos(emb + i * n_embd_out, emb + j * n_embd_out, n_embd_out);
|
||||
LOG("%6.2f", sim);
|
||||
j++;
|
||||
if (j < n_embd_count) LOG(", "); else break;
|
||||
}
|
||||
LOG(" ]");
|
||||
i++;
|
||||
if (i < n_embd_count) LOG(",\n"); else break;
|
||||
}
|
||||
LOG("\n ]");
|
||||
}
|
||||
|
||||
if (notArray) LOG("\n}\n");
|
||||
} else if (params.embd_out == "raw") {
|
||||
print_raw_embeddings(emb, n_embd_count, n_embd_out, model, pooling_type, params.embd_normalize);
|
||||
}
|
||||
|
||||
LOG("\n");
|
||||
llama_perf_context_print(ctx);
|
||||
|
||||
// clean up
|
||||
llama_batch_free(batch);
|
||||
llama_backend_free();
|
||||
|
||||
return 0;
|
||||
}
|
||||
26
examples/eval-callback/CMakeLists.txt
Normal file
@@ -0,0 +1,26 @@
|
||||
set(TARGET llama-eval-callback)
|
||||
add_executable(${TARGET} eval-callback.cpp)
|
||||
install(TARGETS ${TARGET} RUNTIME)
|
||||
target_link_libraries(${TARGET} PRIVATE common llama ${CMAKE_THREAD_LIBS_INIT})
|
||||
target_compile_features(${TARGET} PRIVATE cxx_std_17)
|
||||
|
||||
if(LLAMA_BUILD_TESTS)
|
||||
if(NOT ${CMAKE_SYSTEM_PROCESSOR} MATCHES "s390x")
|
||||
set(MODEL_NAME "tinyllamas/stories15M-q4_0.gguf")
|
||||
set(MODEL_HASH "SHA256=66967fbece6dbe97886593fdbb73589584927e29119ec31f08090732d1861739")
|
||||
else()
|
||||
set(MODEL_NAME "tinyllamas/stories15M-be.Q4_0.gguf")
|
||||
set(MODEL_HASH "SHA256=9aec857937849d976f30397e97eb1cabb53eb9dcb1ce4611ba8247fb5f44c65d")
|
||||
endif()
|
||||
set(MODEL_DEST "${CMAKE_BINARY_DIR}/${MODEL_NAME}")
|
||||
set(TEST_TARGET test-eval-callback)
|
||||
add_test(NAME ${TEST_TARGET}-download-model COMMAND ${CMAKE_COMMAND}
|
||||
-DDEST=${MODEL_DEST}
|
||||
-DNAME=${MODEL_NAME}
|
||||
-DHASH=${MODEL_HASH}
|
||||
-P ${CMAKE_SOURCE_DIR}/cmake/download-models.cmake
|
||||
)
|
||||
set_tests_properties(${TEST_TARGET}-download-model PROPERTIES FIXTURES_SETUP ${TEST_TARGET}-download-model)
|
||||
add_test(NAME ${TEST_TARGET} COMMAND llama-eval-callback -m "${MODEL_DEST}" --prompt hello --seed 42 -ngl 0)
|
||||
set_tests_properties(${TEST_TARGET} PROPERTIES FIXTURES_REQUIRED ${TEST_TARGET}-download-model)
|
||||
endif()
|
||||
95
examples/eval-callback/README.md
Normal file
@@ -0,0 +1,95 @@
|
||||
# llama.cpp/examples/eval-callback
|
||||
|
||||
A simple example which demonstrates how to use callback during the inference.
|
||||
It simply prints to the console all operations and tensor data.
|
||||
|
||||
Usage:
|
||||
|
||||
```shell
|
||||
llama-eval-callback \
|
||||
--hf-repo ggml-org/models \
|
||||
--hf-file phi-2/ggml-model-q4_0.gguf \
|
||||
--model phi-2-q4_0.gguf \
|
||||
--prompt hello \
|
||||
--seed 42 \
|
||||
-ngl 33
|
||||
```
|
||||
|
||||
Will print:
|
||||
|
||||
```shell
|
||||
llm_load_tensors: offloaded 33/33 layers to GPU
|
||||
...
|
||||
llama_new_context_with_model: n_ctx = 512
|
||||
...
|
||||
llama_new_context_with_model: CUDA0 compute buffer size = 105.00 MiB
|
||||
llama_new_context_with_model: CUDA_Host compute buffer size = 6.01 MiB
|
||||
llama_new_context_with_model: graph nodes = 1225
|
||||
llama_new_context_with_model: graph splits = 2
|
||||
ggml_debug: inp_embd = (f32) GET_ROWS(token_embd.weight{2560, 51200, 1, 1}, inp_tokens{1, 1, 1, 1}}) = {2560, 1, 1, 1}
|
||||
[
|
||||
[
|
||||
[ -0.0181, 0.0272, 0.0272, ...],
|
||||
],
|
||||
]
|
||||
ggml_debug: norm-0 = (f32) NORM(CUDA0#inp_embd#0{2560, 1, 1, 1}, }) = {2560, 1, 1, 1}
|
||||
[
|
||||
[
|
||||
[ -0.6989, 1.0636, 1.0636, ...],
|
||||
],
|
||||
]
|
||||
ggml_debug: norm_w-0 = (f32) MUL(norm-0{2560, 1, 1, 1}, blk.0.attn_norm.weight{2560, 1, 1, 1}}) = {2560, 1, 1, 1}
|
||||
[
|
||||
[
|
||||
[ -0.1800, 0.2817, 0.2632, ...],
|
||||
],
|
||||
]
|
||||
ggml_debug: attn_norm-0 = (f32) ADD(norm_w-0{2560, 1, 1, 1}, blk.0.attn_norm.bias{2560, 1, 1, 1}}) = {2560, 1, 1, 1}
|
||||
[
|
||||
[
|
||||
[ -0.1863, 0.2970, 0.2604, ...],
|
||||
],
|
||||
]
|
||||
ggml_debug: wqkv-0 = (f32) MUL_MAT(blk.0.attn_qkv.weight{2560, 7680, 1, 1}, attn_norm-0{2560, 1, 1, 1}}) = {7680, 1, 1, 1}
|
||||
[
|
||||
[
|
||||
[ -1.1238, 1.2876, -1.8086, ...],
|
||||
],
|
||||
]
|
||||
ggml_debug: bqkv-0 = (f32) ADD(wqkv-0{7680, 1, 1, 1}, blk.0.attn_qkv.bias{7680, 1, 1, 1}}) = {7680, 1, 1, 1}
|
||||
[
|
||||
[
|
||||
[ -1.1135, 1.4604, -1.9226, ...],
|
||||
],
|
||||
]
|
||||
ggml_debug: bqkv-0 (view) = (f32) VIEW(bqkv-0{7680, 1, 1, 1}, }) = {2560, 1, 1, 1}
|
||||
[
|
||||
[
|
||||
[ -1.1135, 1.4604, -1.9226, ...],
|
||||
],
|
||||
]
|
||||
ggml_debug: Qcur-0 = (f32) CONT(bqkv-0 (view){2560, 1, 1, 1}, }) = {2560, 1, 1, 1}
|
||||
[
|
||||
[
|
||||
[ -1.1135, 1.4604, -1.9226, ...],
|
||||
],
|
||||
]
|
||||
ggml_debug: Qcur-0 (reshaped) = (f32) RESHAPE(Qcur-0{2560, 1, 1, 1}, }) = {80, 32, 1, 1}
|
||||
[
|
||||
[
|
||||
[ -1.1135, 1.4604, -1.9226, ...],
|
||||
[ -0.3608, 0.5076, -1.8866, ...],
|
||||
[ 1.7643, 0.0273, -2.1065, ...],
|
||||
...
|
||||
],
|
||||
]
|
||||
ggml_debug: Qcur-0 = (f32) ROPE(Qcur-0 (reshaped){80, 32, 1, 1}, CUDA0#inp_pos#0{1, 1, 1, 1}}) = {80, 32, 1, 1}
|
||||
[
|
||||
[
|
||||
[ -1.1135, 1.4604, -1.9226, ...],
|
||||
[ -0.3608, 0.5076, -1.8866, ...],
|
||||
[ 1.7643, 0.0273, -2.1065, ...],
|
||||
...
|
||||
],
|
||||
]
|
||||
```
|
||||
80
examples/eval-callback/eval-callback.cpp
Normal file
@@ -0,0 +1,80 @@
|
||||
#include "arg.h"
|
||||
#include "common.h"
|
||||
#include "debug.h"
|
||||
#include "log.h"
|
||||
#include "llama.h"
|
||||
#include "llama-cpp.h"
|
||||
#include <string>
|
||||
#include <vector>
|
||||
|
||||
static bool run(llama_context * ctx, const common_params & params) {
|
||||
const llama_model * model = llama_get_model(ctx);
|
||||
const llama_vocab * vocab = llama_model_get_vocab(model);
|
||||
|
||||
const bool add_bos = llama_vocab_get_add_bos(vocab);
|
||||
|
||||
std::vector<llama_token> tokens = common_tokenize(ctx, params.prompt, add_bos);
|
||||
|
||||
if (tokens.empty()) {
|
||||
LOG_ERR("%s : there are not input tokens to process - (try to provide a prompt with '-p')\n", __func__);
|
||||
return false;
|
||||
}
|
||||
|
||||
if (llama_decode(ctx, llama_batch_get_one(tokens.data(), tokens.size()))) {
|
||||
LOG_ERR("%s : failed to eval\n", __func__);
|
||||
return false;
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
int main(int argc, char ** argv) {
|
||||
base_callback_data cb_data;
|
||||
|
||||
common_params params;
|
||||
|
||||
if (!common_params_parse(argc, argv, params, LLAMA_EXAMPLE_COMMON)) {
|
||||
return 1;
|
||||
}
|
||||
|
||||
common_init();
|
||||
|
||||
llama_backend_init();
|
||||
llama_numa_init(params.numa);
|
||||
|
||||
// pass the callback to the backend scheduler
|
||||
// it will be executed for each node during the graph computation
|
||||
params.cb_eval = common_debug_cb_eval<false>;
|
||||
params.cb_eval_user_data = &cb_data;
|
||||
params.warmup = false;
|
||||
|
||||
// init
|
||||
auto llama_init = common_init_from_params(params);
|
||||
|
||||
auto * model = llama_init->model();
|
||||
auto * ctx = llama_init->context();
|
||||
|
||||
if (model == nullptr || ctx == nullptr) {
|
||||
LOG_ERR("%s : failed to init\n", __func__);
|
||||
return 1;
|
||||
}
|
||||
|
||||
// print system information
|
||||
{
|
||||
LOG_INF("\n");
|
||||
LOG_INF("%s\n", common_params_get_system_info(params).c_str());
|
||||
LOG_INF("\n");
|
||||
}
|
||||
|
||||
bool OK = run(ctx, params);
|
||||
if (!OK) {
|
||||
return 1;
|
||||
}
|
||||
|
||||
LOG("\n");
|
||||
llama_perf_context_print(ctx);
|
||||
|
||||
llama_backend_free();
|
||||
|
||||
return 0;
|
||||
}
|
||||
5
examples/gen-docs/CMakeLists.txt
Normal file
@@ -0,0 +1,5 @@
|
||||
set(TARGET llama-gen-docs)
|
||||
add_executable(${TARGET} gen-docs.cpp)
|
||||
install(TARGETS ${TARGET} RUNTIME)
|
||||
target_link_libraries(${TARGET} PRIVATE common llama ${CMAKE_THREAD_LIBS_INIT})
|
||||
target_compile_features(${TARGET} PRIVATE cxx_std_17)
|
||||
142
examples/gen-docs/gen-docs.cpp
Normal file
@@ -0,0 +1,142 @@
|
||||
#include "arg.h"
|
||||
#include "common.h"
|
||||
|
||||
#include <fstream>
|
||||
#include <sstream>
|
||||
#include <string>
|
||||
|
||||
// Export usage message (-h) to markdown format
|
||||
// Automatically update the markdown docs
|
||||
|
||||
#define HELP_START_MARKER "<!-- HELP_START -->"
|
||||
#define HELP_END_MARKER "<!-- HELP_END -->"
|
||||
#define NOTE_MESSAGE "<!-- IMPORTANT: The list below is auto-generated by llama-gen-docs; do NOT modify it manually -->"
|
||||
|
||||
struct md_file {
|
||||
llama_example ex;
|
||||
std::string fname;
|
||||
std::string specific_section_header;
|
||||
};
|
||||
|
||||
std::vector<md_file> md_files = {
|
||||
{LLAMA_EXAMPLE_CLI, "tools/cli/README.md", "CLI-specific params"},
|
||||
{LLAMA_EXAMPLE_COMPLETION, "tools/completion/README.md", "Completion-specific params"},
|
||||
{LLAMA_EXAMPLE_SERVER, "tools/server/README.md", "Server-specific params"},
|
||||
};
|
||||
|
||||
static void write_table_header(std::ostringstream & ss) {
|
||||
ss << "| Argument | Explanation |\n";
|
||||
ss << "| -------- | ----------- |\n";
|
||||
}
|
||||
|
||||
static void write_table_entry(std::ostringstream & ss, const common_arg & opt) {
|
||||
ss << "| `";
|
||||
// args
|
||||
auto all_args = opt.get_args();
|
||||
for (const auto & arg : all_args) {
|
||||
if (arg == all_args.front()) {
|
||||
ss << arg;
|
||||
if (all_args.size() > 1) ss << ", ";
|
||||
} else {
|
||||
ss << arg << (arg != all_args.back() ? ", " : "");
|
||||
}
|
||||
}
|
||||
// value hint
|
||||
if (opt.value_hint) {
|
||||
std::string md_value_hint(opt.value_hint);
|
||||
string_replace_all(md_value_hint, "|", "\\|");
|
||||
ss << " " << md_value_hint;
|
||||
}
|
||||
if (opt.value_hint_2) {
|
||||
std::string md_value_hint_2(opt.value_hint_2);
|
||||
string_replace_all(md_value_hint_2, "|", "\\|");
|
||||
ss << " " << md_value_hint_2;
|
||||
}
|
||||
// help text
|
||||
std::string md_help(opt.help);
|
||||
md_help = string_strip(md_help);
|
||||
string_replace_all(md_help, "\n", "<br/>");
|
||||
string_replace_all(md_help, "|", "\\|");
|
||||
ss << "` | " << md_help << " |\n";
|
||||
}
|
||||
|
||||
static void write_table(std::ostringstream & ss, std::vector<common_arg *> & opts) {
|
||||
write_table_header(ss);
|
||||
for (const auto & opt : opts) {
|
||||
write_table_entry(ss, *opt);
|
||||
}
|
||||
}
|
||||
|
||||
static void write_help(std::ostringstream & ss, const md_file & md) {
|
||||
common_params params;
|
||||
auto ctx_arg = common_params_parser_init(params, md.ex);
|
||||
|
||||
std::vector<common_arg *> common_options;
|
||||
std::vector<common_arg *> sparam_options;
|
||||
std::vector<common_arg *> specific_options;
|
||||
for (auto & opt : ctx_arg.options) {
|
||||
// in case multiple LLAMA_EXAMPLE_* are set, we prioritize the LLAMA_EXAMPLE_* matching current example
|
||||
if (opt.is_sparam) {
|
||||
sparam_options.push_back(&opt);
|
||||
} else if (opt.in_example(ctx_arg.ex)) {
|
||||
specific_options.push_back(&opt);
|
||||
} else {
|
||||
common_options.push_back(&opt);
|
||||
}
|
||||
}
|
||||
|
||||
ss << HELP_START_MARKER << "\n\n";
|
||||
|
||||
ss << NOTE_MESSAGE << "\n\n";
|
||||
|
||||
ss << "### Common params\n\n";
|
||||
write_table(ss, common_options);
|
||||
ss << "\n\n### Sampling params\n\n";
|
||||
write_table(ss, sparam_options);
|
||||
ss << "\n\n### " << md.specific_section_header << "\n\n";
|
||||
write_table(ss, specific_options);
|
||||
|
||||
ss << "\n" << HELP_END_MARKER;
|
||||
}
|
||||
|
||||
int main(int, char **) {
|
||||
for (const auto & md : md_files) {
|
||||
std::ifstream infile(md.fname);
|
||||
if (!infile.is_open()) {
|
||||
fprintf(stderr, "failed to open file '%s' for reading\n", md.fname.c_str());
|
||||
return 1;
|
||||
}
|
||||
|
||||
std::ostringstream ss;
|
||||
ss << infile.rdbuf();
|
||||
infile.close();
|
||||
|
||||
std::string content = ss.str();
|
||||
|
||||
size_t help_start = content.find(HELP_START_MARKER);
|
||||
size_t help_end = content.find(HELP_END_MARKER);
|
||||
|
||||
if (help_start == std::string::npos || help_end == std::string::npos || help_end <= help_start) {
|
||||
fprintf(stderr, "failed to find help markers in file '%s'\n", md.fname.c_str());
|
||||
return 1;
|
||||
}
|
||||
|
||||
std::ostringstream new_help_ss;
|
||||
write_help(new_help_ss, md);
|
||||
std::string new_help = new_help_ss.str();
|
||||
|
||||
content = content.substr(0, help_start) + new_help + content.substr(help_end + strlen(HELP_END_MARKER));
|
||||
|
||||
std::ofstream outfile(md.fname);
|
||||
if (!outfile.is_open()) {
|
||||
fprintf(stderr, "failed to open file '%s' for writing\n", md.fname.c_str());
|
||||
return 1;
|
||||
}
|
||||
outfile << content;
|
||||
outfile.close();
|
||||
|
||||
printf("Updated help in '%s'\n", md.fname.c_str());
|
||||
}
|
||||
|
||||
return 0;
|
||||
}
|
||||
22
examples/gguf-hash/CMakeLists.txt
Normal file
@@ -0,0 +1,22 @@
|
||||
set(TARGET llama-gguf-hash)
|
||||
add_executable(${TARGET} gguf-hash.cpp)
|
||||
install(TARGETS ${TARGET} RUNTIME)
|
||||
|
||||
# clibs dependencies
|
||||
include_directories(deps/)
|
||||
|
||||
add_library(xxhash OBJECT deps/xxhash/xxhash.c deps/xxhash/xxhash.h)
|
||||
target_link_libraries(${TARGET} PRIVATE xxhash)
|
||||
|
||||
add_library(sha1 OBJECT deps/sha1/sha1.c deps/sha1/sha1.h)
|
||||
target_link_libraries(${TARGET} PRIVATE sha1)
|
||||
if (NOT MSVC)
|
||||
# disable warnings in 3rd party code
|
||||
target_compile_options(sha1 PRIVATE -w)
|
||||
endif()
|
||||
|
||||
add_library(sha256 OBJECT deps/sha256/sha256.c deps/sha256/sha256.h)
|
||||
target_link_libraries(${TARGET} PRIVATE sha256)
|
||||
|
||||
target_link_libraries(${TARGET} PRIVATE ggml ${CMAKE_THREAD_LIBS_INIT})
|
||||
target_compile_features(${TARGET} PRIVATE cxx_std_17)
|
||||
206
examples/gguf-hash/README.md
Normal file
@@ -0,0 +1,206 @@
|
||||
|
||||
# llama-gguf-hash
|
||||
|
||||
CLI to hash GGUF files to detect difference on a per model and per tensor level.
|
||||
|
||||
**Command line options:**
|
||||
|
||||
- `--help`: display help message
|
||||
- `--xxh64`: use xhash 64bit hash mode (default)
|
||||
- `--sha1`: use sha1
|
||||
- `--uuid`: use uuid
|
||||
- `--sha256`: use sha256
|
||||
- `--all`: use all hash
|
||||
- `--no-layer`: exclude per layer hash
|
||||
- `--uuid`: generate UUIDv5 ID
|
||||
- `-c`, `--check <manifest>`: verify against a manifest
|
||||
|
||||
## About
|
||||
|
||||
While most POSIX systems already have hash checking programs like sha256sum, it
|
||||
is designed to check entire files. This is not ideal for our purpose if we want
|
||||
to check for consistency of the tensor data even if the metadata content of the
|
||||
gguf KV store has been updated.
|
||||
|
||||
This program is designed to hash a gguf tensor payload on a 'per tensor layer'
|
||||
in addition to a 'entire tensor model' hash. The intent is that the entire
|
||||
tensor layer can be checked first but if there is any detected inconsistencies,
|
||||
then the per tensor hash can be used to narrow down the specific tensor layer
|
||||
that has inconsistencies.
|
||||
|
||||
For Maintainers:
|
||||
- Detection of tensor inconsistency during development and automated tests
|
||||
- This is served by xxh64 which is fast
|
||||
- This is also served by having per tensor layer to assist in narrowing down
|
||||
the location of the faulty tensor layer
|
||||
- This is also served by sha1 which is much slower but more widely supported
|
||||
|
||||
For Model Creators:
|
||||
- Optional consistent UUID generation based on model tensor content
|
||||
- This is served by UUIDv5 which is useful for databases keys
|
||||
- llama.cpp UUIDv5 Namespace: `ef001206-dadc-5f6d-a15f-3359e577d4e5`
|
||||
- Made via UUIDv5 URL namespace of `en.wikipedia.org/wiki/Llama.cpp`
|
||||
|
||||
For Model Users:
|
||||
- Assurance of tensor layer integrity even if metadata was updated
|
||||
- This is served by sha256 which is still considered very secure as of 2024
|
||||
|
||||
### Design Note
|
||||
|
||||
- The default behavior of this program if no arguments is provided is to hash
|
||||
using xxhash's xxh32 mode because it is very fast and is primarily targeted
|
||||
towards maintainers who may want to use this in automated tests.
|
||||
- xxhash support xxh32 and xxh128 for 32bit hash and 128bit hash respectively
|
||||
however we picked 64bit xxhash as most computers are 64bit as of 2024 and thus
|
||||
would have a better affinity to calculating hash that is 64bit in size.
|
||||
|
||||
## Compile Example
|
||||
|
||||
```bash
|
||||
cmake -B build -DCMAKE_BUILD_TYPE=Debug -DLLAMA_FATAL_WARNINGS=ON
|
||||
make -C build clean
|
||||
make -C build llama-gguf-hash VERBOSE=1
|
||||
./build/bin/llama-gguf-hash test.gguf
|
||||
./build/bin/llama-gguf-hash --xxh64 test.gguf
|
||||
./build/bin/llama-gguf-hash --sha1 test.gguf
|
||||
./build/bin/llama-gguf-hash --uuid test.gguf
|
||||
./build/bin/llama-gguf-hash --sha256 test.gguf
|
||||
```
|
||||
|
||||
## Generation and Verification Example
|
||||
|
||||
To generate we may use this command
|
||||
|
||||
```bash
|
||||
./llama-gguf-hash --all test.gguf > test.gguf.manifest
|
||||
```
|
||||
|
||||
Which would generate a manifest that looks like below, which contains multiple hash type and per tensor layer hashes as well
|
||||
(This excludes UUID as that is an ID not a hash)
|
||||
|
||||
```bash
|
||||
xxh64 f66e9cd66a4396a0 test.gguf:tensor_0
|
||||
sha1 59f79ecefd8125a996fdf419239051a7e99e5f20 test.gguf:tensor_0
|
||||
sha256 c0510d38fa060c46265e0160a85c7243096b01dd31c2f355bdbb5516b20de1bd test.gguf:tensor_0
|
||||
xxh64 7d3a1f9ac04d0537 test.gguf:tensor_1
|
||||
sha1 4765f592eacf096df4628ba59476af94d767080a test.gguf:tensor_1
|
||||
sha256 8514cbcc73692a2c56bd7a33a022edd5ff819614bd23b19915d7224387f397a7 test.gguf:tensor_1
|
||||
xxh64 a0af5d700049693b test.gguf:tensor_2
|
||||
sha1 25cbfbad4513cc348e2c95ebdee69d6ff2fd8753 test.gguf:tensor_2
|
||||
sha256 947e6b36e20f2cc95e1d2ce1c1669d813d574657ac6b5ac5196158d454d35180 test.gguf:tensor_2
|
||||
xxh64 e83fddf559d7b6a6 test.gguf:tensor_3
|
||||
sha1 a9cba73e2d90f2ee3dae2548caa42bef3fe6a96c test.gguf:tensor_3
|
||||
sha256 423b044e016d8ac73c39f23f60bf01bedef5ecb03c0230accd824c91fe86f1a1 test.gguf:tensor_3
|
||||
xxh64 1257733306b7992d test.gguf:tensor_4
|
||||
sha1 d7bc61db93bb685ce9d598da89717c66729b7543 test.gguf:tensor_4
|
||||
sha256 79737cb3912d4201384cf7f16a1a37ff7823f23ea796cb205b6ca361ab9e3ebf test.gguf:tensor_4
|
||||
xxh64 d238d16ba4711e58 test.gguf:tensor_5
|
||||
sha1 0706566c198fe1072f37e0a5135b4b5f23654c52 test.gguf:tensor_5
|
||||
sha256 60949be8298eced0ecdde64487643d018407bd261691e061d9e9c3dbc9fd358b test.gguf:tensor_5
|
||||
xxh64 3fbc3b65ab8c7f39 test.gguf:tensor_6
|
||||
sha1 73922a0727226a409049f6fc3172a52219ca6f00 test.gguf:tensor_6
|
||||
sha256 574f4c46ff384a3b9a225eb955d2a871847a2e8b3fa59387a8252832e92ef7b0 test.gguf:tensor_6
|
||||
xxh64 c22021c29854f093 test.gguf:tensor_7
|
||||
sha1 efc39cece6a951188fc41e354c73bbfe6813d447 test.gguf:tensor_7
|
||||
sha256 4c0410cd3c500f078ae5b21e8dc9eb79e29112713b2ab58a882f82a3868d4d75 test.gguf:tensor_7
|
||||
xxh64 936df61f5d64261f test.gguf:tensor_8
|
||||
sha1 c2490296d789a4f34398a337fed8377d943d9f06 test.gguf:tensor_8
|
||||
sha256 c4401313feeba0261275c3b25bd2d8fe40ce04e0f440c2980ed0e9674c30ff01 test.gguf:tensor_8
|
||||
xxh64 93fd20c64421c081 test.gguf:tensor_9
|
||||
sha1 7047ce1e78437a6884337a3751c7ee0421918a65 test.gguf:tensor_9
|
||||
sha256 23d57cf0d7a6e90b0b3616b41300e0cd354781e812add854a5f95aa55f2bc514 test.gguf:tensor_9
|
||||
xxh64 5a54d3aad816f302 test.gguf
|
||||
sha1 d15be52c4ff213e823cb6dd13af7ee2f978e7042 test.gguf
|
||||
sha256 7dd641b32f59b60dbd4b5420c4b0f6321ccf48f58f6ae201a3dbc4a58a27c6e4 test.gguf
|
||||
```
|
||||
|
||||
We can then use the normal check command which will by default check for the highest security strength hash and verify against that:
|
||||
|
||||
```bash
|
||||
$ ./llama-gguf-hash --check test.gguf.manifest test.gguf
|
||||
manifest test.gguf.manifest sha256 sha1 xxh64
|
||||
sha256 c0510d38fa060c46265e0160a85c7243096b01dd31c2f355bdbb5516b20de1bd test.gguf:tensor_0 - Ok
|
||||
sha256 8514cbcc73692a2c56bd7a33a022edd5ff819614bd23b19915d7224387f397a7 test.gguf:tensor_1 - Ok
|
||||
sha256 947e6b36e20f2cc95e1d2ce1c1669d813d574657ac6b5ac5196158d454d35180 test.gguf:tensor_2 - Ok
|
||||
sha256 423b044e016d8ac73c39f23f60bf01bedef5ecb03c0230accd824c91fe86f1a1 test.gguf:tensor_3 - Ok
|
||||
sha256 79737cb3912d4201384cf7f16a1a37ff7823f23ea796cb205b6ca361ab9e3ebf test.gguf:tensor_4 - Ok
|
||||
sha256 60949be8298eced0ecdde64487643d018407bd261691e061d9e9c3dbc9fd358b test.gguf:tensor_5 - Ok
|
||||
sha256 574f4c46ff384a3b9a225eb955d2a871847a2e8b3fa59387a8252832e92ef7b0 test.gguf:tensor_6 - Ok
|
||||
sha256 4c0410cd3c500f078ae5b21e8dc9eb79e29112713b2ab58a882f82a3868d4d75 test.gguf:tensor_7 - Ok
|
||||
sha256 c4401313feeba0261275c3b25bd2d8fe40ce04e0f440c2980ed0e9674c30ff01 test.gguf:tensor_8 - Ok
|
||||
sha256 23d57cf0d7a6e90b0b3616b41300e0cd354781e812add854a5f95aa55f2bc514 test.gguf:tensor_9 - Ok
|
||||
sha256 7dd641b32f59b60dbd4b5420c4b0f6321ccf48f58f6ae201a3dbc4a58a27c6e4 test.gguf - Ok
|
||||
|
||||
Verification results for test.gguf.manifest - Success
|
||||
```
|
||||
|
||||
Or we may explicitly ask for a faster hash like:
|
||||
|
||||
```bash
|
||||
$ ./llama-gguf-hash --check test.gguf.manifest --xxh64 test.gguf
|
||||
manifest test.gguf.manifest sha256 sha1 xxh64
|
||||
xxh64 f66e9cd66a4396a0 test.gguf:tensor_0 - Ok
|
||||
xxh64 7d3a1f9ac04d0537 test.gguf:tensor_1 - Ok
|
||||
xxh64 a0af5d700049693b test.gguf:tensor_2 - Ok
|
||||
xxh64 e83fddf559d7b6a6 test.gguf:tensor_3 - Ok
|
||||
xxh64 1257733306b7992d test.gguf:tensor_4 - Ok
|
||||
xxh64 d238d16ba4711e58 test.gguf:tensor_5 - Ok
|
||||
xxh64 3fbc3b65ab8c7f39 test.gguf:tensor_6 - Ok
|
||||
xxh64 c22021c29854f093 test.gguf:tensor_7 - Ok
|
||||
xxh64 936df61f5d64261f test.gguf:tensor_8 - Ok
|
||||
xxh64 93fd20c64421c081 test.gguf:tensor_9 - Ok
|
||||
xxh64 5a54d3aad816f302 test.gguf - Ok
|
||||
|
||||
Verification results for test.gguf.manifest - Success
|
||||
```
|
||||
|
||||
Or maybe we want to just check that all the hash is valid:
|
||||
|
||||
```bash
|
||||
$./llama-gguf-hash --check test.gguf.manifest --all test.gguf.manifest
|
||||
manifest test.gguf.manifest sha256 sha1 xxh64
|
||||
xxh64 f66e9cd66a4396a0 test.gguf:tensor_0 - Ok
|
||||
sha1 59f79ecefd8125a996fdf419239051a7e99e5f20 test.gguf:tensor_0 - Ok
|
||||
sha256 c0510d38fa060c46265e0160a85c7243096b01dd31c2f355bdbb5516b20de1bd test.gguf:tensor_0 - Ok
|
||||
xxh64 7d3a1f9ac04d0537 test.gguf:tensor_1 - Ok
|
||||
sha1 4765f592eacf096df4628ba59476af94d767080a test.gguf:tensor_1 - Ok
|
||||
sha256 8514cbcc73692a2c56bd7a33a022edd5ff819614bd23b19915d7224387f397a7 test.gguf:tensor_1 - Ok
|
||||
xxh64 a0af5d700049693b test.gguf:tensor_2 - Ok
|
||||
sha1 25cbfbad4513cc348e2c95ebdee69d6ff2fd8753 test.gguf:tensor_2 - Ok
|
||||
sha256 947e6b36e20f2cc95e1d2ce1c1669d813d574657ac6b5ac5196158d454d35180 test.gguf:tensor_2 - Ok
|
||||
xxh64 e83fddf559d7b6a6 test.gguf:tensor_3 - Ok
|
||||
sha1 a9cba73e2d90f2ee3dae2548caa42bef3fe6a96c test.gguf:tensor_3 - Ok
|
||||
sha256 423b044e016d8ac73c39f23f60bf01bedef5ecb03c0230accd824c91fe86f1a1 test.gguf:tensor_3 - Ok
|
||||
xxh64 1257733306b7992d test.gguf:tensor_4 - Ok
|
||||
sha1 d7bc61db93bb685ce9d598da89717c66729b7543 test.gguf:tensor_4 - Ok
|
||||
sha256 79737cb3912d4201384cf7f16a1a37ff7823f23ea796cb205b6ca361ab9e3ebf test.gguf:tensor_4 - Ok
|
||||
xxh64 d238d16ba4711e58 test.gguf:tensor_5 - Ok
|
||||
sha1 0706566c198fe1072f37e0a5135b4b5f23654c52 test.gguf:tensor_5 - Ok
|
||||
sha256 60949be8298eced0ecdde64487643d018407bd261691e061d9e9c3dbc9fd358b test.gguf:tensor_5 - Ok
|
||||
xxh64 3fbc3b65ab8c7f39 test.gguf:tensor_6 - Ok
|
||||
sha1 73922a0727226a409049f6fc3172a52219ca6f00 test.gguf:tensor_6 - Ok
|
||||
sha256 574f4c46ff384a3b9a225eb955d2a871847a2e8b3fa59387a8252832e92ef7b0 test.gguf:tensor_6 - Ok
|
||||
xxh64 c22021c29854f093 test.gguf:tensor_7 - Ok
|
||||
sha1 efc39cece6a951188fc41e354c73bbfe6813d447 test.gguf:tensor_7 - Ok
|
||||
sha256 4c0410cd3c500f078ae5b21e8dc9eb79e29112713b2ab58a882f82a3868d4d75 test.gguf:tensor_7 - Ok
|
||||
xxh64 936df61f5d64261f test.gguf:tensor_8 - Ok
|
||||
sha1 c2490296d789a4f34398a337fed8377d943d9f06 test.gguf:tensor_8 - Ok
|
||||
sha256 c4401313feeba0261275c3b25bd2d8fe40ce04e0f440c2980ed0e9674c30ff01 test.gguf:tensor_8 - Ok
|
||||
xxh64 93fd20c64421c081 test.gguf:tensor_9 - Ok
|
||||
sha1 7047ce1e78437a6884337a3751c7ee0421918a65 test.gguf:tensor_9 - Ok
|
||||
sha256 23d57cf0d7a6e90b0b3616b41300e0cd354781e812add854a5f95aa55f2bc514 test.gguf:tensor_9 - Ok
|
||||
xxh64 5a54d3aad816f302 test.gguf - Ok
|
||||
sha1 d15be52c4ff213e823cb6dd13af7ee2f978e7042 test.gguf - Ok
|
||||
sha256 7dd641b32f59b60dbd4b5420c4b0f6321ccf48f58f6ae201a3dbc4a58a27c6e4 test.gguf - Ok
|
||||
|
||||
Verification results for test.gguf.manifest - Success
|
||||
```
|
||||
|
||||
|
||||
## Crypto/Hash Libraries Used
|
||||
|
||||
These micro c libraries dependencies was installed via the [clib c package manager](https://github.com/clibs)
|
||||
|
||||
- https://github.com/Cyan4973/xxHash
|
||||
- https://github.com/clibs/sha1/
|
||||
- https://github.com/jb55/sha256.c
|
||||
13
examples/gguf-hash/deps/rotate-bits/package.json
Normal file
@@ -0,0 +1,13 @@
|
||||
{
|
||||
"name": "rotate-bits",
|
||||
"version": "0.1.1",
|
||||
"repo": "jb55/rotate-bits.h",
|
||||
"description": "rotate bits",
|
||||
"keywords": ["rotl", "rotr"],
|
||||
"src": ["rotate-bits.h"],
|
||||
"license": "Public Domain",
|
||||
"development": {
|
||||
"thlorenz/tap.c": "*"
|
||||
}
|
||||
}
|
||||
|
||||
46
examples/gguf-hash/deps/rotate-bits/rotate-bits.h
Normal file
@@ -0,0 +1,46 @@
|
||||
|
||||
|
||||
#ifndef __ROTATE_DEFS_H
|
||||
#define __ROTATE_DEFS_H
|
||||
|
||||
#ifdef _MSC_VER
|
||||
|
||||
#include <stdlib.h>
|
||||
|
||||
#define ROTL32(v, n) _rotl((v), (n))
|
||||
#define ROTL64(v, n) _rotl64((v), (n))
|
||||
|
||||
#define ROTR32(v, n) _rotr((v), (n))
|
||||
#define ROTR64(v, n) _rotr64((v), (n))
|
||||
|
||||
#else
|
||||
|
||||
#include <stdint.h>
|
||||
|
||||
#define U8V(v) ((uint8_t)(v) & 0xFFU)
|
||||
#define U16V(v) ((uint16_t)(v) & 0xFFFFU)
|
||||
#define U32V(v) ((uint32_t)(v) & 0xFFFFFFFFU)
|
||||
#define U64V(v) ((uint64_t)(v) & 0xFFFFFFFFFFFFFFFFU)
|
||||
|
||||
#define ROTL32(v, n) \
|
||||
(U32V((uint32_t)(v) << (n)) | ((uint32_t)(v) >> (32 - (n))))
|
||||
|
||||
// tests fail if we don't have this cast...
|
||||
#define ROTL64(v, n) \
|
||||
(U64V((uint64_t)(v) << (n)) | ((uint64_t)(v) >> (64 - (n))))
|
||||
|
||||
#define ROTR32(v, n) ROTL32(v, 32 - (n))
|
||||
#define ROTR64(v, n) ROTL64(v, 64 - (n))
|
||||
|
||||
#endif
|
||||
|
||||
#define ROTL8(v, n) \
|
||||
(U8V((uint8_t)(v) << (n)) | ((uint8_t)(v) >> (8 - (n))))
|
||||
|
||||
#define ROTL16(v, n) \
|
||||
(U16V((uint16_t)(v) << (n)) | ((uint16_t)(v) >> (16 - (n))))
|
||||
|
||||
#define ROTR8(v, n) ROTL8(v, 8 - (n))
|
||||
#define ROTR16(v, n) ROTL16(v, 16 - (n))
|
||||
|
||||
#endif
|
||||
9
examples/gguf-hash/deps/sha1/package.json
Normal file
@@ -0,0 +1,9 @@
|
||||
{
|
||||
"name": "sha1",
|
||||
"version": "0.0.1",
|
||||
"repo": "clibs/sha1",
|
||||
"description": "sha1 hash algorithm",
|
||||
"keywords": ["sha1", "hash"],
|
||||
"license": "public domain",
|
||||
"src": ["sha1.c", "sha1.h"]
|
||||
}
|
||||
295
examples/gguf-hash/deps/sha1/sha1.c
Normal file
@@ -0,0 +1,295 @@
|
||||
/*
|
||||
SHA-1 in C
|
||||
By Steve Reid <steve@edmweb.com>
|
||||
100% Public Domain
|
||||
|
||||
Test Vectors (from FIPS PUB 180-1)
|
||||
"abc"
|
||||
A9993E36 4706816A BA3E2571 7850C26C 9CD0D89D
|
||||
"abcdbcdecdefdefgefghfghighijhijkijkljklmklmnlmnomnopnopq"
|
||||
84983E44 1C3BD26E BAAE4AA1 F95129E5 E54670F1
|
||||
A million repetitions of "a"
|
||||
34AA973C D4C4DAA4 F61EEB2B DBAD2731 6534016F
|
||||
*/
|
||||
|
||||
/* #define LITTLE_ENDIAN * This should be #define'd already, if true. */
|
||||
/* #define SHA1HANDSOFF * Copies data before messing with it. */
|
||||
|
||||
#define SHA1HANDSOFF
|
||||
|
||||
#include <stdio.h>
|
||||
#include <string.h>
|
||||
|
||||
/* for uint32_t */
|
||||
#include <stdint.h>
|
||||
|
||||
#include "sha1.h"
|
||||
|
||||
|
||||
#define rol(value, bits) (((value) << (bits)) | ((value) >> (32 - (bits))))
|
||||
|
||||
/* blk0() and blk() perform the initial expand. */
|
||||
/* I got the idea of expanding during the round function from SSLeay */
|
||||
#if BYTE_ORDER == LITTLE_ENDIAN
|
||||
#define blk0(i) (block->l[i] = (rol(block->l[i],24)&0xFF00FF00) \
|
||||
|(rol(block->l[i],8)&0x00FF00FF))
|
||||
#elif BYTE_ORDER == BIG_ENDIAN
|
||||
#define blk0(i) block->l[i]
|
||||
#else
|
||||
#error "Endianness not defined!"
|
||||
#endif
|
||||
#define blk(i) (block->l[i&15] = rol(block->l[(i+13)&15]^block->l[(i+8)&15] \
|
||||
^block->l[(i+2)&15]^block->l[i&15],1))
|
||||
|
||||
/* (R0+R1), R2, R3, R4 are the different operations used in SHA1 */
|
||||
#define R0(v,w,x,y,z,i) z+=((w&(x^y))^y)+blk0(i)+0x5A827999+rol(v,5);w=rol(w,30);
|
||||
#define R1(v,w,x,y,z,i) z+=((w&(x^y))^y)+blk(i)+0x5A827999+rol(v,5);w=rol(w,30);
|
||||
#define R2(v,w,x,y,z,i) z+=(w^x^y)+blk(i)+0x6ED9EBA1+rol(v,5);w=rol(w,30);
|
||||
#define R3(v,w,x,y,z,i) z+=(((w|x)&y)|(w&x))+blk(i)+0x8F1BBCDC+rol(v,5);w=rol(w,30);
|
||||
#define R4(v,w,x,y,z,i) z+=(w^x^y)+blk(i)+0xCA62C1D6+rol(v,5);w=rol(w,30);
|
||||
|
||||
|
||||
/* Hash a single 512-bit block. This is the core of the algorithm. */
|
||||
|
||||
void SHA1Transform(
|
||||
uint32_t state[5],
|
||||
const unsigned char buffer[64]
|
||||
)
|
||||
{
|
||||
uint32_t a, b, c, d, e;
|
||||
|
||||
typedef union
|
||||
{
|
||||
unsigned char c[64];
|
||||
uint32_t l[16];
|
||||
} CHAR64LONG16;
|
||||
|
||||
#ifdef SHA1HANDSOFF
|
||||
CHAR64LONG16 block[1]; /* use array to appear as a pointer */
|
||||
|
||||
memcpy(block, buffer, 64);
|
||||
#else
|
||||
/* The following had better never be used because it causes the
|
||||
* pointer-to-const buffer to be cast into a pointer to non-const.
|
||||
* And the result is written through. I threw a "const" in, hoping
|
||||
* this will cause a diagnostic.
|
||||
*/
|
||||
CHAR64LONG16 *block = (const CHAR64LONG16 *) buffer;
|
||||
#endif
|
||||
/* Copy context->state[] to working vars */
|
||||
a = state[0];
|
||||
b = state[1];
|
||||
c = state[2];
|
||||
d = state[3];
|
||||
e = state[4];
|
||||
/* 4 rounds of 20 operations each. Loop unrolled. */
|
||||
R0(a, b, c, d, e, 0);
|
||||
R0(e, a, b, c, d, 1);
|
||||
R0(d, e, a, b, c, 2);
|
||||
R0(c, d, e, a, b, 3);
|
||||
R0(b, c, d, e, a, 4);
|
||||
R0(a, b, c, d, e, 5);
|
||||
R0(e, a, b, c, d, 6);
|
||||
R0(d, e, a, b, c, 7);
|
||||
R0(c, d, e, a, b, 8);
|
||||
R0(b, c, d, e, a, 9);
|
||||
R0(a, b, c, d, e, 10);
|
||||
R0(e, a, b, c, d, 11);
|
||||
R0(d, e, a, b, c, 12);
|
||||
R0(c, d, e, a, b, 13);
|
||||
R0(b, c, d, e, a, 14);
|
||||
R0(a, b, c, d, e, 15);
|
||||
R1(e, a, b, c, d, 16);
|
||||
R1(d, e, a, b, c, 17);
|
||||
R1(c, d, e, a, b, 18);
|
||||
R1(b, c, d, e, a, 19);
|
||||
R2(a, b, c, d, e, 20);
|
||||
R2(e, a, b, c, d, 21);
|
||||
R2(d, e, a, b, c, 22);
|
||||
R2(c, d, e, a, b, 23);
|
||||
R2(b, c, d, e, a, 24);
|
||||
R2(a, b, c, d, e, 25);
|
||||
R2(e, a, b, c, d, 26);
|
||||
R2(d, e, a, b, c, 27);
|
||||
R2(c, d, e, a, b, 28);
|
||||
R2(b, c, d, e, a, 29);
|
||||
R2(a, b, c, d, e, 30);
|
||||
R2(e, a, b, c, d, 31);
|
||||
R2(d, e, a, b, c, 32);
|
||||
R2(c, d, e, a, b, 33);
|
||||
R2(b, c, d, e, a, 34);
|
||||
R2(a, b, c, d, e, 35);
|
||||
R2(e, a, b, c, d, 36);
|
||||
R2(d, e, a, b, c, 37);
|
||||
R2(c, d, e, a, b, 38);
|
||||
R2(b, c, d, e, a, 39);
|
||||
R3(a, b, c, d, e, 40);
|
||||
R3(e, a, b, c, d, 41);
|
||||
R3(d, e, a, b, c, 42);
|
||||
R3(c, d, e, a, b, 43);
|
||||
R3(b, c, d, e, a, 44);
|
||||
R3(a, b, c, d, e, 45);
|
||||
R3(e, a, b, c, d, 46);
|
||||
R3(d, e, a, b, c, 47);
|
||||
R3(c, d, e, a, b, 48);
|
||||
R3(b, c, d, e, a, 49);
|
||||
R3(a, b, c, d, e, 50);
|
||||
R3(e, a, b, c, d, 51);
|
||||
R3(d, e, a, b, c, 52);
|
||||
R3(c, d, e, a, b, 53);
|
||||
R3(b, c, d, e, a, 54);
|
||||
R3(a, b, c, d, e, 55);
|
||||
R3(e, a, b, c, d, 56);
|
||||
R3(d, e, a, b, c, 57);
|
||||
R3(c, d, e, a, b, 58);
|
||||
R3(b, c, d, e, a, 59);
|
||||
R4(a, b, c, d, e, 60);
|
||||
R4(e, a, b, c, d, 61);
|
||||
R4(d, e, a, b, c, 62);
|
||||
R4(c, d, e, a, b, 63);
|
||||
R4(b, c, d, e, a, 64);
|
||||
R4(a, b, c, d, e, 65);
|
||||
R4(e, a, b, c, d, 66);
|
||||
R4(d, e, a, b, c, 67);
|
||||
R4(c, d, e, a, b, 68);
|
||||
R4(b, c, d, e, a, 69);
|
||||
R4(a, b, c, d, e, 70);
|
||||
R4(e, a, b, c, d, 71);
|
||||
R4(d, e, a, b, c, 72);
|
||||
R4(c, d, e, a, b, 73);
|
||||
R4(b, c, d, e, a, 74);
|
||||
R4(a, b, c, d, e, 75);
|
||||
R4(e, a, b, c, d, 76);
|
||||
R4(d, e, a, b, c, 77);
|
||||
R4(c, d, e, a, b, 78);
|
||||
R4(b, c, d, e, a, 79);
|
||||
/* Add the working vars back into context.state[] */
|
||||
state[0] += a;
|
||||
state[1] += b;
|
||||
state[2] += c;
|
||||
state[3] += d;
|
||||
state[4] += e;
|
||||
/* Wipe variables */
|
||||
a = b = c = d = e = 0;
|
||||
#ifdef SHA1HANDSOFF
|
||||
memset(block, '\0', sizeof(block));
|
||||
#endif
|
||||
}
|
||||
|
||||
|
||||
/* SHA1Init - Initialize new context */
|
||||
|
||||
void SHA1Init(
|
||||
SHA1_CTX * context
|
||||
)
|
||||
{
|
||||
/* SHA1 initialization constants */
|
||||
context->state[0] = 0x67452301;
|
||||
context->state[1] = 0xEFCDAB89;
|
||||
context->state[2] = 0x98BADCFE;
|
||||
context->state[3] = 0x10325476;
|
||||
context->state[4] = 0xC3D2E1F0;
|
||||
context->count[0] = context->count[1] = 0;
|
||||
}
|
||||
|
||||
|
||||
/* Run your data through this. */
|
||||
|
||||
void SHA1Update(
|
||||
SHA1_CTX * context,
|
||||
const unsigned char *data,
|
||||
uint32_t len
|
||||
)
|
||||
{
|
||||
uint32_t i;
|
||||
|
||||
uint32_t j;
|
||||
|
||||
j = context->count[0];
|
||||
if ((context->count[0] += len << 3) < j)
|
||||
context->count[1]++;
|
||||
context->count[1] += (len >> 29);
|
||||
j = (j >> 3) & 63;
|
||||
if ((j + len) > 63)
|
||||
{
|
||||
memcpy(&context->buffer[j], data, (i = 64 - j));
|
||||
SHA1Transform(context->state, context->buffer);
|
||||
for (; i + 63 < len; i += 64)
|
||||
{
|
||||
SHA1Transform(context->state, &data[i]);
|
||||
}
|
||||
j = 0;
|
||||
}
|
||||
else
|
||||
i = 0;
|
||||
memcpy(&context->buffer[j], &data[i], len - i);
|
||||
}
|
||||
|
||||
|
||||
/* Add padding and return the message digest. */
|
||||
|
||||
void SHA1Final(
|
||||
unsigned char digest[20],
|
||||
SHA1_CTX * context
|
||||
)
|
||||
{
|
||||
unsigned i;
|
||||
|
||||
unsigned char finalcount[8];
|
||||
|
||||
unsigned char c;
|
||||
|
||||
#if 0 /* untested "improvement" by DHR */
|
||||
/* Convert context->count to a sequence of bytes
|
||||
* in finalcount. Second element first, but
|
||||
* big-endian order within element.
|
||||
* But we do it all backwards.
|
||||
*/
|
||||
unsigned char *fcp = &finalcount[8];
|
||||
|
||||
for (i = 0; i < 2; i++)
|
||||
{
|
||||
uint32_t t = context->count[i];
|
||||
|
||||
int j;
|
||||
|
||||
for (j = 0; j < 4; t >>= 8, j++)
|
||||
*--fcp = (unsigned char) t}
|
||||
#else
|
||||
for (i = 0; i < 8; i++)
|
||||
{
|
||||
finalcount[i] = (unsigned char) ((context->count[(i >= 4 ? 0 : 1)] >> ((3 - (i & 3)) * 8)) & 255); /* Endian independent */
|
||||
}
|
||||
#endif
|
||||
c = 0200;
|
||||
SHA1Update(context, &c, 1);
|
||||
while ((context->count[0] & 504) != 448)
|
||||
{
|
||||
c = 0000;
|
||||
SHA1Update(context, &c, 1);
|
||||
}
|
||||
SHA1Update(context, finalcount, 8); /* Should cause a SHA1Transform() */
|
||||
for (i = 0; i < 20; i++)
|
||||
{
|
||||
digest[i] = (unsigned char)
|
||||
((context->state[i >> 2] >> ((3 - (i & 3)) * 8)) & 255);
|
||||
}
|
||||
/* Wipe variables */
|
||||
memset(context, '\0', sizeof(*context));
|
||||
memset(&finalcount, '\0', sizeof(finalcount));
|
||||
}
|
||||
|
||||
void SHA1(
|
||||
char *hash_out,
|
||||
const char *str,
|
||||
uint32_t len)
|
||||
{
|
||||
SHA1_CTX ctx;
|
||||
unsigned int ii;
|
||||
|
||||
SHA1Init(&ctx);
|
||||
for (ii=0; ii<len; ii+=1)
|
||||
SHA1Update(&ctx, (const unsigned char*)str + ii, 1);
|
||||
SHA1Final((unsigned char *)hash_out, &ctx);
|
||||
}
|
||||
|
||||
52
examples/gguf-hash/deps/sha1/sha1.h
Normal file
@@ -0,0 +1,52 @@
|
||||
#ifndef SHA1_H
|
||||
#define SHA1_H
|
||||
|
||||
/*
|
||||
SHA-1 in C
|
||||
By Steve Reid <steve@edmweb.com>
|
||||
100% Public Domain
|
||||
*/
|
||||
|
||||
#include "stdint.h"
|
||||
|
||||
#if defined(__cplusplus)
|
||||
extern "C" {
|
||||
#endif
|
||||
|
||||
typedef struct
|
||||
{
|
||||
uint32_t state[5];
|
||||
uint32_t count[2];
|
||||
unsigned char buffer[64];
|
||||
} SHA1_CTX;
|
||||
|
||||
void SHA1Transform(
|
||||
uint32_t state[5],
|
||||
const unsigned char buffer[64]
|
||||
);
|
||||
|
||||
void SHA1Init(
|
||||
SHA1_CTX * context
|
||||
);
|
||||
|
||||
void SHA1Update(
|
||||
SHA1_CTX * context,
|
||||
const unsigned char *data,
|
||||
uint32_t len
|
||||
);
|
||||
|
||||
void SHA1Final(
|
||||
unsigned char digest[20],
|
||||
SHA1_CTX * context
|
||||
);
|
||||
|
||||
void SHA1(
|
||||
char *hash_out,
|
||||
const char *str,
|
||||
uint32_t len);
|
||||
|
||||
#if defined(__cplusplus)
|
||||
}
|
||||
#endif
|
||||
|
||||
#endif /* SHA1_H */
|
||||
15
examples/gguf-hash/deps/sha256/package.json
Normal file
@@ -0,0 +1,15 @@
|
||||
{
|
||||
"name": "sha256",
|
||||
"version": "0.0.2",
|
||||
"repo": "jb55/sha256.c",
|
||||
"description": "sha256 in c",
|
||||
"keywords": ["sha256", "sha2"],
|
||||
"src": ["sha256.c", "sha256.h"],
|
||||
"dependencies": {
|
||||
"jb55/rotate-bits.h": "0.1.1"
|
||||
},
|
||||
"development": {
|
||||
"thlorenz/tap.c": "*"
|
||||
}
|
||||
}
|
||||
|
||||
221
examples/gguf-hash/deps/sha256/sha256.c
Normal file
@@ -0,0 +1,221 @@
|
||||
/* Crypto/Sha256.c -- SHA-256 Hash
|
||||
2010-06-11 : Igor Pavlov : Public domain
|
||||
This code is based on public domain code from Wei Dai's Crypto++ library. */
|
||||
|
||||
#include "rotate-bits/rotate-bits.h"
|
||||
#include "sha256.h"
|
||||
|
||||
/* define it for speed optimization */
|
||||
#define _SHA256_UNROLL
|
||||
#define _SHA256_UNROLL2
|
||||
|
||||
void
|
||||
sha256_init(sha256_t *p)
|
||||
{
|
||||
p->state[0] = 0x6a09e667;
|
||||
p->state[1] = 0xbb67ae85;
|
||||
p->state[2] = 0x3c6ef372;
|
||||
p->state[3] = 0xa54ff53a;
|
||||
p->state[4] = 0x510e527f;
|
||||
p->state[5] = 0x9b05688c;
|
||||
p->state[6] = 0x1f83d9ab;
|
||||
p->state[7] = 0x5be0cd19;
|
||||
p->count = 0;
|
||||
}
|
||||
|
||||
#define S0(x) (ROTR32(x, 2) ^ ROTR32(x,13) ^ ROTR32(x, 22))
|
||||
#define S1(x) (ROTR32(x, 6) ^ ROTR32(x,11) ^ ROTR32(x, 25))
|
||||
#define s0(x) (ROTR32(x, 7) ^ ROTR32(x,18) ^ (x >> 3))
|
||||
#define s1(x) (ROTR32(x,17) ^ ROTR32(x,19) ^ (x >> 10))
|
||||
|
||||
#define blk0(i) (W[i] = data[i])
|
||||
#define blk2(i) (W[i&15] += s1(W[(i-2)&15]) + W[(i-7)&15] + s0(W[(i-15)&15]))
|
||||
|
||||
#define Ch(x,y,z) (z^(x&(y^z)))
|
||||
#define Maj(x,y,z) ((x&y)|(z&(x|y)))
|
||||
|
||||
#define a(i) T[(0-(i))&7]
|
||||
#define b(i) T[(1-(i))&7]
|
||||
#define c(i) T[(2-(i))&7]
|
||||
#define d(i) T[(3-(i))&7]
|
||||
#define e(i) T[(4-(i))&7]
|
||||
#define f(i) T[(5-(i))&7]
|
||||
#define g(i) T[(6-(i))&7]
|
||||
#define h(i) T[(7-(i))&7]
|
||||
|
||||
|
||||
#ifdef _SHA256_UNROLL2
|
||||
|
||||
#define R(a,b,c,d,e,f,g,h, i) h += S1(e) + Ch(e,f,g) + K[i+j] + (j?blk2(i):blk0(i));\
|
||||
d += h; h += S0(a) + Maj(a, b, c)
|
||||
|
||||
#define RX_8(i) \
|
||||
R(a,b,c,d,e,f,g,h, i); \
|
||||
R(h,a,b,c,d,e,f,g, (i+1)); \
|
||||
R(g,h,a,b,c,d,e,f, (i+2)); \
|
||||
R(f,g,h,a,b,c,d,e, (i+3)); \
|
||||
R(e,f,g,h,a,b,c,d, (i+4)); \
|
||||
R(d,e,f,g,h,a,b,c, (i+5)); \
|
||||
R(c,d,e,f,g,h,a,b, (i+6)); \
|
||||
R(b,c,d,e,f,g,h,a, (i+7))
|
||||
|
||||
#else
|
||||
|
||||
#define R(i) h(i) += S1(e(i)) + Ch(e(i),f(i),g(i)) + K[i+j] + (j?blk2(i):blk0(i));\
|
||||
d(i) += h(i); h(i) += S0(a(i)) + Maj(a(i), b(i), c(i))
|
||||
|
||||
#ifdef _SHA256_UNROLL
|
||||
|
||||
#define RX_8(i) R(i+0); R(i+1); R(i+2); R(i+3); R(i+4); R(i+5); R(i+6); R(i+7);
|
||||
|
||||
#endif
|
||||
|
||||
#endif
|
||||
|
||||
static const uint32_t K[64] = {
|
||||
0x428a2f98, 0x71374491, 0xb5c0fbcf, 0xe9b5dba5,
|
||||
0x3956c25b, 0x59f111f1, 0x923f82a4, 0xab1c5ed5,
|
||||
0xd807aa98, 0x12835b01, 0x243185be, 0x550c7dc3,
|
||||
0x72be5d74, 0x80deb1fe, 0x9bdc06a7, 0xc19bf174,
|
||||
0xe49b69c1, 0xefbe4786, 0x0fc19dc6, 0x240ca1cc,
|
||||
0x2de92c6f, 0x4a7484aa, 0x5cb0a9dc, 0x76f988da,
|
||||
0x983e5152, 0xa831c66d, 0xb00327c8, 0xbf597fc7,
|
||||
0xc6e00bf3, 0xd5a79147, 0x06ca6351, 0x14292967,
|
||||
0x27b70a85, 0x2e1b2138, 0x4d2c6dfc, 0x53380d13,
|
||||
0x650a7354, 0x766a0abb, 0x81c2c92e, 0x92722c85,
|
||||
0xa2bfe8a1, 0xa81a664b, 0xc24b8b70, 0xc76c51a3,
|
||||
0xd192e819, 0xd6990624, 0xf40e3585, 0x106aa070,
|
||||
0x19a4c116, 0x1e376c08, 0x2748774c, 0x34b0bcb5,
|
||||
0x391c0cb3, 0x4ed8aa4a, 0x5b9cca4f, 0x682e6ff3,
|
||||
0x748f82ee, 0x78a5636f, 0x84c87814, 0x8cc70208,
|
||||
0x90befffa, 0xa4506ceb, 0xbef9a3f7, 0xc67178f2
|
||||
};
|
||||
|
||||
static void
|
||||
sha256_transform(uint32_t *state, const uint32_t *data)
|
||||
{
|
||||
uint32_t W[16] = {0};
|
||||
unsigned j;
|
||||
#ifdef _SHA256_UNROLL2
|
||||
uint32_t a,b,c,d,e,f,g,h;
|
||||
a = state[0];
|
||||
b = state[1];
|
||||
c = state[2];
|
||||
d = state[3];
|
||||
e = state[4];
|
||||
f = state[5];
|
||||
g = state[6];
|
||||
h = state[7];
|
||||
#else
|
||||
uint32_t T[8];
|
||||
for (j = 0; j < 8; j++)
|
||||
T[j] = state[j];
|
||||
#endif
|
||||
|
||||
for (j = 0; j < 64; j += 16)
|
||||
{
|
||||
#if defined(_SHA256_UNROLL) || defined(_SHA256_UNROLL2)
|
||||
RX_8(0); RX_8(8);
|
||||
#else
|
||||
unsigned i;
|
||||
for (i = 0; i < 16; i++) { R(i); }
|
||||
#endif
|
||||
}
|
||||
|
||||
#ifdef _SHA256_UNROLL2
|
||||
state[0] += a;
|
||||
state[1] += b;
|
||||
state[2] += c;
|
||||
state[3] += d;
|
||||
state[4] += e;
|
||||
state[5] += f;
|
||||
state[6] += g;
|
||||
state[7] += h;
|
||||
#else
|
||||
for (j = 0; j < 8; j++)
|
||||
state[j] += T[j];
|
||||
#endif
|
||||
|
||||
/* Wipe variables */
|
||||
/* memset(W, 0, sizeof(W)); */
|
||||
/* memset(T, 0, sizeof(T)); */
|
||||
}
|
||||
|
||||
#undef S0
|
||||
#undef S1
|
||||
#undef s0
|
||||
#undef s1
|
||||
|
||||
static void
|
||||
sha256_write_byte_block(sha256_t *p)
|
||||
{
|
||||
uint32_t data32[16];
|
||||
unsigned i;
|
||||
for (i = 0; i < 16; i++)
|
||||
data32[i] =
|
||||
((uint32_t)(p->buffer[i * 4 ]) << 24) +
|
||||
((uint32_t)(p->buffer[i * 4 + 1]) << 16) +
|
||||
((uint32_t)(p->buffer[i * 4 + 2]) << 8) +
|
||||
((uint32_t)(p->buffer[i * 4 + 3]));
|
||||
sha256_transform(p->state, data32);
|
||||
}
|
||||
|
||||
|
||||
void
|
||||
sha256_hash(unsigned char *buf, const unsigned char *data, size_t size)
|
||||
{
|
||||
sha256_t hash;
|
||||
sha256_init(&hash);
|
||||
sha256_update(&hash, data, size);
|
||||
sha256_final(&hash, buf);
|
||||
}
|
||||
|
||||
|
||||
void
|
||||
sha256_update(sha256_t *p, const unsigned char *data, size_t size)
|
||||
{
|
||||
uint32_t curBufferPos = (uint32_t)p->count & 0x3F;
|
||||
while (size > 0)
|
||||
{
|
||||
p->buffer[curBufferPos++] = *data++;
|
||||
p->count++;
|
||||
size--;
|
||||
if (curBufferPos == 64)
|
||||
{
|
||||
curBufferPos = 0;
|
||||
sha256_write_byte_block(p);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
void
|
||||
sha256_final(sha256_t *p, unsigned char *digest)
|
||||
{
|
||||
uint64_t lenInBits = (p->count << 3);
|
||||
uint32_t curBufferPos = (uint32_t)p->count & 0x3F;
|
||||
unsigned i;
|
||||
p->buffer[curBufferPos++] = 0x80;
|
||||
while (curBufferPos != (64 - 8))
|
||||
{
|
||||
curBufferPos &= 0x3F;
|
||||
if (curBufferPos == 0)
|
||||
sha256_write_byte_block(p);
|
||||
p->buffer[curBufferPos++] = 0;
|
||||
}
|
||||
for (i = 0; i < 8; i++)
|
||||
{
|
||||
p->buffer[curBufferPos++] = (unsigned char)(lenInBits >> 56);
|
||||
lenInBits <<= 8;
|
||||
}
|
||||
sha256_write_byte_block(p);
|
||||
|
||||
for (i = 0; i < 8; i++)
|
||||
{
|
||||
*digest++ = (unsigned char)(p->state[i] >> 24);
|
||||
*digest++ = (unsigned char)(p->state[i] >> 16);
|
||||
*digest++ = (unsigned char)(p->state[i] >> 8);
|
||||
*digest++ = (unsigned char)(p->state[i]);
|
||||
}
|
||||
sha256_init(p);
|
||||
}
|
||||
24
examples/gguf-hash/deps/sha256/sha256.h
Normal file
@@ -0,0 +1,24 @@
|
||||
/* Sha256.h -- SHA-256 Hash
|
||||
2010-06-11 : Igor Pavlov : Public domain */
|
||||
|
||||
#ifndef __CRYPTO_SHA256_H
|
||||
#define __CRYPTO_SHA256_H
|
||||
|
||||
#include <stdlib.h>
|
||||
#include <stdint.h>
|
||||
|
||||
#define SHA256_DIGEST_SIZE 32
|
||||
|
||||
typedef struct sha256_t
|
||||
{
|
||||
uint32_t state[8];
|
||||
uint64_t count;
|
||||
unsigned char buffer[64];
|
||||
} sha256_t;
|
||||
|
||||
void sha256_init(sha256_t *p);
|
||||
void sha256_update(sha256_t *p, const unsigned char *data, size_t size);
|
||||
void sha256_final(sha256_t *p, unsigned char *digest);
|
||||
void sha256_hash(unsigned char *buf, const unsigned char *data, size_t size);
|
||||
|
||||
#endif
|
||||
12
examples/gguf-hash/deps/xxhash/clib.json
Normal file
@@ -0,0 +1,12 @@
|
||||
{
|
||||
"name": "xxhash",
|
||||
"version": "0.8.2",
|
||||
"repo": "Cyan4973/xxhash",
|
||||
"description": "Extremely fast non-cryptographic hash algorithm",
|
||||
"keywords": ["xxhash", "hashing"],
|
||||
"license": "BSD-2-Clause",
|
||||
"src": [
|
||||
"xxhash.c",
|
||||
"xxhash.h"
|
||||
]
|
||||
}
|
||||
42
examples/gguf-hash/deps/xxhash/xxhash.c
Normal file
@@ -0,0 +1,42 @@
|
||||
/*
|
||||
* xxHash - Extremely Fast Hash algorithm
|
||||
* Copyright (C) 2012-2023 Yann Collet
|
||||
*
|
||||
* BSD 2-Clause License (https://www.opensource.org/licenses/bsd-license.php)
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are permitted provided that the following conditions are
|
||||
* met:
|
||||
*
|
||||
* * Redistributions of source code must retain the above copyright
|
||||
* notice, this list of conditions and the following disclaimer.
|
||||
* * Redistributions in binary form must reproduce the above
|
||||
* copyright notice, this list of conditions and the following disclaimer
|
||||
* in the documentation and/or other materials provided with the
|
||||
* distribution.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
|
||||
* "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
|
||||
* LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
|
||||
* A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
|
||||
* OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
|
||||
* SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
|
||||
* LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
|
||||
* DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
|
||||
* THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
|
||||
* (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
* You can contact the author at:
|
||||
* - xxHash homepage: https://www.xxhash.com
|
||||
* - xxHash source repository: https://github.com/Cyan4973/xxHash
|
||||
*/
|
||||
|
||||
/*
|
||||
* xxhash.c instantiates functions defined in xxhash.h
|
||||
*/
|
||||
|
||||
#define XXH_STATIC_LINKING_ONLY /* access advanced declarations */
|
||||
#define XXH_IMPLEMENTATION /* access definitions */
|
||||
|
||||
#include "xxhash.h"
|
||||
7093
examples/gguf-hash/deps/xxhash/xxhash.h
Normal file
694
examples/gguf-hash/gguf-hash.cpp
Normal file
@@ -0,0 +1,694 @@
|
||||
#include "ggml.h"
|
||||
#include "gguf.h"
|
||||
|
||||
#include <cstdlib> /* abort() */
|
||||
#include <cstddef>
|
||||
#include <cstdio>
|
||||
#include <string>
|
||||
#include <stdexcept>
|
||||
#include <algorithm>
|
||||
#include <cstring>
|
||||
|
||||
#include <sstream>
|
||||
#include <fstream>
|
||||
|
||||
#ifdef __cplusplus
|
||||
extern "C" {
|
||||
#endif
|
||||
|
||||
#include "xxhash/xxhash.h"
|
||||
#include "sha1/sha1.h"
|
||||
#include "sha256/sha256.h"
|
||||
|
||||
#ifdef __cplusplus
|
||||
}
|
||||
#endif
|
||||
|
||||
|
||||
// uuid.uuid5(uuid.NAMESPACE_URL, 'en.wikipedia.org/wiki/Llama.cpp')
|
||||
#define UUID_NAMESPACE_LLAMA_CPP "ef001206-dadc-5f6d-a15f-3359e577d4e5"
|
||||
#define UUID_NAMESPACE_LLAMA_CPP_HEX 0xef, 0x00, 0x12, 0x06, 0xda, 0xdc, 0x5f, 0x6d, 0xa1, 0x5f, 0x33, 0x59, 0xe5, 0x77, 0xd4, 0xe5
|
||||
|
||||
|
||||
#define HASH_TYPE_SHA256_STR "sha256"
|
||||
#define HASH_TYPE_SHA1_STR "sha1"
|
||||
#define HASH_TYPE_XXH64_STR "xxh64"
|
||||
#define HASH_TYPE_UUID_STR "uuid"
|
||||
|
||||
|
||||
typedef enum {
|
||||
HASH_EXIT_SUCCESS = 0, // All hash has been generated or validated
|
||||
HASH_EXIT_FAILURE = 1, // Generic Failure
|
||||
HASH_EXIT_MISMATCH = 2, // Hash mismatched during validation
|
||||
HASH_EXIT_MANIFEST_MISSING_ENTRY = 3, // Hash attempted validation but missing entry in manifest
|
||||
HASH_EXIT_MANIFEST_UNKNOWN_HASH = 4, // Manifest is present, but we do not know any hash format within it
|
||||
HASH_EXIT_MANIFEST_FILE_ERROR = 5 // Manifest is either missing or not a known format
|
||||
} hash_exit_code_t;
|
||||
|
||||
|
||||
typedef enum {
|
||||
HASH_MANIFEST_NOT_FOUND,
|
||||
HASH_MANIFEST_MISMATCH,
|
||||
HASH_MANIFEST_OK,
|
||||
} hash_manifest_result_t;
|
||||
|
||||
|
||||
struct hash_params {
|
||||
std::string input;
|
||||
bool xxh64 = false;
|
||||
bool sha1 = false;
|
||||
bool sha256 = false;
|
||||
bool uuid = false;
|
||||
|
||||
bool no_layer = false;
|
||||
|
||||
bool manifest_is_usable = false;
|
||||
std::string manifest_file;
|
||||
};
|
||||
|
||||
struct manifest_check_params {
|
||||
bool xxh64 = false;
|
||||
bool sha1 = false;
|
||||
bool sha256 = false;
|
||||
bool uuid = false;
|
||||
};
|
||||
|
||||
static char const * hash_manifest_result_to_str(hash_manifest_result_t value) {
|
||||
switch (value) {
|
||||
case HASH_MANIFEST_NOT_FOUND: return "Not Found";
|
||||
case HASH_MANIFEST_MISMATCH: return "Mismatch";
|
||||
case HASH_MANIFEST_OK: return "Ok";
|
||||
}
|
||||
return "?";
|
||||
}
|
||||
|
||||
static char const * hash_exit_code_to_str(hash_exit_code_t value) {
|
||||
switch (value) {
|
||||
case HASH_EXIT_SUCCESS: return "Success";
|
||||
case HASH_EXIT_FAILURE: return "Failure";
|
||||
case HASH_EXIT_MISMATCH: return "Mismatch";
|
||||
case HASH_EXIT_MANIFEST_MISSING_ENTRY: return "Manifest Missing Entry";
|
||||
case HASH_EXIT_MANIFEST_UNKNOWN_HASH: return "Manifest Unknown Hash";
|
||||
case HASH_EXIT_MANIFEST_FILE_ERROR: return "Manifest File Error";
|
||||
}
|
||||
return "?";
|
||||
}
|
||||
|
||||
static void hash_print_usage(const char * executable) {
|
||||
const hash_params default_params;
|
||||
printf("\n");
|
||||
printf("usage: %s [options] GGUF_IN\n", executable);
|
||||
printf("\n");
|
||||
printf("Hash a GGUF file");
|
||||
printf("\n");
|
||||
printf("options:\n");
|
||||
printf(" -h, --help show this help message and exit\n");
|
||||
printf(" --xxh64 use xxh64 hash\n");
|
||||
printf(" --sha1 use sha1 hash\n");
|
||||
printf(" --sha256 use sha256 hash\n");
|
||||
printf(" --all use all hash\n");
|
||||
printf(" --no-layer exclude per layer hash\n");
|
||||
printf(" --uuid generate UUIDv5 ID\n");
|
||||
printf(" -c, --check <manifest> verify against a manifest\n");
|
||||
printf("\n");
|
||||
}
|
||||
|
||||
static void hash_params_parse_ex(int argc, const char ** argv, hash_params & params) {
|
||||
std::string arg;
|
||||
bool invalid_param = false;
|
||||
const std::string arg_prefix = "--";
|
||||
|
||||
int arg_idx = 1;
|
||||
for (; arg_idx < argc && strncmp(argv[arg_idx], "--", 2) == 0; arg_idx++) {
|
||||
arg = argv[arg_idx];
|
||||
if (arg.compare(0, arg_prefix.size(), arg_prefix) == 0) {
|
||||
std::replace(arg.begin(), arg.end(), '_', '-');
|
||||
}
|
||||
|
||||
bool arg_found = false;
|
||||
if (arg == "-h" || arg == "--help") {
|
||||
hash_print_usage(argv[0]);
|
||||
exit(0);
|
||||
}
|
||||
|
||||
if (arg == "--xxh64") {
|
||||
arg_found = true;
|
||||
params.xxh64 = true;
|
||||
}
|
||||
|
||||
if (arg == "--sha1") {
|
||||
arg_found = true;
|
||||
params.sha1 = true;
|
||||
}
|
||||
|
||||
if (arg == "--uuid") {
|
||||
arg_found = true;
|
||||
params.uuid = true;
|
||||
}
|
||||
|
||||
if (arg == "--sha256") {
|
||||
arg_found = true;
|
||||
params.sha256 = true;
|
||||
}
|
||||
|
||||
if (arg == "--all") {
|
||||
arg_found = true;
|
||||
params.sha256 = true;
|
||||
params.sha1 = true;
|
||||
params.xxh64 = true;
|
||||
}
|
||||
|
||||
if (arg == "--no-layer") {
|
||||
arg_found = true;
|
||||
params.no_layer = true;
|
||||
}
|
||||
|
||||
if (arg == "-c" || arg == "--check") {
|
||||
if (++arg_idx >= argc) {
|
||||
invalid_param = true;
|
||||
break;
|
||||
}
|
||||
arg_found = true;
|
||||
params.manifest_file = argv[arg_idx];
|
||||
}
|
||||
|
||||
if (!arg_found) {
|
||||
throw std::invalid_argument("error: unknown argument: " + arg);
|
||||
}
|
||||
}
|
||||
|
||||
if (invalid_param) {
|
||||
throw std::invalid_argument("error: invalid parameter for argument:" + arg);
|
||||
}
|
||||
|
||||
if (argc - arg_idx < 1) {
|
||||
throw std::invalid_argument("error: bad arguments");
|
||||
}
|
||||
|
||||
params.input = argv[arg_idx++];
|
||||
}
|
||||
|
||||
static bool hash_params_parse(int argc, const char ** argv, hash_params & params) {
|
||||
bool result = true;
|
||||
try {
|
||||
hash_params_parse_ex(argc, argv, params);
|
||||
}
|
||||
catch (const std::invalid_argument & ex) {
|
||||
fprintf(stderr, "%s\n", ex.what());
|
||||
hash_print_usage(argv[0]);
|
||||
exit(EXIT_FAILURE);
|
||||
}
|
||||
return result;
|
||||
}
|
||||
|
||||
static bool manifest_type(const std::string & manifest_file, manifest_check_params & manifest_check) {
|
||||
if (manifest_file.empty()) {
|
||||
return false;
|
||||
}
|
||||
|
||||
std::ifstream file(manifest_file);
|
||||
if (!file.is_open()) {
|
||||
return false;
|
||||
}
|
||||
|
||||
std::string manifest_entry_line;
|
||||
while (getline(file, manifest_entry_line)) {
|
||||
// hash_type_str hash_str tensor_name
|
||||
// e.g. 'xxh64 f66e9cd66a4396a0 test.gguf:tensor_0'
|
||||
std::istringstream line_stream(manifest_entry_line);
|
||||
std::string file_hash_type;
|
||||
if (line_stream >> file_hash_type) {
|
||||
if (file_hash_type == HASH_TYPE_SHA256_STR) {
|
||||
manifest_check.sha256 = true;
|
||||
} else if (file_hash_type == HASH_TYPE_SHA1_STR) {
|
||||
manifest_check.sha1 = true;
|
||||
} else if (file_hash_type == HASH_TYPE_XXH64_STR) {
|
||||
manifest_check.xxh64 = true;
|
||||
} else if (file_hash_type == HASH_TYPE_UUID_STR) {
|
||||
manifest_check.uuid = true;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
static hash_manifest_result_t manifest_verify(const std::string& manifest_file, const std::string& hash_type_str, const std::string& hash_str, const std::string& tensor_name) {
|
||||
if (manifest_file.empty()) {
|
||||
return HASH_MANIFEST_NOT_FOUND;
|
||||
}
|
||||
|
||||
std::ifstream file(manifest_file);
|
||||
if (!file.is_open()) {
|
||||
return HASH_MANIFEST_NOT_FOUND;
|
||||
}
|
||||
|
||||
std::string manifest_entry_line;
|
||||
while (getline(file, manifest_entry_line)) {
|
||||
std::istringstream line_stream(manifest_entry_line);
|
||||
std::string file_hash_type;
|
||||
std::string file_hash;
|
||||
std::string file_tensor_name;
|
||||
if (line_stream >> file_hash_type >> file_hash >> file_tensor_name) {
|
||||
// Line parsed. Check hash validity
|
||||
|
||||
if (file_hash_type != hash_type_str) {
|
||||
continue;
|
||||
}
|
||||
|
||||
if (file_tensor_name != tensor_name) {
|
||||
continue;
|
||||
}
|
||||
|
||||
return (file_hash == hash_str) ? HASH_MANIFEST_OK : HASH_MANIFEST_MISMATCH;
|
||||
}
|
||||
}
|
||||
|
||||
return HASH_MANIFEST_NOT_FOUND;
|
||||
}
|
||||
|
||||
static void generate_uuidv5(const unsigned char sha1_digest[20], unsigned char uuid[16]) {
|
||||
// Ref: https://www.rfc-editor.org/rfc/rfc9562.html#section-5.5
|
||||
// Assumes that digest was processed correctly with the expected namespace
|
||||
for (int i = 0; i < 16; i++) {
|
||||
uuid[i] = sha1_digest[i];
|
||||
}
|
||||
|
||||
// Set bits corresponding to UUID ver 5
|
||||
uuid[ 6] &= ~(0xF << 4);
|
||||
uuid[ 6] |= (5 << 4);
|
||||
|
||||
// Set bits corresponding to UUID variant 0b10XX
|
||||
uuid[ 8] &= ~(0xc << 4);
|
||||
uuid[ 8] |= (0x8 << 4);
|
||||
}
|
||||
|
||||
static hash_exit_code_t gguf_hash(const hash_params & hash_params) {
|
||||
const std::string & fname = hash_params.input;
|
||||
struct ggml_context * ctx_data = NULL;
|
||||
|
||||
struct gguf_init_params params = {
|
||||
/*.no_alloc = */ false,
|
||||
/*.ctx = */ &ctx_data,
|
||||
};
|
||||
|
||||
// xxh64 init
|
||||
XXH64_state_t* xxh64_model_hash_state = NULL;
|
||||
if (hash_params.xxh64) {
|
||||
xxh64_model_hash_state = XXH64_createState();
|
||||
if (xxh64_model_hash_state==NULL) {
|
||||
abort();
|
||||
}
|
||||
|
||||
XXH64_hash_t const seed = 0;
|
||||
if (XXH64_reset(xxh64_model_hash_state, seed) == XXH_ERROR) {
|
||||
abort();
|
||||
}
|
||||
}
|
||||
|
||||
// sha1 init
|
||||
SHA1_CTX sha1_model_hash_ctx;
|
||||
if (hash_params.sha1) {
|
||||
SHA1Init(&sha1_model_hash_ctx);
|
||||
}
|
||||
|
||||
// sha256 init
|
||||
sha256_t sha256_model_hash_ctx;
|
||||
if (hash_params.sha256) {
|
||||
sha256_init(&sha256_model_hash_ctx);
|
||||
}
|
||||
|
||||
// sha1 for uuid init
|
||||
SHA1_CTX sha1_for_uuid_ctx;
|
||||
if (hash_params.uuid) {
|
||||
unsigned char const uuidv5_namespace[] = {UUID_NAMESPACE_LLAMA_CPP_HEX};
|
||||
SHA1Init(&sha1_for_uuid_ctx);
|
||||
SHA1Update( &sha1_for_uuid_ctx, (unsigned char const *)uuidv5_namespace, sizeof(uuidv5_namespace));
|
||||
}
|
||||
|
||||
struct gguf_context * ctx = gguf_init_from_file(fname.c_str(), params);
|
||||
const int n_tensors = gguf_get_n_tensors(ctx);
|
||||
bool tensor_layer_in_manifest = false;
|
||||
bool model_in_manifest = false;
|
||||
bool tensor_layer_has_mismatch = false;
|
||||
bool model_has_mismatch = false;
|
||||
for (int i = 0; i < n_tensors; ++i) {
|
||||
const char * name = gguf_get_tensor_name(ctx, i);
|
||||
struct ggml_tensor * cur = ggml_get_tensor(ctx_data, name);
|
||||
auto n_bytes = ggml_nbytes(cur);
|
||||
auto *raw_data = cur->data;
|
||||
const std::string tensor_layer_name = fname + ":" + name;
|
||||
|
||||
if (hash_params.xxh64) {
|
||||
|
||||
if (!hash_params.no_layer) {
|
||||
// Per Layer Hash
|
||||
XXH64_hash_t hash = XXH64(raw_data, n_bytes, 0);
|
||||
|
||||
char hex_result[17];
|
||||
for (int offset = 0; offset < 8; offset++) {
|
||||
unsigned int shift_bits_by = (8 * (8 - offset - 1));
|
||||
snprintf( ( hex_result + (2*offset)), sizeof(hex_result) - (2*offset), "%02x", (unsigned char) (hash >> shift_bits_by)&0xff);
|
||||
}
|
||||
|
||||
if (hash_params.manifest_is_usable) {
|
||||
hash_manifest_result_t verify_result = manifest_verify(hash_params.manifest_file, HASH_TYPE_XXH64_STR, hex_result, tensor_layer_name);
|
||||
|
||||
switch (verify_result) {
|
||||
case HASH_MANIFEST_NOT_FOUND:
|
||||
break;
|
||||
case HASH_MANIFEST_MISMATCH:
|
||||
tensor_layer_in_manifest = true;
|
||||
tensor_layer_has_mismatch = true;
|
||||
break;
|
||||
case HASH_MANIFEST_OK:
|
||||
tensor_layer_in_manifest = true;
|
||||
break;
|
||||
}
|
||||
|
||||
printf("%-8s %-s %s - %s\n", HASH_TYPE_XXH64_STR, hex_result, tensor_layer_name.c_str(), hash_manifest_result_to_str(verify_result));
|
||||
} else {
|
||||
printf("%-8s %-s %s\n", HASH_TYPE_XXH64_STR, hex_result, tensor_layer_name.c_str());
|
||||
}
|
||||
}
|
||||
|
||||
// Overall Model Hash
|
||||
if (XXH64_update(xxh64_model_hash_state, raw_data, n_bytes) == XXH_ERROR) abort();
|
||||
}
|
||||
|
||||
if (hash_params.sha1) {
|
||||
|
||||
if (!hash_params.no_layer) {
|
||||
// Per Layer Hash
|
||||
char result[21]; // sha1 outputs 20 bytes
|
||||
SHA1( result, (const char *)raw_data, n_bytes);
|
||||
|
||||
char hex_result[41] = {0};
|
||||
for (int offset = 0; offset < 20; offset++) {
|
||||
snprintf( ( hex_result + (2*offset)), sizeof(hex_result) - (2*offset), "%02x", result[offset]&0xff);
|
||||
}
|
||||
|
||||
if (hash_params.manifest_is_usable) {
|
||||
hash_manifest_result_t verify_result = manifest_verify(hash_params.manifest_file, HASH_TYPE_SHA1_STR, hex_result, tensor_layer_name);
|
||||
|
||||
switch (verify_result) {
|
||||
case HASH_MANIFEST_NOT_FOUND:
|
||||
break;
|
||||
case HASH_MANIFEST_MISMATCH:
|
||||
tensor_layer_in_manifest = true;
|
||||
tensor_layer_has_mismatch = true;
|
||||
break;
|
||||
case HASH_MANIFEST_OK:
|
||||
tensor_layer_in_manifest = true;
|
||||
break;
|
||||
}
|
||||
|
||||
printf("%-8s %-s %s - %s\n", HASH_TYPE_SHA1_STR, hex_result, tensor_layer_name.c_str(), hash_manifest_result_to_str(verify_result));
|
||||
} else {
|
||||
printf("%-8s %-s %s\n", HASH_TYPE_SHA1_STR, hex_result, tensor_layer_name.c_str());
|
||||
}
|
||||
}
|
||||
|
||||
// Overall Model Hash
|
||||
SHA1Update( &sha1_model_hash_ctx, (unsigned char const *)raw_data, n_bytes);
|
||||
}
|
||||
|
||||
if (hash_params.sha256) {
|
||||
|
||||
if (!hash_params.no_layer) {
|
||||
// Per Layer Hash
|
||||
unsigned char result[SHA256_DIGEST_SIZE]; // sha256 outputs 32 bytes
|
||||
sha256_hash((unsigned char*) result, (const unsigned char *)raw_data, n_bytes);
|
||||
|
||||
char hex_result[SHA256_DIGEST_SIZE * 2 + 1] = {0};
|
||||
for (int offset = 0; offset < SHA256_DIGEST_SIZE; offset++) {
|
||||
snprintf( ( hex_result + (2*offset)), sizeof(hex_result) - (2*offset), "%02x", result[offset]&0xff);
|
||||
}
|
||||
|
||||
if (hash_params.manifest_is_usable) {
|
||||
hash_manifest_result_t verify_result = manifest_verify(hash_params.manifest_file, HASH_TYPE_SHA256_STR, hex_result, tensor_layer_name);
|
||||
|
||||
switch (verify_result) {
|
||||
case HASH_MANIFEST_NOT_FOUND:
|
||||
break;
|
||||
case HASH_MANIFEST_MISMATCH:
|
||||
tensor_layer_in_manifest = true;
|
||||
tensor_layer_has_mismatch = true;
|
||||
break;
|
||||
case HASH_MANIFEST_OK:
|
||||
tensor_layer_in_manifest = true;
|
||||
break;
|
||||
}
|
||||
|
||||
printf("%-8s %-s %s - %s\n", HASH_TYPE_SHA256_STR, hex_result, tensor_layer_name.c_str(), hash_manifest_result_to_str(verify_result));
|
||||
} else {
|
||||
printf("%-8s %-s %s\n", HASH_TYPE_SHA256_STR, hex_result, tensor_layer_name.c_str());
|
||||
}
|
||||
}
|
||||
|
||||
// Overall Model Hash
|
||||
sha256_update( &sha256_model_hash_ctx, (unsigned char const *)raw_data, n_bytes);
|
||||
}
|
||||
|
||||
if (hash_params.uuid) {
|
||||
SHA1Update( &sha1_for_uuid_ctx, (unsigned char const *)raw_data, n_bytes);
|
||||
}
|
||||
}
|
||||
|
||||
if (hash_params.xxh64) {
|
||||
XXH64_hash_t const hash = XXH64_digest(xxh64_model_hash_state);
|
||||
|
||||
char hex_result[17];
|
||||
for (int offset = 0; offset < 8; offset++) {
|
||||
unsigned int shift_bits_by = (8 * (8 - offset - 1));
|
||||
snprintf( ( hex_result + (2*offset)), sizeof(hex_result) - (2*offset), "%02x", (unsigned char) (hash >> shift_bits_by)&0xff);
|
||||
}
|
||||
|
||||
if (hash_params.manifest_is_usable) {
|
||||
hash_manifest_result_t verify_result = manifest_verify(hash_params.manifest_file, HASH_TYPE_XXH64_STR, hex_result, fname);
|
||||
|
||||
switch (verify_result) {
|
||||
case HASH_MANIFEST_NOT_FOUND:
|
||||
break;
|
||||
case HASH_MANIFEST_MISMATCH:
|
||||
model_in_manifest = true;
|
||||
model_has_mismatch = true;
|
||||
break;
|
||||
case HASH_MANIFEST_OK:
|
||||
model_in_manifest = true;
|
||||
break;
|
||||
}
|
||||
|
||||
printf("%-8s %-s %s - %s\n", HASH_TYPE_XXH64_STR, hex_result, fname.c_str(), hash_manifest_result_to_str(verify_result));
|
||||
} else {
|
||||
printf("%-8s %-s %s\n", HASH_TYPE_XXH64_STR, hex_result, fname.c_str());
|
||||
}
|
||||
}
|
||||
|
||||
if (hash_params.sha1) {
|
||||
unsigned char result[21];
|
||||
SHA1Final(result, &sha1_model_hash_ctx);
|
||||
|
||||
char hex_result[41];
|
||||
for (int offset = 0; offset < 20; offset++) {
|
||||
snprintf( ( hex_result + (2*offset)), sizeof(hex_result) - (2*offset), "%02x", result[offset]&0xff);
|
||||
}
|
||||
|
||||
if (hash_params.manifest_is_usable) {
|
||||
hash_manifest_result_t verify_result = manifest_verify(hash_params.manifest_file, HASH_TYPE_SHA1_STR, hex_result, fname);
|
||||
|
||||
switch (verify_result) {
|
||||
case HASH_MANIFEST_NOT_FOUND:
|
||||
break;
|
||||
case HASH_MANIFEST_MISMATCH:
|
||||
model_in_manifest = true;
|
||||
model_has_mismatch = true;
|
||||
break;
|
||||
case HASH_MANIFEST_OK:
|
||||
model_in_manifest = true;
|
||||
break;
|
||||
}
|
||||
|
||||
printf("%-8s %-s %s - %s\n", HASH_TYPE_SHA1_STR, hex_result, fname.c_str(), hash_manifest_result_to_str(verify_result));
|
||||
} else {
|
||||
printf("%-8s %-s %s\n", HASH_TYPE_SHA1_STR, hex_result, fname.c_str());
|
||||
}
|
||||
}
|
||||
|
||||
if (hash_params.sha256) {
|
||||
unsigned char result[SHA256_DIGEST_SIZE]; // sha256 outputs 32 bytes
|
||||
sha256_final( &sha256_model_hash_ctx, result);
|
||||
|
||||
char hex_result[SHA256_DIGEST_SIZE * 2 + 1] = {0};
|
||||
for (int offset = 0; offset < SHA256_DIGEST_SIZE; offset++) {
|
||||
snprintf( ( hex_result + (2*offset)), sizeof(hex_result) - (2*offset), "%02x", result[offset]&0xff);
|
||||
}
|
||||
|
||||
if (hash_params.manifest_is_usable) {
|
||||
hash_manifest_result_t verify_result = manifest_verify(hash_params.manifest_file, HASH_TYPE_SHA256_STR, hex_result, fname);
|
||||
|
||||
switch (verify_result) {
|
||||
case HASH_MANIFEST_NOT_FOUND:
|
||||
break;
|
||||
case HASH_MANIFEST_MISMATCH:
|
||||
model_in_manifest = true;
|
||||
model_has_mismatch = true;
|
||||
break;
|
||||
case HASH_MANIFEST_OK:
|
||||
model_in_manifest = true;
|
||||
break;
|
||||
}
|
||||
|
||||
printf("%-8s %-s %s - %s\n", HASH_TYPE_SHA256_STR, hex_result, fname.c_str(), hash_manifest_result_to_str(verify_result));
|
||||
} else {
|
||||
printf("%-8s %-s %s\n", HASH_TYPE_SHA256_STR, hex_result, fname.c_str());
|
||||
}
|
||||
}
|
||||
|
||||
if (hash_params.uuid) {
|
||||
unsigned char result[21];
|
||||
SHA1Final(result, &sha1_for_uuid_ctx);
|
||||
|
||||
unsigned char uuid[16];
|
||||
generate_uuidv5(result, uuid);
|
||||
|
||||
char string_buffer[37] = {0};
|
||||
snprintf(string_buffer, sizeof(string_buffer), "%02x%02x%02x%02x-%02x%02x-%02x%02x-%02x%02x-%02x%02x%02x%02x%02x%02x",
|
||||
uuid[0], uuid[1], uuid[2], uuid[3],
|
||||
uuid[4], uuid[5], uuid[6], uuid[7],
|
||||
uuid[8], uuid[9], uuid[10], uuid[11],
|
||||
uuid[12], uuid[13], uuid[14], uuid[15]);
|
||||
|
||||
if (hash_params.manifest_is_usable) {
|
||||
hash_manifest_result_t verify_result = manifest_verify(hash_params.manifest_file, HASH_TYPE_SHA256_STR, string_buffer, fname);
|
||||
|
||||
switch (verify_result) {
|
||||
case HASH_MANIFEST_NOT_FOUND:
|
||||
break;
|
||||
case HASH_MANIFEST_MISMATCH:
|
||||
model_in_manifest = true;
|
||||
model_has_mismatch = true;
|
||||
break;
|
||||
case HASH_MANIFEST_OK:
|
||||
model_in_manifest = true;
|
||||
break;
|
||||
}
|
||||
|
||||
printf("%-8s %-s %s - %s\n", HASH_TYPE_UUID_STR, string_buffer, fname.c_str(), hash_manifest_result_to_str(verify_result));
|
||||
} else {
|
||||
printf("%-8s %-s %s\n", HASH_TYPE_UUID_STR, string_buffer, fname.c_str());
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
ggml_free(ctx_data);
|
||||
gguf_free(ctx);
|
||||
|
||||
|
||||
if (hash_params.manifest_is_usable) {
|
||||
// In hash verification mode
|
||||
|
||||
if (!model_in_manifest) {
|
||||
// model missing in manifest?
|
||||
|
||||
// Check tensor layer...
|
||||
if (!tensor_layer_in_manifest) {
|
||||
// Still missing? Maybe we are reading the wrong manifest.
|
||||
return HASH_EXIT_MANIFEST_MISSING_ENTRY;
|
||||
}
|
||||
|
||||
if (tensor_layer_has_mismatch) {
|
||||
// Per tensor check found error
|
||||
return HASH_EXIT_FAILURE;
|
||||
}
|
||||
|
||||
// All per tensor layer checks passed? Sounds good enough.
|
||||
return HASH_EXIT_SUCCESS;
|
||||
}
|
||||
|
||||
// Overall model check passed, but let's check per layer just in case
|
||||
// If missing, we don't care too much as the overall model checked
|
||||
if (tensor_layer_in_manifest && tensor_layer_has_mismatch) {
|
||||
return HASH_EXIT_FAILURE;
|
||||
}
|
||||
|
||||
if (model_has_mismatch) {
|
||||
// model has failed hash somewhere in the model
|
||||
return HASH_EXIT_FAILURE;
|
||||
}
|
||||
|
||||
// All checks appears to be fine
|
||||
return HASH_EXIT_SUCCESS;
|
||||
}
|
||||
|
||||
// In hash generation mode
|
||||
return HASH_EXIT_SUCCESS;
|
||||
}
|
||||
|
||||
int main(int argc, const char ** argv) {
|
||||
hash_params params;
|
||||
manifest_check_params manifest_check;
|
||||
hash_params_parse(argc, argv, params);
|
||||
|
||||
if (!params.manifest_file.empty()) {
|
||||
if (!manifest_type(params.manifest_file, manifest_check)) {
|
||||
printf("ERROR cannot open manifest %s", params.manifest_file.c_str());
|
||||
return HASH_EXIT_MANIFEST_FILE_ERROR;
|
||||
}
|
||||
|
||||
if (!manifest_check.sha256 && !manifest_check.sha1 && !manifest_check.xxh64 && !manifest_check.uuid) {
|
||||
printf("ERROR manifest does not have any known hash format in %s", params.manifest_file.c_str());
|
||||
return HASH_EXIT_MANIFEST_UNKNOWN_HASH;
|
||||
}
|
||||
|
||||
printf("manifest %s", params.manifest_file.c_str());
|
||||
|
||||
if (manifest_check.sha256) {
|
||||
printf(" sha256");
|
||||
}
|
||||
|
||||
if (manifest_check.sha1) {
|
||||
printf(" sha1");
|
||||
}
|
||||
|
||||
if (manifest_check.xxh64) {
|
||||
printf(" xxh64");
|
||||
}
|
||||
|
||||
if (manifest_check.uuid) {
|
||||
printf(" uuid");
|
||||
}
|
||||
|
||||
printf("\n");
|
||||
|
||||
// Autoselect the highest security hash if manifest is provided but
|
||||
// the user has not specifically defined the hash they care about
|
||||
if (!params.xxh64 && !params.sha1 && !params.uuid && !params.sha256) {
|
||||
// User has not selected a specific value, pick most secure hash
|
||||
if (manifest_check.sha256) {
|
||||
params.sha256 = true;
|
||||
} else if (manifest_check.sha1) {
|
||||
params.sha1 = true;
|
||||
} else if (manifest_check.xxh64) {
|
||||
params.xxh64 = true;
|
||||
} else if (manifest_check.uuid) {
|
||||
params.uuid = true;
|
||||
}
|
||||
}
|
||||
|
||||
params.manifest_is_usable = true;
|
||||
}
|
||||
|
||||
// By default if no swich argument provided, assume xxh64
|
||||
if (!params.xxh64 && !params.sha1 && !params.uuid && !params.sha256) {
|
||||
params.xxh64 = true;
|
||||
}
|
||||
|
||||
hash_exit_code_t exit_code = gguf_hash(params);
|
||||
|
||||
if (params.manifest_is_usable) {
|
||||
printf("\nVerification results for %s - %s\n", params.manifest_file.c_str(), hash_exit_code_to_str(exit_code));
|
||||
}
|
||||
|
||||
return exit_code;
|
||||
}
|
||||
5
examples/gguf/CMakeLists.txt
Normal file
@@ -0,0 +1,5 @@
|
||||
set(TARGET llama-gguf)
|
||||
add_executable(${TARGET} gguf.cpp)
|
||||
install(TARGETS ${TARGET} RUNTIME)
|
||||
target_link_libraries(${TARGET} PRIVATE ggml ${CMAKE_THREAD_LIBS_INIT})
|
||||
target_compile_features(${TARGET} PRIVATE cxx_std_17)
|
||||
270
examples/gguf/gguf.cpp
Normal file
@@ -0,0 +1,270 @@
|
||||
#include "ggml.h"
|
||||
#include "gguf.h"
|
||||
|
||||
#include <cstdio>
|
||||
#include <string>
|
||||
#include <sstream>
|
||||
#include <vector>
|
||||
|
||||
#undef MIN
|
||||
#undef MAX
|
||||
#define MIN(a, b) ((a) < (b) ? (a) : (b))
|
||||
#define MAX(a, b) ((a) > (b) ? (a) : (b))
|
||||
|
||||
template <typename T>
|
||||
static std::string to_string(const T & val) {
|
||||
std::stringstream ss;
|
||||
ss << val;
|
||||
return ss.str();
|
||||
}
|
||||
|
||||
static bool gguf_ex_write(const std::string & fname) {
|
||||
struct gguf_context * ctx = gguf_init_empty();
|
||||
|
||||
gguf_set_val_u8 (ctx, "some.parameter.uint8", 0x12);
|
||||
gguf_set_val_i8 (ctx, "some.parameter.int8", -0x13);
|
||||
gguf_set_val_u16 (ctx, "some.parameter.uint16", 0x1234);
|
||||
gguf_set_val_i16 (ctx, "some.parameter.int16", -0x1235);
|
||||
gguf_set_val_u32 (ctx, "some.parameter.uint32", 0x12345678);
|
||||
gguf_set_val_i32 (ctx, "some.parameter.int32", -0x12345679);
|
||||
gguf_set_val_f32 (ctx, "some.parameter.float32", 0.123456789f);
|
||||
gguf_set_val_u64 (ctx, "some.parameter.uint64", 0x123456789abcdef0ull);
|
||||
gguf_set_val_i64 (ctx, "some.parameter.int64", -0x123456789abcdef1ll);
|
||||
gguf_set_val_f64 (ctx, "some.parameter.float64", 0.1234567890123456789);
|
||||
gguf_set_val_bool(ctx, "some.parameter.bool", true);
|
||||
gguf_set_val_str (ctx, "some.parameter.string", "hello world");
|
||||
|
||||
gguf_set_arr_data(ctx, "some.parameter.arr.i16", GGUF_TYPE_INT16, std::vector<int16_t>{ 1, 2, 3, 4, }.data(), 4);
|
||||
gguf_set_arr_data(ctx, "some.parameter.arr.f32", GGUF_TYPE_FLOAT32, std::vector<float>{ 3.145f, 2.718f, 1.414f, }.data(), 3);
|
||||
gguf_set_arr_str (ctx, "some.parameter.arr.str", std::vector<const char *>{ "hello", "world", "!" }.data(), 3);
|
||||
|
||||
struct ggml_init_params params = {
|
||||
/*.mem_size =*/ 128ull*1024ull*1024ull,
|
||||
/*.mem_buffer =*/ NULL,
|
||||
/*.no_alloc =*/ false,
|
||||
};
|
||||
|
||||
struct ggml_context * ctx_data = ggml_init(params);
|
||||
|
||||
const int n_tensors = 10;
|
||||
|
||||
// tensor infos
|
||||
for (int i = 0; i < n_tensors; ++i) {
|
||||
const std::string name = "tensor_" + to_string(i);
|
||||
|
||||
int64_t ne[GGML_MAX_DIMS] = { 1 };
|
||||
int32_t n_dims = rand() % GGML_MAX_DIMS + 1;
|
||||
|
||||
for (int j = 0; j < n_dims; ++j) {
|
||||
ne[j] = rand() % 10 + 1;
|
||||
}
|
||||
|
||||
struct ggml_tensor * cur = ggml_new_tensor(ctx_data, GGML_TYPE_F32, n_dims, ne);
|
||||
ggml_set_name(cur, name.c_str());
|
||||
|
||||
{
|
||||
float * data = (float *) cur->data;
|
||||
for (int j = 0; j < ggml_nelements(cur); ++j) {
|
||||
data[j] = 100 + i;
|
||||
}
|
||||
}
|
||||
|
||||
gguf_add_tensor(ctx, cur);
|
||||
}
|
||||
|
||||
gguf_write_to_file(ctx, fname.c_str(), false);
|
||||
|
||||
printf("%s: wrote file '%s;\n", __func__, fname.c_str());
|
||||
|
||||
ggml_free(ctx_data);
|
||||
gguf_free(ctx);
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
// just read tensor info
|
||||
static bool gguf_ex_read_0(const std::string & fname) {
|
||||
struct gguf_init_params params = {
|
||||
/*.no_alloc = */ false,
|
||||
/*.ctx = */ NULL,
|
||||
};
|
||||
|
||||
struct gguf_context * ctx = gguf_init_from_file(fname.c_str(), params);
|
||||
|
||||
if (!ctx) {
|
||||
fprintf(stderr, "%s: failed to load '%s'\n", __func__, fname.c_str());
|
||||
return false;
|
||||
}
|
||||
|
||||
printf("%s: version: %d\n", __func__, gguf_get_version(ctx));
|
||||
printf("%s: alignment: %zu\n", __func__, gguf_get_alignment(ctx));
|
||||
printf("%s: data offset: %zu\n", __func__, gguf_get_data_offset(ctx));
|
||||
|
||||
// kv
|
||||
{
|
||||
const int n_kv = gguf_get_n_kv(ctx);
|
||||
|
||||
printf("%s: n_kv: %d\n", __func__, n_kv);
|
||||
|
||||
for (int i = 0; i < n_kv; ++i) {
|
||||
const char * key = gguf_get_key(ctx, i);
|
||||
|
||||
printf("%s: kv[%d]: key = %s\n", __func__, i, key);
|
||||
}
|
||||
}
|
||||
|
||||
// find kv string
|
||||
{
|
||||
const char * findkey = "some.parameter.string";
|
||||
|
||||
const int keyidx = gguf_find_key(ctx, findkey);
|
||||
if (keyidx == -1) {
|
||||
printf("%s: find key: %s not found.\n", __func__, findkey);
|
||||
} else {
|
||||
const char * key_value = gguf_get_val_str(ctx, keyidx);
|
||||
printf("%s: find key: %s found, kv[%d] value = %s\n", __func__, findkey, keyidx, key_value);
|
||||
}
|
||||
}
|
||||
|
||||
// tensor info
|
||||
{
|
||||
const int n_tensors = gguf_get_n_tensors(ctx);
|
||||
|
||||
printf("%s: n_tensors: %d\n", __func__, n_tensors);
|
||||
|
||||
for (int i = 0; i < n_tensors; ++i) {
|
||||
const char * name = gguf_get_tensor_name (ctx, i);
|
||||
const size_t size = gguf_get_tensor_size (ctx, i);
|
||||
const size_t offset = gguf_get_tensor_offset(ctx, i);
|
||||
|
||||
printf("%s: tensor[%d]: name = %s, size = %zu, offset = %zu\n", __func__, i, name, size, offset);
|
||||
}
|
||||
}
|
||||
|
||||
gguf_free(ctx);
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
// read and create ggml_context containing the tensors and their data
|
||||
static bool gguf_ex_read_1(const std::string & fname, bool check_data) {
|
||||
struct ggml_context * ctx_data = NULL;
|
||||
|
||||
struct gguf_init_params params = {
|
||||
/*.no_alloc = */ false,
|
||||
/*.ctx = */ &ctx_data,
|
||||
};
|
||||
|
||||
struct gguf_context * ctx = gguf_init_from_file(fname.c_str(), params);
|
||||
|
||||
printf("%s: version: %d\n", __func__, gguf_get_version(ctx));
|
||||
printf("%s: alignment: %zu\n", __func__, gguf_get_alignment(ctx));
|
||||
printf("%s: data offset: %zu\n", __func__, gguf_get_data_offset(ctx));
|
||||
|
||||
// kv
|
||||
{
|
||||
const int n_kv = gguf_get_n_kv(ctx);
|
||||
|
||||
printf("%s: n_kv: %d\n", __func__, n_kv);
|
||||
|
||||
for (int i = 0; i < n_kv; ++i) {
|
||||
const char * key = gguf_get_key(ctx, i);
|
||||
|
||||
printf("%s: kv[%d]: key = %s\n", __func__, i, key);
|
||||
}
|
||||
}
|
||||
|
||||
// tensor info
|
||||
{
|
||||
const int n_tensors = gguf_get_n_tensors(ctx);
|
||||
|
||||
printf("%s: n_tensors: %d\n", __func__, n_tensors);
|
||||
|
||||
for (int i = 0; i < n_tensors; ++i) {
|
||||
const char * name = gguf_get_tensor_name (ctx, i);
|
||||
const size_t size = gguf_get_tensor_size (ctx, i);
|
||||
const size_t offset = gguf_get_tensor_offset(ctx, i);
|
||||
const auto type = gguf_get_tensor_type (ctx, i);
|
||||
|
||||
const char * type_name = ggml_type_name(type);
|
||||
const size_t type_size = ggml_type_size(type);
|
||||
const size_t n_elements = size / type_size;
|
||||
|
||||
printf("%s: tensor[%d]: name = %s, size = %zu, offset = %zu, type = %s, n_elts = %zu\n", __func__, i, name, size, offset, type_name, n_elements);
|
||||
}
|
||||
}
|
||||
|
||||
// data
|
||||
{
|
||||
const int n_tensors = gguf_get_n_tensors(ctx);
|
||||
|
||||
for (int i = 0; i < n_tensors; ++i) {
|
||||
printf("%s: reading tensor %d data\n", __func__, i);
|
||||
|
||||
const char * name = gguf_get_tensor_name(ctx, i);
|
||||
|
||||
struct ggml_tensor * cur = ggml_get_tensor(ctx_data, name);
|
||||
|
||||
printf("%s: tensor[%d]: n_dims = %d, ne = (%d, %d, %d, %d), name = %s, data = %p\n",
|
||||
__func__, i, ggml_n_dims(cur), int(cur->ne[0]), int(cur->ne[1]), int(cur->ne[2]), int(cur->ne[3]), cur->name, cur->data);
|
||||
|
||||
// print first 10 elements
|
||||
const float * data = (const float *) cur->data;
|
||||
|
||||
printf("%s data[:10] : ", name);
|
||||
for (int j = 0; j < MIN(10, ggml_nelements(cur)); ++j) {
|
||||
printf("%f ", data[j]);
|
||||
}
|
||||
printf("\n\n");
|
||||
|
||||
// check data
|
||||
if (check_data) {
|
||||
const float * data = (const float *) cur->data;
|
||||
for (int j = 0; j < ggml_nelements(cur); ++j) {
|
||||
if (data[j] != 100 + i) {
|
||||
fprintf(stderr, "%s: tensor[%d], data[%d]: found %f, expected %f\n", __func__, i, j, data[j], float(100 + i));
|
||||
gguf_free(ctx);
|
||||
return false;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
printf("%s: ctx_data size: %zu\n", __func__, ggml_get_mem_size(ctx_data));
|
||||
|
||||
ggml_free(ctx_data);
|
||||
gguf_free(ctx);
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
int main(int argc, char ** argv) {
|
||||
if (argc < 3) {
|
||||
printf("usage: %s data.gguf r|w [n]\n", argv[0]);
|
||||
printf("r: read data.gguf file\n");
|
||||
printf("w: write data.gguf file\n");
|
||||
printf("n: no check of tensor data\n");
|
||||
return -1;
|
||||
}
|
||||
bool check_data = true;
|
||||
if (argc == 4) {
|
||||
check_data = false;
|
||||
}
|
||||
|
||||
srand(123456);
|
||||
|
||||
const std::string fname(argv[1]);
|
||||
const std::string mode (argv[2]);
|
||||
|
||||
GGML_ASSERT((mode == "r" || mode == "w") && "mode must be r or w");
|
||||
|
||||
if (mode == "w") {
|
||||
GGML_ASSERT(gguf_ex_write(fname) && "failed to write gguf file");
|
||||
} else if (mode == "r") {
|
||||
GGML_ASSERT(gguf_ex_read_0(fname) && "failed to read gguf file");
|
||||
GGML_ASSERT(gguf_ex_read_1(fname, check_data) && "failed to read gguf file");
|
||||
}
|
||||
|
||||
return 0;
|
||||
}
|
||||
5
examples/idle/CMakeLists.txt
Normal file
@@ -0,0 +1,5 @@
|
||||
set(TARGET llama-idle)
|
||||
add_executable(${TARGET} idle.cpp)
|
||||
install(TARGETS ${TARGET} RUNTIME)
|
||||
target_link_libraries(${TARGET} PRIVATE llama common ${CMAKE_THREAD_LIBS_INIT})
|
||||
target_compile_features(${TARGET} PRIVATE cxx_std_11)
|
||||
3
examples/idle/README.md
Normal file
@@ -0,0 +1,3 @@
|
||||
# llama.cpp/example/idle
|
||||
|
||||
https://github.com/ggml-org/llama.cpp/pull/17766
|
||||
110
examples/idle/idle.cpp
Normal file
@@ -0,0 +1,110 @@
|
||||
#include "arg.h"
|
||||
#include "common.h"
|
||||
#include "log.h"
|
||||
#include "llama.h"
|
||||
|
||||
#include <cmath>
|
||||
#include <cstdio>
|
||||
#include <cstring>
|
||||
#include <string>
|
||||
#include <thread>
|
||||
#include <vector>
|
||||
|
||||
static void print_usage(int /*argc*/, char ** argv) {
|
||||
printf("\nexample usage:\n");
|
||||
printf("\n %s -m model.gguf [-ngl n_gpu_layers]\n", argv[0]);
|
||||
printf("\n");
|
||||
}
|
||||
|
||||
int main(int argc, char ** argv) {
|
||||
common_params params;
|
||||
|
||||
if (!common_params_parse(argc, argv, params, LLAMA_EXAMPLE_COMMON, print_usage)) {
|
||||
return 1;
|
||||
}
|
||||
|
||||
common_init();
|
||||
|
||||
// init LLM
|
||||
|
||||
llama_backend_init();
|
||||
llama_numa_init(params.numa);
|
||||
|
||||
// initialize the model
|
||||
|
||||
llama_model_params model_params = common_model_params_to_llama(params);
|
||||
|
||||
llama_model * model = llama_model_load_from_file(params.model.path.c_str(), model_params);
|
||||
|
||||
if (model == NULL) {
|
||||
LOG_ERR("%s: error: unable to load model\n" , __func__);
|
||||
return 1;
|
||||
}
|
||||
|
||||
const llama_vocab * vocab = llama_model_get_vocab(model);
|
||||
|
||||
// we need just a dummy token to evaluate
|
||||
std::vector<llama_token> prompt_tokens(1, llama_vocab_bos(vocab));
|
||||
|
||||
llama_context_params ctx_params = llama_context_default_params();
|
||||
ctx_params.n_ctx = 512;
|
||||
ctx_params.n_batch = 512;
|
||||
ctx_params.no_perf = false;
|
||||
|
||||
llama_context * ctx = llama_init_from_model(model, ctx_params);
|
||||
if (ctx == NULL) {
|
||||
fprintf(stderr , "%s: error: failed to create the llama_context\n" , __func__);
|
||||
return 1;
|
||||
}
|
||||
|
||||
llama_batch batch = llama_batch_get_one(prompt_tokens.data(), prompt_tokens.size());
|
||||
|
||||
const int n_iters = 3;
|
||||
|
||||
// warm-up
|
||||
llama_decode(ctx, batch);
|
||||
llama_memory_clear(llama_get_memory(ctx), true);
|
||||
llama_synchronize(ctx);
|
||||
|
||||
for (int64_t t_pause_ms = 0; t_pause_ms <= 4000; t_pause_ms += 800) {
|
||||
double t_sum_us = 0.0;
|
||||
double t_sum2_us = 0.0;
|
||||
|
||||
for (int i = 0; i < n_iters; i++) {
|
||||
// this pause is important - it simulates "idle GPU"
|
||||
std::this_thread::sleep_for(std::chrono::milliseconds(t_pause_ms));
|
||||
|
||||
const int64_t t_start_us = llama_time_us();
|
||||
|
||||
// this should take constant time
|
||||
llama_decode(ctx, batch);
|
||||
llama_synchronize(ctx);
|
||||
|
||||
const int64_t t_end_us = llama_time_us();
|
||||
|
||||
const double t_cur_us = t_end_us - t_start_us;
|
||||
|
||||
#if 1
|
||||
// print individual decode times
|
||||
printf(" - decode time: %8.2f ms\n", t_cur_us / 1000);
|
||||
#endif
|
||||
|
||||
t_sum_us += t_cur_us;
|
||||
t_sum2_us += t_cur_us * t_cur_us;
|
||||
|
||||
llama_memory_clear(llama_get_memory(ctx), true);
|
||||
llama_synchronize(ctx); // just in case
|
||||
}
|
||||
|
||||
const double t_avg_us = t_sum_us / n_iters;
|
||||
const double t_dev_us = sqrt((t_sum2_us / (n_iters - 1)) - (t_avg_us * t_avg_us * n_iters) / (n_iters - 1));
|
||||
|
||||
printf("iters: %4d, pause: %5d ms, avg decode time: %8.2f +/- %4.2f ms\n", n_iters, (int) t_pause_ms, t_avg_us / 1000, t_dev_us / 1000);
|
||||
fflush(stdout);
|
||||
}
|
||||
|
||||
llama_free(ctx);
|
||||
llama_model_free(model);
|
||||
|
||||
return 0;
|
||||
}
|
||||
82
examples/json_schema_pydantic_example.py
Normal file
@@ -0,0 +1,82 @@
|
||||
# Usage:
|
||||
#! ./llama-server -m some-model.gguf &
|
||||
#! pip install pydantic
|
||||
#! python json_schema_pydantic_example.py
|
||||
|
||||
from pydantic import BaseModel, Field, TypeAdapter
|
||||
from annotated_types import MinLen
|
||||
from typing import Annotated, List, Optional
|
||||
import json, requests
|
||||
|
||||
if True:
|
||||
|
||||
def create_completion(*, response_model=None, endpoint="http://localhost:8080/v1/chat/completions", messages, **kwargs):
|
||||
'''
|
||||
Creates a chat completion using an OpenAI-compatible endpoint w/ JSON schema support
|
||||
(llama.cpp server, llama-cpp-python, Anyscale / Together...)
|
||||
|
||||
The response_model param takes a type (+ supports Pydantic) and behaves just as w/ Instructor (see below)
|
||||
'''
|
||||
response_format = None
|
||||
type_adapter = None
|
||||
|
||||
if response_model:
|
||||
type_adapter = TypeAdapter(response_model)
|
||||
schema = type_adapter.json_schema()
|
||||
messages = [{
|
||||
"role": "system",
|
||||
"content": f"You respond in JSON format with the following schema: {json.dumps(schema, indent=2)}"
|
||||
}] + messages
|
||||
response_format={"type": "json_object", "schema": schema}
|
||||
|
||||
data = requests.post(endpoint, headers={"Content-Type": "application/json"},
|
||||
json=dict(messages=messages, response_format=response_format, **kwargs)).json()
|
||||
if 'error' in data:
|
||||
raise Exception(data['error']['message'])
|
||||
|
||||
content = data["choices"][0]["message"]["content"]
|
||||
return type_adapter.validate_json(content) if type_adapter else content
|
||||
|
||||
else:
|
||||
|
||||
# This alternative branch uses Instructor + OpenAI client lib.
|
||||
# Instructor support streamed iterable responses, retry & more.
|
||||
# (see https://python.useinstructor.com/)
|
||||
#! pip install instructor openai
|
||||
import instructor, openai
|
||||
client = instructor.patch(
|
||||
openai.OpenAI(api_key="123", base_url="http://localhost:8080"),
|
||||
mode=instructor.Mode.JSON_SCHEMA)
|
||||
create_completion = client.chat.completions.create
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
||||
class QAPair(BaseModel):
|
||||
class Config:
|
||||
extra = 'forbid' # triggers additionalProperties: false in the JSON schema
|
||||
question: str
|
||||
concise_answer: str
|
||||
justification: str
|
||||
stars: Annotated[int, Field(ge=1, le=5)]
|
||||
|
||||
class PyramidalSummary(BaseModel):
|
||||
class Config:
|
||||
extra = 'forbid' # triggers additionalProperties: false in the JSON schema
|
||||
title: str
|
||||
summary: str
|
||||
question_answers: Annotated[List[QAPair], MinLen(2)]
|
||||
sub_sections: Optional[Annotated[List['PyramidalSummary'], MinLen(2)]]
|
||||
|
||||
print("# Summary\n", create_completion(
|
||||
model="...",
|
||||
response_model=PyramidalSummary,
|
||||
messages=[{
|
||||
"role": "user",
|
||||
"content": f"""
|
||||
You are a highly efficient corporate document summarizer.
|
||||
Create a pyramidal summary of an imaginary internal document about our company processes
|
||||
(starting high-level, going down to each sub sections).
|
||||
Keep questions short, and answers even shorter (trivia / quizz style).
|
||||
"""
|
||||
}]))
|
||||
837
examples/json_schema_to_grammar.py
Executable file
@@ -0,0 +1,837 @@
|
||||
#!/usr/bin/env python3
|
||||
from __future__ import annotations
|
||||
|
||||
import argparse
|
||||
import itertools
|
||||
import json
|
||||
import re
|
||||
import sys
|
||||
from typing import Any, List, Optional, Set, Tuple, Union
|
||||
|
||||
def _build_repetition(item_rule, min_items, max_items, separator_rule=None):
|
||||
|
||||
if max_items == 0:
|
||||
return ""
|
||||
|
||||
if min_items == 0 and max_items == 1:
|
||||
return f'{item_rule}?'
|
||||
|
||||
if not separator_rule:
|
||||
if min_items == 1 and max_items is None:
|
||||
return f'{item_rule}+'
|
||||
elif min_items == 0 and max_items is None:
|
||||
return f'{item_rule}*'
|
||||
else:
|
||||
return f'{item_rule}{{{min_items},{max_items if max_items is not None else ""}}}'
|
||||
|
||||
result = item_rule + ' ' + _build_repetition(f'({separator_rule} {item_rule})', min_items - 1 if min_items > 0 else 0, max_items - 1 if max_items is not None else None)
|
||||
return f'({result})?' if min_items == 0 else result
|
||||
|
||||
def _generate_min_max_int(min_value: Optional[int], max_value: Optional[int], out: list, decimals_left: int = 16, top_level: bool = True):
|
||||
has_min = min_value != None
|
||||
has_max = max_value != None
|
||||
|
||||
def digit_range(from_char: str, to_char: str):
|
||||
out.append("[")
|
||||
if from_char == to_char:
|
||||
out.append(from_char)
|
||||
else:
|
||||
out.append(from_char)
|
||||
out.append("-")
|
||||
out.append(to_char)
|
||||
out.append("]")
|
||||
|
||||
def more_digits(min_digits: int, max_digits: int):
|
||||
out.append("[0-9]")
|
||||
if min_digits == max_digits and min_digits == 1:
|
||||
return
|
||||
out.append("{")
|
||||
out.append(str(min_digits))
|
||||
if max_digits != min_digits:
|
||||
out.append(",")
|
||||
if max_digits != sys.maxsize:
|
||||
out.append(str(max_digits))
|
||||
out.append("}")
|
||||
|
||||
def uniform_range(from_str: str, to_str: str):
|
||||
i = 0
|
||||
while i < len(from_str) and from_str[i] == to_str[i]:
|
||||
i += 1
|
||||
if i > 0:
|
||||
out.append("\"")
|
||||
out.append(from_str[:i])
|
||||
out.append("\"")
|
||||
if i < len(from_str):
|
||||
if i > 0:
|
||||
out.append(" ")
|
||||
sub_len = len(from_str) - i - 1
|
||||
if sub_len > 0:
|
||||
from_sub = from_str[i+1:]
|
||||
to_sub = to_str[i+1:]
|
||||
sub_zeros = "0" * sub_len
|
||||
sub_nines = "9" * sub_len
|
||||
|
||||
to_reached = False
|
||||
out.append("(")
|
||||
if from_sub == sub_zeros:
|
||||
digit_range(from_str[i], chr(ord(to_str[i]) - 1))
|
||||
out.append(" ")
|
||||
more_digits(sub_len, sub_len)
|
||||
else:
|
||||
out.append("[")
|
||||
out.append(from_str[i])
|
||||
out.append("] ")
|
||||
out.append("(")
|
||||
uniform_range(from_sub, sub_nines)
|
||||
out.append(")")
|
||||
if ord(from_str[i]) < ord(to_str[i]) - 1:
|
||||
out.append(" | ")
|
||||
if to_sub == sub_nines:
|
||||
digit_range(chr(ord(from_str[i]) + 1), to_str[i])
|
||||
to_reached = True
|
||||
else:
|
||||
digit_range(chr(ord(from_str[i]) + 1), chr(ord(to_str[i]) - 1))
|
||||
out.append(" ")
|
||||
more_digits(sub_len, sub_len)
|
||||
if not to_reached:
|
||||
out.append(" | ")
|
||||
digit_range(to_str[i], to_str[i])
|
||||
out.append(" ")
|
||||
uniform_range(sub_zeros, to_sub)
|
||||
out.append(")")
|
||||
else:
|
||||
out.append("[")
|
||||
out.append(from_str[i])
|
||||
out.append("-")
|
||||
out.append(to_str[i])
|
||||
out.append("]")
|
||||
|
||||
if has_min and has_max:
|
||||
if min_value < 0 and max_value < 0:
|
||||
out.append("\"-\" (")
|
||||
_generate_min_max_int(-max_value, -min_value, out, decimals_left, top_level=True)
|
||||
out.append(")")
|
||||
return
|
||||
|
||||
if min_value < 0:
|
||||
out.append("\"-\" (")
|
||||
_generate_min_max_int(0, -min_value, out, decimals_left, top_level=True)
|
||||
out.append(") | ")
|
||||
min_value = 0
|
||||
|
||||
min_s = str(min_value)
|
||||
max_s = str(max_value)
|
||||
min_digits = len(min_s)
|
||||
max_digits = len(max_s)
|
||||
|
||||
for digits in range(min_digits, max_digits):
|
||||
uniform_range(min_s, "9" * digits)
|
||||
min_s = "1" + "0" * digits
|
||||
out.append(" | ")
|
||||
uniform_range(min_s, max_s)
|
||||
return
|
||||
|
||||
less_decimals = max(decimals_left - 1, 1)
|
||||
|
||||
if has_min:
|
||||
if min_value < 0:
|
||||
out.append("\"-\" (")
|
||||
_generate_min_max_int(None, -min_value, out, decimals_left, top_level=False)
|
||||
out.append(") | [0] | [1-9] ")
|
||||
more_digits(0, decimals_left - 1)
|
||||
elif min_value == 0:
|
||||
if top_level:
|
||||
out.append("[0] | [1-9] ")
|
||||
more_digits(0, less_decimals)
|
||||
else:
|
||||
more_digits(1, decimals_left)
|
||||
elif min_value <= 9:
|
||||
c = str(min_value)
|
||||
range_start = '1' if top_level else '0'
|
||||
if c > range_start:
|
||||
digit_range(range_start, chr(ord(c) - 1))
|
||||
out.append(" ")
|
||||
more_digits(1, less_decimals)
|
||||
out.append(" | ")
|
||||
digit_range(c, "9")
|
||||
out.append(" ")
|
||||
more_digits(0, less_decimals)
|
||||
else:
|
||||
min_s = str(min_value)
|
||||
length = len(min_s)
|
||||
c = min_s[0]
|
||||
|
||||
if c > "1":
|
||||
digit_range("1" if top_level else "0", chr(ord(c) - 1))
|
||||
out.append(" ")
|
||||
more_digits(length, less_decimals)
|
||||
out.append(" | ")
|
||||
digit_range(c, c)
|
||||
out.append(" (")
|
||||
_generate_min_max_int(int(min_s[1:]), None, out, less_decimals, top_level=False)
|
||||
out.append(")")
|
||||
if c < "9":
|
||||
out.append(" | ")
|
||||
digit_range(chr(ord(c) + 1), "9")
|
||||
out.append(" ")
|
||||
more_digits(length - 1, less_decimals)
|
||||
return
|
||||
|
||||
if has_max:
|
||||
if max_value >= 0:
|
||||
if top_level:
|
||||
out.append("\"-\" [1-9] ")
|
||||
more_digits(0, less_decimals)
|
||||
out.append(" | ")
|
||||
_generate_min_max_int(0, max_value, out, decimals_left, top_level=True)
|
||||
else:
|
||||
out.append("\"-\" (")
|
||||
_generate_min_max_int(-max_value, None, out, decimals_left, top_level=False)
|
||||
out.append(")")
|
||||
return
|
||||
|
||||
raise RuntimeError("At least one of min_value or max_value must be set")
|
||||
|
||||
class BuiltinRule:
|
||||
def __init__(self, content: str, deps: list | None = None):
|
||||
self.content = content
|
||||
self.deps = deps or []
|
||||
|
||||
# Constraining spaces to prevent model "running away".
|
||||
SPACE_RULE = '| " " | "\\n"{1,2} [ \\t]{0,20}'
|
||||
|
||||
PRIMITIVE_RULES = {
|
||||
'boolean' : BuiltinRule('("true" | "false") space', []),
|
||||
'decimal-part' : BuiltinRule('[0-9]{1,16}', []),
|
||||
'integral-part': BuiltinRule('[0] | [1-9] [0-9]{0,15}', []),
|
||||
'number' : BuiltinRule('("-"? integral-part) ("." decimal-part)? ([eE] [-+]? integral-part)? space', ['integral-part', 'decimal-part']),
|
||||
'integer' : BuiltinRule('("-"? integral-part) space', ['integral-part']),
|
||||
'value' : BuiltinRule('object | array | string | number | boolean | null', ['object', 'array', 'string', 'number', 'boolean', 'null']),
|
||||
'object' : BuiltinRule('"{" space ( string ":" space value ("," space string ":" space value)* )? "}" space', ['string', 'value']),
|
||||
'array' : BuiltinRule('"[" space ( value ("," space value)* )? "]" space', ['value']),
|
||||
'uuid' : BuiltinRule(r'"\"" [0-9a-fA-F]{8} "-" [0-9a-fA-F]{4} "-" [0-9a-fA-F]{4} "-" [0-9a-fA-F]{4} "-" [0-9a-fA-F]{12} "\"" space', []),
|
||||
'char' : BuiltinRule(r'[^"\\\x7F\x00-\x1F] | [\\] (["\\bfnrt] | "u" [0-9a-fA-F]{4})', []),
|
||||
'string' : BuiltinRule(r'"\"" char* "\"" space', ['char']),
|
||||
'null' : BuiltinRule('"null" space', []),
|
||||
}
|
||||
|
||||
# TODO: support "uri", "email" string formats
|
||||
STRING_FORMAT_RULES = {
|
||||
'date' : BuiltinRule('[0-9]{4} "-" ( "0" [1-9] | "1" [0-2] ) "-" ( \"0\" [1-9] | [1-2] [0-9] | "3" [0-1] )', []),
|
||||
'time' : BuiltinRule('([01] [0-9] | "2" [0-3]) ":" [0-5] [0-9] ":" [0-5] [0-9] ( "." [0-9]{3} )? ( "Z" | ( "+" | "-" ) ( [01] [0-9] | "2" [0-3] ) ":" [0-5] [0-9] )', []),
|
||||
'date-time' : BuiltinRule('date "T" time', ['date', 'time']),
|
||||
'date-string' : BuiltinRule('"\\"" date "\\"" space', ['date']),
|
||||
'time-string' : BuiltinRule('"\\"" time "\\"" space', ['time']),
|
||||
'date-time-string': BuiltinRule('"\\"" date-time "\\"" space', ['date-time']),
|
||||
}
|
||||
|
||||
DOTALL = '[\\U00000000-\\U0010FFFF]'
|
||||
DOT = '[^\\x0A\\x0D]'
|
||||
|
||||
RESERVED_NAMES = set(["root", "dot", *PRIMITIVE_RULES.keys(), *STRING_FORMAT_RULES.keys()])
|
||||
|
||||
INVALID_RULE_CHARS_RE = re.compile(r'[^a-zA-Z0-9-]+')
|
||||
GRAMMAR_LITERAL_ESCAPE_RE = re.compile(r'[\r\n"\\]')
|
||||
GRAMMAR_RANGE_LITERAL_ESCAPE_RE = re.compile(r'[\r\n"\]\-\\]')
|
||||
GRAMMAR_LITERAL_ESCAPES = {'\r': '\\r', '\n': '\\n', '"': '\\"', '-': '\\-', ']': '\\]', '\\': '\\\\'}
|
||||
|
||||
NON_LITERAL_SET = set('|.()[]{}*+?')
|
||||
ESCAPED_IN_REGEXPS_BUT_NOT_IN_LITERALS = set('^$.[]()|{}*+?')
|
||||
|
||||
|
||||
class SchemaConverter:
|
||||
def __init__(self, *, prop_order, allow_fetch, dotall, raw_pattern):
|
||||
self._prop_order = prop_order
|
||||
self._allow_fetch = allow_fetch
|
||||
self._dotall = dotall
|
||||
self._raw_pattern = raw_pattern
|
||||
self._rules = {
|
||||
'space': SPACE_RULE,
|
||||
}
|
||||
self._refs = {}
|
||||
self._refs_being_resolved = set()
|
||||
|
||||
def _format_literal(self, literal):
|
||||
escaped = GRAMMAR_LITERAL_ESCAPE_RE.sub(
|
||||
lambda m: GRAMMAR_LITERAL_ESCAPES.get(m.group(0)) or m.group(0), literal
|
||||
)
|
||||
return f'"{escaped}"'
|
||||
|
||||
def not_literal(self, literal: str, dotall: bool = True, maybe_escaped_underscores = False) -> str:
|
||||
'''
|
||||
not_literal('a') -> '[^a]'
|
||||
not_literal('abc') -> '([^a] | "a" ([^b] | "b" ([^c])?)?)?'
|
||||
'''
|
||||
assert len(literal) > 0, 'Empty literal not supported'
|
||||
def recurse(i: int):
|
||||
c = literal[i]
|
||||
if maybe_escaped_underscores and c == '_':
|
||||
yield f'[^{c}\\\\]'
|
||||
yield ' | '
|
||||
yield f'"\\\\"? "{c}"'
|
||||
else:
|
||||
yield f'[^{c}]'
|
||||
if i < len(literal) - 1:
|
||||
yield ' | '
|
||||
yield self._format_literal(c)
|
||||
yield ' ('
|
||||
yield from recurse(i + 1)
|
||||
yield ')?'
|
||||
|
||||
return ''.join(('(', *recurse(0), ')'))
|
||||
|
||||
def _not_strings(self, strings):
|
||||
class TrieNode:
|
||||
def __init__(self):
|
||||
self.children = {}
|
||||
self.is_end_of_string = False
|
||||
|
||||
def insert(self, string):
|
||||
node = self
|
||||
for c in string:
|
||||
node = node.children.setdefault(c, TrieNode())
|
||||
node.is_end_of_string = True
|
||||
|
||||
trie = TrieNode()
|
||||
for s in strings:
|
||||
trie.insert(s)
|
||||
|
||||
char_rule = self._add_primitive('char', PRIMITIVE_RULES['char'])
|
||||
out = ['["] ( ']
|
||||
|
||||
def visit(node):
|
||||
rejects = []
|
||||
first = True
|
||||
for c in sorted(node.children.keys()):
|
||||
child = node.children[c]
|
||||
rejects.append(c)
|
||||
if first:
|
||||
first = False
|
||||
else:
|
||||
out.append(' | ')
|
||||
out.append(f'[{c}]')
|
||||
if child.children:
|
||||
out.append(f' (')
|
||||
visit(child)
|
||||
out.append(')')
|
||||
elif child.is_end_of_string:
|
||||
out.append(f' {char_rule}+')
|
||||
if node.children:
|
||||
if not first:
|
||||
out.append(' | ')
|
||||
out.append(f'[^"{"".join(rejects)}] {char_rule}*')
|
||||
visit(trie)
|
||||
|
||||
out.append(f' ){"" if trie.is_end_of_string else "?"} ["] space')
|
||||
return ''.join(out)
|
||||
|
||||
def _add_rule(self, name, rule):
|
||||
esc_name = INVALID_RULE_CHARS_RE.sub('-', name)
|
||||
if esc_name not in self._rules or self._rules[esc_name] == rule:
|
||||
key = esc_name
|
||||
else:
|
||||
i = 0
|
||||
while f'{esc_name}{i}' in self._rules and self._rules[f'{esc_name}{i}'] != rule:
|
||||
i += 1
|
||||
key = f'{esc_name}{i}'
|
||||
self._rules[key] = rule
|
||||
return key
|
||||
|
||||
def resolve_refs(self, schema: dict, url: str):
|
||||
'''
|
||||
Resolves all $ref fields in the given schema, fetching any remote schemas,
|
||||
replacing $ref with absolute reference URL and populating self._refs with the
|
||||
respective referenced (sub)schema dictionaries.
|
||||
'''
|
||||
def visit(n: dict):
|
||||
if isinstance(n, list):
|
||||
return [visit(x) for x in n]
|
||||
elif isinstance(n, dict):
|
||||
ref = n.get('$ref')
|
||||
if ref is not None and ref not in self._refs:
|
||||
if ref.startswith('https://'):
|
||||
assert self._allow_fetch, 'Fetching remote schemas is not allowed (use --allow-fetch for force)'
|
||||
import requests
|
||||
|
||||
frag_split = ref.split('#')
|
||||
base_url = frag_split[0]
|
||||
|
||||
target = self._refs.get(base_url)
|
||||
if target is None:
|
||||
target = self.resolve_refs(requests.get(ref).json(), base_url)
|
||||
self._refs[base_url] = target
|
||||
|
||||
if len(frag_split) == 1 or frag_split[-1] == '':
|
||||
return target
|
||||
elif ref.startswith('#/'):
|
||||
target = schema
|
||||
ref = f'{url}{ref}'
|
||||
n['$ref'] = ref
|
||||
else:
|
||||
raise ValueError(f'Unsupported ref {ref}')
|
||||
|
||||
for sel in ref.split('#')[-1].split('/')[1:]:
|
||||
assert target is not None, f'Error resolving ref {ref}: {sel} not in {target}'
|
||||
if isinstance(target, list):
|
||||
try:
|
||||
sel_index = int(sel)
|
||||
except ValueError:
|
||||
raise ValueError(f'Error resolving ref {ref}: {sel} not in {target}')
|
||||
assert 0 <= sel_index < len(target), f'Error resolving ref {ref}: {sel} not in {target}'
|
||||
target = target[sel_index]
|
||||
else:
|
||||
assert sel in target, f'Error resolving ref {ref}: {sel} not in {target}'
|
||||
target = target[sel]
|
||||
|
||||
self._refs[ref] = target
|
||||
else:
|
||||
for v in n.values():
|
||||
visit(v)
|
||||
|
||||
return n
|
||||
return visit(schema)
|
||||
|
||||
def _generate_union_rule(self, name, alt_schemas):
|
||||
return ' | '.join((
|
||||
self.visit(alt_schema, f'{name}{"-" if name else "alternative-"}{i}')
|
||||
for i, alt_schema in enumerate(alt_schemas)
|
||||
))
|
||||
|
||||
def _visit_pattern(self, pattern, name):
|
||||
'''
|
||||
Transforms a regular expression pattern into a GBNF rule.
|
||||
|
||||
Input: https://json-schema.org/understanding-json-schema/reference/regular_expressions
|
||||
Output: https://github.com/ggerganov/llama.cpp/blob/master/grammars/README.md
|
||||
|
||||
Unsupported features: negative/positive lookaheads, greedy/non-greedy modifiers.
|
||||
|
||||
Mostly a 1:1 translation, except for {x} / {x,} / {x,y} quantifiers for which
|
||||
we define sub-rules to keep the output lean.
|
||||
'''
|
||||
|
||||
assert pattern.startswith('^') and pattern.endswith('$'), 'Pattern must start with "^" and end with "$"'
|
||||
pattern = pattern[1:-1]
|
||||
sub_rule_ids = {}
|
||||
|
||||
i = 0
|
||||
length = len(pattern)
|
||||
|
||||
def to_rule(s: tuple[str, bool]) -> str:
|
||||
(txt, is_literal) = s
|
||||
return "\"" + txt + "\"" if is_literal else txt
|
||||
|
||||
def transform() -> tuple[str, bool]:
|
||||
'''
|
||||
Parse a unit at index i (advancing it), and return its string representation + whether it's a literal.
|
||||
'''
|
||||
nonlocal i
|
||||
nonlocal pattern
|
||||
nonlocal sub_rule_ids
|
||||
|
||||
start = i
|
||||
# For each component of this sequence, store its string representation and whether it's a literal.
|
||||
# We only need a flat structure here to apply repetition operators to the last item, and
|
||||
# to merge literals at the and (we're parsing grouped ( sequences ) recursively and don't treat '|' specially
|
||||
# (GBNF's syntax is luckily very close to regular expressions!)
|
||||
seq: list[tuple[str, bool]] = []
|
||||
|
||||
def get_dot():
|
||||
if self._dotall:
|
||||
rule = DOTALL
|
||||
else:
|
||||
# Accept any character... except \n and \r line break chars (\x0A and \xOD)
|
||||
rule = DOT
|
||||
return self._add_rule(f'dot', rule)
|
||||
|
||||
def join_seq():
|
||||
nonlocal seq
|
||||
ret = []
|
||||
for is_literal, g in itertools.groupby(seq, lambda x: x[1]):
|
||||
if is_literal:
|
||||
ret.append((''.join(x[0] for x in g), True))
|
||||
else:
|
||||
ret.extend(g)
|
||||
if len(ret) == 1:
|
||||
return ret[0]
|
||||
return (' '.join(to_rule(x) for x in seq), False)
|
||||
|
||||
while i < length:
|
||||
c = pattern[i]
|
||||
if c == '.':
|
||||
seq.append((get_dot(), False))
|
||||
i += 1
|
||||
elif c == '(':
|
||||
i += 1
|
||||
if i < length:
|
||||
assert pattern[i] != '?', f'Unsupported pattern syntax "{pattern[i]}" at index {i} of /{pattern}/'
|
||||
seq.append((f'({to_rule(transform())})', False))
|
||||
elif c == ')':
|
||||
i += 1
|
||||
assert start > 0 and pattern[start-1] == '(', f'Unbalanced parentheses; start = {start}, i = {i}, pattern = {pattern}'
|
||||
return join_seq()
|
||||
elif c == '[':
|
||||
square_brackets = c
|
||||
i += 1
|
||||
while i < length and pattern[i] != ']':
|
||||
if pattern[i] == '\\':
|
||||
square_brackets += pattern[i:i+2]
|
||||
i += 2
|
||||
else:
|
||||
square_brackets += pattern[i]
|
||||
i += 1
|
||||
assert i < length, f'Unbalanced square brackets; start = {start}, i = {i}, pattern = {pattern}'
|
||||
square_brackets += ']'
|
||||
i += 1
|
||||
seq.append((square_brackets, False))
|
||||
elif c == '|':
|
||||
seq.append(('|', False))
|
||||
i += 1
|
||||
elif c in ('*', '+', '?'):
|
||||
seq[-1] = (to_rule(seq[-1]) + c, False)
|
||||
i += 1
|
||||
elif c == '{':
|
||||
curly_brackets = c
|
||||
i += 1
|
||||
while i < length and pattern[i] != '}':
|
||||
curly_brackets += pattern[i]
|
||||
i += 1
|
||||
assert i < length, f'Unbalanced curly brackets; start = {start}, i = {i}, pattern = {pattern}'
|
||||
curly_brackets += '}'
|
||||
i += 1
|
||||
nums = [s.strip() for s in curly_brackets[1:-1].split(',')]
|
||||
min_times = 0
|
||||
max_times = None
|
||||
try:
|
||||
if len(nums) == 1:
|
||||
min_times = int(nums[0])
|
||||
max_times = min_times
|
||||
else:
|
||||
assert len(nums) == 2
|
||||
min_times = int(nums[0]) if nums[0] else 0
|
||||
max_times = int(nums[1]) if nums[1] else None
|
||||
except ValueError:
|
||||
raise ValueError(f'Invalid quantifier {curly_brackets} in /{pattern}/')
|
||||
|
||||
(sub, sub_is_literal) = seq[-1]
|
||||
|
||||
if not sub_is_literal:
|
||||
id = sub_rule_ids.get(sub)
|
||||
if id is None:
|
||||
id = self._add_rule(f'{name}-{len(sub_rule_ids) + 1}', sub)
|
||||
sub_rule_ids[sub] = id
|
||||
sub = id
|
||||
|
||||
seq[-1] = (_build_repetition(f'"{sub}"' if sub_is_literal else sub, min_times, max_times), False)
|
||||
else:
|
||||
literal = ''
|
||||
while i < length:
|
||||
if pattern[i] == '\\' and i < length - 1:
|
||||
next = pattern[i + 1]
|
||||
if next in ESCAPED_IN_REGEXPS_BUT_NOT_IN_LITERALS:
|
||||
i += 1
|
||||
literal += pattern[i]
|
||||
i += 1
|
||||
else:
|
||||
literal += pattern[i:i+2]
|
||||
i += 2
|
||||
elif pattern[i] == '"' and not self._raw_pattern:
|
||||
literal += '\\"'
|
||||
i += 1
|
||||
elif pattern[i] not in NON_LITERAL_SET and \
|
||||
(i == length - 1 or literal == '' or pattern[i+1] == '.' or pattern[i+1] not in NON_LITERAL_SET):
|
||||
literal += pattern[i]
|
||||
i += 1
|
||||
else:
|
||||
break
|
||||
if literal:
|
||||
seq.append((literal, True))
|
||||
|
||||
return join_seq()
|
||||
|
||||
return self._add_rule(
|
||||
name,
|
||||
to_rule(transform()) if self._raw_pattern \
|
||||
else "\"\\\"\" (" + to_rule(transform()) + ") \"\\\"\" space")
|
||||
|
||||
|
||||
def _resolve_ref(self, ref):
|
||||
ref_fragment = ref.split('#')[-1]
|
||||
ref_name = 'ref' + re.sub(r'[^a-zA-Z0-9-]+', '-', ref_fragment)
|
||||
if ref_name not in self._rules and ref not in self._refs_being_resolved:
|
||||
self._refs_being_resolved.add(ref)
|
||||
resolved = self._refs[ref]
|
||||
ref_name = self.visit(resolved, ref_name)
|
||||
self._refs_being_resolved.remove(ref)
|
||||
return ref_name
|
||||
|
||||
def _generate_constant_rule(self, value):
|
||||
return self._format_literal(json.dumps(value))
|
||||
|
||||
def visit(self, schema, name):
|
||||
schema_type = schema.get('type')
|
||||
schema_format = schema.get('format')
|
||||
rule_name = name + '-' if name in RESERVED_NAMES else name or 'root'
|
||||
|
||||
if (ref := schema.get('$ref')) is not None:
|
||||
return self._add_rule(rule_name, self._resolve_ref(ref))
|
||||
|
||||
elif 'oneOf' in schema or 'anyOf' in schema:
|
||||
return self._add_rule(rule_name, self._generate_union_rule(name, schema.get('oneOf') or schema['anyOf']))
|
||||
|
||||
elif isinstance(schema_type, list):
|
||||
return self._add_rule(rule_name, self._generate_union_rule(name, [{**schema, 'type': t} for t in schema_type]))
|
||||
|
||||
elif 'const' in schema:
|
||||
return self._add_rule(rule_name, self._generate_constant_rule(schema['const']) + ' space')
|
||||
|
||||
elif 'enum' in schema:
|
||||
rule = '(' + ' | '.join((self._generate_constant_rule(v) for v in schema['enum'])) + ') space'
|
||||
return self._add_rule(rule_name, rule)
|
||||
|
||||
elif schema_type in (None, 'object') and \
|
||||
('properties' in schema or \
|
||||
('additionalProperties' in schema and schema['additionalProperties'] is not True)):
|
||||
required = set(schema.get('required', []))
|
||||
properties = list(schema.get('properties', {}).items())
|
||||
return self._add_rule(rule_name, self._build_object_rule(properties, required, name, schema.get('additionalProperties')))
|
||||
|
||||
elif schema_type in (None, 'object', 'string') and 'allOf' in schema:
|
||||
required = set()
|
||||
properties = []
|
||||
enum_sets = []
|
||||
hybrid_name = name
|
||||
def add_component(comp_schema, is_required):
|
||||
if (ref := comp_schema.get('$ref')) is not None:
|
||||
comp_schema = self._refs[ref]
|
||||
|
||||
if 'properties' in comp_schema:
|
||||
for prop_name, prop_schema in comp_schema['properties'].items():
|
||||
properties.append((prop_name, prop_schema))
|
||||
if is_required:
|
||||
required.add(prop_name)
|
||||
|
||||
if 'enum' in comp_schema:
|
||||
enum_sets.append(set(comp_schema['enum']))
|
||||
|
||||
for t in schema['allOf']:
|
||||
if 'anyOf' in t:
|
||||
for tt in t['anyOf']:
|
||||
add_component(tt, is_required=False)
|
||||
else:
|
||||
add_component(t, is_required=True)
|
||||
|
||||
if enum_sets:
|
||||
enum_intersection = enum_sets[0]
|
||||
for s in enum_sets[1:]:
|
||||
enum_intersection &= s
|
||||
|
||||
if enum_intersection:
|
||||
rule = '(' + ' | '.join((self._generate_constant_rule(v) for v in sorted(enum_intersection))) + ') space'
|
||||
return self._add_rule(rule_name, rule)
|
||||
|
||||
return self._add_rule(rule_name, self._build_object_rule(properties, required, hybrid_name, additional_properties=None))
|
||||
|
||||
elif schema_type in (None, 'array') and ('items' in schema or 'prefixItems' in schema):
|
||||
items = schema.get('items') or schema['prefixItems']
|
||||
if isinstance(items, list):
|
||||
return self._add_rule(
|
||||
rule_name,
|
||||
'"[" space ' +
|
||||
' "," space '.join(
|
||||
self.visit(item, f'{name}{"-" if name else ""}tuple-{i}')
|
||||
for i, item in enumerate(items)) +
|
||||
' "]" space')
|
||||
else:
|
||||
item_rule_name = self.visit(items, f'{name}{"-" if name else ""}item')
|
||||
min_items = schema.get("minItems", 0)
|
||||
max_items = schema.get("maxItems")
|
||||
return self._add_rule(rule_name, '"[" space ' + _build_repetition(item_rule_name, min_items, max_items, separator_rule='"," space') + ' "]" space')
|
||||
|
||||
elif schema_type in (None, 'string') and 'pattern' in schema:
|
||||
return self._visit_pattern(schema['pattern'], rule_name)
|
||||
|
||||
elif schema_type in (None, 'string') and re.match(r'^uuid[1-5]?$', schema_format or ''):
|
||||
return self._add_primitive(
|
||||
'root' if rule_name == 'root' else schema_format,
|
||||
PRIMITIVE_RULES['uuid']
|
||||
)
|
||||
|
||||
elif schema_type in (None, 'string') and f'{schema_format}-string' in STRING_FORMAT_RULES:
|
||||
prim_name = f'{schema_format}-string'
|
||||
return self._add_rule(rule_name, self._add_primitive(prim_name, STRING_FORMAT_RULES[prim_name]))
|
||||
|
||||
elif schema_type == 'string' and ('minLength' in schema or 'maxLength' in schema):
|
||||
char_rule = self._add_primitive('char', PRIMITIVE_RULES['char'])
|
||||
min_len = schema.get('minLength', 0)
|
||||
max_len = schema.get('maxLength')
|
||||
|
||||
return self._add_rule(rule_name, r'"\"" ' + _build_repetition(char_rule, min_len, max_len) + r' "\"" space')
|
||||
|
||||
elif schema_type in (None, 'integer') and \
|
||||
('minimum' in schema or 'exclusiveMinimum' in schema or 'maximum' in schema or 'exclusiveMaximum' in schema):
|
||||
min_value = None
|
||||
max_value = None
|
||||
if 'minimum' in schema:
|
||||
min_value = schema['minimum']
|
||||
elif 'exclusiveMinimum' in schema:
|
||||
min_value = schema['exclusiveMinimum'] + 1
|
||||
if 'maximum' in schema:
|
||||
max_value = schema['maximum']
|
||||
elif 'exclusiveMaximum' in schema:
|
||||
max_value = schema['exclusiveMaximum'] - 1
|
||||
|
||||
out = ["("]
|
||||
_generate_min_max_int(min_value, max_value, out)
|
||||
out.append(") space")
|
||||
return self._add_rule(rule_name, ''.join(out))
|
||||
|
||||
elif (schema_type == 'object') or (len(schema) == 0):
|
||||
return self._add_rule(rule_name, self._add_primitive('object', PRIMITIVE_RULES['object']))
|
||||
|
||||
else:
|
||||
assert schema_type in PRIMITIVE_RULES, f'Unrecognized schema: {schema}'
|
||||
# TODO: support minimum, maximum, exclusiveMinimum, exclusiveMaximum at least for zero
|
||||
return self._add_primitive('root' if rule_name == 'root' else schema_type, PRIMITIVE_RULES[schema_type])
|
||||
|
||||
def _add_primitive(self, name: str, rule: BuiltinRule):
|
||||
n = self._add_rule(name, rule.content)
|
||||
|
||||
for dep in rule.deps:
|
||||
dep_rule = PRIMITIVE_RULES.get(dep) or STRING_FORMAT_RULES.get(dep)
|
||||
assert dep_rule, f'Rule {dep} not known'
|
||||
if dep not in self._rules:
|
||||
self._add_primitive(dep, dep_rule)
|
||||
return n
|
||||
|
||||
def _build_object_rule(self, properties: List[Tuple[str, Any]], required: Set[str], name: str, additional_properties: Optional[Union[bool, Any]]):
|
||||
prop_order = self._prop_order
|
||||
# sort by position in prop_order (if specified) then by original order
|
||||
sorted_props = [kv[0] for _, kv in sorted(enumerate(properties), key=lambda ikv: (prop_order.get(ikv[1][0], len(prop_order)), ikv[0]))]
|
||||
|
||||
prop_kv_rule_names = {}
|
||||
for prop_name, prop_schema in properties:
|
||||
prop_rule_name = self.visit(prop_schema, f'{name}{"-" if name else ""}{prop_name}')
|
||||
prop_kv_rule_names[prop_name] = self._add_rule(
|
||||
f'{name}{"-" if name else ""}{prop_name}-kv',
|
||||
fr'{self._format_literal(json.dumps(prop_name))} space ":" space {prop_rule_name}'
|
||||
)
|
||||
required_props = [k for k in sorted_props if k in required]
|
||||
optional_props = [k for k in sorted_props if k not in required]
|
||||
|
||||
if additional_properties is not None and additional_properties != False:
|
||||
sub_name = f'{name}{"-" if name else ""}additional'
|
||||
value_rule = self.visit(additional_properties, f'{sub_name}-value') if isinstance(additional_properties, dict) else \
|
||||
self._add_primitive('value', PRIMITIVE_RULES['value'])
|
||||
key_rule = self._add_primitive('string', PRIMITIVE_RULES['string']) if not sorted_props \
|
||||
else self._add_rule(f'{sub_name}-k', self._not_strings(sorted_props))
|
||||
|
||||
prop_kv_rule_names["*"] = self._add_rule(
|
||||
f'{sub_name}-kv',
|
||||
f'{key_rule} ":" space {value_rule}'
|
||||
)
|
||||
optional_props.append("*")
|
||||
|
||||
rule = '"{" space '
|
||||
rule += ' "," space '.join(prop_kv_rule_names[k] for k in required_props)
|
||||
|
||||
if optional_props:
|
||||
rule += ' ('
|
||||
if required_props:
|
||||
rule += ' "," space ( '
|
||||
|
||||
def get_recursive_refs(ks, first_is_optional):
|
||||
[k, *rest] = ks
|
||||
kv_rule_name = prop_kv_rule_names[k]
|
||||
comma_ref = f'( "," space {kv_rule_name} )'
|
||||
if first_is_optional:
|
||||
res = comma_ref + ('*' if k == '*' else '?')
|
||||
else:
|
||||
res = kv_rule_name + (' ' + comma_ref + "*" if k == '*' else '')
|
||||
if len(rest) > 0:
|
||||
res += ' ' + self._add_rule(
|
||||
f'{name}{"-" if name else ""}{k}-rest',
|
||||
get_recursive_refs(rest, first_is_optional=True)
|
||||
)
|
||||
return res
|
||||
|
||||
rule += ' | '.join(
|
||||
get_recursive_refs(optional_props[i:], first_is_optional=False)
|
||||
for i in range(len(optional_props))
|
||||
)
|
||||
if required_props:
|
||||
rule += ' )'
|
||||
rule += ' )?'
|
||||
|
||||
rule += ' "}" space'
|
||||
|
||||
return rule
|
||||
|
||||
def format_grammar(self):
|
||||
return '\n'.join(
|
||||
f'{name} ::= {rule}'
|
||||
for name, rule in sorted(self._rules.items(), key=lambda kv: kv[0])
|
||||
)
|
||||
|
||||
|
||||
def main(args_in = None):
|
||||
parser = argparse.ArgumentParser(
|
||||
description='''
|
||||
Generates a grammar (suitable for use in ./llama-cli) that produces JSON conforming to a
|
||||
given JSON schema. Only a subset of JSON schema features are supported; more may be
|
||||
added in the future.
|
||||
''',
|
||||
)
|
||||
parser.add_argument(
|
||||
'--prop-order',
|
||||
default=[],
|
||||
type=lambda s: s.split(','),
|
||||
help='''
|
||||
comma-separated property names defining the order of precedence for object properties;
|
||||
properties not specified here are given lower precedence than those that are, and
|
||||
are kept in their original order from the schema. Required properties are always
|
||||
given precedence over optional properties.
|
||||
'''
|
||||
)
|
||||
parser.add_argument(
|
||||
'--allow-fetch',
|
||||
action='store_true',
|
||||
default=False,
|
||||
help='Whether to allow fetching referenced schemas over HTTPS')
|
||||
parser.add_argument(
|
||||
'--dotall',
|
||||
action='store_true',
|
||||
default=False,
|
||||
help='Whether to treat dot (".") as matching all chars including line breaks in regular expression patterns')
|
||||
parser.add_argument(
|
||||
'--raw-pattern',
|
||||
action='store_true',
|
||||
default=False,
|
||||
help='Treats string patterns as raw patterns w/o quotes (or quote escapes)')
|
||||
|
||||
parser.add_argument('schema', help='file containing JSON schema ("-" for stdin)')
|
||||
args = parser.parse_args(args_in)
|
||||
|
||||
if args.schema.startswith('https://'):
|
||||
url = args.schema
|
||||
import requests
|
||||
schema = requests.get(url).json()
|
||||
elif args.schema == '-':
|
||||
url = 'stdin'
|
||||
schema = json.load(sys.stdin)
|
||||
else:
|
||||
url = f'file://{args.schema}'
|
||||
with open(args.schema) as f:
|
||||
schema = json.load(f)
|
||||
converter = SchemaConverter(
|
||||
prop_order={name: idx for idx, name in enumerate(args.prop_order)},
|
||||
allow_fetch=args.allow_fetch,
|
||||
dotall=args.dotall,
|
||||
raw_pattern=args.raw_pattern)
|
||||
schema = converter.resolve_refs(schema, url)
|
||||
converter.visit(schema, '')
|
||||
print(converter.format_grammar())
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
||||
33
examples/llama.android/.gitignore
vendored
Normal file
@@ -0,0 +1,33 @@
|
||||
# Gradle files
|
||||
.gradle/
|
||||
build/
|
||||
|
||||
# Local configuration file (sdk path, etc)
|
||||
local.properties
|
||||
|
||||
# Log/OS Files
|
||||
*.log
|
||||
|
||||
# Android Studio generated files and folders
|
||||
captures/
|
||||
.externalNativeBuild/
|
||||
.cxx/
|
||||
*.apk
|
||||
output.json
|
||||
|
||||
# IntelliJ
|
||||
*.iml
|
||||
.idea/
|
||||
misc.xml
|
||||
deploymentTargetDropDown.xml
|
||||
render.experimental.xml
|
||||
|
||||
# Keystore files
|
||||
*.jks
|
||||
*.keystore
|
||||
|
||||
# Google Services (e.g. APIs or Firebase)
|
||||
google-services.json
|
||||
|
||||
# Android Profiling
|
||||
*.hprof
|
||||
1
examples/llama.android/app/.gitignore
vendored
Normal file
@@ -0,0 +1 @@
|
||||
/build
|
||||
58
examples/llama.android/app/build.gradle.kts
Normal file
@@ -0,0 +1,58 @@
|
||||
plugins {
|
||||
alias(libs.plugins.android.application)
|
||||
alias(libs.plugins.jetbrains.kotlin.android)
|
||||
}
|
||||
|
||||
android {
|
||||
namespace = "com.example.llama"
|
||||
compileSdk = 36
|
||||
|
||||
defaultConfig {
|
||||
applicationId = "com.example.llama.aichat"
|
||||
|
||||
minSdk = 33
|
||||
targetSdk = 36
|
||||
|
||||
versionCode = 1
|
||||
versionName = "1.0"
|
||||
|
||||
testInstrumentationRunner = "androidx.test.runner.AndroidJUnitRunner"
|
||||
vectorDrawables {
|
||||
useSupportLibrary = true
|
||||
}
|
||||
}
|
||||
|
||||
buildTypes {
|
||||
debug {
|
||||
isMinifyEnabled = true
|
||||
isShrinkResources = true
|
||||
proguardFiles(
|
||||
getDefaultProguardFile("proguard-android.txt"),
|
||||
"proguard-rules.pro"
|
||||
)
|
||||
}
|
||||
release {
|
||||
isMinifyEnabled = true
|
||||
isShrinkResources = true
|
||||
proguardFiles(
|
||||
getDefaultProguardFile("proguard-android-optimize.txt"),
|
||||
"proguard-rules.pro"
|
||||
)
|
||||
}
|
||||
}
|
||||
compileOptions {
|
||||
sourceCompatibility = JavaVersion.VERSION_17
|
||||
targetCompatibility = JavaVersion.VERSION_17
|
||||
}
|
||||
}
|
||||
|
||||
dependencies {
|
||||
implementation(libs.bundles.androidx)
|
||||
implementation(libs.material)
|
||||
|
||||
implementation(project(":lib"))
|
||||
|
||||
testImplementation(libs.junit)
|
||||
androidTestImplementation(libs.androidx.junit)
|
||||
androidTestImplementation(libs.androidx.espresso.core)
|
||||
}
|
||||
29
examples/llama.android/app/proguard-rules.pro
vendored
Normal file
@@ -0,0 +1,29 @@
|
||||
# Add project specific ProGuard rules here.
|
||||
# You can control the set of applied configuration files using the
|
||||
# proguardFiles setting in build.gradle.
|
||||
#
|
||||
# For more details, see
|
||||
# http://developer.android.com/guide/developing/tools/proguard.html
|
||||
|
||||
# If your project uses WebView with JS, uncomment the following
|
||||
# and specify the fully qualified class name to the JavaScript interface
|
||||
# class:
|
||||
#-keepclassmembers class fqcn.of.javascript.interface.for.webview {
|
||||
# public *;
|
||||
#}
|
||||
|
||||
# Uncomment this to preserve the line number information for
|
||||
# debugging stack traces.
|
||||
#-keepattributes SourceFile,LineNumberTable
|
||||
|
||||
# If you keep the line number information, uncomment this to
|
||||
# hide the original source file name.
|
||||
#-renamesourcefileattribute SourceFile
|
||||
|
||||
-keep class com.arm.aichat.* { *; }
|
||||
-keep class com.arm.aichat.gguf.* { *; }
|
||||
|
||||
-assumenosideeffects class android.util.Log {
|
||||
public static int v(...);
|
||||
public static int d(...);
|
||||
}
|
||||
27
examples/llama.android/app/src/main/AndroidManifest.xml
Normal file
@@ -0,0 +1,27 @@
|
||||
<?xml version="1.0" encoding="utf-8"?>
|
||||
<manifest xmlns:android="http://schemas.android.com/apk/res/android">
|
||||
|
||||
<application
|
||||
android:allowBackup="true"
|
||||
android:dataExtractionRules="@xml/data_extraction_rules"
|
||||
android:extractNativeLibs="true"
|
||||
android:fullBackupContent="@xml/backup_rules"
|
||||
android:icon="@mipmap/ic_launcher_round"
|
||||
android:label="@string/app_name"
|
||||
android:roundIcon="@mipmap/ic_launcher_round"
|
||||
android:supportsRtl="true"
|
||||
android:theme="@style/Theme.AiChatSample"
|
||||
>
|
||||
|
||||
<activity
|
||||
android:name=".MainActivity"
|
||||
android:exported="true">
|
||||
<intent-filter>
|
||||
<action android:name="android.intent.action.MAIN" />
|
||||
|
||||
<category android:name="android.intent.category.LAUNCHER" />
|
||||
</intent-filter>
|
||||
</activity>
|
||||
</application>
|
||||
|
||||
</manifest>
|
||||
@@ -0,0 +1,275 @@
|
||||
package com.example.llama
|
||||
|
||||
import android.net.Uri
|
||||
import android.os.Bundle
|
||||
import android.util.Log
|
||||
import android.widget.EditText
|
||||
import android.widget.TextView
|
||||
import android.widget.Toast
|
||||
import androidx.activity.addCallback
|
||||
import androidx.activity.enableEdgeToEdge
|
||||
import androidx.activity.result.contract.ActivityResultContracts
|
||||
import androidx.appcompat.app.AppCompatActivity
|
||||
import androidx.lifecycle.lifecycleScope
|
||||
import androidx.recyclerview.widget.LinearLayoutManager
|
||||
import androidx.recyclerview.widget.RecyclerView
|
||||
import com.arm.aichat.AiChat
|
||||
import com.arm.aichat.InferenceEngine
|
||||
import com.arm.aichat.gguf.GgufMetadata
|
||||
import com.arm.aichat.gguf.GgufMetadataReader
|
||||
import com.google.android.material.floatingactionbutton.FloatingActionButton
|
||||
import kotlinx.coroutines.Dispatchers
|
||||
import kotlinx.coroutines.Job
|
||||
import kotlinx.coroutines.flow.onCompletion
|
||||
import kotlinx.coroutines.launch
|
||||
import kotlinx.coroutines.withContext
|
||||
import java.io.File
|
||||
import java.io.FileOutputStream
|
||||
import java.io.InputStream
|
||||
import java.util.UUID
|
||||
|
||||
class MainActivity : AppCompatActivity() {
|
||||
|
||||
// Android views
|
||||
private lateinit var ggufTv: TextView
|
||||
private lateinit var messagesRv: RecyclerView
|
||||
private lateinit var userInputEt: EditText
|
||||
private lateinit var userActionFab: FloatingActionButton
|
||||
|
||||
// Arm AI Chat inference engine
|
||||
private lateinit var engine: InferenceEngine
|
||||
private var generationJob: Job? = null
|
||||
|
||||
// Conversation states
|
||||
private var isModelReady = false
|
||||
private val messages = mutableListOf<Message>()
|
||||
private val lastAssistantMsg = StringBuilder()
|
||||
private val messageAdapter = MessageAdapter(messages)
|
||||
|
||||
override fun onCreate(savedInstanceState: Bundle?) {
|
||||
super.onCreate(savedInstanceState)
|
||||
enableEdgeToEdge()
|
||||
setContentView(R.layout.activity_main)
|
||||
// View model boilerplate and state management is out of this basic sample's scope
|
||||
onBackPressedDispatcher.addCallback { Log.w(TAG, "Ignore back press for simplicity") }
|
||||
|
||||
// Find views
|
||||
ggufTv = findViewById(R.id.gguf)
|
||||
messagesRv = findViewById(R.id.messages)
|
||||
messagesRv.layoutManager = LinearLayoutManager(this).apply { stackFromEnd = true }
|
||||
messagesRv.adapter = messageAdapter
|
||||
userInputEt = findViewById(R.id.user_input)
|
||||
userActionFab = findViewById(R.id.fab)
|
||||
|
||||
// Arm AI Chat initialization
|
||||
lifecycleScope.launch(Dispatchers.Default) {
|
||||
engine = AiChat.getInferenceEngine(applicationContext)
|
||||
}
|
||||
|
||||
// Upon CTA button tapped
|
||||
userActionFab.setOnClickListener {
|
||||
if (isModelReady) {
|
||||
// If model is ready, validate input and send to engine
|
||||
handleUserInput()
|
||||
} else {
|
||||
// Otherwise, prompt user to select a GGUF metadata on the device
|
||||
getContent.launch(arrayOf("*/*"))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
private val getContent = registerForActivityResult(
|
||||
ActivityResultContracts.OpenDocument()
|
||||
) { uri ->
|
||||
Log.i(TAG, "Selected file uri:\n $uri")
|
||||
uri?.let { handleSelectedModel(it) }
|
||||
}
|
||||
|
||||
/**
|
||||
* Handles the file Uri from [getContent] result
|
||||
*/
|
||||
private fun handleSelectedModel(uri: Uri) {
|
||||
// Update UI states
|
||||
userActionFab.isEnabled = false
|
||||
userInputEt.hint = "Parsing GGUF..."
|
||||
ggufTv.text = "Parsing metadata from selected file \n$uri"
|
||||
|
||||
lifecycleScope.launch(Dispatchers.IO) {
|
||||
// Parse GGUF metadata
|
||||
Log.i(TAG, "Parsing GGUF metadata...")
|
||||
contentResolver.openInputStream(uri)?.use {
|
||||
GgufMetadataReader.create().readStructuredMetadata(it)
|
||||
}?.let { metadata ->
|
||||
// Update UI to show GGUF metadata to user
|
||||
Log.i(TAG, "GGUF parsed: \n$metadata")
|
||||
withContext(Dispatchers.Main) {
|
||||
ggufTv.text = metadata.toString()
|
||||
}
|
||||
|
||||
// Ensure the model file is available
|
||||
val modelName = metadata.filename() + FILE_EXTENSION_GGUF
|
||||
contentResolver.openInputStream(uri)?.use { input ->
|
||||
ensureModelFile(modelName, input)
|
||||
}?.let { modelFile ->
|
||||
loadModel(modelName, modelFile)
|
||||
|
||||
withContext(Dispatchers.Main) {
|
||||
isModelReady = true
|
||||
userInputEt.hint = "Type and send a message!"
|
||||
userInputEt.isEnabled = true
|
||||
userActionFab.setImageResource(R.drawable.outline_send_24)
|
||||
userActionFab.isEnabled = true
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Prepare the model file within app's private storage
|
||||
*/
|
||||
private suspend fun ensureModelFile(modelName: String, input: InputStream) =
|
||||
withContext(Dispatchers.IO) {
|
||||
File(ensureModelsDirectory(), modelName).also { file ->
|
||||
// Copy the file into local storage if not yet done
|
||||
if (!file.exists()) {
|
||||
Log.i(TAG, "Start copying file to $modelName")
|
||||
withContext(Dispatchers.Main) {
|
||||
userInputEt.hint = "Copying file..."
|
||||
}
|
||||
|
||||
FileOutputStream(file).use { input.copyTo(it) }
|
||||
Log.i(TAG, "Finished copying file to $modelName")
|
||||
} else {
|
||||
Log.i(TAG, "File already exists $modelName")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Load the model file from the app private storage
|
||||
*/
|
||||
private suspend fun loadModel(modelName: String, modelFile: File) =
|
||||
withContext(Dispatchers.IO) {
|
||||
Log.i(TAG, "Loading model $modelName")
|
||||
withContext(Dispatchers.Main) {
|
||||
userInputEt.hint = "Loading model..."
|
||||
}
|
||||
engine.loadModel(modelFile.path)
|
||||
}
|
||||
|
||||
/**
|
||||
* Validate and send the user message into [InferenceEngine]
|
||||
*/
|
||||
private fun handleUserInput() {
|
||||
userInputEt.text.toString().also { userMsg ->
|
||||
if (userMsg.isEmpty()) {
|
||||
Toast.makeText(this, "Input message is empty!", Toast.LENGTH_SHORT).show()
|
||||
} else {
|
||||
userInputEt.text = null
|
||||
userInputEt.isEnabled = false
|
||||
userActionFab.isEnabled = false
|
||||
|
||||
// Update message states
|
||||
messages.add(Message(UUID.randomUUID().toString(), userMsg, true))
|
||||
lastAssistantMsg.clear()
|
||||
messages.add(Message(UUID.randomUUID().toString(), lastAssistantMsg.toString(), false))
|
||||
|
||||
generationJob = lifecycleScope.launch(Dispatchers.Default) {
|
||||
engine.sendUserPrompt(userMsg)
|
||||
.onCompletion {
|
||||
withContext(Dispatchers.Main) {
|
||||
userInputEt.isEnabled = true
|
||||
userActionFab.isEnabled = true
|
||||
}
|
||||
}.collect { token ->
|
||||
withContext(Dispatchers.Main) {
|
||||
val messageCount = messages.size
|
||||
check(messageCount > 0 && !messages[messageCount - 1].isUser)
|
||||
|
||||
messages.removeAt(messageCount - 1).copy(
|
||||
content = lastAssistantMsg.append(token).toString()
|
||||
).let { messages.add(it) }
|
||||
|
||||
messageAdapter.notifyItemChanged(messages.size - 1)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Run a benchmark with the model file
|
||||
*/
|
||||
@Deprecated("This benchmark doesn't accurately indicate GUI performance expected by app developers")
|
||||
private suspend fun runBenchmark(modelName: String, modelFile: File) =
|
||||
withContext(Dispatchers.Default) {
|
||||
Log.i(TAG, "Starts benchmarking $modelName")
|
||||
withContext(Dispatchers.Main) {
|
||||
userInputEt.hint = "Running benchmark..."
|
||||
}
|
||||
engine.bench(
|
||||
pp=BENCH_PROMPT_PROCESSING_TOKENS,
|
||||
tg=BENCH_TOKEN_GENERATION_TOKENS,
|
||||
pl=BENCH_SEQUENCE,
|
||||
nr=BENCH_REPETITION
|
||||
).let { result ->
|
||||
messages.add(Message(UUID.randomUUID().toString(), result, false))
|
||||
withContext(Dispatchers.Main) {
|
||||
messageAdapter.notifyItemChanged(messages.size - 1)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Create the `models` directory if not exist.
|
||||
*/
|
||||
private fun ensureModelsDirectory() =
|
||||
File(filesDir, DIRECTORY_MODELS).also {
|
||||
if (it.exists() && !it.isDirectory) { it.delete() }
|
||||
if (!it.exists()) { it.mkdir() }
|
||||
}
|
||||
|
||||
override fun onStop() {
|
||||
generationJob?.cancel()
|
||||
super.onStop()
|
||||
}
|
||||
|
||||
override fun onDestroy() {
|
||||
engine.destroy()
|
||||
super.onDestroy()
|
||||
}
|
||||
|
||||
companion object {
|
||||
private val TAG = MainActivity::class.java.simpleName
|
||||
|
||||
private const val DIRECTORY_MODELS = "models"
|
||||
private const val FILE_EXTENSION_GGUF = ".gguf"
|
||||
|
||||
private const val BENCH_PROMPT_PROCESSING_TOKENS = 512
|
||||
private const val BENCH_TOKEN_GENERATION_TOKENS = 128
|
||||
private const val BENCH_SEQUENCE = 1
|
||||
private const val BENCH_REPETITION = 3
|
||||
}
|
||||
}
|
||||
|
||||
fun GgufMetadata.filename() = when {
|
||||
basic.name != null -> {
|
||||
basic.name?.let { name ->
|
||||
basic.sizeLabel?.let { size ->
|
||||
"$name-$size"
|
||||
} ?: name
|
||||
}
|
||||
}
|
||||
architecture?.architecture != null -> {
|
||||
architecture?.architecture?.let { arch ->
|
||||
basic.uuid?.let { uuid ->
|
||||
"$arch-$uuid"
|
||||
} ?: "$arch-${System.currentTimeMillis()}"
|
||||
}
|
||||
}
|
||||
else -> {
|
||||
"model-${System.currentTimeMillis().toHexString()}"
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,51 @@
|
||||
package com.example.llama
|
||||
|
||||
import android.view.LayoutInflater
|
||||
import android.view.View
|
||||
import android.view.ViewGroup
|
||||
import android.widget.TextView
|
||||
import androidx.recyclerview.widget.RecyclerView
|
||||
|
||||
data class Message(
|
||||
val id: String,
|
||||
val content: String,
|
||||
val isUser: Boolean
|
||||
)
|
||||
|
||||
class MessageAdapter(
|
||||
private val messages: List<Message>
|
||||
) : RecyclerView.Adapter<RecyclerView.ViewHolder>() {
|
||||
|
||||
companion object {
|
||||
private const val VIEW_TYPE_USER = 1
|
||||
private const val VIEW_TYPE_ASSISTANT = 2
|
||||
}
|
||||
|
||||
override fun getItemViewType(position: Int): Int {
|
||||
return if (messages[position].isUser) VIEW_TYPE_USER else VIEW_TYPE_ASSISTANT
|
||||
}
|
||||
|
||||
override fun onCreateViewHolder(parent: ViewGroup, viewType: Int): RecyclerView.ViewHolder {
|
||||
val layoutInflater = LayoutInflater.from(parent.context)
|
||||
return if (viewType == VIEW_TYPE_USER) {
|
||||
val view = layoutInflater.inflate(R.layout.item_message_user, parent, false)
|
||||
UserMessageViewHolder(view)
|
||||
} else {
|
||||
val view = layoutInflater.inflate(R.layout.item_message_assistant, parent, false)
|
||||
AssistantMessageViewHolder(view)
|
||||
}
|
||||
}
|
||||
|
||||
override fun onBindViewHolder(holder: RecyclerView.ViewHolder, position: Int) {
|
||||
val message = messages[position]
|
||||
if (holder is UserMessageViewHolder || holder is AssistantMessageViewHolder) {
|
||||
val textView = holder.itemView.findViewById<TextView>(R.id.msg_content)
|
||||
textView.text = message.content
|
||||
}
|
||||
}
|
||||
|
||||
override fun getItemCount(): Int = messages.size
|
||||
|
||||
class UserMessageViewHolder(view: View) : RecyclerView.ViewHolder(view)
|
||||
class AssistantMessageViewHolder(view: View) : RecyclerView.ViewHolder(view)
|
||||
}
|
||||
@@ -0,0 +1,4 @@
|
||||
<shape xmlns:android="http://schemas.android.com/apk/res/android" android:shape="rectangle">
|
||||
<solid android:color="#E5E5EA" />
|
||||
<corners android:radius="16dp" />
|
||||
</shape>
|
||||
@@ -0,0 +1,4 @@
|
||||
<shape xmlns:android="http://schemas.android.com/apk/res/android" android:shape="rectangle">
|
||||
<solid android:color="#4285F4" />
|
||||
<corners android:radius="16dp" />
|
||||
</shape>
|
||||
@@ -0,0 +1,170 @@
|
||||
<?xml version="1.0" encoding="utf-8"?>
|
||||
<vector xmlns:android="http://schemas.android.com/apk/res/android"
|
||||
android:width="108dp"
|
||||
android:height="108dp"
|
||||
android:viewportWidth="108"
|
||||
android:viewportHeight="108">
|
||||
<path
|
||||
android:fillColor="#3DDC84"
|
||||
android:pathData="M0,0h108v108h-108z" />
|
||||
<path
|
||||
android:fillColor="#00000000"
|
||||
android:pathData="M9,0L9,108"
|
||||
android:strokeWidth="0.8"
|
||||
android:strokeColor="#33FFFFFF" />
|
||||
<path
|
||||
android:fillColor="#00000000"
|
||||
android:pathData="M19,0L19,108"
|
||||
android:strokeWidth="0.8"
|
||||
android:strokeColor="#33FFFFFF" />
|
||||
<path
|
||||
android:fillColor="#00000000"
|
||||
android:pathData="M29,0L29,108"
|
||||
android:strokeWidth="0.8"
|
||||
android:strokeColor="#33FFFFFF" />
|
||||
<path
|
||||
android:fillColor="#00000000"
|
||||
android:pathData="M39,0L39,108"
|
||||
android:strokeWidth="0.8"
|
||||
android:strokeColor="#33FFFFFF" />
|
||||
<path
|
||||
android:fillColor="#00000000"
|
||||
android:pathData="M49,0L49,108"
|
||||
android:strokeWidth="0.8"
|
||||
android:strokeColor="#33FFFFFF" />
|
||||
<path
|
||||
android:fillColor="#00000000"
|
||||
android:pathData="M59,0L59,108"
|
||||
android:strokeWidth="0.8"
|
||||
android:strokeColor="#33FFFFFF" />
|
||||
<path
|
||||
android:fillColor="#00000000"
|
||||
android:pathData="M69,0L69,108"
|
||||
android:strokeWidth="0.8"
|
||||
android:strokeColor="#33FFFFFF" />
|
||||
<path
|
||||
android:fillColor="#00000000"
|
||||
android:pathData="M79,0L79,108"
|
||||
android:strokeWidth="0.8"
|
||||
android:strokeColor="#33FFFFFF" />
|
||||
<path
|
||||
android:fillColor="#00000000"
|
||||
android:pathData="M89,0L89,108"
|
||||
android:strokeWidth="0.8"
|
||||
android:strokeColor="#33FFFFFF" />
|
||||
<path
|
||||
android:fillColor="#00000000"
|
||||
android:pathData="M99,0L99,108"
|
||||
android:strokeWidth="0.8"
|
||||
android:strokeColor="#33FFFFFF" />
|
||||
<path
|
||||
android:fillColor="#00000000"
|
||||
android:pathData="M0,9L108,9"
|
||||
android:strokeWidth="0.8"
|
||||
android:strokeColor="#33FFFFFF" />
|
||||
<path
|
||||
android:fillColor="#00000000"
|
||||
android:pathData="M0,19L108,19"
|
||||
android:strokeWidth="0.8"
|
||||
android:strokeColor="#33FFFFFF" />
|
||||
<path
|
||||
android:fillColor="#00000000"
|
||||
android:pathData="M0,29L108,29"
|
||||
android:strokeWidth="0.8"
|
||||
android:strokeColor="#33FFFFFF" />
|
||||
<path
|
||||
android:fillColor="#00000000"
|
||||
android:pathData="M0,39L108,39"
|
||||
android:strokeWidth="0.8"
|
||||
android:strokeColor="#33FFFFFF" />
|
||||
<path
|
||||
android:fillColor="#00000000"
|
||||
android:pathData="M0,49L108,49"
|
||||
android:strokeWidth="0.8"
|
||||
android:strokeColor="#33FFFFFF" />
|
||||
<path
|
||||
android:fillColor="#00000000"
|
||||
android:pathData="M0,59L108,59"
|
||||
android:strokeWidth="0.8"
|
||||
android:strokeColor="#33FFFFFF" />
|
||||
<path
|
||||
android:fillColor="#00000000"
|
||||
android:pathData="M0,69L108,69"
|
||||
android:strokeWidth="0.8"
|
||||
android:strokeColor="#33FFFFFF" />
|
||||
<path
|
||||
android:fillColor="#00000000"
|
||||
android:pathData="M0,79L108,79"
|
||||
android:strokeWidth="0.8"
|
||||
android:strokeColor="#33FFFFFF" />
|
||||
<path
|
||||
android:fillColor="#00000000"
|
||||
android:pathData="M0,89L108,89"
|
||||
android:strokeWidth="0.8"
|
||||
android:strokeColor="#33FFFFFF" />
|
||||
<path
|
||||
android:fillColor="#00000000"
|
||||
android:pathData="M0,99L108,99"
|
||||
android:strokeWidth="0.8"
|
||||
android:strokeColor="#33FFFFFF" />
|
||||
<path
|
||||
android:fillColor="#00000000"
|
||||
android:pathData="M19,29L89,29"
|
||||
android:strokeWidth="0.8"
|
||||
android:strokeColor="#33FFFFFF" />
|
||||
<path
|
||||
android:fillColor="#00000000"
|
||||
android:pathData="M19,39L89,39"
|
||||
android:strokeWidth="0.8"
|
||||
android:strokeColor="#33FFFFFF" />
|
||||
<path
|
||||
android:fillColor="#00000000"
|
||||
android:pathData="M19,49L89,49"
|
||||
android:strokeWidth="0.8"
|
||||
android:strokeColor="#33FFFFFF" />
|
||||
<path
|
||||
android:fillColor="#00000000"
|
||||
android:pathData="M19,59L89,59"
|
||||
android:strokeWidth="0.8"
|
||||
android:strokeColor="#33FFFFFF" />
|
||||
<path
|
||||
android:fillColor="#00000000"
|
||||
android:pathData="M19,69L89,69"
|
||||
android:strokeWidth="0.8"
|
||||
android:strokeColor="#33FFFFFF" />
|
||||
<path
|
||||
android:fillColor="#00000000"
|
||||
android:pathData="M19,79L89,79"
|
||||
android:strokeWidth="0.8"
|
||||
android:strokeColor="#33FFFFFF" />
|
||||
<path
|
||||
android:fillColor="#00000000"
|
||||
android:pathData="M29,19L29,89"
|
||||
android:strokeWidth="0.8"
|
||||
android:strokeColor="#33FFFFFF" />
|
||||
<path
|
||||
android:fillColor="#00000000"
|
||||
android:pathData="M39,19L39,89"
|
||||
android:strokeWidth="0.8"
|
||||
android:strokeColor="#33FFFFFF" />
|
||||
<path
|
||||
android:fillColor="#00000000"
|
||||
android:pathData="M49,19L49,89"
|
||||
android:strokeWidth="0.8"
|
||||
android:strokeColor="#33FFFFFF" />
|
||||
<path
|
||||
android:fillColor="#00000000"
|
||||
android:pathData="M59,19L59,89"
|
||||
android:strokeWidth="0.8"
|
||||
android:strokeColor="#33FFFFFF" />
|
||||
<path
|
||||
android:fillColor="#00000000"
|
||||
android:pathData="M69,19L69,89"
|
||||
android:strokeWidth="0.8"
|
||||
android:strokeColor="#33FFFFFF" />
|
||||
<path
|
||||
android:fillColor="#00000000"
|
||||
android:pathData="M79,19L79,89"
|
||||
android:strokeWidth="0.8"
|
||||
android:strokeColor="#33FFFFFF" />
|
||||
</vector>
|
||||
@@ -0,0 +1,30 @@
|
||||
<vector xmlns:android="http://schemas.android.com/apk/res/android"
|
||||
xmlns:aapt="http://schemas.android.com/aapt"
|
||||
android:width="108dp"
|
||||
android:height="108dp"
|
||||
android:viewportWidth="108"
|
||||
android:viewportHeight="108">
|
||||
<path android:pathData="M31,63.928c0,0 6.4,-11 12.1,-13.1c7.2,-2.6 26,-1.4 26,-1.4l38.1,38.1L107,108.928l-32,-1L31,63.928z">
|
||||
<aapt:attr name="android:fillColor">
|
||||
<gradient
|
||||
android:endX="85.84757"
|
||||
android:endY="92.4963"
|
||||
android:startX="42.9492"
|
||||
android:startY="49.59793"
|
||||
android:type="linear">
|
||||
<item
|
||||
android:color="#44000000"
|
||||
android:offset="0.0" />
|
||||
<item
|
||||
android:color="#00000000"
|
||||
android:offset="1.0" />
|
||||
</gradient>
|
||||
</aapt:attr>
|
||||
</path>
|
||||
<path
|
||||
android:fillColor="#FFFFFF"
|
||||
android:fillType="nonZero"
|
||||
android:pathData="M65.3,45.828l3.8,-6.6c0.2,-0.4 0.1,-0.9 -0.3,-1.1c-0.4,-0.2 -0.9,-0.1 -1.1,0.3l-3.9,6.7c-6.3,-2.8 -13.4,-2.8 -19.7,0l-3.9,-6.7c-0.2,-0.4 -0.7,-0.5 -1.1,-0.3C38.8,38.328 38.7,38.828 38.9,39.228l3.8,6.6C36.2,49.428 31.7,56.028 31,63.928h46C76.3,56.028 71.8,49.428 65.3,45.828zM43.4,57.328c-0.8,0 -1.5,-0.5 -1.8,-1.2c-0.3,-0.7 -0.1,-1.5 0.4,-2.1c0.5,-0.5 1.4,-0.7 2.1,-0.4c0.7,0.3 1.2,1 1.2,1.8C45.3,56.528 44.5,57.328 43.4,57.328L43.4,57.328zM64.6,57.328c-0.8,0 -1.5,-0.5 -1.8,-1.2s-0.1,-1.5 0.4,-2.1c0.5,-0.5 1.4,-0.7 2.1,-0.4c0.7,0.3 1.2,1 1.2,1.8C66.5,56.528 65.6,57.328 64.6,57.328L64.6,57.328z"
|
||||
android:strokeWidth="1"
|
||||
android:strokeColor="#00000000" />
|
||||
</vector>
|
||||
@@ -0,0 +1,10 @@
|
||||
<vector xmlns:android="http://schemas.android.com/apk/res/android"
|
||||
android:width="24dp"
|
||||
android:height="24dp"
|
||||
android:viewportWidth="24"
|
||||
android:viewportHeight="24"
|
||||
android:tint="?attr/colorControlNormal">
|
||||
<path
|
||||
android:fillColor="@android:color/white"
|
||||
android:pathData="M20,6h-8l-2,-2L4,4c-1.1,0 -1.99,0.9 -1.99,2L2,18c0,1.1 0.9,2 2,2h16c1.1,0 2,-0.9 2,-2L22,8c0,-1.1 -0.9,-2 -2,-2zM20,18L4,18L4,8h16v10z"/>
|
||||
</vector>
|
||||
@@ -0,0 +1,11 @@
|
||||
<vector xmlns:android="http://schemas.android.com/apk/res/android"
|
||||
android:width="24dp"
|
||||
android:height="24dp"
|
||||
android:viewportWidth="24"
|
||||
android:viewportHeight="24"
|
||||
android:tint="?attr/colorControlNormal"
|
||||
android:autoMirrored="true">
|
||||
<path
|
||||
android:fillColor="@android:color/white"
|
||||
android:pathData="M4.01,6.03l7.51,3.22 -7.52,-1 0.01,-2.22m7.5,8.72L4,17.97v-2.22l7.51,-1M2.01,3L2,10l15,2 -15,2 0.01,7L23,12 2.01,3z"/>
|
||||
</vector>
|
||||
@@ -0,0 +1,77 @@
|
||||
<?xml version="1.0" encoding="utf-8"?>
|
||||
<androidx.constraintlayout.widget.ConstraintLayout xmlns:android="http://schemas.android.com/apk/res/android"
|
||||
xmlns:app="http://schemas.android.com/apk/res-auto"
|
||||
xmlns:tools="http://schemas.android.com/tools"
|
||||
android:id="@+id/main"
|
||||
android:layout_height="match_parent"
|
||||
android:layout_width="match_parent">
|
||||
|
||||
<LinearLayout
|
||||
android:fitsSystemWindows="true"
|
||||
android:layout_width="match_parent"
|
||||
android:layout_height="match_parent"
|
||||
android:orientation="vertical"
|
||||
android:layout_marginEnd="4dp"
|
||||
tools:context=".MainActivity">
|
||||
|
||||
<ScrollView
|
||||
android:layout_width="match_parent"
|
||||
android:layout_height="0dp"
|
||||
android:layout_weight="1"
|
||||
android:fadeScrollbars="false">
|
||||
|
||||
<TextView
|
||||
android:id="@+id/gguf"
|
||||
android:layout_width="match_parent"
|
||||
android:layout_height="wrap_content"
|
||||
android:padding="16dp"
|
||||
android:text="Selected GGUF model's metadata will show here."
|
||||
style="@style/TextAppearance.MaterialComponents.Body2" />
|
||||
|
||||
</ScrollView>
|
||||
|
||||
<com.google.android.material.divider.MaterialDivider
|
||||
android:layout_width="match_parent"
|
||||
android:layout_height="2dp"
|
||||
android:layout_marginHorizontal="16dp" />
|
||||
|
||||
<androidx.recyclerview.widget.RecyclerView
|
||||
android:id="@+id/messages"
|
||||
android:layout_width="match_parent"
|
||||
android:layout_height="0dp"
|
||||
android:layout_weight="4"
|
||||
android:fadeScrollbars="false"
|
||||
android:scrollbars="vertical"
|
||||
app:reverseLayout="true"
|
||||
tools:listitem="@layout/item_message_assistant"/>
|
||||
|
||||
<LinearLayout
|
||||
android:layout_width="match_parent"
|
||||
android:layout_height="wrap_content"
|
||||
android:orientation="horizontal"
|
||||
android:paddingStart="16dp"
|
||||
android:paddingEnd="4dp">
|
||||
|
||||
<EditText
|
||||
android:id="@+id/user_input"
|
||||
android:enabled="false"
|
||||
android:layout_width="0dp"
|
||||
android:layout_weight="1"
|
||||
android:layout_height="match_parent"
|
||||
android:padding="8dp"
|
||||
style="@style/TextAppearance.MaterialComponents.Body2"
|
||||
android:hint="Please first pick a GGUF model file to import." />
|
||||
|
||||
<com.google.android.material.floatingactionbutton.FloatingActionButton
|
||||
android:id="@+id/fab"
|
||||
android:enabled="true"
|
||||
style="@style/Widget.Material3.FloatingActionButton.Primary"
|
||||
android:layout_width="wrap_content"
|
||||
android:layout_height="wrap_content"
|
||||
android:layout_margin="12dp"
|
||||
android:src="@drawable/outline_folder_open_24" />
|
||||
|
||||
</LinearLayout>
|
||||
|
||||
</LinearLayout>
|
||||
</androidx.constraintlayout.widget.ConstraintLayout>
|
||||
@@ -0,0 +1,16 @@
|
||||
<?xml version="1.0" encoding="utf-8"?>
|
||||
<LinearLayout xmlns:android="http://schemas.android.com/apk/res/android"
|
||||
android:layout_width="match_parent"
|
||||
android:layout_height="wrap_content"
|
||||
android:layout_marginHorizontal="16dp"
|
||||
android:layout_marginVertical="8dp"
|
||||
android:gravity="start">
|
||||
|
||||
<TextView
|
||||
android:id="@+id/msg_content"
|
||||
android:layout_width="wrap_content"
|
||||
android:layout_height="wrap_content"
|
||||
android:background="@drawable/bg_assistant_message"
|
||||
android:padding="12dp"
|
||||
android:textColor="@android:color/black" />
|
||||
</LinearLayout>
|
||||
@@ -0,0 +1,16 @@
|
||||
<?xml version="1.0" encoding="utf-8"?>
|
||||
<LinearLayout xmlns:android="http://schemas.android.com/apk/res/android"
|
||||
android:layout_width="match_parent"
|
||||
android:layout_height="wrap_content"
|
||||
android:layout_marginHorizontal="16dp"
|
||||
android:layout_marginVertical="8dp"
|
||||
android:gravity="end">
|
||||
|
||||
<TextView
|
||||
android:id="@+id/msg_content"
|
||||
android:layout_width="wrap_content"
|
||||
android:layout_height="wrap_content"
|
||||
android:background="@drawable/bg_user_message"
|
||||
android:padding="12dp"
|
||||
android:textColor="@android:color/white" />
|
||||
</LinearLayout>
|
||||
@@ -0,0 +1,6 @@
|
||||
<?xml version="1.0" encoding="utf-8"?>
|
||||
<adaptive-icon xmlns:android="http://schemas.android.com/apk/res/android">
|
||||
<background android:drawable="@drawable/ic_launcher_background" />
|
||||
<foreground android:drawable="@drawable/ic_launcher_foreground" />
|
||||
<monochrome android:drawable="@drawable/ic_launcher_foreground" />
|
||||
</adaptive-icon>
|
||||
@@ -0,0 +1,6 @@
|
||||
<?xml version="1.0" encoding="utf-8"?>
|
||||
<adaptive-icon xmlns:android="http://schemas.android.com/apk/res/android">
|
||||
<background android:drawable="@drawable/ic_launcher_background" />
|
||||
<foreground android:drawable="@drawable/ic_launcher_foreground" />
|
||||
<monochrome android:drawable="@drawable/ic_launcher_foreground" />
|
||||
</adaptive-icon>
|
||||
|
After Width: | Height: | Size: 1.4 KiB |
|
After Width: | Height: | Size: 2.8 KiB |
|
After Width: | Height: | Size: 982 B |
|
After Width: | Height: | Size: 1.7 KiB |
|
After Width: | Height: | Size: 1.9 KiB |
|
After Width: | Height: | Size: 3.8 KiB |
|
After Width: | Height: | Size: 2.8 KiB |
|
After Width: | Height: | Size: 5.8 KiB |
|
After Width: | Height: | Size: 3.8 KiB |
|
After Width: | Height: | Size: 7.6 KiB |
10
examples/llama.android/app/src/main/res/values/colors.xml
Normal file
@@ -0,0 +1,10 @@
|
||||
<?xml version="1.0" encoding="utf-8"?>
|
||||
<resources>
|
||||
<color name="purple_200">#FFBB86FC</color>
|
||||
<color name="purple_500">#FF6200EE</color>
|
||||
<color name="purple_700">#FF3700B3</color>
|
||||
<color name="teal_200">#FF03DAC5</color>
|
||||
<color name="teal_700">#FF018786</color>
|
||||
<color name="black">#FF000000</color>
|
||||
<color name="white">#FFFFFFFF</color>
|
||||
</resources>
|
||||
@@ -0,0 +1,3 @@
|
||||
<resources>
|
||||
<string name="app_name">AI Chat basic sample</string>
|
||||
</resources>
|
||||
10
examples/llama.android/app/src/main/res/values/themes.xml
Normal file
@@ -0,0 +1,10 @@
|
||||
<?xml version="1.0" encoding="utf-8"?>
|
||||
<resources>
|
||||
|
||||
<style name="Base.Theme.AiChatSample" parent="Theme.Material3.DayNight.NoActionBar">
|
||||
<!-- Customize your light theme here. -->
|
||||
<!-- <item name="colorPrimary">@color/my_light_primary</item> -->
|
||||
</style>
|
||||
|
||||
<style name="Theme.AiChatSample" parent="Base.Theme.AiChatSample" />
|
||||
</resources>
|
||||
13
examples/llama.android/app/src/main/res/xml/backup_rules.xml
Normal file
@@ -0,0 +1,13 @@
|
||||
<?xml version="1.0" encoding="utf-8"?><!--
|
||||
Sample backup rules file; uncomment and customize as necessary.
|
||||
See https://developer.android.com/guide/topics/data/autobackup
|
||||
for details.
|
||||
Note: This file is ignored for devices older that API 31
|
||||
See https://developer.android.com/about/versions/12/backup-restore
|
||||
-->
|
||||
<full-backup-content>
|
||||
<!--
|
||||
<include domain="sharedpref" path="."/>
|
||||
<exclude domain="sharedpref" path="device.xml"/>
|
||||
-->
|
||||
</full-backup-content>
|
||||
@@ -0,0 +1,19 @@
|
||||
<?xml version="1.0" encoding="utf-8"?><!--
|
||||
Sample data extraction rules file; uncomment and customize as necessary.
|
||||
See https://developer.android.com/about/versions/12/backup-restore#xml-changes
|
||||
for details.
|
||||
-->
|
||||
<data-extraction-rules>
|
||||
<cloud-backup>
|
||||
<!-- TODO: Use <include> and <exclude> to control what is backed up.
|
||||
<include .../>
|
||||
<exclude .../>
|
||||
-->
|
||||
</cloud-backup>
|
||||
<!--
|
||||
<device-transfer>
|
||||
<include .../>
|
||||
<exclude .../>
|
||||
</device-transfer>
|
||||
-->
|
||||
</data-extraction-rules>
|
||||
6
examples/llama.android/build.gradle.kts
Normal file
@@ -0,0 +1,6 @@
|
||||
// Top-level build file where you can add configuration options common to all sub-projects/modules.
|
||||
plugins {
|
||||
alias(libs.plugins.android.application) apply false
|
||||
alias(libs.plugins.android.library) apply false
|
||||
alias(libs.plugins.jetbrains.kotlin.android) apply false
|
||||
}
|
||||
24
examples/llama.android/gradle.properties
Normal file
@@ -0,0 +1,24 @@
|
||||
# Project-wide Gradle settings.
|
||||
# IDE (e.g. Android Studio) users:
|
||||
# Gradle settings configured through the IDE *will override*
|
||||
# any settings specified in this file.
|
||||
# For more details on how to configure your build environment visit
|
||||
# http://www.gradle.org/docs/current/userguide/build_environment.html
|
||||
# Specifies the JVM arguments used for the daemon process.
|
||||
# The setting is particularly useful for tweaking memory settings.
|
||||
org.gradle.jvmargs=-Xmx2048m -Dfile.encoding=UTF-8
|
||||
# When configured, Gradle will run in incubating parallel mode.
|
||||
# This option should only be used with decoupled projects. More details, visit
|
||||
# http://www.gradle.org/docs/current/userguide/multi_project_builds.html#sec:decoupled_projects
|
||||
# org.gradle.parallel=true
|
||||
# AndroidX package structure to make it clearer which packages are bundled with the
|
||||
# Android operating system, and which are packaged with your app's APK
|
||||
# https://developer.android.com/topic/libraries/support-library/androidx-rn
|
||||
android.useAndroidX=true
|
||||
# Kotlin code style for this project: "official" or "obsolete":
|
||||
kotlin.code.style=official
|
||||
# Enables namespacing of each library's R class so that its R class includes only the
|
||||
# resources declared in the library itself and none from the library's dependencies,
|
||||
# thereby reducing the size of the R class for that library
|
||||
android.nonTransitiveRClass=true
|
||||
android.native.buildOutput=verbose
|
||||
53
examples/llama.android/gradle/libs.versions.toml
Normal file
@@ -0,0 +1,53 @@
|
||||
[versions]
|
||||
|
||||
# Plugins
|
||||
agp = "8.13.2"
|
||||
kotlin = "2.3.0"
|
||||
|
||||
# AndroidX
|
||||
activity = "1.12.2"
|
||||
appcompat = "1.7.1"
|
||||
core-ktx = "1.17.0"
|
||||
constraint-layout = "2.2.1"
|
||||
datastore-preferences = "1.2.0"
|
||||
|
||||
# Material
|
||||
material = "1.13.0"
|
||||
|
||||
# Testing
|
||||
espresso-core = "3.7.0"
|
||||
androidx-junit = "1.3.0"
|
||||
junit = "4.13.2"
|
||||
|
||||
|
||||
[plugins]
|
||||
android-application = { id = "com.android.application", version.ref = "agp" }
|
||||
android-library = { id = "com.android.library", version.ref = "agp" }
|
||||
jetbrains-kotlin-android = { id = "org.jetbrains.kotlin.android", version.ref = "kotlin" }
|
||||
|
||||
|
||||
[libraries]
|
||||
|
||||
# AndroidX
|
||||
androidx-activity = { group = "androidx.activity", name = "activity", version.ref = "activity" }
|
||||
androidx-appcompat = { group = "androidx.appcompat", name = "appcompat", version.ref = "appcompat" }
|
||||
androidx-constraintlayout = { group = "androidx.constraintlayout", name = "constraintlayout", version.ref = "constraint-layout" }
|
||||
androidx-core-ktx = { group = "androidx.core", name = "core-ktx", version.ref = "core-ktx" }
|
||||
androidx-datastore-preferences = { group = "androidx.datastore", name = "datastore-preferences", version.ref = "datastore-preferences" }
|
||||
|
||||
#Material
|
||||
material = { group = "com.google.android.material", name = "material", version.ref = "material" }
|
||||
|
||||
# Testing
|
||||
androidx-espresso-core = { group = "androidx.test.espresso", name = "espresso-core", version.ref = "espresso-core" }
|
||||
androidx-junit = { group = "androidx.test.ext", name = "junit", version.ref = "androidx-junit" }
|
||||
junit = { group = "junit", name = "junit", version.ref = "junit" }
|
||||
|
||||
[bundles]
|
||||
androidx = [
|
||||
"androidx-activity",
|
||||
"androidx-appcompat",
|
||||
"androidx-constraintlayout",
|
||||
"androidx-core-ktx",
|
||||
"androidx-datastore-preferences",
|
||||
]
|
||||
BIN
examples/llama.android/gradle/wrapper/gradle-wrapper.jar
vendored
Normal file
6
examples/llama.android/gradle/wrapper/gradle-wrapper.properties
vendored
Normal file
@@ -0,0 +1,6 @@
|
||||
#Tue Apr 01 11:15:06 PDT 2025
|
||||
distributionBase=GRADLE_USER_HOME
|
||||
distributionPath=wrapper/dists
|
||||
distributionUrl=https\://services.gradle.org/distributions/gradle-8.14.3-bin.zip
|
||||
zipStoreBase=GRADLE_USER_HOME
|
||||
zipStorePath=wrapper/dists
|
||||
185
examples/llama.android/gradlew
vendored
Executable file
@@ -0,0 +1,185 @@
|
||||
#!/usr/bin/env sh
|
||||
|
||||
#
|
||||
# Copyright 2015 the original author or authors.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# https://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
|
||||
##############################################################################
|
||||
##
|
||||
## Gradle start up script for UN*X
|
||||
##
|
||||
##############################################################################
|
||||
|
||||
# Attempt to set APP_HOME
|
||||
# Resolve links: $0 may be a link
|
||||
PRG="$0"
|
||||
# Need this for relative symlinks.
|
||||
while [ -h "$PRG" ] ; do
|
||||
ls=`ls -ld "$PRG"`
|
||||
link=`expr "$ls" : '.*-> \(.*\)$'`
|
||||
if expr "$link" : '/.*' > /dev/null; then
|
||||
PRG="$link"
|
||||
else
|
||||
PRG=`dirname "$PRG"`"/$link"
|
||||
fi
|
||||
done
|
||||
SAVED="`pwd`"
|
||||
cd "`dirname \"$PRG\"`/" >/dev/null
|
||||
APP_HOME="`pwd -P`"
|
||||
cd "$SAVED" >/dev/null
|
||||
|
||||
APP_NAME="Gradle"
|
||||
APP_BASE_NAME=`basename "$0"`
|
||||
|
||||
# Add default JVM options here. You can also use JAVA_OPTS and GRADLE_OPTS to pass JVM options to this script.
|
||||
DEFAULT_JVM_OPTS='"-Xmx64m" "-Xms64m"'
|
||||
|
||||
# Use the maximum available, or set MAX_FD != -1 to use that value.
|
||||
MAX_FD="maximum"
|
||||
|
||||
warn () {
|
||||
echo "$*"
|
||||
}
|
||||
|
||||
die () {
|
||||
echo
|
||||
echo "$*"
|
||||
echo
|
||||
exit 1
|
||||
}
|
||||
|
||||
# OS specific support (must be 'true' or 'false').
|
||||
cygwin=false
|
||||
msys=false
|
||||
darwin=false
|
||||
nonstop=false
|
||||
case "`uname`" in
|
||||
CYGWIN* )
|
||||
cygwin=true
|
||||
;;
|
||||
Darwin* )
|
||||
darwin=true
|
||||
;;
|
||||
MINGW* )
|
||||
msys=true
|
||||
;;
|
||||
NONSTOP* )
|
||||
nonstop=true
|
||||
;;
|
||||
esac
|
||||
|
||||
CLASSPATH=$APP_HOME/gradle/wrapper/gradle-wrapper.jar
|
||||
|
||||
|
||||
# Determine the Java command to use to start the JVM.
|
||||
if [ -n "$JAVA_HOME" ] ; then
|
||||
if [ -x "$JAVA_HOME/jre/sh/java" ] ; then
|
||||
# IBM's JDK on AIX uses strange locations for the executables
|
||||
JAVACMD="$JAVA_HOME/jre/sh/java"
|
||||
else
|
||||
JAVACMD="$JAVA_HOME/bin/java"
|
||||
fi
|
||||
if [ ! -x "$JAVACMD" ] ; then
|
||||
die "ERROR: JAVA_HOME is set to an invalid directory: $JAVA_HOME
|
||||
|
||||
Please set the JAVA_HOME variable in your environment to match the
|
||||
location of your Java installation."
|
||||
fi
|
||||
else
|
||||
JAVACMD="java"
|
||||
which java >/dev/null 2>&1 || die "ERROR: JAVA_HOME is not set and no 'java' command could be found in your PATH.
|
||||
|
||||
Please set the JAVA_HOME variable in your environment to match the
|
||||
location of your Java installation."
|
||||
fi
|
||||
|
||||
# Increase the maximum file descriptors if we can.
|
||||
if [ "$cygwin" = "false" -a "$darwin" = "false" -a "$nonstop" = "false" ] ; then
|
||||
MAX_FD_LIMIT=`ulimit -H -n`
|
||||
if [ $? -eq 0 ] ; then
|
||||
if [ "$MAX_FD" = "maximum" -o "$MAX_FD" = "max" ] ; then
|
||||
MAX_FD="$MAX_FD_LIMIT"
|
||||
fi
|
||||
ulimit -n $MAX_FD
|
||||
if [ $? -ne 0 ] ; then
|
||||
warn "Could not set maximum file descriptor limit: $MAX_FD"
|
||||
fi
|
||||
else
|
||||
warn "Could not query maximum file descriptor limit: $MAX_FD_LIMIT"
|
||||
fi
|
||||
fi
|
||||
|
||||
# For Darwin, add options to specify how the application appears in the dock
|
||||
if $darwin; then
|
||||
GRADLE_OPTS="$GRADLE_OPTS \"-Xdock:name=$APP_NAME\" \"-Xdock:icon=$APP_HOME/media/gradle.icns\""
|
||||
fi
|
||||
|
||||
# For Cygwin or MSYS, switch paths to Windows format before running java
|
||||
if [ "$cygwin" = "true" -o "$msys" = "true" ] ; then
|
||||
APP_HOME=`cygpath --path --mixed "$APP_HOME"`
|
||||
CLASSPATH=`cygpath --path --mixed "$CLASSPATH"`
|
||||
|
||||
JAVACMD=`cygpath --unix "$JAVACMD"`
|
||||
|
||||
# We build the pattern for arguments to be converted via cygpath
|
||||
ROOTDIRSRAW=`find -L / -maxdepth 1 -mindepth 1 -type d 2>/dev/null`
|
||||
SEP=""
|
||||
for dir in $ROOTDIRSRAW ; do
|
||||
ROOTDIRS="$ROOTDIRS$SEP$dir"
|
||||
SEP="|"
|
||||
done
|
||||
OURCYGPATTERN="(^($ROOTDIRS))"
|
||||
# Add a user-defined pattern to the cygpath arguments
|
||||
if [ "$GRADLE_CYGPATTERN" != "" ] ; then
|
||||
OURCYGPATTERN="$OURCYGPATTERN|($GRADLE_CYGPATTERN)"
|
||||
fi
|
||||
# Now convert the arguments - kludge to limit ourselves to /bin/sh
|
||||
i=0
|
||||
for arg in "$@" ; do
|
||||
CHECK=`echo "$arg"|egrep -c "$OURCYGPATTERN" -`
|
||||
CHECK2=`echo "$arg"|egrep -c "^-"` ### Determine if an option
|
||||
|
||||
if [ $CHECK -ne 0 ] && [ $CHECK2 -eq 0 ] ; then ### Added a condition
|
||||
eval `echo args$i`=`cygpath --path --ignore --mixed "$arg"`
|
||||
else
|
||||
eval `echo args$i`="\"$arg\""
|
||||
fi
|
||||
i=`expr $i + 1`
|
||||
done
|
||||
case $i in
|
||||
0) set -- ;;
|
||||
1) set -- "$args0" ;;
|
||||
2) set -- "$args0" "$args1" ;;
|
||||
3) set -- "$args0" "$args1" "$args2" ;;
|
||||
4) set -- "$args0" "$args1" "$args2" "$args3" ;;
|
||||
5) set -- "$args0" "$args1" "$args2" "$args3" "$args4" ;;
|
||||
6) set -- "$args0" "$args1" "$args2" "$args3" "$args4" "$args5" ;;
|
||||
7) set -- "$args0" "$args1" "$args2" "$args3" "$args4" "$args5" "$args6" ;;
|
||||
8) set -- "$args0" "$args1" "$args2" "$args3" "$args4" "$args5" "$args6" "$args7" ;;
|
||||
9) set -- "$args0" "$args1" "$args2" "$args3" "$args4" "$args5" "$args6" "$args7" "$args8" ;;
|
||||
esac
|
||||
fi
|
||||
|
||||
# Escape application args
|
||||
save () {
|
||||
for i do printf %s\\n "$i" | sed "s/'/'\\\\''/g;1s/^/'/;\$s/\$/' \\\\/" ; done
|
||||
echo " "
|
||||
}
|
||||
APP_ARGS=`save "$@"`
|
||||
|
||||
# Collect all arguments for the java command, following the shell quoting and substitution rules
|
||||
eval set -- $DEFAULT_JVM_OPTS $JAVA_OPTS $GRADLE_OPTS "\"-Dorg.gradle.appname=$APP_BASE_NAME\"" -classpath "\"$CLASSPATH\"" org.gradle.wrapper.GradleWrapperMain "$APP_ARGS"
|
||||
|
||||
exec "$JAVACMD" "$@"
|
||||
1
examples/llama.android/lib/.gitignore
vendored
Normal file
@@ -0,0 +1 @@
|
||||
/build
|
||||
78
examples/llama.android/lib/build.gradle.kts
Normal file
@@ -0,0 +1,78 @@
|
||||
plugins {
|
||||
alias(libs.plugins.android.library)
|
||||
alias(libs.plugins.jetbrains.kotlin.android)
|
||||
}
|
||||
|
||||
android {
|
||||
namespace = "com.arm.aichat"
|
||||
compileSdk = 36
|
||||
|
||||
ndkVersion = "29.0.13113456"
|
||||
|
||||
defaultConfig {
|
||||
minSdk = 33
|
||||
|
||||
testInstrumentationRunner = "androidx.test.runner.AndroidJUnitRunner"
|
||||
consumerProguardFiles("consumer-rules.pro")
|
||||
|
||||
ndk {
|
||||
abiFilters += listOf("arm64-v8a", "x86_64")
|
||||
}
|
||||
externalNativeBuild {
|
||||
cmake {
|
||||
arguments += "-DCMAKE_BUILD_TYPE=Release"
|
||||
arguments += "-DCMAKE_MESSAGE_LOG_LEVEL=DEBUG"
|
||||
arguments += "-DCMAKE_VERBOSE_MAKEFILE=ON"
|
||||
|
||||
arguments += "-DBUILD_SHARED_LIBS=ON"
|
||||
arguments += "-DLLAMA_BUILD_COMMON=ON"
|
||||
arguments += "-DLLAMA_OPENSSL=OFF"
|
||||
|
||||
arguments += "-DGGML_NATIVE=OFF"
|
||||
arguments += "-DGGML_BACKEND_DL=ON"
|
||||
arguments += "-DGGML_CPU_ALL_VARIANTS=ON"
|
||||
arguments += "-DGGML_LLAMAFILE=OFF"
|
||||
}
|
||||
}
|
||||
aarMetadata {
|
||||
minCompileSdk = 35
|
||||
}
|
||||
}
|
||||
externalNativeBuild {
|
||||
cmake {
|
||||
path("src/main/cpp/CMakeLists.txt")
|
||||
version = "3.31.6"
|
||||
}
|
||||
}
|
||||
compileOptions {
|
||||
sourceCompatibility = JavaVersion.VERSION_17
|
||||
targetCompatibility = JavaVersion.VERSION_17
|
||||
}
|
||||
kotlin {
|
||||
jvmToolchain(17)
|
||||
|
||||
compileOptions {
|
||||
targetCompatibility = JavaVersion.VERSION_17
|
||||
}
|
||||
}
|
||||
|
||||
packaging {
|
||||
resources {
|
||||
excludes += "/META-INF/{AL2.0,LGPL2.1}"
|
||||
}
|
||||
}
|
||||
|
||||
publishing {
|
||||
singleVariant("release") {
|
||||
withJavadocJar()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
dependencies {
|
||||
implementation(libs.androidx.core.ktx)
|
||||
implementation(libs.androidx.datastore.preferences)
|
||||
|
||||
testImplementation(libs.junit)
|
||||
androidTestImplementation(libs.androidx.junit)
|
||||
}
|
||||
8
examples/llama.android/lib/consumer-rules.pro
Normal file
@@ -0,0 +1,8 @@
|
||||
-keep class com.arm.aichat.* { *; }
|
||||
-keep class com.arm.aichat.gguf.* { *; }
|
||||
|
||||
-keepclasseswithmembernames class * {
|
||||
native <methods>;
|
||||
}
|
||||
|
||||
-keep class kotlin.Metadata { *; }
|
||||
21
examples/llama.android/lib/proguard-rules.pro
vendored
Normal file
@@ -0,0 +1,21 @@
|
||||
# Add project specific ProGuard rules here.
|
||||
# You can control the set of applied configuration files using the
|
||||
# proguardFiles setting in build.gradle.
|
||||
#
|
||||
# For more details, see
|
||||
# http://developer.android.com/guide/developing/tools/proguard.html
|
||||
|
||||
# If your project uses WebView with JS, uncomment the following
|
||||
# and specify the fully qualified class name to the JavaScript interface
|
||||
# class:
|
||||
#-keepclassmembers class fqcn.of.javascript.interface.for.webview {
|
||||
# public *;
|
||||
#}
|
||||
|
||||
# Uncomment this to preserve the line number information for
|
||||
# debugging stack traces.
|
||||
#-keepattributes SourceFile,LineNumberTable
|
||||
|
||||
# If you keep the line number information, uncomment this to
|
||||
# hide the original source file name.
|
||||
#-renamesourcefileattribute SourceFile
|
||||
@@ -0,0 +1,24 @@
|
||||
package android.llama.cpp
|
||||
|
||||
import androidx.test.platform.app.InstrumentationRegistry
|
||||
import androidx.test.ext.junit.runners.AndroidJUnit4
|
||||
|
||||
import org.junit.Test
|
||||
import org.junit.runner.RunWith
|
||||
|
||||
import org.junit.Assert.*
|
||||
|
||||
/**
|
||||
* Instrumented test, which will execute on an Android device.
|
||||
*
|
||||
* See [testing documentation](http://d.android.com/tools/testing).
|
||||
*/
|
||||
@RunWith(AndroidJUnit4::class)
|
||||
class ExampleInstrumentedTest {
|
||||
@Test
|
||||
fun useAppContext() {
|
||||
// Context of the app under test.
|
||||
val appContext = InstrumentationRegistry.getInstrumentation().targetContext
|
||||
assertEquals("android.llama.cpp.test", appContext.packageName)
|
||||
}
|
||||
}
|
||||
4
examples/llama.android/lib/src/main/AndroidManifest.xml
Normal file
@@ -0,0 +1,4 @@
|
||||
<?xml version="1.0" encoding="utf-8"?>
|
||||
<manifest xmlns:android="http://schemas.android.com/apk/res/android">
|
||||
|
||||
</manifest>
|
||||
56
examples/llama.android/lib/src/main/cpp/CMakeLists.txt
Normal file
@@ -0,0 +1,56 @@
|
||||
cmake_minimum_required(VERSION 3.31.6)
|
||||
|
||||
project("ai-chat" VERSION 1.0.0 LANGUAGES C CXX)
|
||||
|
||||
set(CMAKE_C_STANDARD 11)
|
||||
set(CMAKE_C_STANDARD_REQUIRED true)
|
||||
|
||||
set(CMAKE_CXX_STANDARD 17)
|
||||
set(CMAKE_CXX_STANDARD_REQUIRED true)
|
||||
|
||||
set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS}" CACHE STRING "" FORCE)
|
||||
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS}" CACHE STRING "" FORCE)
|
||||
|
||||
# --------------------------------------------------------------------------
|
||||
# AI Chat library
|
||||
# --------------------------------------------------------------------------
|
||||
|
||||
if(DEFINED ANDROID_ABI)
|
||||
message(STATUS "Detected Android ABI: ${ANDROID_ABI}")
|
||||
if(ANDROID_ABI STREQUAL "arm64-v8a")
|
||||
set(GGML_SYSTEM_ARCH "ARM")
|
||||
set(GGML_CPU_KLEIDIAI ON)
|
||||
set(GGML_OPENMP ON)
|
||||
elseif(ANDROID_ABI STREQUAL "x86_64")
|
||||
set(GGML_SYSTEM_ARCH "x86")
|
||||
set(GGML_CPU_KLEIDIAI OFF)
|
||||
set(GGML_OPENMP OFF)
|
||||
else()
|
||||
message(FATAL_ERROR "Unsupported ABI: ${ANDROID_ABI}")
|
||||
endif()
|
||||
endif()
|
||||
|
||||
set(LLAMA_SRC ${CMAKE_CURRENT_LIST_DIR}/../../../../../../)
|
||||
add_subdirectory(${LLAMA_SRC} build-llama)
|
||||
|
||||
add_library(${CMAKE_PROJECT_NAME} SHARED
|
||||
ai_chat.cpp)
|
||||
|
||||
target_compile_definitions(${CMAKE_PROJECT_NAME} PRIVATE
|
||||
GGML_SYSTEM_ARCH=${GGML_SYSTEM_ARCH}
|
||||
GGML_CPU_KLEIDIAI=$<BOOL:${GGML_CPU_KLEIDIAI}>
|
||||
GGML_OPENMP=$<BOOL:${GGML_OPENMP}>
|
||||
)
|
||||
|
||||
target_include_directories(${CMAKE_PROJECT_NAME} PRIVATE
|
||||
${LLAMA_SRC}
|
||||
${LLAMA_SRC}/common
|
||||
${LLAMA_SRC}/include
|
||||
${LLAMA_SRC}/ggml/include
|
||||
${LLAMA_SRC}/ggml/src)
|
||||
|
||||
target_link_libraries(${CMAKE_PROJECT_NAME}
|
||||
llama
|
||||
common
|
||||
android
|
||||
log)
|
||||
565
examples/llama.android/lib/src/main/cpp/ai_chat.cpp
Normal file
@@ -0,0 +1,565 @@
|
||||
#include <android/log.h>
|
||||
#include <jni.h>
|
||||
#include <iomanip>
|
||||
#include <cmath>
|
||||
#include <string>
|
||||
#include <unistd.h>
|
||||
#include <sampling.h>
|
||||
|
||||
#include "logging.h"
|
||||
#include "chat.h"
|
||||
#include "common.h"
|
||||
#include "llama.h"
|
||||
|
||||
template<class T>
|
||||
static std::string join(const std::vector<T> &values, const std::string &delim) {
|
||||
std::ostringstream str;
|
||||
for (size_t i = 0; i < values.size(); i++) {
|
||||
str << values[i];
|
||||
if (i < values.size() - 1) { str << delim; }
|
||||
}
|
||||
return str.str();
|
||||
}
|
||||
|
||||
/**
|
||||
* LLama resources: context, model, batch and sampler
|
||||
*/
|
||||
constexpr int N_THREADS_MIN = 2;
|
||||
constexpr int N_THREADS_MAX = 4;
|
||||
constexpr int N_THREADS_HEADROOM = 2;
|
||||
|
||||
constexpr int DEFAULT_CONTEXT_SIZE = 8192;
|
||||
constexpr int OVERFLOW_HEADROOM = 4;
|
||||
constexpr int BATCH_SIZE = 512;
|
||||
constexpr float DEFAULT_SAMPLER_TEMP = 0.3f;
|
||||
|
||||
static llama_model * g_model;
|
||||
static llama_context * g_context;
|
||||
static llama_batch g_batch;
|
||||
static common_chat_templates_ptr g_chat_templates;
|
||||
static common_sampler * g_sampler;
|
||||
|
||||
extern "C"
|
||||
JNIEXPORT void JNICALL
|
||||
Java_com_arm_aichat_internal_InferenceEngineImpl_init(JNIEnv *env, jobject /*unused*/, jstring nativeLibDir) {
|
||||
// Set llama log handler to Android
|
||||
llama_log_set(aichat_android_log_callback, nullptr);
|
||||
|
||||
// Loading all CPU backend variants
|
||||
const auto *path_to_backend = env->GetStringUTFChars(nativeLibDir, 0);
|
||||
LOGi("Loading backends from %s", path_to_backend);
|
||||
ggml_backend_load_all_from_path(path_to_backend);
|
||||
env->ReleaseStringUTFChars(nativeLibDir, path_to_backend);
|
||||
|
||||
// Initialize backends
|
||||
llama_backend_init();
|
||||
LOGi("Backend initiated; Log handler set.");
|
||||
}
|
||||
|
||||
extern "C"
|
||||
JNIEXPORT jint JNICALL
|
||||
Java_com_arm_aichat_internal_InferenceEngineImpl_load(JNIEnv *env, jobject, jstring jmodel_path) {
|
||||
llama_model_params model_params = llama_model_default_params();
|
||||
|
||||
const auto *model_path = env->GetStringUTFChars(jmodel_path, 0);
|
||||
LOGd("%s: Loading model from: \n%s\n", __func__, model_path);
|
||||
|
||||
auto *model = llama_model_load_from_file(model_path, model_params);
|
||||
env->ReleaseStringUTFChars(jmodel_path, model_path);
|
||||
if (!model) {
|
||||
return 1;
|
||||
}
|
||||
g_model = model;
|
||||
return 0;
|
||||
}
|
||||
|
||||
static llama_context *init_context(llama_model *model, const int n_ctx = DEFAULT_CONTEXT_SIZE) {
|
||||
if (!model) {
|
||||
LOGe("%s: model cannot be null", __func__);
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
// Multi-threading setup
|
||||
const int n_threads = std::max(N_THREADS_MIN, std::min(N_THREADS_MAX,
|
||||
(int) sysconf(_SC_NPROCESSORS_ONLN) -
|
||||
N_THREADS_HEADROOM));
|
||||
LOGi("%s: Using %d threads", __func__, n_threads);
|
||||
|
||||
// Context parameters setup
|
||||
llama_context_params ctx_params = llama_context_default_params();
|
||||
const int trained_context_size = llama_model_n_ctx_train(model);
|
||||
if (n_ctx > trained_context_size) {
|
||||
LOGw("%s: Model was trained with only %d context size! Enforcing %d context size...",
|
||||
__func__, trained_context_size, n_ctx);
|
||||
}
|
||||
ctx_params.n_ctx = n_ctx;
|
||||
ctx_params.n_batch = BATCH_SIZE;
|
||||
ctx_params.n_ubatch = BATCH_SIZE;
|
||||
ctx_params.n_threads = n_threads;
|
||||
ctx_params.n_threads_batch = n_threads;
|
||||
auto *context = llama_init_from_model(g_model, ctx_params);
|
||||
if (context == nullptr) {
|
||||
LOGe("%s: llama_new_context_with_model() returned null)", __func__);
|
||||
}
|
||||
return context;
|
||||
}
|
||||
|
||||
static common_sampler *new_sampler(float temp) {
|
||||
common_params_sampling sparams;
|
||||
sparams.temp = temp;
|
||||
return common_sampler_init(g_model, sparams);
|
||||
}
|
||||
|
||||
extern "C"
|
||||
JNIEXPORT jint JNICALL
|
||||
Java_com_arm_aichat_internal_InferenceEngineImpl_prepare(JNIEnv * /*env*/, jobject /*unused*/) {
|
||||
auto *context = init_context(g_model);
|
||||
if (!context) { return 1; }
|
||||
g_context = context;
|
||||
g_batch = llama_batch_init(BATCH_SIZE, 0, 1);
|
||||
g_chat_templates = common_chat_templates_init(g_model, "");
|
||||
g_sampler = new_sampler(DEFAULT_SAMPLER_TEMP);
|
||||
return 0;
|
||||
}
|
||||
|
||||
static std::string get_backend() {
|
||||
std::vector<std::string> backends;
|
||||
for (size_t i = 0; i < ggml_backend_reg_count(); i++) {
|
||||
auto *reg = ggml_backend_reg_get(i);
|
||||
std::string name = ggml_backend_reg_name(reg);
|
||||
if (name != "CPU") {
|
||||
backends.push_back(ggml_backend_reg_name(reg));
|
||||
}
|
||||
}
|
||||
return backends.empty() ? "CPU" : join(backends, ",");
|
||||
}
|
||||
|
||||
extern "C"
|
||||
JNIEXPORT jstring JNICALL
|
||||
Java_com_arm_aichat_internal_InferenceEngineImpl_systemInfo(JNIEnv *env, jobject /*unused*/) {
|
||||
return env->NewStringUTF(llama_print_system_info());
|
||||
}
|
||||
|
||||
extern "C"
|
||||
JNIEXPORT jstring JNICALL
|
||||
Java_com_arm_aichat_internal_InferenceEngineImpl_benchModel(JNIEnv *env, jobject /*unused*/, jint pp, jint tg,
|
||||
jint pl, jint nr) {
|
||||
auto *context = init_context(g_model, pp);
|
||||
if (!context) {
|
||||
const auto *const err_msg = "Fail to init_context! Bench aborted.";
|
||||
LOGe(err_msg);
|
||||
return env->NewStringUTF(err_msg);
|
||||
}
|
||||
|
||||
auto pp_avg = 0.0;
|
||||
auto tg_avg = 0.0;
|
||||
auto pp_std = 0.0;
|
||||
auto tg_std = 0.0;
|
||||
|
||||
const uint32_t n_ctx = llama_n_ctx(context);
|
||||
LOGi("n_ctx = %d", n_ctx);
|
||||
|
||||
int i, j;
|
||||
int nri;
|
||||
for (nri = 0; nri < nr; nri++) {
|
||||
LOGi("Benchmark prompt processing (pp = %d)", pp);
|
||||
|
||||
common_batch_clear(g_batch);
|
||||
|
||||
const int n_tokens = pp;
|
||||
for (i = 0; i < n_tokens; i++) {
|
||||
common_batch_add(g_batch, 0, i, {0}, false);
|
||||
}
|
||||
|
||||
g_batch.logits[g_batch.n_tokens - 1] = true;
|
||||
llama_memory_clear(llama_get_memory(context), false);
|
||||
|
||||
const auto t_pp_start = ggml_time_us();
|
||||
if (llama_decode(context, g_batch) != 0) {
|
||||
LOGe("llama_decode() failed during prompt processing");
|
||||
}
|
||||
const auto t_pp_end = ggml_time_us();
|
||||
|
||||
// bench text generation
|
||||
|
||||
LOGi("Benchmark text generation (tg = %d)", tg);
|
||||
|
||||
llama_memory_clear(llama_get_memory(context), false);
|
||||
const auto t_tg_start = ggml_time_us();
|
||||
for (i = 0; i < tg; i++) {
|
||||
common_batch_clear(g_batch);
|
||||
for (j = 0; j < pl; j++) {
|
||||
common_batch_add(g_batch, 0, i, {j}, true);
|
||||
}
|
||||
|
||||
if (llama_decode(context, g_batch) != 0) {
|
||||
LOGe("llama_decode() failed during text generation");
|
||||
}
|
||||
}
|
||||
const auto t_tg_end = ggml_time_us();
|
||||
|
||||
llama_memory_clear(llama_get_memory(context), false);
|
||||
|
||||
const auto t_pp = double(t_pp_end - t_pp_start) / 1000000.0;
|
||||
const auto t_tg = double(t_tg_end - t_tg_start) / 1000000.0;
|
||||
|
||||
const auto speed_pp = double(pp) / t_pp;
|
||||
const auto speed_tg = double(pl * tg) / t_tg;
|
||||
|
||||
pp_avg += speed_pp;
|
||||
tg_avg += speed_tg;
|
||||
|
||||
pp_std += speed_pp * speed_pp;
|
||||
tg_std += speed_tg * speed_tg;
|
||||
|
||||
LOGi("pp %f t/s, tg %f t/s", speed_pp, speed_tg);
|
||||
}
|
||||
|
||||
llama_free(context);
|
||||
|
||||
pp_avg /= double(nr);
|
||||
tg_avg /= double(nr);
|
||||
|
||||
if (nr > 1) {
|
||||
pp_std = sqrt(pp_std / double(nr - 1) - pp_avg * pp_avg * double(nr) / double(nr - 1));
|
||||
tg_std = sqrt(tg_std / double(nr - 1) - tg_avg * tg_avg * double(nr) / double(nr - 1));
|
||||
} else {
|
||||
pp_std = 0;
|
||||
tg_std = 0;
|
||||
}
|
||||
|
||||
char model_desc[128];
|
||||
llama_model_desc(g_model, model_desc, sizeof(model_desc));
|
||||
|
||||
const auto model_size = double(llama_model_size(g_model)) / 1024.0 / 1024.0 / 1024.0;
|
||||
const auto model_n_params = double(llama_model_n_params(g_model)) / 1e9;
|
||||
|
||||
const auto backend = get_backend();
|
||||
std::stringstream result;
|
||||
result << std::setprecision(3);
|
||||
result << "| model | size | params | backend | test | t/s |\n";
|
||||
result << "| --- | --- | --- | --- | --- | --- |\n";
|
||||
result << "| " << model_desc << " | " << model_size << "GiB | " << model_n_params << "B | "
|
||||
<< backend << " | pp " << pp << " | " << pp_avg << " ± " << pp_std << " |\n";
|
||||
result << "| " << model_desc << " | " << model_size << "GiB | " << model_n_params << "B | "
|
||||
<< backend << " | tg " << tg << " | " << tg_avg << " ± " << tg_std << " |\n";
|
||||
return env->NewStringUTF(result.str().c_str());
|
||||
}
|
||||
|
||||
|
||||
/**
|
||||
* Completion loop's long-term states:
|
||||
* - chat management
|
||||
* - position tracking
|
||||
*/
|
||||
constexpr const char *ROLE_SYSTEM = "system";
|
||||
constexpr const char *ROLE_USER = "user";
|
||||
constexpr const char *ROLE_ASSISTANT = "assistant";
|
||||
|
||||
static std::vector<common_chat_msg> chat_msgs;
|
||||
static llama_pos system_prompt_position;
|
||||
static llama_pos current_position;
|
||||
|
||||
static void reset_long_term_states(const bool clear_kv_cache = true) {
|
||||
chat_msgs.clear();
|
||||
system_prompt_position = 0;
|
||||
current_position = 0;
|
||||
|
||||
if (clear_kv_cache)
|
||||
llama_memory_clear(llama_get_memory(g_context), false);
|
||||
}
|
||||
|
||||
/**
|
||||
* TODO-hyin: implement sliding-window version as a better alternative
|
||||
*
|
||||
* Context shifting by discarding the older half of the tokens appended after system prompt:
|
||||
* - take the [system_prompt_position] first tokens from the original prompt
|
||||
* - take half of the last (system_prompt_position - system_prompt_position) tokens
|
||||
* - recompute the logits in batches
|
||||
*/
|
||||
static void shift_context() {
|
||||
const int n_discard = (current_position - system_prompt_position) / 2;
|
||||
LOGi("%s: Discarding %d tokens", __func__, n_discard);
|
||||
llama_memory_seq_rm(llama_get_memory(g_context), 0, system_prompt_position, system_prompt_position + n_discard);
|
||||
llama_memory_seq_add(llama_get_memory(g_context), 0, system_prompt_position + n_discard, current_position, -n_discard);
|
||||
current_position -= n_discard;
|
||||
LOGi("%s: Context shifting done! Current position: %d", __func__, current_position);
|
||||
}
|
||||
|
||||
static std::string chat_add_and_format(const std::string &role, const std::string &content) {
|
||||
common_chat_msg new_msg;
|
||||
new_msg.role = role;
|
||||
new_msg.content = content;
|
||||
auto formatted = common_chat_format_single(
|
||||
g_chat_templates.get(), chat_msgs, new_msg, role == ROLE_USER, /* use_jinja */ false);
|
||||
chat_msgs.push_back(new_msg);
|
||||
LOGi("%s: Formatted and added %s message: \n%s\n", __func__, role.c_str(), formatted.c_str());
|
||||
return formatted;
|
||||
}
|
||||
|
||||
/**
|
||||
* Completion loop's short-term states:
|
||||
* - stop generation position
|
||||
* - token chars caching
|
||||
* - current assistant message being generated
|
||||
*/
|
||||
static llama_pos stop_generation_position;
|
||||
static std::string cached_token_chars;
|
||||
static std::ostringstream assistant_ss;
|
||||
|
||||
static void reset_short_term_states() {
|
||||
stop_generation_position = 0;
|
||||
cached_token_chars.clear();
|
||||
assistant_ss.str("");
|
||||
}
|
||||
|
||||
static int decode_tokens_in_batches(
|
||||
llama_context *context,
|
||||
llama_batch &batch,
|
||||
const llama_tokens &tokens,
|
||||
const llama_pos start_pos,
|
||||
const bool compute_last_logit = false) {
|
||||
// Process tokens in batches using the global batch
|
||||
LOGd("%s: Decode %d tokens starting at position %d", __func__, (int) tokens.size(), start_pos);
|
||||
for (int i = 0; i < (int) tokens.size(); i += BATCH_SIZE) {
|
||||
const int cur_batch_size = std::min((int) tokens.size() - i, BATCH_SIZE);
|
||||
common_batch_clear(batch);
|
||||
LOGv("%s: Preparing a batch size of %d starting at: %d", __func__, cur_batch_size, i);
|
||||
|
||||
// Shift context if current batch cannot fit into the context
|
||||
if (start_pos + i + cur_batch_size >= DEFAULT_CONTEXT_SIZE - OVERFLOW_HEADROOM) {
|
||||
LOGw("%s: Current batch won't fit into context! Shifting...", __func__);
|
||||
shift_context();
|
||||
}
|
||||
|
||||
// Add tokens to the batch with proper positions
|
||||
for (int j = 0; j < cur_batch_size; j++) {
|
||||
const llama_token token_id = tokens[i + j];
|
||||
const llama_pos position = start_pos + i + j;
|
||||
const bool want_logit = compute_last_logit && (i + j == tokens.size() - 1);
|
||||
common_batch_add(batch, token_id, position, {0}, want_logit);
|
||||
}
|
||||
|
||||
// Decode this batch
|
||||
const int decode_result = llama_decode(context, batch);
|
||||
if (decode_result) {
|
||||
LOGe("%s: llama_decode failed w/ %d", __func__, decode_result);
|
||||
return 1;
|
||||
}
|
||||
}
|
||||
return 0;
|
||||
}
|
||||
|
||||
extern "C"
|
||||
JNIEXPORT jint JNICALL
|
||||
Java_com_arm_aichat_internal_InferenceEngineImpl_processSystemPrompt(
|
||||
JNIEnv *env,
|
||||
jobject /*unused*/,
|
||||
jstring jsystem_prompt
|
||||
) {
|
||||
// Reset long-term & short-term states
|
||||
reset_long_term_states();
|
||||
reset_short_term_states();
|
||||
|
||||
// Obtain system prompt from JEnv
|
||||
const auto *system_prompt = env->GetStringUTFChars(jsystem_prompt, nullptr);
|
||||
LOGd("%s: System prompt received: \n%s", __func__, system_prompt);
|
||||
std::string formatted_system_prompt(system_prompt);
|
||||
env->ReleaseStringUTFChars(jsystem_prompt, system_prompt);
|
||||
|
||||
// Format system prompt if applicable
|
||||
const bool has_chat_template = common_chat_templates_was_explicit(g_chat_templates.get());
|
||||
if (has_chat_template) {
|
||||
formatted_system_prompt = chat_add_and_format(ROLE_SYSTEM, system_prompt);
|
||||
}
|
||||
|
||||
// Tokenize system prompt
|
||||
const auto system_tokens = common_tokenize(g_context, formatted_system_prompt,
|
||||
has_chat_template, has_chat_template);
|
||||
for (auto id: system_tokens) {
|
||||
LOGv("token: `%s`\t -> `%d`", common_token_to_piece(g_context, id).c_str(), id);
|
||||
}
|
||||
|
||||
// Handle context overflow
|
||||
const int max_batch_size = DEFAULT_CONTEXT_SIZE - OVERFLOW_HEADROOM;
|
||||
if ((int) system_tokens.size() > max_batch_size) {
|
||||
LOGe("%s: System prompt too long for context! %d tokens, max: %d",
|
||||
__func__, (int) system_tokens.size(), max_batch_size);
|
||||
return 1;
|
||||
}
|
||||
|
||||
// Decode system tokens in batches
|
||||
if (decode_tokens_in_batches(g_context, g_batch, system_tokens, current_position)) {
|
||||
LOGe("%s: llama_decode() failed!", __func__);
|
||||
return 2;
|
||||
}
|
||||
|
||||
// Update position
|
||||
system_prompt_position = current_position = (int) system_tokens.size();
|
||||
return 0;
|
||||
}
|
||||
|
||||
extern "C"
|
||||
JNIEXPORT jint JNICALL
|
||||
Java_com_arm_aichat_internal_InferenceEngineImpl_processUserPrompt(
|
||||
JNIEnv *env,
|
||||
jobject /*unused*/,
|
||||
jstring juser_prompt,
|
||||
jint n_predict
|
||||
) {
|
||||
// Reset short-term states
|
||||
reset_short_term_states();
|
||||
|
||||
// Obtain and tokenize user prompt
|
||||
const auto *const user_prompt = env->GetStringUTFChars(juser_prompt, nullptr);
|
||||
LOGd("%s: User prompt received: \n%s", __func__, user_prompt);
|
||||
std::string formatted_user_prompt(user_prompt);
|
||||
env->ReleaseStringUTFChars(juser_prompt, user_prompt);
|
||||
|
||||
// Format user prompt if applicable
|
||||
const bool has_chat_template = common_chat_templates_was_explicit(g_chat_templates.get());
|
||||
if (has_chat_template) {
|
||||
formatted_user_prompt = chat_add_and_format(ROLE_USER, user_prompt);
|
||||
}
|
||||
|
||||
// Decode formatted user prompts
|
||||
auto user_tokens = common_tokenize(g_context, formatted_user_prompt, has_chat_template, has_chat_template);
|
||||
for (auto id: user_tokens) {
|
||||
LOGv("token: `%s`\t -> `%d`", common_token_to_piece(g_context, id).c_str(), id);
|
||||
}
|
||||
|
||||
// Ensure user prompt doesn't exceed the context size by truncating if necessary.
|
||||
const int user_prompt_size = (int) user_tokens.size();
|
||||
const int max_batch_size = DEFAULT_CONTEXT_SIZE - OVERFLOW_HEADROOM;
|
||||
if (user_prompt_size > max_batch_size) {
|
||||
const int skipped_tokens = user_prompt_size - max_batch_size;
|
||||
user_tokens.resize(max_batch_size);
|
||||
LOGw("%s: User prompt too long! Skipped %d tokens!", __func__, skipped_tokens);
|
||||
}
|
||||
|
||||
// Decode user tokens in batches
|
||||
if (decode_tokens_in_batches(g_context, g_batch, user_tokens, current_position, true)) {
|
||||
LOGe("%s: llama_decode() failed!", __func__);
|
||||
return 2;
|
||||
}
|
||||
|
||||
// Update position
|
||||
current_position += user_prompt_size;
|
||||
stop_generation_position = current_position + user_prompt_size + n_predict;
|
||||
return 0;
|
||||
}
|
||||
|
||||
static bool is_valid_utf8(const char *string) {
|
||||
if (!string) { return true; }
|
||||
|
||||
const auto *bytes = (const unsigned char *) string;
|
||||
int num;
|
||||
|
||||
while (*bytes != 0x00) {
|
||||
if ((*bytes & 0x80) == 0x00) {
|
||||
// U+0000 to U+007F
|
||||
num = 1;
|
||||
} else if ((*bytes & 0xE0) == 0xC0) {
|
||||
// U+0080 to U+07FF
|
||||
num = 2;
|
||||
} else if ((*bytes & 0xF0) == 0xE0) {
|
||||
// U+0800 to U+FFFF
|
||||
num = 3;
|
||||
} else if ((*bytes & 0xF8) == 0xF0) {
|
||||
// U+10000 to U+10FFFF
|
||||
num = 4;
|
||||
} else {
|
||||
return false;
|
||||
}
|
||||
|
||||
bytes += 1;
|
||||
for (int i = 1; i < num; ++i) {
|
||||
if ((*bytes & 0xC0) != 0x80) {
|
||||
return false;
|
||||
}
|
||||
bytes += 1;
|
||||
}
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
extern "C"
|
||||
JNIEXPORT jstring JNICALL
|
||||
Java_com_arm_aichat_internal_InferenceEngineImpl_generateNextToken(
|
||||
JNIEnv *env,
|
||||
jobject /*unused*/
|
||||
) {
|
||||
// Infinite text generation via context shifting
|
||||
if (current_position >= DEFAULT_CONTEXT_SIZE - OVERFLOW_HEADROOM) {
|
||||
LOGw("%s: Context full! Shifting...", __func__);
|
||||
shift_context();
|
||||
}
|
||||
|
||||
// Stop if reaching the marked position
|
||||
if (current_position >= stop_generation_position) {
|
||||
LOGw("%s: STOP: hitting stop position: %d", __func__, stop_generation_position);
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
// Sample next token
|
||||
const auto new_token_id = common_sampler_sample(g_sampler, g_context, -1);
|
||||
common_sampler_accept(g_sampler, new_token_id, true);
|
||||
|
||||
// Populate the batch with new token, then decode
|
||||
common_batch_clear(g_batch);
|
||||
common_batch_add(g_batch, new_token_id, current_position, {0}, true);
|
||||
if (llama_decode(g_context, g_batch) != 0) {
|
||||
LOGe("%s: llama_decode() failed for generated token", __func__);
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
// Update position
|
||||
current_position++;
|
||||
|
||||
// Stop if next token is EOG
|
||||
if (llama_vocab_is_eog(llama_model_get_vocab(g_model), new_token_id)) {
|
||||
LOGd("id: %d,\tIS EOG!\nSTOP.", new_token_id);
|
||||
chat_add_and_format(ROLE_ASSISTANT, assistant_ss.str());
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
// If not EOG, convert to text
|
||||
auto new_token_chars = common_token_to_piece(g_context, new_token_id);
|
||||
cached_token_chars += new_token_chars;
|
||||
|
||||
// Create and return a valid UTF-8 Java string
|
||||
jstring result = nullptr;
|
||||
if (is_valid_utf8(cached_token_chars.c_str())) {
|
||||
result = env->NewStringUTF(cached_token_chars.c_str());
|
||||
LOGv("id: %d,\tcached: `%s`,\tnew: `%s`", new_token_id, cached_token_chars.c_str(), new_token_chars.c_str());
|
||||
|
||||
assistant_ss << cached_token_chars;
|
||||
cached_token_chars.clear();
|
||||
} else {
|
||||
LOGv("id: %d,\tappend to cache", new_token_id);
|
||||
result = env->NewStringUTF("");
|
||||
}
|
||||
return result;
|
||||
}
|
||||
|
||||
|
||||
extern "C"
|
||||
JNIEXPORT void JNICALL
|
||||
Java_com_arm_aichat_internal_InferenceEngineImpl_unload(JNIEnv * /*unused*/, jobject /*unused*/) {
|
||||
// Reset long-term & short-term states
|
||||
reset_long_term_states();
|
||||
reset_short_term_states();
|
||||
|
||||
// Free up resources
|
||||
common_sampler_free(g_sampler);
|
||||
g_chat_templates.reset();
|
||||
llama_batch_free(g_batch);
|
||||
llama_free(g_context);
|
||||
llama_model_free(g_model);
|
||||
}
|
||||
|
||||
extern "C"
|
||||
JNIEXPORT void JNICALL
|
||||
Java_com_arm_aichat_internal_InferenceEngineImpl_shutdown(JNIEnv *, jobject /*unused*/) {
|
||||
llama_backend_free();
|
||||
}
|
||||
61
examples/llama.android/lib/src/main/cpp/logging.h
Normal file
@@ -0,0 +1,61 @@
|
||||
//
|
||||
// Created by Han Yin on 10/31/25.
|
||||
//
|
||||
|
||||
#ifndef AICHAT_LOGGING_H
|
||||
#define AICHAT_LOGGING_H
|
||||
|
||||
#endif //AICHAT_LOGGING_H
|
||||
|
||||
#pragma once
|
||||
#include <android/log.h>
|
||||
|
||||
#ifndef LOG_TAG
|
||||
#define LOG_TAG "ai-chat"
|
||||
#endif
|
||||
|
||||
#ifndef LOG_MIN_LEVEL
|
||||
#if defined(NDEBUG)
|
||||
#define LOG_MIN_LEVEL ANDROID_LOG_INFO
|
||||
#else
|
||||
#define LOG_MIN_LEVEL ANDROID_LOG_VERBOSE
|
||||
#endif
|
||||
#endif
|
||||
|
||||
static inline int ai_should_log(int prio) {
|
||||
return __android_log_is_loggable(prio, LOG_TAG, LOG_MIN_LEVEL);
|
||||
}
|
||||
|
||||
#if LOG_MIN_LEVEL <= ANDROID_LOG_VERBOSE
|
||||
#define LOGv(...) do { if (ai_should_log(ANDROID_LOG_VERBOSE)) __android_log_print(ANDROID_LOG_VERBOSE, LOG_TAG, __VA_ARGS__); } while (0)
|
||||
#else
|
||||
#define LOGv(...) ((void)0)
|
||||
#endif
|
||||
|
||||
#if LOG_MIN_LEVEL <= ANDROID_LOG_DEBUG
|
||||
#define LOGd(...) do { if (ai_should_log(ANDROID_LOG_DEBUG)) __android_log_print(ANDROID_LOG_DEBUG, LOG_TAG, __VA_ARGS__); } while (0)
|
||||
#else
|
||||
#define LOGd(...) ((void)0)
|
||||
#endif
|
||||
|
||||
#define LOGi(...) do { if (ai_should_log(ANDROID_LOG_INFO )) __android_log_print(ANDROID_LOG_INFO , LOG_TAG, __VA_ARGS__); } while (0)
|
||||
#define LOGw(...) do { if (ai_should_log(ANDROID_LOG_WARN )) __android_log_print(ANDROID_LOG_WARN , LOG_TAG, __VA_ARGS__); } while (0)
|
||||
#define LOGe(...) do { if (ai_should_log(ANDROID_LOG_ERROR)) __android_log_print(ANDROID_LOG_ERROR, LOG_TAG, __VA_ARGS__); } while (0)
|
||||
|
||||
static inline int android_log_prio_from_ggml(enum ggml_log_level level) {
|
||||
switch (level) {
|
||||
case GGML_LOG_LEVEL_ERROR: return ANDROID_LOG_ERROR;
|
||||
case GGML_LOG_LEVEL_WARN: return ANDROID_LOG_WARN;
|
||||
case GGML_LOG_LEVEL_INFO: return ANDROID_LOG_INFO;
|
||||
case GGML_LOG_LEVEL_DEBUG: return ANDROID_LOG_DEBUG;
|
||||
default: return ANDROID_LOG_DEFAULT;
|
||||
}
|
||||
}
|
||||
|
||||
static inline void aichat_android_log_callback(enum ggml_log_level level,
|
||||
const char* text,
|
||||
void* /*user*/) {
|
||||
const int prio = android_log_prio_from_ggml(level);
|
||||
if (!ai_should_log(prio)) return;
|
||||
__android_log_write(prio, LOG_TAG, text);
|
||||
}
|
||||
@@ -0,0 +1,14 @@
|
||||
package com.arm.aichat
|
||||
|
||||
import android.content.Context
|
||||
import com.arm.aichat.internal.InferenceEngineImpl
|
||||
|
||||
/**
|
||||
* Main entry point for Arm's AI Chat library.
|
||||
*/
|
||||
object AiChat {
|
||||
/**
|
||||
* Get the inference engine single instance.
|
||||
*/
|
||||
fun getInferenceEngine(context: Context) = InferenceEngineImpl.getInstance(context)
|
||||
}
|
||||
@@ -0,0 +1,89 @@
|
||||
package com.arm.aichat
|
||||
|
||||
import com.arm.aichat.InferenceEngine.State
|
||||
import kotlinx.coroutines.flow.Flow
|
||||
import kotlinx.coroutines.flow.StateFlow
|
||||
|
||||
/**
|
||||
* Interface defining the core LLM inference operations.
|
||||
*/
|
||||
interface InferenceEngine {
|
||||
/**
|
||||
* Current state of the inference engine
|
||||
*/
|
||||
val state: StateFlow<State>
|
||||
|
||||
/**
|
||||
* Load a model from the given path.
|
||||
*
|
||||
* @throws UnsupportedArchitectureException if model architecture not supported
|
||||
*/
|
||||
suspend fun loadModel(pathToModel: String)
|
||||
|
||||
/**
|
||||
* Sends a system prompt to the loaded model
|
||||
*/
|
||||
suspend fun setSystemPrompt(systemPrompt: String)
|
||||
|
||||
/**
|
||||
* Sends a user prompt to the loaded model and returns a Flow of generated tokens.
|
||||
*/
|
||||
fun sendUserPrompt(message: String, predictLength: Int = DEFAULT_PREDICT_LENGTH): Flow<String>
|
||||
|
||||
/**
|
||||
* Runs a benchmark with the specified parameters.
|
||||
*/
|
||||
suspend fun bench(pp: Int, tg: Int, pl: Int, nr: Int = 1): String
|
||||
|
||||
/**
|
||||
* Unloads the currently loaded model.
|
||||
*/
|
||||
fun cleanUp()
|
||||
|
||||
/**
|
||||
* Cleans up resources when the engine is no longer needed.
|
||||
*/
|
||||
fun destroy()
|
||||
|
||||
/**
|
||||
* States of the inference engine
|
||||
*/
|
||||
sealed class State {
|
||||
object Uninitialized : State()
|
||||
object Initializing : State()
|
||||
object Initialized : State()
|
||||
|
||||
object LoadingModel : State()
|
||||
object UnloadingModel : State()
|
||||
object ModelReady : State()
|
||||
|
||||
object Benchmarking : State()
|
||||
object ProcessingSystemPrompt : State()
|
||||
object ProcessingUserPrompt : State()
|
||||
|
||||
object Generating : State()
|
||||
|
||||
data class Error(val exception: Exception) : State()
|
||||
}
|
||||
|
||||
companion object {
|
||||
const val DEFAULT_PREDICT_LENGTH = 1024
|
||||
}
|
||||
}
|
||||
|
||||
val State.isUninterruptible
|
||||
get() = this is State.Initializing ||
|
||||
this is State.LoadingModel ||
|
||||
this is State.UnloadingModel ||
|
||||
this is State.Benchmarking ||
|
||||
this is State.ProcessingSystemPrompt ||
|
||||
this is State.ProcessingUserPrompt
|
||||
|
||||
val State.isModelLoaded: Boolean
|
||||
get() = this is State.ModelReady ||
|
||||
this is State.Benchmarking ||
|
||||
this is State.ProcessingSystemPrompt ||
|
||||
this is State.ProcessingUserPrompt ||
|
||||
this is State.Generating
|
||||
|
||||
class UnsupportedArchitectureException : Exception()
|
||||