Add Android APK for Silero VAD (#335)

This commit is contained in:
Fangjun Kuang
2023-09-23 20:39:13 +08:00
committed by GitHub
parent 65ec4dc741
commit 6e60a77d89
68 changed files with 1562 additions and 45 deletions

View File

@@ -32,7 +32,7 @@ std::vector<std::vector<std::string>> SplitToBatches(
process_num += batch_size;
}
if (itr != input.cend()) {
outputs.emplace_back(itr, input.cend());
outputs.emplace_back(itr, input.cend());
}
return outputs;
}
@@ -41,8 +41,8 @@ std::vector<std::string> LoadScpFile(const std::string &wav_scp_path) {
std::vector<std::string> wav_paths;
std::ifstream in(wav_scp_path);
if (!in.is_open()) {
fprintf(stderr, "Failed to open file: %s.\n", wav_scp_path.c_str());
return wav_paths;
fprintf(stderr, "Failed to open file: %s.\n", wav_scp_path.c_str());
return wav_paths;
}
std::string line, column1, column2;
while (std::getline(in, line)) {
@@ -55,8 +55,8 @@ std::vector<std::string> LoadScpFile(const std::string &wav_scp_path) {
}
void AsrInference(const std::vector<std::vector<std::string>> &chunk_wav_paths,
sherpa_onnx::OfflineRecognizer* recognizer,
float* total_length, float* total_time) {
sherpa_onnx::OfflineRecognizer *recognizer,
float *total_length, float *total_time) {
std::vector<std::unique_ptr<sherpa_onnx::OfflineStream>> ss;
std::vector<sherpa_onnx::OfflineStream *> ss_pointers;
float duration = 0.0f;
@@ -70,7 +70,7 @@ void AsrInference(const std::vector<std::vector<std::string>> &chunk_wav_paths,
sherpa_onnx::ReadWave(wav_filename, &sampling_rate, &is_ok);
if (!is_ok) {
fprintf(stderr, "Failed to read %s\n", wav_filename.c_str());
continue;
continue;
}
duration += samples.size() / static_cast<float>(sampling_rate);
auto s = recognizer->CreateStream();
@@ -97,7 +97,7 @@ void AsrInference(const std::vector<std::vector<std::string>> &chunk_wav_paths,
sherpa_onnx::ReadWave(wav_filename, &sampling_rate, &is_ok);
if (!is_ok) {
fprintf(stderr, "Failed to read %s\n", wav_filename.c_str());
continue;
continue;
}
duration += samples.size() / static_cast<float>(sampling_rate);
auto s = recognizer->CreateStream();
@@ -109,9 +109,9 @@ void AsrInference(const std::vector<std::vector<std::string>> &chunk_wav_paths,
recognizer->DecodeStreams(ss_pointers.data(), ss_pointers.size());
const auto end = std::chrono::steady_clock::now();
float elapsed_seconds =
std::chrono::duration_cast<std::chrono::milliseconds>(end - begin)
.count() /
1000.;
std::chrono::duration_cast<std::chrono::milliseconds>(end - begin)
.count() /
1000.;
elapsed_seconds_batch += elapsed_seconds;
int i = 0;
for (const auto &wav_filename : wav_paths) {
@@ -122,7 +122,7 @@ void AsrInference(const std::vector<std::vector<std::string>> &chunk_wav_paths,
ss_pointers.clear();
ss.clear();
}
fprintf(stderr, "thread %lu.\n", std::this_thread::get_id());
{
std::lock_guard<std::mutex> guard(mtx);
*total_length += duration;
@@ -132,7 +132,6 @@ void AsrInference(const std::vector<std::vector<std::string>> &chunk_wav_paths,
}
}
int main(int32_t argc, char *argv[]) {
const char *kUsageMessage = R"usage(
Speech recognition using non-streaming models with sherpa-onnx.
@@ -223,17 +222,17 @@ https://k2-fsa.github.io/sherpa/onnx/pretrained_models/index.html
for a list of pre-trained models to download.
)usage";
std::string wav_scp = ""; // file path, kaldi style wav list.
int32_t nj = 1; // thread number
int32_t batch_size = 1; // number of wav files processed at once.
int32_t nj = 1; // thread number
int32_t batch_size = 1; // number of wav files processed at once.
sherpa_onnx::ParseOptions po(kUsageMessage);
sherpa_onnx::OfflineRecognizerConfig config;
config.Register(&po);
po.Register("wav-scp", &wav_scp,
"a file including wav-id and wav-path, kaldi style wav list."
"default="". when it is not empty, wav files which positional "
"default="
". when it is not empty, wav files which positional "
"parameters provide are invalid.");
po.Register("nj", &nj,
"multi-thread num for decoding, default=1");
po.Register("nj", &nj, "multi-thread num for decoding, default=1");
po.Register("batch-size", &batch_size,
"number of wav files processed at once during the decoding"
"process. default=1");
@@ -262,7 +261,8 @@ for a list of pre-trained models to download.
1000.;
fprintf(stderr,
"Started nj: %d, batch_size: %d, wav_path: %s. recognizer init time: "
"%.6f\n", nj, batch_size, wav_scp.c_str(), elapsed_seconds);
"%.6f\n",
nj, batch_size, wav_scp.c_str(), elapsed_seconds);
std::this_thread::sleep_for(std::chrono::seconds(10)); // sleep 10s
std::vector<std::string> wav_paths;
if (!wav_scp.empty()) {
@@ -282,12 +282,12 @@ for a list of pre-trained models to download.
float total_length = 0.0f;
float total_time = 0.0f;
for (int i = 0; i < nj; i++) {
threads.emplace_back(std::thread(AsrInference, batch_wav_paths,
&recognizer, &total_length, &total_time));
threads.emplace_back(std::thread(AsrInference, batch_wav_paths, &recognizer,
&total_length, &total_time));
}
for (auto& thread : threads) {
thread.join();
for (auto &thread : threads) {
thread.join();
}
fprintf(stderr, "num threads: %d\n", config.model_config.num_threads);
@@ -297,8 +297,8 @@ for a list of pre-trained models to download.
}
fprintf(stderr, "Elapsed seconds: %.3f s\n", total_time);
float rtf = total_time / total_length;
fprintf(stderr, "Real time factor (RTF): %.6f / %.6f = %.4f\n",
total_time, total_length, rtf);
fprintf(stderr, "Real time factor (RTF): %.6f / %.6f = %.4f\n", total_time,
total_length, rtf);
fprintf(stderr, "SPEEDUP: %.4f\n", 1.0 / rtf);
return 0;

View File

@@ -37,6 +37,29 @@ class SileroVadModel::Impl {
min_speech_samples_ = sample_rate_ * config_.silero_vad.min_speech_duration;
}
#if __ANDROID_API__ >= 9
Impl(AAssetManager *mgr, const VadModelConfig &config)
: config_(config),
env_(ORT_LOGGING_LEVEL_ERROR),
sess_opts_(GetSessionOptions(config)),
allocator_{} {
auto buf = ReadFile(mgr, config.silero_vad.model);
Init(buf.data(), buf.size());
sample_rate_ = config.sample_rate;
if (sample_rate_ != 16000) {
SHERPA_ONNX_LOGE("Expected sample rate 16000. Given: %d",
config.sample_rate);
exit(-1);
}
min_silence_samples_ =
sample_rate_ * config_.silero_vad.min_silence_duration;
min_speech_samples_ = sample_rate_ * config_.silero_vad.min_speech_duration;
}
#endif
void Reset() {
// 2 - number of LSTM layer
// 1 - batch size
@@ -260,6 +283,11 @@ class SileroVadModel::Impl {
SileroVadModel::SileroVadModel(const VadModelConfig &config)
: impl_(std::make_unique<Impl>(config)) {}
#if __ANDROID_API__ >= 9
SileroVadModel::SileroVadModel(AAssetManager *mgr, const VadModelConfig &config)
: impl_(std::make_unique<Impl>(mgr, config)) {}
#endif
SileroVadModel::~SileroVadModel() = default;
void SileroVadModel::Reset() { return impl_->Reset(); }

View File

@@ -6,6 +6,11 @@
#include <memory>
#if __ANDROID_API__ >= 9
#include "android/asset_manager.h"
#include "android/asset_manager_jni.h"
#endif
#include "sherpa-onnx/csrc/vad-model.h"
namespace sherpa_onnx {
@@ -13,6 +18,11 @@ namespace sherpa_onnx {
class SileroVadModel : public VadModel {
public:
explicit SileroVadModel(const VadModelConfig &config);
#if __ANDROID_API__ >= 9
SileroVadModel(AAssetManager *mgr, const VadModelConfig &config);
#endif
~SileroVadModel() override;
// reset the internal model states

View File

@@ -13,4 +13,12 @@ std::unique_ptr<VadModel> VadModel::Create(const VadModelConfig &config) {
return std::make_unique<SileroVadModel>(config);
}
#if __ANDROID_API__ >= 9
std::unique_ptr<VadModel> VadModel::Create(AAssetManager *mgr,
const VadModelConfig &config) {
// TODO(fangjun): Support other VAD models.
return std::make_unique<SileroVadModel>(mgr, config);
}
#endif
} // namespace sherpa_onnx

View File

@@ -6,6 +6,11 @@
#include <memory>
#if __ANDROID_API__ >= 9
#include "android/asset_manager.h"
#include "android/asset_manager_jni.h"
#endif
#include "sherpa-onnx/csrc/vad-model-config.h"
namespace sherpa_onnx {
@@ -16,6 +21,11 @@ class VadModel {
static std::unique_ptr<VadModel> Create(const VadModelConfig &config);
#if __ANDROID_API__ >= 9
static std::unique_ptr<VadModel> Create(AAssetManager *mgr,
const VadModelConfig &config);
#endif
// reset the internal model states
virtual void Reset() = 0;

View File

@@ -19,10 +19,32 @@ class VoiceActivityDetector::Impl {
config_(config),
buffer_(buffer_size_in_seconds * config.sample_rate) {}
void AcceptWaveform(const float *samples, int32_t n) {
buffer_.Push(samples, n);
#if __ANDROID_API__ >= 9
Impl(AAssetManager *mgr, const VadModelConfig &config,
float buffer_size_in_seconds = 60)
: model_(VadModel::Create(mgr, config)),
config_(config),
buffer_(buffer_size_in_seconds * config.sample_rate) {}
#endif
void AcceptWaveform(const float *samples, int32_t n) {
int32_t window_size = model_->WindowSize();
// note n is usally window_size and there is no need to use
// an extra buffer here
last_.insert(last_.end(), samples, samples + n);
int32_t k = static_cast<int32_t>(last_.size()) / window_size;
const float *p = last_.data();
bool is_speech = false;
for (int32_t i = 0; i != k; ++i, p += window_size) {
buffer_.Push(p, window_size);
is_speech = model_->IsSpeech(p, window_size);
}
last_ = std::vector<float>(
p, static_cast<const float *>(last_.data()) + last_.size());
bool is_speech = model_->IsSpeech(samples, n);
if (is_speech) {
if (start_ == -1) {
// beginning of speech
@@ -31,15 +53,15 @@ class VoiceActivityDetector::Impl {
}
} else {
// non-speech
if (start_ != -1) {
if (start_ != -1 && buffer_.Size()) {
// end of speech, save the speech segment
int32_t end = buffer_.Tail() - model_->MinSilenceDurationSamples();
std::vector<float> samples = buffer_.Get(start_, end - start_);
std::vector<float> s = buffer_.Get(start_, end - start_);
SpeechSegment segment;
segment.start = start_;
segment.samples = std::move(samples);
segment.samples = std::move(s);
segments_.push(std::move(segment));
@@ -73,6 +95,7 @@ class VoiceActivityDetector::Impl {
std::unique_ptr<VadModel> model_;
VadModelConfig config_;
CircularBuffer buffer_;
std::vector<float> last_;
int32_t start_ = -1;
};
@@ -81,6 +104,13 @@ VoiceActivityDetector::VoiceActivityDetector(
const VadModelConfig &config, float buffer_size_in_seconds /*= 60*/)
: impl_(std::make_unique<Impl>(config, buffer_size_in_seconds)) {}
#if __ANDROID_API__ >= 9
VoiceActivityDetector::VoiceActivityDetector(
AAssetManager *mgr, const VadModelConfig &config,
float buffer_size_in_seconds /*= 60*/)
: impl_(std::make_unique<Impl>(mgr, config, buffer_size_in_seconds)) {}
#endif
VoiceActivityDetector::~VoiceActivityDetector() = default;
void VoiceActivityDetector::AcceptWaveform(const float *samples, int32_t n) {

View File

@@ -7,6 +7,11 @@
#include <memory>
#include <vector>
#if __ANDROID_API__ >= 9
#include "android/asset_manager.h"
#include "android/asset_manager_jni.h"
#endif
#include "sherpa-onnx/csrc/vad-model-config.h"
namespace sherpa_onnx {
@@ -20,6 +25,12 @@ class VoiceActivityDetector {
public:
explicit VoiceActivityDetector(const VadModelConfig &config,
float buffer_size_in_seconds = 60);
#if __ANDROID_API__ >= 9
VoiceActivityDetector(AAssetManager *mgr, const VadModelConfig &config,
float buffer_size_in_seconds = 60);
#endif
~VoiceActivityDetector();
void AcceptWaveform(const float *samples, int32_t n);

View File

@@ -23,6 +23,7 @@
#include "sherpa-onnx/csrc/offline-recognizer.h"
#include "sherpa-onnx/csrc/online-recognizer.h"
#include "sherpa-onnx/csrc/onnx-utils.h"
#include "sherpa-onnx/csrc/voice-activity-detector.h"
#include "sherpa-onnx/csrc/wave-reader.h"
#define SHERPA_ONNX_EXTERN_C extern "C"
@@ -106,6 +107,33 @@ class SherpaOnnxOffline {
OfflineRecognizer recognizer_;
};
class SherpaOnnxVad {
public:
#if __ANDROID_API__ >= 9
SherpaOnnxVad(AAssetManager *mgr, const VadModelConfig &config)
: vad_(mgr, config) {}
#endif
explicit SherpaOnnxVad(const VadModelConfig &config) : vad_(config) {}
void AcceptWaveform(const float *samples, int32_t n) {
vad_.AcceptWaveform(samples, n);
}
bool Empty() const { return vad_.Empty(); }
void Pop() { vad_.Pop(); }
const SpeechSegment &Front() const { return vad_.Front(); }
bool IsSpeechDetected() const { return vad_.IsSpeechDetected(); }
void Reset() { vad_.Reset(); }
private:
VoiceActivityDetector vad_;
};
static OnlineRecognizerConfig GetConfig(JNIEnv *env, jobject config) {
OnlineRecognizerConfig ans;
@@ -411,8 +439,165 @@ static OfflineRecognizerConfig GetOfflineConfig(JNIEnv *env, jobject config) {
return ans;
}
static VadModelConfig GetVadModelConfig(JNIEnv *env, jobject config) {
VadModelConfig ans;
jclass cls = env->GetObjectClass(config);
jfieldID fid;
// silero_vad
fid = env->GetFieldID(cls, "sileroVadModelConfig",
"Lcom/k2fsa/sherpa/onnx/SileroVadModelConfig;");
jobject silero_vad_config = env->GetObjectField(config, fid);
jclass silero_vad_config_cls = env->GetObjectClass(silero_vad_config);
fid = env->GetFieldID(silero_vad_config_cls, "model", "Ljava/lang/String;");
auto s = (jstring)env->GetObjectField(silero_vad_config, fid);
auto p = env->GetStringUTFChars(s, nullptr);
ans.silero_vad.model = p;
env->ReleaseStringUTFChars(s, p);
fid = env->GetFieldID(silero_vad_config_cls, "threshold", "F");
ans.silero_vad.threshold = env->GetFloatField(silero_vad_config, fid);
fid = env->GetFieldID(silero_vad_config_cls, "minSilenceDuration", "F");
ans.silero_vad.min_silence_duration =
env->GetFloatField(silero_vad_config, fid);
fid = env->GetFieldID(silero_vad_config_cls, "minSpeechDuration", "F");
ans.silero_vad.min_speech_duration =
env->GetFloatField(silero_vad_config, fid);
fid = env->GetFieldID(silero_vad_config_cls, "windowSize", "I");
ans.silero_vad.window_size = env->GetIntField(silero_vad_config, fid);
fid = env->GetFieldID(cls, "sampleRate", "I");
ans.sample_rate = env->GetIntField(config, fid);
fid = env->GetFieldID(cls, "numThreads", "I");
ans.num_threads = env->GetIntField(config, fid);
fid = env->GetFieldID(cls, "provider", "Ljava/lang/String;");
s = (jstring)env->GetObjectField(config, fid);
p = env->GetStringUTFChars(s, nullptr);
ans.provider = p;
env->ReleaseStringUTFChars(s, p);
fid = env->GetFieldID(cls, "debug", "Z");
ans.debug = env->GetBooleanField(config, fid);
return ans;
}
} // namespace sherpa_onnx
SHERPA_ONNX_EXTERN_C
JNIEXPORT jlong JNICALL Java_com_k2fsa_sherpa_onnx_Vad_new(
JNIEnv *env, jobject /*obj*/, jobject asset_manager, jobject _config) {
#if __ANDROID_API__ >= 9
AAssetManager *mgr = AAssetManager_fromJava(env, asset_manager);
if (!mgr) {
SHERPA_ONNX_LOGE("Failed to get asset manager: %p", mgr);
}
#endif
auto config = sherpa_onnx::GetVadModelConfig(env, _config);
SHERPA_ONNX_LOGE("config:\n%s", config.ToString().c_str());
auto model = new sherpa_onnx::SherpaOnnxVad(
#if __ANDROID_API__ >= 9
mgr,
#endif
config);
return (jlong)model;
}
JNIEXPORT jlong JNICALL Java_com_k2fsa_sherpa_onnx_Vad_newFromFile(
JNIEnv *env, jobject /*obj*/, jobject _config) {
auto config = sherpa_onnx::GetVadModelConfig(env, _config);
SHERPA_ONNX_LOGE("config:\n%s", config.ToString().c_str());
auto model = new sherpa_onnx::SherpaOnnxVad(config);
return (jlong)model;
}
SHERPA_ONNX_EXTERN_C
JNIEXPORT void JNICALL Java_com_k2fsa_sherpa_onnx_Vad_delete(JNIEnv *env,
jobject /*obj*/,
jlong ptr) {
delete reinterpret_cast<sherpa_onnx::SherpaOnnxVad *>(ptr);
}
SHERPA_ONNX_EXTERN_C
JNIEXPORT void JNICALL Java_com_k2fsa_sherpa_onnx_Vad_acceptWaveform(
JNIEnv *env, jobject /*obj*/, jlong ptr, jfloatArray samples) {
auto model = reinterpret_cast<sherpa_onnx::SherpaOnnxVad *>(ptr);
jfloat *p = env->GetFloatArrayElements(samples, nullptr);
jsize n = env->GetArrayLength(samples);
model->AcceptWaveform(p, n);
env->ReleaseFloatArrayElements(samples, p, JNI_ABORT);
}
SHERPA_ONNX_EXTERN_C
JNIEXPORT bool JNICALL Java_com_k2fsa_sherpa_onnx_Vad_empty(JNIEnv *env,
jobject /*obj*/,
jlong ptr) {
auto model = reinterpret_cast<sherpa_onnx::SherpaOnnxVad *>(ptr);
return model->Empty();
}
SHERPA_ONNX_EXTERN_C
JNIEXPORT void JNICALL Java_com_k2fsa_sherpa_onnx_Vad_pop(JNIEnv *env,
jobject /*obj*/,
jlong ptr) {
auto model = reinterpret_cast<sherpa_onnx::SherpaOnnxVad *>(ptr);
model->Pop();
}
// see
// https://stackoverflow.com/questions/29043872/android-jni-return-multiple-variables
static jobject NewInteger(JNIEnv *env, int32_t value) {
jclass cls = env->FindClass("java/lang/Integer");
jmethodID constructor = env->GetMethodID(cls, "<init>", "(I)V");
return env->NewObject(cls, constructor, value);
}
SHERPA_ONNX_EXTERN_C
JNIEXPORT jobjectArray JNICALL
Java_com_k2fsa_sherpa_onnx_Vad_front(JNIEnv *env, jobject /*obj*/, jlong ptr) {
const auto &front =
reinterpret_cast<sherpa_onnx::SherpaOnnxVad *>(ptr)->Front();
jfloatArray samples_arr = env->NewFloatArray(front.samples.size());
env->SetFloatArrayRegion(samples_arr, 0, front.samples.size(),
front.samples.data());
jobjectArray obj_arr = (jobjectArray)env->NewObjectArray(
2, env->FindClass("java/lang/Object"), nullptr);
env->SetObjectArrayElement(obj_arr, 0, NewInteger(env, front.start));
env->SetObjectArrayElement(obj_arr, 1, samples_arr);
return obj_arr;
}
SHERPA_ONNX_EXTERN_C
JNIEXPORT bool JNICALL Java_com_k2fsa_sherpa_onnx_Vad_isSpeechDetected(
JNIEnv *env, jobject /*obj*/, jlong ptr) {
auto model = reinterpret_cast<sherpa_onnx::SherpaOnnxVad *>(ptr);
return model->IsSpeechDetected();
}
SHERPA_ONNX_EXTERN_C
JNIEXPORT void JNICALL Java_com_k2fsa_sherpa_onnx_Vad_reset(JNIEnv *env,
jobject /*obj*/,
jlong ptr) {
auto model = reinterpret_cast<sherpa_onnx::SherpaOnnxVad *>(ptr);
model->Reset();
}
SHERPA_ONNX_EXTERN_C
JNIEXPORT jlong JNICALL Java_com_k2fsa_sherpa_onnx_SherpaOnnx_new(
JNIEnv *env, jobject /*obj*/, jobject asset_manager, jobject _config) {
@@ -564,12 +749,12 @@ SHERPA_ONNX_EXTERN_C
JNIEXPORT jobjectArray JNICALL Java_com_k2fsa_sherpa_onnx_SherpaOnnx_getTokens(
JNIEnv *env, jobject /*obj*/, jlong ptr) {
auto tokens = reinterpret_cast<sherpa_onnx::SherpaOnnx *>(ptr)->GetTokens();
int size = tokens.size();
int32_t size = tokens.size();
jclass stringClass = env->FindClass("java/lang/String");
// convert C++ list into jni string array
jobjectArray result = env->NewObjectArray(size, stringClass, NULL);
for (int i = 0; i < size; i++) {
for (int32_t i = 0; i < size; i++) {
// Convert the C++ string to a C string
const char *cstr = tokens[i].c_str();
@@ -583,14 +768,6 @@ JNIEXPORT jobjectArray JNICALL Java_com_k2fsa_sherpa_onnx_SherpaOnnx_getTokens(
return result;
}
// see
// https://stackoverflow.com/questions/29043872/android-jni-return-multiple-variables
static jobject NewInteger(JNIEnv *env, int32_t value) {
jclass cls = env->FindClass("java/lang/Integer");
jmethodID constructor = env->GetMethodID(cls, "<init>", "(I)V");
return env->NewObject(cls, constructor, value);
}
static jobjectArray ReadWaveImpl(JNIEnv *env, std::istream &is,
const char *p_filename) {
bool is_ok = false;