Refactor the JNI interface to make it more modular and maintainable (#802)
This commit is contained in:
@@ -82,7 +82,7 @@ bool OfflineTtsVitsModelConfig::Validate() const {
|
||||
|
||||
for (const auto &f : required_files) {
|
||||
if (!FileExists(dict_dir + "/" + f)) {
|
||||
SHERPA_ONNX_LOGE("'%s/%s' does not exist.", data_dir.c_str(),
|
||||
SHERPA_ONNX_LOGE("'%s/%s' does not exist.", dict_dir.c_str(),
|
||||
f.c_str());
|
||||
return false;
|
||||
}
|
||||
|
||||
@@ -12,8 +12,15 @@ endif()
|
||||
set(sources
|
||||
audio-tagging.cc
|
||||
jni.cc
|
||||
keyword-spotter.cc
|
||||
offline-recognizer.cc
|
||||
offline-stream.cc
|
||||
online-recognizer.cc
|
||||
online-stream.cc
|
||||
speaker-embedding-extractor.cc
|
||||
speaker-embedding-manager.cc
|
||||
spoken-language-identification.cc
|
||||
voice-activity-detector.cc
|
||||
)
|
||||
|
||||
if(SHERPA_ONNX_ENABLE_TTS)
|
||||
|
||||
@@ -6,6 +6,8 @@
|
||||
#define SHERPA_ONNX_JNI_COMMON_H_
|
||||
|
||||
#if __ANDROID_API__ >= 9
|
||||
#include <strstream>
|
||||
|
||||
#include "android/asset_manager.h"
|
||||
#include "android/asset_manager_jni.h"
|
||||
#endif
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
233
sherpa-onnx/jni/keyword-spotter.cc
Normal file
233
sherpa-onnx/jni/keyword-spotter.cc
Normal file
@@ -0,0 +1,233 @@
|
||||
// sherpa-onnx/jni/keyword-spotter.cc
|
||||
//
|
||||
// Copyright (c) 2024 Xiaomi Corporation
|
||||
|
||||
#include "sherpa-onnx/csrc/keyword-spotter.h"
|
||||
|
||||
#include "sherpa-onnx/csrc/macros.h"
|
||||
#include "sherpa-onnx/jni/common.h"
|
||||
|
||||
namespace sherpa_onnx {
|
||||
|
||||
static KeywordSpotterConfig GetKwsConfig(JNIEnv *env, jobject config) {
|
||||
KeywordSpotterConfig ans;
|
||||
|
||||
jclass cls = env->GetObjectClass(config);
|
||||
jfieldID fid;
|
||||
|
||||
// https://docs.oracle.com/javase/7/docs/technotes/guides/jni/spec/types.html
|
||||
// https://courses.cs.washington.edu/courses/cse341/99wi/java/tutorial/native1.1/implementing/field.html
|
||||
|
||||
//---------- decoding ----------
|
||||
fid = env->GetFieldID(cls, "maxActivePaths", "I");
|
||||
ans.max_active_paths = env->GetIntField(config, fid);
|
||||
|
||||
fid = env->GetFieldID(cls, "keywordsFile", "Ljava/lang/String;");
|
||||
jstring s = (jstring)env->GetObjectField(config, fid);
|
||||
const char *p = env->GetStringUTFChars(s, nullptr);
|
||||
ans.keywords_file = p;
|
||||
env->ReleaseStringUTFChars(s, p);
|
||||
|
||||
fid = env->GetFieldID(cls, "keywordsScore", "F");
|
||||
ans.keywords_score = env->GetFloatField(config, fid);
|
||||
|
||||
fid = env->GetFieldID(cls, "keywordsThreshold", "F");
|
||||
ans.keywords_threshold = env->GetFloatField(config, fid);
|
||||
|
||||
fid = env->GetFieldID(cls, "numTrailingBlanks", "I");
|
||||
ans.num_trailing_blanks = env->GetIntField(config, fid);
|
||||
|
||||
//---------- feat config ----------
|
||||
fid = env->GetFieldID(cls, "featConfig",
|
||||
"Lcom/k2fsa/sherpa/onnx/FeatureConfig;");
|
||||
jobject feat_config = env->GetObjectField(config, fid);
|
||||
jclass feat_config_cls = env->GetObjectClass(feat_config);
|
||||
|
||||
fid = env->GetFieldID(feat_config_cls, "sampleRate", "I");
|
||||
ans.feat_config.sampling_rate = env->GetIntField(feat_config, fid);
|
||||
|
||||
fid = env->GetFieldID(feat_config_cls, "featureDim", "I");
|
||||
ans.feat_config.feature_dim = env->GetIntField(feat_config, fid);
|
||||
|
||||
//---------- model config ----------
|
||||
fid = env->GetFieldID(cls, "modelConfig",
|
||||
"Lcom/k2fsa/sherpa/onnx/OnlineModelConfig;");
|
||||
jobject model_config = env->GetObjectField(config, fid);
|
||||
jclass model_config_cls = env->GetObjectClass(model_config);
|
||||
|
||||
// transducer
|
||||
fid = env->GetFieldID(model_config_cls, "transducer",
|
||||
"Lcom/k2fsa/sherpa/onnx/OnlineTransducerModelConfig;");
|
||||
jobject transducer_config = env->GetObjectField(model_config, fid);
|
||||
jclass transducer_config_cls = env->GetObjectClass(transducer_config);
|
||||
|
||||
fid = env->GetFieldID(transducer_config_cls, "encoder", "Ljava/lang/String;");
|
||||
s = (jstring)env->GetObjectField(transducer_config, fid);
|
||||
p = env->GetStringUTFChars(s, nullptr);
|
||||
ans.model_config.transducer.encoder = p;
|
||||
env->ReleaseStringUTFChars(s, p);
|
||||
|
||||
fid = env->GetFieldID(transducer_config_cls, "decoder", "Ljava/lang/String;");
|
||||
s = (jstring)env->GetObjectField(transducer_config, fid);
|
||||
p = env->GetStringUTFChars(s, nullptr);
|
||||
ans.model_config.transducer.decoder = p;
|
||||
env->ReleaseStringUTFChars(s, p);
|
||||
|
||||
fid = env->GetFieldID(transducer_config_cls, "joiner", "Ljava/lang/String;");
|
||||
s = (jstring)env->GetObjectField(transducer_config, fid);
|
||||
p = env->GetStringUTFChars(s, nullptr);
|
||||
ans.model_config.transducer.joiner = p;
|
||||
env->ReleaseStringUTFChars(s, p);
|
||||
|
||||
fid = env->GetFieldID(model_config_cls, "tokens", "Ljava/lang/String;");
|
||||
s = (jstring)env->GetObjectField(model_config, fid);
|
||||
p = env->GetStringUTFChars(s, nullptr);
|
||||
ans.model_config.tokens = p;
|
||||
env->ReleaseStringUTFChars(s, p);
|
||||
|
||||
fid = env->GetFieldID(model_config_cls, "numThreads", "I");
|
||||
ans.model_config.num_threads = env->GetIntField(model_config, fid);
|
||||
|
||||
fid = env->GetFieldID(model_config_cls, "debug", "Z");
|
||||
ans.model_config.debug = env->GetBooleanField(model_config, fid);
|
||||
|
||||
fid = env->GetFieldID(model_config_cls, "provider", "Ljava/lang/String;");
|
||||
s = (jstring)env->GetObjectField(model_config, fid);
|
||||
p = env->GetStringUTFChars(s, nullptr);
|
||||
ans.model_config.provider = p;
|
||||
env->ReleaseStringUTFChars(s, p);
|
||||
|
||||
fid = env->GetFieldID(model_config_cls, "modelType", "Ljava/lang/String;");
|
||||
s = (jstring)env->GetObjectField(model_config, fid);
|
||||
p = env->GetStringUTFChars(s, nullptr);
|
||||
ans.model_config.model_type = p;
|
||||
env->ReleaseStringUTFChars(s, p);
|
||||
|
||||
return ans;
|
||||
}
|
||||
|
||||
} // namespace sherpa_onnx
|
||||
|
||||
SHERPA_ONNX_EXTERN_C
|
||||
JNIEXPORT jlong JNICALL Java_com_k2fsa_sherpa_onnx_KeywordSpotter_newFromAsset(
|
||||
JNIEnv *env, jobject /*obj*/, jobject asset_manager, jobject _config) {
|
||||
#if __ANDROID_API__ >= 9
|
||||
AAssetManager *mgr = AAssetManager_fromJava(env, asset_manager);
|
||||
if (!mgr) {
|
||||
SHERPA_ONNX_LOGE("Failed to get asset manager: %p", mgr);
|
||||
}
|
||||
#endif
|
||||
auto config = sherpa_onnx::GetKwsConfig(env, _config);
|
||||
SHERPA_ONNX_LOGE("config:\n%s", config.ToString().c_str());
|
||||
auto kws = new sherpa_onnx::KeywordSpotter(
|
||||
#if __ANDROID_API__ >= 9
|
||||
mgr,
|
||||
#endif
|
||||
config);
|
||||
|
||||
return (jlong)kws;
|
||||
}
|
||||
|
||||
SHERPA_ONNX_EXTERN_C
|
||||
JNIEXPORT jlong JNICALL Java_com_k2fsa_sherpa_onnx_KeywordSpotter_newFromFile(
|
||||
JNIEnv *env, jobject /*obj*/, jobject _config) {
|
||||
auto config = sherpa_onnx::GetKwsConfig(env, _config);
|
||||
SHERPA_ONNX_LOGE("config:\n%s", config.ToString().c_str());
|
||||
|
||||
if (!config.Validate()) {
|
||||
SHERPA_ONNX_LOGE("Errors found in config!");
|
||||
return 0;
|
||||
}
|
||||
|
||||
auto kws = new sherpa_onnx::KeywordSpotter(config);
|
||||
|
||||
return (jlong)kws;
|
||||
}
|
||||
|
||||
SHERPA_ONNX_EXTERN_C
|
||||
JNIEXPORT void JNICALL Java_com_k2fsa_sherpa_onnx_KeywordSpotter_delete(
|
||||
JNIEnv *env, jobject /*obj*/, jlong ptr) {
|
||||
delete reinterpret_cast<sherpa_onnx::KeywordSpotter *>(ptr);
|
||||
}
|
||||
|
||||
SHERPA_ONNX_EXTERN_C
|
||||
JNIEXPORT void JNICALL Java_com_k2fsa_sherpa_onnx_KeywordSpotter_decode(
|
||||
JNIEnv *env, jobject /*obj*/, jlong ptr, jlong stream_ptr) {
|
||||
auto kws = reinterpret_cast<sherpa_onnx::KeywordSpotter *>(ptr);
|
||||
auto stream = reinterpret_cast<sherpa_onnx::OnlineStream *>(stream_ptr);
|
||||
|
||||
kws->DecodeStream(stream);
|
||||
}
|
||||
|
||||
SHERPA_ONNX_EXTERN_C
|
||||
JNIEXPORT jlong JNICALL Java_com_k2fsa_sherpa_onnx_KeywordSpotter_createStream(
|
||||
JNIEnv *env, jobject /*obj*/, jlong ptr, jstring keywords) {
|
||||
auto kws = reinterpret_cast<sherpa_onnx::KeywordSpotter *>(ptr);
|
||||
|
||||
const char *p = env->GetStringUTFChars(keywords, nullptr);
|
||||
std::unique_ptr<sherpa_onnx::OnlineStream> stream;
|
||||
|
||||
if (strlen(p) == 0) {
|
||||
stream = kws->CreateStream();
|
||||
} else {
|
||||
stream = kws->CreateStream(p);
|
||||
}
|
||||
|
||||
env->ReleaseStringUTFChars(keywords, p);
|
||||
|
||||
// The user is responsible to free the returned pointer.
|
||||
//
|
||||
// See Java_com_k2fsa_sherpa_onnx_OfflineStream_delete() from
|
||||
// ./offline-stream.cc
|
||||
sherpa_onnx::OnlineStream *ans = stream.release();
|
||||
return (jlong)ans;
|
||||
}
|
||||
|
||||
SHERPA_ONNX_EXTERN_C
|
||||
JNIEXPORT bool JNICALL Java_com_k2fsa_sherpa_onnx_KeywordSpotter_isReady(
|
||||
JNIEnv *env, jobject /*obj*/, jlong ptr, jlong stream_ptr) {
|
||||
auto kws = reinterpret_cast<sherpa_onnx::KeywordSpotter *>(ptr);
|
||||
auto stream = reinterpret_cast<sherpa_onnx::OnlineStream *>(stream_ptr);
|
||||
|
||||
return kws->IsReady(stream);
|
||||
}
|
||||
|
||||
SHERPA_ONNX_EXTERN_C
|
||||
JNIEXPORT jobjectArray JNICALL
|
||||
Java_com_k2fsa_sherpa_onnx_KeywordSpotter_getResult(JNIEnv *env,
|
||||
jobject /*obj*/, jlong ptr,
|
||||
jlong stream_ptr) {
|
||||
auto kws = reinterpret_cast<sherpa_onnx::KeywordSpotter *>(ptr);
|
||||
auto stream = reinterpret_cast<sherpa_onnx::OnlineStream *>(stream_ptr);
|
||||
|
||||
sherpa_onnx::KeywordResult result = kws->GetResult(stream);
|
||||
|
||||
// [0]: keyword, jstring
|
||||
// [1]: tokens, array of jstring
|
||||
// [2]: timestamps, array of float
|
||||
jobjectArray obj_arr = (jobjectArray)env->NewObjectArray(
|
||||
3, env->FindClass("java/lang/Object"), nullptr);
|
||||
|
||||
jstring keyword = env->NewStringUTF(result.keyword.c_str());
|
||||
env->SetObjectArrayElement(obj_arr, 0, keyword);
|
||||
|
||||
jobjectArray tokens_arr = (jobjectArray)env->NewObjectArray(
|
||||
result.tokens.size(), env->FindClass("java/lang/String"), nullptr);
|
||||
|
||||
int32_t i = 0;
|
||||
for (const auto &t : result.tokens) {
|
||||
jstring jtext = env->NewStringUTF(t.c_str());
|
||||
env->SetObjectArrayElement(tokens_arr, i, jtext);
|
||||
i += 1;
|
||||
}
|
||||
|
||||
env->SetObjectArrayElement(obj_arr, 1, tokens_arr);
|
||||
|
||||
jfloatArray timestamps_arr = env->NewFloatArray(result.timestamps.size());
|
||||
env->SetFloatArrayRegion(timestamps_arr, 0, result.timestamps.size(),
|
||||
result.timestamps.data());
|
||||
|
||||
env->SetObjectArrayElement(obj_arr, 2, timestamps_arr);
|
||||
|
||||
return obj_arr;
|
||||
}
|
||||
263
sherpa-onnx/jni/offline-recognizer.cc
Normal file
263
sherpa-onnx/jni/offline-recognizer.cc
Normal file
@@ -0,0 +1,263 @@
|
||||
// sherpa-onnx/jni/offline-recognizer.cc
|
||||
//
|
||||
// Copyright (c) 2024 Xiaomi Corporation
|
||||
|
||||
#include "sherpa-onnx/csrc/offline-recognizer.h"
|
||||
|
||||
#include "sherpa-onnx/csrc/macros.h"
|
||||
#include "sherpa-onnx/jni/common.h"
|
||||
|
||||
namespace sherpa_onnx {
|
||||
|
||||
static OfflineRecognizerConfig GetOfflineConfig(JNIEnv *env, jobject config) {
|
||||
OfflineRecognizerConfig ans;
|
||||
|
||||
jclass cls = env->GetObjectClass(config);
|
||||
jfieldID fid;
|
||||
|
||||
//---------- decoding ----------
|
||||
fid = env->GetFieldID(cls, "decodingMethod", "Ljava/lang/String;");
|
||||
jstring s = (jstring)env->GetObjectField(config, fid);
|
||||
const char *p = env->GetStringUTFChars(s, nullptr);
|
||||
ans.decoding_method = p;
|
||||
env->ReleaseStringUTFChars(s, p);
|
||||
|
||||
fid = env->GetFieldID(cls, "maxActivePaths", "I");
|
||||
ans.max_active_paths = env->GetIntField(config, fid);
|
||||
|
||||
fid = env->GetFieldID(cls, "hotwordsFile", "Ljava/lang/String;");
|
||||
s = (jstring)env->GetObjectField(config, fid);
|
||||
p = env->GetStringUTFChars(s, nullptr);
|
||||
ans.hotwords_file = p;
|
||||
env->ReleaseStringUTFChars(s, p);
|
||||
|
||||
fid = env->GetFieldID(cls, "hotwordsScore", "F");
|
||||
ans.hotwords_score = env->GetFloatField(config, fid);
|
||||
|
||||
//---------- feat config ----------
|
||||
fid = env->GetFieldID(cls, "featConfig",
|
||||
"Lcom/k2fsa/sherpa/onnx/FeatureConfig;");
|
||||
jobject feat_config = env->GetObjectField(config, fid);
|
||||
jclass feat_config_cls = env->GetObjectClass(feat_config);
|
||||
|
||||
fid = env->GetFieldID(feat_config_cls, "sampleRate", "I");
|
||||
ans.feat_config.sampling_rate = env->GetIntField(feat_config, fid);
|
||||
|
||||
fid = env->GetFieldID(feat_config_cls, "featureDim", "I");
|
||||
ans.feat_config.feature_dim = env->GetIntField(feat_config, fid);
|
||||
|
||||
//---------- model config ----------
|
||||
fid = env->GetFieldID(cls, "modelConfig",
|
||||
"Lcom/k2fsa/sherpa/onnx/OfflineModelConfig;");
|
||||
jobject model_config = env->GetObjectField(config, fid);
|
||||
jclass model_config_cls = env->GetObjectClass(model_config);
|
||||
|
||||
fid = env->GetFieldID(model_config_cls, "tokens", "Ljava/lang/String;");
|
||||
s = (jstring)env->GetObjectField(model_config, fid);
|
||||
p = env->GetStringUTFChars(s, nullptr);
|
||||
ans.model_config.tokens = p;
|
||||
env->ReleaseStringUTFChars(s, p);
|
||||
|
||||
fid = env->GetFieldID(model_config_cls, "numThreads", "I");
|
||||
ans.model_config.num_threads = env->GetIntField(model_config, fid);
|
||||
|
||||
fid = env->GetFieldID(model_config_cls, "debug", "Z");
|
||||
ans.model_config.debug = env->GetBooleanField(model_config, fid);
|
||||
|
||||
fid = env->GetFieldID(model_config_cls, "provider", "Ljava/lang/String;");
|
||||
s = (jstring)env->GetObjectField(model_config, fid);
|
||||
p = env->GetStringUTFChars(s, nullptr);
|
||||
ans.model_config.provider = p;
|
||||
env->ReleaseStringUTFChars(s, p);
|
||||
|
||||
fid = env->GetFieldID(model_config_cls, "modelType", "Ljava/lang/String;");
|
||||
s = (jstring)env->GetObjectField(model_config, fid);
|
||||
p = env->GetStringUTFChars(s, nullptr);
|
||||
ans.model_config.model_type = p;
|
||||
env->ReleaseStringUTFChars(s, p);
|
||||
|
||||
// transducer
|
||||
fid = env->GetFieldID(model_config_cls, "transducer",
|
||||
"Lcom/k2fsa/sherpa/onnx/OfflineTransducerModelConfig;");
|
||||
jobject transducer_config = env->GetObjectField(model_config, fid);
|
||||
jclass transducer_config_cls = env->GetObjectClass(transducer_config);
|
||||
|
||||
fid = env->GetFieldID(transducer_config_cls, "encoder", "Ljava/lang/String;");
|
||||
s = (jstring)env->GetObjectField(transducer_config, fid);
|
||||
p = env->GetStringUTFChars(s, nullptr);
|
||||
ans.model_config.transducer.encoder_filename = p;
|
||||
env->ReleaseStringUTFChars(s, p);
|
||||
|
||||
fid = env->GetFieldID(transducer_config_cls, "decoder", "Ljava/lang/String;");
|
||||
s = (jstring)env->GetObjectField(transducer_config, fid);
|
||||
p = env->GetStringUTFChars(s, nullptr);
|
||||
ans.model_config.transducer.decoder_filename = p;
|
||||
env->ReleaseStringUTFChars(s, p);
|
||||
|
||||
fid = env->GetFieldID(transducer_config_cls, "joiner", "Ljava/lang/String;");
|
||||
s = (jstring)env->GetObjectField(transducer_config, fid);
|
||||
p = env->GetStringUTFChars(s, nullptr);
|
||||
ans.model_config.transducer.joiner_filename = p;
|
||||
env->ReleaseStringUTFChars(s, p);
|
||||
|
||||
// paraformer
|
||||
fid = env->GetFieldID(model_config_cls, "paraformer",
|
||||
"Lcom/k2fsa/sherpa/onnx/OfflineParaformerModelConfig;");
|
||||
jobject paraformer_config = env->GetObjectField(model_config, fid);
|
||||
jclass paraformer_config_cls = env->GetObjectClass(paraformer_config);
|
||||
|
||||
fid = env->GetFieldID(paraformer_config_cls, "model", "Ljava/lang/String;");
|
||||
|
||||
s = (jstring)env->GetObjectField(paraformer_config, fid);
|
||||
p = env->GetStringUTFChars(s, nullptr);
|
||||
ans.model_config.paraformer.model = p;
|
||||
env->ReleaseStringUTFChars(s, p);
|
||||
|
||||
// whisper
|
||||
fid = env->GetFieldID(model_config_cls, "whisper",
|
||||
"Lcom/k2fsa/sherpa/onnx/OfflineWhisperModelConfig;");
|
||||
jobject whisper_config = env->GetObjectField(model_config, fid);
|
||||
jclass whisper_config_cls = env->GetObjectClass(whisper_config);
|
||||
|
||||
fid = env->GetFieldID(whisper_config_cls, "encoder", "Ljava/lang/String;");
|
||||
s = (jstring)env->GetObjectField(whisper_config, fid);
|
||||
p = env->GetStringUTFChars(s, nullptr);
|
||||
ans.model_config.whisper.encoder = p;
|
||||
env->ReleaseStringUTFChars(s, p);
|
||||
|
||||
fid = env->GetFieldID(whisper_config_cls, "decoder", "Ljava/lang/String;");
|
||||
s = (jstring)env->GetObjectField(whisper_config, fid);
|
||||
p = env->GetStringUTFChars(s, nullptr);
|
||||
ans.model_config.whisper.decoder = p;
|
||||
env->ReleaseStringUTFChars(s, p);
|
||||
|
||||
fid = env->GetFieldID(whisper_config_cls, "language", "Ljava/lang/String;");
|
||||
s = (jstring)env->GetObjectField(whisper_config, fid);
|
||||
p = env->GetStringUTFChars(s, nullptr);
|
||||
ans.model_config.whisper.language = p;
|
||||
env->ReleaseStringUTFChars(s, p);
|
||||
|
||||
fid = env->GetFieldID(whisper_config_cls, "task", "Ljava/lang/String;");
|
||||
s = (jstring)env->GetObjectField(whisper_config, fid);
|
||||
p = env->GetStringUTFChars(s, nullptr);
|
||||
ans.model_config.whisper.task = p;
|
||||
env->ReleaseStringUTFChars(s, p);
|
||||
|
||||
fid = env->GetFieldID(whisper_config_cls, "tailPaddings", "I");
|
||||
ans.model_config.whisper.tail_paddings =
|
||||
env->GetIntField(whisper_config, fid);
|
||||
|
||||
return ans;
|
||||
}
|
||||
|
||||
} // namespace sherpa_onnx
|
||||
|
||||
SHERPA_ONNX_EXTERN_C
|
||||
JNIEXPORT jlong JNICALL
|
||||
Java_com_k2fsa_sherpa_onnx_OfflineRecognizer_newFromAsset(JNIEnv *env,
|
||||
jobject /*obj*/,
|
||||
jobject asset_manager,
|
||||
jobject _config) {
|
||||
#if __ANDROID_API__ >= 9
|
||||
AAssetManager *mgr = AAssetManager_fromJava(env, asset_manager);
|
||||
if (!mgr) {
|
||||
SHERPA_ONNX_LOGE("Failed to get asset manager: %p", mgr);
|
||||
}
|
||||
#endif
|
||||
auto config = sherpa_onnx::GetOfflineConfig(env, _config);
|
||||
SHERPA_ONNX_LOGE("config:\n%s", config.ToString().c_str());
|
||||
auto model = new sherpa_onnx::OfflineRecognizer(
|
||||
#if __ANDROID_API__ >= 9
|
||||
mgr,
|
||||
#endif
|
||||
config);
|
||||
|
||||
return (jlong)model;
|
||||
}
|
||||
|
||||
SHERPA_ONNX_EXTERN_C
|
||||
JNIEXPORT jlong JNICALL
|
||||
Java_com_k2fsa_sherpa_onnx_OfflineRecognizer_newFromFile(JNIEnv *env,
|
||||
jobject /*obj*/,
|
||||
jobject _config) {
|
||||
auto config = sherpa_onnx::GetOfflineConfig(env, _config);
|
||||
SHERPA_ONNX_LOGE("config:\n%s", config.ToString().c_str());
|
||||
|
||||
if (!config.Validate()) {
|
||||
SHERPA_ONNX_LOGE("Errors found in config!");
|
||||
return 0;
|
||||
}
|
||||
|
||||
auto model = new sherpa_onnx::OfflineRecognizer(config);
|
||||
|
||||
return (jlong)model;
|
||||
}
|
||||
|
||||
SHERPA_ONNX_EXTERN_C
|
||||
JNIEXPORT void JNICALL Java_com_k2fsa_sherpa_onnx_OfflineRecognizer_delete(
|
||||
JNIEnv *env, jobject /*obj*/, jlong ptr) {
|
||||
delete reinterpret_cast<sherpa_onnx::OfflineRecognizer *>(ptr);
|
||||
}
|
||||
|
||||
SHERPA_ONNX_EXTERN_C
|
||||
JNIEXPORT jlong JNICALL
|
||||
Java_com_k2fsa_sherpa_onnx_OfflineRecognizer_createStream(JNIEnv *env,
|
||||
jobject /*obj*/,
|
||||
jlong ptr) {
|
||||
auto recognizer = reinterpret_cast<sherpa_onnx::OfflineRecognizer *>(ptr);
|
||||
std::unique_ptr<sherpa_onnx::OfflineStream> s = recognizer->CreateStream();
|
||||
|
||||
// The user is responsible to free the returned pointer.
|
||||
//
|
||||
// See Java_com_k2fsa_sherpa_onnx_OfflineStream_delete() from
|
||||
// ./offline-stream.cc
|
||||
sherpa_onnx::OfflineStream *p = s.release();
|
||||
return (jlong)p;
|
||||
}
|
||||
|
||||
SHERPA_ONNX_EXTERN_C
|
||||
JNIEXPORT void JNICALL Java_com_k2fsa_sherpa_onnx_OfflineRecognizer_decode(
|
||||
JNIEnv *env, jobject /*obj*/, jlong ptr, jlong streamPtr) {
|
||||
auto recognizer = reinterpret_cast<sherpa_onnx::OfflineRecognizer *>(ptr);
|
||||
auto stream = reinterpret_cast<sherpa_onnx::OfflineStream *>(streamPtr);
|
||||
|
||||
recognizer->DecodeStream(stream);
|
||||
}
|
||||
|
||||
SHERPA_ONNX_EXTERN_C
|
||||
JNIEXPORT jobjectArray JNICALL
|
||||
Java_com_k2fsa_sherpa_onnx_OfflineRecognizer_getResult(JNIEnv *env,
|
||||
jobject /*obj*/,
|
||||
jlong streamPtr) {
|
||||
auto stream = reinterpret_cast<sherpa_onnx::OfflineStream *>(streamPtr);
|
||||
sherpa_onnx::OfflineRecognitionResult result = stream->GetResult();
|
||||
|
||||
// [0]: text, jstring
|
||||
// [1]: tokens, array of jstring
|
||||
// [2]: timestamps, array of float
|
||||
jobjectArray obj_arr = (jobjectArray)env->NewObjectArray(
|
||||
3, env->FindClass("java/lang/Object"), nullptr);
|
||||
|
||||
jstring text = env->NewStringUTF(result.text.c_str());
|
||||
env->SetObjectArrayElement(obj_arr, 0, text);
|
||||
|
||||
jobjectArray tokens_arr = (jobjectArray)env->NewObjectArray(
|
||||
result.tokens.size(), env->FindClass("java/lang/String"), nullptr);
|
||||
|
||||
int32_t i = 0;
|
||||
for (const auto &t : result.tokens) {
|
||||
jstring jtext = env->NewStringUTF(t.c_str());
|
||||
env->SetObjectArrayElement(tokens_arr, i, jtext);
|
||||
i += 1;
|
||||
}
|
||||
|
||||
env->SetObjectArrayElement(obj_arr, 1, tokens_arr);
|
||||
|
||||
jfloatArray timestamps_arr = env->NewFloatArray(result.timestamps.size());
|
||||
env->SetFloatArrayRegion(timestamps_arr, 0, result.timestamps.size(),
|
||||
result.timestamps.data());
|
||||
|
||||
env->SetObjectArrayElement(obj_arr, 2, timestamps_arr);
|
||||
|
||||
return obj_arr;
|
||||
}
|
||||
352
sherpa-onnx/jni/online-recognizer.cc
Normal file
352
sherpa-onnx/jni/online-recognizer.cc
Normal file
@@ -0,0 +1,352 @@
|
||||
// sherpa-onnx/jni/online-recognizer.cc
|
||||
//
|
||||
// Copyright (c) 2024 Xiaomi Corporation
|
||||
|
||||
#include "sherpa-onnx/csrc/online-recognizer.h"
|
||||
|
||||
#include "sherpa-onnx/csrc/macros.h"
|
||||
#include "sherpa-onnx/jni/common.h"
|
||||
|
||||
namespace sherpa_onnx {
|
||||
|
||||
static OnlineRecognizerConfig GetConfig(JNIEnv *env, jobject config) {
|
||||
OnlineRecognizerConfig ans;
|
||||
|
||||
jclass cls = env->GetObjectClass(config);
|
||||
jfieldID fid;
|
||||
|
||||
// https://docs.oracle.com/javase/7/docs/technotes/guides/jni/spec/types.html
|
||||
// https://courses.cs.washington.edu/courses/cse341/99wi/java/tutorial/native1.1/implementing/field.html
|
||||
|
||||
//---------- decoding ----------
|
||||
fid = env->GetFieldID(cls, "decodingMethod", "Ljava/lang/String;");
|
||||
jstring s = (jstring)env->GetObjectField(config, fid);
|
||||
const char *p = env->GetStringUTFChars(s, nullptr);
|
||||
ans.decoding_method = p;
|
||||
env->ReleaseStringUTFChars(s, p);
|
||||
|
||||
fid = env->GetFieldID(cls, "maxActivePaths", "I");
|
||||
ans.max_active_paths = env->GetIntField(config, fid);
|
||||
|
||||
fid = env->GetFieldID(cls, "hotwordsFile", "Ljava/lang/String;");
|
||||
s = (jstring)env->GetObjectField(config, fid);
|
||||
p = env->GetStringUTFChars(s, nullptr);
|
||||
ans.hotwords_file = p;
|
||||
env->ReleaseStringUTFChars(s, p);
|
||||
|
||||
fid = env->GetFieldID(cls, "hotwordsScore", "F");
|
||||
ans.hotwords_score = env->GetFloatField(config, fid);
|
||||
|
||||
//---------- feat config ----------
|
||||
fid = env->GetFieldID(cls, "featConfig",
|
||||
"Lcom/k2fsa/sherpa/onnx/FeatureConfig;");
|
||||
jobject feat_config = env->GetObjectField(config, fid);
|
||||
jclass feat_config_cls = env->GetObjectClass(feat_config);
|
||||
|
||||
fid = env->GetFieldID(feat_config_cls, "sampleRate", "I");
|
||||
ans.feat_config.sampling_rate = env->GetIntField(feat_config, fid);
|
||||
|
||||
fid = env->GetFieldID(feat_config_cls, "featureDim", "I");
|
||||
ans.feat_config.feature_dim = env->GetIntField(feat_config, fid);
|
||||
|
||||
//---------- enable endpoint ----------
|
||||
fid = env->GetFieldID(cls, "enableEndpoint", "Z");
|
||||
ans.enable_endpoint = env->GetBooleanField(config, fid);
|
||||
|
||||
//---------- endpoint_config ----------
|
||||
|
||||
fid = env->GetFieldID(cls, "endpointConfig",
|
||||
"Lcom/k2fsa/sherpa/onnx/EndpointConfig;");
|
||||
jobject endpoint_config = env->GetObjectField(config, fid);
|
||||
jclass endpoint_config_cls = env->GetObjectClass(endpoint_config);
|
||||
|
||||
fid = env->GetFieldID(endpoint_config_cls, "rule1",
|
||||
"Lcom/k2fsa/sherpa/onnx/EndpointRule;");
|
||||
jobject rule1 = env->GetObjectField(endpoint_config, fid);
|
||||
jclass rule_class = env->GetObjectClass(rule1);
|
||||
|
||||
fid = env->GetFieldID(endpoint_config_cls, "rule2",
|
||||
"Lcom/k2fsa/sherpa/onnx/EndpointRule;");
|
||||
jobject rule2 = env->GetObjectField(endpoint_config, fid);
|
||||
|
||||
fid = env->GetFieldID(endpoint_config_cls, "rule3",
|
||||
"Lcom/k2fsa/sherpa/onnx/EndpointRule;");
|
||||
jobject rule3 = env->GetObjectField(endpoint_config, fid);
|
||||
|
||||
fid = env->GetFieldID(rule_class, "mustContainNonSilence", "Z");
|
||||
ans.endpoint_config.rule1.must_contain_nonsilence =
|
||||
env->GetBooleanField(rule1, fid);
|
||||
ans.endpoint_config.rule2.must_contain_nonsilence =
|
||||
env->GetBooleanField(rule2, fid);
|
||||
ans.endpoint_config.rule3.must_contain_nonsilence =
|
||||
env->GetBooleanField(rule3, fid);
|
||||
|
||||
fid = env->GetFieldID(rule_class, "minTrailingSilence", "F");
|
||||
ans.endpoint_config.rule1.min_trailing_silence =
|
||||
env->GetFloatField(rule1, fid);
|
||||
ans.endpoint_config.rule2.min_trailing_silence =
|
||||
env->GetFloatField(rule2, fid);
|
||||
ans.endpoint_config.rule3.min_trailing_silence =
|
||||
env->GetFloatField(rule3, fid);
|
||||
|
||||
fid = env->GetFieldID(rule_class, "minUtteranceLength", "F");
|
||||
ans.endpoint_config.rule1.min_utterance_length =
|
||||
env->GetFloatField(rule1, fid);
|
||||
ans.endpoint_config.rule2.min_utterance_length =
|
||||
env->GetFloatField(rule2, fid);
|
||||
ans.endpoint_config.rule3.min_utterance_length =
|
||||
env->GetFloatField(rule3, fid);
|
||||
|
||||
//---------- model config ----------
|
||||
fid = env->GetFieldID(cls, "modelConfig",
|
||||
"Lcom/k2fsa/sherpa/onnx/OnlineModelConfig;");
|
||||
jobject model_config = env->GetObjectField(config, fid);
|
||||
jclass model_config_cls = env->GetObjectClass(model_config);
|
||||
|
||||
// transducer
|
||||
fid = env->GetFieldID(model_config_cls, "transducer",
|
||||
"Lcom/k2fsa/sherpa/onnx/OnlineTransducerModelConfig;");
|
||||
jobject transducer_config = env->GetObjectField(model_config, fid);
|
||||
jclass transducer_config_cls = env->GetObjectClass(transducer_config);
|
||||
|
||||
fid = env->GetFieldID(transducer_config_cls, "encoder", "Ljava/lang/String;");
|
||||
s = (jstring)env->GetObjectField(transducer_config, fid);
|
||||
p = env->GetStringUTFChars(s, nullptr);
|
||||
ans.model_config.transducer.encoder = p;
|
||||
env->ReleaseStringUTFChars(s, p);
|
||||
|
||||
fid = env->GetFieldID(transducer_config_cls, "decoder", "Ljava/lang/String;");
|
||||
s = (jstring)env->GetObjectField(transducer_config, fid);
|
||||
p = env->GetStringUTFChars(s, nullptr);
|
||||
ans.model_config.transducer.decoder = p;
|
||||
env->ReleaseStringUTFChars(s, p);
|
||||
|
||||
fid = env->GetFieldID(transducer_config_cls, "joiner", "Ljava/lang/String;");
|
||||
s = (jstring)env->GetObjectField(transducer_config, fid);
|
||||
p = env->GetStringUTFChars(s, nullptr);
|
||||
ans.model_config.transducer.joiner = p;
|
||||
env->ReleaseStringUTFChars(s, p);
|
||||
|
||||
// paraformer
|
||||
fid = env->GetFieldID(model_config_cls, "paraformer",
|
||||
"Lcom/k2fsa/sherpa/onnx/OnlineParaformerModelConfig;");
|
||||
jobject paraformer_config = env->GetObjectField(model_config, fid);
|
||||
jclass paraformer_config_cls = env->GetObjectClass(paraformer_config);
|
||||
|
||||
fid = env->GetFieldID(paraformer_config_cls, "encoder", "Ljava/lang/String;");
|
||||
s = (jstring)env->GetObjectField(paraformer_config, fid);
|
||||
p = env->GetStringUTFChars(s, nullptr);
|
||||
ans.model_config.paraformer.encoder = p;
|
||||
env->ReleaseStringUTFChars(s, p);
|
||||
|
||||
fid = env->GetFieldID(paraformer_config_cls, "decoder", "Ljava/lang/String;");
|
||||
s = (jstring)env->GetObjectField(paraformer_config, fid);
|
||||
p = env->GetStringUTFChars(s, nullptr);
|
||||
ans.model_config.paraformer.decoder = p;
|
||||
env->ReleaseStringUTFChars(s, p);
|
||||
|
||||
// streaming zipformer2 CTC
|
||||
fid =
|
||||
env->GetFieldID(model_config_cls, "zipformer2Ctc",
|
||||
"Lcom/k2fsa/sherpa/onnx/OnlineZipformer2CtcModelConfig;");
|
||||
jobject zipformer2_ctc_config = env->GetObjectField(model_config, fid);
|
||||
jclass zipformer2_ctc_config_cls = env->GetObjectClass(zipformer2_ctc_config);
|
||||
|
||||
fid =
|
||||
env->GetFieldID(zipformer2_ctc_config_cls, "model", "Ljava/lang/String;");
|
||||
s = (jstring)env->GetObjectField(zipformer2_ctc_config, fid);
|
||||
p = env->GetStringUTFChars(s, nullptr);
|
||||
ans.model_config.zipformer2_ctc.model = p;
|
||||
env->ReleaseStringUTFChars(s, p);
|
||||
|
||||
fid = env->GetFieldID(model_config_cls, "tokens", "Ljava/lang/String;");
|
||||
s = (jstring)env->GetObjectField(model_config, fid);
|
||||
p = env->GetStringUTFChars(s, nullptr);
|
||||
ans.model_config.tokens = p;
|
||||
env->ReleaseStringUTFChars(s, p);
|
||||
|
||||
fid = env->GetFieldID(model_config_cls, "numThreads", "I");
|
||||
ans.model_config.num_threads = env->GetIntField(model_config, fid);
|
||||
|
||||
fid = env->GetFieldID(model_config_cls, "debug", "Z");
|
||||
ans.model_config.debug = env->GetBooleanField(model_config, fid);
|
||||
|
||||
fid = env->GetFieldID(model_config_cls, "provider", "Ljava/lang/String;");
|
||||
s = (jstring)env->GetObjectField(model_config, fid);
|
||||
p = env->GetStringUTFChars(s, nullptr);
|
||||
ans.model_config.provider = p;
|
||||
env->ReleaseStringUTFChars(s, p);
|
||||
|
||||
fid = env->GetFieldID(model_config_cls, "modelType", "Ljava/lang/String;");
|
||||
s = (jstring)env->GetObjectField(model_config, fid);
|
||||
p = env->GetStringUTFChars(s, nullptr);
|
||||
ans.model_config.model_type = p;
|
||||
env->ReleaseStringUTFChars(s, p);
|
||||
|
||||
//---------- rnn lm model config ----------
|
||||
fid = env->GetFieldID(cls, "lmConfig",
|
||||
"Lcom/k2fsa/sherpa/onnx/OnlineLMConfig;");
|
||||
jobject lm_model_config = env->GetObjectField(config, fid);
|
||||
jclass lm_model_config_cls = env->GetObjectClass(lm_model_config);
|
||||
|
||||
fid = env->GetFieldID(lm_model_config_cls, "model", "Ljava/lang/String;");
|
||||
s = (jstring)env->GetObjectField(lm_model_config, fid);
|
||||
p = env->GetStringUTFChars(s, nullptr);
|
||||
ans.lm_config.model = p;
|
||||
env->ReleaseStringUTFChars(s, p);
|
||||
|
||||
fid = env->GetFieldID(lm_model_config_cls, "scale", "F");
|
||||
ans.lm_config.scale = env->GetFloatField(lm_model_config, fid);
|
||||
|
||||
return ans;
|
||||
}
|
||||
} // namespace sherpa_onnx
|
||||
|
||||
SHERPA_ONNX_EXTERN_C
|
||||
JNIEXPORT jlong JNICALL
|
||||
Java_com_k2fsa_sherpa_onnx_OnlineRecognizer_newFromAsset(JNIEnv *env,
|
||||
jobject /*obj*/,
|
||||
jobject asset_manager,
|
||||
jobject _config) {
|
||||
#if __ANDROID_API__ >= 9
|
||||
AAssetManager *mgr = AAssetManager_fromJava(env, asset_manager);
|
||||
if (!mgr) {
|
||||
SHERPA_ONNX_LOGE("Failed to get asset manager: %p", mgr);
|
||||
}
|
||||
#endif
|
||||
auto config = sherpa_onnx::GetConfig(env, _config);
|
||||
SHERPA_ONNX_LOGE("config:\n%s", config.ToString().c_str());
|
||||
|
||||
auto recognizer = new sherpa_onnx::OnlineRecognizer(
|
||||
#if __ANDROID_API__ >= 9
|
||||
mgr,
|
||||
#endif
|
||||
config);
|
||||
|
||||
return (jlong)recognizer;
|
||||
}
|
||||
|
||||
SHERPA_ONNX_EXTERN_C
|
||||
JNIEXPORT jlong JNICALL Java_com_k2fsa_sherpa_onnx_OnlineRecognizer_newFromFile(
|
||||
JNIEnv *env, jobject /*obj*/, jobject _config) {
|
||||
auto config = sherpa_onnx::GetConfig(env, _config);
|
||||
SHERPA_ONNX_LOGE("config:\n%s", config.ToString().c_str());
|
||||
|
||||
if (!config.Validate()) {
|
||||
SHERPA_ONNX_LOGE("Errors found in config!");
|
||||
return 0;
|
||||
}
|
||||
|
||||
auto recognizer = new sherpa_onnx::OnlineRecognizer(config);
|
||||
|
||||
return (jlong)recognizer;
|
||||
}
|
||||
|
||||
SHERPA_ONNX_EXTERN_C
|
||||
JNIEXPORT void JNICALL Java_com_k2fsa_sherpa_onnx_OnlineRecognizer_delete(
|
||||
JNIEnv *env, jobject /*obj*/, jlong ptr) {
|
||||
delete reinterpret_cast<sherpa_onnx::OnlineRecognizer *>(ptr);
|
||||
}
|
||||
|
||||
SHERPA_ONNX_EXTERN_C
|
||||
JNIEXPORT void JNICALL Java_com_k2fsa_sherpa_onnx_OnlineRecognizer_reset(
|
||||
JNIEnv *env, jobject /*obj*/, jlong ptr, jlong stream_ptr) {
|
||||
auto recognizer = reinterpret_cast<sherpa_onnx::OnlineRecognizer *>(ptr);
|
||||
auto stream = reinterpret_cast<sherpa_onnx::OnlineStream *>(stream_ptr);
|
||||
recognizer->Reset(stream);
|
||||
}
|
||||
|
||||
SHERPA_ONNX_EXTERN_C
|
||||
JNIEXPORT bool JNICALL Java_com_k2fsa_sherpa_onnx_OnlineRecognizer_isReady(
|
||||
JNIEnv *env, jobject /*obj*/, jlong ptr, jlong stream_ptr) {
|
||||
auto recognizer = reinterpret_cast<sherpa_onnx::OnlineRecognizer *>(ptr);
|
||||
auto stream = reinterpret_cast<sherpa_onnx::OnlineStream *>(stream_ptr);
|
||||
|
||||
return recognizer->IsReady(stream);
|
||||
}
|
||||
|
||||
SHERPA_ONNX_EXTERN_C
|
||||
JNIEXPORT bool JNICALL Java_com_k2fsa_sherpa_onnx_OnlineRecognizer_isEndpoint(
|
||||
JNIEnv *env, jobject /*obj*/, jlong ptr, jlong stream_ptr) {
|
||||
auto recognizer = reinterpret_cast<sherpa_onnx::OnlineRecognizer *>(ptr);
|
||||
auto stream = reinterpret_cast<sherpa_onnx::OnlineStream *>(stream_ptr);
|
||||
|
||||
return recognizer->IsEndpoint(stream);
|
||||
}
|
||||
|
||||
SHERPA_ONNX_EXTERN_C
|
||||
JNIEXPORT void JNICALL Java_com_k2fsa_sherpa_onnx_OnlineRecognizer_decode(
|
||||
JNIEnv *env, jobject /*obj*/, jlong ptr, jlong stream_ptr) {
|
||||
auto recognizer = reinterpret_cast<sherpa_onnx::OnlineRecognizer *>(ptr);
|
||||
auto stream = reinterpret_cast<sherpa_onnx::OnlineStream *>(stream_ptr);
|
||||
|
||||
recognizer->DecodeStream(stream);
|
||||
}
|
||||
|
||||
SHERPA_ONNX_EXTERN_C
|
||||
JNIEXPORT jlong JNICALL
|
||||
Java_com_k2fsa_sherpa_onnx_OnlineRecognizer_createStream(JNIEnv *env,
|
||||
jobject /*obj*/,
|
||||
jlong ptr,
|
||||
jstring hotwords) {
|
||||
auto recognizer = reinterpret_cast<sherpa_onnx::OnlineRecognizer *>(ptr);
|
||||
|
||||
const char *p = env->GetStringUTFChars(hotwords, nullptr);
|
||||
std::unique_ptr<sherpa_onnx::OnlineStream> stream;
|
||||
|
||||
if (strlen(p) == 0) {
|
||||
stream = recognizer->CreateStream();
|
||||
} else {
|
||||
stream = recognizer->CreateStream(p);
|
||||
}
|
||||
|
||||
env->ReleaseStringUTFChars(hotwords, p);
|
||||
|
||||
// The user is responsible to free the returned pointer.
|
||||
//
|
||||
// See Java_com_k2fsa_sherpa_onnx_OfflineStream_delete() from
|
||||
// ./offline-stream.cc
|
||||
sherpa_onnx::OnlineStream *ans = stream.release();
|
||||
return (jlong)ans;
|
||||
}
|
||||
|
||||
SHERPA_ONNX_EXTERN_C
|
||||
JNIEXPORT jobjectArray JNICALL
|
||||
Java_com_k2fsa_sherpa_onnx_OnlineRecognizer_getResult(JNIEnv *env,
|
||||
jobject /*obj*/,
|
||||
jlong ptr,
|
||||
jlong stream_ptr) {
|
||||
auto recognizer = reinterpret_cast<sherpa_onnx::OnlineRecognizer *>(ptr);
|
||||
auto stream = reinterpret_cast<sherpa_onnx::OnlineStream *>(stream_ptr);
|
||||
|
||||
sherpa_onnx::OnlineRecognizerResult result = recognizer->GetResult(stream);
|
||||
|
||||
// [0]: text, jstring
|
||||
// [1]: tokens, array of jstring
|
||||
// [2]: timestamps, array of float
|
||||
jobjectArray obj_arr = (jobjectArray)env->NewObjectArray(
|
||||
3, env->FindClass("java/lang/Object"), nullptr);
|
||||
|
||||
jstring text = env->NewStringUTF(result.text.c_str());
|
||||
env->SetObjectArrayElement(obj_arr, 0, text);
|
||||
|
||||
jobjectArray tokens_arr = (jobjectArray)env->NewObjectArray(
|
||||
result.tokens.size(), env->FindClass("java/lang/String"), nullptr);
|
||||
|
||||
int32_t i = 0;
|
||||
for (const auto &t : result.tokens) {
|
||||
jstring jtext = env->NewStringUTF(t.c_str());
|
||||
env->SetObjectArrayElement(tokens_arr, i, jtext);
|
||||
i += 1;
|
||||
}
|
||||
|
||||
env->SetObjectArrayElement(obj_arr, 1, tokens_arr);
|
||||
|
||||
jfloatArray timestamps_arr = env->NewFloatArray(result.timestamps.size());
|
||||
env->SetFloatArrayRegion(timestamps_arr, 0, result.timestamps.size(),
|
||||
result.timestamps.data());
|
||||
|
||||
env->SetObjectArrayElement(obj_arr, 2, timestamps_arr);
|
||||
|
||||
return obj_arr;
|
||||
}
|
||||
32
sherpa-onnx/jni/online-stream.cc
Normal file
32
sherpa-onnx/jni/online-stream.cc
Normal file
@@ -0,0 +1,32 @@
|
||||
// sherpa-onnx/jni/online-stream.cc
|
||||
//
|
||||
// Copyright (c) 2024 Xiaomi Corporation
|
||||
|
||||
#include "sherpa-onnx/csrc/online-stream.h"
|
||||
|
||||
#include "sherpa-onnx/jni/common.h"
|
||||
|
||||
SHERPA_ONNX_EXTERN_C
|
||||
JNIEXPORT void JNICALL Java_com_k2fsa_sherpa_onnx_OnlineStream_delete(
|
||||
JNIEnv *env, jobject /*obj*/, jlong ptr) {
|
||||
delete reinterpret_cast<sherpa_onnx::OnlineStream *>(ptr);
|
||||
}
|
||||
|
||||
SHERPA_ONNX_EXTERN_C
|
||||
JNIEXPORT void JNICALL Java_com_k2fsa_sherpa_onnx_OnlineStream_acceptWaveform(
|
||||
JNIEnv *env, jobject /*obj*/, jlong ptr, jfloatArray samples,
|
||||
jint sample_rate) {
|
||||
auto stream = reinterpret_cast<sherpa_onnx::OnlineStream *>(ptr);
|
||||
|
||||
jfloat *p = env->GetFloatArrayElements(samples, nullptr);
|
||||
jsize n = env->GetArrayLength(samples);
|
||||
stream->AcceptWaveform(sample_rate, p, n);
|
||||
env->ReleaseFloatArrayElements(samples, p, JNI_ABORT);
|
||||
}
|
||||
|
||||
SHERPA_ONNX_EXTERN_C
|
||||
JNIEXPORT void JNICALL Java_com_k2fsa_sherpa_onnx_OnlineStream_inputFinished(
|
||||
JNIEnv *env, jobject /*obj*/, jlong ptr) {
|
||||
auto stream = reinterpret_cast<sherpa_onnx::OnlineStream *>(ptr);
|
||||
stream->InputFinished();
|
||||
}
|
||||
137
sherpa-onnx/jni/speaker-embedding-extractor.cc
Normal file
137
sherpa-onnx/jni/speaker-embedding-extractor.cc
Normal file
@@ -0,0 +1,137 @@
|
||||
// sherpa-onnx/jni/speaker-embedding-extractor.cc
|
||||
//
|
||||
// Copyright (c) 2024 Xiaomi Corporation
|
||||
#include "sherpa-onnx/csrc/speaker-embedding-extractor.h"
|
||||
|
||||
#include "sherpa-onnx/jni/common.h"
|
||||
|
||||
namespace sherpa_onnx {
|
||||
|
||||
static SpeakerEmbeddingExtractorConfig GetSpeakerEmbeddingExtractorConfig(
|
||||
JNIEnv *env, jobject config) {
|
||||
SpeakerEmbeddingExtractorConfig ans;
|
||||
|
||||
jclass cls = env->GetObjectClass(config);
|
||||
|
||||
jfieldID fid = env->GetFieldID(cls, "model", "Ljava/lang/String;");
|
||||
jstring s = (jstring)env->GetObjectField(config, fid);
|
||||
const char *p = env->GetStringUTFChars(s, nullptr);
|
||||
|
||||
ans.model = p;
|
||||
env->ReleaseStringUTFChars(s, p);
|
||||
|
||||
fid = env->GetFieldID(cls, "numThreads", "I");
|
||||
ans.num_threads = env->GetIntField(config, fid);
|
||||
|
||||
fid = env->GetFieldID(cls, "debug", "Z");
|
||||
ans.debug = env->GetBooleanField(config, fid);
|
||||
|
||||
fid = env->GetFieldID(cls, "provider", "Ljava/lang/String;");
|
||||
s = (jstring)env->GetObjectField(config, fid);
|
||||
p = env->GetStringUTFChars(s, nullptr);
|
||||
ans.provider = p;
|
||||
env->ReleaseStringUTFChars(s, p);
|
||||
|
||||
return ans;
|
||||
}
|
||||
|
||||
} // namespace sherpa_onnx
|
||||
|
||||
SHERPA_ONNX_EXTERN_C
|
||||
JNIEXPORT jlong JNICALL
|
||||
Java_com_k2fsa_sherpa_onnx_SpeakerEmbeddingExtractor_newFromAsset(
|
||||
JNIEnv *env, jobject /*obj*/, jobject asset_manager, jobject _config) {
|
||||
#if __ANDROID_API__ >= 9
|
||||
AAssetManager *mgr = AAssetManager_fromJava(env, asset_manager);
|
||||
if (!mgr) {
|
||||
SHERPA_ONNX_LOGE("Failed to get asset manager: %p", mgr);
|
||||
}
|
||||
#endif
|
||||
auto config = sherpa_onnx::GetSpeakerEmbeddingExtractorConfig(env, _config);
|
||||
SHERPA_ONNX_LOGE("new config:\n%s", config.ToString().c_str());
|
||||
|
||||
auto extractor = new sherpa_onnx::SpeakerEmbeddingExtractor(
|
||||
#if __ANDROID_API__ >= 9
|
||||
mgr,
|
||||
#endif
|
||||
config);
|
||||
|
||||
return (jlong)extractor;
|
||||
}
|
||||
|
||||
SHERPA_ONNX_EXTERN_C
|
||||
JNIEXPORT jlong JNICALL
|
||||
Java_com_k2fsa_sherpa_onnx_SpeakerEmbeddingExtractor_newFromFile(
|
||||
JNIEnv *env, jobject /*obj*/, jobject _config) {
|
||||
auto config = sherpa_onnx::GetSpeakerEmbeddingExtractorConfig(env, _config);
|
||||
SHERPA_ONNX_LOGE("newFromFile config:\n%s", config.ToString().c_str());
|
||||
|
||||
if (!config.Validate()) {
|
||||
SHERPA_ONNX_LOGE("Errors found in config!");
|
||||
}
|
||||
|
||||
auto extractor = new sherpa_onnx::SpeakerEmbeddingExtractor(config);
|
||||
|
||||
return (jlong)extractor;
|
||||
}
|
||||
|
||||
SHERPA_ONNX_EXTERN_C
|
||||
JNIEXPORT void JNICALL
|
||||
Java_com_k2fsa_sherpa_onnx_SpeakerEmbeddingExtractor_delete(JNIEnv *env,
|
||||
jobject /*obj*/,
|
||||
jlong ptr) {
|
||||
delete reinterpret_cast<sherpa_onnx::SpeakerEmbeddingExtractor *>(ptr);
|
||||
}
|
||||
|
||||
SHERPA_ONNX_EXTERN_C
|
||||
JNIEXPORT jlong JNICALL
|
||||
Java_com_k2fsa_sherpa_onnx_SpeakerEmbeddingExtractor_createStream(
|
||||
JNIEnv *env, jobject /*obj*/, jlong ptr) {
|
||||
std::unique_ptr<sherpa_onnx::OnlineStream> s =
|
||||
reinterpret_cast<sherpa_onnx::SpeakerEmbeddingExtractor *>(ptr)
|
||||
->CreateStream();
|
||||
|
||||
// The user is responsible to free the returned pointer.
|
||||
//
|
||||
// See Java_com_k2fsa_sherpa_onnx_OnlineStream_delete() from
|
||||
// ./online-stream.cc
|
||||
sherpa_onnx::OnlineStream *p = s.release();
|
||||
return (jlong)p;
|
||||
}
|
||||
|
||||
SHERPA_ONNX_EXTERN_C
|
||||
JNIEXPORT jboolean JNICALL
|
||||
Java_com_k2fsa_sherpa_onnx_SpeakerEmbeddingExtractor_isReady(JNIEnv *env,
|
||||
jobject /*obj*/,
|
||||
jlong ptr,
|
||||
jlong stream_ptr) {
|
||||
auto extractor =
|
||||
reinterpret_cast<sherpa_onnx::SpeakerEmbeddingExtractor *>(ptr);
|
||||
auto stream = reinterpret_cast<sherpa_onnx::OnlineStream *>(stream_ptr);
|
||||
return extractor->IsReady(stream);
|
||||
}
|
||||
|
||||
SHERPA_ONNX_EXTERN_C
|
||||
JNIEXPORT jfloatArray JNICALL
|
||||
Java_com_k2fsa_sherpa_onnx_SpeakerEmbeddingExtractor_compute(JNIEnv *env,
|
||||
jobject /*obj*/,
|
||||
jlong ptr,
|
||||
jlong stream_ptr) {
|
||||
auto extractor =
|
||||
reinterpret_cast<sherpa_onnx::SpeakerEmbeddingExtractor *>(ptr);
|
||||
auto stream = reinterpret_cast<sherpa_onnx::OnlineStream *>(stream_ptr);
|
||||
|
||||
std::vector<float> embedding = extractor->Compute(stream);
|
||||
jfloatArray embedding_arr = env->NewFloatArray(embedding.size());
|
||||
env->SetFloatArrayRegion(embedding_arr, 0, embedding.size(),
|
||||
embedding.data());
|
||||
return embedding_arr;
|
||||
}
|
||||
|
||||
SHERPA_ONNX_EXTERN_C
|
||||
JNIEXPORT jint JNICALL Java_com_k2fsa_sherpa_onnx_SpeakerEmbeddingExtractor_dim(
|
||||
JNIEnv *env, jobject /*obj*/, jlong ptr) {
|
||||
auto extractor =
|
||||
reinterpret_cast<sherpa_onnx::SpeakerEmbeddingExtractor *>(ptr);
|
||||
return extractor->Dim();
|
||||
}
|
||||
207
sherpa-onnx/jni/speaker-embedding-manager.cc
Normal file
207
sherpa-onnx/jni/speaker-embedding-manager.cc
Normal file
@@ -0,0 +1,207 @@
|
||||
// sherpa-onnx/jni/speaker-embedding-manager.cc
|
||||
//
|
||||
// Copyright (c) 2024 Xiaomi Corporation
|
||||
#include "sherpa-onnx/csrc/speaker-embedding-manager.h"
|
||||
|
||||
#include "sherpa-onnx/csrc/macros.h"
|
||||
#include "sherpa-onnx/jni/common.h"
|
||||
|
||||
SHERPA_ONNX_EXTERN_C
|
||||
JNIEXPORT jlong JNICALL
|
||||
Java_com_k2fsa_sherpa_onnx_SpeakerEmbeddingManager_create(JNIEnv *env,
|
||||
jobject /*obj*/,
|
||||
jint dim) {
|
||||
auto p = new sherpa_onnx::SpeakerEmbeddingManager(dim);
|
||||
return (jlong)p;
|
||||
}
|
||||
|
||||
SHERPA_ONNX_EXTERN_C
|
||||
JNIEXPORT void JNICALL
|
||||
Java_com_k2fsa_sherpa_onnx_SpeakerEmbeddingManager_delete(JNIEnv *env,
|
||||
jobject /*obj*/,
|
||||
jlong ptr) {
|
||||
auto manager = reinterpret_cast<sherpa_onnx::SpeakerEmbeddingManager *>(ptr);
|
||||
delete manager;
|
||||
}
|
||||
|
||||
SHERPA_ONNX_EXTERN_C
|
||||
JNIEXPORT jboolean JNICALL
|
||||
Java_com_k2fsa_sherpa_onnx_SpeakerEmbeddingManager_add(JNIEnv *env,
|
||||
jobject /*obj*/,
|
||||
jlong ptr, jstring name,
|
||||
jfloatArray embedding) {
|
||||
auto manager = reinterpret_cast<sherpa_onnx::SpeakerEmbeddingManager *>(ptr);
|
||||
|
||||
jfloat *p = env->GetFloatArrayElements(embedding, nullptr);
|
||||
jsize n = env->GetArrayLength(embedding);
|
||||
|
||||
if (n != manager->Dim()) {
|
||||
SHERPA_ONNX_LOGE("Expected dim %d, given %d", manager->Dim(),
|
||||
static_cast<int32_t>(n));
|
||||
exit(-1);
|
||||
}
|
||||
|
||||
const char *p_name = env->GetStringUTFChars(name, nullptr);
|
||||
|
||||
jboolean ok = manager->Add(p_name, p);
|
||||
env->ReleaseStringUTFChars(name, p_name);
|
||||
env->ReleaseFloatArrayElements(embedding, p, JNI_ABORT);
|
||||
|
||||
return ok;
|
||||
}
|
||||
|
||||
SHERPA_ONNX_EXTERN_C
|
||||
JNIEXPORT jboolean JNICALL
|
||||
Java_com_k2fsa_sherpa_onnx_SpeakerEmbeddingManager_addList(
|
||||
JNIEnv *env, jobject /*obj*/, jlong ptr, jstring name,
|
||||
jobjectArray embedding_arr) {
|
||||
auto manager = reinterpret_cast<sherpa_onnx::SpeakerEmbeddingManager *>(ptr);
|
||||
|
||||
int num_embeddings = env->GetArrayLength(embedding_arr);
|
||||
if (num_embeddings == 0) {
|
||||
return false;
|
||||
}
|
||||
|
||||
std::vector<std::vector<float>> embedding_list;
|
||||
embedding_list.reserve(num_embeddings);
|
||||
for (int32_t i = 0; i != num_embeddings; ++i) {
|
||||
jfloatArray embedding =
|
||||
(jfloatArray)env->GetObjectArrayElement(embedding_arr, i);
|
||||
|
||||
jfloat *p = env->GetFloatArrayElements(embedding, nullptr);
|
||||
jsize n = env->GetArrayLength(embedding);
|
||||
|
||||
if (n != manager->Dim()) {
|
||||
SHERPA_ONNX_LOGE("i: %d. Expected dim %d, given %d", i, manager->Dim(),
|
||||
static_cast<int32_t>(n));
|
||||
exit(-1);
|
||||
}
|
||||
|
||||
embedding_list.push_back({p, p + n});
|
||||
env->ReleaseFloatArrayElements(embedding, p, JNI_ABORT);
|
||||
}
|
||||
|
||||
const char *p_name = env->GetStringUTFChars(name, nullptr);
|
||||
|
||||
jboolean ok = manager->Add(p_name, embedding_list);
|
||||
|
||||
env->ReleaseStringUTFChars(name, p_name);
|
||||
|
||||
return ok;
|
||||
}
|
||||
|
||||
SHERPA_ONNX_EXTERN_C
|
||||
JNIEXPORT jboolean JNICALL
|
||||
Java_com_k2fsa_sherpa_onnx_SpeakerEmbeddingManager_remove(JNIEnv *env,
|
||||
jobject /*obj*/,
|
||||
jlong ptr,
|
||||
jstring name) {
|
||||
auto manager = reinterpret_cast<sherpa_onnx::SpeakerEmbeddingManager *>(ptr);
|
||||
|
||||
const char *p_name = env->GetStringUTFChars(name, nullptr);
|
||||
|
||||
jboolean ok = manager->Remove(p_name);
|
||||
|
||||
env->ReleaseStringUTFChars(name, p_name);
|
||||
|
||||
return ok;
|
||||
}
|
||||
|
||||
SHERPA_ONNX_EXTERN_C
|
||||
JNIEXPORT jstring JNICALL
|
||||
Java_com_k2fsa_sherpa_onnx_SpeakerEmbeddingManager_search(JNIEnv *env,
|
||||
jobject /*obj*/,
|
||||
jlong ptr,
|
||||
jfloatArray embedding,
|
||||
jfloat threshold) {
|
||||
auto manager = reinterpret_cast<sherpa_onnx::SpeakerEmbeddingManager *>(ptr);
|
||||
|
||||
jfloat *p = env->GetFloatArrayElements(embedding, nullptr);
|
||||
jsize n = env->GetArrayLength(embedding);
|
||||
|
||||
if (n != manager->Dim()) {
|
||||
SHERPA_ONNX_LOGE("Expected dim %d, given %d", manager->Dim(),
|
||||
static_cast<int32_t>(n));
|
||||
exit(-1);
|
||||
}
|
||||
|
||||
std::string name = manager->Search(p, threshold);
|
||||
|
||||
env->ReleaseFloatArrayElements(embedding, p, JNI_ABORT);
|
||||
|
||||
return env->NewStringUTF(name.c_str());
|
||||
}
|
||||
|
||||
SHERPA_ONNX_EXTERN_C
|
||||
JNIEXPORT jboolean JNICALL
|
||||
Java_com_k2fsa_sherpa_onnx_SpeakerEmbeddingManager_verify(
|
||||
JNIEnv *env, jobject /*obj*/, jlong ptr, jstring name,
|
||||
jfloatArray embedding, jfloat threshold) {
|
||||
auto manager = reinterpret_cast<sherpa_onnx::SpeakerEmbeddingManager *>(ptr);
|
||||
|
||||
jfloat *p = env->GetFloatArrayElements(embedding, nullptr);
|
||||
jsize n = env->GetArrayLength(embedding);
|
||||
|
||||
if (n != manager->Dim()) {
|
||||
SHERPA_ONNX_LOGE("Expected dim %d, given %d", manager->Dim(),
|
||||
static_cast<int32_t>(n));
|
||||
exit(-1);
|
||||
}
|
||||
|
||||
const char *p_name = env->GetStringUTFChars(name, nullptr);
|
||||
|
||||
jboolean ok = manager->Verify(p_name, p, threshold);
|
||||
|
||||
env->ReleaseFloatArrayElements(embedding, p, JNI_ABORT);
|
||||
|
||||
env->ReleaseStringUTFChars(name, p_name);
|
||||
|
||||
return ok;
|
||||
}
|
||||
|
||||
SHERPA_ONNX_EXTERN_C
|
||||
JNIEXPORT jboolean JNICALL
|
||||
Java_com_k2fsa_sherpa_onnx_SpeakerEmbeddingManager_contains(JNIEnv *env,
|
||||
jobject /*obj*/,
|
||||
jlong ptr,
|
||||
jstring name) {
|
||||
auto manager = reinterpret_cast<sherpa_onnx::SpeakerEmbeddingManager *>(ptr);
|
||||
|
||||
const char *p_name = env->GetStringUTFChars(name, nullptr);
|
||||
|
||||
jboolean ok = manager->Contains(p_name);
|
||||
|
||||
env->ReleaseStringUTFChars(name, p_name);
|
||||
|
||||
return ok;
|
||||
}
|
||||
|
||||
SHERPA_ONNX_EXTERN_C
|
||||
JNIEXPORT jint JNICALL
|
||||
Java_com_k2fsa_sherpa_onnx_SpeakerEmbeddingManager_numSpeakers(JNIEnv *env,
|
||||
jobject /*obj*/,
|
||||
jlong ptr) {
|
||||
auto manager = reinterpret_cast<sherpa_onnx::SpeakerEmbeddingManager *>(ptr);
|
||||
return manager->NumSpeakers();
|
||||
}
|
||||
|
||||
SHERPA_ONNX_EXTERN_C
|
||||
JNIEXPORT jobjectArray JNICALL
|
||||
Java_com_k2fsa_sherpa_onnx_SpeakerEmbeddingManager_allSpeakerNames(
|
||||
JNIEnv *env, jobject /*obj*/, jlong ptr) {
|
||||
auto manager = reinterpret_cast<sherpa_onnx::SpeakerEmbeddingManager *>(ptr);
|
||||
std::vector<std::string> all_speakers = manager->GetAllSpeakers();
|
||||
|
||||
jobjectArray obj_arr = (jobjectArray)env->NewObjectArray(
|
||||
all_speakers.size(), env->FindClass("java/lang/String"), nullptr);
|
||||
|
||||
int32_t i = 0;
|
||||
for (auto &s : all_speakers) {
|
||||
jstring js = env->NewStringUTF(s.c_str());
|
||||
env->SetObjectArrayElement(obj_arr, i, js);
|
||||
|
||||
++i;
|
||||
}
|
||||
|
||||
return obj_arr;
|
||||
}
|
||||
175
sherpa-onnx/jni/voice-activity-detector.cc
Normal file
175
sherpa-onnx/jni/voice-activity-detector.cc
Normal file
@@ -0,0 +1,175 @@
|
||||
// sherpa-onnx/csrc/voice-activity-detector.cc
|
||||
//
|
||||
// Copyright (c) 2024 Xiaomi Corporation
|
||||
#include "sherpa-onnx/csrc/voice-activity-detector.h"
|
||||
|
||||
#include "sherpa-onnx/csrc/macros.h"
|
||||
#include "sherpa-onnx/jni/common.h"
|
||||
|
||||
namespace sherpa_onnx {
|
||||
|
||||
static VadModelConfig GetVadModelConfig(JNIEnv *env, jobject config) {
|
||||
VadModelConfig ans;
|
||||
|
||||
jclass cls = env->GetObjectClass(config);
|
||||
jfieldID fid;
|
||||
|
||||
// silero_vad
|
||||
fid = env->GetFieldID(cls, "sileroVadModelConfig",
|
||||
"Lcom/k2fsa/sherpa/onnx/SileroVadModelConfig;");
|
||||
jobject silero_vad_config = env->GetObjectField(config, fid);
|
||||
jclass silero_vad_config_cls = env->GetObjectClass(silero_vad_config);
|
||||
|
||||
fid = env->GetFieldID(silero_vad_config_cls, "model", "Ljava/lang/String;");
|
||||
auto s = (jstring)env->GetObjectField(silero_vad_config, fid);
|
||||
auto p = env->GetStringUTFChars(s, nullptr);
|
||||
ans.silero_vad.model = p;
|
||||
env->ReleaseStringUTFChars(s, p);
|
||||
|
||||
fid = env->GetFieldID(silero_vad_config_cls, "threshold", "F");
|
||||
ans.silero_vad.threshold = env->GetFloatField(silero_vad_config, fid);
|
||||
|
||||
fid = env->GetFieldID(silero_vad_config_cls, "minSilenceDuration", "F");
|
||||
ans.silero_vad.min_silence_duration =
|
||||
env->GetFloatField(silero_vad_config, fid);
|
||||
|
||||
fid = env->GetFieldID(silero_vad_config_cls, "minSpeechDuration", "F");
|
||||
ans.silero_vad.min_speech_duration =
|
||||
env->GetFloatField(silero_vad_config, fid);
|
||||
|
||||
fid = env->GetFieldID(silero_vad_config_cls, "windowSize", "I");
|
||||
ans.silero_vad.window_size = env->GetIntField(silero_vad_config, fid);
|
||||
|
||||
fid = env->GetFieldID(cls, "sampleRate", "I");
|
||||
ans.sample_rate = env->GetIntField(config, fid);
|
||||
|
||||
fid = env->GetFieldID(cls, "numThreads", "I");
|
||||
ans.num_threads = env->GetIntField(config, fid);
|
||||
|
||||
fid = env->GetFieldID(cls, "provider", "Ljava/lang/String;");
|
||||
s = (jstring)env->GetObjectField(config, fid);
|
||||
p = env->GetStringUTFChars(s, nullptr);
|
||||
ans.provider = p;
|
||||
env->ReleaseStringUTFChars(s, p);
|
||||
|
||||
fid = env->GetFieldID(cls, "debug", "Z");
|
||||
ans.debug = env->GetBooleanField(config, fid);
|
||||
|
||||
return ans;
|
||||
}
|
||||
|
||||
} // namespace sherpa_onnx
|
||||
|
||||
SHERPA_ONNX_EXTERN_C
|
||||
JNIEXPORT jlong JNICALL Java_com_k2fsa_sherpa_onnx_Vad_newFromAsset(
|
||||
JNIEnv *env, jobject /*obj*/, jobject asset_manager, jobject _config) {
|
||||
#if __ANDROID_API__ >= 9
|
||||
AAssetManager *mgr = AAssetManager_fromJava(env, asset_manager);
|
||||
if (!mgr) {
|
||||
SHERPA_ONNX_LOGE("Failed to get asset manager: %p", mgr);
|
||||
}
|
||||
#endif
|
||||
auto config = sherpa_onnx::GetVadModelConfig(env, _config);
|
||||
SHERPA_ONNX_LOGE("config:\n%s", config.ToString().c_str());
|
||||
auto model = new sherpa_onnx::VoiceActivityDetector(
|
||||
#if __ANDROID_API__ >= 9
|
||||
mgr,
|
||||
#endif
|
||||
config);
|
||||
|
||||
return (jlong)model;
|
||||
}
|
||||
|
||||
SHERPA_ONNX_EXTERN_C
|
||||
JNIEXPORT jlong JNICALL Java_com_k2fsa_sherpa_onnx_Vad_newFromFile(
|
||||
JNIEnv *env, jobject /*obj*/, jobject _config) {
|
||||
auto config = sherpa_onnx::GetVadModelConfig(env, _config);
|
||||
SHERPA_ONNX_LOGE("config:\n%s", config.ToString().c_str());
|
||||
|
||||
if (!config.Validate()) {
|
||||
SHERPA_ONNX_LOGE("Errors found in config!");
|
||||
return 0;
|
||||
}
|
||||
|
||||
auto model = new sherpa_onnx::VoiceActivityDetector(config);
|
||||
|
||||
return (jlong)model;
|
||||
}
|
||||
|
||||
SHERPA_ONNX_EXTERN_C
|
||||
JNIEXPORT void JNICALL Java_com_k2fsa_sherpa_onnx_Vad_delete(JNIEnv *env,
|
||||
jobject /*obj*/,
|
||||
jlong ptr) {
|
||||
delete reinterpret_cast<sherpa_onnx::VoiceActivityDetector *>(ptr);
|
||||
}
|
||||
|
||||
SHERPA_ONNX_EXTERN_C
|
||||
JNIEXPORT void JNICALL Java_com_k2fsa_sherpa_onnx_Vad_acceptWaveform(
|
||||
JNIEnv *env, jobject /*obj*/, jlong ptr, jfloatArray samples) {
|
||||
auto model = reinterpret_cast<sherpa_onnx::VoiceActivityDetector *>(ptr);
|
||||
|
||||
jfloat *p = env->GetFloatArrayElements(samples, nullptr);
|
||||
jsize n = env->GetArrayLength(samples);
|
||||
|
||||
model->AcceptWaveform(p, n);
|
||||
|
||||
env->ReleaseFloatArrayElements(samples, p, JNI_ABORT);
|
||||
}
|
||||
|
||||
SHERPA_ONNX_EXTERN_C
|
||||
JNIEXPORT bool JNICALL Java_com_k2fsa_sherpa_onnx_Vad_empty(JNIEnv *env,
|
||||
jobject /*obj*/,
|
||||
jlong ptr) {
|
||||
auto model = reinterpret_cast<sherpa_onnx::VoiceActivityDetector *>(ptr);
|
||||
return model->Empty();
|
||||
}
|
||||
|
||||
SHERPA_ONNX_EXTERN_C
|
||||
JNIEXPORT void JNICALL Java_com_k2fsa_sherpa_onnx_Vad_pop(JNIEnv *env,
|
||||
jobject /*obj*/,
|
||||
jlong ptr) {
|
||||
auto model = reinterpret_cast<sherpa_onnx::VoiceActivityDetector *>(ptr);
|
||||
model->Pop();
|
||||
}
|
||||
|
||||
SHERPA_ONNX_EXTERN_C
|
||||
JNIEXPORT void JNICALL Java_com_k2fsa_sherpa_onnx_Vad_clear(JNIEnv *env,
|
||||
jobject /*obj*/,
|
||||
jlong ptr) {
|
||||
auto model = reinterpret_cast<sherpa_onnx::VoiceActivityDetector *>(ptr);
|
||||
model->Clear();
|
||||
}
|
||||
|
||||
SHERPA_ONNX_EXTERN_C
|
||||
JNIEXPORT jobjectArray JNICALL
|
||||
Java_com_k2fsa_sherpa_onnx_Vad_front(JNIEnv *env, jobject /*obj*/, jlong ptr) {
|
||||
const auto &front =
|
||||
reinterpret_cast<sherpa_onnx::VoiceActivityDetector *>(ptr)->Front();
|
||||
|
||||
jfloatArray samples_arr = env->NewFloatArray(front.samples.size());
|
||||
env->SetFloatArrayRegion(samples_arr, 0, front.samples.size(),
|
||||
front.samples.data());
|
||||
|
||||
jobjectArray obj_arr = (jobjectArray)env->NewObjectArray(
|
||||
2, env->FindClass("java/lang/Object"), nullptr);
|
||||
|
||||
env->SetObjectArrayElement(obj_arr, 0, NewInteger(env, front.start));
|
||||
env->SetObjectArrayElement(obj_arr, 1, samples_arr);
|
||||
|
||||
return obj_arr;
|
||||
}
|
||||
|
||||
SHERPA_ONNX_EXTERN_C
|
||||
JNIEXPORT bool JNICALL Java_com_k2fsa_sherpa_onnx_Vad_isSpeechDetected(
|
||||
JNIEnv *env, jobject /*obj*/, jlong ptr) {
|
||||
auto model = reinterpret_cast<sherpa_onnx::VoiceActivityDetector *>(ptr);
|
||||
return model->IsSpeechDetected();
|
||||
}
|
||||
|
||||
SHERPA_ONNX_EXTERN_C
|
||||
JNIEXPORT void JNICALL Java_com_k2fsa_sherpa_onnx_Vad_reset(JNIEnv *env,
|
||||
jobject /*obj*/,
|
||||
jlong ptr) {
|
||||
auto model = reinterpret_cast<sherpa_onnx::VoiceActivityDetector *>(ptr);
|
||||
model->Reset();
|
||||
}
|
||||
186
sherpa-onnx/kotlin-api/AudioTagging.kt
Normal file
186
sherpa-onnx/kotlin-api/AudioTagging.kt
Normal file
@@ -0,0 +1,186 @@
|
||||
package com.k2fsa.sherpa.onnx
|
||||
|
||||
import android.content.res.AssetManager
|
||||
|
||||
data class OfflineZipformerAudioTaggingModelConfig(
|
||||
var model: String = "",
|
||||
)
|
||||
|
||||
data class AudioTaggingModelConfig(
|
||||
var zipformer: OfflineZipformerAudioTaggingModelConfig = OfflineZipformerAudioTaggingModelConfig(),
|
||||
var ced: String = "",
|
||||
var numThreads: Int = 1,
|
||||
var debug: Boolean = false,
|
||||
var provider: String = "cpu",
|
||||
)
|
||||
|
||||
data class AudioTaggingConfig(
|
||||
var model: AudioTaggingModelConfig,
|
||||
var labels: String,
|
||||
var topK: Int = 5,
|
||||
)
|
||||
|
||||
data class AudioEvent(
|
||||
val name: String,
|
||||
val index: Int,
|
||||
val prob: Float,
|
||||
)
|
||||
|
||||
class AudioTagging(
|
||||
assetManager: AssetManager? = null,
|
||||
config: AudioTaggingConfig,
|
||||
) {
|
||||
private var ptr: Long
|
||||
|
||||
init {
|
||||
ptr = if (assetManager != null) {
|
||||
newFromAsset(assetManager, config)
|
||||
} else {
|
||||
newFromFile(config)
|
||||
}
|
||||
}
|
||||
|
||||
protected fun finalize() {
|
||||
if (ptr != 0L) {
|
||||
delete(ptr)
|
||||
ptr = 0
|
||||
}
|
||||
}
|
||||
|
||||
fun release() = finalize()
|
||||
|
||||
fun createStream(): OfflineStream {
|
||||
val p = createStream(ptr)
|
||||
return OfflineStream(p)
|
||||
}
|
||||
|
||||
@Suppress("UNCHECKED_CAST")
|
||||
fun compute(stream: OfflineStream, topK: Int = -1): ArrayList<AudioEvent> {
|
||||
val events: Array<Any> = compute(ptr, stream.ptr, topK)
|
||||
val ans = ArrayList<AudioEvent>()
|
||||
|
||||
for (e in events) {
|
||||
val p: Array<Any> = e as Array<Any>
|
||||
ans.add(
|
||||
AudioEvent(
|
||||
name = p[0] as String,
|
||||
index = p[1] as Int,
|
||||
prob = p[2] as Float,
|
||||
)
|
||||
)
|
||||
}
|
||||
|
||||
return ans
|
||||
}
|
||||
|
||||
private external fun newFromAsset(
|
||||
assetManager: AssetManager,
|
||||
config: AudioTaggingConfig,
|
||||
): Long
|
||||
|
||||
private external fun newFromFile(
|
||||
config: AudioTaggingConfig,
|
||||
): Long
|
||||
|
||||
private external fun delete(ptr: Long)
|
||||
|
||||
private external fun createStream(ptr: Long): Long
|
||||
|
||||
private external fun compute(ptr: Long, streamPtr: Long, topK: Int): Array<Any>
|
||||
|
||||
companion object {
|
||||
init {
|
||||
System.loadLibrary("sherpa-onnx-jni")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// please refer to
|
||||
// https://github.com/k2-fsa/sherpa-onnx/releases/tag/audio-tagging-models
|
||||
// to download more models
|
||||
//
|
||||
// See also
|
||||
// https://k2-fsa.github.io/sherpa/onnx/audio-tagging/
|
||||
fun getAudioTaggingConfig(type: Int, numThreads: Int = 1): AudioTaggingConfig? {
|
||||
when (type) {
|
||||
0 -> {
|
||||
val modelDir = "sherpa-onnx-zipformer-small-audio-tagging-2024-04-15"
|
||||
return AudioTaggingConfig(
|
||||
model = AudioTaggingModelConfig(
|
||||
zipformer = OfflineZipformerAudioTaggingModelConfig(model = "$modelDir/model.int8.onnx"),
|
||||
numThreads = numThreads,
|
||||
debug = true,
|
||||
),
|
||||
labels = "$modelDir/class_labels_indices.csv",
|
||||
topK = 3,
|
||||
)
|
||||
}
|
||||
|
||||
1 -> {
|
||||
val modelDir = "sherpa-onnx-zipformer-audio-tagging-2024-04-09"
|
||||
return AudioTaggingConfig(
|
||||
model = AudioTaggingModelConfig(
|
||||
zipformer = OfflineZipformerAudioTaggingModelConfig(model = "$modelDir/model.int8.onnx"),
|
||||
numThreads = numThreads,
|
||||
debug = true,
|
||||
),
|
||||
labels = "$modelDir/class_labels_indices.csv",
|
||||
topK = 3,
|
||||
)
|
||||
}
|
||||
|
||||
2 -> {
|
||||
val modelDir = "sherpa-onnx-ced-tiny-audio-tagging-2024-04-19"
|
||||
return AudioTaggingConfig(
|
||||
model = AudioTaggingModelConfig(
|
||||
ced = "$modelDir/model.int8.onnx",
|
||||
numThreads = numThreads,
|
||||
debug = true,
|
||||
),
|
||||
labels = "$modelDir/class_labels_indices.csv",
|
||||
topK = 3,
|
||||
)
|
||||
}
|
||||
|
||||
3 -> {
|
||||
val modelDir = "sherpa-onnx-ced-mini-audio-tagging-2024-04-19"
|
||||
return AudioTaggingConfig(
|
||||
model = AudioTaggingModelConfig(
|
||||
ced = "$modelDir/model.int8.onnx",
|
||||
numThreads = numThreads,
|
||||
debug = true,
|
||||
),
|
||||
labels = "$modelDir/class_labels_indices.csv",
|
||||
topK = 3,
|
||||
)
|
||||
}
|
||||
|
||||
4 -> {
|
||||
val modelDir = "sherpa-onnx-ced-small-audio-tagging-2024-04-19"
|
||||
return AudioTaggingConfig(
|
||||
model = AudioTaggingModelConfig(
|
||||
ced = "$modelDir/model.int8.onnx",
|
||||
numThreads = numThreads,
|
||||
debug = true,
|
||||
),
|
||||
labels = "$modelDir/class_labels_indices.csv",
|
||||
topK = 3,
|
||||
)
|
||||
}
|
||||
|
||||
5 -> {
|
||||
val modelDir = "sherpa-onnx-ced-base-audio-tagging-2024-04-19"
|
||||
return AudioTaggingConfig(
|
||||
model = AudioTaggingModelConfig(
|
||||
ced = "$modelDir/model.int8.onnx",
|
||||
numThreads = numThreads,
|
||||
debug = true,
|
||||
),
|
||||
labels = "$modelDir/class_labels_indices.csv",
|
||||
topK = 3,
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
return null
|
||||
}
|
||||
10
sherpa-onnx/kotlin-api/FeatureConfig.kt
Normal file
10
sherpa-onnx/kotlin-api/FeatureConfig.kt
Normal file
@@ -0,0 +1,10 @@
|
||||
package com.k2fsa.sherpa.onnx
|
||||
|
||||
data class FeatureConfig(
|
||||
var sampleRate: Int = 16000,
|
||||
var featureDim: Int = 80,
|
||||
)
|
||||
|
||||
fun getFeatureConfig(sampleRate: Int, featureDim: Int): FeatureConfig {
|
||||
return FeatureConfig(sampleRate = sampleRate, featureDim = featureDim)
|
||||
}
|
||||
151
sherpa-onnx/kotlin-api/KeywordSpotter.kt
Normal file
151
sherpa-onnx/kotlin-api/KeywordSpotter.kt
Normal file
@@ -0,0 +1,151 @@
|
||||
// Copyright (c) 2024 Xiaomi Corporation
|
||||
package com.k2fsa.sherpa.onnx
|
||||
|
||||
import android.content.res.AssetManager
|
||||
|
||||
data class KeywordSpotterConfig(
|
||||
var featConfig: FeatureConfig = FeatureConfig(),
|
||||
var modelConfig: OnlineModelConfig,
|
||||
var maxActivePaths: Int = 4,
|
||||
var keywordsFile: String = "keywords.txt",
|
||||
var keywordsScore: Float = 1.5f,
|
||||
var keywordsThreshold: Float = 0.25f,
|
||||
var numTrailingBlanks: Int = 2,
|
||||
)
|
||||
|
||||
data class KeywordSpotterResult(
|
||||
val keyword: String,
|
||||
val tokens: Array<String>,
|
||||
val timestamps: FloatArray,
|
||||
// TODO(fangjun): Add more fields
|
||||
)
|
||||
|
||||
class KeywordSpotter(
|
||||
assetManager: AssetManager? = null,
|
||||
val config: KeywordSpotterConfig,
|
||||
) {
|
||||
private val ptr: Long
|
||||
|
||||
init {
|
||||
ptr = if (assetManager != null) {
|
||||
newFromAsset(assetManager, config)
|
||||
} else {
|
||||
newFromFile(config)
|
||||
}
|
||||
}
|
||||
|
||||
protected fun finalize() {
|
||||
delete(ptr)
|
||||
}
|
||||
|
||||
fun release() = finalize()
|
||||
|
||||
fun createStream(keywords: String = ""): OnlineStream {
|
||||
val p = createStream(ptr, keywords)
|
||||
return OnlineStream(p)
|
||||
}
|
||||
|
||||
fun decode(stream: OnlineStream) = decode(ptr, stream.ptr)
|
||||
fun isReady(stream: OnlineStream) = isReady(ptr, stream.ptr)
|
||||
fun getResult(stream: OnlineStream): KeywordSpotterResult {
|
||||
val objArray = getResult(ptr, stream.ptr)
|
||||
|
||||
val keyword = objArray[0] as String
|
||||
val tokens = objArray[1] as Array<String>
|
||||
val timestamps = objArray[2] as FloatArray
|
||||
|
||||
return KeywordSpotterResult(keyword = keyword, tokens = tokens, timestamps = timestamps)
|
||||
}
|
||||
|
||||
private external fun delete(ptr: Long)
|
||||
|
||||
private external fun newFromAsset(
|
||||
assetManager: AssetManager,
|
||||
config: KeywordSpotterConfig,
|
||||
): Long
|
||||
|
||||
private external fun newFromFile(
|
||||
config: KeywordSpotterConfig,
|
||||
): Long
|
||||
|
||||
private external fun createStream(ptr: Long, keywords: String): Long
|
||||
private external fun isReady(ptr: Long, streamPtr: Long): Boolean
|
||||
private external fun decode(ptr: Long, streamPtr: Long)
|
||||
private external fun getResult(ptr: Long, streamPtr: Long): Array<Any>
|
||||
|
||||
companion object {
|
||||
init {
|
||||
System.loadLibrary("sherpa-onnx-jni")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/*
|
||||
Please see
|
||||
https://k2-fsa.github.io/sherpa/onnx/kws/pretrained_models/index.html
|
||||
for a list of pre-trained models.
|
||||
|
||||
We only add a few here. Please change the following code
|
||||
to add your own. (It should be straightforward to add a new model
|
||||
by following the code)
|
||||
|
||||
@param type
|
||||
0 - sherpa-onnx-kws-zipformer-wenetspeech-3.3M-2024-01-01 (Chinese)
|
||||
https://www.modelscope.cn/models/pkufool/sherpa-onnx-kws-zipformer-wenetspeech-3.3M-2024-01-01/summary
|
||||
|
||||
1 - sherpa-onnx-kws-zipformer-gigaspeech-3.3M-2024-01-01 (English)
|
||||
https://www.modelscope.cn/models/pkufool/sherpa-onnx-kws-zipformer-gigaspeech-3.3M-2024-01-01/summary
|
||||
|
||||
*/
|
||||
fun getKwsModelConfig(type: Int): OnlineModelConfig? {
|
||||
when (type) {
|
||||
0 -> {
|
||||
val modelDir = "sherpa-onnx-kws-zipformer-wenetspeech-3.3M-2024-01-01"
|
||||
return OnlineModelConfig(
|
||||
transducer = OnlineTransducerModelConfig(
|
||||
encoder = "$modelDir/encoder-epoch-12-avg-2-chunk-16-left-64.onnx",
|
||||
decoder = "$modelDir/decoder-epoch-12-avg-2-chunk-16-left-64.onnx",
|
||||
joiner = "$modelDir/joiner-epoch-12-avg-2-chunk-16-left-64.onnx",
|
||||
),
|
||||
tokens = "$modelDir/tokens.txt",
|
||||
modelType = "zipformer2",
|
||||
)
|
||||
}
|
||||
|
||||
1 -> {
|
||||
val modelDir = "sherpa-onnx-kws-zipformer-gigaspeech-3.3M-2024-01-01"
|
||||
return OnlineModelConfig(
|
||||
transducer = OnlineTransducerModelConfig(
|
||||
encoder = "$modelDir/encoder-epoch-12-avg-2-chunk-16-left-64.onnx",
|
||||
decoder = "$modelDir/decoder-epoch-12-avg-2-chunk-16-left-64.onnx",
|
||||
joiner = "$modelDir/joiner-epoch-12-avg-2-chunk-16-left-64.onnx",
|
||||
),
|
||||
tokens = "$modelDir/tokens.txt",
|
||||
modelType = "zipformer2",
|
||||
)
|
||||
}
|
||||
|
||||
}
|
||||
return null
|
||||
}
|
||||
|
||||
/*
|
||||
* Get the default keywords for each model.
|
||||
* Caution: The types and modelDir should be the same as those in getModelConfig
|
||||
* function above.
|
||||
*/
|
||||
fun getKeywordsFile(type: Int): String {
|
||||
when (type) {
|
||||
0 -> {
|
||||
val modelDir = "sherpa-onnx-kws-zipformer-wenetspeech-3.3M-2024-01-01"
|
||||
return "$modelDir/keywords.txt"
|
||||
}
|
||||
|
||||
1 -> {
|
||||
val modelDir = "sherpa-onnx-kws-zipformer-gigaspeech-3.3M-2024-01-01"
|
||||
return "$modelDir/keywords.txt"
|
||||
}
|
||||
|
||||
}
|
||||
return ""
|
||||
}
|
||||
221
sherpa-onnx/kotlin-api/OfflineRecognizer.kt
Normal file
221
sherpa-onnx/kotlin-api/OfflineRecognizer.kt
Normal file
@@ -0,0 +1,221 @@
|
||||
package com.k2fsa.sherpa.onnx
|
||||
|
||||
import android.content.res.AssetManager
|
||||
|
||||
data class OfflineRecognizerResult(
|
||||
val text: String,
|
||||
val tokens: Array<String>,
|
||||
val timestamps: FloatArray,
|
||||
)
|
||||
|
||||
data class OfflineTransducerModelConfig(
|
||||
var encoder: String = "",
|
||||
var decoder: String = "",
|
||||
var joiner: String = "",
|
||||
)
|
||||
|
||||
data class OfflineParaformerModelConfig(
|
||||
var model: String = "",
|
||||
)
|
||||
|
||||
data class OfflineWhisperModelConfig(
|
||||
var encoder: String = "",
|
||||
var decoder: String = "",
|
||||
var language: String = "en", // Used with multilingual model
|
||||
var task: String = "transcribe", // transcribe or translate
|
||||
var tailPaddings: Int = 1000, // Padding added at the end of the samples
|
||||
)
|
||||
|
||||
data class OfflineModelConfig(
|
||||
var transducer: OfflineTransducerModelConfig = OfflineTransducerModelConfig(),
|
||||
var paraformer: OfflineParaformerModelConfig = OfflineParaformerModelConfig(),
|
||||
var whisper: OfflineWhisperModelConfig = OfflineWhisperModelConfig(),
|
||||
var numThreads: Int = 1,
|
||||
var debug: Boolean = false,
|
||||
var provider: String = "cpu",
|
||||
var modelType: String = "",
|
||||
var tokens: String,
|
||||
)
|
||||
|
||||
data class OfflineRecognizerConfig(
|
||||
var featConfig: FeatureConfig = FeatureConfig(),
|
||||
var modelConfig: OfflineModelConfig,
|
||||
// var lmConfig: OfflineLMConfig(), // TODO(fangjun): enable it
|
||||
var decodingMethod: String = "greedy_search",
|
||||
var maxActivePaths: Int = 4,
|
||||
var hotwordsFile: String = "",
|
||||
var hotwordsScore: Float = 1.5f,
|
||||
)
|
||||
|
||||
class OfflineRecognizer(
|
||||
assetManager: AssetManager? = null,
|
||||
config: OfflineRecognizerConfig,
|
||||
) {
|
||||
private val ptr: Long
|
||||
|
||||
init {
|
||||
ptr = if (assetManager != null) {
|
||||
newFromAsset(assetManager, config)
|
||||
} else {
|
||||
newFromFile(config)
|
||||
}
|
||||
}
|
||||
|
||||
protected fun finalize() {
|
||||
delete(ptr)
|
||||
}
|
||||
|
||||
fun release() = finalize()
|
||||
|
||||
fun createStream(): OfflineStream {
|
||||
val p = createStream(ptr)
|
||||
return OfflineStream(p)
|
||||
}
|
||||
|
||||
fun getResult(stream: OfflineStream): OfflineRecognizerResult {
|
||||
val objArray = getResult(stream.ptr)
|
||||
|
||||
val text = objArray[0] as String
|
||||
val tokens = objArray[1] as Array<String>
|
||||
val timestamps = objArray[2] as FloatArray
|
||||
return OfflineRecognizerResult(text = text, tokens = tokens, timestamps = timestamps)
|
||||
}
|
||||
|
||||
fun decode(stream: OfflineStream) = decode(ptr, stream.ptr)
|
||||
|
||||
private external fun delete(ptr: Long)
|
||||
|
||||
private external fun createStream(ptr: Long): Long
|
||||
|
||||
private external fun newFromAsset(
|
||||
assetManager: AssetManager,
|
||||
config: OfflineRecognizerConfig,
|
||||
): Long
|
||||
|
||||
private external fun newFromFile(
|
||||
config: OfflineRecognizerConfig,
|
||||
): Long
|
||||
|
||||
private external fun decode(ptr: Long, streamPtr: Long)
|
||||
|
||||
private external fun getResult(streamPtr: Long): Array<Any>
|
||||
|
||||
companion object {
|
||||
init {
|
||||
System.loadLibrary("sherpa-onnx-jni")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/*
|
||||
Please see
|
||||
https://k2-fsa.github.io/sherpa/onnx/pretrained_models/index.html
|
||||
for a list of pre-trained models.
|
||||
|
||||
We only add a few here. Please change the following code
|
||||
to add your own. (It should be straightforward to add a new model
|
||||
by following the code)
|
||||
|
||||
@param type
|
||||
|
||||
0 - csukuangfj/sherpa-onnx-paraformer-zh-2023-03-28 (Chinese)
|
||||
https://k2-fsa.github.io/sherpa/onnx/pretrained_models/offline-paraformer/paraformer-models.html#csukuangfj-sherpa-onnx-paraformer-zh-2023-03-28-chinese
|
||||
int8
|
||||
|
||||
1 - icefall-asr-multidataset-pruned_transducer_stateless7-2023-05-04 (English)
|
||||
https://k2-fsa.github.io/sherpa/onnx/pretrained_models/offline-transducer/zipformer-transducer-models.html#icefall-asr-multidataset-pruned-transducer-stateless7-2023-05-04-english
|
||||
encoder int8, decoder/joiner float32
|
||||
|
||||
2 - sherpa-onnx-whisper-tiny.en
|
||||
https://k2-fsa.github.io/sherpa/onnx/pretrained_models/whisper/tiny.en.html#tiny-en
|
||||
encoder int8, decoder int8
|
||||
|
||||
3 - sherpa-onnx-whisper-base.en
|
||||
https://k2-fsa.github.io/sherpa/onnx/pretrained_models/whisper/tiny.en.html#tiny-en
|
||||
encoder int8, decoder int8
|
||||
|
||||
4 - pkufool/icefall-asr-zipformer-wenetspeech-20230615 (Chinese)
|
||||
https://k2-fsa.github.io/sherpa/onnx/pretrained_models/offline-transducer/zipformer-transducer-models.html#pkufool-icefall-asr-zipformer-wenetspeech-20230615-chinese
|
||||
encoder/joiner int8, decoder fp32
|
||||
|
||||
*/
|
||||
fun getOfflineModelConfig(type: Int): OfflineModelConfig? {
|
||||
when (type) {
|
||||
0 -> {
|
||||
val modelDir = "sherpa-onnx-paraformer-zh-2023-03-28"
|
||||
return OfflineModelConfig(
|
||||
paraformer = OfflineParaformerModelConfig(
|
||||
model = "$modelDir/model.int8.onnx",
|
||||
),
|
||||
tokens = "$modelDir/tokens.txt",
|
||||
modelType = "paraformer",
|
||||
)
|
||||
}
|
||||
|
||||
1 -> {
|
||||
val modelDir = "icefall-asr-multidataset-pruned_transducer_stateless7-2023-05-04"
|
||||
return OfflineModelConfig(
|
||||
transducer = OfflineTransducerModelConfig(
|
||||
encoder = "$modelDir/encoder-epoch-30-avg-4.int8.onnx",
|
||||
decoder = "$modelDir/decoder-epoch-30-avg-4.onnx",
|
||||
joiner = "$modelDir/joiner-epoch-30-avg-4.onnx",
|
||||
),
|
||||
tokens = "$modelDir/tokens.txt",
|
||||
modelType = "zipformer",
|
||||
)
|
||||
}
|
||||
|
||||
2 -> {
|
||||
val modelDir = "sherpa-onnx-whisper-tiny.en"
|
||||
return OfflineModelConfig(
|
||||
whisper = OfflineWhisperModelConfig(
|
||||
encoder = "$modelDir/tiny.en-encoder.int8.onnx",
|
||||
decoder = "$modelDir/tiny.en-decoder.int8.onnx",
|
||||
),
|
||||
tokens = "$modelDir/tiny.en-tokens.txt",
|
||||
modelType = "whisper",
|
||||
)
|
||||
}
|
||||
|
||||
3 -> {
|
||||
val modelDir = "sherpa-onnx-whisper-base.en"
|
||||
return OfflineModelConfig(
|
||||
whisper = OfflineWhisperModelConfig(
|
||||
encoder = "$modelDir/base.en-encoder.int8.onnx",
|
||||
decoder = "$modelDir/base.en-decoder.int8.onnx",
|
||||
),
|
||||
tokens = "$modelDir/base.en-tokens.txt",
|
||||
modelType = "whisper",
|
||||
)
|
||||
}
|
||||
|
||||
|
||||
4 -> {
|
||||
val modelDir = "icefall-asr-zipformer-wenetspeech-20230615"
|
||||
return OfflineModelConfig(
|
||||
transducer = OfflineTransducerModelConfig(
|
||||
encoder = "$modelDir/encoder-epoch-12-avg-4.int8.onnx",
|
||||
decoder = "$modelDir/decoder-epoch-12-avg-4.onnx",
|
||||
joiner = "$modelDir/joiner-epoch-12-avg-4.int8.onnx",
|
||||
),
|
||||
tokens = "$modelDir/tokens.txt",
|
||||
modelType = "zipformer",
|
||||
)
|
||||
}
|
||||
|
||||
5 -> {
|
||||
val modelDir = "sherpa-onnx-zipformer-multi-zh-hans-2023-9-2"
|
||||
return OfflineModelConfig(
|
||||
transducer = OfflineTransducerModelConfig(
|
||||
encoder = "$modelDir/encoder-epoch-20-avg-1.int8.onnx",
|
||||
decoder = "$modelDir/decoder-epoch-20-avg-1.onnx",
|
||||
joiner = "$modelDir/joiner-epoch-20-avg-1.int8.onnx",
|
||||
),
|
||||
tokens = "$modelDir/tokens.txt",
|
||||
modelType = "zipformer2",
|
||||
)
|
||||
}
|
||||
|
||||
}
|
||||
return null
|
||||
}
|
||||
24
sherpa-onnx/kotlin-api/OfflineStream.kt
Normal file
24
sherpa-onnx/kotlin-api/OfflineStream.kt
Normal file
@@ -0,0 +1,24 @@
|
||||
package com.k2fsa.sherpa.onnx
|
||||
|
||||
class OfflineStream(var ptr: Long) {
|
||||
fun acceptWaveform(samples: FloatArray, sampleRate: Int) =
|
||||
acceptWaveform(ptr, samples, sampleRate)
|
||||
|
||||
protected fun finalize() {
|
||||
if (ptr != 0L) {
|
||||
delete(ptr)
|
||||
ptr = 0
|
||||
}
|
||||
}
|
||||
|
||||
fun release() = finalize()
|
||||
|
||||
private external fun acceptWaveform(ptr: Long, samples: FloatArray, sampleRate: Int)
|
||||
private external fun delete(ptr: Long)
|
||||
|
||||
companion object {
|
||||
init {
|
||||
System.loadLibrary("sherpa-onnx-jni")
|
||||
}
|
||||
}
|
||||
}
|
||||
352
sherpa-onnx/kotlin-api/OnlineRecognizer.kt
Normal file
352
sherpa-onnx/kotlin-api/OnlineRecognizer.kt
Normal file
@@ -0,0 +1,352 @@
|
||||
package com.k2fsa.sherpa.onnx
|
||||
|
||||
import android.content.res.AssetManager
|
||||
|
||||
data class EndpointRule(
|
||||
var mustContainNonSilence: Boolean,
|
||||
var minTrailingSilence: Float,
|
||||
var minUtteranceLength: Float,
|
||||
)
|
||||
|
||||
data class EndpointConfig(
|
||||
var rule1: EndpointRule = EndpointRule(false, 2.4f, 0.0f),
|
||||
var rule2: EndpointRule = EndpointRule(true, 1.4f, 0.0f),
|
||||
var rule3: EndpointRule = EndpointRule(false, 0.0f, 20.0f)
|
||||
)
|
||||
|
||||
data class OnlineTransducerModelConfig(
|
||||
var encoder: String = "",
|
||||
var decoder: String = "",
|
||||
var joiner: String = "",
|
||||
)
|
||||
|
||||
data class OnlineParaformerModelConfig(
|
||||
var encoder: String = "",
|
||||
var decoder: String = "",
|
||||
)
|
||||
|
||||
data class OnlineZipformer2CtcModelConfig(
|
||||
var model: String = "",
|
||||
)
|
||||
|
||||
data class OnlineModelConfig(
|
||||
var transducer: OnlineTransducerModelConfig = OnlineTransducerModelConfig(),
|
||||
var paraformer: OnlineParaformerModelConfig = OnlineParaformerModelConfig(),
|
||||
var zipformer2Ctc: OnlineZipformer2CtcModelConfig = OnlineZipformer2CtcModelConfig(),
|
||||
var tokens: String,
|
||||
var numThreads: Int = 1,
|
||||
var debug: Boolean = false,
|
||||
var provider: String = "cpu",
|
||||
var modelType: String = "",
|
||||
)
|
||||
|
||||
data class OnlineLMConfig(
|
||||
var model: String = "",
|
||||
var scale: Float = 0.5f,
|
||||
)
|
||||
|
||||
|
||||
data class OnlineRecognizerConfig(
|
||||
var featConfig: FeatureConfig = FeatureConfig(),
|
||||
var modelConfig: OnlineModelConfig,
|
||||
var lmConfig: OnlineLMConfig = OnlineLMConfig(),
|
||||
var endpointConfig: EndpointConfig = EndpointConfig(),
|
||||
var enableEndpoint: Boolean = true,
|
||||
var decodingMethod: String = "greedy_search",
|
||||
var maxActivePaths: Int = 4,
|
||||
var hotwordsFile: String = "",
|
||||
var hotwordsScore: Float = 1.5f,
|
||||
)
|
||||
|
||||
data class OnlineRecognizerResult(
|
||||
val text: String,
|
||||
val tokens: Array<String>,
|
||||
val timestamps: FloatArray,
|
||||
// TODO(fangjun): Add more fields
|
||||
)
|
||||
|
||||
class OnlineRecognizer(
|
||||
assetManager: AssetManager? = null,
|
||||
val config: OnlineRecognizerConfig,
|
||||
) {
|
||||
private val ptr: Long
|
||||
|
||||
init {
|
||||
ptr = if (assetManager != null) {
|
||||
newFromAsset(assetManager, config)
|
||||
} else {
|
||||
newFromFile(config)
|
||||
}
|
||||
}
|
||||
|
||||
protected fun finalize() {
|
||||
delete(ptr)
|
||||
}
|
||||
|
||||
fun release() = finalize()
|
||||
|
||||
fun createStream(hotwords: String = ""): OnlineStream {
|
||||
val p = createStream(ptr, hotwords)
|
||||
return OnlineStream(p)
|
||||
}
|
||||
|
||||
fun reset(stream: OnlineStream) = reset(ptr, stream.ptr)
|
||||
fun decode(stream: OnlineStream) = decode(ptr, stream.ptr)
|
||||
fun isEndpoint(stream: OnlineStream) = isEndpoint(ptr, stream.ptr)
|
||||
fun isReady(stream: OnlineStream) = isReady(ptr, stream.ptr)
|
||||
fun getResult(stream: OnlineStream): OnlineRecognizerResult {
|
||||
val objArray = getResult(ptr, stream.ptr)
|
||||
|
||||
val text = objArray[0] as String
|
||||
val tokens = objArray[1] as Array<String>
|
||||
val timestamps = objArray[2] as FloatArray
|
||||
|
||||
return OnlineRecognizerResult(text = text, tokens = tokens, timestamps = timestamps)
|
||||
}
|
||||
|
||||
private external fun delete(ptr: Long)
|
||||
|
||||
private external fun newFromAsset(
|
||||
assetManager: AssetManager,
|
||||
config: OnlineRecognizerConfig,
|
||||
): Long
|
||||
|
||||
private external fun newFromFile(
|
||||
config: OnlineRecognizerConfig,
|
||||
): Long
|
||||
|
||||
private external fun createStream(ptr: Long, hotwords: String): Long
|
||||
private external fun reset(ptr: Long, streamPtr: Long)
|
||||
private external fun decode(ptr: Long, streamPtr: Long)
|
||||
private external fun isEndpoint(ptr: Long, streamPtr: Long): Boolean
|
||||
private external fun isReady(ptr: Long, streamPtr: Long): Boolean
|
||||
private external fun getResult(ptr: Long, streamPtr: Long): Array<Any>
|
||||
|
||||
companion object {
|
||||
init {
|
||||
System.loadLibrary("sherpa-onnx-jni")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
/*
|
||||
Please see
|
||||
https://k2-fsa.github.io/sherpa/onnx/pretrained_models/index.html
|
||||
for a list of pre-trained models.
|
||||
|
||||
We only add a few here. Please change the following code
|
||||
to add your own. (It should be straightforward to add a new model
|
||||
by following the code)
|
||||
|
||||
@param type
|
||||
0 - sherpa-onnx-streaming-zipformer-bilingual-zh-en-2023-02-20 (Bilingual, Chinese + English)
|
||||
https://k2-fsa.github.io/sherpa/onnx/pretrained_models/zipformer-transducer-models.html#sherpa-onnx-streaming-zipformer-bilingual-zh-en-2023-02-20-bilingual-chinese-english
|
||||
|
||||
1 - csukuangfj/sherpa-onnx-lstm-zh-2023-02-20 (Chinese)
|
||||
|
||||
https://k2-fsa.github.io/sherpa/onnx/pretrained_models/lstm-transducer-models.html#csukuangfj-sherpa-onnx-lstm-zh-2023-02-20-chinese
|
||||
|
||||
2 - csukuangfj/sherpa-onnx-lstm-en-2023-02-17 (English)
|
||||
https://k2-fsa.github.io/sherpa/onnx/pretrained_models/lstm-transducer-models.html#csukuangfj-sherpa-onnx-lstm-en-2023-02-17-english
|
||||
|
||||
3,4 - pkufool/icefall-asr-zipformer-streaming-wenetspeech-20230615
|
||||
https://huggingface.co/pkufool/icefall-asr-zipformer-streaming-wenetspeech-20230615
|
||||
3 - int8 encoder
|
||||
4 - float32 encoder
|
||||
|
||||
5 - csukuangfj/sherpa-onnx-streaming-paraformer-bilingual-zh-en
|
||||
https://huggingface.co/csukuangfj/sherpa-onnx-streaming-paraformer-bilingual-zh-en
|
||||
|
||||
6 - sherpa-onnx-streaming-zipformer-en-2023-06-26
|
||||
https://huggingface.co/csukuangfj/sherpa-onnx-streaming-zipformer-en-2023-06-26
|
||||
|
||||
7 - shaojieli/sherpa-onnx-streaming-zipformer-fr-2023-04-14 (French)
|
||||
https://huggingface.co/shaojieli/sherpa-onnx-streaming-zipformer-fr-2023-04-14
|
||||
|
||||
8 - csukuangfj/sherpa-onnx-streaming-zipformer-bilingual-zh-en-2023-02-20 (Bilingual, Chinese + English)
|
||||
https://huggingface.co/csukuangfj/sherpa-onnx-streaming-zipformer-bilingual-zh-en-2023-02-20
|
||||
encoder int8, decoder/joiner float32
|
||||
|
||||
*/
|
||||
fun getModelConfig(type: Int): OnlineModelConfig? {
|
||||
when (type) {
|
||||
0 -> {
|
||||
val modelDir = "sherpa-onnx-streaming-zipformer-bilingual-zh-en-2023-02-20"
|
||||
return OnlineModelConfig(
|
||||
transducer = OnlineTransducerModelConfig(
|
||||
encoder = "$modelDir/encoder-epoch-99-avg-1.onnx",
|
||||
decoder = "$modelDir/decoder-epoch-99-avg-1.onnx",
|
||||
joiner = "$modelDir/joiner-epoch-99-avg-1.onnx",
|
||||
),
|
||||
tokens = "$modelDir/tokens.txt",
|
||||
modelType = "zipformer",
|
||||
)
|
||||
}
|
||||
|
||||
1 -> {
|
||||
val modelDir = "sherpa-onnx-lstm-zh-2023-02-20"
|
||||
return OnlineModelConfig(
|
||||
transducer = OnlineTransducerModelConfig(
|
||||
encoder = "$modelDir/encoder-epoch-11-avg-1.onnx",
|
||||
decoder = "$modelDir/decoder-epoch-11-avg-1.onnx",
|
||||
joiner = "$modelDir/joiner-epoch-11-avg-1.onnx",
|
||||
),
|
||||
tokens = "$modelDir/tokens.txt",
|
||||
modelType = "lstm",
|
||||
)
|
||||
}
|
||||
|
||||
2 -> {
|
||||
val modelDir = "sherpa-onnx-lstm-en-2023-02-17"
|
||||
return OnlineModelConfig(
|
||||
transducer = OnlineTransducerModelConfig(
|
||||
encoder = "$modelDir/encoder-epoch-99-avg-1.onnx",
|
||||
decoder = "$modelDir/decoder-epoch-99-avg-1.onnx",
|
||||
joiner = "$modelDir/joiner-epoch-99-avg-1.onnx",
|
||||
),
|
||||
tokens = "$modelDir/tokens.txt",
|
||||
modelType = "lstm",
|
||||
)
|
||||
}
|
||||
|
||||
3 -> {
|
||||
val modelDir = "icefall-asr-zipformer-streaming-wenetspeech-20230615"
|
||||
return OnlineModelConfig(
|
||||
transducer = OnlineTransducerModelConfig(
|
||||
encoder = "$modelDir/exp/encoder-epoch-12-avg-4-chunk-16-left-128.int8.onnx",
|
||||
decoder = "$modelDir/exp/decoder-epoch-12-avg-4-chunk-16-left-128.onnx",
|
||||
joiner = "$modelDir/exp/joiner-epoch-12-avg-4-chunk-16-left-128.onnx",
|
||||
),
|
||||
tokens = "$modelDir/data/lang_char/tokens.txt",
|
||||
modelType = "zipformer2",
|
||||
)
|
||||
}
|
||||
|
||||
4 -> {
|
||||
val modelDir = "icefall-asr-zipformer-streaming-wenetspeech-20230615"
|
||||
return OnlineModelConfig(
|
||||
transducer = OnlineTransducerModelConfig(
|
||||
encoder = "$modelDir/exp/encoder-epoch-12-avg-4-chunk-16-left-128.onnx",
|
||||
decoder = "$modelDir/exp/decoder-epoch-12-avg-4-chunk-16-left-128.onnx",
|
||||
joiner = "$modelDir/exp/joiner-epoch-12-avg-4-chunk-16-left-128.onnx",
|
||||
),
|
||||
tokens = "$modelDir/data/lang_char/tokens.txt",
|
||||
modelType = "zipformer2",
|
||||
)
|
||||
}
|
||||
|
||||
5 -> {
|
||||
val modelDir = "sherpa-onnx-streaming-paraformer-bilingual-zh-en"
|
||||
return OnlineModelConfig(
|
||||
paraformer = OnlineParaformerModelConfig(
|
||||
encoder = "$modelDir/encoder.int8.onnx",
|
||||
decoder = "$modelDir/decoder.int8.onnx",
|
||||
),
|
||||
tokens = "$modelDir/tokens.txt",
|
||||
modelType = "paraformer",
|
||||
)
|
||||
}
|
||||
|
||||
6 -> {
|
||||
val modelDir = "sherpa-onnx-streaming-zipformer-en-2023-06-26"
|
||||
return OnlineModelConfig(
|
||||
transducer = OnlineTransducerModelConfig(
|
||||
encoder = "$modelDir/encoder-epoch-99-avg-1-chunk-16-left-128.int8.onnx",
|
||||
decoder = "$modelDir/decoder-epoch-99-avg-1-chunk-16-left-128.onnx",
|
||||
joiner = "$modelDir/joiner-epoch-99-avg-1-chunk-16-left-128.onnx",
|
||||
),
|
||||
tokens = "$modelDir/tokens.txt",
|
||||
modelType = "zipformer2",
|
||||
)
|
||||
}
|
||||
|
||||
7 -> {
|
||||
val modelDir = "sherpa-onnx-streaming-zipformer-fr-2023-04-14"
|
||||
return OnlineModelConfig(
|
||||
transducer = OnlineTransducerModelConfig(
|
||||
encoder = "$modelDir/encoder-epoch-29-avg-9-with-averaged-model.int8.onnx",
|
||||
decoder = "$modelDir/decoder-epoch-29-avg-9-with-averaged-model.onnx",
|
||||
joiner = "$modelDir/joiner-epoch-29-avg-9-with-averaged-model.onnx",
|
||||
),
|
||||
tokens = "$modelDir/tokens.txt",
|
||||
modelType = "zipformer",
|
||||
)
|
||||
}
|
||||
|
||||
8 -> {
|
||||
val modelDir = "sherpa-onnx-streaming-zipformer-bilingual-zh-en-2023-02-20"
|
||||
return OnlineModelConfig(
|
||||
transducer = OnlineTransducerModelConfig(
|
||||
encoder = "$modelDir/encoder-epoch-99-avg-1.int8.onnx",
|
||||
decoder = "$modelDir/decoder-epoch-99-avg-1.onnx",
|
||||
joiner = "$modelDir/joiner-epoch-99-avg-1.int8.onnx",
|
||||
),
|
||||
tokens = "$modelDir/tokens.txt",
|
||||
modelType = "zipformer",
|
||||
)
|
||||
}
|
||||
|
||||
9 -> {
|
||||
val modelDir = "sherpa-onnx-streaming-zipformer-zh-14M-2023-02-23"
|
||||
return OnlineModelConfig(
|
||||
transducer = OnlineTransducerModelConfig(
|
||||
encoder = "$modelDir/encoder-epoch-99-avg-1.int8.onnx",
|
||||
decoder = "$modelDir/decoder-epoch-99-avg-1.onnx",
|
||||
joiner = "$modelDir/joiner-epoch-99-avg-1.int8.onnx",
|
||||
),
|
||||
tokens = "$modelDir/tokens.txt",
|
||||
modelType = "zipformer",
|
||||
)
|
||||
}
|
||||
|
||||
10 -> {
|
||||
val modelDir = "sherpa-onnx-streaming-zipformer-en-20M-2023-02-17"
|
||||
return OnlineModelConfig(
|
||||
transducer = OnlineTransducerModelConfig(
|
||||
encoder = "$modelDir/encoder-epoch-99-avg-1.int8.onnx",
|
||||
decoder = "$modelDir/decoder-epoch-99-avg-1.onnx",
|
||||
joiner = "$modelDir/joiner-epoch-99-avg-1.int8.onnx",
|
||||
),
|
||||
tokens = "$modelDir/tokens.txt",
|
||||
modelType = "zipformer",
|
||||
)
|
||||
}
|
||||
}
|
||||
return null
|
||||
}
|
||||
|
||||
/*
|
||||
Please see
|
||||
https://k2-fsa.github.io/sherpa/onnx/pretrained_models/index.html
|
||||
for a list of pre-trained models.
|
||||
|
||||
We only add a few here. Please change the following code
|
||||
to add your own LM model. (It should be straightforward to train a new NN LM model
|
||||
by following the code, https://github.com/k2-fsa/icefall/blob/master/icefall/rnn_lm/train.py)
|
||||
|
||||
@param type
|
||||
0 - sherpa-onnx-streaming-zipformer-bilingual-zh-en-2023-02-20 (Bilingual, Chinese + English)
|
||||
https://k2-fsa.github.io/sherpa/onnx/pretrained_models/zipformer-transducer-models.html#sherpa-onnx-streaming-zipformer-bilingual-zh-en-2023-02-20-bilingual-chinese-english
|
||||
*/
|
||||
fun getOnlineLMConfig(type: Int): OnlineLMConfig {
|
||||
when (type) {
|
||||
0 -> {
|
||||
val modelDir = "sherpa-onnx-streaming-zipformer-bilingual-zh-en-2023-02-20"
|
||||
return OnlineLMConfig(
|
||||
model = "$modelDir/with-state-epoch-99-avg-1.int8.onnx",
|
||||
scale = 0.5f,
|
||||
)
|
||||
}
|
||||
}
|
||||
return OnlineLMConfig()
|
||||
}
|
||||
|
||||
fun getEndpointConfig(): EndpointConfig {
|
||||
return EndpointConfig(
|
||||
rule1 = EndpointRule(false, 2.4f, 0.0f),
|
||||
rule2 = EndpointRule(true, 1.4f, 0.0f),
|
||||
rule3 = EndpointRule(false, 0.0f, 20.0f)
|
||||
)
|
||||
}
|
||||
|
||||
27
sherpa-onnx/kotlin-api/OnlineStream.kt
Normal file
27
sherpa-onnx/kotlin-api/OnlineStream.kt
Normal file
@@ -0,0 +1,27 @@
|
||||
package com.k2fsa.sherpa.onnx
|
||||
|
||||
class OnlineStream(var ptr: Long = 0) {
|
||||
fun acceptWaveform(samples: FloatArray, sampleRate: Int) =
|
||||
acceptWaveform(ptr, samples, sampleRate)
|
||||
|
||||
fun inputFinished() = inputFinished(ptr)
|
||||
|
||||
protected fun finalize() {
|
||||
if (ptr != 0L) {
|
||||
delete(ptr)
|
||||
ptr = 0
|
||||
}
|
||||
}
|
||||
|
||||
fun release() = finalize()
|
||||
|
||||
private external fun acceptWaveform(ptr: Long, samples: FloatArray, sampleRate: Int)
|
||||
private external fun inputFinished(ptr: Long)
|
||||
private external fun delete(ptr: Long)
|
||||
|
||||
companion object {
|
||||
init {
|
||||
System.loadLibrary("sherpa-onnx-jni")
|
||||
}
|
||||
}
|
||||
}
|
||||
164
sherpa-onnx/kotlin-api/Speaker.kt
Normal file
164
sherpa-onnx/kotlin-api/Speaker.kt
Normal file
@@ -0,0 +1,164 @@
|
||||
package com.k2fsa.sherpa.onnx
|
||||
|
||||
import android.content.res.AssetManager
|
||||
import android.util.Log
|
||||
|
||||
data class SpeakerEmbeddingExtractorConfig(
|
||||
val model: String,
|
||||
var numThreads: Int = 1,
|
||||
var debug: Boolean = false,
|
||||
var provider: String = "cpu",
|
||||
)
|
||||
|
||||
class SpeakerEmbeddingExtractor(
|
||||
assetManager: AssetManager? = null,
|
||||
config: SpeakerEmbeddingExtractorConfig,
|
||||
) {
|
||||
private var ptr: Long
|
||||
|
||||
init {
|
||||
ptr = if (assetManager != null) {
|
||||
newFromAsset(assetManager, config)
|
||||
} else {
|
||||
newFromFile(config)
|
||||
}
|
||||
}
|
||||
|
||||
protected fun finalize() {
|
||||
if (ptr != 0L) {
|
||||
delete(ptr)
|
||||
ptr = 0
|
||||
}
|
||||
}
|
||||
|
||||
fun release() = finalize()
|
||||
|
||||
fun createStream(): OnlineStream {
|
||||
val p = createStream(ptr)
|
||||
return OnlineStream(p)
|
||||
}
|
||||
|
||||
fun isReady(stream: OnlineStream) = isReady(ptr, stream.ptr)
|
||||
fun compute(stream: OnlineStream) = compute(ptr, stream.ptr)
|
||||
fun dim() = dim(ptr)
|
||||
|
||||
private external fun newFromAsset(
|
||||
assetManager: AssetManager,
|
||||
config: SpeakerEmbeddingExtractorConfig,
|
||||
): Long
|
||||
|
||||
private external fun newFromFile(
|
||||
config: SpeakerEmbeddingExtractorConfig,
|
||||
): Long
|
||||
|
||||
private external fun delete(ptr: Long)
|
||||
|
||||
private external fun createStream(ptr: Long): Long
|
||||
|
||||
private external fun isReady(ptr: Long, streamPtr: Long): Boolean
|
||||
|
||||
private external fun compute(ptr: Long, streamPtr: Long): FloatArray
|
||||
|
||||
private external fun dim(ptr: Long): Int
|
||||
|
||||
companion object {
|
||||
init {
|
||||
System.loadLibrary("sherpa-onnx-jni")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
class SpeakerEmbeddingManager(val dim: Int) {
|
||||
private var ptr: Long
|
||||
|
||||
init {
|
||||
ptr = create(dim)
|
||||
}
|
||||
|
||||
protected fun finalize() {
|
||||
if (ptr != 0L) {
|
||||
delete(ptr)
|
||||
ptr = 0
|
||||
}
|
||||
}
|
||||
|
||||
fun release() = finalize()
|
||||
fun add(name: String, embedding: FloatArray) = add(ptr, name, embedding)
|
||||
fun add(name: String, embedding: Array<FloatArray>) = addList(ptr, name, embedding)
|
||||
fun remove(name: String) = remove(ptr, name)
|
||||
fun search(embedding: FloatArray, threshold: Float) = search(ptr, embedding, threshold)
|
||||
fun verify(name: String, embedding: FloatArray, threshold: Float) =
|
||||
verify(ptr, name, embedding, threshold)
|
||||
|
||||
fun contains(name: String) = contains(ptr, name)
|
||||
fun numSpeakers() = numSpeakers(ptr)
|
||||
|
||||
fun allSpeakerNames() = allSpeakerNames(ptr)
|
||||
|
||||
private external fun create(dim: Int): Long
|
||||
private external fun delete(ptr: Long): Unit
|
||||
private external fun add(ptr: Long, name: String, embedding: FloatArray): Boolean
|
||||
private external fun addList(ptr: Long, name: String, embedding: Array<FloatArray>): Boolean
|
||||
private external fun remove(ptr: Long, name: String): Boolean
|
||||
private external fun search(ptr: Long, embedding: FloatArray, threshold: Float): String
|
||||
private external fun verify(
|
||||
ptr: Long,
|
||||
name: String,
|
||||
embedding: FloatArray,
|
||||
threshold: Float
|
||||
): Boolean
|
||||
|
||||
private external fun contains(ptr: Long, name: String): Boolean
|
||||
private external fun numSpeakers(ptr: Long): Int
|
||||
|
||||
private external fun allSpeakerNames(ptr: Long): Array<String>
|
||||
|
||||
companion object {
|
||||
init {
|
||||
System.loadLibrary("sherpa-onnx-jni")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Please download the model file from
|
||||
// https://github.com/k2-fsa/sherpa-onnx/releases/tag/speaker-recongition-models
|
||||
// and put it inside the assets directory.
|
||||
//
|
||||
// Please don't put it in a subdirectory of assets
|
||||
private val modelName = "3dspeaker_speech_eres2net_base_sv_zh-cn_3dspeaker_16k.onnx"
|
||||
|
||||
object SpeakerRecognition {
|
||||
var _extractor: SpeakerEmbeddingExtractor? = null
|
||||
var _manager: SpeakerEmbeddingManager? = null
|
||||
|
||||
val extractor: SpeakerEmbeddingExtractor
|
||||
get() {
|
||||
return _extractor!!
|
||||
}
|
||||
|
||||
val manager: SpeakerEmbeddingManager
|
||||
get() {
|
||||
return _manager!!
|
||||
}
|
||||
|
||||
fun initExtractor(assetManager: AssetManager? = null) {
|
||||
synchronized(this) {
|
||||
if (_extractor != null) {
|
||||
return
|
||||
}
|
||||
Log.i("sherpa-onnx", "Initializing speaker embedding extractor")
|
||||
|
||||
_extractor = SpeakerEmbeddingExtractor(
|
||||
assetManager = assetManager,
|
||||
config = SpeakerEmbeddingExtractorConfig(
|
||||
model = modelName,
|
||||
numThreads = 2,
|
||||
debug = false,
|
||||
provider = "cpu",
|
||||
)
|
||||
)
|
||||
|
||||
_manager = SpeakerEmbeddingManager(dim = _extractor!!.dim())
|
||||
}
|
||||
}
|
||||
}
|
||||
103
sherpa-onnx/kotlin-api/SpokenLanguageIdentification.kt
Normal file
103
sherpa-onnx/kotlin-api/SpokenLanguageIdentification.kt
Normal file
@@ -0,0 +1,103 @@
|
||||
package com.k2fsa.sherpa.onnx
|
||||
|
||||
import android.content.res.AssetManager
|
||||
|
||||
data class SpokenLanguageIdentificationWhisperConfig(
|
||||
var encoder: String,
|
||||
var decoder: String,
|
||||
var tailPaddings: Int = -1,
|
||||
)
|
||||
|
||||
data class SpokenLanguageIdentificationConfig(
|
||||
var whisper: SpokenLanguageIdentificationWhisperConfig,
|
||||
var numThreads: Int = 1,
|
||||
var debug: Boolean = false,
|
||||
var provider: String = "cpu",
|
||||
)
|
||||
|
||||
class SpokenLanguageIdentification(
|
||||
assetManager: AssetManager? = null,
|
||||
config: SpokenLanguageIdentificationConfig,
|
||||
) {
|
||||
private var ptr: Long
|
||||
|
||||
init {
|
||||
ptr = if (assetManager != null) {
|
||||
newFromAsset(assetManager, config)
|
||||
} else {
|
||||
newFromFile(config)
|
||||
}
|
||||
}
|
||||
|
||||
protected fun finalize() {
|
||||
if (ptr != 0L) {
|
||||
delete(ptr)
|
||||
ptr = 0
|
||||
}
|
||||
}
|
||||
|
||||
fun release() = finalize()
|
||||
|
||||
fun createStream(): OfflineStream {
|
||||
val p = createStream(ptr)
|
||||
return OfflineStream(p)
|
||||
}
|
||||
|
||||
fun compute(stream: OfflineStream) = compute(ptr, stream.ptr)
|
||||
|
||||
private external fun newFromAsset(
|
||||
assetManager: AssetManager,
|
||||
config: SpokenLanguageIdentificationConfig,
|
||||
): Long
|
||||
|
||||
private external fun newFromFile(
|
||||
config: SpokenLanguageIdentificationConfig,
|
||||
): Long
|
||||
|
||||
private external fun delete(ptr: Long)
|
||||
|
||||
private external fun createStream(ptr: Long): Long
|
||||
|
||||
private external fun compute(ptr: Long, streamPtr: Long): String
|
||||
|
||||
companion object {
|
||||
init {
|
||||
System.loadLibrary("sherpa-onnx-jni")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// please refer to
|
||||
// https://k2-fsa.github.io/sherpa/onnx/spolken-language-identification/pretrained_models.html#whisper
|
||||
// to download more models
|
||||
fun getSpokenLanguageIdentificationConfig(
|
||||
type: Int,
|
||||
numThreads: Int = 1
|
||||
): SpokenLanguageIdentificationConfig? {
|
||||
when (type) {
|
||||
0 -> {
|
||||
val modelDir = "sherpa-onnx-whisper-tiny"
|
||||
return SpokenLanguageIdentificationConfig(
|
||||
whisper = SpokenLanguageIdentificationWhisperConfig(
|
||||
encoder = "$modelDir/tiny-encoder.int8.onnx",
|
||||
decoder = "$modelDir/tiny-decoder.int8.onnx",
|
||||
),
|
||||
numThreads = numThreads,
|
||||
debug = true,
|
||||
)
|
||||
}
|
||||
|
||||
1 -> {
|
||||
val modelDir = "sherpa-onnx-whisper-base"
|
||||
return SpokenLanguageIdentificationConfig(
|
||||
whisper = SpokenLanguageIdentificationWhisperConfig(
|
||||
encoder = "$modelDir/tiny-encoder.int8.onnx",
|
||||
decoder = "$modelDir/tiny-decoder.int8.onnx",
|
||||
),
|
||||
numThreads = 1,
|
||||
debug = true,
|
||||
)
|
||||
}
|
||||
}
|
||||
return null
|
||||
}
|
||||
104
sherpa-onnx/kotlin-api/Vad.kt
Normal file
104
sherpa-onnx/kotlin-api/Vad.kt
Normal file
@@ -0,0 +1,104 @@
|
||||
// Copyright (c) 2023 Xiaomi Corporation
|
||||
package com.k2fsa.sherpa.onnx
|
||||
|
||||
import android.content.res.AssetManager
|
||||
|
||||
data class SileroVadModelConfig(
|
||||
var model: String,
|
||||
var threshold: Float = 0.5F,
|
||||
var minSilenceDuration: Float = 0.25F,
|
||||
var minSpeechDuration: Float = 0.25F,
|
||||
var windowSize: Int = 512,
|
||||
)
|
||||
|
||||
data class VadModelConfig(
|
||||
var sileroVadModelConfig: SileroVadModelConfig,
|
||||
var sampleRate: Int = 16000,
|
||||
var numThreads: Int = 1,
|
||||
var provider: String = "cpu",
|
||||
var debug: Boolean = false,
|
||||
)
|
||||
|
||||
class Vad(
|
||||
assetManager: AssetManager? = null,
|
||||
var config: VadModelConfig,
|
||||
) {
|
||||
private val ptr: Long
|
||||
|
||||
init {
|
||||
if (assetManager != null) {
|
||||
ptr = newFromAsset(assetManager, config)
|
||||
} else {
|
||||
ptr = newFromFile(config)
|
||||
}
|
||||
}
|
||||
|
||||
protected fun finalize() {
|
||||
delete(ptr)
|
||||
}
|
||||
|
||||
fun acceptWaveform(samples: FloatArray) = acceptWaveform(ptr, samples)
|
||||
|
||||
fun empty(): Boolean = empty(ptr)
|
||||
fun pop() = pop(ptr)
|
||||
|
||||
// return an array containing
|
||||
// [start: Int, samples: FloatArray]
|
||||
fun front() = front(ptr)
|
||||
|
||||
fun clear() = clear(ptr)
|
||||
|
||||
fun isSpeechDetected(): Boolean = isSpeechDetected(ptr)
|
||||
|
||||
fun reset() = reset(ptr)
|
||||
|
||||
private external fun delete(ptr: Long)
|
||||
|
||||
private external fun newFromAsset(
|
||||
assetManager: AssetManager,
|
||||
config: VadModelConfig,
|
||||
): Long
|
||||
|
||||
private external fun newFromFile(
|
||||
config: VadModelConfig,
|
||||
): Long
|
||||
|
||||
private external fun acceptWaveform(ptr: Long, samples: FloatArray)
|
||||
private external fun empty(ptr: Long): Boolean
|
||||
private external fun pop(ptr: Long)
|
||||
private external fun clear(ptr: Long)
|
||||
private external fun front(ptr: Long): Array<Any>
|
||||
private external fun isSpeechDetected(ptr: Long): Boolean
|
||||
private external fun reset(ptr: Long)
|
||||
|
||||
companion object {
|
||||
init {
|
||||
System.loadLibrary("sherpa-onnx-jni")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Please visit
|
||||
// https://github.com/snakers4/silero-vad/blob/master/files/silero_vad.onnx
|
||||
// to download silero_vad.onnx
|
||||
// and put it inside the assets/
|
||||
// directory
|
||||
fun getVadModelConfig(type: Int): VadModelConfig? {
|
||||
when (type) {
|
||||
0 -> {
|
||||
return VadModelConfig(
|
||||
sileroVadModelConfig = SileroVadModelConfig(
|
||||
model = "silero_vad.onnx",
|
||||
threshold = 0.5F,
|
||||
minSilenceDuration = 0.25F,
|
||||
minSpeechDuration = 0.25F,
|
||||
windowSize = 512,
|
||||
),
|
||||
sampleRate = 16000,
|
||||
numThreads = 1,
|
||||
provider = "cpu",
|
||||
)
|
||||
}
|
||||
}
|
||||
return null;
|
||||
}
|
||||
29
sherpa-onnx/kotlin-api/WaveReader.kt
Normal file
29
sherpa-onnx/kotlin-api/WaveReader.kt
Normal file
@@ -0,0 +1,29 @@
|
||||
// Copyright (c) 2023 Xiaomi Corporation
|
||||
package com.k2fsa.sherpa.onnx
|
||||
|
||||
import android.content.res.AssetManager
|
||||
|
||||
class WaveReader {
|
||||
companion object {
|
||||
// Read a mono wave file asset
|
||||
// The returned array has two entries:
|
||||
// - the first entry contains an 1-D float array
|
||||
// - the second entry is the sample rate
|
||||
external fun readWaveFromAsset(
|
||||
assetManager: AssetManager,
|
||||
filename: String,
|
||||
): Array<Any>
|
||||
|
||||
// Read a mono wave file from disk
|
||||
// The returned array has two entries:
|
||||
// - the first entry contains an 1-D float array
|
||||
// - the second entry is the sample rate
|
||||
external fun readWaveFromFile(
|
||||
filename: String,
|
||||
): Array<Any>
|
||||
|
||||
init {
|
||||
System.loadLibrary("sherpa-onnx-jni")
|
||||
}
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user