Support streaming paraformer (#263)

This commit is contained in:
Fangjun Kuang
2023-08-14 10:32:14 +08:00
committed by GitHub
parent a4bff28e21
commit 6038e2aa62
38 changed files with 1488 additions and 112 deletions

View File

@@ -46,6 +46,8 @@ set(sources
online-lm.cc
online-lstm-transducer-model.cc
online-model-config.cc
online-paraformer-model-config.cc
online-paraformer-model.cc
online-recognizer-impl.cc
online-recognizer.cc
online-rnn-lm.cc

View File

@@ -39,7 +39,7 @@ std::string FeatureExtractorConfig::ToString() const {
class FeatureExtractor::Impl {
public:
explicit Impl(const FeatureExtractorConfig &config) {
explicit Impl(const FeatureExtractorConfig &config) : config_(config) {
opts_.frame_opts.dither = 0;
opts_.frame_opts.snip_edges = false;
opts_.frame_opts.samp_freq = config.sampling_rate;
@@ -50,6 +50,19 @@ class FeatureExtractor::Impl {
}
void AcceptWaveform(int32_t sampling_rate, const float *waveform, int32_t n) {
if (config_.normalize_samples) {
AcceptWaveformImpl(sampling_rate, waveform, n);
} else {
std::vector<float> buf(n);
for (int32_t i = 0; i != n; ++i) {
buf[i] = waveform[i] * 32768;
}
AcceptWaveformImpl(sampling_rate, buf.data(), n);
}
}
void AcceptWaveformImpl(int32_t sampling_rate, const float *waveform,
int32_t n) {
std::lock_guard<std::mutex> lock(mutex_);
if (resampler_) {
@@ -146,6 +159,7 @@ class FeatureExtractor::Impl {
private:
std::unique_ptr<knf::OnlineFbank> fbank_;
knf::FbankOptions opts_;
FeatureExtractorConfig config_;
mutable std::mutex mutex_;
std::unique_ptr<LinearResample> resampler_;
int32_t last_frame_index_ = 0;

View File

@@ -21,6 +21,13 @@ struct FeatureExtractorConfig {
// Feature dimension
int32_t feature_dim = 80;
// Set internally by some models, e.g., paraformer sets it to false.
// This parameter is not exposed to users from the commandline
// If true, the feature extractor expects inputs to be normalized to
// the range [-1, 1].
// If false, we will multiply the inputs by 32768
bool normalize_samples = true;
std::string ToString() const;
void Register(ParseOptions *po);

View File

@@ -12,6 +12,7 @@ namespace sherpa_onnx {
void OnlineModelConfig::Register(ParseOptions *po) {
transducer.Register(po);
paraformer.Register(po);
po->Register("tokens", &tokens, "Path to tokens.txt");
@@ -41,6 +42,10 @@ bool OnlineModelConfig::Validate() const {
return false;
}
if (!paraformer.encoder.empty()) {
return paraformer.Validate();
}
return transducer.Validate();
}
@@ -49,6 +54,7 @@ std::string OnlineModelConfig::ToString() const {
os << "OnlineModelConfig(";
os << "transducer=" << transducer.ToString() << ", ";
os << "paraformer=" << paraformer.ToString() << ", ";
os << "tokens=\"" << tokens << "\", ";
os << "num_threads=" << num_threads << ", ";
os << "debug=" << (debug ? "True" : "False") << ", ";

View File

@@ -6,12 +6,14 @@
#include <string>
#include "sherpa-onnx/csrc/online-paraformer-model-config.h"
#include "sherpa-onnx/csrc/online-transducer-model-config.h"
namespace sherpa_onnx {
struct OnlineModelConfig {
OnlineTransducerModelConfig transducer;
OnlineParaformerModelConfig paraformer;
std::string tokens;
int32_t num_threads = 1;
bool debug = false;
@@ -28,9 +30,11 @@ struct OnlineModelConfig {
OnlineModelConfig() = default;
OnlineModelConfig(const OnlineTransducerModelConfig &transducer,
const OnlineParaformerModelConfig &paraformer,
const std::string &tokens, int32_t num_threads, bool debug,
const std::string &provider, const std::string &model_type)
: transducer(transducer),
paraformer(paraformer),
tokens(tokens),
num_threads(num_threads),
debug(debug),

View File

@@ -0,0 +1,23 @@
// sherpa-onnx/csrc/online-paraformer-decoder.h
//
// Copyright (c) 2023 Xiaomi Corporation
#ifndef SHERPA_ONNX_CSRC_ONLINE_PARAFORMER_DECODER_H_
#define SHERPA_ONNX_CSRC_ONLINE_PARAFORMER_DECODER_H_
#include <vector>
#include "onnxruntime_cxx_api.h" // NOLINT
namespace sherpa_onnx {
struct OnlineParaformerDecoderResult {
/// The decoded token IDs
std::vector<int32_t> tokens;
int32_t last_non_blank_frame_index = 0;
};
} // namespace sherpa_onnx
#endif // SHERPA_ONNX_CSRC_ONLINE_PARAFORMER_DECODER_H_

View File

@@ -0,0 +1,43 @@
// sherpa-onnx/csrc/online-paraformer-model-config.cc
//
// Copyright (c) 2023 Xiaomi Corporation
#include "sherpa-onnx/csrc/online-paraformer-model-config.h"
#include "sherpa-onnx/csrc/file-utils.h"
#include "sherpa-onnx/csrc/macros.h"
namespace sherpa_onnx {
void OnlineParaformerModelConfig::Register(ParseOptions *po) {
po->Register("paraformer-encoder", &encoder,
"Path to encoder.onnx of paraformer.");
po->Register("paraformer-decoder", &decoder,
"Path to decoder.onnx of paraformer.");
}
bool OnlineParaformerModelConfig::Validate() const {
if (!FileExists(encoder)) {
SHERPA_ONNX_LOGE("Paraformer encoder %s does not exist", encoder.c_str());
return false;
}
if (!FileExists(decoder)) {
SHERPA_ONNX_LOGE("Paraformer decoder %s does not exist", decoder.c_str());
return false;
}
return true;
}
std::string OnlineParaformerModelConfig::ToString() const {
std::ostringstream os;
os << "OnlineParaformerModelConfig(";
os << "encoder=\"" << encoder << "\", ";
os << "decoder=\"" << decoder << "\")";
return os.str();
}
} // namespace sherpa_onnx

View File

@@ -0,0 +1,31 @@
// sherpa-onnx/csrc/online-paraformer-model-config.h
//
// Copyright (c) 2023 Xiaomi Corporation
#ifndef SHERPA_ONNX_CSRC_ONLINE_PARAFORMER_MODEL_CONFIG_H_
#define SHERPA_ONNX_CSRC_ONLINE_PARAFORMER_MODEL_CONFIG_H_
#include <string>
#include "sherpa-onnx/csrc/parse-options.h"
namespace sherpa_onnx {
struct OnlineParaformerModelConfig {
std::string encoder;
std::string decoder;
OnlineParaformerModelConfig() = default;
OnlineParaformerModelConfig(const std::string &encoder,
const std::string &decoder)
: encoder(encoder), decoder(decoder) {}
void Register(ParseOptions *po);
bool Validate() const;
std::string ToString() const;
};
} // namespace sherpa_onnx
#endif // SHERPA_ONNX_CSRC_ONLINE_PARAFORMER_MODEL_CONFIG_H_

View File

@@ -0,0 +1,249 @@
// sherpa-onnx/csrc/online-paraformer-model.cc
//
// Copyright (c) 2022-2023 Xiaomi Corporation
#include "sherpa-onnx/csrc/online-paraformer-model.h"
#include <algorithm>
#include <cmath>
#include <string>
#if __ANDROID_API__ >= 9
#include "android/asset_manager.h"
#include "android/asset_manager_jni.h"
#endif
#include "sherpa-onnx/csrc/macros.h"
#include "sherpa-onnx/csrc/onnx-utils.h"
#include "sherpa-onnx/csrc/session.h"
#include "sherpa-onnx/csrc/text-utils.h"
namespace sherpa_onnx {
class OnlineParaformerModel::Impl {
public:
explicit Impl(const OnlineModelConfig &config)
: config_(config),
env_(ORT_LOGGING_LEVEL_ERROR),
sess_opts_(GetSessionOptions(config)),
allocator_{} {
{
auto buf = ReadFile(config.paraformer.encoder);
InitEncoder(buf.data(), buf.size());
}
{
auto buf = ReadFile(config.paraformer.decoder);
InitDecoder(buf.data(), buf.size());
}
}
#if __ANDROID_API__ >= 9
Impl(AAssetManager *mgr, const OnlineModelConfig &config)
: config_(config),
env_(ORT_LOGGING_LEVEL_WARNING),
sess_opts_(GetSessionOptions(config)),
allocator_{} {
{
auto buf = ReadFile(mgr, config.paraformer.encoder);
InitEncoder(buf.data(), buf.size());
}
{
auto buf = ReadFile(mgr, config.paraformer.decoder);
InitDecoder(buf.data(), buf.size());
}
}
#endif
std::vector<Ort::Value> ForwardEncoder(Ort::Value features,
Ort::Value features_length) {
std::array<Ort::Value, 2> inputs = {std::move(features),
std::move(features_length)};
return encoder_sess_->Run(
{}, encoder_input_names_ptr_.data(), inputs.data(), inputs.size(),
encoder_output_names_ptr_.data(), encoder_output_names_ptr_.size());
}
std::vector<Ort::Value> ForwardDecoder(Ort::Value encoder_out,
Ort::Value encoder_out_length,
Ort::Value acoustic_embedding,
Ort::Value acoustic_embedding_length,
std::vector<Ort::Value> states) {
std::vector<Ort::Value> decoder_inputs;
decoder_inputs.reserve(4 + states.size());
decoder_inputs.push_back(std::move(encoder_out));
decoder_inputs.push_back(std::move(encoder_out_length));
decoder_inputs.push_back(std::move(acoustic_embedding));
decoder_inputs.push_back(std::move(acoustic_embedding_length));
for (auto &v : states) {
decoder_inputs.push_back(std::move(v));
}
return decoder_sess_->Run({}, decoder_input_names_ptr_.data(),
decoder_inputs.data(), decoder_inputs.size(),
decoder_output_names_ptr_.data(),
decoder_output_names_ptr_.size());
}
int32_t VocabSize() const { return vocab_size_; }
int32_t LfrWindowSize() const { return lfr_window_size_; }
int32_t LfrWindowShift() const { return lfr_window_shift_; }
int32_t EncoderOutputSize() const { return encoder_output_size_; }
int32_t DecoderKernelSize() const { return decoder_kernel_size_; }
int32_t DecoderNumBlocks() const { return decoder_num_blocks_; }
const std::vector<float> &NegativeMean() const { return neg_mean_; }
const std::vector<float> &InverseStdDev() const { return inv_stddev_; }
OrtAllocator *Allocator() const { return allocator_; }
private:
void InitEncoder(void *model_data, size_t model_data_length) {
encoder_sess_ = std::make_unique<Ort::Session>(
env_, model_data, model_data_length, sess_opts_);
GetInputNames(encoder_sess_.get(), &encoder_input_names_,
&encoder_input_names_ptr_);
GetOutputNames(encoder_sess_.get(), &encoder_output_names_,
&encoder_output_names_ptr_);
// get meta data
Ort::ModelMetadata meta_data = encoder_sess_->GetModelMetadata();
if (config_.debug) {
std::ostringstream os;
PrintModelMetadata(os, meta_data);
SHERPA_ONNX_LOGE("%s\n", os.str().c_str());
}
Ort::AllocatorWithDefaultOptions allocator; // used in the macro below
SHERPA_ONNX_READ_META_DATA(vocab_size_, "vocab_size");
SHERPA_ONNX_READ_META_DATA(lfr_window_size_, "lfr_window_size");
SHERPA_ONNX_READ_META_DATA(lfr_window_shift_, "lfr_window_shift");
SHERPA_ONNX_READ_META_DATA(encoder_output_size_, "encoder_output_size");
SHERPA_ONNX_READ_META_DATA(decoder_num_blocks_, "decoder_num_blocks");
SHERPA_ONNX_READ_META_DATA(decoder_kernel_size_, "decoder_kernel_size");
SHERPA_ONNX_READ_META_DATA_VEC_FLOAT(neg_mean_, "neg_mean");
SHERPA_ONNX_READ_META_DATA_VEC_FLOAT(inv_stddev_, "inv_stddev");
float scale = std::sqrt(encoder_output_size_);
for (auto &f : inv_stddev_) {
f *= scale;
}
}
void InitDecoder(void *model_data, size_t model_data_length) {
decoder_sess_ = std::make_unique<Ort::Session>(
env_, model_data, model_data_length, sess_opts_);
GetInputNames(decoder_sess_.get(), &decoder_input_names_,
&decoder_input_names_ptr_);
GetOutputNames(decoder_sess_.get(), &decoder_output_names_,
&decoder_output_names_ptr_);
}
private:
OnlineModelConfig config_;
Ort::Env env_;
Ort::SessionOptions sess_opts_;
Ort::AllocatorWithDefaultOptions allocator_;
std::unique_ptr<Ort::Session> encoder_sess_;
std::vector<std::string> encoder_input_names_;
std::vector<const char *> encoder_input_names_ptr_;
std::vector<std::string> encoder_output_names_;
std::vector<const char *> encoder_output_names_ptr_;
std::unique_ptr<Ort::Session> decoder_sess_;
std::vector<std::string> decoder_input_names_;
std::vector<const char *> decoder_input_names_ptr_;
std::vector<std::string> decoder_output_names_;
std::vector<const char *> decoder_output_names_ptr_;
std::vector<float> neg_mean_;
std::vector<float> inv_stddev_;
int32_t vocab_size_ = 0; // initialized in Init
int32_t lfr_window_size_ = 0;
int32_t lfr_window_shift_ = 0;
int32_t encoder_output_size_ = 0;
int32_t decoder_num_blocks_ = 0;
int32_t decoder_kernel_size_ = 0;
};
OnlineParaformerModel::OnlineParaformerModel(const OnlineModelConfig &config)
: impl_(std::make_unique<Impl>(config)) {}
#if __ANDROID_API__ >= 9
OnlineParaformerModel::OnlineParaformerModel(AAssetManager *mgr,
const OnlineModelConfig &config)
: impl_(std::make_unique<Impl>(mgr, config)) {}
#endif
OnlineParaformerModel::~OnlineParaformerModel() = default;
std::vector<Ort::Value> OnlineParaformerModel::ForwardEncoder(
Ort::Value features, Ort::Value features_length) const {
return impl_->ForwardEncoder(std::move(features), std::move(features_length));
}
std::vector<Ort::Value> OnlineParaformerModel::ForwardDecoder(
Ort::Value encoder_out, Ort::Value encoder_out_length,
Ort::Value acoustic_embedding, Ort::Value acoustic_embedding_length,
std::vector<Ort::Value> states) const {
return impl_->ForwardDecoder(
std::move(encoder_out), std::move(encoder_out_length),
std::move(acoustic_embedding), std::move(acoustic_embedding_length),
std::move(states));
}
int32_t OnlineParaformerModel::VocabSize() const { return impl_->VocabSize(); }
int32_t OnlineParaformerModel::LfrWindowSize() const {
return impl_->LfrWindowSize();
}
int32_t OnlineParaformerModel::LfrWindowShift() const {
return impl_->LfrWindowShift();
}
int32_t OnlineParaformerModel::EncoderOutputSize() const {
return impl_->EncoderOutputSize();
}
int32_t OnlineParaformerModel::DecoderKernelSize() const {
return impl_->DecoderKernelSize();
}
int32_t OnlineParaformerModel::DecoderNumBlocks() const {
return impl_->DecoderNumBlocks();
}
const std::vector<float> &OnlineParaformerModel::NegativeMean() const {
return impl_->NegativeMean();
}
const std::vector<float> &OnlineParaformerModel::InverseStdDev() const {
return impl_->InverseStdDev();
}
OrtAllocator *OnlineParaformerModel::Allocator() const {
return impl_->Allocator();
}
} // namespace sherpa_onnx

View File

@@ -0,0 +1,76 @@
// sherpa-onnx/csrc/online-paraformer-model.h
//
// Copyright (c) 2022-2023 Xiaomi Corporation
#ifndef SHERPA_ONNX_CSRC_ONLINE_PARAFORMER_MODEL_H_
#define SHERPA_ONNX_CSRC_ONLINE_PARAFORMER_MODEL_H_
#include <memory>
#include <utility>
#include <vector>
#if __ANDROID_API__ >= 9
#include "android/asset_manager.h"
#include "android/asset_manager_jni.h"
#endif
#include "onnxruntime_cxx_api.h" // NOLINT
#include "sherpa-onnx/csrc/online-model-config.h"
namespace sherpa_onnx {
class OnlineParaformerModel {
public:
explicit OnlineParaformerModel(const OnlineModelConfig &config);
#if __ANDROID_API__ >= 9
OnlineParaformerModel(AAssetManager *mgr, const OnlineModelConfig &config);
#endif
~OnlineParaformerModel();
std::vector<Ort::Value> ForwardEncoder(Ort::Value features,
Ort::Value features_length) const;
std::vector<Ort::Value> ForwardDecoder(Ort::Value encoder_out,
Ort::Value encoder_out_length,
Ort::Value acoustic_embedding,
Ort::Value acoustic_embedding_length,
std::vector<Ort::Value> states) const;
/** Return the vocabulary size of the model
*/
int32_t VocabSize() const;
/** It is lfr_m in config.yaml
*/
int32_t LfrWindowSize() const;
/** It is lfr_n in config.yaml
*/
int32_t LfrWindowShift() const;
int32_t EncoderOutputSize() const;
int32_t DecoderKernelSize() const;
int32_t DecoderNumBlocks() const;
/** Return negative mean for CMVN
*/
const std::vector<float> &NegativeMean() const;
/** Return inverse stddev for CMVN
*/
const std::vector<float> &InverseStdDev() const;
/** Return an allocator for allocating memory
*/
OrtAllocator *Allocator() const;
private:
class Impl;
std::unique_ptr<Impl> impl_;
};
} // namespace sherpa_onnx
#endif // SHERPA_ONNX_CSRC_ONLINE_PARAFORMER_MODEL_H_

View File

@@ -4,6 +4,7 @@
#include "sherpa-onnx/csrc/online-recognizer-impl.h"
#include "sherpa-onnx/csrc/online-recognizer-paraformer-impl.h"
#include "sherpa-onnx/csrc/online-recognizer-transducer-impl.h"
namespace sherpa_onnx {
@@ -14,6 +15,10 @@ std::unique_ptr<OnlineRecognizerImpl> OnlineRecognizerImpl::Create(
return std::make_unique<OnlineRecognizerTransducerImpl>(config);
}
if (!config.model_config.paraformer.encoder.empty()) {
return std::make_unique<OnlineRecognizerParaformerImpl>(config);
}
SHERPA_ONNX_LOGE("Please specify a model");
exit(-1);
}
@@ -25,6 +30,10 @@ std::unique_ptr<OnlineRecognizerImpl> OnlineRecognizerImpl::Create(
return std::make_unique<OnlineRecognizerTransducerImpl>(mgr, config);
}
if (!config.model_config.paraformer.encoder.empty()) {
return std::make_unique<OnlineRecognizerParaformerImpl>(mgr, config);
}
SHERPA_ONNX_LOGE("Please specify a model");
exit(-1);
}

View File

@@ -26,8 +26,6 @@ class OnlineRecognizerImpl {
virtual ~OnlineRecognizerImpl() = default;
virtual void InitOnlineStream(OnlineStream *stream) const = 0;
virtual std::unique_ptr<OnlineStream> CreateStream() const = 0;
virtual std::unique_ptr<OnlineStream> CreateStream(

View File

@@ -0,0 +1,465 @@
// sherpa-onnx/csrc/online-recognizer-paraformer-impl.h
//
// Copyright (c) 2022-2023 Xiaomi Corporation
#ifndef SHERPA_ONNX_CSRC_ONLINE_RECOGNIZER_PARAFORMER_IMPL_H_
#define SHERPA_ONNX_CSRC_ONLINE_RECOGNIZER_PARAFORMER_IMPL_H_
#include <algorithm>
#include <memory>
#include <string>
#include <utility>
#include <vector>
#include "sherpa-onnx/csrc/file-utils.h"
#include "sherpa-onnx/csrc/macros.h"
#include "sherpa-onnx/csrc/online-lm.h"
#include "sherpa-onnx/csrc/online-paraformer-decoder.h"
#include "sherpa-onnx/csrc/online-paraformer-model.h"
#include "sherpa-onnx/csrc/online-recognizer-impl.h"
#include "sherpa-onnx/csrc/online-recognizer.h"
#include "sherpa-onnx/csrc/symbol-table.h"
namespace sherpa_onnx {
static OnlineRecognizerResult Convert(const OnlineParaformerDecoderResult &src,
const SymbolTable &sym_table) {
OnlineRecognizerResult r;
r.tokens.reserve(src.tokens.size());
std::string text;
// When the current token ends with "@@" we set mergeable to true
bool mergeable = false;
for (int32_t i = 0; i != src.tokens.size(); ++i) {
auto sym = sym_table[src.tokens[i]];
r.tokens.push_back(sym);
if ((sym.back() != '@') || (sym.size() > 2 && sym[sym.size() - 2] != '@')) {
// sym does not end with "@@"
const uint8_t *p = reinterpret_cast<const uint8_t *>(sym.c_str());
if (p[0] < 0x80) {
// an ascii
if (mergeable) {
mergeable = false;
text.append(sym);
} else {
text.append(" ");
text.append(sym);
}
} else {
// not an ascii
mergeable = false;
if (i > 0) {
const uint8_t *p = reinterpret_cast<const uint8_t *>(
sym_table[src.tokens[i - 1]].c_str());
if (p[0] < 0x80) {
// put a space between ascii and non-ascii
text.append(" ");
}
}
text.append(sym);
}
} else {
// this sym ends with @@
sym = std::string(sym.data(), sym.size() - 2);
if (mergeable) {
text.append(sym);
} else {
text.append(" ");
text.append(sym);
mergeable = true;
}
}
}
r.text = std::move(text);
return r;
}
// y[i] += x[i] * scale
static void ScaleAddInPlace(const float *x, int32_t n, float scale, float *y) {
for (int32_t i = 0; i != n; ++i) {
y[i] += x[i] * scale;
}
}
// y[i] = x[i] * scale
static void Scale(const float *x, int32_t n, float scale, float *y) {
for (int32_t i = 0; i != n; ++i) {
y[i] = x[i] * scale;
}
}
class OnlineRecognizerParaformerImpl : public OnlineRecognizerImpl {
public:
explicit OnlineRecognizerParaformerImpl(const OnlineRecognizerConfig &config)
: config_(config),
model_(config.model_config),
sym_(config.model_config.tokens),
endpoint_(config_.endpoint_config) {
if (config.decoding_method != "greedy_search") {
SHERPA_ONNX_LOGE(
"Unsupported decoding method: %s. Support only greedy_search at "
"present",
config.decoding_method.c_str());
exit(-1);
}
// Paraformer models assume input samples are in the range
// [-32768, 32767], so we set normalize_samples to false
config_.feat_config.normalize_samples = false;
}
#if __ANDROID_API__ >= 9
explicit OnlineRecognizerParaformerImpl(AAssetManager *mgr,
const OnlineRecognizerConfig &config)
: config_(config),
model_(mgr, config.model_config),
sym_(mgr, config.model_config.tokens),
endpoint_(config_.endpoint_config) {
if (config.decoding_method == "greedy_search") {
// add greedy search decoder
// SHERPA_ONNX_LOGE("to be implemented");
// exit(-1);
} else {
SHERPA_ONNX_LOGE("Unsupported decoding method: %s",
config.decoding_method.c_str());
exit(-1);
}
// Paraformer models assume input samples are in the range
// [-32768, 32767], so we set normalize_samples to false
config_.feat_config.normalize_samples = false;
}
#endif
OnlineRecognizerParaformerImpl(const OnlineRecognizerParaformerImpl &) =
delete;
OnlineRecognizerParaformerImpl operator=(
const OnlineRecognizerParaformerImpl &) = delete;
std::unique_ptr<OnlineStream> CreateStream() const override {
auto stream = std::make_unique<OnlineStream>(config_.feat_config);
OnlineParaformerDecoderResult r;
stream->SetParaformerResult(r);
return stream;
}
bool IsReady(OnlineStream *s) const override {
return s->GetNumProcessedFrames() + chunk_size_ < s->NumFramesReady();
}
void DecodeStreams(OnlineStream **ss, int32_t n) const override {
// TODO(fangjun): Support batch size > 1
for (int32_t i = 0; i != n; ++i) {
DecodeStream(ss[i]);
}
}
OnlineRecognizerResult GetResult(OnlineStream *s) const override {
auto decoder_result = s->GetParaformerResult();
return Convert(decoder_result, sym_);
}
bool IsEndpoint(OnlineStream *s) const override {
if (!config_.enable_endpoint) {
return false;
}
const auto &result = s->GetParaformerResult();
int32_t num_processed_frames = s->GetNumProcessedFrames();
// frame shift is 10 milliseconds
float frame_shift_in_seconds = 0.01;
int32_t trailing_silence_frames =
num_processed_frames - result.last_non_blank_frame_index;
return endpoint_.IsEndpoint(num_processed_frames, trailing_silence_frames,
frame_shift_in_seconds);
}
void Reset(OnlineStream *s) const override {
OnlineParaformerDecoderResult r;
s->SetParaformerResult(r);
// the internal model caches are not reset
// Note: We only update counters. The underlying audio samples
// are not discarded.
s->Reset();
}
private:
void DecodeStream(OnlineStream *s) const {
const auto num_processed_frames = s->GetNumProcessedFrames();
std::vector<float> frames = s->GetFrames(num_processed_frames, chunk_size_);
s->GetNumProcessedFrames() += chunk_size_ - 1;
frames = ApplyLFR(frames);
ApplyCMVN(&frames);
PositionalEncoding(&frames, num_processed_frames / model_.LfrWindowShift());
int32_t feat_dim = model_.NegativeMean().size();
// We have scaled inv_stddev by sqrt(encoder_output_size)
// so the following line can be commented out
// frames *= encoder_output_size ** 0.5
// add overlap chunk
std::vector<float> &feat_cache = s->GetParaformerFeatCache();
if (feat_cache.empty()) {
int32_t n = (left_chunk_size_ + right_chunk_size_) * feat_dim;
feat_cache.resize(n, 0);
}
frames.insert(frames.begin(), feat_cache.begin(), feat_cache.end());
std::copy(frames.end() - feat_cache.size(), frames.end(),
feat_cache.begin());
int32_t num_frames = frames.size() / feat_dim;
auto memory_info =
Ort::MemoryInfo::CreateCpu(OrtDeviceAllocator, OrtMemTypeDefault);
std::array<int64_t, 3> x_shape{1, num_frames, feat_dim};
Ort::Value x =
Ort::Value::CreateTensor(memory_info, frames.data(), frames.size(),
x_shape.data(), x_shape.size());
int64_t x_len_shape = 1;
int32_t x_len_val = num_frames;
Ort::Value x_length =
Ort::Value::CreateTensor(memory_info, &x_len_val, 1, &x_len_shape, 1);
auto encoder_out_vec =
model_.ForwardEncoder(std::move(x), std::move(x_length));
// CIF search
auto &encoder_out = encoder_out_vec[0];
auto &encoder_out_len = encoder_out_vec[1];
auto &alpha = encoder_out_vec[2];
float *p_alpha = alpha.GetTensorMutableData<float>();
std::vector<int64_t> alpha_shape =
alpha.GetTensorTypeAndShapeInfo().GetShape();
std::fill(p_alpha, p_alpha + left_chunk_size_, 0);
std::fill(p_alpha + alpha_shape[1] - right_chunk_size_,
p_alpha + alpha_shape[1], 0);
const float *p_encoder_out = encoder_out.GetTensorData<float>();
std::vector<int64_t> encoder_out_shape =
encoder_out.GetTensorTypeAndShapeInfo().GetShape();
std::vector<float> &initial_hidden = s->GetParaformerEncoderOutCache();
if (initial_hidden.empty()) {
initial_hidden.resize(encoder_out_shape[2]);
}
std::vector<float> &alpha_cache = s->GetParaformerAlphaCache();
if (alpha_cache.empty()) {
alpha_cache.resize(1);
}
std::vector<float> acoustic_embedding;
acoustic_embedding.reserve(encoder_out_shape[1] * encoder_out_shape[2]);
float threshold = 1.0;
float integrate = alpha_cache[0];
for (int32_t i = 0; i != encoder_out_shape[1]; ++i) {
float this_alpha = p_alpha[i];
if (integrate + this_alpha < threshold) {
integrate += this_alpha;
ScaleAddInPlace(p_encoder_out + i * encoder_out_shape[2],
encoder_out_shape[2], this_alpha,
initial_hidden.data());
continue;
}
// fire
ScaleAddInPlace(p_encoder_out + i * encoder_out_shape[2],
encoder_out_shape[2], threshold - integrate,
initial_hidden.data());
acoustic_embedding.insert(acoustic_embedding.end(),
initial_hidden.begin(), initial_hidden.end());
integrate += this_alpha - threshold;
Scale(p_encoder_out + i * encoder_out_shape[2], encoder_out_shape[2],
integrate, initial_hidden.data());
}
alpha_cache[0] = integrate;
if (acoustic_embedding.empty()) {
return;
}
auto &states = s->GetStates();
if (states.empty()) {
states.reserve(model_.DecoderNumBlocks());
std::array<int64_t, 3> shape{1, model_.EncoderOutputSize(),
model_.DecoderKernelSize() - 1};
int32_t num_bytes = sizeof(float) * shape[0] * shape[1] * shape[2];
for (int32_t i = 0; i != model_.DecoderNumBlocks(); ++i) {
Ort::Value this_state = Ort::Value::CreateTensor<float>(
model_.Allocator(), shape.data(), shape.size());
memset(this_state.GetTensorMutableData<float>(), 0, num_bytes);
states.push_back(std::move(this_state));
}
}
int32_t num_tokens = acoustic_embedding.size() / initial_hidden.size();
std::array<int64_t, 3> acoustic_embedding_shape{
1, num_tokens, static_cast<int32_t>(initial_hidden.size())};
Ort::Value acoustic_embedding_tensor = Ort::Value::CreateTensor(
memory_info, acoustic_embedding.data(), acoustic_embedding.size(),
acoustic_embedding_shape.data(), acoustic_embedding_shape.size());
std::array<int64_t, 1> acoustic_embedding_length_shape{1};
Ort::Value acoustic_embedding_length_tensor = Ort::Value::CreateTensor(
memory_info, &num_tokens, 1, acoustic_embedding_length_shape.data(),
acoustic_embedding_length_shape.size());
auto decoder_out_vec = model_.ForwardDecoder(
std::move(encoder_out), std::move(encoder_out_len),
std::move(acoustic_embedding_tensor),
std::move(acoustic_embedding_length_tensor), std::move(states));
states.reserve(model_.DecoderNumBlocks());
for (int32_t i = 2; i != decoder_out_vec.size(); ++i) {
// TODO(fangjun): When we change chunk_size_, we need to
// slice decoder_out_vec[i] accordingly.
states.push_back(std::move(decoder_out_vec[i]));
}
const auto &sample_ids = decoder_out_vec[1];
const int64_t *p_sample_ids = sample_ids.GetTensorData<int64_t>();
bool non_blank_detected = false;
auto &result = s->GetParaformerResult();
for (int32_t i = 0; i != num_tokens; ++i) {
int32_t t = p_sample_ids[i];
if (t == 0) {
continue;
}
non_blank_detected = true;
result.tokens.push_back(t);
}
if (non_blank_detected) {
result.last_non_blank_frame_index = num_processed_frames;
}
}
std::vector<float> ApplyLFR(const std::vector<float> &in) const {
int32_t lfr_window_size = model_.LfrWindowSize();
int32_t lfr_window_shift = model_.LfrWindowShift();
int32_t in_feat_dim = config_.feat_config.feature_dim;
int32_t in_num_frames = in.size() / in_feat_dim;
int32_t out_num_frames =
(in_num_frames - lfr_window_size) / lfr_window_shift + 1;
int32_t out_feat_dim = in_feat_dim * lfr_window_size;
std::vector<float> out(out_num_frames * out_feat_dim);
const float *p_in = in.data();
float *p_out = out.data();
for (int32_t i = 0; i != out_num_frames; ++i) {
std::copy(p_in, p_in + out_feat_dim, p_out);
p_out += out_feat_dim;
p_in += lfr_window_shift * in_feat_dim;
}
return out;
}
void ApplyCMVN(std::vector<float> *v) const {
const std::vector<float> &neg_mean = model_.NegativeMean();
const std::vector<float> &inv_stddev = model_.InverseStdDev();
int32_t dim = neg_mean.size();
int32_t num_frames = v->size() / dim;
float *p = v->data();
for (int32_t i = 0; i != num_frames; ++i) {
for (int32_t k = 0; k != dim; ++k) {
p[k] = (p[k] + neg_mean[k]) * inv_stddev[k];
}
p += dim;
}
}
void PositionalEncoding(std::vector<float> *v, int32_t t_offset) const {
int32_t lfr_window_size = model_.LfrWindowSize();
int32_t in_feat_dim = config_.feat_config.feature_dim;
int32_t feat_dim = in_feat_dim * lfr_window_size;
int32_t T = v->size() / feat_dim;
// log(10000)/(7*80/2-1) == 0.03301197265941284
// 7 is lfr_window_size
// 80 is in_feat_dim
// 7*80 is feat_dim
constexpr float kScale = -0.03301197265941284;
for (int32_t t = 0; t != T; ++t) {
float *p = v->data() + t * feat_dim;
int32_t offset = t + 1 + t_offset;
for (int32_t d = 0; d < feat_dim / 2; ++d) {
float inv_timescale = offset * std::exp(d * kScale);
float sin_d = std::sin(inv_timescale);
float cos_d = std::cos(inv_timescale);
p[d] += sin_d;
p[d + feat_dim / 2] += cos_d;
}
}
}
private:
OnlineRecognizerConfig config_;
OnlineParaformerModel model_;
SymbolTable sym_;
Endpoint endpoint_;
// 0.61 seconds
int32_t chunk_size_ = 61;
// (61 - 7) / 6 + 1 = 10
int32_t left_chunk_size_ = 5;
int32_t right_chunk_size_ = 5;
};
} // namespace sherpa_onnx
#endif // SHERPA_ONNX_CSRC_ONLINE_RECOGNIZER_PARAFORMER_IMPL_H_

View File

@@ -94,21 +94,6 @@ class OnlineRecognizerTransducerImpl : public OnlineRecognizerImpl {
}
#endif
void InitOnlineStream(OnlineStream *stream) const override {
auto r = decoder_->GetEmptyResult();
if (config_.decoding_method == "modified_beam_search" &&
nullptr != stream->GetContextGraph()) {
// r.hyps has only one element.
for (auto it = r.hyps.begin(); it != r.hyps.end(); ++it) {
it->second.context_state = stream->GetContextGraph()->Root();
}
}
stream->SetResult(r);
stream->SetStates(model_->GetEncoderInitStates());
}
std::unique_ptr<OnlineStream> CreateStream() const override {
auto stream = std::make_unique<OnlineStream>(config_.feat_config);
InitOnlineStream(stream.get());
@@ -211,7 +196,10 @@ class OnlineRecognizerTransducerImpl : public OnlineRecognizerImpl {
}
bool IsEndpoint(OnlineStream *s) const override {
if (!config_.enable_endpoint) return false;
if (!config_.enable_endpoint) {
return false;
}
int32_t num_processed_frames = s->GetNumProcessedFrames();
// frame shift is 10 milliseconds
@@ -244,6 +232,22 @@ class OnlineRecognizerTransducerImpl : public OnlineRecognizerImpl {
s->Reset();
}
private:
void InitOnlineStream(OnlineStream *stream) const {
auto r = decoder_->GetEmptyResult();
if (config_.decoding_method == "modified_beam_search" &&
nullptr != stream->GetContextGraph()) {
// r.hyps has only one element.
for (auto it = r.hyps.begin(); it != r.hyps.end(); ++it) {
it->second.context_state = stream->GetContextGraph()->Root();
}
}
stream->SetResult(r);
stream->SetStates(model_->GetEncoderInitStates());
}
private:
OnlineRecognizerConfig config_;
std::unique_ptr<OnlineTransducerModel> model_;

View File

@@ -47,6 +47,14 @@ class OnlineStream::Impl {
OnlineTransducerDecoderResult &GetResult() { return result_; }
void SetParaformerResult(const OnlineParaformerDecoderResult &r) {
paraformer_result_ = r;
}
OnlineParaformerDecoderResult &GetParaformerResult() {
return paraformer_result_;
}
int32_t FeatureDim() const { return feat_extractor_.FeatureDim(); }
void SetStates(std::vector<Ort::Value> states) {
@@ -57,6 +65,18 @@ class OnlineStream::Impl {
const ContextGraphPtr &GetContextGraph() const { return context_graph_; }
std::vector<float> &GetParaformerFeatCache() {
return paraformer_feat_cache_;
}
std::vector<float> &GetParaformerEncoderOutCache() {
return paraformer_encoder_out_cache_;
}
std::vector<float> &GetParaformerAlphaCache() {
return paraformer_alpha_cache_;
}
private:
FeatureExtractor feat_extractor_;
/// For contextual-biasing
@@ -65,6 +85,10 @@ class OnlineStream::Impl {
int32_t start_frame_index_ = 0; // never reset
OnlineTransducerDecoderResult result_;
std::vector<Ort::Value> states_;
std::vector<float> paraformer_feat_cache_;
std::vector<float> paraformer_encoder_out_cache_;
std::vector<float> paraformer_alpha_cache_;
OnlineParaformerDecoderResult paraformer_result_;
};
OnlineStream::OnlineStream(const FeatureExtractorConfig &config /*= {}*/,
@@ -107,6 +131,14 @@ OnlineTransducerDecoderResult &OnlineStream::GetResult() {
return impl_->GetResult();
}
void OnlineStream::SetParaformerResult(const OnlineParaformerDecoderResult &r) {
impl_->SetParaformerResult(r);
}
OnlineParaformerDecoderResult &OnlineStream::GetParaformerResult() {
return impl_->GetParaformerResult();
}
void OnlineStream::SetStates(std::vector<Ort::Value> states) {
impl_->SetStates(std::move(states));
}
@@ -119,4 +151,16 @@ const ContextGraphPtr &OnlineStream::GetContextGraph() const {
return impl_->GetContextGraph();
}
std::vector<float> &OnlineStream::GetParaformerFeatCache() {
return impl_->GetParaformerFeatCache();
}
std::vector<float> &OnlineStream::GetParaformerEncoderOutCache() {
return impl_->GetParaformerEncoderOutCache();
}
std::vector<float> &OnlineStream::GetParaformerAlphaCache() {
return impl_->GetParaformerAlphaCache();
}
} // namespace sherpa_onnx

View File

@@ -11,6 +11,7 @@
#include "onnxruntime_cxx_api.h" // NOLINT
#include "sherpa-onnx/csrc/context-graph.h"
#include "sherpa-onnx/csrc/features.h"
#include "sherpa-onnx/csrc/online-paraformer-decoder.h"
#include "sherpa-onnx/csrc/online-transducer-decoder.h"
namespace sherpa_onnx {
@@ -70,6 +71,9 @@ class OnlineStream {
void SetResult(const OnlineTransducerDecoderResult &r);
OnlineTransducerDecoderResult &GetResult();
void SetParaformerResult(const OnlineParaformerDecoderResult &r);
OnlineParaformerDecoderResult &GetParaformerResult();
void SetStates(std::vector<Ort::Value> states);
std::vector<Ort::Value> &GetStates();
@@ -80,6 +84,11 @@ class OnlineStream {
*/
const ContextGraphPtr &GetContextGraph() const;
// for streaming parformer
std::vector<float> &GetParaformerFeatCache();
std::vector<float> &GetParaformerEncoderOutCache();
std::vector<float> &GetParaformerAlphaCache();
private:
class Impl;
std::unique_ptr<Impl> impl_;

View File

@@ -12,8 +12,8 @@
#include "sherpa-onnx/csrc/online-recognizer.h"
#include "sherpa-onnx/csrc/online-stream.h"
#include "sherpa-onnx/csrc/symbol-table.h"
#include "sherpa-onnx/csrc/parse-options.h"
#include "sherpa-onnx/csrc/symbol-table.h"
#include "sherpa-onnx/csrc/wave-reader.h"
typedef struct {
@@ -80,7 +80,7 @@ for a list of pre-trained models to download.
bool is_ok = false;
const std::vector<float> samples =
sherpa_onnx::ReadWave(wav_filename, &sampling_rate, &is_ok);
sherpa_onnx::ReadWave(wav_filename, &sampling_rate, &is_ok);
if (!is_ok) {
fprintf(stderr, "Failed to read %s\n", wav_filename.c_str());
@@ -92,14 +92,14 @@ for a list of pre-trained models to download.
auto s = recognizer.CreateStream();
s->AcceptWaveform(sampling_rate, samples.data(), samples.size());
std::vector<float> tail_paddings(static_cast<int>(0.3 * sampling_rate));
std::vector<float> tail_paddings(static_cast<int>(0.8 * sampling_rate));
// Note: We can call AcceptWaveform() multiple times.
s->AcceptWaveform(
sampling_rate, tail_paddings.data(), tail_paddings.size());
s->AcceptWaveform(sampling_rate, tail_paddings.data(),
tail_paddings.size());
// Call InputFinished() to indicate that no audio samples are available
s->InputFinished();
ss.push_back({ std::move(s), duration, 0 });
ss.push_back({std::move(s), duration, 0});
}
std::vector<sherpa_onnx::OnlineStream *> ready_streams;
@@ -112,8 +112,9 @@ for a list of pre-trained models to download.
} else if (s.elapsed_seconds == 0) {
const auto end = std::chrono::steady_clock::now();
const float elapsed_seconds =
std::chrono::duration_cast<std::chrono::milliseconds>(end - begin)
.count() / 1000.;
std::chrono::duration_cast<std::chrono::milliseconds>(end - begin)
.count() /
1000.;
s.elapsed_seconds = elapsed_seconds;
}
}