diff --git a/cmake/kaldi-native-fbank.cmake b/cmake/kaldi-native-fbank.cmake index db87977a..3450a9f1 100644 --- a/cmake/kaldi-native-fbank.cmake +++ b/cmake/kaldi-native-fbank.cmake @@ -1,8 +1,8 @@ function(download_kaldi_native_fbank) include(FetchContent) - set(kaldi_native_fbank_URL "https://github.com/csukuangfj/kaldi-native-fbank/archive/refs/tags/v1.11.tar.gz") - set(kaldi_native_fbank_HASH "SHA256=e69ae25ef6f30566ef31ca949dd1b0b8ec3a827caeba93a61d82bb848dac5d69") + set(kaldi_native_fbank_URL "https://github.com/csukuangfj/kaldi-native-fbank/archive/refs/tags/v1.12.tar.gz") + set(kaldi_native_fbank_HASH "SHA256=8f4dfc3f6ddb1adcd9ac0ae87743ebc6cbcae147aacf9d46e76fa54134e12b44") set(KALDI_NATIVE_FBANK_BUILD_TESTS OFF CACHE BOOL "" FORCE) set(KALDI_NATIVE_FBANK_BUILD_PYTHON OFF CACHE BOOL "" FORCE) @@ -11,10 +11,11 @@ function(download_kaldi_native_fbank) # If you don't have access to the Internet, # please pre-download kaldi-native-fbank set(possible_file_locations - $ENV{HOME}/Downloads/kaldi-native-fbank-1.11.tar.gz - ${PROJECT_SOURCE_DIR}/kaldi-native-fbank-1.11.tar.gz - ${PROJECT_BINARY_DIR}/kaldi-native-fbank-1.11.tar.gz - /tmp/kaldi-native-fbank-1.11.tar.gz + $ENV{HOME}/Downloads/kaldi-native-fbank-1.12.tar.gz + ${PROJECT_SOURCE_DIR}/kaldi-native-fbank-1.12.tar.gz + ${PROJECT_BINARY_DIR}/kaldi-native-fbank-1.12.tar.gz + /tmp/kaldi-native-fbank-1.12.tar.gz + /star-fj/fangjun/download/github/kaldi-native-fbank-1.12.tar.gz ) foreach(f IN LISTS possible_file_locations) diff --git a/cmake/onnxruntime.cmake b/cmake/onnxruntime.cmake index bb2a64b9..9325b536 100644 --- a/cmake/onnxruntime.cmake +++ b/cmake/onnxruntime.cmake @@ -9,6 +9,7 @@ function(download_onnxruntime) ${PROJECT_SOURCE_DIR}/onnxruntime-linux-x64-1.14.0.tgz ${PROJECT_BINARY_DIR}/onnxruntime-linux-x64-1.14.0.tgz /tmp/onnxruntime-linux-x64-1.14.0.tgz + /star-fj/fangjun/download/github/onnxruntime-linux-x64-1.14.0.tgz ) set(onnxruntime_URL "https://github.com/microsoft/onnxruntime/releases/download/v1.14.0/onnxruntime-linux-x64-1.14.0.tgz") diff --git a/sherpa-onnx/csrc/CMakeLists.txt b/sherpa-onnx/csrc/CMakeLists.txt index dbc4461e..f2b5fbd7 100644 --- a/sherpa-onnx/csrc/CMakeLists.txt +++ b/sherpa-onnx/csrc/CMakeLists.txt @@ -1,9 +1,9 @@ include_directories(${CMAKE_SOURCE_DIR}) add_executable(sherpa-onnx - decode.cc features.cc online-lstm-transducer-model.cc + online-transducer-greedy-search-decoder.cc online-transducer-model-config.cc online-transducer-model.cc onnx-utils.cc diff --git a/sherpa-onnx/csrc/decode.cc b/sherpa-onnx/csrc/decode.cc deleted file mode 100644 index 5f0cb0f1..00000000 --- a/sherpa-onnx/csrc/decode.cc +++ /dev/null @@ -1,79 +0,0 @@ -// sherpa/csrc/decode.cc -// -// Copyright (c) 2023 Xiaomi Corporation - -#include "sherpa-onnx/csrc/decode.h" - -#include - -#include -#include -#include - -namespace sherpa_onnx { - -static Ort::Value Clone(Ort::Value *v) { - auto type_and_shape = v->GetTensorTypeAndShapeInfo(); - std::vector shape = type_and_shape.GetShape(); - - auto memory_info = - Ort::MemoryInfo::CreateCpu(OrtDeviceAllocator, OrtMemTypeDefault); - - return Ort::Value::CreateTensor(memory_info, v->GetTensorMutableData(), - type_and_shape.GetElementCount(), - shape.data(), shape.size()); -} - -static Ort::Value GetFrame(Ort::Value *encoder_out, int32_t t) { - std::vector encoder_out_shape = - encoder_out->GetTensorTypeAndShapeInfo().GetShape(); - assert(encoder_out_shape[0] == 1); - - int32_t encoder_out_dim = encoder_out_shape[2]; - - auto memory_info = - Ort::MemoryInfo::CreateCpu(OrtDeviceAllocator, OrtMemTypeDefault); - - std::array shape{1, encoder_out_dim}; - - return Ort::Value::CreateTensor( - memory_info, - encoder_out->GetTensorMutableData() + t * encoder_out_dim, - encoder_out_dim, shape.data(), shape.size()); -} - -void GreedySearch(OnlineTransducerModel *model, Ort::Value encoder_out, - std::vector *hyp) { - std::vector encoder_out_shape = - encoder_out.GetTensorTypeAndShapeInfo().GetShape(); - - if (encoder_out_shape[0] > 1) { - fprintf(stderr, "Only batch_size=1 is implemented. Given: %d\n", - static_cast(encoder_out_shape[0])); - } - - int32_t num_frames = encoder_out_shape[1]; - int32_t vocab_size = model->VocabSize(); - - Ort::Value decoder_input = model->BuildDecoderInput(*hyp); - Ort::Value decoder_out = model->RunDecoder(std::move(decoder_input)); - - for (int32_t t = 0; t != num_frames; ++t) { - Ort::Value cur_encoder_out = GetFrame(&encoder_out, t); - Ort::Value logit = - model->RunJoiner(std::move(cur_encoder_out), Clone(&decoder_out)); - const float *p_logit = logit.GetTensorData(); - - auto y = static_cast(std::distance( - static_cast(p_logit), - std::max_element(static_cast(p_logit), - static_cast(p_logit) + vocab_size))); - if (y != 0) { - hyp->push_back(y); - decoder_input = model->BuildDecoderInput(*hyp); - decoder_out = model->RunDecoder(std::move(decoder_input)); - } - } -} - -} // namespace sherpa_onnx diff --git a/sherpa-onnx/csrc/decode.h b/sherpa-onnx/csrc/decode.h deleted file mode 100644 index 88821573..00000000 --- a/sherpa-onnx/csrc/decode.h +++ /dev/null @@ -1,26 +0,0 @@ -// sherpa/csrc/decode.h -// -// Copyright (c) 2023 Xiaomi Corporation - -#ifndef SHERPA_ONNX_CSRC_DECODE_H_ -#define SHERPA_ONNX_CSRC_DECODE_H_ - -#include - -#include "sherpa-onnx/csrc/online-transducer-model.h" - -namespace sherpa_onnx { - -/** Greedy search for non-streaming ASR. - * - * @TODO(fangjun) Support batch size > 1 - * - * @param model The RnntModel - * @param encoder_out Its shape is (1, num_frames, encoder_out_dim). - */ -void GreedySearch(OnlineTransducerModel *model, Ort::Value encoder_out, - std::vector *hyp); - -} // namespace sherpa_onnx - -#endif // SHERPA_ONNX_CSRC_DECODE_H_ diff --git a/sherpa-onnx/csrc/features.cc b/sherpa-onnx/csrc/features.cc index 2dab27fd..3201f4fe 100644 --- a/sherpa-onnx/csrc/features.cc +++ b/sherpa-onnx/csrc/features.cc @@ -15,16 +15,16 @@ namespace sherpa_onnx { class FeatureExtractor::Impl { public: - Impl(int32_t sampling_rate, int32_t feature_dim) { + explicit Impl(const FeatureExtractorConfig &config) { opts_.frame_opts.dither = 0; opts_.frame_opts.snip_edges = false; - opts_.frame_opts.samp_freq = sampling_rate; + opts_.frame_opts.samp_freq = config.sampling_rate; // cache 100 seconds of feature frames, which is more than enough // for real needs opts_.frame_opts.max_feature_vectors = 100 * 100; - opts_.mel_opts.num_bins = feature_dim; + opts_.mel_opts.num_bins = config.feature_dim; fbank_ = std::make_unique(opts_); } @@ -80,9 +80,8 @@ class FeatureExtractor::Impl { mutable std::mutex mutex_; }; -FeatureExtractor::FeatureExtractor(int32_t sampling_rate /*=16000*/, - int32_t feature_dim /*=80*/) - : impl_(std::make_unique(sampling_rate, feature_dim)) {} +FeatureExtractor::FeatureExtractor(const FeatureExtractorConfig &config /*={}*/) + : impl_(std::make_unique(config)) {} FeatureExtractor::~FeatureExtractor() = default; diff --git a/sherpa-onnx/csrc/features.h b/sherpa-onnx/csrc/features.h index 9ff3104b..e65a64ba 100644 --- a/sherpa-onnx/csrc/features.h +++ b/sherpa-onnx/csrc/features.h @@ -10,14 +10,18 @@ namespace sherpa_onnx { +struct FeatureExtractorConfig { + int32_t sampling_rate = 16000; + int32_t feature_dim = 80; +}; + class FeatureExtractor { public: /** * @param sampling_rate Sampling rate of the data used to train the model. * @param feature_dim Dimension of the features used to train the model. */ - explicit FeatureExtractor(int32_t sampling_rate = 16000, - int32_t feature_dim = 80); + explicit FeatureExtractor(const FeatureExtractorConfig &config = {}); ~FeatureExtractor(); /** diff --git a/sherpa-onnx/csrc/online-transducer-decoder.h b/sherpa-onnx/csrc/online-transducer-decoder.h new file mode 100644 index 00000000..d4dfd109 --- /dev/null +++ b/sherpa-onnx/csrc/online-transducer-decoder.h @@ -0,0 +1,52 @@ +// sherpa/csrc/online-transducer-decoder.h +// +// Copyright (c) 2023 Xiaomi Corporation + +#ifndef SHERPA_ONNX_CSRC_ONLINE_TRANSDUCER_DECODER_H_ +#define SHERPA_ONNX_CSRC_ONLINE_TRANSDUCER_DECODER_H_ + +#include + +#include "onnxruntime_cxx_api.h" // NOLINT + +namespace sherpa_onnx { + +struct OnlineTransducerDecoderResult { + /// The decoded token IDs so far + std::vector tokens; +}; + +class OnlineTransducerDecoder { + public: + virtual ~OnlineTransducerDecoder() = default; + + /* Return an empty result. + * + * To simplify the decoding code, we add `context_size` blanks + * to the beginning of the decoding result, which will be + * stripped by calling `StripPrecedingBlanks()`. + */ + virtual OnlineTransducerDecoderResult GetEmptyResult() = 0; + + /** Strip blanks added by `GetEmptyResult()`. + * + * @param r It is changed in-place. + */ + virtual void StripLeadingBlanks(OnlineTransducerDecoderResult * /*r*/) {} + + /** Run transducer beam search given the output from the encoder model. + * + * @param encoder_out A 3-D tensor of shape (N, T, joiner_dim) + * @param result It is modified in-place. + * + * @note There is no need to pass encoder_out_length here since for the + * online decoding case, each utterance has the same number of frames + * and there are no paddings. + */ + virtual void Decode(Ort::Value encoder_out, + std::vector *result) = 0; +}; + +} // namespace sherpa_onnx + +#endif // SHERPA_ONNX_CSRC_ONLINE_TRANSDUCER_DECODER_H_ diff --git a/sherpa-onnx/csrc/online-transducer-greedy-search-decoder.cc b/sherpa-onnx/csrc/online-transducer-greedy-search-decoder.cc new file mode 100644 index 00000000..9ef41a1e --- /dev/null +++ b/sherpa-onnx/csrc/online-transducer-greedy-search-decoder.cc @@ -0,0 +1,101 @@ +// sherpa/csrc/online-transducer-greedy-search-decoder.cc +// +// Copyright (c) 2023 Xiaomi Corporation + +#include "sherpa-onnx/csrc/online-transducer-greedy-search-decoder.h" + +#include + +#include +#include +#include + +#include "sherpa-onnx/csrc/onnx-utils.h" + +namespace sherpa_onnx { + +static Ort::Value GetFrame(Ort::Value *encoder_out, int32_t t) { + std::vector encoder_out_shape = + encoder_out->GetTensorTypeAndShapeInfo().GetShape(); + assert(encoder_out_shape[0] == 1); + + int32_t encoder_out_dim = encoder_out_shape[2]; + + auto memory_info = + Ort::MemoryInfo::CreateCpu(OrtDeviceAllocator, OrtMemTypeDefault); + + std::array shape{1, encoder_out_dim}; + + return Ort::Value::CreateTensor( + memory_info, + encoder_out->GetTensorMutableData() + t * encoder_out_dim, + encoder_out_dim, shape.data(), shape.size()); +} + +OnlineTransducerDecoderResult +OnlineTransducerGreedySearchDecoder::GetEmptyResult() { + int32_t context_size = model_->ContextSize(); + int32_t blank_id = 0; // always 0 + OnlineTransducerDecoderResult r; + r.tokens.resize(context_size, blank_id); + + return r; +} + +void OnlineTransducerGreedySearchDecoder::StripLeadingBlanks( + OnlineTransducerDecoderResult *r) { + int32_t context_size = model_->ContextSize(); + + auto start = r->tokens.begin() + context_size; + auto end = r->tokens.end(); + + r->tokens = std::vector(start, end); +} + +void OnlineTransducerGreedySearchDecoder::Decode( + Ort::Value encoder_out, + std::vector *result) { + std::vector encoder_out_shape = + encoder_out.GetTensorTypeAndShapeInfo().GetShape(); + + if (encoder_out_shape[0] != result->size()) { + fprintf(stderr, + "Size mismatch! encoder_out.size(0) %d, result.size(0): %d\n", + static_cast(encoder_out_shape[0]), + static_cast(result->size())); + exit(-1); + } + + if (result->size() != 1) { + fprintf(stderr, "only batch size == 1 is implemented. Given: %d", + static_cast(result->size())); + exit(-1); + } + + auto &hyp = (*result)[0].tokens; + + int32_t num_frames = encoder_out_shape[1]; + int32_t vocab_size = model_->VocabSize(); + + Ort::Value decoder_input = model_->BuildDecoderInput(hyp); + Ort::Value decoder_out = model_->RunDecoder(std::move(decoder_input)); + + for (int32_t t = 0; t != num_frames; ++t) { + Ort::Value cur_encoder_out = GetFrame(&encoder_out, t); + Ort::Value logit = + model_->RunJoiner(std::move(cur_encoder_out), Clone(&decoder_out)); + const float *p_logit = logit.GetTensorData(); + + auto y = static_cast(std::distance( + static_cast(p_logit), + std::max_element(static_cast(p_logit), + static_cast(p_logit) + vocab_size))); + if (y != 0) { + hyp.push_back(y); + decoder_input = model_->BuildDecoderInput(hyp); + decoder_out = model_->RunDecoder(std::move(decoder_input)); + } + } +} + +} // namespace sherpa_onnx diff --git a/sherpa-onnx/csrc/online-transducer-greedy-search-decoder.h b/sherpa-onnx/csrc/online-transducer-greedy-search-decoder.h new file mode 100644 index 00000000..26e35238 --- /dev/null +++ b/sherpa-onnx/csrc/online-transducer-greedy-search-decoder.h @@ -0,0 +1,33 @@ +// sherpa/csrc/online-transducer-greedy-search-decoder.h +// +// Copyright (c) 2023 Xiaomi Corporation + +#ifndef SHERPA_ONNX_CSRC_ONLINE_TRANSDUCER_GREEDY_SEARCH_DECODER_H_ +#define SHERPA_ONNX_CSRC_ONLINE_TRANSDUCER_GREEDY_SEARCH_DECODER_H_ + +#include + +#include "sherpa-onnx/csrc/online-transducer-decoder.h" +#include "sherpa-onnx/csrc/online-transducer-model.h" + +namespace sherpa_onnx { + +class OnlineTransducerGreedySearchDecoder : public OnlineTransducerDecoder { + public: + explicit OnlineTransducerGreedySearchDecoder(OnlineTransducerModel *model) + : model_(model) {} + + OnlineTransducerDecoderResult GetEmptyResult() override; + + void StripLeadingBlanks(OnlineTransducerDecoderResult *r) override; + + void Decode(Ort::Value encoder_out, + std::vector *result) override; + + private: + OnlineTransducerModel *model_; // Not owned +}; + +} // namespace sherpa_onnx + +#endif // SHERPA_ONNX_CSRC_ONLINE_TRANSDUCER_GREEDY_SEARCH_DECODER_H_ diff --git a/sherpa-onnx/csrc/onnx-utils.cc b/sherpa-onnx/csrc/onnx-utils.cc index 47dd1576..9105fab0 100644 --- a/sherpa-onnx/csrc/onnx-utils.cc +++ b/sherpa-onnx/csrc/onnx-utils.cc @@ -46,4 +46,16 @@ void PrintModelMetadata(std::ostream &os, const Ort::ModelMetadata &meta_data) { } } +Ort::Value Clone(Ort::Value *v) { + auto type_and_shape = v->GetTensorTypeAndShapeInfo(); + std::vector shape = type_and_shape.GetShape(); + + auto memory_info = + Ort::MemoryInfo::CreateCpu(OrtDeviceAllocator, OrtMemTypeDefault); + + return Ort::Value::CreateTensor(memory_info, v->GetTensorMutableData(), + type_and_shape.GetElementCount(), + shape.data(), shape.size()); +} + } // namespace sherpa_onnx diff --git a/sherpa-onnx/csrc/onnx-utils.h b/sherpa-onnx/csrc/onnx-utils.h index f7f5677e..2307722b 100644 --- a/sherpa-onnx/csrc/onnx-utils.h +++ b/sherpa-onnx/csrc/onnx-utils.h @@ -55,6 +55,9 @@ void GetOutputNames(Ort::Session *sess, std::vector *output_names, void PrintModelMetadata(std::ostream &os, const Ort::ModelMetadata &meta_data); // NOLINT +// Return a shallow copy of v +Ort::Value Clone(Ort::Value *v); + } // namespace sherpa_onnx #endif // SHERPA_ONNX_CSRC_ONNX_UTILS_H_ diff --git a/sherpa-onnx/csrc/sherpa-onnx.cc b/sherpa-onnx/csrc/sherpa-onnx.cc index 1e18a9be..6338317f 100644 --- a/sherpa-onnx/csrc/sherpa-onnx.cc +++ b/sherpa-onnx/csrc/sherpa-onnx.cc @@ -9,8 +9,8 @@ #include #include "kaldi-native-fbank/csrc/online-feature.h" -#include "sherpa-onnx/csrc/decode.h" #include "sherpa-onnx/csrc/features.h" +#include "sherpa-onnx/csrc/online-transducer-greedy-search-decoder.h" #include "sherpa-onnx/csrc/online-transducer-model-config.h" #include "sherpa-onnx/csrc/online-transducer-model.h" #include "sherpa-onnx/csrc/symbol-table.h" @@ -64,8 +64,6 @@ for a list of pre-trained models to download. std::vector states = model->GetEncoderInitStates(); - std::vector hyp(model->ContextSize(), 0); - int32_t expected_sampling_rate = 16000; bool is_ok = false; @@ -100,6 +98,10 @@ for a list of pre-trained models to download. std::array x_shape{1, chunk_size, feature_dim}; + sherpa_onnx::OnlineTransducerGreedySearchDecoder decoder(model.get()); + std::vector result = { + decoder.GetEmptyResult()}; + for (int32_t start = 0; start + chunk_size < num_frames; start += chunk_shift) { std::vector features = feat_extractor.GetFrames(start, chunk_size); @@ -109,8 +111,10 @@ for a list of pre-trained models to download. x_shape.data(), x_shape.size()); auto pair = model->RunEncoder(std::move(x), states); states = std::move(pair.second); - sherpa_onnx::GreedySearch(model.get(), std::move(pair.first), &hyp); + decoder.Decode(std::move(pair.first), &result); } + decoder.StripLeadingBlanks(&result[0]); + const auto &hyp = result[0].tokens; std::string text; for (size_t i = model->ContextSize(); i != hyp.size(); ++i) { text += sym[hyp[i]];