Add Android APK for Silero VAD (#335)
This commit is contained in:
@@ -32,7 +32,7 @@ std::vector<std::vector<std::string>> SplitToBatches(
|
||||
process_num += batch_size;
|
||||
}
|
||||
if (itr != input.cend()) {
|
||||
outputs.emplace_back(itr, input.cend());
|
||||
outputs.emplace_back(itr, input.cend());
|
||||
}
|
||||
return outputs;
|
||||
}
|
||||
@@ -41,8 +41,8 @@ std::vector<std::string> LoadScpFile(const std::string &wav_scp_path) {
|
||||
std::vector<std::string> wav_paths;
|
||||
std::ifstream in(wav_scp_path);
|
||||
if (!in.is_open()) {
|
||||
fprintf(stderr, "Failed to open file: %s.\n", wav_scp_path.c_str());
|
||||
return wav_paths;
|
||||
fprintf(stderr, "Failed to open file: %s.\n", wav_scp_path.c_str());
|
||||
return wav_paths;
|
||||
}
|
||||
std::string line, column1, column2;
|
||||
while (std::getline(in, line)) {
|
||||
@@ -55,8 +55,8 @@ std::vector<std::string> LoadScpFile(const std::string &wav_scp_path) {
|
||||
}
|
||||
|
||||
void AsrInference(const std::vector<std::vector<std::string>> &chunk_wav_paths,
|
||||
sherpa_onnx::OfflineRecognizer* recognizer,
|
||||
float* total_length, float* total_time) {
|
||||
sherpa_onnx::OfflineRecognizer *recognizer,
|
||||
float *total_length, float *total_time) {
|
||||
std::vector<std::unique_ptr<sherpa_onnx::OfflineStream>> ss;
|
||||
std::vector<sherpa_onnx::OfflineStream *> ss_pointers;
|
||||
float duration = 0.0f;
|
||||
@@ -70,7 +70,7 @@ void AsrInference(const std::vector<std::vector<std::string>> &chunk_wav_paths,
|
||||
sherpa_onnx::ReadWave(wav_filename, &sampling_rate, &is_ok);
|
||||
if (!is_ok) {
|
||||
fprintf(stderr, "Failed to read %s\n", wav_filename.c_str());
|
||||
continue;
|
||||
continue;
|
||||
}
|
||||
duration += samples.size() / static_cast<float>(sampling_rate);
|
||||
auto s = recognizer->CreateStream();
|
||||
@@ -97,7 +97,7 @@ void AsrInference(const std::vector<std::vector<std::string>> &chunk_wav_paths,
|
||||
sherpa_onnx::ReadWave(wav_filename, &sampling_rate, &is_ok);
|
||||
if (!is_ok) {
|
||||
fprintf(stderr, "Failed to read %s\n", wav_filename.c_str());
|
||||
continue;
|
||||
continue;
|
||||
}
|
||||
duration += samples.size() / static_cast<float>(sampling_rate);
|
||||
auto s = recognizer->CreateStream();
|
||||
@@ -109,9 +109,9 @@ void AsrInference(const std::vector<std::vector<std::string>> &chunk_wav_paths,
|
||||
recognizer->DecodeStreams(ss_pointers.data(), ss_pointers.size());
|
||||
const auto end = std::chrono::steady_clock::now();
|
||||
float elapsed_seconds =
|
||||
std::chrono::duration_cast<std::chrono::milliseconds>(end - begin)
|
||||
.count() /
|
||||
1000.;
|
||||
std::chrono::duration_cast<std::chrono::milliseconds>(end - begin)
|
||||
.count() /
|
||||
1000.;
|
||||
elapsed_seconds_batch += elapsed_seconds;
|
||||
int i = 0;
|
||||
for (const auto &wav_filename : wav_paths) {
|
||||
@@ -122,7 +122,7 @@ void AsrInference(const std::vector<std::vector<std::string>> &chunk_wav_paths,
|
||||
ss_pointers.clear();
|
||||
ss.clear();
|
||||
}
|
||||
fprintf(stderr, "thread %lu.\n", std::this_thread::get_id());
|
||||
|
||||
{
|
||||
std::lock_guard<std::mutex> guard(mtx);
|
||||
*total_length += duration;
|
||||
@@ -132,7 +132,6 @@ void AsrInference(const std::vector<std::vector<std::string>> &chunk_wav_paths,
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
int main(int32_t argc, char *argv[]) {
|
||||
const char *kUsageMessage = R"usage(
|
||||
Speech recognition using non-streaming models with sherpa-onnx.
|
||||
@@ -223,17 +222,17 @@ https://k2-fsa.github.io/sherpa/onnx/pretrained_models/index.html
|
||||
for a list of pre-trained models to download.
|
||||
)usage";
|
||||
std::string wav_scp = ""; // file path, kaldi style wav list.
|
||||
int32_t nj = 1; // thread number
|
||||
int32_t batch_size = 1; // number of wav files processed at once.
|
||||
int32_t nj = 1; // thread number
|
||||
int32_t batch_size = 1; // number of wav files processed at once.
|
||||
sherpa_onnx::ParseOptions po(kUsageMessage);
|
||||
sherpa_onnx::OfflineRecognizerConfig config;
|
||||
config.Register(&po);
|
||||
po.Register("wav-scp", &wav_scp,
|
||||
"a file including wav-id and wav-path, kaldi style wav list."
|
||||
"default="". when it is not empty, wav files which positional "
|
||||
"default="
|
||||
". when it is not empty, wav files which positional "
|
||||
"parameters provide are invalid.");
|
||||
po.Register("nj", &nj,
|
||||
"multi-thread num for decoding, default=1");
|
||||
po.Register("nj", &nj, "multi-thread num for decoding, default=1");
|
||||
po.Register("batch-size", &batch_size,
|
||||
"number of wav files processed at once during the decoding"
|
||||
"process. default=1");
|
||||
@@ -262,7 +261,8 @@ for a list of pre-trained models to download.
|
||||
1000.;
|
||||
fprintf(stderr,
|
||||
"Started nj: %d, batch_size: %d, wav_path: %s. recognizer init time: "
|
||||
"%.6f\n", nj, batch_size, wav_scp.c_str(), elapsed_seconds);
|
||||
"%.6f\n",
|
||||
nj, batch_size, wav_scp.c_str(), elapsed_seconds);
|
||||
std::this_thread::sleep_for(std::chrono::seconds(10)); // sleep 10s
|
||||
std::vector<std::string> wav_paths;
|
||||
if (!wav_scp.empty()) {
|
||||
@@ -282,12 +282,12 @@ for a list of pre-trained models to download.
|
||||
float total_length = 0.0f;
|
||||
float total_time = 0.0f;
|
||||
for (int i = 0; i < nj; i++) {
|
||||
threads.emplace_back(std::thread(AsrInference, batch_wav_paths,
|
||||
&recognizer, &total_length, &total_time));
|
||||
threads.emplace_back(std::thread(AsrInference, batch_wav_paths, &recognizer,
|
||||
&total_length, &total_time));
|
||||
}
|
||||
|
||||
for (auto& thread : threads) {
|
||||
thread.join();
|
||||
for (auto &thread : threads) {
|
||||
thread.join();
|
||||
}
|
||||
|
||||
fprintf(stderr, "num threads: %d\n", config.model_config.num_threads);
|
||||
@@ -297,8 +297,8 @@ for a list of pre-trained models to download.
|
||||
}
|
||||
fprintf(stderr, "Elapsed seconds: %.3f s\n", total_time);
|
||||
float rtf = total_time / total_length;
|
||||
fprintf(stderr, "Real time factor (RTF): %.6f / %.6f = %.4f\n",
|
||||
total_time, total_length, rtf);
|
||||
fprintf(stderr, "Real time factor (RTF): %.6f / %.6f = %.4f\n", total_time,
|
||||
total_length, rtf);
|
||||
fprintf(stderr, "SPEEDUP: %.4f\n", 1.0 / rtf);
|
||||
|
||||
return 0;
|
||||
|
||||
@@ -37,6 +37,29 @@ class SileroVadModel::Impl {
|
||||
min_speech_samples_ = sample_rate_ * config_.silero_vad.min_speech_duration;
|
||||
}
|
||||
|
||||
#if __ANDROID_API__ >= 9
|
||||
Impl(AAssetManager *mgr, const VadModelConfig &config)
|
||||
: config_(config),
|
||||
env_(ORT_LOGGING_LEVEL_ERROR),
|
||||
sess_opts_(GetSessionOptions(config)),
|
||||
allocator_{} {
|
||||
auto buf = ReadFile(mgr, config.silero_vad.model);
|
||||
Init(buf.data(), buf.size());
|
||||
|
||||
sample_rate_ = config.sample_rate;
|
||||
if (sample_rate_ != 16000) {
|
||||
SHERPA_ONNX_LOGE("Expected sample rate 16000. Given: %d",
|
||||
config.sample_rate);
|
||||
exit(-1);
|
||||
}
|
||||
|
||||
min_silence_samples_ =
|
||||
sample_rate_ * config_.silero_vad.min_silence_duration;
|
||||
|
||||
min_speech_samples_ = sample_rate_ * config_.silero_vad.min_speech_duration;
|
||||
}
|
||||
#endif
|
||||
|
||||
void Reset() {
|
||||
// 2 - number of LSTM layer
|
||||
// 1 - batch size
|
||||
@@ -260,6 +283,11 @@ class SileroVadModel::Impl {
|
||||
SileroVadModel::SileroVadModel(const VadModelConfig &config)
|
||||
: impl_(std::make_unique<Impl>(config)) {}
|
||||
|
||||
#if __ANDROID_API__ >= 9
|
||||
SileroVadModel::SileroVadModel(AAssetManager *mgr, const VadModelConfig &config)
|
||||
: impl_(std::make_unique<Impl>(mgr, config)) {}
|
||||
#endif
|
||||
|
||||
SileroVadModel::~SileroVadModel() = default;
|
||||
|
||||
void SileroVadModel::Reset() { return impl_->Reset(); }
|
||||
|
||||
@@ -6,6 +6,11 @@
|
||||
|
||||
#include <memory>
|
||||
|
||||
#if __ANDROID_API__ >= 9
|
||||
#include "android/asset_manager.h"
|
||||
#include "android/asset_manager_jni.h"
|
||||
#endif
|
||||
|
||||
#include "sherpa-onnx/csrc/vad-model.h"
|
||||
|
||||
namespace sherpa_onnx {
|
||||
@@ -13,6 +18,11 @@ namespace sherpa_onnx {
|
||||
class SileroVadModel : public VadModel {
|
||||
public:
|
||||
explicit SileroVadModel(const VadModelConfig &config);
|
||||
|
||||
#if __ANDROID_API__ >= 9
|
||||
SileroVadModel(AAssetManager *mgr, const VadModelConfig &config);
|
||||
#endif
|
||||
|
||||
~SileroVadModel() override;
|
||||
|
||||
// reset the internal model states
|
||||
|
||||
@@ -13,4 +13,12 @@ std::unique_ptr<VadModel> VadModel::Create(const VadModelConfig &config) {
|
||||
return std::make_unique<SileroVadModel>(config);
|
||||
}
|
||||
|
||||
#if __ANDROID_API__ >= 9
|
||||
std::unique_ptr<VadModel> VadModel::Create(AAssetManager *mgr,
|
||||
const VadModelConfig &config) {
|
||||
// TODO(fangjun): Support other VAD models.
|
||||
return std::make_unique<SileroVadModel>(mgr, config);
|
||||
}
|
||||
#endif
|
||||
|
||||
} // namespace sherpa_onnx
|
||||
|
||||
@@ -6,6 +6,11 @@
|
||||
|
||||
#include <memory>
|
||||
|
||||
#if __ANDROID_API__ >= 9
|
||||
#include "android/asset_manager.h"
|
||||
#include "android/asset_manager_jni.h"
|
||||
#endif
|
||||
|
||||
#include "sherpa-onnx/csrc/vad-model-config.h"
|
||||
|
||||
namespace sherpa_onnx {
|
||||
@@ -16,6 +21,11 @@ class VadModel {
|
||||
|
||||
static std::unique_ptr<VadModel> Create(const VadModelConfig &config);
|
||||
|
||||
#if __ANDROID_API__ >= 9
|
||||
static std::unique_ptr<VadModel> Create(AAssetManager *mgr,
|
||||
const VadModelConfig &config);
|
||||
#endif
|
||||
|
||||
// reset the internal model states
|
||||
virtual void Reset() = 0;
|
||||
|
||||
|
||||
@@ -19,10 +19,32 @@ class VoiceActivityDetector::Impl {
|
||||
config_(config),
|
||||
buffer_(buffer_size_in_seconds * config.sample_rate) {}
|
||||
|
||||
void AcceptWaveform(const float *samples, int32_t n) {
|
||||
buffer_.Push(samples, n);
|
||||
#if __ANDROID_API__ >= 9
|
||||
Impl(AAssetManager *mgr, const VadModelConfig &config,
|
||||
float buffer_size_in_seconds = 60)
|
||||
: model_(VadModel::Create(mgr, config)),
|
||||
config_(config),
|
||||
buffer_(buffer_size_in_seconds * config.sample_rate) {}
|
||||
#endif
|
||||
|
||||
void AcceptWaveform(const float *samples, int32_t n) {
|
||||
int32_t window_size = model_->WindowSize();
|
||||
|
||||
// note n is usally window_size and there is no need to use
|
||||
// an extra buffer here
|
||||
last_.insert(last_.end(), samples, samples + n);
|
||||
int32_t k = static_cast<int32_t>(last_.size()) / window_size;
|
||||
const float *p = last_.data();
|
||||
bool is_speech = false;
|
||||
|
||||
for (int32_t i = 0; i != k; ++i, p += window_size) {
|
||||
buffer_.Push(p, window_size);
|
||||
is_speech = model_->IsSpeech(p, window_size);
|
||||
}
|
||||
|
||||
last_ = std::vector<float>(
|
||||
p, static_cast<const float *>(last_.data()) + last_.size());
|
||||
|
||||
bool is_speech = model_->IsSpeech(samples, n);
|
||||
if (is_speech) {
|
||||
if (start_ == -1) {
|
||||
// beginning of speech
|
||||
@@ -31,15 +53,15 @@ class VoiceActivityDetector::Impl {
|
||||
}
|
||||
} else {
|
||||
// non-speech
|
||||
if (start_ != -1) {
|
||||
if (start_ != -1 && buffer_.Size()) {
|
||||
// end of speech, save the speech segment
|
||||
int32_t end = buffer_.Tail() - model_->MinSilenceDurationSamples();
|
||||
|
||||
std::vector<float> samples = buffer_.Get(start_, end - start_);
|
||||
std::vector<float> s = buffer_.Get(start_, end - start_);
|
||||
SpeechSegment segment;
|
||||
|
||||
segment.start = start_;
|
||||
segment.samples = std::move(samples);
|
||||
segment.samples = std::move(s);
|
||||
|
||||
segments_.push(std::move(segment));
|
||||
|
||||
@@ -73,6 +95,7 @@ class VoiceActivityDetector::Impl {
|
||||
std::unique_ptr<VadModel> model_;
|
||||
VadModelConfig config_;
|
||||
CircularBuffer buffer_;
|
||||
std::vector<float> last_;
|
||||
|
||||
int32_t start_ = -1;
|
||||
};
|
||||
@@ -81,6 +104,13 @@ VoiceActivityDetector::VoiceActivityDetector(
|
||||
const VadModelConfig &config, float buffer_size_in_seconds /*= 60*/)
|
||||
: impl_(std::make_unique<Impl>(config, buffer_size_in_seconds)) {}
|
||||
|
||||
#if __ANDROID_API__ >= 9
|
||||
VoiceActivityDetector::VoiceActivityDetector(
|
||||
AAssetManager *mgr, const VadModelConfig &config,
|
||||
float buffer_size_in_seconds /*= 60*/)
|
||||
: impl_(std::make_unique<Impl>(mgr, config, buffer_size_in_seconds)) {}
|
||||
#endif
|
||||
|
||||
VoiceActivityDetector::~VoiceActivityDetector() = default;
|
||||
|
||||
void VoiceActivityDetector::AcceptWaveform(const float *samples, int32_t n) {
|
||||
|
||||
@@ -7,6 +7,11 @@
|
||||
#include <memory>
|
||||
#include <vector>
|
||||
|
||||
#if __ANDROID_API__ >= 9
|
||||
#include "android/asset_manager.h"
|
||||
#include "android/asset_manager_jni.h"
|
||||
#endif
|
||||
|
||||
#include "sherpa-onnx/csrc/vad-model-config.h"
|
||||
|
||||
namespace sherpa_onnx {
|
||||
@@ -20,6 +25,12 @@ class VoiceActivityDetector {
|
||||
public:
|
||||
explicit VoiceActivityDetector(const VadModelConfig &config,
|
||||
float buffer_size_in_seconds = 60);
|
||||
|
||||
#if __ANDROID_API__ >= 9
|
||||
VoiceActivityDetector(AAssetManager *mgr, const VadModelConfig &config,
|
||||
float buffer_size_in_seconds = 60);
|
||||
#endif
|
||||
|
||||
~VoiceActivityDetector();
|
||||
|
||||
void AcceptWaveform(const float *samples, int32_t n);
|
||||
|
||||
Reference in New Issue
Block a user