Support silero_vad version 5 (#1064)
This commit is contained in:
@@ -8,7 +8,7 @@ project(sherpa-onnx)
|
|||||||
# ./nodejs-addon-examples
|
# ./nodejs-addon-examples
|
||||||
# ./dart-api-examples/
|
# ./dart-api-examples/
|
||||||
# ./sherpa-onnx/flutter/CHANGELOG.md
|
# ./sherpa-onnx/flutter/CHANGELOG.md
|
||||||
set(SHERPA_ONNX_VERSION "1.10.5")
|
set(SHERPA_ONNX_VERSION "1.10.6")
|
||||||
|
|
||||||
# Disable warning about
|
# Disable warning about
|
||||||
#
|
#
|
||||||
|
|||||||
@@ -1,5 +1,5 @@
|
|||||||
{
|
{
|
||||||
"dependencies": {
|
"dependencies": {
|
||||||
"sherpa-onnx-node": "^1.10.3"
|
"sherpa-onnx-node": "^1.10.6"
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -61,25 +61,11 @@ class SileroVadModel::Impl {
|
|||||||
#endif
|
#endif
|
||||||
|
|
||||||
void Reset() {
|
void Reset() {
|
||||||
// 2 - number of LSTM layer
|
if (is_v5_) {
|
||||||
// 1 - batch size
|
ResetV5();
|
||||||
// 64 - hidden dim
|
} else {
|
||||||
std::array<int64_t, 3> shape{2, 1, 64};
|
ResetV4();
|
||||||
|
}
|
||||||
Ort::Value h =
|
|
||||||
Ort::Value::CreateTensor<float>(allocator_, shape.data(), shape.size());
|
|
||||||
|
|
||||||
Ort::Value c =
|
|
||||||
Ort::Value::CreateTensor<float>(allocator_, shape.data(), shape.size());
|
|
||||||
|
|
||||||
Fill<float>(&h, 0);
|
|
||||||
Fill<float>(&c, 0);
|
|
||||||
|
|
||||||
states_.clear();
|
|
||||||
|
|
||||||
states_.reserve(2);
|
|
||||||
states_.push_back(std::move(h));
|
|
||||||
states_.push_back(std::move(c));
|
|
||||||
|
|
||||||
triggered_ = false;
|
triggered_ = false;
|
||||||
current_sample_ = 0;
|
current_sample_ = 0;
|
||||||
@@ -94,31 +80,7 @@ class SileroVadModel::Impl {
|
|||||||
exit(-1);
|
exit(-1);
|
||||||
}
|
}
|
||||||
|
|
||||||
auto memory_info =
|
float prob = Run(samples, n);
|
||||||
Ort::MemoryInfo::CreateCpu(OrtDeviceAllocator, OrtMemTypeDefault);
|
|
||||||
|
|
||||||
std::array<int64_t, 2> x_shape = {1, n};
|
|
||||||
|
|
||||||
Ort::Value x =
|
|
||||||
Ort::Value::CreateTensor(memory_info, const_cast<float *>(samples), n,
|
|
||||||
x_shape.data(), x_shape.size());
|
|
||||||
|
|
||||||
int64_t sr_shape = 1;
|
|
||||||
Ort::Value sr =
|
|
||||||
Ort::Value::CreateTensor(memory_info, &sample_rate_, 1, &sr_shape, 1);
|
|
||||||
|
|
||||||
std::array<Ort::Value, 4> inputs = {std::move(x), std::move(sr),
|
|
||||||
std::move(states_[0]),
|
|
||||||
std::move(states_[1])};
|
|
||||||
|
|
||||||
auto out =
|
|
||||||
sess_->Run({}, input_names_ptr_.data(), inputs.data(), inputs.size(),
|
|
||||||
output_names_ptr_.data(), output_names_ptr_.size());
|
|
||||||
|
|
||||||
states_[0] = std::move(out[1]);
|
|
||||||
states_[1] = std::move(out[2]);
|
|
||||||
|
|
||||||
float prob = out[0].GetTensorData<float>()[0];
|
|
||||||
|
|
||||||
float threshold = config_.silero_vad.threshold;
|
float threshold = config_.silero_vad.threshold;
|
||||||
|
|
||||||
@@ -186,6 +148,8 @@ class SileroVadModel::Impl {
|
|||||||
|
|
||||||
int32_t WindowSize() const { return config_.silero_vad.window_size; }
|
int32_t WindowSize() const { return config_.silero_vad.window_size; }
|
||||||
|
|
||||||
|
int32_t WindowShift() const { return WindowSize() - window_shift_; }
|
||||||
|
|
||||||
int32_t MinSilenceDurationSamples() const { return min_silence_samples_; }
|
int32_t MinSilenceDurationSamples() const { return min_silence_samples_; }
|
||||||
|
|
||||||
int32_t MinSpeechDurationSamples() const { return min_speech_samples_; }
|
int32_t MinSpeechDurationSamples() const { return min_speech_samples_; }
|
||||||
@@ -205,12 +169,76 @@ class SileroVadModel::Impl {
|
|||||||
|
|
||||||
GetInputNames(sess_.get(), &input_names_, &input_names_ptr_);
|
GetInputNames(sess_.get(), &input_names_, &input_names_ptr_);
|
||||||
GetOutputNames(sess_.get(), &output_names_, &output_names_ptr_);
|
GetOutputNames(sess_.get(), &output_names_, &output_names_ptr_);
|
||||||
|
|
||||||
|
if (input_names_.size() == 4 && output_names_.size() == 3) {
|
||||||
|
is_v5_ = false;
|
||||||
|
} else if (input_names_.size() == 3 && output_names_.size() == 2) {
|
||||||
|
is_v5_ = true;
|
||||||
|
|
||||||
|
// 64 for 16kHz
|
||||||
|
// 32 for 8kHz
|
||||||
|
window_shift_ = 64;
|
||||||
|
|
||||||
|
if (WindowSize() != 512) {
|
||||||
|
SHERPA_ONNX_LOGE(
|
||||||
|
"For silero_vad v5, we require window_size to be 512 for 16kHz");
|
||||||
|
exit(-1);
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
SHERPA_ONNX_LOGE("Unsupported silero vad model");
|
||||||
|
exit(-1);
|
||||||
|
}
|
||||||
|
|
||||||
Check();
|
Check();
|
||||||
|
|
||||||
Reset();
|
Reset();
|
||||||
}
|
}
|
||||||
|
|
||||||
void Check() {
|
void ResetV5() {
|
||||||
|
// 2 - number of LSTM layer
|
||||||
|
// 1 - batch size
|
||||||
|
// 128 - hidden dim
|
||||||
|
std::array<int64_t, 3> shape{2, 1, 128};
|
||||||
|
|
||||||
|
Ort::Value s =
|
||||||
|
Ort::Value::CreateTensor<float>(allocator_, shape.data(), shape.size());
|
||||||
|
|
||||||
|
Fill<float>(&s, 0);
|
||||||
|
states_.clear();
|
||||||
|
states_.push_back(std::move(s));
|
||||||
|
}
|
||||||
|
|
||||||
|
void ResetV4() {
|
||||||
|
// 2 - number of LSTM layer
|
||||||
|
// 1 - batch size
|
||||||
|
// 64 - hidden dim
|
||||||
|
std::array<int64_t, 3> shape{2, 1, 64};
|
||||||
|
|
||||||
|
Ort::Value h =
|
||||||
|
Ort::Value::CreateTensor<float>(allocator_, shape.data(), shape.size());
|
||||||
|
|
||||||
|
Ort::Value c =
|
||||||
|
Ort::Value::CreateTensor<float>(allocator_, shape.data(), shape.size());
|
||||||
|
|
||||||
|
Fill<float>(&h, 0);
|
||||||
|
Fill<float>(&c, 0);
|
||||||
|
|
||||||
|
states_.clear();
|
||||||
|
|
||||||
|
states_.reserve(2);
|
||||||
|
states_.push_back(std::move(h));
|
||||||
|
states_.push_back(std::move(c));
|
||||||
|
}
|
||||||
|
|
||||||
|
void Check() const {
|
||||||
|
if (is_v5_) {
|
||||||
|
CheckV5();
|
||||||
|
} else {
|
||||||
|
CheckV4();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
void CheckV4() const {
|
||||||
if (input_names_.size() != 4) {
|
if (input_names_.size() != 4) {
|
||||||
SHERPA_ONNX_LOGE("Expect 4 inputs. Given: %d",
|
SHERPA_ONNX_LOGE("Expect 4 inputs. Given: %d",
|
||||||
static_cast<int32_t>(input_names_.size()));
|
static_cast<int32_t>(input_names_.size()));
|
||||||
@@ -262,6 +290,114 @@ class SileroVadModel::Impl {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
void CheckV5() const {
|
||||||
|
if (input_names_.size() != 3) {
|
||||||
|
SHERPA_ONNX_LOGE("Expect 3 inputs. Given: %d",
|
||||||
|
static_cast<int32_t>(input_names_.size()));
|
||||||
|
exit(-1);
|
||||||
|
}
|
||||||
|
|
||||||
|
if (input_names_[0] != "input") {
|
||||||
|
SHERPA_ONNX_LOGE("Input[0]: %s. Expected: input",
|
||||||
|
input_names_[0].c_str());
|
||||||
|
exit(-1);
|
||||||
|
}
|
||||||
|
|
||||||
|
if (input_names_[1] != "state") {
|
||||||
|
SHERPA_ONNX_LOGE("Input[1]: %s. Expected: state",
|
||||||
|
input_names_[1].c_str());
|
||||||
|
exit(-1);
|
||||||
|
}
|
||||||
|
|
||||||
|
if (input_names_[2] != "sr") {
|
||||||
|
SHERPA_ONNX_LOGE("Input[2]: %s. Expected: sr", input_names_[2].c_str());
|
||||||
|
exit(-1);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Now for outputs
|
||||||
|
if (output_names_.size() != 2) {
|
||||||
|
SHERPA_ONNX_LOGE("Expect 2 outputs. Given: %d",
|
||||||
|
static_cast<int32_t>(output_names_.size()));
|
||||||
|
exit(-1);
|
||||||
|
}
|
||||||
|
|
||||||
|
if (output_names_[0] != "output") {
|
||||||
|
SHERPA_ONNX_LOGE("Output[0]: %s. Expected: output",
|
||||||
|
output_names_[0].c_str());
|
||||||
|
exit(-1);
|
||||||
|
}
|
||||||
|
|
||||||
|
if (output_names_[1] != "stateN") {
|
||||||
|
SHERPA_ONNX_LOGE("Output[1]: %s. Expected: stateN",
|
||||||
|
output_names_[1].c_str());
|
||||||
|
exit(-1);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
float Run(const float *samples, int32_t n) {
|
||||||
|
if (is_v5_) {
|
||||||
|
return RunV5(samples, n);
|
||||||
|
} else {
|
||||||
|
return RunV4(samples, n);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
float RunV5(const float *samples, int32_t n) {
|
||||||
|
auto memory_info =
|
||||||
|
Ort::MemoryInfo::CreateCpu(OrtDeviceAllocator, OrtMemTypeDefault);
|
||||||
|
|
||||||
|
std::array<int64_t, 2> x_shape = {1, n};
|
||||||
|
|
||||||
|
Ort::Value x =
|
||||||
|
Ort::Value::CreateTensor(memory_info, const_cast<float *>(samples), n,
|
||||||
|
x_shape.data(), x_shape.size());
|
||||||
|
|
||||||
|
int64_t sr_shape = 1;
|
||||||
|
Ort::Value sr =
|
||||||
|
Ort::Value::CreateTensor(memory_info, &sample_rate_, 1, &sr_shape, 1);
|
||||||
|
|
||||||
|
std::array<Ort::Value, 3> inputs = {std::move(x), std::move(states_[0]),
|
||||||
|
std::move(sr)};
|
||||||
|
|
||||||
|
auto out =
|
||||||
|
sess_->Run({}, input_names_ptr_.data(), inputs.data(), inputs.size(),
|
||||||
|
output_names_ptr_.data(), output_names_ptr_.size());
|
||||||
|
|
||||||
|
states_[0] = std::move(out[1]);
|
||||||
|
|
||||||
|
float prob = out[0].GetTensorData<float>()[0];
|
||||||
|
return prob;
|
||||||
|
}
|
||||||
|
|
||||||
|
float RunV4(const float *samples, int32_t n) {
|
||||||
|
auto memory_info =
|
||||||
|
Ort::MemoryInfo::CreateCpu(OrtDeviceAllocator, OrtMemTypeDefault);
|
||||||
|
|
||||||
|
std::array<int64_t, 2> x_shape = {1, n};
|
||||||
|
|
||||||
|
Ort::Value x =
|
||||||
|
Ort::Value::CreateTensor(memory_info, const_cast<float *>(samples), n,
|
||||||
|
x_shape.data(), x_shape.size());
|
||||||
|
|
||||||
|
int64_t sr_shape = 1;
|
||||||
|
Ort::Value sr =
|
||||||
|
Ort::Value::CreateTensor(memory_info, &sample_rate_, 1, &sr_shape, 1);
|
||||||
|
|
||||||
|
std::array<Ort::Value, 4> inputs = {std::move(x), std::move(sr),
|
||||||
|
std::move(states_[0]),
|
||||||
|
std::move(states_[1])};
|
||||||
|
|
||||||
|
auto out =
|
||||||
|
sess_->Run({}, input_names_ptr_.data(), inputs.data(), inputs.size(),
|
||||||
|
output_names_ptr_.data(), output_names_ptr_.size());
|
||||||
|
|
||||||
|
states_[0] = std::move(out[1]);
|
||||||
|
states_[1] = std::move(out[2]);
|
||||||
|
|
||||||
|
float prob = out[0].GetTensorData<float>()[0];
|
||||||
|
return prob;
|
||||||
|
}
|
||||||
|
|
||||||
private:
|
private:
|
||||||
VadModelConfig config_;
|
VadModelConfig config_;
|
||||||
|
|
||||||
@@ -286,6 +422,10 @@ class SileroVadModel::Impl {
|
|||||||
int32_t current_sample_ = 0;
|
int32_t current_sample_ = 0;
|
||||||
int32_t temp_start_ = 0;
|
int32_t temp_start_ = 0;
|
||||||
int32_t temp_end_ = 0;
|
int32_t temp_end_ = 0;
|
||||||
|
|
||||||
|
int32_t window_shift_ = 0;
|
||||||
|
|
||||||
|
bool is_v5_ = false;
|
||||||
};
|
};
|
||||||
|
|
||||||
SileroVadModel::SileroVadModel(const VadModelConfig &config)
|
SileroVadModel::SileroVadModel(const VadModelConfig &config)
|
||||||
@@ -306,6 +446,8 @@ bool SileroVadModel::IsSpeech(const float *samples, int32_t n) {
|
|||||||
|
|
||||||
int32_t SileroVadModel::WindowSize() const { return impl_->WindowSize(); }
|
int32_t SileroVadModel::WindowSize() const { return impl_->WindowSize(); }
|
||||||
|
|
||||||
|
int32_t SileroVadModel::WindowShift() const { return impl_->WindowShift(); }
|
||||||
|
|
||||||
int32_t SileroVadModel::MinSilenceDurationSamples() const {
|
int32_t SileroVadModel::MinSilenceDurationSamples() const {
|
||||||
return impl_->MinSilenceDurationSamples();
|
return impl_->MinSilenceDurationSamples();
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -39,6 +39,11 @@ class SileroVadModel : public VadModel {
|
|||||||
|
|
||||||
int32_t WindowSize() const override;
|
int32_t WindowSize() const override;
|
||||||
|
|
||||||
|
// For silero vad V4, it is WindowSize().
|
||||||
|
// For silero vad V5, it is WindowSize()-64 for 16kHz and
|
||||||
|
// WindowSize()-32 for 8kHz
|
||||||
|
int32_t WindowShift() const override;
|
||||||
|
|
||||||
int32_t MinSilenceDurationSamples() const override;
|
int32_t MinSilenceDurationSamples() const override;
|
||||||
int32_t MinSpeechDurationSamples() const override;
|
int32_t MinSpeechDurationSamples() const override;
|
||||||
|
|
||||||
|
|||||||
@@ -40,6 +40,8 @@ class VadModel {
|
|||||||
|
|
||||||
virtual int32_t WindowSize() const = 0;
|
virtual int32_t WindowSize() const = 0;
|
||||||
|
|
||||||
|
virtual int32_t WindowShift() const = 0;
|
||||||
|
|
||||||
virtual int32_t MinSilenceDurationSamples() const = 0;
|
virtual int32_t MinSilenceDurationSamples() const = 0;
|
||||||
virtual int32_t MinSpeechDurationSamples() const = 0;
|
virtual int32_t MinSpeechDurationSamples() const = 0;
|
||||||
virtual void SetMinSilenceDuration(float s) = 0;
|
virtual void SetMinSilenceDuration(float s) = 0;
|
||||||
|
|||||||
@@ -38,16 +38,20 @@ class VoiceActivityDetector::Impl {
|
|||||||
}
|
}
|
||||||
|
|
||||||
int32_t window_size = model_->WindowSize();
|
int32_t window_size = model_->WindowSize();
|
||||||
|
int32_t window_shift = model_->WindowShift();
|
||||||
|
|
||||||
// note n is usually window_size and there is no need to use
|
// note n is usually window_size and there is no need to use
|
||||||
// an extra buffer here
|
// an extra buffer here
|
||||||
last_.insert(last_.end(), samples, samples + n);
|
last_.insert(last_.end(), samples, samples + n);
|
||||||
int32_t k = static_cast<int32_t>(last_.size()) / window_size;
|
|
||||||
|
// Note: For v4, window_shift == window_size
|
||||||
|
int32_t k =
|
||||||
|
(static_cast<int32_t>(last_.size()) - window_size) / window_shift + 1;
|
||||||
const float *p = last_.data();
|
const float *p = last_.data();
|
||||||
bool is_speech = false;
|
bool is_speech = false;
|
||||||
|
|
||||||
for (int32_t i = 0; i != k; ++i, p += window_size) {
|
for (int32_t i = 0; i != k; ++i, p += window_shift) {
|
||||||
buffer_.Push(p, window_size);
|
buffer_.Push(p, window_shift);
|
||||||
// NOTE(fangjun): Please don't use a very large n.
|
// NOTE(fangjun): Please don't use a very large n.
|
||||||
bool this_window_is_speech = model_->IsSpeech(p, window_size);
|
bool this_window_is_speech = model_->IsSpeech(p, window_size);
|
||||||
is_speech = is_speech || this_window_is_speech;
|
is_speech = is_speech || this_window_is_speech;
|
||||||
|
|||||||
Reference in New Issue
Block a user