diff --git a/.github/scripts/test-offline-transducer.sh b/.github/scripts/test-offline-transducer.sh index 9f98c522..33b9bec0 100755 --- a/.github/scripts/test-offline-transducer.sh +++ b/.github/scripts/test-offline-transducer.sh @@ -25,36 +25,59 @@ log "Download pretrained model and test-data from $repo_url" GIT_LFS_SKIP_SMUDGE=1 git clone $repo_url pushd $repo git lfs pull --include "*.onnx" -cd test_wavs popd -waves=( -$repo/test_wavs/0.wav -$repo/test_wavs/1.wav -$repo/test_wavs/2.wav -) - -for wave in ${waves[@]}; do - time $EXE \ - $repo/tokens.txt \ - $repo/encoder-epoch-99-avg-1.onnx \ - $repo/decoder-epoch-99-avg-1.onnx \ - $repo/joiner-epoch-99-avg-1.onnx \ - $wave \ - 2 -done +time $EXE \ + --tokens=$repo/tokens.txt \ + --encoder=$repo/encoder-epoch-99-avg-1.onnx \ + --decoder=$repo/decoder-epoch-99-avg-1.onnx \ + --joiner=$repo/joiner-epoch-99-avg-1.onnx \ + --num-threads=2 \ + $repo/test_wavs/0.wav \ + $repo/test_wavs/1.wav \ + $repo/test_wavs/2.wav if command -v sox &> /dev/null; then echo "test 8kHz" sox $repo/test_wavs/0.wav -r 8000 8k.wav + time $EXE \ - $repo/tokens.txt \ - $repo/encoder-epoch-99-avg-1.onnx \ - $repo/decoder-epoch-99-avg-1.onnx \ - $repo/joiner-epoch-99-avg-1.onnx \ - 8k.wav \ - 2 + --tokens=$repo/tokens.txt \ + --encoder=$repo/encoder-epoch-99-avg-1.onnx \ + --decoder=$repo/decoder-epoch-99-avg-1.onnx \ + --joiner=$repo/joiner-epoch-99-avg-1.onnx \ + --num-threads=2 \ + $repo/test_wavs/0.wav \ + $repo/test_wavs/1.wav \ + $repo/test_wavs/2.wav \ + 8k.wav fi rm -rf $repo + +log "------------------------------------------------------------" +log "Run Paraformer (Chinese)" +log "------------------------------------------------------------" + +repo_url=https://huggingface.co/csukuangfj/sherpa-onnx-paraformer-zh-2023-03-28 +log "Start testing ${repo_url}" +repo=$(basename $repo_url) +log "Download pretrained model and test-data from $repo_url" + +GIT_LFS_SKIP_SMUDGE=1 git clone $repo_url +pushd $repo +git lfs pull --include "*.onnx" +popd + +time $EXE \ + --tokens=$repo/tokens.txt \ + --paraformer=$repo/model.onnx \ + --num-threads=2 \ + --decoding-method=greedy_search \ + $repo/test_wavs/0.wav \ + $repo/test_wavs/1.wav \ + $repo/test_wavs/2.wav \ + $repo/test_wavs/8k.wav + +rm -rf $repo diff --git a/.github/workflows/windows-x64.yaml b/.github/workflows/windows-x64.yaml index 76dbf799..39a4b627 100644 --- a/.github/workflows/windows-x64.yaml +++ b/.github/workflows/windows-x64.yaml @@ -71,7 +71,15 @@ jobs: ls -lh ./bin/Release/sherpa-onnx.exe - - name: Test sherpa-onnx for Windows x64 + - name: Test offline transducer for Windows x64 + shell: bash + run: | + export PATH=$PWD/build/bin/Release:$PATH + export EXE=sherpa-onnx-offline.exe + + .github/scripts/test-offline-transducer.sh + + - name: Test online transducer for Windows x64 shell: bash run: | export PATH=$PWD/build/bin/Release:$PATH diff --git a/.github/workflows/windows-x86.yaml b/.github/workflows/windows-x86.yaml index 5f648d81..eb4022b0 100644 --- a/.github/workflows/windows-x86.yaml +++ b/.github/workflows/windows-x86.yaml @@ -71,7 +71,15 @@ jobs: ls -lh ./bin/Release/sherpa-onnx.exe - - name: Test sherpa-onnx for Windows x86 + - name: Test offline transducer for Windows x86 + shell: bash + run: | + export PATH=$PWD/build/bin/Release:$PATH + export EXE=sherpa-onnx-offline.exe + + .github/scripts/test-offline-transducer.sh + + - name: Test online transducer for Windows x86 shell: bash run: | export PATH=$PWD/build/bin/Release:$PATH diff --git a/.gitignore b/.gitignore index 6c4f80e0..41d1a0a1 100644 --- a/.gitignore +++ b/.gitignore @@ -41,3 +41,7 @@ android/SherpaOnnx/app/src/main/assets/ *.ncnn.* run-sherpa-onnx-offline.sh sherpa-onnx-conformer-en-2023-03-18 +paraformer-onnxruntime-python-example +run-sherpa-onnx-offline-paraformer.sh +run-sherpa-onnx-offline-transducer.sh +sherpa-onnx-paraformer-zh-2023-03-28 diff --git a/sherpa-onnx/csrc/CMakeLists.txt b/sherpa-onnx/csrc/CMakeLists.txt index f25e8bf6..30ec8016 100644 --- a/sherpa-onnx/csrc/CMakeLists.txt +++ b/sherpa-onnx/csrc/CMakeLists.txt @@ -6,6 +6,10 @@ set(sources features.cc file-utils.cc hypothesis.cc + offline-model-config.cc + offline-paraformer-greedy-search-decoder.cc + offline-paraformer-model-config.cc + offline-paraformer-model.cc offline-recognizer-impl.cc offline-recognizer.cc offline-stream.cc diff --git a/sherpa-onnx/csrc/macros.h b/sherpa-onnx/csrc/macros.h index a44021e7..efe61289 100644 --- a/sherpa-onnx/csrc/macros.h +++ b/sherpa-onnx/csrc/macros.h @@ -57,6 +57,23 @@ } \ } while (0) +// read a vector of floats +#define SHERPA_ONNX_READ_META_DATA_VEC_FLOAT(dst, src_key) \ + do { \ + auto value = \ + meta_data.LookupCustomMetadataMapAllocated(src_key, allocator); \ + if (!value) { \ + SHERPA_ONNX_LOGE("%s does not exist in the metadata", src_key); \ + exit(-1); \ + } \ + \ + bool ret = SplitStringToFloats(value.get(), ",", true, &dst); \ + if (!ret) { \ + SHERPA_ONNX_LOGE("Invalid value %s for %s", value.get(), src_key); \ + exit(-1); \ + } \ + } while (0) + // Read a string #define SHERPA_ONNX_READ_META_DATA_STR(dst, src_key) \ do { \ diff --git a/sherpa-onnx/csrc/offline-model-config.cc b/sherpa-onnx/csrc/offline-model-config.cc new file mode 100644 index 00000000..29f7b8a7 --- /dev/null +++ b/sherpa-onnx/csrc/offline-model-config.cc @@ -0,0 +1,57 @@ +// sherpa-onnx/csrc/offline-model-config.cc +// +// Copyright (c) 2023 Xiaomi Corporation +#include "sherpa-onnx/csrc/offline-model-config.h" + +#include + +#include "sherpa-onnx/csrc/file-utils.h" +#include "sherpa-onnx/csrc/macros.h" + +namespace sherpa_onnx { + +void OfflineModelConfig::Register(ParseOptions *po) { + transducer.Register(po); + paraformer.Register(po); + + po->Register("tokens", &tokens, "Path to tokens.txt"); + + po->Register("num-threads", &num_threads, + "Number of threads to run the neural network"); + + po->Register("debug", &debug, + "true to print model information while loading it."); +} + +bool OfflineModelConfig::Validate() const { + if (num_threads < 1) { + SHERPA_ONNX_LOGE("num_threads should be > 0. Given %d", num_threads); + return false; + } + + if (!FileExists(tokens)) { + SHERPA_ONNX_LOGE("%s does not exist", tokens.c_str()); + return false; + } + + if (!paraformer.model.empty()) { + return paraformer.Validate(); + } + + return transducer.Validate(); +} + +std::string OfflineModelConfig::ToString() const { + std::ostringstream os; + + os << "OfflineModelConfig("; + os << "transducer=" << transducer.ToString() << ", "; + os << "paraformer=" << paraformer.ToString() << ", "; + os << "tokens=\"" << tokens << "\", "; + os << "num_threads=" << num_threads << ", "; + os << "debug=" << (debug ? "True" : "False") << ")"; + + return os.str(); +} + +} // namespace sherpa_onnx diff --git a/sherpa-onnx/csrc/offline-model-config.h b/sherpa-onnx/csrc/offline-model-config.h new file mode 100644 index 00000000..d8412316 --- /dev/null +++ b/sherpa-onnx/csrc/offline-model-config.h @@ -0,0 +1,40 @@ +// sherpa-onnx/csrc/offline-model-config.h +// +// Copyright (c) 2023 Xiaomi Corporation +#ifndef SHERPA_ONNX_CSRC_OFFLINE_MODEL_CONFIG_H_ +#define SHERPA_ONNX_CSRC_OFFLINE_MODEL_CONFIG_H_ + +#include + +#include "sherpa-onnx/csrc/offline-paraformer-model-config.h" +#include "sherpa-onnx/csrc/offline-transducer-model-config.h" + +namespace sherpa_onnx { + +struct OfflineModelConfig { + OfflineTransducerModelConfig transducer; + OfflineParaformerModelConfig paraformer; + + std::string tokens; + int32_t num_threads = 2; + bool debug = false; + + OfflineModelConfig() = default; + OfflineModelConfig(const OfflineTransducerModelConfig &transducer, + const OfflineParaformerModelConfig ¶former, + const std::string &tokens, int32_t num_threads, bool debug) + : transducer(transducer), + paraformer(paraformer), + tokens(tokens), + num_threads(num_threads), + debug(debug) {} + + void Register(ParseOptions *po); + bool Validate() const; + + std::string ToString() const; +}; + +} // namespace sherpa_onnx + +#endif // SHERPA_ONNX_CSRC_OFFLINE_MODEL_CONFIG_H_ diff --git a/sherpa-onnx/csrc/offline-paraformer-decoder.h b/sherpa-onnx/csrc/offline-paraformer-decoder.h new file mode 100644 index 00000000..65781324 --- /dev/null +++ b/sherpa-onnx/csrc/offline-paraformer-decoder.h @@ -0,0 +1,37 @@ +// sherpa-onnx/csrc/offline-paraformer-decoder.h +// +// Copyright (c) 2023 Xiaomi Corporation + +#ifndef SHERPA_ONNX_CSRC_OFFLINE_PARAFORMER_DECODER_H_ +#define SHERPA_ONNX_CSRC_OFFLINE_PARAFORMER_DECODER_H_ + +#include + +#include "onnxruntime_cxx_api.h" // NOLINT + +namespace sherpa_onnx { + +struct OfflineParaformerDecoderResult { + /// The decoded token IDs + std::vector tokens; +}; + +class OfflineParaformerDecoder { + public: + virtual ~OfflineParaformerDecoder() = default; + + /** Run beam search given the output from the paraformer model. + * + * @param log_probs A 3-D tensor of shape (N, T, vocab_size) + * @param token_num A 2-D tensor of shape (N, T). Its dtype is int64_t. + * log_probs[i].argmax(axis=-1) equals to token_num[i] + * + * @return Return a vector of size `N` containing the decoded results. + */ + virtual std::vector Decode( + Ort::Value log_probs, Ort::Value token_num) = 0; +}; + +} // namespace sherpa_onnx + +#endif // SHERPA_ONNX_CSRC_OFFLINE_PARAFORMER_DECODER_H_ diff --git a/sherpa-onnx/csrc/offline-paraformer-greedy-search-decoder.cc b/sherpa-onnx/csrc/offline-paraformer-greedy-search-decoder.cc new file mode 100644 index 00000000..54ce545a --- /dev/null +++ b/sherpa-onnx/csrc/offline-paraformer-greedy-search-decoder.cc @@ -0,0 +1,34 @@ +// sherpa-onnx/csrc/offline-paraformer-greedy-search-decoder.cc +// +// Copyright (c) 2023 Xiaomi Corporation + +#include "sherpa-onnx/csrc/offline-paraformer-greedy-search-decoder.h" + +#include + +namespace sherpa_onnx { + +std::vector +OfflineParaformerGreedySearchDecoder::Decode(Ort::Value /*log_probs*/, + Ort::Value token_num) { + std::vector shape = token_num.GetTensorTypeAndShapeInfo().GetShape(); + int32_t batch_size = shape[0]; + int32_t num_tokens = shape[1]; + + std::vector results(batch_size); + + const int64_t *p = token_num.GetTensorData(); + for (int32_t i = 0; i != batch_size; ++i) { + for (int32_t k = 0; k != num_tokens; ++k) { + if (p[k] == eos_id_) break; + + results[i].tokens.push_back(p[k]); + } + + p += num_tokens; + } + + return results; +} + +} // namespace sherpa_onnx diff --git a/sherpa-onnx/csrc/offline-paraformer-greedy-search-decoder.h b/sherpa-onnx/csrc/offline-paraformer-greedy-search-decoder.h new file mode 100644 index 00000000..9ba177c9 --- /dev/null +++ b/sherpa-onnx/csrc/offline-paraformer-greedy-search-decoder.h @@ -0,0 +1,28 @@ +// sherpa-onnx/csrc/offline-paraformer-greedy-search-decoder.h +// +// Copyright (c) 2023 Xiaomi Corporation + +#ifndef SHERPA_ONNX_CSRC_OFFLINE_PARAFORMER_GREEDY_SEARCH_DECODER_H_ +#define SHERPA_ONNX_CSRC_OFFLINE_PARAFORMER_GREEDY_SEARCH_DECODER_H_ + +#include + +#include "sherpa-onnx/csrc/offline-paraformer-decoder.h" + +namespace sherpa_onnx { + +class OfflineParaformerGreedySearchDecoder : public OfflineParaformerDecoder { + public: + explicit OfflineParaformerGreedySearchDecoder(int32_t eos_id) + : eos_id_(eos_id) {} + + std::vector Decode( + Ort::Value /*log_probs*/, Ort::Value token_num) override; + + private: + int32_t eos_id_; +}; + +} // namespace sherpa_onnx + +#endif // SHERPA_ONNX_CSRC_OFFLINE_PARAFORMER_GREEDY_SEARCH_DECODER_H_ diff --git a/sherpa-onnx/csrc/offline-paraformer-model-config.cc b/sherpa-onnx/csrc/offline-paraformer-model-config.cc new file mode 100644 index 00000000..dad43ff6 --- /dev/null +++ b/sherpa-onnx/csrc/offline-paraformer-model-config.cc @@ -0,0 +1,34 @@ +// sherpa-onnx/csrc/offline-paraformer-model-config.cc +// +// Copyright (c) 2023 Xiaomi Corporation + +#include "sherpa-onnx/csrc/offline-paraformer-model-config.h" + +#include "sherpa-onnx/csrc/file-utils.h" +#include "sherpa-onnx/csrc/macros.h" + +namespace sherpa_onnx { + +void OfflineParaformerModelConfig::Register(ParseOptions *po) { + po->Register("paraformer", &model, "Path to model.onnx of paraformer."); +} + +bool OfflineParaformerModelConfig::Validate() const { + if (!FileExists(model)) { + SHERPA_ONNX_LOGE("%s does not exist", model.c_str()); + return false; + } + + return true; +} + +std::string OfflineParaformerModelConfig::ToString() const { + std::ostringstream os; + + os << "OfflineParaformerModelConfig("; + os << "model=\"" << model << "\")"; + + return os.str(); +} + +} // namespace sherpa_onnx diff --git a/sherpa-onnx/csrc/offline-paraformer-model-config.h b/sherpa-onnx/csrc/offline-paraformer-model-config.h new file mode 100644 index 00000000..f0420dcb --- /dev/null +++ b/sherpa-onnx/csrc/offline-paraformer-model-config.h @@ -0,0 +1,28 @@ +// sherpa-onnx/csrc/offline-paraformer-model-config.h +// +// Copyright (c) 2023 Xiaomi Corporation +#ifndef SHERPA_ONNX_CSRC_OFFLINE_PARAFORMER_MODEL_CONFIG_H_ +#define SHERPA_ONNX_CSRC_OFFLINE_PARAFORMER_MODEL_CONFIG_H_ + +#include + +#include "sherpa-onnx/csrc/parse-options.h" + +namespace sherpa_onnx { + +struct OfflineParaformerModelConfig { + std::string model; + + OfflineParaformerModelConfig() = default; + explicit OfflineParaformerModelConfig(const std::string &model) + : model(model) {} + + void Register(ParseOptions *po); + bool Validate() const; + + std::string ToString() const; +}; + +} // namespace sherpa_onnx + +#endif // SHERPA_ONNX_CSRC_OFFLINE_PARAFORMER_MODEL_CONFIG_H_ diff --git a/sherpa-onnx/csrc/offline-paraformer-model.cc b/sherpa-onnx/csrc/offline-paraformer-model.cc new file mode 100644 index 00000000..3accce35 --- /dev/null +++ b/sherpa-onnx/csrc/offline-paraformer-model.cc @@ -0,0 +1,132 @@ +// sherpa-onnx/csrc/offline-paraformer-model.cc +// +// Copyright (c) 2022-2023 Xiaomi Corporation + +#include "sherpa-onnx/csrc/offline-paraformer-model.h" + +#include +#include + +#include "sherpa-onnx/csrc/macros.h" +#include "sherpa-onnx/csrc/onnx-utils.h" +#include "sherpa-onnx/csrc/text-utils.h" + +namespace sherpa_onnx { + +class OfflineParaformerModel::Impl { + public: + explicit Impl(const OfflineModelConfig &config) + : config_(config), + env_(ORT_LOGGING_LEVEL_ERROR), + sess_opts_{}, + allocator_{} { + sess_opts_.SetIntraOpNumThreads(config_.num_threads); + sess_opts_.SetInterOpNumThreads(config_.num_threads); + + Init(); + } + + std::pair Forward(Ort::Value features, + Ort::Value features_length) { + std::array inputs = {std::move(features), + std::move(features_length)}; + + auto out = + sess_->Run({}, input_names_ptr_.data(), inputs.data(), inputs.size(), + output_names_ptr_.data(), output_names_ptr_.size()); + + return {std::move(out[0]), std::move(out[1])}; + } + + int32_t VocabSize() const { return vocab_size_; } + + int32_t LfrWindowSize() const { return lfr_window_size_; } + + int32_t LfrWindowShift() const { return lfr_window_shift_; } + + const std::vector &NegativeMean() const { return neg_mean_; } + + const std::vector &InverseStdDev() const { return inv_stddev_; } + + OrtAllocator *Allocator() const { return allocator_; } + + private: + void Init() { + auto buf = ReadFile(config_.paraformer.model); + + sess_ = std::make_unique(env_, buf.data(), buf.size(), + sess_opts_); + + GetInputNames(sess_.get(), &input_names_, &input_names_ptr_); + + GetOutputNames(sess_.get(), &output_names_, &output_names_ptr_); + + // get meta data + Ort::ModelMetadata meta_data = sess_->GetModelMetadata(); + if (config_.debug) { + std::ostringstream os; + PrintModelMetadata(os, meta_data); + SHERPA_ONNX_LOGE("%s\n", os.str().c_str()); + } + + Ort::AllocatorWithDefaultOptions allocator; // used in the macro below + SHERPA_ONNX_READ_META_DATA(vocab_size_, "vocab_size"); + SHERPA_ONNX_READ_META_DATA(lfr_window_size_, "lfr_window_size"); + SHERPA_ONNX_READ_META_DATA(lfr_window_shift_, "lfr_window_shift"); + + SHERPA_ONNX_READ_META_DATA_VEC_FLOAT(neg_mean_, "neg_mean"); + SHERPA_ONNX_READ_META_DATA_VEC_FLOAT(inv_stddev_, "inv_stddev"); + } + + private: + OfflineModelConfig config_; + Ort::Env env_; + Ort::SessionOptions sess_opts_; + Ort::AllocatorWithDefaultOptions allocator_; + + std::unique_ptr sess_; + + std::vector input_names_; + std::vector input_names_ptr_; + + std::vector output_names_; + std::vector output_names_ptr_; + + std::vector neg_mean_; + std::vector inv_stddev_; + + int32_t vocab_size_ = 0; // initialized in Init + int32_t lfr_window_size_ = 0; + int32_t lfr_window_shift_ = 0; +}; + +OfflineParaformerModel::OfflineParaformerModel(const OfflineModelConfig &config) + : impl_(std::make_unique(config)) {} + +OfflineParaformerModel::~OfflineParaformerModel() = default; + +std::pair OfflineParaformerModel::Forward( + Ort::Value features, Ort::Value features_length) { + return impl_->Forward(std::move(features), std::move(features_length)); +} + +int32_t OfflineParaformerModel::VocabSize() const { return impl_->VocabSize(); } + +int32_t OfflineParaformerModel::LfrWindowSize() const { + return impl_->LfrWindowSize(); +} +int32_t OfflineParaformerModel::LfrWindowShift() const { + return impl_->LfrWindowShift(); +} +const std::vector &OfflineParaformerModel::NegativeMean() const { + return impl_->NegativeMean(); +} +const std::vector &OfflineParaformerModel::InverseStdDev() const { + return impl_->InverseStdDev(); +} + +OrtAllocator *OfflineParaformerModel::Allocator() const { + return impl_->Allocator(); +} + +} // namespace sherpa_onnx diff --git a/sherpa-onnx/csrc/offline-paraformer-model.h b/sherpa-onnx/csrc/offline-paraformer-model.h new file mode 100644 index 00000000..75de8cb0 --- /dev/null +++ b/sherpa-onnx/csrc/offline-paraformer-model.h @@ -0,0 +1,67 @@ +// sherpa-onnx/csrc/offline-paraformer-model.h +// +// Copyright (c) 2022-2023 Xiaomi Corporation +#ifndef SHERPA_ONNX_CSRC_OFFLINE_PARAFORMER_MODEL_H_ +#define SHERPA_ONNX_CSRC_OFFLINE_PARAFORMER_MODEL_H_ + +#include +#include +#include + +#include "onnxruntime_cxx_api.h" // NOLINT +#include "sherpa-onnx/csrc/offline-model-config.h" + +namespace sherpa_onnx { + +class OfflineParaformerModel { + public: + explicit OfflineParaformerModel(const OfflineModelConfig &config); + ~OfflineParaformerModel(); + + /** Run the forward method of the model. + * + * @param features A tensor of shape (N, T, C). It is changed in-place. + * @param features_length A 1-D tensor of shape (N,) containing number of + * valid frames in `features` before padding. + * Its dtype is int32_t. + * + * @return Return a pair containing: + * - log_probs: A 3-D tensor of shape (N, T', vocab_size) + * - token_num: A 1-D tensor of shape (N, T') containing number + * of valid tokens in each utterance. Its dtype is int64_t. + */ + std::pair Forward(Ort::Value features, + Ort::Value features_length); + + /** Return the vocabulary size of the model + */ + int32_t VocabSize() const; + + /** It is lfr_m in config.yaml + */ + int32_t LfrWindowSize() const; + + /** It is lfr_n in config.yaml + */ + int32_t LfrWindowShift() const; + + /** Return negative mean for CMVN + */ + const std::vector &NegativeMean() const; + + /** Return inverse stddev for CMVN + */ + const std::vector &InverseStdDev() const; + + /** Return an allocator for allocating memory + */ + OrtAllocator *Allocator() const; + + private: + class Impl; + std::unique_ptr impl_; +}; + +} // namespace sherpa_onnx + +#endif // SHERPA_ONNX_CSRC_OFFLINE_PARAFORMER_MODEL_H_ diff --git a/sherpa-onnx/csrc/offline-recognizer-impl.cc b/sherpa-onnx/csrc/offline-recognizer-impl.cc index cfcad26f..ff03db3f 100644 --- a/sherpa-onnx/csrc/offline-recognizer-impl.cc +++ b/sherpa-onnx/csrc/offline-recognizer-impl.cc @@ -8,6 +8,7 @@ #include "onnxruntime_cxx_api.h" // NOLINT #include "sherpa-onnx/csrc/macros.h" +#include "sherpa-onnx/csrc/offline-recognizer-paraformer-impl.h" #include "sherpa-onnx/csrc/offline-recognizer-transducer-impl.h" #include "sherpa-onnx/csrc/onnx-utils.h" #include "sherpa-onnx/csrc/text-utils.h" @@ -16,10 +17,20 @@ namespace sherpa_onnx { std::unique_ptr OfflineRecognizerImpl::Create( const OfflineRecognizerConfig &config) { - Ort::Env env; + Ort::Env env(ORT_LOGGING_LEVEL_ERROR); Ort::SessionOptions sess_opts; - auto buf = ReadFile(config.model_config.encoder_filename); + std::string model_filename; + if (!config.model_config.transducer.encoder_filename.empty()) { + model_filename = config.model_config.transducer.encoder_filename; + } else if (!config.model_config.paraformer.model.empty()) { + model_filename = config.model_config.paraformer.model; + } else { + SHERPA_ONNX_LOGE("Please provide a model"); + exit(-1); + } + + auto buf = ReadFile(model_filename); auto encoder_sess = std::make_unique(env, buf.data(), buf.size(), sess_opts); @@ -35,7 +46,16 @@ std::unique_ptr OfflineRecognizerImpl::Create( return std::make_unique(config); } - SHERPA_ONNX_LOGE("Unsupported model_type: %s\n", model_type.c_str()); + if (model_type == "paraformer") { + return std::make_unique(config); + } + + SHERPA_ONNX_LOGE( + "\nUnsupported model_type: %s\n" + "We support only the following model types at present: \n" + " - transducer models from icefall\n" + " - Paraformer models from FunASR\n", + model_type.c_str()); exit(-1); } diff --git a/sherpa-onnx/csrc/offline-recognizer-paraformer-impl.h b/sherpa-onnx/csrc/offline-recognizer-paraformer-impl.h new file mode 100644 index 00000000..17c64583 --- /dev/null +++ b/sherpa-onnx/csrc/offline-recognizer-paraformer-impl.h @@ -0,0 +1,182 @@ +// sherpa-onnx/csrc/offline-recognizer-paraformer-impl.h +// +// Copyright (c) 2022-2023 Xiaomi Corporation + +#ifndef SHERPA_ONNX_CSRC_OFFLINE_RECOGNIZER_PARAFORMER_IMPL_H_ +#define SHERPA_ONNX_CSRC_OFFLINE_RECOGNIZER_PARAFORMER_IMPL_H_ + +#include +#include +#include +#include +#include + +#include "sherpa-onnx/csrc/offline-model-config.h" +#include "sherpa-onnx/csrc/offline-paraformer-decoder.h" +#include "sherpa-onnx/csrc/offline-paraformer-greedy-search-decoder.h" +#include "sherpa-onnx/csrc/offline-paraformer-model.h" +#include "sherpa-onnx/csrc/offline-recognizer-impl.h" +#include "sherpa-onnx/csrc/offline-recognizer.h" +#include "sherpa-onnx/csrc/pad-sequence.h" +#include "sherpa-onnx/csrc/symbol-table.h" + +namespace sherpa_onnx { + +static OfflineRecognitionResult Convert( + const OfflineParaformerDecoderResult &src, const SymbolTable &sym_table) { + OfflineRecognitionResult r; + r.tokens.reserve(src.tokens.size()); + + std::string text; + for (auto i : src.tokens) { + auto sym = sym_table[i]; + text.append(sym); + + r.tokens.push_back(std::move(sym)); + } + r.text = std::move(text); + + return r; +} + +class OfflineRecognizerParaformerImpl : public OfflineRecognizerImpl { + public: + explicit OfflineRecognizerParaformerImpl( + const OfflineRecognizerConfig &config) + : config_(config), + symbol_table_(config_.model_config.tokens), + model_(std::make_unique(config.model_config)) { + if (config.decoding_method == "greedy_search") { + int32_t eos_id = symbol_table_[""]; + decoder_ = std::make_unique(eos_id); + } else { + SHERPA_ONNX_LOGE("Only greedy_search is supported at present. Given %s", + config.decoding_method.c_str()); + exit(-1); + } + + // Paraformer models assume input samples are in the range + // [-32768, 32767], so we set normalize_samples to false + config_.feat_config.normalize_samples = false; + } + + std::unique_ptr CreateStream() const override { + return std::make_unique(config_.feat_config); + } + + void DecodeStreams(OfflineStream **ss, int32_t n) const override { + // 1. Apply LFR + // 2. Apply CMVN + // + // Please refer to + // https://static.googleusercontent.com/media/research.google.com/en//pubs/archive/45555.pdf + // for what LFR means + // + // "Lower Frame Rate Neural Network Acoustic Models" + auto memory_info = + Ort::MemoryInfo::CreateCpu(OrtDeviceAllocator, OrtMemTypeDefault); + + std::vector features; + features.reserve(n); + + int32_t feat_dim = + config_.feat_config.feature_dim * model_->LfrWindowSize(); + + std::vector> features_vec(n); + std::vector features_length_vec(n); + for (int32_t i = 0; i != n; ++i) { + std::vector f = ss[i]->GetFrames(); + + f = ApplyLFR(f); + ApplyCMVN(&f); + + int32_t num_frames = f.size() / feat_dim; + features_vec[i] = std::move(f); + + features_length_vec[i] = num_frames; + + std::array shape = {num_frames, feat_dim}; + + Ort::Value x = Ort::Value::CreateTensor( + memory_info, features_vec[i].data(), features_vec[i].size(), + shape.data(), shape.size()); + features.push_back(std::move(x)); + } + + std::vector features_pointer(n); + for (int32_t i = 0; i != n; ++i) { + features_pointer[i] = &features[i]; + } + + std::array features_length_shape = {n}; + Ort::Value x_length = Ort::Value::CreateTensor( + memory_info, features_length_vec.data(), n, + features_length_shape.data(), features_length_shape.size()); + + // Caution(fangjun): We cannot pad it with log(eps), + // i.e., -23.025850929940457f + Ort::Value x = PadSequence(model_->Allocator(), features_pointer, 0); + + auto t = model_->Forward(std::move(x), std::move(x_length)); + + auto results = decoder_->Decode(std::move(t.first), std::move(t.second)); + + for (int32_t i = 0; i != n; ++i) { + auto r = Convert(results[i], symbol_table_); + ss[i]->SetResult(r); + } + } + + private: + std::vector ApplyLFR(const std::vector &in) const { + int32_t lfr_window_size = model_->LfrWindowSize(); + int32_t lfr_window_shift = model_->LfrWindowShift(); + int32_t in_feat_dim = config_.feat_config.feature_dim; + + int32_t in_num_frames = in.size() / in_feat_dim; + int32_t out_num_frames = + (in_num_frames - lfr_window_size) / lfr_window_shift + 1; + int32_t out_feat_dim = in_feat_dim * lfr_window_size; + + std::vector out(out_num_frames * out_feat_dim); + + const float *p_in = in.data(); + float *p_out = out.data(); + + for (int32_t i = 0; i != out_num_frames; ++i) { + std::copy(p_in, p_in + out_feat_dim, p_out); + + p_out += out_feat_dim; + p_in += lfr_window_shift * in_feat_dim; + } + + return out; + } + + void ApplyCMVN(std::vector *v) const { + const std::vector &neg_mean = model_->NegativeMean(); + const std::vector &inv_stddev = model_->InverseStdDev(); + + int32_t dim = neg_mean.size(); + int32_t num_frames = v->size() / dim; + + float *p = v->data(); + + for (int32_t i = 0; i != num_frames; ++i) { + for (int32_t k = 0; k != dim; ++k) { + p[k] = (p[k] + neg_mean[k]) * inv_stddev[k]; + } + + p += dim; + } + } + + OfflineRecognizerConfig config_; + SymbolTable symbol_table_; + std::unique_ptr model_; + std::unique_ptr decoder_; +}; + +} // namespace sherpa_onnx + +#endif // SHERPA_ONNX_CSRC_OFFLINE_RECOGNIZER_PARAFORMER_IMPL_H_ diff --git a/sherpa-onnx/csrc/offline-recognizer-transducer-impl.h b/sherpa-onnx/csrc/offline-recognizer-transducer-impl.h index b9884e1b..750951fc 100644 --- a/sherpa-onnx/csrc/offline-recognizer-transducer-impl.h +++ b/sherpa-onnx/csrc/offline-recognizer-transducer-impl.h @@ -1,6 +1,6 @@ // sherpa-onnx/csrc/offline-recognizer-transducer-impl.h // -// Copyright (c) 2022 Xiaomi Corporation +// Copyright (c) 2022-2023 Xiaomi Corporation #ifndef SHERPA_ONNX_CSRC_OFFLINE_RECOGNIZER_TRANSDUCER_IMPL_H_ #define SHERPA_ONNX_CSRC_OFFLINE_RECOGNIZER_TRANSDUCER_IMPL_H_ diff --git a/sherpa-onnx/csrc/offline-recognizer.cc b/sherpa-onnx/csrc/offline-recognizer.cc index 873a5f3b..e0e3a651 100644 --- a/sherpa-onnx/csrc/offline-recognizer.cc +++ b/sherpa-onnx/csrc/offline-recognizer.cc @@ -6,6 +6,8 @@ #include +#include "sherpa-onnx/csrc/file-utils.h" +#include "sherpa-onnx/csrc/macros.h" #include "sherpa-onnx/csrc/offline-recognizer-impl.h" namespace sherpa_onnx { diff --git a/sherpa-onnx/csrc/offline-recognizer.h b/sherpa-onnx/csrc/offline-recognizer.h index e02e3095..321e2950 100644 --- a/sherpa-onnx/csrc/offline-recognizer.h +++ b/sherpa-onnx/csrc/offline-recognizer.h @@ -9,6 +9,7 @@ #include #include +#include "sherpa-onnx/csrc/offline-model-config.h" #include "sherpa-onnx/csrc/offline-stream.h" #include "sherpa-onnx/csrc/offline-transducer-model-config.h" #include "sherpa-onnx/csrc/parse-options.h" @@ -32,7 +33,7 @@ struct OfflineRecognitionResult { struct OfflineRecognizerConfig { OfflineFeatureExtractorConfig feat_config; - OfflineTransducerModelConfig model_config; + OfflineModelConfig model_config; std::string decoding_method = "greedy_search"; // only greedy_search is implemented @@ -40,7 +41,7 @@ struct OfflineRecognizerConfig { OfflineRecognizerConfig() = default; OfflineRecognizerConfig(const OfflineFeatureExtractorConfig &feat_config, - const OfflineTransducerModelConfig &model_config, + const OfflineModelConfig &model_config, const std::string &decoding_method) : feat_config(feat_config), model_config(model_config), diff --git a/sherpa-onnx/csrc/offline-stream.cc b/sherpa-onnx/csrc/offline-stream.cc index 2fd0a8ab..28ed642f 100644 --- a/sherpa-onnx/csrc/offline-stream.cc +++ b/sherpa-onnx/csrc/offline-stream.cc @@ -38,7 +38,7 @@ std::string OfflineFeatureExtractorConfig::ToString() const { class OfflineStream::Impl { public: - explicit Impl(const OfflineFeatureExtractorConfig &config) { + explicit Impl(const OfflineFeatureExtractorConfig &config) : config_(config) { opts_.frame_opts.dither = 0; opts_.frame_opts.snip_edges = false; opts_.frame_opts.samp_freq = config.sampling_rate; @@ -48,6 +48,19 @@ class OfflineStream::Impl { } void AcceptWaveform(int32_t sampling_rate, const float *waveform, int32_t n) { + if (config_.normalize_samples) { + AcceptWaveformImpl(sampling_rate, waveform, n); + } else { + std::vector buf(n); + for (int32_t i = 0; i != n; ++i) { + buf[i] = waveform[i] * 32768; + } + AcceptWaveformImpl(sampling_rate, buf.data(), n); + } + } + + void AcceptWaveformImpl(int32_t sampling_rate, const float *waveform, + int32_t n) { if (sampling_rate != opts_.frame_opts.samp_freq) { SHERPA_ONNX_LOGE( "Creating a resampler:\n" @@ -101,6 +114,7 @@ class OfflineStream::Impl { const OfflineRecognitionResult &GetResult() const { return r_; } private: + OfflineFeatureExtractorConfig config_; std::unique_ptr fbank_; knf::FbankOptions opts_; OfflineRecognitionResult r_; diff --git a/sherpa-onnx/csrc/offline-stream.h b/sherpa-onnx/csrc/offline-stream.h index 3059c38a..7686d641 100644 --- a/sherpa-onnx/csrc/offline-stream.h +++ b/sherpa-onnx/csrc/offline-stream.h @@ -23,6 +23,13 @@ struct OfflineFeatureExtractorConfig { // Feature dimension int32_t feature_dim = 80; + // Set internally by some models, e.g., paraformer + // This parameter is not exposed to users from the commandline + // If true, the feature extractor expects inputs to be normalized to + // the range [-1, 1]. + // If false, we will multiply the inputs by 32768 + bool normalize_samples = true; + std::string ToString() const; void Register(ParseOptions *po); diff --git a/sherpa-onnx/csrc/offline-transducer-model-config.cc b/sherpa-onnx/csrc/offline-transducer-model-config.cc index b66ff303..16b7a9f3 100644 --- a/sherpa-onnx/csrc/offline-transducer-model-config.cc +++ b/sherpa-onnx/csrc/offline-transducer-model-config.cc @@ -14,20 +14,9 @@ void OfflineTransducerModelConfig::Register(ParseOptions *po) { po->Register("encoder", &encoder_filename, "Path to encoder.onnx"); po->Register("decoder", &decoder_filename, "Path to decoder.onnx"); po->Register("joiner", &joiner_filename, "Path to joiner.onnx"); - po->Register("tokens", &tokens, "Path to tokens.txt"); - po->Register("num_threads", &num_threads, - "Number of threads to run the neural network"); - - po->Register("debug", &debug, - "true to print model information while loading it."); } bool OfflineTransducerModelConfig::Validate() const { - if (!FileExists(tokens)) { - SHERPA_ONNX_LOGE("%s does not exist", tokens.c_str()); - return false; - } - if (!FileExists(encoder_filename)) { SHERPA_ONNX_LOGE("%s does not exist", encoder_filename.c_str()); return false; @@ -43,11 +32,6 @@ bool OfflineTransducerModelConfig::Validate() const { return false; } - if (num_threads < 1) { - SHERPA_ONNX_LOGE("num_threads should be > 0. Given %d", num_threads); - return false; - } - return true; } @@ -57,10 +41,7 @@ std::string OfflineTransducerModelConfig::ToString() const { os << "OfflineTransducerModelConfig("; os << "encoder_filename=\"" << encoder_filename << "\", "; os << "decoder_filename=\"" << decoder_filename << "\", "; - os << "joiner_filename=\"" << joiner_filename << "\", "; - os << "tokens=\"" << tokens << "\", "; - os << "num_threads=" << num_threads << ", "; - os << "debug=" << (debug ? "True" : "False") << ")"; + os << "joiner_filename=\"" << joiner_filename << "\")"; return os.str(); } diff --git a/sherpa-onnx/csrc/offline-transducer-model-config.h b/sherpa-onnx/csrc/offline-transducer-model-config.h index 39987bbc..1b51f104 100644 --- a/sherpa-onnx/csrc/offline-transducer-model-config.h +++ b/sherpa-onnx/csrc/offline-transducer-model-config.h @@ -14,22 +14,14 @@ struct OfflineTransducerModelConfig { std::string encoder_filename; std::string decoder_filename; std::string joiner_filename; - std::string tokens; - int32_t num_threads = 2; - bool debug = false; OfflineTransducerModelConfig() = default; OfflineTransducerModelConfig(const std::string &encoder_filename, const std::string &decoder_filename, - const std::string &joiner_filename, - const std::string &tokens, int32_t num_threads, - bool debug) + const std::string &joiner_filename) : encoder_filename(encoder_filename), decoder_filename(decoder_filename), - joiner_filename(joiner_filename), - tokens(tokens), - num_threads(num_threads), - debug(debug) {} + joiner_filename(joiner_filename) {} void Register(ParseOptions *po); bool Validate() const; diff --git a/sherpa-onnx/csrc/offline-transducer-model.cc b/sherpa-onnx/csrc/offline-transducer-model.cc index 3d584b5f..6ecb94f8 100644 --- a/sherpa-onnx/csrc/offline-transducer-model.cc +++ b/sherpa-onnx/csrc/offline-transducer-model.cc @@ -16,7 +16,7 @@ namespace sherpa_onnx { class OfflineTransducerModel::Impl { public: - explicit Impl(const OfflineTransducerModelConfig &config) + explicit Impl(const OfflineModelConfig &config) : config_(config), env_(ORT_LOGGING_LEVEL_WARNING), sess_opts_{}, @@ -24,17 +24,17 @@ class OfflineTransducerModel::Impl { sess_opts_.SetIntraOpNumThreads(config.num_threads); sess_opts_.SetInterOpNumThreads(config.num_threads); { - auto buf = ReadFile(config.encoder_filename); + auto buf = ReadFile(config.transducer.encoder_filename); InitEncoder(buf.data(), buf.size()); } { - auto buf = ReadFile(config.decoder_filename); + auto buf = ReadFile(config.transducer.decoder_filename); InitDecoder(buf.data(), buf.size()); } { - auto buf = ReadFile(config.joiner_filename); + auto buf = ReadFile(config.transducer.joiner_filename); InitJoiner(buf.data(), buf.size()); } } @@ -164,7 +164,7 @@ class OfflineTransducerModel::Impl { } private: - OfflineTransducerModelConfig config_; + OfflineModelConfig config_; Ort::Env env_; Ort::SessionOptions sess_opts_; Ort::AllocatorWithDefaultOptions allocator_; @@ -195,8 +195,7 @@ class OfflineTransducerModel::Impl { int32_t context_size_ = 0; // initialized in InitDecoder }; -OfflineTransducerModel::OfflineTransducerModel( - const OfflineTransducerModelConfig &config) +OfflineTransducerModel::OfflineTransducerModel(const OfflineModelConfig &config) : impl_(std::make_unique(config)) {} OfflineTransducerModel::~OfflineTransducerModel() = default; diff --git a/sherpa-onnx/csrc/offline-transducer-model.h b/sherpa-onnx/csrc/offline-transducer-model.h index f40c82a0..7f7d24d6 100644 --- a/sherpa-onnx/csrc/offline-transducer-model.h +++ b/sherpa-onnx/csrc/offline-transducer-model.h @@ -9,7 +9,7 @@ #include #include "onnxruntime_cxx_api.h" // NOLINT -#include "sherpa-onnx/csrc/offline-transducer-model-config.h" +#include "sherpa-onnx/csrc/offline-model-config.h" namespace sherpa_onnx { @@ -17,7 +17,7 @@ struct OfflineTransducerDecoderResult; class OfflineTransducerModel { public: - explicit OfflineTransducerModel(const OfflineTransducerModelConfig &config); + explicit OfflineTransducerModel(const OfflineModelConfig &config); ~OfflineTransducerModel(); /** Run the encoder. @@ -25,6 +25,7 @@ class OfflineTransducerModel { * @param features A tensor of shape (N, T, C). It is changed in-place. * @param features_length A 1-D tensor of shape (N,) containing number of * valid frames in `features` before padding. + * Its dtype is int64_t. * * @return Return a pair containing: * - encoder_out: A 3-D tensor of shape (N, T', encoder_dim) diff --git a/sherpa-onnx/csrc/onnx-utils.cc b/sherpa-onnx/csrc/onnx-utils.cc index 8b0cf34e..133d3df3 100644 --- a/sherpa-onnx/csrc/onnx-utils.cc +++ b/sherpa-onnx/csrc/onnx-utils.cc @@ -5,6 +5,7 @@ #include #include +#include #include #include @@ -133,19 +134,24 @@ void Print1D(Ort::Value *v) { fprintf(stderr, "\n"); } +template void Print2D(Ort::Value *v) { std::vector shape = v->GetTensorTypeAndShapeInfo().GetShape(); - const float *d = v->GetTensorData(); + const T *d = v->GetTensorData(); + std::ostringstream os; for (int32_t r = 0; r != static_cast(shape[0]); ++r) { for (int32_t c = 0; c != static_cast(shape[1]); ++c, ++d) { - fprintf(stderr, "%.3f ", *d); + os << *d << " "; } - fprintf(stderr, "\n"); + os << "\n"; } - fprintf(stderr, "\n"); + fprintf(stderr, "%s\n", os.str().c_str()); } +template void Print2D(Ort::Value *v); +template void Print2D(Ort::Value *v); + void Print3D(Ort::Value *v) { std::vector shape = v->GetTensorTypeAndShapeInfo().GetShape(); const float *d = v->GetTensorData(); diff --git a/sherpa-onnx/csrc/onnx-utils.h b/sherpa-onnx/csrc/onnx-utils.h index af4f3ccb..8bdba7a4 100644 --- a/sherpa-onnx/csrc/onnx-utils.h +++ b/sherpa-onnx/csrc/onnx-utils.h @@ -24,18 +24,6 @@ namespace sherpa_onnx { -#ifdef _MSC_VER -// See -// https://stackoverflow.com/questions/2573834/c-convert-string-or-char-to-wstring-or-wchar-t -static std::wstring ToWide(const std::string &s) { - std::wstring_convert> converter; - return converter.from_bytes(s); -} -#define SHERPA_MAYBE_WIDE(s) ToWide(s) -#else -#define SHERPA_MAYBE_WIDE(s) s -#endif - /** * Get the input names of a model. * @@ -79,6 +67,7 @@ Ort::Value Clone(OrtAllocator *allocator, const Ort::Value *v); void Print1D(Ort::Value *v); // Print a 2-D tensor to stderr +template void Print2D(Ort::Value *v); // Print a 3-D tensor to stderr diff --git a/sherpa-onnx/csrc/sherpa-onnx-offline.cc b/sherpa-onnx/csrc/sherpa-onnx-offline.cc index d4b7529d..1b5b4d0d 100644 --- a/sherpa-onnx/csrc/sherpa-onnx-offline.cc +++ b/sherpa-onnx/csrc/sherpa-onnx-offline.cc @@ -9,24 +9,35 @@ #include #include "sherpa-onnx/csrc/offline-recognizer.h" -#include "sherpa-onnx/csrc/offline-stream.h" -#include "sherpa-onnx/csrc/offline-transducer-decoder.h" -#include "sherpa-onnx/csrc/offline-transducer-greedy-search-decoder.h" -#include "sherpa-onnx/csrc/offline-transducer-model.h" -#include "sherpa-onnx/csrc/pad-sequence.h" -#include "sherpa-onnx/csrc/symbol-table.h" +#include "sherpa-onnx/csrc/parse-options.h" #include "sherpa-onnx/csrc/wave-reader.h" int main(int32_t argc, char *argv[]) { - if (argc < 6 || argc > 8) { - const char *usage = R"usage( + const char *kUsageMessage = R"usage( Usage: + +(1) Transducer from icefall + ./bin/sherpa-onnx-offline \ - /path/to/tokens.txt \ - /path/to/encoder.onnx \ - /path/to/decoder.onnx \ - /path/to/joiner.onnx \ - /path/to/foo.wav [num_threads [decoding_method]] + --tokens=/path/to/tokens.txt \ + --encoder=/path/to/encoder.onnx \ + --decoder=/path/to/decoder.onnx \ + --joiner=/path/to/joiner.onnx \ + --num-threads=2 \ + --decoding-method=greedy_search \ + /path/to/foo.wav [bar.wav foobar.wav ...] + + +(2) Paraformer from FunASR + + ./bin/sherpa-onnx-offline \ + --tokens=/path/to/tokens.txt \ + --paraformer=/path/to/model.onnx \ + --num-threads=2 \ + --decoding-method=greedy_search \ + /path/to/foo.wav [bar.wav foobar.wav ...] + +Note: It supports decoding multiple files in batches Default value for num_threads is 2. Valid values for decoding_method: greedy_search. @@ -37,29 +48,15 @@ Please refer to https://k2-fsa.github.io/sherpa/onnx/pretrained_models/index.html for a list of pre-trained models to download. )usage"; - fprintf(stderr, "%s\n", usage); - - return 0; - } + sherpa_onnx::ParseOptions po(kUsageMessage); sherpa_onnx::OfflineRecognizerConfig config; + config.Register(&po); - config.model_config.tokens = argv[1]; - - config.model_config.debug = false; - config.model_config.encoder_filename = argv[2]; - config.model_config.decoder_filename = argv[3]; - config.model_config.joiner_filename = argv[4]; - - std::string wav_filename = argv[5]; - - config.model_config.num_threads = 2; - if (argc == 7 && atoi(argv[6]) > 0) { - config.model_config.num_threads = atoi(argv[6]); - } - - if (argc == 8) { - config.decoding_method = argv[7]; + po.Read(argc, argv); + if (po.NumArgs() < 1) { + po.PrintUsage(); + exit(EXIT_FAILURE); } fprintf(stderr, "%s\n", config.ToString().c_str()); @@ -69,35 +66,43 @@ for a list of pre-trained models to download. return -1; } - int32_t sampling_rate = -1; - - bool is_ok = false; - std::vector samples = - sherpa_onnx::ReadWave(wav_filename, &sampling_rate, &is_ok); - if (!is_ok) { - fprintf(stderr, "Failed to read %s\n", wav_filename.c_str()); - return -1; - } - fprintf(stderr, "sampling rate of input file: %d\n", sampling_rate); - - float duration = samples.size() / static_cast(sampling_rate); - sherpa_onnx::OfflineRecognizer recognizer(config); - auto s = recognizer.CreateStream(); auto begin = std::chrono::steady_clock::now(); fprintf(stderr, "Started\n"); - s->AcceptWaveform(sampling_rate, samples.data(), samples.size()); + std::vector> ss; + std::vector ss_pointers; + float duration = 0; + for (int32_t i = 1; i <= po.NumArgs(); ++i) { + std::string wav_filename = po.GetArg(i); + int32_t sampling_rate = -1; + bool is_ok = false; + std::vector samples = + sherpa_onnx::ReadWave(wav_filename, &sampling_rate, &is_ok); + if (!is_ok) { + fprintf(stderr, "Failed to read %s\n", wav_filename.c_str()); + return -1; + } + duration += samples.size() / static_cast(sampling_rate); - recognizer.DecodeStream(s.get()); + auto s = recognizer.CreateStream(); + s->AcceptWaveform(sampling_rate, samples.data(), samples.size()); - fprintf(stderr, "Done!\n"); + ss.push_back(std::move(s)); + ss_pointers.push_back(ss.back().get()); + } - fprintf(stderr, "Recognition result for %s:\n%s\n", wav_filename.c_str(), - s->GetResult().text.c_str()); + recognizer.DecodeStreams(ss_pointers.data(), ss_pointers.size()); auto end = std::chrono::steady_clock::now(); + + fprintf(stderr, "Done!\n\n"); + for (int32_t i = 1; i <= po.NumArgs(); ++i) { + fprintf(stderr, "%s\n%s\n----\n", po.GetArg(i).c_str(), + ss[i - 1]->GetResult().text.c_str()); + } + float elapsed_seconds = std::chrono::duration_cast(end - begin) .count() /