add online-recognizer (#29)
This commit is contained in:
@@ -3,6 +3,7 @@ include_directories(${CMAKE_SOURCE_DIR})
|
|||||||
add_executable(sherpa-onnx
|
add_executable(sherpa-onnx
|
||||||
features.cc
|
features.cc
|
||||||
online-lstm-transducer-model.cc
|
online-lstm-transducer-model.cc
|
||||||
|
online-recognizer.cc
|
||||||
online-stream.cc
|
online-stream.cc
|
||||||
online-transducer-greedy-search-decoder.cc
|
online-transducer-greedy-search-decoder.cc
|
||||||
online-transducer-model-config.cc
|
online-transducer-model-config.cc
|
||||||
|
|||||||
@@ -7,12 +7,23 @@
|
|||||||
#include <algorithm>
|
#include <algorithm>
|
||||||
#include <memory>
|
#include <memory>
|
||||||
#include <mutex> // NOLINT
|
#include <mutex> // NOLINT
|
||||||
|
#include <sstream>
|
||||||
#include <vector>
|
#include <vector>
|
||||||
|
|
||||||
#include "kaldi-native-fbank/csrc/online-feature.h"
|
#include "kaldi-native-fbank/csrc/online-feature.h"
|
||||||
|
|
||||||
namespace sherpa_onnx {
|
namespace sherpa_onnx {
|
||||||
|
|
||||||
|
std::string FeatureExtractorConfig::ToString() const {
|
||||||
|
std::ostringstream os;
|
||||||
|
|
||||||
|
os << "FeatureExtractorConfig(";
|
||||||
|
os << "sampling_rate=" << sampling_rate << ", ";
|
||||||
|
os << "feature_dim=" << feature_dim << ")";
|
||||||
|
|
||||||
|
return os.str();
|
||||||
|
}
|
||||||
|
|
||||||
class FeatureExtractor::Impl {
|
class FeatureExtractor::Impl {
|
||||||
public:
|
public:
|
||||||
explicit Impl(const FeatureExtractorConfig &config) {
|
explicit Impl(const FeatureExtractorConfig &config) {
|
||||||
|
|||||||
@@ -6,6 +6,7 @@
|
|||||||
#define SHERPA_ONNX_CSRC_FEATURES_H_
|
#define SHERPA_ONNX_CSRC_FEATURES_H_
|
||||||
|
|
||||||
#include <memory>
|
#include <memory>
|
||||||
|
#include <string>
|
||||||
#include <vector>
|
#include <vector>
|
||||||
|
|
||||||
namespace sherpa_onnx {
|
namespace sherpa_onnx {
|
||||||
@@ -13,6 +14,8 @@ namespace sherpa_onnx {
|
|||||||
struct FeatureExtractorConfig {
|
struct FeatureExtractorConfig {
|
||||||
float sampling_rate = 16000;
|
float sampling_rate = 16000;
|
||||||
int32_t feature_dim = 80;
|
int32_t feature_dim = 80;
|
||||||
|
|
||||||
|
std::string ToString() const;
|
||||||
};
|
};
|
||||||
|
|
||||||
class FeatureExtractor {
|
class FeatureExtractor {
|
||||||
|
|||||||
136
sherpa-onnx/csrc/online-recognizer.cc
Normal file
136
sherpa-onnx/csrc/online-recognizer.cc
Normal file
@@ -0,0 +1,136 @@
|
|||||||
|
// sherpa-onnx/csrc/online-recognizer.cc
|
||||||
|
//
|
||||||
|
// Copyright (c) 2023 Xiaomi Corporation
|
||||||
|
|
||||||
|
#include "sherpa-onnx/csrc/online-recognizer.h"
|
||||||
|
|
||||||
|
#include <assert.h>
|
||||||
|
|
||||||
|
#include <memory>
|
||||||
|
#include <sstream>
|
||||||
|
#include <utility>
|
||||||
|
#include <vector>
|
||||||
|
|
||||||
|
#include "sherpa-onnx/csrc/online-transducer-decoder.h"
|
||||||
|
#include "sherpa-onnx/csrc/online-transducer-greedy-search-decoder.h"
|
||||||
|
#include "sherpa-onnx/csrc/online-transducer-model.h"
|
||||||
|
#include "sherpa-onnx/csrc/symbol-table.h"
|
||||||
|
|
||||||
|
namespace sherpa_onnx {
|
||||||
|
|
||||||
|
static OnlineRecognizerResult Convert(const OnlineTransducerDecoderResult &src,
|
||||||
|
const SymbolTable &sym_table) {
|
||||||
|
std::string text;
|
||||||
|
for (auto t : src.tokens) {
|
||||||
|
text += sym_table[t];
|
||||||
|
}
|
||||||
|
|
||||||
|
OnlineRecognizerResult ans;
|
||||||
|
ans.text = std::move(text);
|
||||||
|
return ans;
|
||||||
|
}
|
||||||
|
|
||||||
|
std::string OnlineRecognizerConfig::ToString() const {
|
||||||
|
std::ostringstream os;
|
||||||
|
|
||||||
|
os << "OnlineRecognizerConfig(";
|
||||||
|
os << "feat_config=" << feat_config.ToString() << ", ";
|
||||||
|
os << "model_config=" << model_config.ToString() << ", ";
|
||||||
|
os << "tokens=\"" << tokens << "\")";
|
||||||
|
|
||||||
|
return os.str();
|
||||||
|
}
|
||||||
|
|
||||||
|
class OnlineRecognizer::Impl {
|
||||||
|
public:
|
||||||
|
explicit Impl(const OnlineRecognizerConfig &config)
|
||||||
|
: config_(config),
|
||||||
|
model_(OnlineTransducerModel::Create(config.model_config)),
|
||||||
|
sym_(config.tokens) {
|
||||||
|
decoder_ =
|
||||||
|
std::make_unique<OnlineTransducerGreedySearchDecoder>(model_.get());
|
||||||
|
}
|
||||||
|
|
||||||
|
std::unique_ptr<OnlineStream> CreateStream() const {
|
||||||
|
auto stream = std::make_unique<OnlineStream>(config_.feat_config);
|
||||||
|
stream->SetResult(decoder_->GetEmptyResult());
|
||||||
|
stream->SetStates(model_->GetEncoderInitStates());
|
||||||
|
return stream;
|
||||||
|
}
|
||||||
|
|
||||||
|
bool IsReady(OnlineStream *s) const {
|
||||||
|
return s->GetNumProcessedFrames() + model_->ChunkSize() <
|
||||||
|
s->NumFramesReady();
|
||||||
|
}
|
||||||
|
|
||||||
|
void DecodeStreams(OnlineStream **ss, int32_t n) {
|
||||||
|
if (n != 1) {
|
||||||
|
fprintf(stderr, "only n == 1 is implemented\n");
|
||||||
|
exit(-1);
|
||||||
|
}
|
||||||
|
OnlineStream *s = ss[0];
|
||||||
|
assert(IsReady(s));
|
||||||
|
|
||||||
|
int32_t chunk_size = model_->ChunkSize();
|
||||||
|
int32_t chunk_shift = model_->ChunkShift();
|
||||||
|
|
||||||
|
int32_t feature_dim = s->FeatureDim();
|
||||||
|
|
||||||
|
std::array<int64_t, 3> x_shape{1, chunk_size, feature_dim};
|
||||||
|
|
||||||
|
auto memory_info =
|
||||||
|
Ort::MemoryInfo::CreateCpu(OrtDeviceAllocator, OrtMemTypeDefault);
|
||||||
|
|
||||||
|
std::vector<float> features =
|
||||||
|
s->GetFrames(s->GetNumProcessedFrames(), chunk_size);
|
||||||
|
|
||||||
|
s->GetNumProcessedFrames() += chunk_shift;
|
||||||
|
|
||||||
|
Ort::Value x =
|
||||||
|
Ort::Value::CreateTensor(memory_info, features.data(), features.size(),
|
||||||
|
x_shape.data(), x_shape.size());
|
||||||
|
|
||||||
|
auto pair = model_->RunEncoder(std::move(x), s->GetStates());
|
||||||
|
|
||||||
|
s->SetStates(std::move(pair.second));
|
||||||
|
std::vector<OnlineTransducerDecoderResult> results = {s->GetResult()};
|
||||||
|
|
||||||
|
decoder_->Decode(std::move(pair.first), &results);
|
||||||
|
s->SetResult(results[0]);
|
||||||
|
}
|
||||||
|
|
||||||
|
OnlineRecognizerResult GetResult(OnlineStream *s) {
|
||||||
|
OnlineTransducerDecoderResult decoder_result = s->GetResult();
|
||||||
|
decoder_->StripLeadingBlanks(&decoder_result);
|
||||||
|
|
||||||
|
return Convert(decoder_result, sym_);
|
||||||
|
}
|
||||||
|
|
||||||
|
private:
|
||||||
|
OnlineRecognizerConfig config_;
|
||||||
|
std::unique_ptr<OnlineTransducerModel> model_;
|
||||||
|
std::unique_ptr<OnlineTransducerDecoder> decoder_;
|
||||||
|
SymbolTable sym_;
|
||||||
|
};
|
||||||
|
|
||||||
|
OnlineRecognizer::OnlineRecognizer(const OnlineRecognizerConfig &config)
|
||||||
|
: impl_(std::make_unique<Impl>(config)) {}
|
||||||
|
OnlineRecognizer::~OnlineRecognizer() = default;
|
||||||
|
|
||||||
|
std::unique_ptr<OnlineStream> OnlineRecognizer::CreateStream() const {
|
||||||
|
return impl_->CreateStream();
|
||||||
|
}
|
||||||
|
|
||||||
|
bool OnlineRecognizer::IsReady(OnlineStream *s) const {
|
||||||
|
return impl_->IsReady(s);
|
||||||
|
}
|
||||||
|
|
||||||
|
void OnlineRecognizer::DecodeStreams(OnlineStream **ss, int32_t n) {
|
||||||
|
impl_->DecodeStreams(ss, n);
|
||||||
|
}
|
||||||
|
|
||||||
|
OnlineRecognizerResult OnlineRecognizer::GetResult(OnlineStream *s) {
|
||||||
|
return impl_->GetResult(s);
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace sherpa_onnx
|
||||||
65
sherpa-onnx/csrc/online-recognizer.h
Normal file
65
sherpa-onnx/csrc/online-recognizer.h
Normal file
@@ -0,0 +1,65 @@
|
|||||||
|
// sherpa-onnx/csrc/online-recognizer.h
|
||||||
|
//
|
||||||
|
// Copyright (c) 2023 Xiaomi Corporation
|
||||||
|
|
||||||
|
#ifndef SHERPA_ONNX_CSRC_ONLINE_RECOGNIZER_H_
|
||||||
|
#define SHERPA_ONNX_CSRC_ONLINE_RECOGNIZER_H_
|
||||||
|
|
||||||
|
#include <memory>
|
||||||
|
#include <string>
|
||||||
|
|
||||||
|
#include "sherpa-onnx/csrc/features.h"
|
||||||
|
#include "sherpa-onnx/csrc/online-stream.h"
|
||||||
|
#include "sherpa-onnx/csrc/online-transducer-model-config.h"
|
||||||
|
|
||||||
|
namespace sherpa_onnx {
|
||||||
|
|
||||||
|
struct OnlineRecognizerResult {
|
||||||
|
std::string text;
|
||||||
|
};
|
||||||
|
|
||||||
|
struct OnlineRecognizerConfig {
|
||||||
|
FeatureExtractorConfig feat_config;
|
||||||
|
OnlineTransducerModelConfig model_config;
|
||||||
|
std::string tokens;
|
||||||
|
|
||||||
|
std::string ToString() const;
|
||||||
|
};
|
||||||
|
|
||||||
|
class OnlineRecognizer {
|
||||||
|
public:
|
||||||
|
explicit OnlineRecognizer(const OnlineRecognizerConfig &config);
|
||||||
|
~OnlineRecognizer();
|
||||||
|
|
||||||
|
/// Create a stream for decoding.
|
||||||
|
std::unique_ptr<OnlineStream> CreateStream() const;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Return true if the given stream has enough frames for decoding.
|
||||||
|
* Return false otherwise
|
||||||
|
*/
|
||||||
|
bool IsReady(OnlineStream *s) const;
|
||||||
|
|
||||||
|
/** Decode a single stream. */
|
||||||
|
void DecodeStream(OnlineStream *s) {
|
||||||
|
OnlineStream *ss[1] = {s};
|
||||||
|
DecodeStreams(ss, 1);
|
||||||
|
}
|
||||||
|
|
||||||
|
/** Decode multiple streams in parallel
|
||||||
|
*
|
||||||
|
* @param ss Pointer array containing streams to be decoded.
|
||||||
|
* @param n Number of streams in `ss`.
|
||||||
|
*/
|
||||||
|
void DecodeStreams(OnlineStream **ss, int32_t n);
|
||||||
|
|
||||||
|
OnlineRecognizerResult GetResult(OnlineStream *s);
|
||||||
|
|
||||||
|
private:
|
||||||
|
class Impl;
|
||||||
|
std::unique_ptr<Impl> impl_;
|
||||||
|
};
|
||||||
|
|
||||||
|
} // namespace sherpa_onnx
|
||||||
|
|
||||||
|
#endif // SHERPA_ONNX_CSRC_ONLINE_RECOGNIZER_H_
|
||||||
@@ -4,6 +4,7 @@
|
|||||||
#include "sherpa-onnx/csrc/online-stream.h"
|
#include "sherpa-onnx/csrc/online-stream.h"
|
||||||
|
|
||||||
#include <memory>
|
#include <memory>
|
||||||
|
#include <utility>
|
||||||
#include <vector>
|
#include <vector>
|
||||||
|
|
||||||
#include "sherpa-onnx/csrc/features.h"
|
#include "sherpa-onnx/csrc/features.h"
|
||||||
@@ -41,10 +42,17 @@ class OnlineStream::Impl {
|
|||||||
|
|
||||||
int32_t FeatureDim() const { return feat_extractor_.FeatureDim(); }
|
int32_t FeatureDim() const { return feat_extractor_.FeatureDim(); }
|
||||||
|
|
||||||
|
void SetStates(std::vector<Ort::Value> states) {
|
||||||
|
states_ = std::move(states);
|
||||||
|
}
|
||||||
|
|
||||||
|
std::vector<Ort::Value> &GetStates() { return states_; }
|
||||||
|
|
||||||
private:
|
private:
|
||||||
FeatureExtractor feat_extractor_;
|
FeatureExtractor feat_extractor_;
|
||||||
int32_t num_processed_frames_ = 0; // before subsampling
|
int32_t num_processed_frames_ = 0; // before subsampling
|
||||||
OnlineTransducerDecoderResult result_;
|
OnlineTransducerDecoderResult result_;
|
||||||
|
std::vector<Ort::Value> states_;
|
||||||
};
|
};
|
||||||
|
|
||||||
OnlineStream::OnlineStream(const FeatureExtractorConfig &config /*= {}*/)
|
OnlineStream::OnlineStream(const FeatureExtractorConfig &config /*= {}*/)
|
||||||
@@ -86,4 +94,12 @@ const OnlineTransducerDecoderResult &OnlineStream::GetResult() const {
|
|||||||
return impl_->GetResult();
|
return impl_->GetResult();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
void OnlineStream::SetStates(std::vector<Ort::Value> states) {
|
||||||
|
impl_->SetStates(std::move(states));
|
||||||
|
}
|
||||||
|
|
||||||
|
std::vector<Ort::Value> &OnlineStream::GetStates() {
|
||||||
|
return impl_->GetStates();
|
||||||
|
}
|
||||||
|
|
||||||
} // namespace sherpa_onnx
|
} // namespace sherpa_onnx
|
||||||
|
|||||||
@@ -8,6 +8,7 @@
|
|||||||
#include <memory>
|
#include <memory>
|
||||||
#include <vector>
|
#include <vector>
|
||||||
|
|
||||||
|
#include "onnxruntime_cxx_api.h" // NOLINT
|
||||||
#include "sherpa-onnx/csrc/features.h"
|
#include "sherpa-onnx/csrc/features.h"
|
||||||
#include "sherpa-onnx/csrc/online-transducer-decoder.h"
|
#include "sherpa-onnx/csrc/online-transducer-decoder.h"
|
||||||
|
|
||||||
@@ -63,6 +64,9 @@ class OnlineStream {
|
|||||||
void SetResult(const OnlineTransducerDecoderResult &r);
|
void SetResult(const OnlineTransducerDecoderResult &r);
|
||||||
const OnlineTransducerDecoderResult &GetResult() const;
|
const OnlineTransducerDecoderResult &GetResult() const;
|
||||||
|
|
||||||
|
void SetStates(std::vector<Ort::Value> states);
|
||||||
|
std::vector<Ort::Value> &GetStates();
|
||||||
|
|
||||||
private:
|
private:
|
||||||
class Impl;
|
class Impl;
|
||||||
std::unique_ptr<Impl> impl_;
|
std::unique_ptr<Impl> impl_;
|
||||||
|
|||||||
@@ -26,13 +26,14 @@ class OnlineTransducerDecoder {
|
|||||||
* to the beginning of the decoding result, which will be
|
* to the beginning of the decoding result, which will be
|
||||||
* stripped by calling `StripPrecedingBlanks()`.
|
* stripped by calling `StripPrecedingBlanks()`.
|
||||||
*/
|
*/
|
||||||
virtual OnlineTransducerDecoderResult GetEmptyResult() = 0;
|
virtual OnlineTransducerDecoderResult GetEmptyResult() const = 0;
|
||||||
|
|
||||||
/** Strip blanks added by `GetEmptyResult()`.
|
/** Strip blanks added by `GetEmptyResult()`.
|
||||||
*
|
*
|
||||||
* @param r It is changed in-place.
|
* @param r It is changed in-place.
|
||||||
*/
|
*/
|
||||||
virtual void StripLeadingBlanks(OnlineTransducerDecoderResult * /*r*/) {}
|
virtual void StripLeadingBlanks(OnlineTransducerDecoderResult * /*r*/) const {
|
||||||
|
}
|
||||||
|
|
||||||
/** Run transducer beam search given the output from the encoder model.
|
/** Run transducer beam search given the output from the encoder model.
|
||||||
*
|
*
|
||||||
|
|||||||
@@ -33,7 +33,7 @@ static Ort::Value GetFrame(Ort::Value *encoder_out, int32_t t) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
OnlineTransducerDecoderResult
|
OnlineTransducerDecoderResult
|
||||||
OnlineTransducerGreedySearchDecoder::GetEmptyResult() {
|
OnlineTransducerGreedySearchDecoder::GetEmptyResult() const {
|
||||||
int32_t context_size = model_->ContextSize();
|
int32_t context_size = model_->ContextSize();
|
||||||
int32_t blank_id = 0; // always 0
|
int32_t blank_id = 0; // always 0
|
||||||
OnlineTransducerDecoderResult r;
|
OnlineTransducerDecoderResult r;
|
||||||
@@ -43,7 +43,7 @@ OnlineTransducerGreedySearchDecoder::GetEmptyResult() {
|
|||||||
}
|
}
|
||||||
|
|
||||||
void OnlineTransducerGreedySearchDecoder::StripLeadingBlanks(
|
void OnlineTransducerGreedySearchDecoder::StripLeadingBlanks(
|
||||||
OnlineTransducerDecoderResult *r) {
|
OnlineTransducerDecoderResult *r) const {
|
||||||
int32_t context_size = model_->ContextSize();
|
int32_t context_size = model_->ContextSize();
|
||||||
|
|
||||||
auto start = r->tokens.begin() + context_size;
|
auto start = r->tokens.begin() + context_size;
|
||||||
|
|||||||
@@ -17,9 +17,9 @@ class OnlineTransducerGreedySearchDecoder : public OnlineTransducerDecoder {
|
|||||||
explicit OnlineTransducerGreedySearchDecoder(OnlineTransducerModel *model)
|
explicit OnlineTransducerGreedySearchDecoder(OnlineTransducerModel *model)
|
||||||
: model_(model) {}
|
: model_(model) {}
|
||||||
|
|
||||||
OnlineTransducerDecoderResult GetEmptyResult() override;
|
OnlineTransducerDecoderResult GetEmptyResult() const override;
|
||||||
|
|
||||||
void StripLeadingBlanks(OnlineTransducerDecoderResult *r) override;
|
void StripLeadingBlanks(OnlineTransducerDecoderResult *r) const override;
|
||||||
|
|
||||||
void Decode(Ort::Value encoder_out,
|
void Decode(Ort::Value encoder_out,
|
||||||
std::vector<OnlineTransducerDecoderResult> *result) override;
|
std::vector<OnlineTransducerDecoderResult> *result) override;
|
||||||
|
|||||||
@@ -8,6 +8,7 @@
|
|||||||
#include <string>
|
#include <string>
|
||||||
#include <vector>
|
#include <vector>
|
||||||
|
|
||||||
|
#include "sherpa-onnx/csrc/online-recognizer.h"
|
||||||
#include "sherpa-onnx/csrc/online-stream.h"
|
#include "sherpa-onnx/csrc/online-stream.h"
|
||||||
#include "sherpa-onnx/csrc/online-transducer-greedy-search-decoder.h"
|
#include "sherpa-onnx/csrc/online-transducer-greedy-search-decoder.h"
|
||||||
#include "sherpa-onnx/csrc/online-transducer-model-config.h"
|
#include "sherpa-onnx/csrc/online-transducer-model-config.h"
|
||||||
@@ -35,35 +36,26 @@ for a list of pre-trained models to download.
|
|||||||
return 0;
|
return 0;
|
||||||
}
|
}
|
||||||
|
|
||||||
std::string tokens = argv[1];
|
sherpa_onnx::OnlineRecognizerConfig config;
|
||||||
sherpa_onnx::OnlineTransducerModelConfig config;
|
|
||||||
config.debug = false;
|
config.tokens = argv[1];
|
||||||
config.encoder_filename = argv[2];
|
|
||||||
config.decoder_filename = argv[3];
|
config.model_config.debug = false;
|
||||||
config.joiner_filename = argv[4];
|
config.model_config.encoder_filename = argv[2];
|
||||||
|
config.model_config.decoder_filename = argv[3];
|
||||||
|
config.model_config.joiner_filename = argv[4];
|
||||||
|
|
||||||
std::string wav_filename = argv[5];
|
std::string wav_filename = argv[5];
|
||||||
|
|
||||||
config.num_threads = 2;
|
config.model_config.num_threads = 2;
|
||||||
if (argc == 7) {
|
if (argc == 7) {
|
||||||
config.num_threads = atoi(argv[6]);
|
config.model_config.num_threads = atoi(argv[6]);
|
||||||
}
|
}
|
||||||
fprintf(stderr, "%s\n", config.ToString().c_str());
|
fprintf(stderr, "%s\n", config.ToString().c_str());
|
||||||
|
|
||||||
auto model = sherpa_onnx::OnlineTransducerModel::Create(config);
|
sherpa_onnx::OnlineRecognizer recognizer(config);
|
||||||
|
|
||||||
sherpa_onnx::SymbolTable sym(tokens);
|
float expected_sampling_rate = config.feat_config.sampling_rate;
|
||||||
|
|
||||||
Ort::AllocatorWithDefaultOptions allocator;
|
|
||||||
|
|
||||||
int32_t chunk_size = model->ChunkSize();
|
|
||||||
int32_t chunk_shift = model->ChunkShift();
|
|
||||||
|
|
||||||
auto memory_info =
|
|
||||||
Ort::MemoryInfo::CreateCpu(OrtDeviceAllocator, OrtMemTypeDefault);
|
|
||||||
|
|
||||||
std::vector<Ort::Value> states = model->GetEncoderInitStates();
|
|
||||||
|
|
||||||
float expected_sampling_rate = 16000;
|
|
||||||
|
|
||||||
bool is_ok = false;
|
bool is_ok = false;
|
||||||
std::vector<float> samples =
|
std::vector<float> samples =
|
||||||
@@ -82,44 +74,21 @@ for a list of pre-trained models to download.
|
|||||||
auto begin = std::chrono::steady_clock::now();
|
auto begin = std::chrono::steady_clock::now();
|
||||||
fprintf(stderr, "Started\n");
|
fprintf(stderr, "Started\n");
|
||||||
|
|
||||||
sherpa_onnx::OnlineStream stream;
|
auto s = recognizer.CreateStream();
|
||||||
stream.AcceptWaveform(expected_sampling_rate, samples.data(), samples.size());
|
s->AcceptWaveform(expected_sampling_rate, samples.data(), samples.size());
|
||||||
|
|
||||||
std::vector<float> tail_paddings(
|
std::vector<float> tail_paddings(
|
||||||
static_cast<int>(0.2 * expected_sampling_rate));
|
static_cast<int>(0.2 * expected_sampling_rate));
|
||||||
stream.AcceptWaveform(expected_sampling_rate, tail_paddings.data(),
|
s->AcceptWaveform(expected_sampling_rate, tail_paddings.data(),
|
||||||
tail_paddings.size());
|
tail_paddings.size());
|
||||||
stream.InputFinished();
|
s->InputFinished();
|
||||||
|
|
||||||
int32_t num_frames = stream.NumFramesReady();
|
while (recognizer.IsReady(s.get())) {
|
||||||
int32_t feature_dim = stream.FeatureDim();
|
recognizer.DecodeStream(s.get());
|
||||||
|
|
||||||
std::array<int64_t, 3> x_shape{1, chunk_size, feature_dim};
|
|
||||||
|
|
||||||
sherpa_onnx::OnlineTransducerGreedySearchDecoder decoder(model.get());
|
|
||||||
std::vector<sherpa_onnx::OnlineTransducerDecoderResult> result = {
|
|
||||||
decoder.GetEmptyResult()};
|
|
||||||
while (stream.NumFramesReady() - stream.GetNumProcessedFrames() >
|
|
||||||
chunk_size) {
|
|
||||||
std::vector<float> features =
|
|
||||||
stream.GetFrames(stream.GetNumProcessedFrames(), chunk_size);
|
|
||||||
stream.GetNumProcessedFrames() += chunk_shift;
|
|
||||||
|
|
||||||
Ort::Value x =
|
|
||||||
Ort::Value::CreateTensor(memory_info, features.data(), features.size(),
|
|
||||||
x_shape.data(), x_shape.size());
|
|
||||||
|
|
||||||
auto pair = model->RunEncoder(std::move(x), states);
|
|
||||||
states = std::move(pair.second);
|
|
||||||
decoder.Decode(std::move(pair.first), &result);
|
|
||||||
}
|
|
||||||
decoder.StripLeadingBlanks(&result[0]);
|
|
||||||
const auto &hyp = result[0].tokens;
|
|
||||||
std::string text;
|
|
||||||
for (auto t : hyp) {
|
|
||||||
text += sym[t];
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
std::string text = recognizer.GetResult(s.get()).text;
|
||||||
|
|
||||||
fprintf(stderr, "Done!\n");
|
fprintf(stderr, "Done!\n");
|
||||||
|
|
||||||
fprintf(stderr, "Recognition result for %s:\n%s\n", wav_filename.c_str(),
|
fprintf(stderr, "Recognition result for %s:\n%s\n", wav_filename.c_str(),
|
||||||
@@ -131,7 +100,7 @@ for a list of pre-trained models to download.
|
|||||||
.count() /
|
.count() /
|
||||||
1000.;
|
1000.;
|
||||||
|
|
||||||
fprintf(stderr, "num threads: %d\n", config.num_threads);
|
fprintf(stderr, "num threads: %d\n", config.model_config.num_threads);
|
||||||
|
|
||||||
fprintf(stderr, "Elapsed seconds: %.3f s\n", elapsed_seconds);
|
fprintf(stderr, "Elapsed seconds: %.3f s\n", elapsed_seconds);
|
||||||
float rtf = elapsed_seconds / duration;
|
float rtf = elapsed_seconds / duration;
|
||||||
|
|||||||
Reference in New Issue
Block a user