208 lines
6.8 KiB
C++
208 lines
6.8 KiB
C++
// sherpa-onnx/jni/speaker-embedding-manager.cc
|
|
//
|
|
// Copyright (c) 2024 Xiaomi Corporation
|
|
#include "sherpa-onnx/csrc/speaker-embedding-manager.h"
|
|
|
|
#include "sherpa-onnx/csrc/macros.h"
|
|
#include "sherpa-onnx/jni/common.h"
|
|
|
|
SHERPA_ONNX_EXTERN_C
|
|
JNIEXPORT jlong JNICALL
|
|
Java_com_k2fsa_sherpa_onnx_SpeakerEmbeddingManager_create(JNIEnv *env,
|
|
jobject /*obj*/,
|
|
jint dim) {
|
|
auto p = new sherpa_onnx::SpeakerEmbeddingManager(dim);
|
|
return (jlong)p;
|
|
}
|
|
|
|
SHERPA_ONNX_EXTERN_C
|
|
JNIEXPORT void JNICALL
|
|
Java_com_k2fsa_sherpa_onnx_SpeakerEmbeddingManager_delete(JNIEnv * /*env*/,
|
|
jobject /*obj*/,
|
|
jlong ptr) {
|
|
auto manager = reinterpret_cast<sherpa_onnx::SpeakerEmbeddingManager *>(ptr);
|
|
delete manager;
|
|
}
|
|
|
|
SHERPA_ONNX_EXTERN_C
|
|
JNIEXPORT jboolean JNICALL
|
|
Java_com_k2fsa_sherpa_onnx_SpeakerEmbeddingManager_add(JNIEnv *env,
|
|
jobject /*obj*/,
|
|
jlong ptr, jstring name,
|
|
jfloatArray embedding) {
|
|
auto manager = reinterpret_cast<sherpa_onnx::SpeakerEmbeddingManager *>(ptr);
|
|
|
|
jfloat *p = env->GetFloatArrayElements(embedding, nullptr);
|
|
jsize n = env->GetArrayLength(embedding);
|
|
|
|
if (n != manager->Dim()) {
|
|
SHERPA_ONNX_LOGE("Expected dim %d, given %d", manager->Dim(),
|
|
static_cast<int32_t>(n));
|
|
exit(-1);
|
|
}
|
|
|
|
const char *p_name = env->GetStringUTFChars(name, nullptr);
|
|
|
|
jboolean ok = manager->Add(p_name, p);
|
|
env->ReleaseStringUTFChars(name, p_name);
|
|
env->ReleaseFloatArrayElements(embedding, p, JNI_ABORT);
|
|
|
|
return ok;
|
|
}
|
|
|
|
SHERPA_ONNX_EXTERN_C
|
|
JNIEXPORT jboolean JNICALL
|
|
Java_com_k2fsa_sherpa_onnx_SpeakerEmbeddingManager_addList(
|
|
JNIEnv *env, jobject /*obj*/, jlong ptr, jstring name,
|
|
jobjectArray embedding_arr) {
|
|
auto manager = reinterpret_cast<sherpa_onnx::SpeakerEmbeddingManager *>(ptr);
|
|
|
|
int num_embeddings = env->GetArrayLength(embedding_arr);
|
|
if (num_embeddings == 0) {
|
|
return false;
|
|
}
|
|
|
|
std::vector<std::vector<float>> embedding_list;
|
|
embedding_list.reserve(num_embeddings);
|
|
for (int32_t i = 0; i != num_embeddings; ++i) {
|
|
jfloatArray embedding =
|
|
(jfloatArray)env->GetObjectArrayElement(embedding_arr, i);
|
|
|
|
jfloat *p = env->GetFloatArrayElements(embedding, nullptr);
|
|
jsize n = env->GetArrayLength(embedding);
|
|
|
|
if (n != manager->Dim()) {
|
|
SHERPA_ONNX_LOGE("i: %d. Expected dim %d, given %d", i, manager->Dim(),
|
|
static_cast<int32_t>(n));
|
|
exit(-1);
|
|
}
|
|
|
|
embedding_list.push_back({p, p + n});
|
|
env->ReleaseFloatArrayElements(embedding, p, JNI_ABORT);
|
|
}
|
|
|
|
const char *p_name = env->GetStringUTFChars(name, nullptr);
|
|
|
|
jboolean ok = manager->Add(p_name, embedding_list);
|
|
|
|
env->ReleaseStringUTFChars(name, p_name);
|
|
|
|
return ok;
|
|
}
|
|
|
|
SHERPA_ONNX_EXTERN_C
|
|
JNIEXPORT jboolean JNICALL
|
|
Java_com_k2fsa_sherpa_onnx_SpeakerEmbeddingManager_remove(JNIEnv *env,
|
|
jobject /*obj*/,
|
|
jlong ptr,
|
|
jstring name) {
|
|
auto manager = reinterpret_cast<sherpa_onnx::SpeakerEmbeddingManager *>(ptr);
|
|
|
|
const char *p_name = env->GetStringUTFChars(name, nullptr);
|
|
|
|
jboolean ok = manager->Remove(p_name);
|
|
|
|
env->ReleaseStringUTFChars(name, p_name);
|
|
|
|
return ok;
|
|
}
|
|
|
|
SHERPA_ONNX_EXTERN_C
|
|
JNIEXPORT jstring JNICALL
|
|
Java_com_k2fsa_sherpa_onnx_SpeakerEmbeddingManager_search(JNIEnv *env,
|
|
jobject /*obj*/,
|
|
jlong ptr,
|
|
jfloatArray embedding,
|
|
jfloat threshold) {
|
|
auto manager = reinterpret_cast<sherpa_onnx::SpeakerEmbeddingManager *>(ptr);
|
|
|
|
jfloat *p = env->GetFloatArrayElements(embedding, nullptr);
|
|
jsize n = env->GetArrayLength(embedding);
|
|
|
|
if (n != manager->Dim()) {
|
|
SHERPA_ONNX_LOGE("Expected dim %d, given %d", manager->Dim(),
|
|
static_cast<int32_t>(n));
|
|
exit(-1);
|
|
}
|
|
|
|
std::string name = manager->Search(p, threshold);
|
|
|
|
env->ReleaseFloatArrayElements(embedding, p, JNI_ABORT);
|
|
|
|
return env->NewStringUTF(name.c_str());
|
|
}
|
|
|
|
SHERPA_ONNX_EXTERN_C
|
|
JNIEXPORT jboolean JNICALL
|
|
Java_com_k2fsa_sherpa_onnx_SpeakerEmbeddingManager_verify(
|
|
JNIEnv *env, jobject /*obj*/, jlong ptr, jstring name,
|
|
jfloatArray embedding, jfloat threshold) {
|
|
auto manager = reinterpret_cast<sherpa_onnx::SpeakerEmbeddingManager *>(ptr);
|
|
|
|
jfloat *p = env->GetFloatArrayElements(embedding, nullptr);
|
|
jsize n = env->GetArrayLength(embedding);
|
|
|
|
if (n != manager->Dim()) {
|
|
SHERPA_ONNX_LOGE("Expected dim %d, given %d", manager->Dim(),
|
|
static_cast<int32_t>(n));
|
|
exit(-1);
|
|
}
|
|
|
|
const char *p_name = env->GetStringUTFChars(name, nullptr);
|
|
|
|
jboolean ok = manager->Verify(p_name, p, threshold);
|
|
|
|
env->ReleaseFloatArrayElements(embedding, p, JNI_ABORT);
|
|
|
|
env->ReleaseStringUTFChars(name, p_name);
|
|
|
|
return ok;
|
|
}
|
|
|
|
SHERPA_ONNX_EXTERN_C
|
|
JNIEXPORT jboolean JNICALL
|
|
Java_com_k2fsa_sherpa_onnx_SpeakerEmbeddingManager_contains(JNIEnv *env,
|
|
jobject /*obj*/,
|
|
jlong ptr,
|
|
jstring name) {
|
|
auto manager = reinterpret_cast<sherpa_onnx::SpeakerEmbeddingManager *>(ptr);
|
|
|
|
const char *p_name = env->GetStringUTFChars(name, nullptr);
|
|
|
|
jboolean ok = manager->Contains(p_name);
|
|
|
|
env->ReleaseStringUTFChars(name, p_name);
|
|
|
|
return ok;
|
|
}
|
|
|
|
SHERPA_ONNX_EXTERN_C
|
|
JNIEXPORT jint JNICALL
|
|
Java_com_k2fsa_sherpa_onnx_SpeakerEmbeddingManager_numSpeakers(JNIEnv * /*env*/,
|
|
jobject /*obj*/,
|
|
jlong ptr) {
|
|
auto manager = reinterpret_cast<sherpa_onnx::SpeakerEmbeddingManager *>(ptr);
|
|
return manager->NumSpeakers();
|
|
}
|
|
|
|
SHERPA_ONNX_EXTERN_C
|
|
JNIEXPORT jobjectArray JNICALL
|
|
Java_com_k2fsa_sherpa_onnx_SpeakerEmbeddingManager_allSpeakerNames(
|
|
JNIEnv *env, jobject /*obj*/, jlong ptr) {
|
|
auto manager = reinterpret_cast<sherpa_onnx::SpeakerEmbeddingManager *>(ptr);
|
|
std::vector<std::string> all_speakers = manager->GetAllSpeakers();
|
|
|
|
jobjectArray obj_arr = (jobjectArray)env->NewObjectArray(
|
|
all_speakers.size(), env->FindClass("java/lang/String"), nullptr);
|
|
|
|
int32_t i = 0;
|
|
for (auto &s : all_speakers) {
|
|
jstring js = env->NewStringUTF(s.c_str());
|
|
env->SetObjectArrayElement(obj_arr, i, js);
|
|
|
|
++i;
|
|
}
|
|
|
|
return obj_arr;
|
|
}
|