Sync from upstream llama.cpp repository
This commit is contained in:
1
examples/llama.android/lib/.gitignore
vendored
Normal file
1
examples/llama.android/lib/.gitignore
vendored
Normal file
@@ -0,0 +1 @@
|
||||
/build
|
||||
78
examples/llama.android/lib/build.gradle.kts
Normal file
78
examples/llama.android/lib/build.gradle.kts
Normal file
@@ -0,0 +1,78 @@
|
||||
plugins {
|
||||
alias(libs.plugins.android.library)
|
||||
alias(libs.plugins.jetbrains.kotlin.android)
|
||||
}
|
||||
|
||||
android {
|
||||
namespace = "com.arm.aichat"
|
||||
compileSdk = 36
|
||||
|
||||
ndkVersion = "29.0.13113456"
|
||||
|
||||
defaultConfig {
|
||||
minSdk = 33
|
||||
|
||||
testInstrumentationRunner = "androidx.test.runner.AndroidJUnitRunner"
|
||||
consumerProguardFiles("consumer-rules.pro")
|
||||
|
||||
ndk {
|
||||
abiFilters += listOf("arm64-v8a", "x86_64")
|
||||
}
|
||||
externalNativeBuild {
|
||||
cmake {
|
||||
arguments += "-DCMAKE_BUILD_TYPE=Release"
|
||||
arguments += "-DCMAKE_MESSAGE_LOG_LEVEL=DEBUG"
|
||||
arguments += "-DCMAKE_VERBOSE_MAKEFILE=ON"
|
||||
|
||||
arguments += "-DBUILD_SHARED_LIBS=ON"
|
||||
arguments += "-DLLAMA_BUILD_COMMON=ON"
|
||||
arguments += "-DLLAMA_OPENSSL=OFF"
|
||||
|
||||
arguments += "-DGGML_NATIVE=OFF"
|
||||
arguments += "-DGGML_BACKEND_DL=ON"
|
||||
arguments += "-DGGML_CPU_ALL_VARIANTS=ON"
|
||||
arguments += "-DGGML_LLAMAFILE=OFF"
|
||||
}
|
||||
}
|
||||
aarMetadata {
|
||||
minCompileSdk = 35
|
||||
}
|
||||
}
|
||||
externalNativeBuild {
|
||||
cmake {
|
||||
path("src/main/cpp/CMakeLists.txt")
|
||||
version = "3.31.6"
|
||||
}
|
||||
}
|
||||
compileOptions {
|
||||
sourceCompatibility = JavaVersion.VERSION_17
|
||||
targetCompatibility = JavaVersion.VERSION_17
|
||||
}
|
||||
kotlin {
|
||||
jvmToolchain(17)
|
||||
|
||||
compileOptions {
|
||||
targetCompatibility = JavaVersion.VERSION_17
|
||||
}
|
||||
}
|
||||
|
||||
packaging {
|
||||
resources {
|
||||
excludes += "/META-INF/{AL2.0,LGPL2.1}"
|
||||
}
|
||||
}
|
||||
|
||||
publishing {
|
||||
singleVariant("release") {
|
||||
withJavadocJar()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
dependencies {
|
||||
implementation(libs.androidx.core.ktx)
|
||||
implementation(libs.androidx.datastore.preferences)
|
||||
|
||||
testImplementation(libs.junit)
|
||||
androidTestImplementation(libs.androidx.junit)
|
||||
}
|
||||
8
examples/llama.android/lib/consumer-rules.pro
Normal file
8
examples/llama.android/lib/consumer-rules.pro
Normal file
@@ -0,0 +1,8 @@
|
||||
-keep class com.arm.aichat.* { *; }
|
||||
-keep class com.arm.aichat.gguf.* { *; }
|
||||
|
||||
-keepclasseswithmembernames class * {
|
||||
native <methods>;
|
||||
}
|
||||
|
||||
-keep class kotlin.Metadata { *; }
|
||||
21
examples/llama.android/lib/proguard-rules.pro
vendored
Normal file
21
examples/llama.android/lib/proguard-rules.pro
vendored
Normal file
@@ -0,0 +1,21 @@
|
||||
# Add project specific ProGuard rules here.
|
||||
# You can control the set of applied configuration files using the
|
||||
# proguardFiles setting in build.gradle.
|
||||
#
|
||||
# For more details, see
|
||||
# http://developer.android.com/guide/developing/tools/proguard.html
|
||||
|
||||
# If your project uses WebView with JS, uncomment the following
|
||||
# and specify the fully qualified class name to the JavaScript interface
|
||||
# class:
|
||||
#-keepclassmembers class fqcn.of.javascript.interface.for.webview {
|
||||
# public *;
|
||||
#}
|
||||
|
||||
# Uncomment this to preserve the line number information for
|
||||
# debugging stack traces.
|
||||
#-keepattributes SourceFile,LineNumberTable
|
||||
|
||||
# If you keep the line number information, uncomment this to
|
||||
# hide the original source file name.
|
||||
#-renamesourcefileattribute SourceFile
|
||||
@@ -0,0 +1,24 @@
|
||||
package android.llama.cpp
|
||||
|
||||
import androidx.test.platform.app.InstrumentationRegistry
|
||||
import androidx.test.ext.junit.runners.AndroidJUnit4
|
||||
|
||||
import org.junit.Test
|
||||
import org.junit.runner.RunWith
|
||||
|
||||
import org.junit.Assert.*
|
||||
|
||||
/**
|
||||
* Instrumented test, which will execute on an Android device.
|
||||
*
|
||||
* See [testing documentation](http://d.android.com/tools/testing).
|
||||
*/
|
||||
@RunWith(AndroidJUnit4::class)
|
||||
class ExampleInstrumentedTest {
|
||||
@Test
|
||||
fun useAppContext() {
|
||||
// Context of the app under test.
|
||||
val appContext = InstrumentationRegistry.getInstrumentation().targetContext
|
||||
assertEquals("android.llama.cpp.test", appContext.packageName)
|
||||
}
|
||||
}
|
||||
4
examples/llama.android/lib/src/main/AndroidManifest.xml
Normal file
4
examples/llama.android/lib/src/main/AndroidManifest.xml
Normal file
@@ -0,0 +1,4 @@
|
||||
<?xml version="1.0" encoding="utf-8"?>
|
||||
<manifest xmlns:android="http://schemas.android.com/apk/res/android">
|
||||
|
||||
</manifest>
|
||||
56
examples/llama.android/lib/src/main/cpp/CMakeLists.txt
Normal file
56
examples/llama.android/lib/src/main/cpp/CMakeLists.txt
Normal file
@@ -0,0 +1,56 @@
|
||||
cmake_minimum_required(VERSION 3.31.6)
|
||||
|
||||
project("ai-chat" VERSION 1.0.0 LANGUAGES C CXX)
|
||||
|
||||
set(CMAKE_C_STANDARD 11)
|
||||
set(CMAKE_C_STANDARD_REQUIRED true)
|
||||
|
||||
set(CMAKE_CXX_STANDARD 17)
|
||||
set(CMAKE_CXX_STANDARD_REQUIRED true)
|
||||
|
||||
set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS}" CACHE STRING "" FORCE)
|
||||
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS}" CACHE STRING "" FORCE)
|
||||
|
||||
# --------------------------------------------------------------------------
|
||||
# AI Chat library
|
||||
# --------------------------------------------------------------------------
|
||||
|
||||
if(DEFINED ANDROID_ABI)
|
||||
message(STATUS "Detected Android ABI: ${ANDROID_ABI}")
|
||||
if(ANDROID_ABI STREQUAL "arm64-v8a")
|
||||
set(GGML_SYSTEM_ARCH "ARM")
|
||||
set(GGML_CPU_KLEIDIAI ON)
|
||||
set(GGML_OPENMP ON)
|
||||
elseif(ANDROID_ABI STREQUAL "x86_64")
|
||||
set(GGML_SYSTEM_ARCH "x86")
|
||||
set(GGML_CPU_KLEIDIAI OFF)
|
||||
set(GGML_OPENMP OFF)
|
||||
else()
|
||||
message(FATAL_ERROR "Unsupported ABI: ${ANDROID_ABI}")
|
||||
endif()
|
||||
endif()
|
||||
|
||||
set(LLAMA_SRC ${CMAKE_CURRENT_LIST_DIR}/../../../../../../)
|
||||
add_subdirectory(${LLAMA_SRC} build-llama)
|
||||
|
||||
add_library(${CMAKE_PROJECT_NAME} SHARED
|
||||
ai_chat.cpp)
|
||||
|
||||
target_compile_definitions(${CMAKE_PROJECT_NAME} PRIVATE
|
||||
GGML_SYSTEM_ARCH=${GGML_SYSTEM_ARCH}
|
||||
GGML_CPU_KLEIDIAI=$<BOOL:${GGML_CPU_KLEIDIAI}>
|
||||
GGML_OPENMP=$<BOOL:${GGML_OPENMP}>
|
||||
)
|
||||
|
||||
target_include_directories(${CMAKE_PROJECT_NAME} PRIVATE
|
||||
${LLAMA_SRC}
|
||||
${LLAMA_SRC}/common
|
||||
${LLAMA_SRC}/include
|
||||
${LLAMA_SRC}/ggml/include
|
||||
${LLAMA_SRC}/ggml/src)
|
||||
|
||||
target_link_libraries(${CMAKE_PROJECT_NAME}
|
||||
llama
|
||||
common
|
||||
android
|
||||
log)
|
||||
565
examples/llama.android/lib/src/main/cpp/ai_chat.cpp
Normal file
565
examples/llama.android/lib/src/main/cpp/ai_chat.cpp
Normal file
@@ -0,0 +1,565 @@
|
||||
#include <android/log.h>
|
||||
#include <jni.h>
|
||||
#include <iomanip>
|
||||
#include <cmath>
|
||||
#include <string>
|
||||
#include <unistd.h>
|
||||
#include <sampling.h>
|
||||
|
||||
#include "logging.h"
|
||||
#include "chat.h"
|
||||
#include "common.h"
|
||||
#include "llama.h"
|
||||
|
||||
template<class T>
|
||||
static std::string join(const std::vector<T> &values, const std::string &delim) {
|
||||
std::ostringstream str;
|
||||
for (size_t i = 0; i < values.size(); i++) {
|
||||
str << values[i];
|
||||
if (i < values.size() - 1) { str << delim; }
|
||||
}
|
||||
return str.str();
|
||||
}
|
||||
|
||||
/**
|
||||
* LLama resources: context, model, batch and sampler
|
||||
*/
|
||||
constexpr int N_THREADS_MIN = 2;
|
||||
constexpr int N_THREADS_MAX = 4;
|
||||
constexpr int N_THREADS_HEADROOM = 2;
|
||||
|
||||
constexpr int DEFAULT_CONTEXT_SIZE = 8192;
|
||||
constexpr int OVERFLOW_HEADROOM = 4;
|
||||
constexpr int BATCH_SIZE = 512;
|
||||
constexpr float DEFAULT_SAMPLER_TEMP = 0.3f;
|
||||
|
||||
static llama_model * g_model;
|
||||
static llama_context * g_context;
|
||||
static llama_batch g_batch;
|
||||
static common_chat_templates_ptr g_chat_templates;
|
||||
static common_sampler * g_sampler;
|
||||
|
||||
extern "C"
|
||||
JNIEXPORT void JNICALL
|
||||
Java_com_arm_aichat_internal_InferenceEngineImpl_init(JNIEnv *env, jobject /*unused*/, jstring nativeLibDir) {
|
||||
// Set llama log handler to Android
|
||||
llama_log_set(aichat_android_log_callback, nullptr);
|
||||
|
||||
// Loading all CPU backend variants
|
||||
const auto *path_to_backend = env->GetStringUTFChars(nativeLibDir, 0);
|
||||
LOGi("Loading backends from %s", path_to_backend);
|
||||
ggml_backend_load_all_from_path(path_to_backend);
|
||||
env->ReleaseStringUTFChars(nativeLibDir, path_to_backend);
|
||||
|
||||
// Initialize backends
|
||||
llama_backend_init();
|
||||
LOGi("Backend initiated; Log handler set.");
|
||||
}
|
||||
|
||||
extern "C"
|
||||
JNIEXPORT jint JNICALL
|
||||
Java_com_arm_aichat_internal_InferenceEngineImpl_load(JNIEnv *env, jobject, jstring jmodel_path) {
|
||||
llama_model_params model_params = llama_model_default_params();
|
||||
|
||||
const auto *model_path = env->GetStringUTFChars(jmodel_path, 0);
|
||||
LOGd("%s: Loading model from: \n%s\n", __func__, model_path);
|
||||
|
||||
auto *model = llama_model_load_from_file(model_path, model_params);
|
||||
env->ReleaseStringUTFChars(jmodel_path, model_path);
|
||||
if (!model) {
|
||||
return 1;
|
||||
}
|
||||
g_model = model;
|
||||
return 0;
|
||||
}
|
||||
|
||||
static llama_context *init_context(llama_model *model, const int n_ctx = DEFAULT_CONTEXT_SIZE) {
|
||||
if (!model) {
|
||||
LOGe("%s: model cannot be null", __func__);
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
// Multi-threading setup
|
||||
const int n_threads = std::max(N_THREADS_MIN, std::min(N_THREADS_MAX,
|
||||
(int) sysconf(_SC_NPROCESSORS_ONLN) -
|
||||
N_THREADS_HEADROOM));
|
||||
LOGi("%s: Using %d threads", __func__, n_threads);
|
||||
|
||||
// Context parameters setup
|
||||
llama_context_params ctx_params = llama_context_default_params();
|
||||
const int trained_context_size = llama_model_n_ctx_train(model);
|
||||
if (n_ctx > trained_context_size) {
|
||||
LOGw("%s: Model was trained with only %d context size! Enforcing %d context size...",
|
||||
__func__, trained_context_size, n_ctx);
|
||||
}
|
||||
ctx_params.n_ctx = n_ctx;
|
||||
ctx_params.n_batch = BATCH_SIZE;
|
||||
ctx_params.n_ubatch = BATCH_SIZE;
|
||||
ctx_params.n_threads = n_threads;
|
||||
ctx_params.n_threads_batch = n_threads;
|
||||
auto *context = llama_init_from_model(g_model, ctx_params);
|
||||
if (context == nullptr) {
|
||||
LOGe("%s: llama_new_context_with_model() returned null)", __func__);
|
||||
}
|
||||
return context;
|
||||
}
|
||||
|
||||
static common_sampler *new_sampler(float temp) {
|
||||
common_params_sampling sparams;
|
||||
sparams.temp = temp;
|
||||
return common_sampler_init(g_model, sparams);
|
||||
}
|
||||
|
||||
extern "C"
|
||||
JNIEXPORT jint JNICALL
|
||||
Java_com_arm_aichat_internal_InferenceEngineImpl_prepare(JNIEnv * /*env*/, jobject /*unused*/) {
|
||||
auto *context = init_context(g_model);
|
||||
if (!context) { return 1; }
|
||||
g_context = context;
|
||||
g_batch = llama_batch_init(BATCH_SIZE, 0, 1);
|
||||
g_chat_templates = common_chat_templates_init(g_model, "");
|
||||
g_sampler = new_sampler(DEFAULT_SAMPLER_TEMP);
|
||||
return 0;
|
||||
}
|
||||
|
||||
static std::string get_backend() {
|
||||
std::vector<std::string> backends;
|
||||
for (size_t i = 0; i < ggml_backend_reg_count(); i++) {
|
||||
auto *reg = ggml_backend_reg_get(i);
|
||||
std::string name = ggml_backend_reg_name(reg);
|
||||
if (name != "CPU") {
|
||||
backends.push_back(ggml_backend_reg_name(reg));
|
||||
}
|
||||
}
|
||||
return backends.empty() ? "CPU" : join(backends, ",");
|
||||
}
|
||||
|
||||
extern "C"
|
||||
JNIEXPORT jstring JNICALL
|
||||
Java_com_arm_aichat_internal_InferenceEngineImpl_systemInfo(JNIEnv *env, jobject /*unused*/) {
|
||||
return env->NewStringUTF(llama_print_system_info());
|
||||
}
|
||||
|
||||
extern "C"
|
||||
JNIEXPORT jstring JNICALL
|
||||
Java_com_arm_aichat_internal_InferenceEngineImpl_benchModel(JNIEnv *env, jobject /*unused*/, jint pp, jint tg,
|
||||
jint pl, jint nr) {
|
||||
auto *context = init_context(g_model, pp);
|
||||
if (!context) {
|
||||
const auto *const err_msg = "Fail to init_context! Bench aborted.";
|
||||
LOGe(err_msg);
|
||||
return env->NewStringUTF(err_msg);
|
||||
}
|
||||
|
||||
auto pp_avg = 0.0;
|
||||
auto tg_avg = 0.0;
|
||||
auto pp_std = 0.0;
|
||||
auto tg_std = 0.0;
|
||||
|
||||
const uint32_t n_ctx = llama_n_ctx(context);
|
||||
LOGi("n_ctx = %d", n_ctx);
|
||||
|
||||
int i, j;
|
||||
int nri;
|
||||
for (nri = 0; nri < nr; nri++) {
|
||||
LOGi("Benchmark prompt processing (pp = %d)", pp);
|
||||
|
||||
common_batch_clear(g_batch);
|
||||
|
||||
const int n_tokens = pp;
|
||||
for (i = 0; i < n_tokens; i++) {
|
||||
common_batch_add(g_batch, 0, i, {0}, false);
|
||||
}
|
||||
|
||||
g_batch.logits[g_batch.n_tokens - 1] = true;
|
||||
llama_memory_clear(llama_get_memory(context), false);
|
||||
|
||||
const auto t_pp_start = ggml_time_us();
|
||||
if (llama_decode(context, g_batch) != 0) {
|
||||
LOGe("llama_decode() failed during prompt processing");
|
||||
}
|
||||
const auto t_pp_end = ggml_time_us();
|
||||
|
||||
// bench text generation
|
||||
|
||||
LOGi("Benchmark text generation (tg = %d)", tg);
|
||||
|
||||
llama_memory_clear(llama_get_memory(context), false);
|
||||
const auto t_tg_start = ggml_time_us();
|
||||
for (i = 0; i < tg; i++) {
|
||||
common_batch_clear(g_batch);
|
||||
for (j = 0; j < pl; j++) {
|
||||
common_batch_add(g_batch, 0, i, {j}, true);
|
||||
}
|
||||
|
||||
if (llama_decode(context, g_batch) != 0) {
|
||||
LOGe("llama_decode() failed during text generation");
|
||||
}
|
||||
}
|
||||
const auto t_tg_end = ggml_time_us();
|
||||
|
||||
llama_memory_clear(llama_get_memory(context), false);
|
||||
|
||||
const auto t_pp = double(t_pp_end - t_pp_start) / 1000000.0;
|
||||
const auto t_tg = double(t_tg_end - t_tg_start) / 1000000.0;
|
||||
|
||||
const auto speed_pp = double(pp) / t_pp;
|
||||
const auto speed_tg = double(pl * tg) / t_tg;
|
||||
|
||||
pp_avg += speed_pp;
|
||||
tg_avg += speed_tg;
|
||||
|
||||
pp_std += speed_pp * speed_pp;
|
||||
tg_std += speed_tg * speed_tg;
|
||||
|
||||
LOGi("pp %f t/s, tg %f t/s", speed_pp, speed_tg);
|
||||
}
|
||||
|
||||
llama_free(context);
|
||||
|
||||
pp_avg /= double(nr);
|
||||
tg_avg /= double(nr);
|
||||
|
||||
if (nr > 1) {
|
||||
pp_std = sqrt(pp_std / double(nr - 1) - pp_avg * pp_avg * double(nr) / double(nr - 1));
|
||||
tg_std = sqrt(tg_std / double(nr - 1) - tg_avg * tg_avg * double(nr) / double(nr - 1));
|
||||
} else {
|
||||
pp_std = 0;
|
||||
tg_std = 0;
|
||||
}
|
||||
|
||||
char model_desc[128];
|
||||
llama_model_desc(g_model, model_desc, sizeof(model_desc));
|
||||
|
||||
const auto model_size = double(llama_model_size(g_model)) / 1024.0 / 1024.0 / 1024.0;
|
||||
const auto model_n_params = double(llama_model_n_params(g_model)) / 1e9;
|
||||
|
||||
const auto backend = get_backend();
|
||||
std::stringstream result;
|
||||
result << std::setprecision(3);
|
||||
result << "| model | size | params | backend | test | t/s |\n";
|
||||
result << "| --- | --- | --- | --- | --- | --- |\n";
|
||||
result << "| " << model_desc << " | " << model_size << "GiB | " << model_n_params << "B | "
|
||||
<< backend << " | pp " << pp << " | " << pp_avg << " ± " << pp_std << " |\n";
|
||||
result << "| " << model_desc << " | " << model_size << "GiB | " << model_n_params << "B | "
|
||||
<< backend << " | tg " << tg << " | " << tg_avg << " ± " << tg_std << " |\n";
|
||||
return env->NewStringUTF(result.str().c_str());
|
||||
}
|
||||
|
||||
|
||||
/**
|
||||
* Completion loop's long-term states:
|
||||
* - chat management
|
||||
* - position tracking
|
||||
*/
|
||||
constexpr const char *ROLE_SYSTEM = "system";
|
||||
constexpr const char *ROLE_USER = "user";
|
||||
constexpr const char *ROLE_ASSISTANT = "assistant";
|
||||
|
||||
static std::vector<common_chat_msg> chat_msgs;
|
||||
static llama_pos system_prompt_position;
|
||||
static llama_pos current_position;
|
||||
|
||||
static void reset_long_term_states(const bool clear_kv_cache = true) {
|
||||
chat_msgs.clear();
|
||||
system_prompt_position = 0;
|
||||
current_position = 0;
|
||||
|
||||
if (clear_kv_cache)
|
||||
llama_memory_clear(llama_get_memory(g_context), false);
|
||||
}
|
||||
|
||||
/**
|
||||
* TODO-hyin: implement sliding-window version as a better alternative
|
||||
*
|
||||
* Context shifting by discarding the older half of the tokens appended after system prompt:
|
||||
* - take the [system_prompt_position] first tokens from the original prompt
|
||||
* - take half of the last (system_prompt_position - system_prompt_position) tokens
|
||||
* - recompute the logits in batches
|
||||
*/
|
||||
static void shift_context() {
|
||||
const int n_discard = (current_position - system_prompt_position) / 2;
|
||||
LOGi("%s: Discarding %d tokens", __func__, n_discard);
|
||||
llama_memory_seq_rm(llama_get_memory(g_context), 0, system_prompt_position, system_prompt_position + n_discard);
|
||||
llama_memory_seq_add(llama_get_memory(g_context), 0, system_prompt_position + n_discard, current_position, -n_discard);
|
||||
current_position -= n_discard;
|
||||
LOGi("%s: Context shifting done! Current position: %d", __func__, current_position);
|
||||
}
|
||||
|
||||
static std::string chat_add_and_format(const std::string &role, const std::string &content) {
|
||||
common_chat_msg new_msg;
|
||||
new_msg.role = role;
|
||||
new_msg.content = content;
|
||||
auto formatted = common_chat_format_single(
|
||||
g_chat_templates.get(), chat_msgs, new_msg, role == ROLE_USER, /* use_jinja */ false);
|
||||
chat_msgs.push_back(new_msg);
|
||||
LOGi("%s: Formatted and added %s message: \n%s\n", __func__, role.c_str(), formatted.c_str());
|
||||
return formatted;
|
||||
}
|
||||
|
||||
/**
|
||||
* Completion loop's short-term states:
|
||||
* - stop generation position
|
||||
* - token chars caching
|
||||
* - current assistant message being generated
|
||||
*/
|
||||
static llama_pos stop_generation_position;
|
||||
static std::string cached_token_chars;
|
||||
static std::ostringstream assistant_ss;
|
||||
|
||||
static void reset_short_term_states() {
|
||||
stop_generation_position = 0;
|
||||
cached_token_chars.clear();
|
||||
assistant_ss.str("");
|
||||
}
|
||||
|
||||
static int decode_tokens_in_batches(
|
||||
llama_context *context,
|
||||
llama_batch &batch,
|
||||
const llama_tokens &tokens,
|
||||
const llama_pos start_pos,
|
||||
const bool compute_last_logit = false) {
|
||||
// Process tokens in batches using the global batch
|
||||
LOGd("%s: Decode %d tokens starting at position %d", __func__, (int) tokens.size(), start_pos);
|
||||
for (int i = 0; i < (int) tokens.size(); i += BATCH_SIZE) {
|
||||
const int cur_batch_size = std::min((int) tokens.size() - i, BATCH_SIZE);
|
||||
common_batch_clear(batch);
|
||||
LOGv("%s: Preparing a batch size of %d starting at: %d", __func__, cur_batch_size, i);
|
||||
|
||||
// Shift context if current batch cannot fit into the context
|
||||
if (start_pos + i + cur_batch_size >= DEFAULT_CONTEXT_SIZE - OVERFLOW_HEADROOM) {
|
||||
LOGw("%s: Current batch won't fit into context! Shifting...", __func__);
|
||||
shift_context();
|
||||
}
|
||||
|
||||
// Add tokens to the batch with proper positions
|
||||
for (int j = 0; j < cur_batch_size; j++) {
|
||||
const llama_token token_id = tokens[i + j];
|
||||
const llama_pos position = start_pos + i + j;
|
||||
const bool want_logit = compute_last_logit && (i + j == tokens.size() - 1);
|
||||
common_batch_add(batch, token_id, position, {0}, want_logit);
|
||||
}
|
||||
|
||||
// Decode this batch
|
||||
const int decode_result = llama_decode(context, batch);
|
||||
if (decode_result) {
|
||||
LOGe("%s: llama_decode failed w/ %d", __func__, decode_result);
|
||||
return 1;
|
||||
}
|
||||
}
|
||||
return 0;
|
||||
}
|
||||
|
||||
extern "C"
|
||||
JNIEXPORT jint JNICALL
|
||||
Java_com_arm_aichat_internal_InferenceEngineImpl_processSystemPrompt(
|
||||
JNIEnv *env,
|
||||
jobject /*unused*/,
|
||||
jstring jsystem_prompt
|
||||
) {
|
||||
// Reset long-term & short-term states
|
||||
reset_long_term_states();
|
||||
reset_short_term_states();
|
||||
|
||||
// Obtain system prompt from JEnv
|
||||
const auto *system_prompt = env->GetStringUTFChars(jsystem_prompt, nullptr);
|
||||
LOGd("%s: System prompt received: \n%s", __func__, system_prompt);
|
||||
std::string formatted_system_prompt(system_prompt);
|
||||
env->ReleaseStringUTFChars(jsystem_prompt, system_prompt);
|
||||
|
||||
// Format system prompt if applicable
|
||||
const bool has_chat_template = common_chat_templates_was_explicit(g_chat_templates.get());
|
||||
if (has_chat_template) {
|
||||
formatted_system_prompt = chat_add_and_format(ROLE_SYSTEM, system_prompt);
|
||||
}
|
||||
|
||||
// Tokenize system prompt
|
||||
const auto system_tokens = common_tokenize(g_context, formatted_system_prompt,
|
||||
has_chat_template, has_chat_template);
|
||||
for (auto id: system_tokens) {
|
||||
LOGv("token: `%s`\t -> `%d`", common_token_to_piece(g_context, id).c_str(), id);
|
||||
}
|
||||
|
||||
// Handle context overflow
|
||||
const int max_batch_size = DEFAULT_CONTEXT_SIZE - OVERFLOW_HEADROOM;
|
||||
if ((int) system_tokens.size() > max_batch_size) {
|
||||
LOGe("%s: System prompt too long for context! %d tokens, max: %d",
|
||||
__func__, (int) system_tokens.size(), max_batch_size);
|
||||
return 1;
|
||||
}
|
||||
|
||||
// Decode system tokens in batches
|
||||
if (decode_tokens_in_batches(g_context, g_batch, system_tokens, current_position)) {
|
||||
LOGe("%s: llama_decode() failed!", __func__);
|
||||
return 2;
|
||||
}
|
||||
|
||||
// Update position
|
||||
system_prompt_position = current_position = (int) system_tokens.size();
|
||||
return 0;
|
||||
}
|
||||
|
||||
extern "C"
|
||||
JNIEXPORT jint JNICALL
|
||||
Java_com_arm_aichat_internal_InferenceEngineImpl_processUserPrompt(
|
||||
JNIEnv *env,
|
||||
jobject /*unused*/,
|
||||
jstring juser_prompt,
|
||||
jint n_predict
|
||||
) {
|
||||
// Reset short-term states
|
||||
reset_short_term_states();
|
||||
|
||||
// Obtain and tokenize user prompt
|
||||
const auto *const user_prompt = env->GetStringUTFChars(juser_prompt, nullptr);
|
||||
LOGd("%s: User prompt received: \n%s", __func__, user_prompt);
|
||||
std::string formatted_user_prompt(user_prompt);
|
||||
env->ReleaseStringUTFChars(juser_prompt, user_prompt);
|
||||
|
||||
// Format user prompt if applicable
|
||||
const bool has_chat_template = common_chat_templates_was_explicit(g_chat_templates.get());
|
||||
if (has_chat_template) {
|
||||
formatted_user_prompt = chat_add_and_format(ROLE_USER, user_prompt);
|
||||
}
|
||||
|
||||
// Decode formatted user prompts
|
||||
auto user_tokens = common_tokenize(g_context, formatted_user_prompt, has_chat_template, has_chat_template);
|
||||
for (auto id: user_tokens) {
|
||||
LOGv("token: `%s`\t -> `%d`", common_token_to_piece(g_context, id).c_str(), id);
|
||||
}
|
||||
|
||||
// Ensure user prompt doesn't exceed the context size by truncating if necessary.
|
||||
const int user_prompt_size = (int) user_tokens.size();
|
||||
const int max_batch_size = DEFAULT_CONTEXT_SIZE - OVERFLOW_HEADROOM;
|
||||
if (user_prompt_size > max_batch_size) {
|
||||
const int skipped_tokens = user_prompt_size - max_batch_size;
|
||||
user_tokens.resize(max_batch_size);
|
||||
LOGw("%s: User prompt too long! Skipped %d tokens!", __func__, skipped_tokens);
|
||||
}
|
||||
|
||||
// Decode user tokens in batches
|
||||
if (decode_tokens_in_batches(g_context, g_batch, user_tokens, current_position, true)) {
|
||||
LOGe("%s: llama_decode() failed!", __func__);
|
||||
return 2;
|
||||
}
|
||||
|
||||
// Update position
|
||||
current_position += user_prompt_size;
|
||||
stop_generation_position = current_position + user_prompt_size + n_predict;
|
||||
return 0;
|
||||
}
|
||||
|
||||
static bool is_valid_utf8(const char *string) {
|
||||
if (!string) { return true; }
|
||||
|
||||
const auto *bytes = (const unsigned char *) string;
|
||||
int num;
|
||||
|
||||
while (*bytes != 0x00) {
|
||||
if ((*bytes & 0x80) == 0x00) {
|
||||
// U+0000 to U+007F
|
||||
num = 1;
|
||||
} else if ((*bytes & 0xE0) == 0xC0) {
|
||||
// U+0080 to U+07FF
|
||||
num = 2;
|
||||
} else if ((*bytes & 0xF0) == 0xE0) {
|
||||
// U+0800 to U+FFFF
|
||||
num = 3;
|
||||
} else if ((*bytes & 0xF8) == 0xF0) {
|
||||
// U+10000 to U+10FFFF
|
||||
num = 4;
|
||||
} else {
|
||||
return false;
|
||||
}
|
||||
|
||||
bytes += 1;
|
||||
for (int i = 1; i < num; ++i) {
|
||||
if ((*bytes & 0xC0) != 0x80) {
|
||||
return false;
|
||||
}
|
||||
bytes += 1;
|
||||
}
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
extern "C"
|
||||
JNIEXPORT jstring JNICALL
|
||||
Java_com_arm_aichat_internal_InferenceEngineImpl_generateNextToken(
|
||||
JNIEnv *env,
|
||||
jobject /*unused*/
|
||||
) {
|
||||
// Infinite text generation via context shifting
|
||||
if (current_position >= DEFAULT_CONTEXT_SIZE - OVERFLOW_HEADROOM) {
|
||||
LOGw("%s: Context full! Shifting...", __func__);
|
||||
shift_context();
|
||||
}
|
||||
|
||||
// Stop if reaching the marked position
|
||||
if (current_position >= stop_generation_position) {
|
||||
LOGw("%s: STOP: hitting stop position: %d", __func__, stop_generation_position);
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
// Sample next token
|
||||
const auto new_token_id = common_sampler_sample(g_sampler, g_context, -1);
|
||||
common_sampler_accept(g_sampler, new_token_id, true);
|
||||
|
||||
// Populate the batch with new token, then decode
|
||||
common_batch_clear(g_batch);
|
||||
common_batch_add(g_batch, new_token_id, current_position, {0}, true);
|
||||
if (llama_decode(g_context, g_batch) != 0) {
|
||||
LOGe("%s: llama_decode() failed for generated token", __func__);
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
// Update position
|
||||
current_position++;
|
||||
|
||||
// Stop if next token is EOG
|
||||
if (llama_vocab_is_eog(llama_model_get_vocab(g_model), new_token_id)) {
|
||||
LOGd("id: %d,\tIS EOG!\nSTOP.", new_token_id);
|
||||
chat_add_and_format(ROLE_ASSISTANT, assistant_ss.str());
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
// If not EOG, convert to text
|
||||
auto new_token_chars = common_token_to_piece(g_context, new_token_id);
|
||||
cached_token_chars += new_token_chars;
|
||||
|
||||
// Create and return a valid UTF-8 Java string
|
||||
jstring result = nullptr;
|
||||
if (is_valid_utf8(cached_token_chars.c_str())) {
|
||||
result = env->NewStringUTF(cached_token_chars.c_str());
|
||||
LOGv("id: %d,\tcached: `%s`,\tnew: `%s`", new_token_id, cached_token_chars.c_str(), new_token_chars.c_str());
|
||||
|
||||
assistant_ss << cached_token_chars;
|
||||
cached_token_chars.clear();
|
||||
} else {
|
||||
LOGv("id: %d,\tappend to cache", new_token_id);
|
||||
result = env->NewStringUTF("");
|
||||
}
|
||||
return result;
|
||||
}
|
||||
|
||||
|
||||
extern "C"
|
||||
JNIEXPORT void JNICALL
|
||||
Java_com_arm_aichat_internal_InferenceEngineImpl_unload(JNIEnv * /*unused*/, jobject /*unused*/) {
|
||||
// Reset long-term & short-term states
|
||||
reset_long_term_states();
|
||||
reset_short_term_states();
|
||||
|
||||
// Free up resources
|
||||
common_sampler_free(g_sampler);
|
||||
g_chat_templates.reset();
|
||||
llama_batch_free(g_batch);
|
||||
llama_free(g_context);
|
||||
llama_model_free(g_model);
|
||||
}
|
||||
|
||||
extern "C"
|
||||
JNIEXPORT void JNICALL
|
||||
Java_com_arm_aichat_internal_InferenceEngineImpl_shutdown(JNIEnv *, jobject /*unused*/) {
|
||||
llama_backend_free();
|
||||
}
|
||||
61
examples/llama.android/lib/src/main/cpp/logging.h
Normal file
61
examples/llama.android/lib/src/main/cpp/logging.h
Normal file
@@ -0,0 +1,61 @@
|
||||
//
|
||||
// Created by Han Yin on 10/31/25.
|
||||
//
|
||||
|
||||
#ifndef AICHAT_LOGGING_H
|
||||
#define AICHAT_LOGGING_H
|
||||
|
||||
#endif //AICHAT_LOGGING_H
|
||||
|
||||
#pragma once
|
||||
#include <android/log.h>
|
||||
|
||||
#ifndef LOG_TAG
|
||||
#define LOG_TAG "ai-chat"
|
||||
#endif
|
||||
|
||||
#ifndef LOG_MIN_LEVEL
|
||||
#if defined(NDEBUG)
|
||||
#define LOG_MIN_LEVEL ANDROID_LOG_INFO
|
||||
#else
|
||||
#define LOG_MIN_LEVEL ANDROID_LOG_VERBOSE
|
||||
#endif
|
||||
#endif
|
||||
|
||||
static inline int ai_should_log(int prio) {
|
||||
return __android_log_is_loggable(prio, LOG_TAG, LOG_MIN_LEVEL);
|
||||
}
|
||||
|
||||
#if LOG_MIN_LEVEL <= ANDROID_LOG_VERBOSE
|
||||
#define LOGv(...) do { if (ai_should_log(ANDROID_LOG_VERBOSE)) __android_log_print(ANDROID_LOG_VERBOSE, LOG_TAG, __VA_ARGS__); } while (0)
|
||||
#else
|
||||
#define LOGv(...) ((void)0)
|
||||
#endif
|
||||
|
||||
#if LOG_MIN_LEVEL <= ANDROID_LOG_DEBUG
|
||||
#define LOGd(...) do { if (ai_should_log(ANDROID_LOG_DEBUG)) __android_log_print(ANDROID_LOG_DEBUG, LOG_TAG, __VA_ARGS__); } while (0)
|
||||
#else
|
||||
#define LOGd(...) ((void)0)
|
||||
#endif
|
||||
|
||||
#define LOGi(...) do { if (ai_should_log(ANDROID_LOG_INFO )) __android_log_print(ANDROID_LOG_INFO , LOG_TAG, __VA_ARGS__); } while (0)
|
||||
#define LOGw(...) do { if (ai_should_log(ANDROID_LOG_WARN )) __android_log_print(ANDROID_LOG_WARN , LOG_TAG, __VA_ARGS__); } while (0)
|
||||
#define LOGe(...) do { if (ai_should_log(ANDROID_LOG_ERROR)) __android_log_print(ANDROID_LOG_ERROR, LOG_TAG, __VA_ARGS__); } while (0)
|
||||
|
||||
static inline int android_log_prio_from_ggml(enum ggml_log_level level) {
|
||||
switch (level) {
|
||||
case GGML_LOG_LEVEL_ERROR: return ANDROID_LOG_ERROR;
|
||||
case GGML_LOG_LEVEL_WARN: return ANDROID_LOG_WARN;
|
||||
case GGML_LOG_LEVEL_INFO: return ANDROID_LOG_INFO;
|
||||
case GGML_LOG_LEVEL_DEBUG: return ANDROID_LOG_DEBUG;
|
||||
default: return ANDROID_LOG_DEFAULT;
|
||||
}
|
||||
}
|
||||
|
||||
static inline void aichat_android_log_callback(enum ggml_log_level level,
|
||||
const char* text,
|
||||
void* /*user*/) {
|
||||
const int prio = android_log_prio_from_ggml(level);
|
||||
if (!ai_should_log(prio)) return;
|
||||
__android_log_write(prio, LOG_TAG, text);
|
||||
}
|
||||
@@ -0,0 +1,14 @@
|
||||
package com.arm.aichat
|
||||
|
||||
import android.content.Context
|
||||
import com.arm.aichat.internal.InferenceEngineImpl
|
||||
|
||||
/**
|
||||
* Main entry point for Arm's AI Chat library.
|
||||
*/
|
||||
object AiChat {
|
||||
/**
|
||||
* Get the inference engine single instance.
|
||||
*/
|
||||
fun getInferenceEngine(context: Context) = InferenceEngineImpl.getInstance(context)
|
||||
}
|
||||
@@ -0,0 +1,89 @@
|
||||
package com.arm.aichat
|
||||
|
||||
import com.arm.aichat.InferenceEngine.State
|
||||
import kotlinx.coroutines.flow.Flow
|
||||
import kotlinx.coroutines.flow.StateFlow
|
||||
|
||||
/**
|
||||
* Interface defining the core LLM inference operations.
|
||||
*/
|
||||
interface InferenceEngine {
|
||||
/**
|
||||
* Current state of the inference engine
|
||||
*/
|
||||
val state: StateFlow<State>
|
||||
|
||||
/**
|
||||
* Load a model from the given path.
|
||||
*
|
||||
* @throws UnsupportedArchitectureException if model architecture not supported
|
||||
*/
|
||||
suspend fun loadModel(pathToModel: String)
|
||||
|
||||
/**
|
||||
* Sends a system prompt to the loaded model
|
||||
*/
|
||||
suspend fun setSystemPrompt(systemPrompt: String)
|
||||
|
||||
/**
|
||||
* Sends a user prompt to the loaded model and returns a Flow of generated tokens.
|
||||
*/
|
||||
fun sendUserPrompt(message: String, predictLength: Int = DEFAULT_PREDICT_LENGTH): Flow<String>
|
||||
|
||||
/**
|
||||
* Runs a benchmark with the specified parameters.
|
||||
*/
|
||||
suspend fun bench(pp: Int, tg: Int, pl: Int, nr: Int = 1): String
|
||||
|
||||
/**
|
||||
* Unloads the currently loaded model.
|
||||
*/
|
||||
fun cleanUp()
|
||||
|
||||
/**
|
||||
* Cleans up resources when the engine is no longer needed.
|
||||
*/
|
||||
fun destroy()
|
||||
|
||||
/**
|
||||
* States of the inference engine
|
||||
*/
|
||||
sealed class State {
|
||||
object Uninitialized : State()
|
||||
object Initializing : State()
|
||||
object Initialized : State()
|
||||
|
||||
object LoadingModel : State()
|
||||
object UnloadingModel : State()
|
||||
object ModelReady : State()
|
||||
|
||||
object Benchmarking : State()
|
||||
object ProcessingSystemPrompt : State()
|
||||
object ProcessingUserPrompt : State()
|
||||
|
||||
object Generating : State()
|
||||
|
||||
data class Error(val exception: Exception) : State()
|
||||
}
|
||||
|
||||
companion object {
|
||||
const val DEFAULT_PREDICT_LENGTH = 1024
|
||||
}
|
||||
}
|
||||
|
||||
val State.isUninterruptible
|
||||
get() = this is State.Initializing ||
|
||||
this is State.LoadingModel ||
|
||||
this is State.UnloadingModel ||
|
||||
this is State.Benchmarking ||
|
||||
this is State.ProcessingSystemPrompt ||
|
||||
this is State.ProcessingUserPrompt
|
||||
|
||||
val State.isModelLoaded: Boolean
|
||||
get() = this is State.ModelReady ||
|
||||
this is State.Benchmarking ||
|
||||
this is State.ProcessingSystemPrompt ||
|
||||
this is State.ProcessingUserPrompt ||
|
||||
this is State.Generating
|
||||
|
||||
class UnsupportedArchitectureException : Exception()
|
||||
@@ -0,0 +1,61 @@
|
||||
package com.arm.aichat.gguf
|
||||
|
||||
import kotlin.collections.get
|
||||
|
||||
|
||||
/**
|
||||
* Numerical codes used by `general.file_type` (see llama.cpp repo's `constants.py`).
|
||||
* The `label` matches what llama‑cli prints.
|
||||
*/
|
||||
enum class FileType(val code: Int, val label: String) {
|
||||
ALL_F32(0, "all F32"),
|
||||
MOSTLY_F16(1, "F16"),
|
||||
MOSTLY_Q4_0(2, "Q4_0"),
|
||||
MOSTLY_Q4_1(3, "Q4_1"),
|
||||
// 4 removed
|
||||
MOSTLY_Q8_0(7, "Q8_0"),
|
||||
MOSTLY_Q5_0(8, "Q5_0"),
|
||||
MOSTLY_Q5_1(9, "Q5_1"),
|
||||
|
||||
/* K‑quants ------------------------------------------------------------ */
|
||||
MOSTLY_Q2_K (10, "Q2_K - Medium"),
|
||||
MOSTLY_Q3_K_S (11, "Q3_K - Small"),
|
||||
MOSTLY_Q3_K_M (12, "Q3_K - Medium"),
|
||||
MOSTLY_Q3_K_L (13, "Q3_K - Large"),
|
||||
MOSTLY_Q4_K_S (14, "Q4_K - Small"),
|
||||
MOSTLY_Q4_K_M (15, "Q4_K - Medium"),
|
||||
MOSTLY_Q5_K_S (16, "Q5_K - Small"),
|
||||
MOSTLY_Q5_K_M (17, "Q5_K - Medium"),
|
||||
MOSTLY_Q6_K (18, "Q6_K"),
|
||||
|
||||
/* IQ quants ----------------------------------------------------------- */
|
||||
MOSTLY_IQ2_XXS (19, "IQ2_XXS - 2.06 bpw"),
|
||||
MOSTLY_IQ2_XS (20, "IQ2_XS - 2.31 bpw"),
|
||||
MOSTLY_Q2_K_S (21, "Q2_K - Small"),
|
||||
MOSTLY_IQ3_XS (22, "IQ3_XS - 3.30 bpw"),
|
||||
MOSTLY_IQ3_XXS (23, "IQ3_XXS - 3.06 bpw"),
|
||||
MOSTLY_IQ1_S (24, "IQ1_S - 1.56 bpw"),
|
||||
MOSTLY_IQ4_NL (25, "IQ4_NL - 4.5 bpw"),
|
||||
MOSTLY_IQ3_S (26, "IQ3_S - 3.44 bpw"),
|
||||
MOSTLY_IQ3_M (27, "IQ3_M - 3.66 bpw"),
|
||||
MOSTLY_IQ2_S (28, "IQ2_S - 2.50 bpw"),
|
||||
MOSTLY_IQ2_M (29, "IQ2_M - 2.70 bpw"),
|
||||
MOSTLY_IQ4_XS (30, "IQ4_XS - 4.25 bpw"),
|
||||
MOSTLY_IQ1_M (31, "IQ1_M - 1.75 bpw"),
|
||||
|
||||
/* BF16 & Ternary ------------------------------------------------------ */
|
||||
MOSTLY_BF16 (32, "BF16"),
|
||||
MOSTLY_TQ1_0 (36, "TQ1_0 - 1.69 bpw ternary"),
|
||||
MOSTLY_TQ2_0 (37, "TQ2_0 - 2.06 bpw ternary"),
|
||||
|
||||
/* Special flag -------------------------------------------------------- */
|
||||
GUESSED(1024, "(guessed)"),
|
||||
|
||||
UNKNOWN(-1, "unknown");
|
||||
|
||||
companion object {
|
||||
private val map = entries.associateBy(FileType::code)
|
||||
|
||||
fun fromCode(code: Int?): FileType = map[code] ?: UNKNOWN
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,132 @@
|
||||
package com.arm.aichat.gguf
|
||||
|
||||
import java.io.IOException
|
||||
|
||||
|
||||
/**
|
||||
* Structured metadata of GGUF
|
||||
*/
|
||||
data class GgufMetadata(
|
||||
// Basic file info
|
||||
val version: GgufVersion,
|
||||
val tensorCount: Long,
|
||||
val kvCount: Long,
|
||||
|
||||
// General info
|
||||
val basic: BasicInfo,
|
||||
val author: AuthorInfo? = null,
|
||||
val additional: AdditionalInfo? = null,
|
||||
val architecture: ArchitectureInfo? = null,
|
||||
val baseModels: List<BaseModelInfo>? = null,
|
||||
val tokenizer: TokenizerInfo? = null,
|
||||
|
||||
// Derivative info
|
||||
val dimensions: DimensionsInfo? = null,
|
||||
val attention: AttentionInfo? = null,
|
||||
val rope: RopeInfo? = null,
|
||||
val experts: ExpertsInfo? = null
|
||||
) {
|
||||
enum class GgufVersion(val code: Int, val label: String) {
|
||||
/** First public draft; little‑endian only, no alignment key. */
|
||||
LEGACY_V1(1, "Legacy v1"),
|
||||
|
||||
/** Added split‑file support and some extra metadata keys. */
|
||||
EXTENDED_V2(2, "Extended v2"),
|
||||
|
||||
/** Current spec: endian‑aware, mandatory alignment, fully validated. */
|
||||
VALIDATED_V3(3, "Validated v3");
|
||||
|
||||
companion object {
|
||||
fun fromCode(code: Int): GgufVersion =
|
||||
entries.firstOrNull { it.code == code }
|
||||
?: throw IOException("Unknown GGUF version code $code")
|
||||
}
|
||||
|
||||
override fun toString(): String = "$label (code=$code)"
|
||||
}
|
||||
|
||||
data class BasicInfo(
|
||||
val uuid: String? = null,
|
||||
val name: String? = null,
|
||||
val nameLabel: String? = null,
|
||||
val sizeLabel: String? = null, // Size label like "7B"
|
||||
)
|
||||
|
||||
data class AuthorInfo(
|
||||
val organization: String? = null,
|
||||
val author: String? = null,
|
||||
val doi: String? = null,
|
||||
val url: String? = null,
|
||||
val repoUrl: String? = null,
|
||||
val license: String? = null,
|
||||
val licenseLink: String? = null,
|
||||
)
|
||||
|
||||
data class AdditionalInfo(
|
||||
val type: String? = null,
|
||||
val description: String? = null,
|
||||
val tags: List<String>? = null,
|
||||
val languages: List<String>? = null,
|
||||
)
|
||||
|
||||
data class ArchitectureInfo(
|
||||
val architecture: String? = null,
|
||||
val fileType: Int? = null,
|
||||
val vocabSize: Int? = null,
|
||||
val finetune: String? = null,
|
||||
val quantizationVersion: Int? = null,
|
||||
)
|
||||
|
||||
data class BaseModelInfo(
|
||||
val name: String? = null,
|
||||
val author: String? = null,
|
||||
val version: String? = null,
|
||||
val organization: String? = null,
|
||||
val url: String? = null,
|
||||
val doi: String? = null,
|
||||
val uuid: String? = null,
|
||||
val repoUrl: String? = null,
|
||||
)
|
||||
|
||||
data class TokenizerInfo(
|
||||
val model: String? = null,
|
||||
val bosTokenId: Int? = null,
|
||||
val eosTokenId: Int? = null,
|
||||
val unknownTokenId: Int? = null,
|
||||
val paddingTokenId: Int? = null,
|
||||
val addBosToken: Boolean? = null,
|
||||
val addEosToken: Boolean? = null,
|
||||
val chatTemplate: String? = null,
|
||||
)
|
||||
|
||||
data class DimensionsInfo(
|
||||
val contextLength: Int? = null,
|
||||
val embeddingSize: Int? = null,
|
||||
val blockCount: Int? = null,
|
||||
val feedForwardSize: Int? = null,
|
||||
)
|
||||
|
||||
data class AttentionInfo(
|
||||
val headCount: Int? = null,
|
||||
val headCountKv: Int? = null,
|
||||
val keyLength: Int? = null,
|
||||
val valueLength: Int? = null,
|
||||
val layerNormEpsilon: Float? = null,
|
||||
val layerNormRmsEpsilon: Float? = null,
|
||||
)
|
||||
|
||||
data class RopeInfo(
|
||||
val frequencyBase: Float? = null,
|
||||
val dimensionCount: Int? = null,
|
||||
val scalingType: String? = null,
|
||||
val scalingFactor: Float? = null,
|
||||
val attnFactor: Float? = null,
|
||||
val originalContextLength: Int? = null,
|
||||
val finetuned: Boolean? = null,
|
||||
)
|
||||
|
||||
data class ExpertsInfo(
|
||||
val count: Int? = null,
|
||||
val usedCount: Int? = null,
|
||||
)
|
||||
}
|
||||
@@ -0,0 +1,77 @@
|
||||
package com.arm.aichat.gguf
|
||||
|
||||
import android.content.Context
|
||||
import android.net.Uri
|
||||
import com.arm.aichat.internal.gguf.GgufMetadataReaderImpl
|
||||
import java.io.File
|
||||
import java.io.IOException
|
||||
import java.io.InputStream
|
||||
|
||||
/**
|
||||
* Interface for reading GGUF metadata from model files.
|
||||
* Use `GgufMetadataReader.create()` to get an instance.
|
||||
*/
|
||||
interface GgufMetadataReader {
|
||||
/**
|
||||
* Reads the magic number from the specified file path.
|
||||
*
|
||||
* @param file Java File to the GGUF file with absolute path
|
||||
* @return true if file is valid GGUF, otherwise false
|
||||
* @throws InvalidFileFormatException if file format is invalid
|
||||
*/
|
||||
suspend fun ensureSourceFileFormat(file: File): Boolean
|
||||
|
||||
/**
|
||||
* Reads the magic number from the specified file path.
|
||||
*
|
||||
* @param context Context for obtaining [android.content.ContentProvider]
|
||||
* @param uri Uri to the GGUF file provided by [android.content.ContentProvider]
|
||||
* @return true if file is valid GGUF, otherwise false
|
||||
* @throws InvalidFileFormatException if file format is invalid
|
||||
*/
|
||||
suspend fun ensureSourceFileFormat(context: Context, uri: Uri): Boolean
|
||||
|
||||
/**
|
||||
* Reads and parses GGUF metadata from the specified file path.
|
||||
*
|
||||
* @param input the [InputStream] obtained from a readable file or content
|
||||
* @return Structured metadata extracted from the file
|
||||
* @throws IOException if file is damaged or cannot be read
|
||||
* @throws InvalidFileFormatException if file format is invalid
|
||||
*/
|
||||
suspend fun readStructuredMetadata(input: InputStream): GgufMetadata
|
||||
|
||||
companion object {
|
||||
private val DEFAULT_SKIP_KEYS = setOf(
|
||||
"tokenizer.chat_template",
|
||||
"tokenizer.ggml.scores",
|
||||
"tokenizer.ggml.tokens",
|
||||
"tokenizer.ggml.token_type"
|
||||
)
|
||||
|
||||
/**
|
||||
* Creates a default GgufMetadataReader instance
|
||||
*/
|
||||
fun create(): GgufMetadataReader = GgufMetadataReaderImpl(
|
||||
skipKeys = DEFAULT_SKIP_KEYS,
|
||||
arraySummariseThreshold = 1_000
|
||||
)
|
||||
|
||||
/**
|
||||
* Creates a GgufMetadataReader with custom configuration
|
||||
*
|
||||
* @param skipKeys Keys whose value should be skipped entirely (not kept in the result map)
|
||||
* @param arraySummariseThreshold If ≥0, arrays longer get summarised, not materialised;
|
||||
* If -1, never summarise.
|
||||
*/
|
||||
fun create(
|
||||
skipKeys: Set<String> = DEFAULT_SKIP_KEYS,
|
||||
arraySummariseThreshold: Int = 1_000
|
||||
): GgufMetadataReader = GgufMetadataReaderImpl(
|
||||
skipKeys = skipKeys,
|
||||
arraySummariseThreshold = arraySummariseThreshold
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
class InvalidFileFormatException : IOException()
|
||||
@@ -0,0 +1,324 @@
|
||||
package com.arm.aichat.internal
|
||||
|
||||
import android.content.Context
|
||||
import android.util.Log
|
||||
import com.arm.aichat.InferenceEngine
|
||||
import com.arm.aichat.UnsupportedArchitectureException
|
||||
import com.arm.aichat.internal.InferenceEngineImpl.Companion.getInstance
|
||||
import dalvik.annotation.optimization.FastNative
|
||||
import kotlinx.coroutines.CancellationException
|
||||
import kotlinx.coroutines.CoroutineScope
|
||||
import kotlinx.coroutines.Dispatchers
|
||||
import kotlinx.coroutines.ExperimentalCoroutinesApi
|
||||
import kotlinx.coroutines.SupervisorJob
|
||||
import kotlinx.coroutines.cancel
|
||||
import kotlinx.coroutines.flow.Flow
|
||||
import kotlinx.coroutines.flow.MutableStateFlow
|
||||
import kotlinx.coroutines.flow.StateFlow
|
||||
import kotlinx.coroutines.flow.asStateFlow
|
||||
import kotlinx.coroutines.flow.flow
|
||||
import kotlinx.coroutines.flow.flowOn
|
||||
import kotlinx.coroutines.launch
|
||||
import kotlinx.coroutines.runBlocking
|
||||
import kotlinx.coroutines.withContext
|
||||
import java.io.File
|
||||
import java.io.IOException
|
||||
|
||||
/**
|
||||
* JNI wrapper for the llama.cpp library providing Android-friendly access to large language models.
|
||||
*
|
||||
* This class implements a singleton pattern for managing the lifecycle of a single LLM instance.
|
||||
* All operations are executed on a dedicated single-threaded dispatcher to ensure thread safety
|
||||
* with the underlying C++ native code.
|
||||
*
|
||||
* The typical usage flow is:
|
||||
* 1. Get instance via [getInstance]
|
||||
* 2. Load a model with [loadModel]
|
||||
* 3. Send prompts with [sendUserPrompt]
|
||||
* 4. Generate responses as token streams
|
||||
* 5. Perform [cleanUp] when done with a model
|
||||
* 6. Properly [destroy] when completely done
|
||||
*
|
||||
* State transitions are managed automatically and validated at each operation.
|
||||
*
|
||||
* @see ai_chat.cpp for the native implementation details
|
||||
*/
|
||||
internal class InferenceEngineImpl private constructor(
|
||||
private val nativeLibDir: String
|
||||
) : InferenceEngine {
|
||||
|
||||
companion object {
|
||||
private val TAG = InferenceEngineImpl::class.java.simpleName
|
||||
|
||||
@Volatile
|
||||
private var instance: InferenceEngine? = null
|
||||
|
||||
/**
|
||||
* Create or obtain [InferenceEngineImpl]'s single instance.
|
||||
*
|
||||
* @param Context for obtaining native library directory
|
||||
* @throws IllegalArgumentException if native library path is invalid
|
||||
* @throws UnsatisfiedLinkError if library failed to load
|
||||
*/
|
||||
internal fun getInstance(context: Context) =
|
||||
instance ?: synchronized(this) {
|
||||
val nativeLibDir = context.applicationInfo.nativeLibraryDir
|
||||
require(nativeLibDir.isNotBlank()) { "Expected a valid native library path!" }
|
||||
|
||||
try {
|
||||
Log.i(TAG, "Instantiating InferenceEngineImpl,,,")
|
||||
InferenceEngineImpl(nativeLibDir).also { instance = it }
|
||||
} catch (e: UnsatisfiedLinkError) {
|
||||
Log.e(TAG, "Failed to load native library from $nativeLibDir", e)
|
||||
throw e
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* JNI methods
|
||||
* @see ai_chat.cpp
|
||||
*/
|
||||
@FastNative
|
||||
private external fun init(nativeLibDir: String)
|
||||
|
||||
@FastNative
|
||||
private external fun load(modelPath: String): Int
|
||||
|
||||
@FastNative
|
||||
private external fun prepare(): Int
|
||||
|
||||
@FastNative
|
||||
private external fun systemInfo(): String
|
||||
|
||||
@FastNative
|
||||
private external fun benchModel(pp: Int, tg: Int, pl: Int, nr: Int): String
|
||||
|
||||
@FastNative
|
||||
private external fun processSystemPrompt(systemPrompt: String): Int
|
||||
|
||||
@FastNative
|
||||
private external fun processUserPrompt(userPrompt: String, predictLength: Int): Int
|
||||
|
||||
@FastNative
|
||||
private external fun generateNextToken(): String?
|
||||
|
||||
@FastNative
|
||||
private external fun unload()
|
||||
|
||||
@FastNative
|
||||
private external fun shutdown()
|
||||
|
||||
private val _state =
|
||||
MutableStateFlow<InferenceEngine.State>(InferenceEngine.State.Uninitialized)
|
||||
override val state: StateFlow<InferenceEngine.State> = _state.asStateFlow()
|
||||
|
||||
private var _readyForSystemPrompt = false
|
||||
@Volatile
|
||||
private var _cancelGeneration = false
|
||||
|
||||
/**
|
||||
* Single-threaded coroutine dispatcher & scope for LLama asynchronous operations
|
||||
*/
|
||||
@OptIn(ExperimentalCoroutinesApi::class)
|
||||
private val llamaDispatcher = Dispatchers.IO.limitedParallelism(1)
|
||||
private val llamaScope = CoroutineScope(llamaDispatcher + SupervisorJob())
|
||||
|
||||
init {
|
||||
llamaScope.launch {
|
||||
try {
|
||||
check(_state.value is InferenceEngine.State.Uninitialized) {
|
||||
"Cannot load native library in ${_state.value.javaClass.simpleName}!"
|
||||
}
|
||||
_state.value = InferenceEngine.State.Initializing
|
||||
Log.i(TAG, "Loading native library...")
|
||||
System.loadLibrary("ai-chat")
|
||||
init(nativeLibDir)
|
||||
_state.value = InferenceEngine.State.Initialized
|
||||
Log.i(TAG, "Native library loaded! System info: \n${systemInfo()}")
|
||||
|
||||
} catch (e: Exception) {
|
||||
Log.e(TAG, "Failed to load native library", e)
|
||||
throw e
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Load the LLM
|
||||
*/
|
||||
override suspend fun loadModel(pathToModel: String) =
|
||||
withContext(llamaDispatcher) {
|
||||
check(_state.value is InferenceEngine.State.Initialized) {
|
||||
"Cannot load model in ${_state.value.javaClass.simpleName}!"
|
||||
}
|
||||
|
||||
try {
|
||||
Log.i(TAG, "Checking access to model file... \n$pathToModel")
|
||||
File(pathToModel).let {
|
||||
require(it.exists()) { "File not found" }
|
||||
require(it.isFile) { "Not a valid file" }
|
||||
require(it.canRead()) { "Cannot read file" }
|
||||
}
|
||||
|
||||
Log.i(TAG, "Loading model... \n$pathToModel")
|
||||
_readyForSystemPrompt = false
|
||||
_state.value = InferenceEngine.State.LoadingModel
|
||||
load(pathToModel).let {
|
||||
// TODO-han.yin: find a better way to pass other error codes
|
||||
if (it != 0) throw UnsupportedArchitectureException()
|
||||
}
|
||||
prepare().let {
|
||||
if (it != 0) throw IOException("Failed to prepare resources")
|
||||
}
|
||||
Log.i(TAG, "Model loaded!")
|
||||
_readyForSystemPrompt = true
|
||||
|
||||
_cancelGeneration = false
|
||||
_state.value = InferenceEngine.State.ModelReady
|
||||
} catch (e: Exception) {
|
||||
Log.e(TAG, (e.message ?: "Error loading model") + "\n" + pathToModel, e)
|
||||
_state.value = InferenceEngine.State.Error(e)
|
||||
throw e
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Process the plain text system prompt
|
||||
*
|
||||
* TODO-han.yin: return error code if system prompt not correct processed?
|
||||
*/
|
||||
override suspend fun setSystemPrompt(prompt: String) =
|
||||
withContext(llamaDispatcher) {
|
||||
require(prompt.isNotBlank()) { "Cannot process empty system prompt!" }
|
||||
check(_readyForSystemPrompt) { "System prompt must be set ** RIGHT AFTER ** model loaded!" }
|
||||
check(_state.value is InferenceEngine.State.ModelReady) {
|
||||
"Cannot process system prompt in ${_state.value.javaClass.simpleName}!"
|
||||
}
|
||||
|
||||
Log.i(TAG, "Sending system prompt...")
|
||||
_readyForSystemPrompt = false
|
||||
_state.value = InferenceEngine.State.ProcessingSystemPrompt
|
||||
processSystemPrompt(prompt).let { result ->
|
||||
if (result != 0) {
|
||||
RuntimeException("Failed to process system prompt: $result").also {
|
||||
_state.value = InferenceEngine.State.Error(it)
|
||||
throw it
|
||||
}
|
||||
}
|
||||
}
|
||||
Log.i(TAG, "System prompt processed! Awaiting user prompt...")
|
||||
_state.value = InferenceEngine.State.ModelReady
|
||||
}
|
||||
|
||||
/**
|
||||
* Send plain text user prompt to LLM, which starts generating tokens in a [Flow]
|
||||
*/
|
||||
override fun sendUserPrompt(
|
||||
message: String,
|
||||
predictLength: Int,
|
||||
): Flow<String> = flow {
|
||||
require(message.isNotEmpty()) { "User prompt discarded due to being empty!" }
|
||||
check(_state.value is InferenceEngine.State.ModelReady) {
|
||||
"User prompt discarded due to: ${_state.value.javaClass.simpleName}"
|
||||
}
|
||||
|
||||
try {
|
||||
Log.i(TAG, "Sending user prompt...")
|
||||
_readyForSystemPrompt = false
|
||||
_state.value = InferenceEngine.State.ProcessingUserPrompt
|
||||
|
||||
processUserPrompt(message, predictLength).let { result ->
|
||||
if (result != 0) {
|
||||
Log.e(TAG, "Failed to process user prompt: $result")
|
||||
return@flow
|
||||
}
|
||||
}
|
||||
|
||||
Log.i(TAG, "User prompt processed. Generating assistant prompt...")
|
||||
_state.value = InferenceEngine.State.Generating
|
||||
while (!_cancelGeneration) {
|
||||
generateNextToken()?.let { utf8token ->
|
||||
if (utf8token.isNotEmpty()) emit(utf8token)
|
||||
} ?: break
|
||||
}
|
||||
if (_cancelGeneration) {
|
||||
Log.i(TAG, "Assistant generation aborted per requested.")
|
||||
} else {
|
||||
Log.i(TAG, "Assistant generation complete. Awaiting user prompt...")
|
||||
}
|
||||
_state.value = InferenceEngine.State.ModelReady
|
||||
} catch (e: CancellationException) {
|
||||
Log.i(TAG, "Assistant generation's flow collection cancelled.")
|
||||
_state.value = InferenceEngine.State.ModelReady
|
||||
throw e
|
||||
} catch (e: Exception) {
|
||||
Log.e(TAG, "Error during generation!", e)
|
||||
_state.value = InferenceEngine.State.Error(e)
|
||||
throw e
|
||||
}
|
||||
}.flowOn(llamaDispatcher)
|
||||
|
||||
/**
|
||||
* Benchmark the model
|
||||
*/
|
||||
override suspend fun bench(pp: Int, tg: Int, pl: Int, nr: Int): String =
|
||||
withContext(llamaDispatcher) {
|
||||
check(_state.value is InferenceEngine.State.ModelReady) {
|
||||
"Benchmark request discarded due to: $state"
|
||||
}
|
||||
Log.i(TAG, "Start benchmark (pp: $pp, tg: $tg, pl: $pl, nr: $nr)")
|
||||
_readyForSystemPrompt = false // Just to be safe
|
||||
_state.value = InferenceEngine.State.Benchmarking
|
||||
benchModel(pp, tg, pl, nr).also {
|
||||
_state.value = InferenceEngine.State.ModelReady
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Unloads the model and frees resources, or reset error states
|
||||
*/
|
||||
override fun cleanUp() {
|
||||
_cancelGeneration = true
|
||||
runBlocking(llamaDispatcher) {
|
||||
when (val state = _state.value) {
|
||||
is InferenceEngine.State.ModelReady -> {
|
||||
Log.i(TAG, "Unloading model and free resources...")
|
||||
_readyForSystemPrompt = false
|
||||
_state.value = InferenceEngine.State.UnloadingModel
|
||||
|
||||
unload()
|
||||
|
||||
_state.value = InferenceEngine.State.Initialized
|
||||
Log.i(TAG, "Model unloaded!")
|
||||
Unit
|
||||
}
|
||||
|
||||
is InferenceEngine.State.Error -> {
|
||||
Log.i(TAG, "Resetting error states...")
|
||||
_state.value = InferenceEngine.State.Initialized
|
||||
Log.i(TAG, "States reset!")
|
||||
Unit
|
||||
}
|
||||
|
||||
else -> throw IllegalStateException("Cannot unload model in ${state.javaClass.simpleName}")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Cancel all ongoing coroutines and free GGML backends
|
||||
*/
|
||||
override fun destroy() {
|
||||
_cancelGeneration = true
|
||||
runBlocking(llamaDispatcher) {
|
||||
_readyForSystemPrompt = false
|
||||
when(_state.value) {
|
||||
is InferenceEngine.State.Uninitialized -> {}
|
||||
is InferenceEngine.State.Initialized -> shutdown()
|
||||
else -> { unload(); shutdown() }
|
||||
}
|
||||
}
|
||||
llamaScope.cancel()
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,590 @@
|
||||
package com.arm.aichat.internal.gguf
|
||||
|
||||
import android.content.Context
|
||||
import android.net.Uri
|
||||
import com.arm.aichat.gguf.GgufMetadata
|
||||
import com.arm.aichat.gguf.GgufMetadataReader
|
||||
import com.arm.aichat.gguf.InvalidFileFormatException
|
||||
import java.io.File
|
||||
import java.io.IOException
|
||||
import java.io.InputStream
|
||||
|
||||
|
||||
/**
|
||||
* Utility class to read GGUF model files and extract metadata key-value pairs.
|
||||
* This parser reads the header and metadata of a GGUF v3 file (little-endian) and skips tensor data.
|
||||
*/
|
||||
internal class GgufMetadataReaderImpl(
|
||||
private val skipKeys: Set<String>,
|
||||
private val arraySummariseThreshold: Int,
|
||||
) : GgufMetadataReader {
|
||||
companion object {
|
||||
private const val ARCH_LLAMA = "llama"
|
||||
}
|
||||
|
||||
/** Enum corresponding to GGUF metadata value types (for convenience and array element typing). */
|
||||
enum class MetadataType(val code: Int) {
|
||||
UINT8(0), INT8(1), UINT16(2), INT16(3),
|
||||
UINT32(4), INT32(5), FLOAT32(6), BOOL(7),
|
||||
STRING(8), ARRAY(9), UINT64(10), INT64(11), FLOAT64(12);
|
||||
companion object {
|
||||
private val codeMap = entries.associateBy(MetadataType::code)
|
||||
fun fromCode(code: Int): MetadataType = codeMap[code]
|
||||
?: throw IOException("Unknown metadata value type code: $code")
|
||||
}
|
||||
}
|
||||
|
||||
/** Sealed class hierarchy for metadata values, providing type-safe representations for each GGUF metadata type. */
|
||||
sealed class MetadataValue {
|
||||
data class UInt8(val value: UByte) : MetadataValue() // 0: 8-bit unsigned int
|
||||
data class Int8(val value: Byte) : MetadataValue() // 1: 8-bit signed int
|
||||
data class UInt16(val value: UShort) : MetadataValue() // 2: 16-bit unsigned int (little-endian)
|
||||
data class Int16(val value: Short) : MetadataValue() // 3: 16-bit signed int (little-endian)
|
||||
data class UInt32(val value: UInt) : MetadataValue() // 4: 32-bit unsigned int (little-endian)
|
||||
data class Int32(val value: Int) : MetadataValue() // 5: 32-bit signed int (little-endian)
|
||||
data class Float32(val value: Float) : MetadataValue() // 6: 32-bit IEEE754 float
|
||||
data class Bool(val value: Boolean) : MetadataValue() // 7: Boolean (1-byte, 0=false, 1=true)
|
||||
data class StringVal(val value: String) : MetadataValue() // 8: UTF-8 string (length-prefixed)
|
||||
data class ArrayVal(val elementType: MetadataType, val elements: List<MetadataValue>) : MetadataValue()
|
||||
data class UInt64(val value: ULong) : MetadataValue() // 10: 64-bit unsigned int (little-endian)
|
||||
data class Int64(val value: Long) : MetadataValue() // 11: 64-bit signed int (little-endian)
|
||||
data class Float64(val value: Double) : MetadataValue() // 12: 64-bit IEEE754 double
|
||||
}
|
||||
|
||||
/* Convert MetadataValue to plain Kotlin primitives for allMetadata map */
|
||||
private fun MetadataValue.toPrimitive(): Any = when (this) {
|
||||
is MetadataValue.UInt8 -> value
|
||||
is MetadataValue.Int8 -> value
|
||||
is MetadataValue.UInt16 -> value
|
||||
is MetadataValue.Int16 -> value
|
||||
is MetadataValue.UInt32 -> value
|
||||
is MetadataValue.Int32 -> value
|
||||
is MetadataValue.Float32 -> value
|
||||
is MetadataValue.Bool -> value
|
||||
is MetadataValue.StringVal -> value
|
||||
is MetadataValue.UInt64 -> value
|
||||
is MetadataValue.Int64 -> value
|
||||
is MetadataValue.Float64 -> value
|
||||
is MetadataValue.ArrayVal -> elements.map { it.toPrimitive() }
|
||||
}
|
||||
|
||||
/**
|
||||
* Reads the magic number from the specified file path.
|
||||
*
|
||||
* @param context Context for obtaining ContentResolver
|
||||
* @param uri Uri to the GGUF file provided by ContentProvider
|
||||
* @return true if file is valid GGUF, otherwise false
|
||||
*/
|
||||
override suspend fun ensureSourceFileFormat(file: File): Boolean =
|
||||
file.inputStream().buffered().use { ensureMagic(it) }
|
||||
|
||||
/**
|
||||
* Reads the magic number from the specified file path.
|
||||
*
|
||||
* @param context Context for obtaining ContentResolver
|
||||
* @param uri Uri to the GGUF file provided by ContentProvider
|
||||
* @return true if file is valid GGUF, otherwise false
|
||||
*/
|
||||
override suspend fun ensureSourceFileFormat(context: Context, uri: Uri): Boolean =
|
||||
context.contentResolver.openInputStream(uri)?.buffered()?.use { ensureMagic(it) } == true
|
||||
|
||||
/** Reads the 4‑byte magic; throws if magic ≠ "GGUF". */
|
||||
private fun ensureMagic(input: InputStream): Boolean =
|
||||
ByteArray(4).let {
|
||||
if (input.read(it) != 4) throw IOException("Not a valid file!")
|
||||
it.contentEquals(byteArrayOf(0x47, 0x47, 0x55, 0x46)) // "GGUF"
|
||||
}
|
||||
|
||||
/**
|
||||
* High‑level entry point: parses a `.gguf` file on disk and returns the fully
|
||||
* populated [GgufMetadata] tree.
|
||||
*
|
||||
* Steps performed internally:
|
||||
* 1. Reads and validates the 8‑byte header (`"GGUF"` magic + version).
|
||||
* 2. Streams through the key‑value section, skipping large blobs if the key
|
||||
* appears in [skipKeys] or if an array exceeds [arraySummariseThreshold].
|
||||
* 3. Converts the resulting raw map into strongly‑typed sub‑structures
|
||||
* (basic info, tokenizer, rope, etc.).
|
||||
*
|
||||
* The method is STREAMING‑ONLY: tensors are never mapped or loaded into
|
||||
* memory, so even multi‑GB model files can be processed in < 50 ms.
|
||||
*
|
||||
* @param path Absolute or relative filesystem path to a `.gguf` file.
|
||||
* @return A [GgufMetadata] instance containing all recognised metadata plus
|
||||
* an `allMetadata` map with any keys that were not given a dedicated
|
||||
* field.
|
||||
* @throws IOException if the file is not GGUF, the version is unsupported,
|
||||
* or the metadata block is truncated / corrupt.
|
||||
*/
|
||||
override suspend fun readStructuredMetadata(input: InputStream): GgufMetadata {
|
||||
// ── 1. header ──────────────────────────────────────────────────────────
|
||||
// throws on mismatch
|
||||
val version = ensureMagicAndVersion(input)
|
||||
val tensorCount = readLittleLong(input)
|
||||
val kvCount = readLittleLong(input)
|
||||
|
||||
// ── 2. metadata map (reuse our raw parser, but we need access to the stream) ──
|
||||
val meta = readMetaMap(input, kvCount) // <String, MetadataValue>
|
||||
|
||||
// ── 3. build structured object ────────────────────────────────────────
|
||||
return buildStructured(meta, version, tensorCount, kvCount)
|
||||
}
|
||||
|
||||
/** Reads the 4‑byte magic + 4‑byte version; throws if magic ≠ "GGUF". */
|
||||
private fun ensureMagicAndVersion(input: InputStream): GgufMetadata.GgufVersion {
|
||||
if (!ensureMagic(input)) throw InvalidFileFormatException()
|
||||
return GgufMetadata.GgufVersion.fromCode(readLEUInt32(input))
|
||||
}
|
||||
|
||||
/**
|
||||
* Read an unsigned 32‑bit little‑endian integer.
|
||||
*
|
||||
* @throws IOException if fewer than four bytes are available.
|
||||
*/
|
||||
private fun readLEUInt32(input: InputStream): Int {
|
||||
val b0 = input.read(); val b1 = input.read(); val b2 = input.read(); val b3 = input.read()
|
||||
if (b3 == -1) throw IOException("Unexpected EOF while reading UInt32")
|
||||
return (b3 and 0xFF shl 24) or
|
||||
(b2 and 0xFF shl 16) or
|
||||
(b1 and 0xFF shl 8) or
|
||||
(b0 and 0xFF)
|
||||
}
|
||||
|
||||
/**
|
||||
* Low‑level helper that reads the entire “key-value” section from the current
|
||||
* stream position.
|
||||
*
|
||||
* @param input Open stream positioned JUST AFTER the header.
|
||||
* @param kvCnt Number of key‑value pairs (taken from the header).
|
||||
* @return Mutable map with one [MetadataValue] for every key that is NOT skipped.
|
||||
*
|
||||
* The function honours [skipKeys] and [arraySummariseThreshold] by invoking
|
||||
* [skipValue] or [parseValue] accordingly.
|
||||
*/
|
||||
private fun readMetaMap(input: InputStream, kvCnt: Long): Map<String, MetadataValue> =
|
||||
mutableMapOf<String, MetadataValue>().apply {
|
||||
repeat(kvCnt.toInt()) {
|
||||
val key = readString(input)
|
||||
val valueT = MetadataType.fromCode(littleEndianBytesToInt(input.readNBytesExact(4)))
|
||||
if (key in skipKeys) {
|
||||
skipValue(input, valueT)
|
||||
} else {
|
||||
this[key] = parseValue(input, valueT)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Converts a flat [Map]<[String], [MetadataValue]> into the strongly‑typed
|
||||
* [GgufMetadata] tree used by the rest of the app.
|
||||
*
|
||||
* Only the keys listed in the spec are copied into dedicated data classes;
|
||||
* everything else is preserved in `GgufMetadata.allMetadata`.
|
||||
*
|
||||
* @param m Raw key/value map.
|
||||
* @param version GGUF file‑format version (enum).
|
||||
* @param tensorCnt Number of tensors (from the header).
|
||||
* @param kvCnt Total metadata pair count (from the header).
|
||||
*/
|
||||
private fun buildStructured(
|
||||
m: Map<String, MetadataValue>,
|
||||
version: GgufMetadata.GgufVersion,
|
||||
tensorCnt: Long,
|
||||
kvCnt: Long
|
||||
): GgufMetadata {
|
||||
// ---------- helpers ----------
|
||||
fun String.str() = (m[this] as? MetadataValue.StringVal)?.value
|
||||
fun String.bool() = (m[this] as? MetadataValue.Bool)?.value
|
||||
fun String.i32() = (m[this] as? MetadataValue.Int32)?.value
|
||||
fun String.u32() = (m[this] as? MetadataValue.UInt32)?.value?.toInt()
|
||||
fun String.f32() = (m[this] as? MetadataValue.Float32)?.value
|
||||
fun String.f64() = (m[this] as? MetadataValue.Float64)?.value?.toFloat()
|
||||
fun String.strList(): List<String>? =
|
||||
(m[this] as? MetadataValue.ArrayVal)
|
||||
?.elements
|
||||
?.mapNotNull { (it as? MetadataValue.StringVal)?.value }
|
||||
|
||||
val arch = "general.architecture".str() ?: ARCH_LLAMA
|
||||
|
||||
// -------------- populate sections ----------------
|
||||
val basic = GgufMetadata.BasicInfo(
|
||||
uuid = "general.uuid".str(),
|
||||
name = "general.basename".str(),
|
||||
nameLabel = "general.name".str(),
|
||||
sizeLabel = "general.size_label".str()
|
||||
)
|
||||
|
||||
val author = GgufMetadata.AuthorInfo(
|
||||
organization = "general.organization".str(),
|
||||
author = "general.author".str(),
|
||||
doi = "general.doi".str(),
|
||||
url = "general.url".str(),
|
||||
repoUrl = "general.repo_url".str(),
|
||||
license = "general.license".str(),
|
||||
licenseLink = "general.license.link".str()
|
||||
).takeUnless {
|
||||
organization == null && author == null && doi == null &&
|
||||
url == null && repoUrl == null && license == null && licenseLink == null
|
||||
}
|
||||
|
||||
val additional = GgufMetadata.AdditionalInfo(
|
||||
type = "general.type".str(),
|
||||
description = "general.description".str(),
|
||||
tags = "general.tags".strList(),
|
||||
languages = "general.languages".strList()
|
||||
).takeUnless {
|
||||
type == null && description == null && tags == null && languages == null
|
||||
}
|
||||
|
||||
val architectureInfo = GgufMetadata.ArchitectureInfo(
|
||||
architecture = arch,
|
||||
fileType = "general.file_type".u32(),
|
||||
vocabSize = "$arch.vocab_size".u32(),
|
||||
finetune = "general.finetune".str(),
|
||||
quantizationVersion = "general.quantization_version".u32()
|
||||
).takeUnless { fileType == null && vocabSize == null && finetune == null && quantizationVersion == null }
|
||||
|
||||
val baseModels = buildList {
|
||||
val n = "general.base_model.count".u32() ?: 0
|
||||
for (i in 0 until n) {
|
||||
fun k(s: String) = "general.base_model.$i.$s"
|
||||
add(
|
||||
GgufMetadata.BaseModelInfo(
|
||||
name = k("name").str(),
|
||||
author = k("author").str(),
|
||||
version = k("version").str(),
|
||||
organization = k("organization").str(),
|
||||
url = k("url").str(),
|
||||
doi = k("doi").str(),
|
||||
uuid = k("uuid").str(),
|
||||
repoUrl = k("repo_url").str(),
|
||||
)
|
||||
)
|
||||
}
|
||||
}.takeIf { it.isNotEmpty() }
|
||||
|
||||
val tokenizer = GgufMetadata.TokenizerInfo(
|
||||
model = "tokenizer.ggml.model".str(),
|
||||
bosTokenId = "tokenizer.ggml.bos_token_id".u32(),
|
||||
eosTokenId = "tokenizer.ggml.eos_token_id".u32(),
|
||||
unknownTokenId = "tokenizer.ggml.unknown_token_id".u32(),
|
||||
paddingTokenId = "tokenizer.ggml.padding_token_id".u32(),
|
||||
addBosToken = "tokenizer.ggml.add_bos_token".bool(),
|
||||
addEosToken = "tokenizer.ggml.add_eos_token".bool(),
|
||||
chatTemplate = "tokenizer.chat_template".str()
|
||||
).takeUnless { model == null && bosTokenId == null && eosTokenId == null &&
|
||||
unknownTokenId == null && paddingTokenId == null &&
|
||||
addBosToken == null && addEosToken == null && chatTemplate == null
|
||||
}
|
||||
|
||||
val dimensions = GgufMetadata.DimensionsInfo(
|
||||
contextLength = "$arch.context_length".u32(),
|
||||
embeddingSize = "$arch.embedding_length".u32(),
|
||||
blockCount = "$arch.block_count".u32(),
|
||||
feedForwardSize = "$arch.feed_forward_length".u32()
|
||||
).takeUnless { contextLength == null && embeddingSize == null && blockCount == null && feedForwardSize == null }
|
||||
|
||||
val attention = GgufMetadata.AttentionInfo(
|
||||
headCount = "$arch.attention.head_count".u32(),
|
||||
headCountKv = "$arch.attention.head_count_kv".u32(),
|
||||
keyLength = "$arch.attention.key_length".u32(),
|
||||
valueLength = "$arch.attention.value_length".u32(),
|
||||
layerNormEpsilon = "$arch.attention.layer_norm_epsilon".f32(),
|
||||
layerNormRmsEpsilon = "$arch.attention.layer_norm_rms_epsilon".f32(),
|
||||
).takeUnless { headCount == null && headCountKv == null && keyLength == null && valueLength == null &&
|
||||
layerNormEpsilon == null && layerNormRmsEpsilon == null
|
||||
}
|
||||
|
||||
val rope = GgufMetadata.RopeInfo(
|
||||
frequencyBase = "$arch.rope.freq_base".f32(),
|
||||
dimensionCount = "$arch.rope.dimension_count".u32(),
|
||||
scalingType = "$arch.rope.scaling.type".str(),
|
||||
scalingFactor = "$arch.rope.scaling.factor".f32(),
|
||||
attnFactor = "$arch.rope.scaling.attn_factor".f32(),
|
||||
originalContextLength = "$arch.rope.scaling.original_context_length".u32(),
|
||||
finetuned = "$arch.rope.scaling.finetuned".bool()
|
||||
).takeUnless { frequencyBase == null && dimensionCount == null &&
|
||||
scalingType == null && scalingFactor == null && attnFactor == null &&
|
||||
originalContextLength == null && finetuned == null
|
||||
}
|
||||
|
||||
val experts = GgufMetadata.ExpertsInfo(
|
||||
count = "$arch.expert_count".u32(),
|
||||
usedCount = "$arch.expert_used_count".u32()
|
||||
).takeUnless { count == null && usedCount == null }
|
||||
|
||||
return GgufMetadata(
|
||||
version = version,
|
||||
tensorCount = tensorCnt,
|
||||
kvCount = kvCnt,
|
||||
basic = basic,
|
||||
author = author,
|
||||
additional = additional,
|
||||
architecture = architectureInfo,
|
||||
baseModels = baseModels,
|
||||
tokenizer = tokenizer,
|
||||
dimensions = dimensions,
|
||||
attention = attention,
|
||||
rope = rope,
|
||||
experts = experts
|
||||
)
|
||||
}
|
||||
|
||||
/**
|
||||
* Recursively parses a metadata value of the given type from the input stream.
|
||||
* @param input The input stream positioned at the start of the value.
|
||||
* @param type The metadata value type to parse.
|
||||
*/
|
||||
private fun parseValue(input: InputStream, type: MetadataType): MetadataValue = when (type) {
|
||||
MetadataType.UINT8 -> {
|
||||
// 1-byte unsigned integer
|
||||
val byteVal = input.read()
|
||||
if (byteVal == -1) throw IOException("Unexpected EOF while reading uint8 value.")
|
||||
MetadataValue.UInt8(byteVal.toUByte())
|
||||
}
|
||||
MetadataType.INT8 -> {
|
||||
// 1-byte signed integer
|
||||
val byteVal = input.read()
|
||||
if (byteVal == -1) throw IOException("Unexpected EOF while reading int8 value.")
|
||||
MetadataValue.Int8(byteVal.toByte())
|
||||
}
|
||||
MetadataType.UINT16 -> {
|
||||
// 2-byte unsigned integer (little-endian)
|
||||
val bytes = ByteArray(2)
|
||||
if (input.read(bytes) != 2) throw IOException("Unexpected EOF while reading uint16 value.")
|
||||
// Combine two bytes (little-endian) into an unsigned 16-bit value
|
||||
val u16 = ((bytes[1].toInt() and 0xFF) shl 8) or (bytes[0].toInt() and 0xFF)
|
||||
MetadataValue.UInt16(u16.toUShort())
|
||||
}
|
||||
MetadataType.INT16 -> {
|
||||
// 2-byte signed integer (little-endian)
|
||||
val bytes = ByteArray(2)
|
||||
if (input.read(bytes) != 2) throw IOException("Unexpected EOF while reading int16 value.")
|
||||
// Combine to 16-bit and interpret as signed
|
||||
val i16 = ((bytes[1].toInt() and 0xFF) shl 8) or (bytes[0].toInt() and 0xFF)
|
||||
MetadataValue.Int16(i16.toShort())
|
||||
}
|
||||
MetadataType.UINT32 -> {
|
||||
// 4-byte unsigned integer (little-endian)
|
||||
val bytes = ByteArray(4)
|
||||
if (input.read(bytes) != 4) throw IOException("Unexpected EOF while reading uint32 value.")
|
||||
// Combine four bytes into a 32-bit value (as Long to avoid overflow), then convert to UInt
|
||||
val u32 = (bytes[3].toLong() and 0xFFL shl 24) or
|
||||
(bytes[2].toLong() and 0xFFL shl 16) or
|
||||
(bytes[1].toLong() and 0xFFL shl 8) or
|
||||
(bytes[0].toLong() and 0xFFL)
|
||||
MetadataValue.UInt32(u32.toUInt())
|
||||
}
|
||||
MetadataType.INT32 -> {
|
||||
// 4-byte signed integer (little-endian)
|
||||
val bytes = ByteArray(4)
|
||||
if (input.read(bytes) != 4) throw IOException("Unexpected EOF while reading int32 value.")
|
||||
// Combine four bytes into a 32-bit signed int
|
||||
val i32 = (bytes[3].toInt() and 0xFF shl 24) or
|
||||
(bytes[2].toInt() and 0xFF shl 16) or
|
||||
(bytes[1].toInt() and 0xFF shl 8) or
|
||||
(bytes[0].toInt() and 0xFF)
|
||||
MetadataValue.Int32(i32)
|
||||
}
|
||||
MetadataType.FLOAT32 -> {
|
||||
// 4-byte IEEE 754 float (little-endian)
|
||||
val bytes = ByteArray(4)
|
||||
if (input.read(bytes) != 4) throw IOException("Unexpected EOF while reading float32 value.")
|
||||
// Assemble 4 bytes into a 32-bit int bit-pattern, then convert to Float
|
||||
val bits = (bytes[3].toInt() and 0xFF shl 24) or
|
||||
(bytes[2].toInt() and 0xFF shl 16) or
|
||||
(bytes[1].toInt() and 0xFF shl 8) or
|
||||
(bytes[0].toInt() and 0xFF)
|
||||
val floatVal = Float.fromBits(bits)
|
||||
MetadataValue.Float32(floatVal)
|
||||
}
|
||||
MetadataType.BOOL -> {
|
||||
// 1-byte boolean (0 = false, 1 = true)
|
||||
val byteVal = input.read()
|
||||
if (byteVal == -1) throw IOException("Unexpected EOF while reading boolean value.")
|
||||
if (byteVal != 0 && byteVal != 1) {
|
||||
throw IOException("Invalid boolean value: $byteVal (must be 0 or 1).")
|
||||
}
|
||||
MetadataValue.Bool(byteVal != 0)
|
||||
}
|
||||
MetadataType.STRING -> {
|
||||
// UTF-8 string (length-prefixed with 8-byte length)
|
||||
val str = readString(input)
|
||||
MetadataValue.StringVal(str)
|
||||
}
|
||||
MetadataType.ARRAY -> {
|
||||
val elemType = MetadataType.fromCode(littleEndianBytesToInt(input.readNBytesExact(4)))
|
||||
val len = readLittleLong(input)
|
||||
val count = len.toInt()
|
||||
|
||||
if (arraySummariseThreshold >= 0 && count > arraySummariseThreshold) {
|
||||
// fast‑forward without allocation
|
||||
repeat(count) { skipValue(input, elemType) }
|
||||
MetadataValue.StringVal("Array($elemType, $count items) /* summarised */")
|
||||
} else {
|
||||
val list = ArrayList<MetadataValue>(count)
|
||||
repeat(count) { list += parseValue(input, elemType) }
|
||||
MetadataValue.ArrayVal(elemType, list)
|
||||
}
|
||||
}
|
||||
MetadataType.UINT64 -> {
|
||||
// 8-byte unsigned integer (little-endian)
|
||||
val bytes = ByteArray(8)
|
||||
if (input.read(bytes) != 8) throw IOException("Unexpected EOF while reading uint64 value.")
|
||||
// Combine 8 bytes into an unsigned 64-bit (ULong). Use ULong for full 0 to 2^64-1 range.
|
||||
val u64 = (bytes[7].toULong() and 0xFFuL shl 56) or
|
||||
(bytes[6].toULong() and 0xFFuL shl 48) or
|
||||
(bytes[5].toULong() and 0xFFuL shl 40) or
|
||||
(bytes[4].toULong() and 0xFFuL shl 32) or
|
||||
(bytes[3].toULong() and 0xFFuL shl 24) or
|
||||
(bytes[2].toULong() and 0xFFuL shl 16) or
|
||||
(bytes[1].toULong() and 0xFFuL shl 8) or
|
||||
(bytes[0].toULong() and 0xFFuL)
|
||||
MetadataValue.UInt64(u64)
|
||||
}
|
||||
MetadataType.INT64 -> {
|
||||
// 8-byte signed integer (little-endian)
|
||||
val bytes = ByteArray(8)
|
||||
if (input.read(bytes) != 8) throw IOException("Unexpected EOF while reading int64 value.")
|
||||
// Combine 8 bytes into a signed 64-bit value (Long)
|
||||
val i64 = (bytes[7].toLong() and 0xFFL shl 56) or
|
||||
(bytes[6].toLong() and 0xFFL shl 48) or
|
||||
(bytes[5].toLong() and 0xFFL shl 40) or
|
||||
(bytes[4].toLong() and 0xFFL shl 32) or
|
||||
(bytes[3].toLong() and 0xFFL shl 24) or
|
||||
(bytes[2].toLong() and 0xFFL shl 16) or
|
||||
(bytes[1].toLong() and 0xFFL shl 8) or
|
||||
(bytes[0].toLong() and 0xFFL)
|
||||
MetadataValue.Int64(i64)
|
||||
}
|
||||
MetadataType.FLOAT64 -> {
|
||||
// 8-byte IEEE 754 double (little-endian)
|
||||
val bytes = ByteArray(8)
|
||||
if (input.read(bytes) != 8) throw IOException("Unexpected EOF while reading float64 value.")
|
||||
// Assemble 8 bytes into a 64-bit bit-pattern, then convert to Double
|
||||
val bits = (bytes[7].toLong() and 0xFFL shl 56) or
|
||||
(bytes[6].toLong() and 0xFFL shl 48) or
|
||||
(bytes[5].toLong() and 0xFFL shl 40) or
|
||||
(bytes[4].toLong() and 0xFFL shl 32) or
|
||||
(bytes[3].toLong() and 0xFFL shl 24) or
|
||||
(bytes[2].toLong() and 0xFFL shl 16) or
|
||||
(bytes[1].toLong() and 0xFFL shl 8) or
|
||||
(bytes[0].toLong() and 0xFFL)
|
||||
val doubleVal = Double.fromBits(bits)
|
||||
MetadataValue.Float64(doubleVal)
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
private fun <T> T?.takeUnless(check: T.() -> Boolean): T? =
|
||||
this?.takeIf { !it.check() }
|
||||
|
||||
/** Helper: Skip a value in the stream without storing it (still maintains pointer). */
|
||||
private fun skipValue(input: InputStream, type: MetadataType) {
|
||||
when (type) {
|
||||
MetadataType.UINT8, MetadataType.INT8, MetadataType.BOOL -> input.skipFully(1)
|
||||
MetadataType.UINT16, MetadataType.INT16 -> input.skipFully(2)
|
||||
MetadataType.UINT32, MetadataType.INT32, MetadataType.FLOAT32 -> input.skipFully(4)
|
||||
MetadataType.UINT64, MetadataType.INT64, MetadataType.FLOAT64 -> input.skipFully(8)
|
||||
MetadataType.STRING -> {
|
||||
val len = readLittleLong(input); input.skipFully(len)
|
||||
}
|
||||
MetadataType.ARRAY -> {
|
||||
val elemType = MetadataType.fromCode(littleEndianBytesToInt(input.readNBytesExact(4)))
|
||||
val len = readLittleLong(input)
|
||||
repeat(len.toInt()) { skipValue(input, elemType) } // recursive skip
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/** Helper: Read an 8-byte little-endian unsigned value and return it as a signed Long (assuming it fits in 63 bits). */
|
||||
private fun readLittleLong(input: InputStream): Long {
|
||||
val bytes = ByteArray(8)
|
||||
input.readFully(bytes)
|
||||
|
||||
// Combine 8 bytes into a 64-bit value (Little Endian).
|
||||
// Note: If the value exceeds Long.MAX_VALUE (bit 63 is 1), this will produce a negative Long (two's complement).
|
||||
// In our context (lengths/counts), such extremely large values are not expected.
|
||||
return (bytes[7].toLong() and 0xFFL shl 56) or
|
||||
(bytes[6].toLong() and 0xFFL shl 48) or
|
||||
(bytes[5].toLong() and 0xFFL shl 40) or
|
||||
(bytes[4].toLong() and 0xFFL shl 32) or
|
||||
(bytes[3].toLong() and 0xFFL shl 24) or
|
||||
(bytes[2].toLong() and 0xFFL shl 16) or
|
||||
(bytes[1].toLong() and 0xFFL shl 8) or
|
||||
(bytes[0].toLong() and 0xFFL)
|
||||
}
|
||||
|
||||
/** Helper: Read a GGUF string from the stream (8-byte length followed by UTF-8 bytes). */
|
||||
private fun readString(input: InputStream): String =
|
||||
// Read 8-byte little-endian length (number of bytes in the string).
|
||||
readLittleLong(input).let { len ->
|
||||
if (len < 0 || len > Int.MAX_VALUE) throw IOException("String too long: $len")
|
||||
|
||||
// Read the UTF-8 bytes of the given length.
|
||||
ByteArray(len.toInt()).let {
|
||||
if (it.isNotEmpty()) input.readFully(it)
|
||||
String(it, Charsets.UTF_8)
|
||||
}
|
||||
}
|
||||
|
||||
/** Helper: Convert a 4-byte little-endian byte array to a 32-bit integer. */
|
||||
private fun littleEndianBytesToInt(bytes: ByteArray): Int =
|
||||
// Note: assumes bytes length is 4.
|
||||
(bytes[3].toInt() and 0xFF shl 24) or
|
||||
(bytes[2].toInt() and 0xFF shl 16) or
|
||||
(bytes[1].toInt() and 0xFF shl 8) or
|
||||
(bytes[0].toInt() and 0xFF)
|
||||
|
||||
/**
|
||||
* Robust skip that works the same on JDK 11 and Android’s desugared runtime.
|
||||
*
|
||||
* @param n Number of bytes to advance in the stream.
|
||||
* @throws IOException on premature EOF.
|
||||
*/
|
||||
private fun InputStream.skipFully(n: Long) {
|
||||
var remaining = n
|
||||
val scratch = ByteArray(8192) // read‑and‑toss buffer
|
||||
while (remaining > 0) {
|
||||
val skipped = skip(remaining)
|
||||
when {
|
||||
skipped > 0 -> remaining -= skipped // normal fast path
|
||||
skipped == 0L -> {
|
||||
// fallback: read and discard
|
||||
val read = read(scratch, 0, minOf(remaining, scratch.size.toLong()).toInt())
|
||||
if (read == -1) throw IOException("EOF while skipping $n bytes")
|
||||
remaining -= read
|
||||
}
|
||||
else -> throw IOException("Skip returned negative value")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Extension that keeps reading until the requested number of bytes are filled.
|
||||
* Falls back to `read()` when `skip()` returns 0, which happens on some Android
|
||||
* streams.
|
||||
*
|
||||
* @param buf Destination buffer.
|
||||
* @param len Number of bytes to fill (defaults to `buf.size`).
|
||||
* @throws IOException on premature EOF.
|
||||
*/
|
||||
private fun InputStream.readFully(buf: ByteArray, len: Int = buf.size) {
|
||||
var off = 0
|
||||
while (off < len) {
|
||||
val n = read(buf, off, len - off)
|
||||
if (n == -1) throw IOException("EOF after $off of $len bytes")
|
||||
off += n
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Read EXACTLY `n` bytes or throw – never returns a partially‑filled array.
|
||||
* This is used for small fixed‑length reads (e.g. 4‑byte type codes).
|
||||
*
|
||||
* @throws IOException on premature EOF.
|
||||
*/
|
||||
private fun InputStream.readNBytesExact(n: Int) = ByteArray(n).also {
|
||||
if (read(it) != n) throw IOException("Unexpected EOF")
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,17 @@
|
||||
package android.llama.cpp
|
||||
|
||||
import org.junit.Test
|
||||
|
||||
import org.junit.Assert.*
|
||||
|
||||
/**
|
||||
* Example local unit test, which will execute on the development machine (host).
|
||||
*
|
||||
* See [testing documentation](http://d.android.com/tools/testing).
|
||||
*/
|
||||
class ExampleUnitTest {
|
||||
@Test
|
||||
fun addition_isCorrect() {
|
||||
assertEquals(4, 2 + 2)
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user