Fix nemo streaming transducer greedy search (#944)

This commit is contained in:
Fangjun Kuang
2024-05-30 15:31:10 +08:00
committed by GitHub
parent 3f472a9993
commit 082f230dfb
18 changed files with 318 additions and 288 deletions

View File

@@ -6,6 +6,7 @@
#ifndef SHERPA_ONNX_CSRC_ONLINE_RECOGNIZER_TRANSDUCER_NEMO_IMPL_H_
#define SHERPA_ONNX_CSRC_ONLINE_RECOGNIZER_TRANSDUCER_NEMO_IMPL_H_
#include <algorithm>
#include <fstream>
#include <ios>
#include <memory>
@@ -32,23 +33,20 @@
namespace sherpa_onnx {
// defined in ./online-recognizer-transducer-impl.h
// static may or may not be here? TODDOs
static OnlineRecognizerResult Convert(const OnlineTransducerDecoderResult &src,
const SymbolTable &sym_table,
float frame_shift_ms,
int32_t subsampling_factor,
int32_t segment,
int32_t frames_since_start);
OnlineRecognizerResult Convert(const OnlineTransducerDecoderResult &src,
const SymbolTable &sym_table,
float frame_shift_ms, int32_t subsampling_factor,
int32_t segment, int32_t frames_since_start);
class OnlineRecognizerTransducerNeMoImpl : public OnlineRecognizerImpl {
public:
public:
explicit OnlineRecognizerTransducerNeMoImpl(
const OnlineRecognizerConfig &config)
: config_(config),
symbol_table_(config.model_config.tokens),
endpoint_(config_.endpoint_config),
model_(std::make_unique<OnlineTransducerNeMoModel>(
config.model_config)) {
model_(
std::make_unique<OnlineTransducerNeMoModel>(config.model_config)) {
if (config.decoding_method == "greedy_search") {
decoder_ = std::make_unique<OnlineTransducerGreedySearchNeMoDecoder>(
model_.get(), config_.blank_penalty);
@@ -73,7 +71,7 @@ class OnlineRecognizerTransducerNeMoImpl : public OnlineRecognizerImpl {
model_.get(), config_.blank_penalty);
} else {
SHERPA_ONNX_LOGE("Unsupported decoding method: %s",
config.decoding_method.c_str());
config.decoding_method.c_str());
exit(-1);
}
@@ -83,7 +81,6 @@ class OnlineRecognizerTransducerNeMoImpl : public OnlineRecognizerImpl {
std::unique_ptr<OnlineStream> CreateStream() const override {
auto stream = std::make_unique<OnlineStream>(config_.feat_config);
stream->SetStates(model_->GetInitStates());
InitOnlineStream(stream.get());
return stream;
}
@@ -94,14 +91,12 @@ class OnlineRecognizerTransducerNeMoImpl : public OnlineRecognizerImpl {
}
OnlineRecognizerResult GetResult(OnlineStream *s) const override {
OnlineTransducerDecoderResult decoder_result = s->GetResult();
decoder_->StripLeadingBlanks(&decoder_result);
// TODO(fangjun): Remember to change these constants if needed
int32_t frame_shift_ms = 10;
int32_t subsampling_factor = 8;
return Convert(decoder_result, symbol_table_, frame_shift_ms, subsampling_factor,
s->GetCurrentSegment(), s->GetNumFramesSinceStart());
int32_t subsampling_factor = model_->SubsamplingFactor();
return Convert(s->GetResult(), symbol_table_, frame_shift_ms,
subsampling_factor, s->GetCurrentSegment(),
s->GetNumFramesSinceStart());
}
bool IsEndpoint(OnlineStream *s) const override {
@@ -114,8 +109,8 @@ class OnlineRecognizerTransducerNeMoImpl : public OnlineRecognizerImpl {
// frame shift is 10 milliseconds
float frame_shift_in_seconds = 0.01;
// subsampling factor is 8
int32_t trailing_silence_frames = s->GetResult().num_trailing_blanks * 8;
int32_t trailing_silence_frames =
s->GetResult().num_trailing_blanks * model_->SubsamplingFactor();
return endpoint_.IsEndpoint(num_processed_frames, trailing_silence_frames,
frame_shift_in_seconds);
@@ -126,19 +121,16 @@ class OnlineRecognizerTransducerNeMoImpl : public OnlineRecognizerImpl {
// segment is incremented only when the last
// result is not empty
const auto &r = s->GetResult();
if (!r.tokens.empty() && r.tokens.back() != 0) {
if (!r.tokens.empty()) {
s->GetCurrentSegment() += 1;
}
}
// we keep the decoder_out
decoder_->UpdateDecoderOut(&s->GetResult());
Ort::Value decoder_out = std::move(s->GetResult().decoder_out);
s->SetResult({});
auto r = decoder_->GetEmptyResult();
s->SetResult(r);
s->GetResult().decoder_out = std::move(decoder_out);
s->SetStates(model_->GetEncoderInitStates());
s->SetNeMoDecoderStates(model_->GetDecoderInitStates());
// Note: We only update counters. The underlying audio samples
// are not discarded.
@@ -151,10 +143,9 @@ class OnlineRecognizerTransducerNeMoImpl : public OnlineRecognizerImpl {
int32_t feature_dim = ss[0]->FeatureDim();
std::vector<OnlineTransducerDecoderResult> result(n);
std::vector<float> features_vec(n * chunk_size * feature_dim);
std::vector<std::vector<Ort::Value>> encoder_states(n);
for (int32_t i = 0; i != n; ++i) {
const auto num_processed_frames = ss[i]->GetNumProcessedFrames();
std::vector<float> features =
@@ -166,9 +157,7 @@ class OnlineRecognizerTransducerNeMoImpl : public OnlineRecognizerImpl {
std::copy(features.begin(), features.end(),
features_vec.data() + i * chunk_size * feature_dim);
result[i] = std::move(ss[i]->GetResult());
encoder_states[i] = std::move(ss[i]->GetStates());
}
auto memory_info =
@@ -180,42 +169,35 @@ class OnlineRecognizerTransducerNeMoImpl : public OnlineRecognizerImpl {
features_vec.size(), x_shape.data(),
x_shape.size());
// Batch size is 1
auto states = std::move(encoder_states[0]);
int32_t num_states = states.size(); // num_states = 3
auto states = model_->StackStates(std::move(encoder_states));
int32_t num_states = states.size(); // num_states = 3
auto t = model_->RunEncoder(std::move(x), std::move(states));
// t[0] encoder_out, float tensor, (batch_size, dim, T)
// t[1] next states
std::vector<Ort::Value> out_states;
out_states.reserve(num_states);
for (int32_t k = 1; k != num_states + 1; ++k) {
out_states.push_back(std::move(t[k]));
}
auto unstacked_states = model_->UnStackStates(std::move(out_states));
for (int32_t i = 0; i != n; ++i) {
ss[i]->SetStates(std::move(unstacked_states[i]));
}
Ort::Value encoder_out = Transpose12(model_->Allocator(), &t[0]);
// defined in online-transducer-greedy-search-nemo-decoder.h
// get intial states of decoder.
std::vector<Ort::Value> &decoder_states = ss[0]->GetNeMoDecoderStates();
// Subsequent decoder states (for each chunks) are updated inside the Decode method.
// This returns the decoder state from the LAST chunk. We probably dont need it. So we can discard it.
decoder_states = decoder_->Decode(std::move(encoder_out),
std::move(decoder_states),
&result, ss, n);
ss[0]->SetResult(result[0]);
ss[0]->SetStates(std::move(out_states));
decoder_->Decode(std::move(encoder_out), ss, n);
}
void InitOnlineStream(OnlineStream *stream) const {
auto r = decoder_->GetEmptyResult();
// set encoder states
stream->SetStates(model_->GetEncoderInitStates());
stream->SetResult(r);
stream->SetNeMoDecoderStates(model_->GetDecoderInitStates(1));
// set decoder states
stream->SetNeMoDecoderStates(model_->GetDecoderInitStates());
}
private:
@@ -250,7 +232,6 @@ class OnlineRecognizerTransducerNeMoImpl : public OnlineRecognizerImpl {
symbol_table_.NumSymbols(), vocab_size);
exit(-1);
}
}
private:
@@ -259,9 +240,8 @@ class OnlineRecognizerTransducerNeMoImpl : public OnlineRecognizerImpl {
std::unique_ptr<OnlineTransducerNeMoModel> model_;
std::unique_ptr<OnlineTransducerGreedySearchNeMoDecoder> decoder_;
Endpoint endpoint_;
};
} // namespace sherpa_onnx
#endif // SHERPA_ONNX_CSRC_ONLINE_RECOGNIZER_TRANSDUCER_NEMO_IMPL_H_
#endif // SHERPA_ONNX_CSRC_ONLINE_RECOGNIZER_TRANSDUCER_NEMO_IMPL_H_