Add Android demo for speaker recognition (#536)

See pre-built Android APKs at 
https://k2-fsa.github.io/sherpa/onnx/speaker-identification/apk.html
This commit is contained in:
Fangjun Kuang
2024-01-23 16:50:52 +08:00
committed by GitHub
parent 626775e5e2
commit bbd7c7fc18
73 changed files with 3022 additions and 6 deletions

View File

@@ -27,6 +27,8 @@
#include "sherpa-onnx/csrc/offline-tts.h"
#include "sherpa-onnx/csrc/online-recognizer.h"
#include "sherpa-onnx/csrc/onnx-utils.h"
#include "sherpa-onnx/csrc/speaker-embedding-extractor.h"
#include "sherpa-onnx/csrc/speaker-embedding-manager.h"
#include "sherpa-onnx/csrc/voice-activity-detector.h"
#include "sherpa-onnx/csrc/wave-reader.h"
#include "sherpa-onnx/csrc/wave-writer.h"
@@ -208,6 +210,85 @@ class SherpaOnnxKws {
int32_t input_sample_rate_ = -1;
};
class SherpaOnnxSpeakerEmbeddingExtractorStream {
public:
explicit SherpaOnnxSpeakerEmbeddingExtractorStream(
std::unique_ptr<OnlineStream> stream)
: stream_(std::move(stream)) {}
void AcceptWaveform(int32_t sample_rate, const float *samples,
int32_t n) const {
stream_->AcceptWaveform(sample_rate, samples, n);
}
void InputFinished() const { stream_->InputFinished(); }
OnlineStream *Get() const { return stream_.get(); }
private:
std::unique_ptr<OnlineStream> stream_;
};
class SherpaOnnxSpeakerEmbeddingExtractor {
public:
#if __ANDROID_API__ >= 9
SherpaOnnxSpeakerEmbeddingExtractor(
AAssetManager *mgr, const SpeakerEmbeddingExtractorConfig &config)
: extractor_(mgr, config) {}
#endif
explicit SherpaOnnxSpeakerEmbeddingExtractor(
const SpeakerEmbeddingExtractorConfig &config)
: extractor_(config) {}
int32_t Dim() const { return extractor_.Dim(); }
bool IsReady(const SherpaOnnxSpeakerEmbeddingExtractorStream *stream) const {
return extractor_.IsReady(stream->Get());
}
SherpaOnnxSpeakerEmbeddingExtractorStream *CreateStream() const {
return new SherpaOnnxSpeakerEmbeddingExtractorStream(
extractor_.CreateStream());
}
std::vector<float> Compute(
const SherpaOnnxSpeakerEmbeddingExtractorStream *stream) const {
return extractor_.Compute(stream->Get());
}
private:
SpeakerEmbeddingExtractor extractor_;
};
static SpeakerEmbeddingExtractorConfig GetSpeakerEmbeddingExtractorConfig(
JNIEnv *env, jobject config) {
SpeakerEmbeddingExtractorConfig ans;
jclass cls = env->GetObjectClass(config);
jfieldID fid = env->GetFieldID(cls, "model", "Ljava/lang/String;");
jstring s = (jstring)env->GetObjectField(config, fid);
const char *p = env->GetStringUTFChars(s, nullptr);
ans.model = p;
env->ReleaseStringUTFChars(s, p);
fid = env->GetFieldID(cls, "numThreads", "I");
ans.num_threads = env->GetIntField(config, fid);
fid = env->GetFieldID(cls, "debug", "Z");
ans.debug = env->GetBooleanField(config, fid);
fid = env->GetFieldID(cls, "provider", "Ljava/lang/String;");
s = (jstring)env->GetObjectField(config, fid);
p = env->GetStringUTFChars(s, nullptr);
ans.provider = p;
env->ReleaseStringUTFChars(s, p);
return ans;
}
static OnlineRecognizerConfig GetConfig(JNIEnv *env, jobject config) {
OnlineRecognizerConfig ans;
@@ -771,6 +852,334 @@ static OfflineTtsConfig GetOfflineTtsConfig(JNIEnv *env, jobject config) {
} // namespace sherpa_onnx
SHERPA_ONNX_EXTERN_C
JNIEXPORT jlong JNICALL
Java_com_k2fsa_sherpa_onnx_SpeakerEmbeddingExtractor_new(JNIEnv *env,
jobject /*obj*/,
jobject asset_manager,
jobject _config) {
#if __ANDROID_API__ >= 9
AAssetManager *mgr = AAssetManager_fromJava(env, asset_manager);
if (!mgr) {
SHERPA_ONNX_LOGE("Failed to get asset manager: %p", mgr);
}
#endif
auto config = sherpa_onnx::GetSpeakerEmbeddingExtractorConfig(env, _config);
SHERPA_ONNX_LOGE("new config:\n%s", config.ToString().c_str());
auto extractor = new sherpa_onnx::SherpaOnnxSpeakerEmbeddingExtractor(
#if __ANDROID_API__ >= 9
mgr,
#endif
config);
return (jlong)extractor;
}
SHERPA_ONNX_EXTERN_C
JNIEXPORT jlong JNICALL
Java_com_k2fsa_sherpa_onnx_SpeakerEmbeddingExtractor_newFromFile(
JNIEnv *env, jobject /*obj*/, jobject _config) {
auto config = sherpa_onnx::GetSpeakerEmbeddingExtractorConfig(env, _config);
SHERPA_ONNX_LOGE("newFromFile config:\n%s", config.ToString().c_str());
if (!config.Validate()) {
SHERPA_ONNX_LOGE("Errors found in config!");
}
auto extractor = new sherpa_onnx::SherpaOnnxSpeakerEmbeddingExtractor(config);
return (jlong)extractor;
}
SHERPA_ONNX_EXTERN_C
JNIEXPORT void JNICALL
Java_com_k2fsa_sherpa_onnx_SpeakerEmbeddingExtractor_delete(JNIEnv *env,
jobject /*obj*/,
jlong ptr) {
delete reinterpret_cast<sherpa_onnx::SherpaOnnxSpeakerEmbeddingExtractor *>(
ptr);
}
SHERPA_ONNX_EXTERN_C
JNIEXPORT jlong JNICALL
Java_com_k2fsa_sherpa_onnx_SpeakerEmbeddingExtractor_createStream(
JNIEnv *env, jobject /*obj*/, jlong ptr) {
auto stream =
reinterpret_cast<sherpa_onnx::SherpaOnnxSpeakerEmbeddingExtractor *>(ptr)
->CreateStream();
return (jlong)stream;
}
SHERPA_ONNX_EXTERN_C
JNIEXPORT jboolean JNICALL
Java_com_k2fsa_sherpa_onnx_SpeakerEmbeddingExtractor_isReady(JNIEnv *env,
jobject /*obj*/,
jlong ptr,
jlong stream_ptr) {
auto extractor =
reinterpret_cast<sherpa_onnx::SherpaOnnxSpeakerEmbeddingExtractor *>(ptr);
auto stream = reinterpret_cast<
sherpa_onnx::SherpaOnnxSpeakerEmbeddingExtractorStream *>(stream_ptr);
return extractor->IsReady(stream);
}
SHERPA_ONNX_EXTERN_C
JNIEXPORT jfloatArray JNICALL
Java_com_k2fsa_sherpa_onnx_SpeakerEmbeddingExtractor_compute(JNIEnv *env,
jobject /*obj*/,
jlong ptr,
jlong stream_ptr) {
auto extractor =
reinterpret_cast<sherpa_onnx::SherpaOnnxSpeakerEmbeddingExtractor *>(ptr);
auto stream = reinterpret_cast<
sherpa_onnx::SherpaOnnxSpeakerEmbeddingExtractorStream *>(stream_ptr);
std::vector<float> embedding = extractor->Compute(stream);
jfloatArray embedding_arr = env->NewFloatArray(embedding.size());
env->SetFloatArrayRegion(embedding_arr, 0, embedding.size(),
embedding.data());
return embedding_arr;
}
SHERPA_ONNX_EXTERN_C
JNIEXPORT jint JNICALL Java_com_k2fsa_sherpa_onnx_SpeakerEmbeddingExtractor_dim(
JNIEnv *env, jobject /*obj*/, jlong ptr) {
auto extractor =
reinterpret_cast<sherpa_onnx::SherpaOnnxSpeakerEmbeddingExtractor *>(ptr);
return extractor->Dim();
}
SHERPA_ONNX_EXTERN_C
JNIEXPORT void JNICALL
Java_com_k2fsa_sherpa_onnx_SpeakerEmbeddingExtractorStream_delete(
JNIEnv *env, jobject /*obj*/, jlong ptr) {
delete reinterpret_cast<
sherpa_onnx::SherpaOnnxSpeakerEmbeddingExtractorStream *>(ptr);
}
SHERPA_ONNX_EXTERN_C
JNIEXPORT void JNICALL
Java_com_k2fsa_sherpa_onnx_SpeakerEmbeddingExtractorStream_acceptWaveform(
JNIEnv *env, jobject /*obj*/, jlong ptr, jfloatArray samples,
jint sample_rate) {
auto stream = reinterpret_cast<
sherpa_onnx::SherpaOnnxSpeakerEmbeddingExtractorStream *>(ptr);
jfloat *p = env->GetFloatArrayElements(samples, nullptr);
jsize n = env->GetArrayLength(samples);
stream->AcceptWaveform(sample_rate, p, n);
env->ReleaseFloatArrayElements(samples, p, JNI_ABORT);
}
SHERPA_ONNX_EXTERN_C
JNIEXPORT void JNICALL
Java_com_k2fsa_sherpa_onnx_SpeakerEmbeddingExtractorStream_inputFinished(
JNIEnv *env, jobject /*obj*/, jlong ptr) {
auto stream = reinterpret_cast<
sherpa_onnx::SherpaOnnxSpeakerEmbeddingExtractorStream *>(ptr);
stream->InputFinished();
}
SHERPA_ONNX_EXTERN_C
JNIEXPORT jlong JNICALL Java_com_k2fsa_sherpa_onnx_SpeakerEmbeddingManager_new(
JNIEnv *env, jobject /*obj*/, jint dim) {
auto p = new sherpa_onnx::SpeakerEmbeddingManager(dim);
return (jlong)p;
}
SHERPA_ONNX_EXTERN_C
JNIEXPORT void JNICALL
Java_com_k2fsa_sherpa_onnx_SpeakerEmbeddingManager_delete(JNIEnv *env,
jobject /*obj*/,
jlong ptr) {
auto manager = reinterpret_cast<sherpa_onnx::SpeakerEmbeddingManager *>(ptr);
delete manager;
}
SHERPA_ONNX_EXTERN_C
JNIEXPORT jboolean JNICALL
Java_com_k2fsa_sherpa_onnx_SpeakerEmbeddingManager_add(JNIEnv *env,
jobject /*obj*/,
jlong ptr, jstring name,
jfloatArray embedding) {
auto manager = reinterpret_cast<sherpa_onnx::SpeakerEmbeddingManager *>(ptr);
jfloat *p = env->GetFloatArrayElements(embedding, nullptr);
jsize n = env->GetArrayLength(embedding);
if (n != manager->Dim()) {
SHERPA_ONNX_LOGE("Expected dim %d, given %d", manager->Dim(),
static_cast<int32_t>(n));
exit(-1);
}
const char *p_name = env->GetStringUTFChars(name, nullptr);
jboolean ok = manager->Add(p_name, p);
env->ReleaseStringUTFChars(name, p_name);
env->ReleaseFloatArrayElements(embedding, p, JNI_ABORT);
return ok;
}
SHERPA_ONNX_EXTERN_C
JNIEXPORT jboolean JNICALL
Java_com_k2fsa_sherpa_onnx_SpeakerEmbeddingManager_addList(
JNIEnv *env, jobject /*obj*/, jlong ptr, jstring name,
jobjectArray embedding_arr) {
auto manager = reinterpret_cast<sherpa_onnx::SpeakerEmbeddingManager *>(ptr);
int num_embeddings = env->GetArrayLength(embedding_arr);
if (num_embeddings == 0) {
return false;
}
std::vector<std::vector<float>> embedding_list;
embedding_list.reserve(num_embeddings);
for (int32_t i = 0; i != num_embeddings; ++i) {
jfloatArray embedding =
(jfloatArray)env->GetObjectArrayElement(embedding_arr, i);
jfloat *p = env->GetFloatArrayElements(embedding, nullptr);
jsize n = env->GetArrayLength(embedding);
if (n != manager->Dim()) {
SHERPA_ONNX_LOGE("i: %d. Expected dim %d, given %d", i, manager->Dim(),
static_cast<int32_t>(n));
exit(-1);
}
embedding_list.push_back({p, p + n});
env->ReleaseFloatArrayElements(embedding, p, JNI_ABORT);
}
const char *p_name = env->GetStringUTFChars(name, nullptr);
jboolean ok = manager->Add(p_name, embedding_list);
env->ReleaseStringUTFChars(name, p_name);
return ok;
}
SHERPA_ONNX_EXTERN_C
JNIEXPORT jboolean JNICALL
Java_com_k2fsa_sherpa_onnx_SpeakerEmbeddingManager_remove(JNIEnv *env,
jobject /*obj*/,
jlong ptr,
jstring name) {
auto manager = reinterpret_cast<sherpa_onnx::SpeakerEmbeddingManager *>(ptr);
const char *p_name = env->GetStringUTFChars(name, nullptr);
jboolean ok = manager->Remove(p_name);
env->ReleaseStringUTFChars(name, p_name);
return ok;
}
SHERPA_ONNX_EXTERN_C
JNIEXPORT jstring JNICALL
Java_com_k2fsa_sherpa_onnx_SpeakerEmbeddingManager_search(JNIEnv *env,
jobject /*obj*/,
jlong ptr,
jfloatArray embedding,
jfloat threshold) {
auto manager = reinterpret_cast<sherpa_onnx::SpeakerEmbeddingManager *>(ptr);
jfloat *p = env->GetFloatArrayElements(embedding, nullptr);
jsize n = env->GetArrayLength(embedding);
if (n != manager->Dim()) {
SHERPA_ONNX_LOGE("Expected dim %d, given %d", manager->Dim(),
static_cast<int32_t>(n));
exit(-1);
}
std::string name = manager->Search(p, threshold);
env->ReleaseFloatArrayElements(embedding, p, JNI_ABORT);
return env->NewStringUTF(name.c_str());
}
SHERPA_ONNX_EXTERN_C
JNIEXPORT jboolean JNICALL
Java_com_k2fsa_sherpa_onnx_SpeakerEmbeddingManager_verify(
JNIEnv *env, jobject /*obj*/, jlong ptr, jstring name,
jfloatArray embedding, jfloat threshold) {
auto manager = reinterpret_cast<sherpa_onnx::SpeakerEmbeddingManager *>(ptr);
jfloat *p = env->GetFloatArrayElements(embedding, nullptr);
jsize n = env->GetArrayLength(embedding);
if (n != manager->Dim()) {
SHERPA_ONNX_LOGE("Expected dim %d, given %d", manager->Dim(),
static_cast<int32_t>(n));
exit(-1);
}
const char *p_name = env->GetStringUTFChars(name, nullptr);
jboolean ok = manager->Verify(p_name, p, threshold);
env->ReleaseFloatArrayElements(embedding, p, JNI_ABORT);
env->ReleaseStringUTFChars(name, p_name);
return ok;
}
SHERPA_ONNX_EXTERN_C
JNIEXPORT jboolean JNICALL
Java_com_k2fsa_sherpa_onnx_SpeakerEmbeddingManager_contains(JNIEnv *env,
jobject /*obj*/,
jlong ptr,
jstring name) {
auto manager = reinterpret_cast<sherpa_onnx::SpeakerEmbeddingManager *>(ptr);
const char *p_name = env->GetStringUTFChars(name, nullptr);
jboolean ok = manager->Contains(p_name);
env->ReleaseStringUTFChars(name, p_name);
return ok;
}
SHERPA_ONNX_EXTERN_C
JNIEXPORT jint JNICALL
Java_com_k2fsa_sherpa_onnx_SpeakerEmbeddingManager_numSpeakers(JNIEnv *env,
jobject /*obj*/,
jlong ptr) {
auto manager = reinterpret_cast<sherpa_onnx::SpeakerEmbeddingManager *>(ptr);
return manager->NumSpeakers();
}
SHERPA_ONNX_EXTERN_C
JNIEXPORT jobjectArray JNICALL
Java_com_k2fsa_sherpa_onnx_SpeakerEmbeddingManager_allSpeakerNames(
JNIEnv *env, jobject /*obj*/, jlong ptr) {
auto manager = reinterpret_cast<sherpa_onnx::SpeakerEmbeddingManager *>(ptr);
std::vector<std::string> all_speakers = manager->GetAllSpeakers();
jobjectArray obj_arr = (jobjectArray)env->NewObjectArray(
all_speakers.size(), env->FindClass("java/lang/String"), nullptr);
int32_t i = 0;
for (auto &s : all_speakers) {
jstring js = env->NewStringUTF(s.c_str());
env->SetObjectArrayElement(obj_arr, i, js);
++i;
}
return obj_arr;
}
SHERPA_ONNX_EXTERN_C
JNIEXPORT jlong JNICALL Java_com_k2fsa_sherpa_onnx_OfflineTts_new(
JNIEnv *env, jobject /*obj*/, jobject asset_manager, jobject _config) {
@@ -783,10 +1192,6 @@ JNIEXPORT jlong JNICALL Java_com_k2fsa_sherpa_onnx_OfflineTts_new(
auto config = sherpa_onnx::GetOfflineTtsConfig(env, _config);
SHERPA_ONNX_LOGE("config:\n%s", config.ToString().c_str());
if (!config.Validate()) {
SHERPA_ONNX_LOGE("Erros found in config!");
}
auto tts = new sherpa_onnx::SherpaOnnxOfflineTts(
#if __ANDROID_API__ >= 9
mgr,
@@ -801,6 +1206,11 @@ JNIEXPORT jlong JNICALL Java_com_k2fsa_sherpa_onnx_OfflineTts_newFromFile(
JNIEnv *env, jobject /*obj*/, jobject _config) {
auto config = sherpa_onnx::GetOfflineTtsConfig(env, _config);
SHERPA_ONNX_LOGE("config:\n%s", config.ToString().c_str());
if (!config.Validate()) {
SHERPA_ONNX_LOGE("Errors found in config!");
}
auto tts = new sherpa_onnx::SherpaOnnxOfflineTts(config);
return (jlong)tts;