Support paraformer. (#95)
This commit is contained in:
67
.github/scripts/test-offline-transducer.sh
vendored
67
.github/scripts/test-offline-transducer.sh
vendored
@@ -25,36 +25,59 @@ log "Download pretrained model and test-data from $repo_url"
|
||||
GIT_LFS_SKIP_SMUDGE=1 git clone $repo_url
|
||||
pushd $repo
|
||||
git lfs pull --include "*.onnx"
|
||||
cd test_wavs
|
||||
popd
|
||||
|
||||
waves=(
|
||||
$repo/test_wavs/0.wav
|
||||
$repo/test_wavs/1.wav
|
||||
$repo/test_wavs/2.wav
|
||||
)
|
||||
|
||||
for wave in ${waves[@]}; do
|
||||
time $EXE \
|
||||
$repo/tokens.txt \
|
||||
$repo/encoder-epoch-99-avg-1.onnx \
|
||||
$repo/decoder-epoch-99-avg-1.onnx \
|
||||
$repo/joiner-epoch-99-avg-1.onnx \
|
||||
$wave \
|
||||
2
|
||||
done
|
||||
time $EXE \
|
||||
--tokens=$repo/tokens.txt \
|
||||
--encoder=$repo/encoder-epoch-99-avg-1.onnx \
|
||||
--decoder=$repo/decoder-epoch-99-avg-1.onnx \
|
||||
--joiner=$repo/joiner-epoch-99-avg-1.onnx \
|
||||
--num-threads=2 \
|
||||
$repo/test_wavs/0.wav \
|
||||
$repo/test_wavs/1.wav \
|
||||
$repo/test_wavs/2.wav
|
||||
|
||||
|
||||
if command -v sox &> /dev/null; then
|
||||
echo "test 8kHz"
|
||||
sox $repo/test_wavs/0.wav -r 8000 8k.wav
|
||||
|
||||
time $EXE \
|
||||
$repo/tokens.txt \
|
||||
$repo/encoder-epoch-99-avg-1.onnx \
|
||||
$repo/decoder-epoch-99-avg-1.onnx \
|
||||
$repo/joiner-epoch-99-avg-1.onnx \
|
||||
8k.wav \
|
||||
2
|
||||
--tokens=$repo/tokens.txt \
|
||||
--encoder=$repo/encoder-epoch-99-avg-1.onnx \
|
||||
--decoder=$repo/decoder-epoch-99-avg-1.onnx \
|
||||
--joiner=$repo/joiner-epoch-99-avg-1.onnx \
|
||||
--num-threads=2 \
|
||||
$repo/test_wavs/0.wav \
|
||||
$repo/test_wavs/1.wav \
|
||||
$repo/test_wavs/2.wav \
|
||||
8k.wav
|
||||
fi
|
||||
|
||||
rm -rf $repo
|
||||
|
||||
log "------------------------------------------------------------"
|
||||
log "Run Paraformer (Chinese)"
|
||||
log "------------------------------------------------------------"
|
||||
|
||||
repo_url=https://huggingface.co/csukuangfj/sherpa-onnx-paraformer-zh-2023-03-28
|
||||
log "Start testing ${repo_url}"
|
||||
repo=$(basename $repo_url)
|
||||
log "Download pretrained model and test-data from $repo_url"
|
||||
|
||||
GIT_LFS_SKIP_SMUDGE=1 git clone $repo_url
|
||||
pushd $repo
|
||||
git lfs pull --include "*.onnx"
|
||||
popd
|
||||
|
||||
time $EXE \
|
||||
--tokens=$repo/tokens.txt \
|
||||
--paraformer=$repo/model.onnx \
|
||||
--num-threads=2 \
|
||||
--decoding-method=greedy_search \
|
||||
$repo/test_wavs/0.wav \
|
||||
$repo/test_wavs/1.wav \
|
||||
$repo/test_wavs/2.wav \
|
||||
$repo/test_wavs/8k.wav
|
||||
|
||||
rm -rf $repo
|
||||
|
||||
10
.github/workflows/windows-x64.yaml
vendored
10
.github/workflows/windows-x64.yaml
vendored
@@ -71,7 +71,15 @@ jobs:
|
||||
|
||||
ls -lh ./bin/Release/sherpa-onnx.exe
|
||||
|
||||
- name: Test sherpa-onnx for Windows x64
|
||||
- name: Test offline transducer for Windows x64
|
||||
shell: bash
|
||||
run: |
|
||||
export PATH=$PWD/build/bin/Release:$PATH
|
||||
export EXE=sherpa-onnx-offline.exe
|
||||
|
||||
.github/scripts/test-offline-transducer.sh
|
||||
|
||||
- name: Test online transducer for Windows x64
|
||||
shell: bash
|
||||
run: |
|
||||
export PATH=$PWD/build/bin/Release:$PATH
|
||||
|
||||
10
.github/workflows/windows-x86.yaml
vendored
10
.github/workflows/windows-x86.yaml
vendored
@@ -71,7 +71,15 @@ jobs:
|
||||
|
||||
ls -lh ./bin/Release/sherpa-onnx.exe
|
||||
|
||||
- name: Test sherpa-onnx for Windows x86
|
||||
- name: Test offline transducer for Windows x86
|
||||
shell: bash
|
||||
run: |
|
||||
export PATH=$PWD/build/bin/Release:$PATH
|
||||
export EXE=sherpa-onnx-offline.exe
|
||||
|
||||
.github/scripts/test-offline-transducer.sh
|
||||
|
||||
- name: Test online transducer for Windows x86
|
||||
shell: bash
|
||||
run: |
|
||||
export PATH=$PWD/build/bin/Release:$PATH
|
||||
|
||||
4
.gitignore
vendored
4
.gitignore
vendored
@@ -41,3 +41,7 @@ android/SherpaOnnx/app/src/main/assets/
|
||||
*.ncnn.*
|
||||
run-sherpa-onnx-offline.sh
|
||||
sherpa-onnx-conformer-en-2023-03-18
|
||||
paraformer-onnxruntime-python-example
|
||||
run-sherpa-onnx-offline-paraformer.sh
|
||||
run-sherpa-onnx-offline-transducer.sh
|
||||
sherpa-onnx-paraformer-zh-2023-03-28
|
||||
|
||||
@@ -6,6 +6,10 @@ set(sources
|
||||
features.cc
|
||||
file-utils.cc
|
||||
hypothesis.cc
|
||||
offline-model-config.cc
|
||||
offline-paraformer-greedy-search-decoder.cc
|
||||
offline-paraformer-model-config.cc
|
||||
offline-paraformer-model.cc
|
||||
offline-recognizer-impl.cc
|
||||
offline-recognizer.cc
|
||||
offline-stream.cc
|
||||
|
||||
@@ -57,6 +57,23 @@
|
||||
} \
|
||||
} while (0)
|
||||
|
||||
// read a vector of floats
|
||||
#define SHERPA_ONNX_READ_META_DATA_VEC_FLOAT(dst, src_key) \
|
||||
do { \
|
||||
auto value = \
|
||||
meta_data.LookupCustomMetadataMapAllocated(src_key, allocator); \
|
||||
if (!value) { \
|
||||
SHERPA_ONNX_LOGE("%s does not exist in the metadata", src_key); \
|
||||
exit(-1); \
|
||||
} \
|
||||
\
|
||||
bool ret = SplitStringToFloats(value.get(), ",", true, &dst); \
|
||||
if (!ret) { \
|
||||
SHERPA_ONNX_LOGE("Invalid value %s for %s", value.get(), src_key); \
|
||||
exit(-1); \
|
||||
} \
|
||||
} while (0)
|
||||
|
||||
// Read a string
|
||||
#define SHERPA_ONNX_READ_META_DATA_STR(dst, src_key) \
|
||||
do { \
|
||||
|
||||
57
sherpa-onnx/csrc/offline-model-config.cc
Normal file
57
sherpa-onnx/csrc/offline-model-config.cc
Normal file
@@ -0,0 +1,57 @@
|
||||
// sherpa-onnx/csrc/offline-model-config.cc
|
||||
//
|
||||
// Copyright (c) 2023 Xiaomi Corporation
|
||||
#include "sherpa-onnx/csrc/offline-model-config.h"
|
||||
|
||||
#include <string>
|
||||
|
||||
#include "sherpa-onnx/csrc/file-utils.h"
|
||||
#include "sherpa-onnx/csrc/macros.h"
|
||||
|
||||
namespace sherpa_onnx {
|
||||
|
||||
void OfflineModelConfig::Register(ParseOptions *po) {
|
||||
transducer.Register(po);
|
||||
paraformer.Register(po);
|
||||
|
||||
po->Register("tokens", &tokens, "Path to tokens.txt");
|
||||
|
||||
po->Register("num-threads", &num_threads,
|
||||
"Number of threads to run the neural network");
|
||||
|
||||
po->Register("debug", &debug,
|
||||
"true to print model information while loading it.");
|
||||
}
|
||||
|
||||
bool OfflineModelConfig::Validate() const {
|
||||
if (num_threads < 1) {
|
||||
SHERPA_ONNX_LOGE("num_threads should be > 0. Given %d", num_threads);
|
||||
return false;
|
||||
}
|
||||
|
||||
if (!FileExists(tokens)) {
|
||||
SHERPA_ONNX_LOGE("%s does not exist", tokens.c_str());
|
||||
return false;
|
||||
}
|
||||
|
||||
if (!paraformer.model.empty()) {
|
||||
return paraformer.Validate();
|
||||
}
|
||||
|
||||
return transducer.Validate();
|
||||
}
|
||||
|
||||
std::string OfflineModelConfig::ToString() const {
|
||||
std::ostringstream os;
|
||||
|
||||
os << "OfflineModelConfig(";
|
||||
os << "transducer=" << transducer.ToString() << ", ";
|
||||
os << "paraformer=" << paraformer.ToString() << ", ";
|
||||
os << "tokens=\"" << tokens << "\", ";
|
||||
os << "num_threads=" << num_threads << ", ";
|
||||
os << "debug=" << (debug ? "True" : "False") << ")";
|
||||
|
||||
return os.str();
|
||||
}
|
||||
|
||||
} // namespace sherpa_onnx
|
||||
40
sherpa-onnx/csrc/offline-model-config.h
Normal file
40
sherpa-onnx/csrc/offline-model-config.h
Normal file
@@ -0,0 +1,40 @@
|
||||
// sherpa-onnx/csrc/offline-model-config.h
|
||||
//
|
||||
// Copyright (c) 2023 Xiaomi Corporation
|
||||
#ifndef SHERPA_ONNX_CSRC_OFFLINE_MODEL_CONFIG_H_
|
||||
#define SHERPA_ONNX_CSRC_OFFLINE_MODEL_CONFIG_H_
|
||||
|
||||
#include <string>
|
||||
|
||||
#include "sherpa-onnx/csrc/offline-paraformer-model-config.h"
|
||||
#include "sherpa-onnx/csrc/offline-transducer-model-config.h"
|
||||
|
||||
namespace sherpa_onnx {
|
||||
|
||||
struct OfflineModelConfig {
|
||||
OfflineTransducerModelConfig transducer;
|
||||
OfflineParaformerModelConfig paraformer;
|
||||
|
||||
std::string tokens;
|
||||
int32_t num_threads = 2;
|
||||
bool debug = false;
|
||||
|
||||
OfflineModelConfig() = default;
|
||||
OfflineModelConfig(const OfflineTransducerModelConfig &transducer,
|
||||
const OfflineParaformerModelConfig ¶former,
|
||||
const std::string &tokens, int32_t num_threads, bool debug)
|
||||
: transducer(transducer),
|
||||
paraformer(paraformer),
|
||||
tokens(tokens),
|
||||
num_threads(num_threads),
|
||||
debug(debug) {}
|
||||
|
||||
void Register(ParseOptions *po);
|
||||
bool Validate() const;
|
||||
|
||||
std::string ToString() const;
|
||||
};
|
||||
|
||||
} // namespace sherpa_onnx
|
||||
|
||||
#endif // SHERPA_ONNX_CSRC_OFFLINE_MODEL_CONFIG_H_
|
||||
37
sherpa-onnx/csrc/offline-paraformer-decoder.h
Normal file
37
sherpa-onnx/csrc/offline-paraformer-decoder.h
Normal file
@@ -0,0 +1,37 @@
|
||||
// sherpa-onnx/csrc/offline-paraformer-decoder.h
|
||||
//
|
||||
// Copyright (c) 2023 Xiaomi Corporation
|
||||
|
||||
#ifndef SHERPA_ONNX_CSRC_OFFLINE_PARAFORMER_DECODER_H_
|
||||
#define SHERPA_ONNX_CSRC_OFFLINE_PARAFORMER_DECODER_H_
|
||||
|
||||
#include <vector>
|
||||
|
||||
#include "onnxruntime_cxx_api.h" // NOLINT
|
||||
|
||||
namespace sherpa_onnx {
|
||||
|
||||
struct OfflineParaformerDecoderResult {
|
||||
/// The decoded token IDs
|
||||
std::vector<int64_t> tokens;
|
||||
};
|
||||
|
||||
class OfflineParaformerDecoder {
|
||||
public:
|
||||
virtual ~OfflineParaformerDecoder() = default;
|
||||
|
||||
/** Run beam search given the output from the paraformer model.
|
||||
*
|
||||
* @param log_probs A 3-D tensor of shape (N, T, vocab_size)
|
||||
* @param token_num A 2-D tensor of shape (N, T). Its dtype is int64_t.
|
||||
* log_probs[i].argmax(axis=-1) equals to token_num[i]
|
||||
*
|
||||
* @return Return a vector of size `N` containing the decoded results.
|
||||
*/
|
||||
virtual std::vector<OfflineParaformerDecoderResult> Decode(
|
||||
Ort::Value log_probs, Ort::Value token_num) = 0;
|
||||
};
|
||||
|
||||
} // namespace sherpa_onnx
|
||||
|
||||
#endif // SHERPA_ONNX_CSRC_OFFLINE_PARAFORMER_DECODER_H_
|
||||
34
sherpa-onnx/csrc/offline-paraformer-greedy-search-decoder.cc
Normal file
34
sherpa-onnx/csrc/offline-paraformer-greedy-search-decoder.cc
Normal file
@@ -0,0 +1,34 @@
|
||||
// sherpa-onnx/csrc/offline-paraformer-greedy-search-decoder.cc
|
||||
//
|
||||
// Copyright (c) 2023 Xiaomi Corporation
|
||||
|
||||
#include "sherpa-onnx/csrc/offline-paraformer-greedy-search-decoder.h"
|
||||
|
||||
#include <vector>
|
||||
|
||||
namespace sherpa_onnx {
|
||||
|
||||
std::vector<OfflineParaformerDecoderResult>
|
||||
OfflineParaformerGreedySearchDecoder::Decode(Ort::Value /*log_probs*/,
|
||||
Ort::Value token_num) {
|
||||
std::vector<int64_t> shape = token_num.GetTensorTypeAndShapeInfo().GetShape();
|
||||
int32_t batch_size = shape[0];
|
||||
int32_t num_tokens = shape[1];
|
||||
|
||||
std::vector<OfflineParaformerDecoderResult> results(batch_size);
|
||||
|
||||
const int64_t *p = token_num.GetTensorData<int64_t>();
|
||||
for (int32_t i = 0; i != batch_size; ++i) {
|
||||
for (int32_t k = 0; k != num_tokens; ++k) {
|
||||
if (p[k] == eos_id_) break;
|
||||
|
||||
results[i].tokens.push_back(p[k]);
|
||||
}
|
||||
|
||||
p += num_tokens;
|
||||
}
|
||||
|
||||
return results;
|
||||
}
|
||||
|
||||
} // namespace sherpa_onnx
|
||||
28
sherpa-onnx/csrc/offline-paraformer-greedy-search-decoder.h
Normal file
28
sherpa-onnx/csrc/offline-paraformer-greedy-search-decoder.h
Normal file
@@ -0,0 +1,28 @@
|
||||
// sherpa-onnx/csrc/offline-paraformer-greedy-search-decoder.h
|
||||
//
|
||||
// Copyright (c) 2023 Xiaomi Corporation
|
||||
|
||||
#ifndef SHERPA_ONNX_CSRC_OFFLINE_PARAFORMER_GREEDY_SEARCH_DECODER_H_
|
||||
#define SHERPA_ONNX_CSRC_OFFLINE_PARAFORMER_GREEDY_SEARCH_DECODER_H_
|
||||
|
||||
#include <vector>
|
||||
|
||||
#include "sherpa-onnx/csrc/offline-paraformer-decoder.h"
|
||||
|
||||
namespace sherpa_onnx {
|
||||
|
||||
class OfflineParaformerGreedySearchDecoder : public OfflineParaformerDecoder {
|
||||
public:
|
||||
explicit OfflineParaformerGreedySearchDecoder(int32_t eos_id)
|
||||
: eos_id_(eos_id) {}
|
||||
|
||||
std::vector<OfflineParaformerDecoderResult> Decode(
|
||||
Ort::Value /*log_probs*/, Ort::Value token_num) override;
|
||||
|
||||
private:
|
||||
int32_t eos_id_;
|
||||
};
|
||||
|
||||
} // namespace sherpa_onnx
|
||||
|
||||
#endif // SHERPA_ONNX_CSRC_OFFLINE_PARAFORMER_GREEDY_SEARCH_DECODER_H_
|
||||
34
sherpa-onnx/csrc/offline-paraformer-model-config.cc
Normal file
34
sherpa-onnx/csrc/offline-paraformer-model-config.cc
Normal file
@@ -0,0 +1,34 @@
|
||||
// sherpa-onnx/csrc/offline-paraformer-model-config.cc
|
||||
//
|
||||
// Copyright (c) 2023 Xiaomi Corporation
|
||||
|
||||
#include "sherpa-onnx/csrc/offline-paraformer-model-config.h"
|
||||
|
||||
#include "sherpa-onnx/csrc/file-utils.h"
|
||||
#include "sherpa-onnx/csrc/macros.h"
|
||||
|
||||
namespace sherpa_onnx {
|
||||
|
||||
void OfflineParaformerModelConfig::Register(ParseOptions *po) {
|
||||
po->Register("paraformer", &model, "Path to model.onnx of paraformer.");
|
||||
}
|
||||
|
||||
bool OfflineParaformerModelConfig::Validate() const {
|
||||
if (!FileExists(model)) {
|
||||
SHERPA_ONNX_LOGE("%s does not exist", model.c_str());
|
||||
return false;
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
std::string OfflineParaformerModelConfig::ToString() const {
|
||||
std::ostringstream os;
|
||||
|
||||
os << "OfflineParaformerModelConfig(";
|
||||
os << "model=\"" << model << "\")";
|
||||
|
||||
return os.str();
|
||||
}
|
||||
|
||||
} // namespace sherpa_onnx
|
||||
28
sherpa-onnx/csrc/offline-paraformer-model-config.h
Normal file
28
sherpa-onnx/csrc/offline-paraformer-model-config.h
Normal file
@@ -0,0 +1,28 @@
|
||||
// sherpa-onnx/csrc/offline-paraformer-model-config.h
|
||||
//
|
||||
// Copyright (c) 2023 Xiaomi Corporation
|
||||
#ifndef SHERPA_ONNX_CSRC_OFFLINE_PARAFORMER_MODEL_CONFIG_H_
|
||||
#define SHERPA_ONNX_CSRC_OFFLINE_PARAFORMER_MODEL_CONFIG_H_
|
||||
|
||||
#include <string>
|
||||
|
||||
#include "sherpa-onnx/csrc/parse-options.h"
|
||||
|
||||
namespace sherpa_onnx {
|
||||
|
||||
struct OfflineParaformerModelConfig {
|
||||
std::string model;
|
||||
|
||||
OfflineParaformerModelConfig() = default;
|
||||
explicit OfflineParaformerModelConfig(const std::string &model)
|
||||
: model(model) {}
|
||||
|
||||
void Register(ParseOptions *po);
|
||||
bool Validate() const;
|
||||
|
||||
std::string ToString() const;
|
||||
};
|
||||
|
||||
} // namespace sherpa_onnx
|
||||
|
||||
#endif // SHERPA_ONNX_CSRC_OFFLINE_PARAFORMER_MODEL_CONFIG_H_
|
||||
132
sherpa-onnx/csrc/offline-paraformer-model.cc
Normal file
132
sherpa-onnx/csrc/offline-paraformer-model.cc
Normal file
@@ -0,0 +1,132 @@
|
||||
// sherpa-onnx/csrc/offline-paraformer-model.cc
|
||||
//
|
||||
// Copyright (c) 2022-2023 Xiaomi Corporation
|
||||
|
||||
#include "sherpa-onnx/csrc/offline-paraformer-model.h"
|
||||
|
||||
#include <algorithm>
|
||||
#include <string>
|
||||
|
||||
#include "sherpa-onnx/csrc/macros.h"
|
||||
#include "sherpa-onnx/csrc/onnx-utils.h"
|
||||
#include "sherpa-onnx/csrc/text-utils.h"
|
||||
|
||||
namespace sherpa_onnx {
|
||||
|
||||
class OfflineParaformerModel::Impl {
|
||||
public:
|
||||
explicit Impl(const OfflineModelConfig &config)
|
||||
: config_(config),
|
||||
env_(ORT_LOGGING_LEVEL_ERROR),
|
||||
sess_opts_{},
|
||||
allocator_{} {
|
||||
sess_opts_.SetIntraOpNumThreads(config_.num_threads);
|
||||
sess_opts_.SetInterOpNumThreads(config_.num_threads);
|
||||
|
||||
Init();
|
||||
}
|
||||
|
||||
std::pair<Ort::Value, Ort::Value> Forward(Ort::Value features,
|
||||
Ort::Value features_length) {
|
||||
std::array<Ort::Value, 2> inputs = {std::move(features),
|
||||
std::move(features_length)};
|
||||
|
||||
auto out =
|
||||
sess_->Run({}, input_names_ptr_.data(), inputs.data(), inputs.size(),
|
||||
output_names_ptr_.data(), output_names_ptr_.size());
|
||||
|
||||
return {std::move(out[0]), std::move(out[1])};
|
||||
}
|
||||
|
||||
int32_t VocabSize() const { return vocab_size_; }
|
||||
|
||||
int32_t LfrWindowSize() const { return lfr_window_size_; }
|
||||
|
||||
int32_t LfrWindowShift() const { return lfr_window_shift_; }
|
||||
|
||||
const std::vector<float> &NegativeMean() const { return neg_mean_; }
|
||||
|
||||
const std::vector<float> &InverseStdDev() const { return inv_stddev_; }
|
||||
|
||||
OrtAllocator *Allocator() const { return allocator_; }
|
||||
|
||||
private:
|
||||
void Init() {
|
||||
auto buf = ReadFile(config_.paraformer.model);
|
||||
|
||||
sess_ = std::make_unique<Ort::Session>(env_, buf.data(), buf.size(),
|
||||
sess_opts_);
|
||||
|
||||
GetInputNames(sess_.get(), &input_names_, &input_names_ptr_);
|
||||
|
||||
GetOutputNames(sess_.get(), &output_names_, &output_names_ptr_);
|
||||
|
||||
// get meta data
|
||||
Ort::ModelMetadata meta_data = sess_->GetModelMetadata();
|
||||
if (config_.debug) {
|
||||
std::ostringstream os;
|
||||
PrintModelMetadata(os, meta_data);
|
||||
SHERPA_ONNX_LOGE("%s\n", os.str().c_str());
|
||||
}
|
||||
|
||||
Ort::AllocatorWithDefaultOptions allocator; // used in the macro below
|
||||
SHERPA_ONNX_READ_META_DATA(vocab_size_, "vocab_size");
|
||||
SHERPA_ONNX_READ_META_DATA(lfr_window_size_, "lfr_window_size");
|
||||
SHERPA_ONNX_READ_META_DATA(lfr_window_shift_, "lfr_window_shift");
|
||||
|
||||
SHERPA_ONNX_READ_META_DATA_VEC_FLOAT(neg_mean_, "neg_mean");
|
||||
SHERPA_ONNX_READ_META_DATA_VEC_FLOAT(inv_stddev_, "inv_stddev");
|
||||
}
|
||||
|
||||
private:
|
||||
OfflineModelConfig config_;
|
||||
Ort::Env env_;
|
||||
Ort::SessionOptions sess_opts_;
|
||||
Ort::AllocatorWithDefaultOptions allocator_;
|
||||
|
||||
std::unique_ptr<Ort::Session> sess_;
|
||||
|
||||
std::vector<std::string> input_names_;
|
||||
std::vector<const char *> input_names_ptr_;
|
||||
|
||||
std::vector<std::string> output_names_;
|
||||
std::vector<const char *> output_names_ptr_;
|
||||
|
||||
std::vector<float> neg_mean_;
|
||||
std::vector<float> inv_stddev_;
|
||||
|
||||
int32_t vocab_size_ = 0; // initialized in Init
|
||||
int32_t lfr_window_size_ = 0;
|
||||
int32_t lfr_window_shift_ = 0;
|
||||
};
|
||||
|
||||
OfflineParaformerModel::OfflineParaformerModel(const OfflineModelConfig &config)
|
||||
: impl_(std::make_unique<Impl>(config)) {}
|
||||
|
||||
OfflineParaformerModel::~OfflineParaformerModel() = default;
|
||||
|
||||
std::pair<Ort::Value, Ort::Value> OfflineParaformerModel::Forward(
|
||||
Ort::Value features, Ort::Value features_length) {
|
||||
return impl_->Forward(std::move(features), std::move(features_length));
|
||||
}
|
||||
|
||||
int32_t OfflineParaformerModel::VocabSize() const { return impl_->VocabSize(); }
|
||||
|
||||
int32_t OfflineParaformerModel::LfrWindowSize() const {
|
||||
return impl_->LfrWindowSize();
|
||||
}
|
||||
int32_t OfflineParaformerModel::LfrWindowShift() const {
|
||||
return impl_->LfrWindowShift();
|
||||
}
|
||||
const std::vector<float> &OfflineParaformerModel::NegativeMean() const {
|
||||
return impl_->NegativeMean();
|
||||
}
|
||||
const std::vector<float> &OfflineParaformerModel::InverseStdDev() const {
|
||||
return impl_->InverseStdDev();
|
||||
}
|
||||
|
||||
OrtAllocator *OfflineParaformerModel::Allocator() const {
|
||||
return impl_->Allocator();
|
||||
}
|
||||
|
||||
} // namespace sherpa_onnx
|
||||
67
sherpa-onnx/csrc/offline-paraformer-model.h
Normal file
67
sherpa-onnx/csrc/offline-paraformer-model.h
Normal file
@@ -0,0 +1,67 @@
|
||||
// sherpa-onnx/csrc/offline-paraformer-model.h
|
||||
//
|
||||
// Copyright (c) 2022-2023 Xiaomi Corporation
|
||||
#ifndef SHERPA_ONNX_CSRC_OFFLINE_PARAFORMER_MODEL_H_
|
||||
#define SHERPA_ONNX_CSRC_OFFLINE_PARAFORMER_MODEL_H_
|
||||
|
||||
#include <memory>
|
||||
#include <utility>
|
||||
#include <vector>
|
||||
|
||||
#include "onnxruntime_cxx_api.h" // NOLINT
|
||||
#include "sherpa-onnx/csrc/offline-model-config.h"
|
||||
|
||||
namespace sherpa_onnx {
|
||||
|
||||
class OfflineParaformerModel {
|
||||
public:
|
||||
explicit OfflineParaformerModel(const OfflineModelConfig &config);
|
||||
~OfflineParaformerModel();
|
||||
|
||||
/** Run the forward method of the model.
|
||||
*
|
||||
* @param features A tensor of shape (N, T, C). It is changed in-place.
|
||||
* @param features_length A 1-D tensor of shape (N,) containing number of
|
||||
* valid frames in `features` before padding.
|
||||
* Its dtype is int32_t.
|
||||
*
|
||||
* @return Return a pair containing:
|
||||
* - log_probs: A 3-D tensor of shape (N, T', vocab_size)
|
||||
* - token_num: A 1-D tensor of shape (N, T') containing number
|
||||
* of valid tokens in each utterance. Its dtype is int64_t.
|
||||
*/
|
||||
std::pair<Ort::Value, Ort::Value> Forward(Ort::Value features,
|
||||
Ort::Value features_length);
|
||||
|
||||
/** Return the vocabulary size of the model
|
||||
*/
|
||||
int32_t VocabSize() const;
|
||||
|
||||
/** It is lfr_m in config.yaml
|
||||
*/
|
||||
int32_t LfrWindowSize() const;
|
||||
|
||||
/** It is lfr_n in config.yaml
|
||||
*/
|
||||
int32_t LfrWindowShift() const;
|
||||
|
||||
/** Return negative mean for CMVN
|
||||
*/
|
||||
const std::vector<float> &NegativeMean() const;
|
||||
|
||||
/** Return inverse stddev for CMVN
|
||||
*/
|
||||
const std::vector<float> &InverseStdDev() const;
|
||||
|
||||
/** Return an allocator for allocating memory
|
||||
*/
|
||||
OrtAllocator *Allocator() const;
|
||||
|
||||
private:
|
||||
class Impl;
|
||||
std::unique_ptr<Impl> impl_;
|
||||
};
|
||||
|
||||
} // namespace sherpa_onnx
|
||||
|
||||
#endif // SHERPA_ONNX_CSRC_OFFLINE_PARAFORMER_MODEL_H_
|
||||
@@ -8,6 +8,7 @@
|
||||
|
||||
#include "onnxruntime_cxx_api.h" // NOLINT
|
||||
#include "sherpa-onnx/csrc/macros.h"
|
||||
#include "sherpa-onnx/csrc/offline-recognizer-paraformer-impl.h"
|
||||
#include "sherpa-onnx/csrc/offline-recognizer-transducer-impl.h"
|
||||
#include "sherpa-onnx/csrc/onnx-utils.h"
|
||||
#include "sherpa-onnx/csrc/text-utils.h"
|
||||
@@ -16,10 +17,20 @@ namespace sherpa_onnx {
|
||||
|
||||
std::unique_ptr<OfflineRecognizerImpl> OfflineRecognizerImpl::Create(
|
||||
const OfflineRecognizerConfig &config) {
|
||||
Ort::Env env;
|
||||
Ort::Env env(ORT_LOGGING_LEVEL_ERROR);
|
||||
|
||||
Ort::SessionOptions sess_opts;
|
||||
auto buf = ReadFile(config.model_config.encoder_filename);
|
||||
std::string model_filename;
|
||||
if (!config.model_config.transducer.encoder_filename.empty()) {
|
||||
model_filename = config.model_config.transducer.encoder_filename;
|
||||
} else if (!config.model_config.paraformer.model.empty()) {
|
||||
model_filename = config.model_config.paraformer.model;
|
||||
} else {
|
||||
SHERPA_ONNX_LOGE("Please provide a model");
|
||||
exit(-1);
|
||||
}
|
||||
|
||||
auto buf = ReadFile(model_filename);
|
||||
|
||||
auto encoder_sess =
|
||||
std::make_unique<Ort::Session>(env, buf.data(), buf.size(), sess_opts);
|
||||
@@ -35,7 +46,16 @@ std::unique_ptr<OfflineRecognizerImpl> OfflineRecognizerImpl::Create(
|
||||
return std::make_unique<OfflineRecognizerTransducerImpl>(config);
|
||||
}
|
||||
|
||||
SHERPA_ONNX_LOGE("Unsupported model_type: %s\n", model_type.c_str());
|
||||
if (model_type == "paraformer") {
|
||||
return std::make_unique<OfflineRecognizerParaformerImpl>(config);
|
||||
}
|
||||
|
||||
SHERPA_ONNX_LOGE(
|
||||
"\nUnsupported model_type: %s\n"
|
||||
"We support only the following model types at present: \n"
|
||||
" - transducer models from icefall\n"
|
||||
" - Paraformer models from FunASR\n",
|
||||
model_type.c_str());
|
||||
|
||||
exit(-1);
|
||||
}
|
||||
|
||||
182
sherpa-onnx/csrc/offline-recognizer-paraformer-impl.h
Normal file
182
sherpa-onnx/csrc/offline-recognizer-paraformer-impl.h
Normal file
@@ -0,0 +1,182 @@
|
||||
// sherpa-onnx/csrc/offline-recognizer-paraformer-impl.h
|
||||
//
|
||||
// Copyright (c) 2022-2023 Xiaomi Corporation
|
||||
|
||||
#ifndef SHERPA_ONNX_CSRC_OFFLINE_RECOGNIZER_PARAFORMER_IMPL_H_
|
||||
#define SHERPA_ONNX_CSRC_OFFLINE_RECOGNIZER_PARAFORMER_IMPL_H_
|
||||
|
||||
#include <algorithm>
|
||||
#include <memory>
|
||||
#include <string>
|
||||
#include <utility>
|
||||
#include <vector>
|
||||
|
||||
#include "sherpa-onnx/csrc/offline-model-config.h"
|
||||
#include "sherpa-onnx/csrc/offline-paraformer-decoder.h"
|
||||
#include "sherpa-onnx/csrc/offline-paraformer-greedy-search-decoder.h"
|
||||
#include "sherpa-onnx/csrc/offline-paraformer-model.h"
|
||||
#include "sherpa-onnx/csrc/offline-recognizer-impl.h"
|
||||
#include "sherpa-onnx/csrc/offline-recognizer.h"
|
||||
#include "sherpa-onnx/csrc/pad-sequence.h"
|
||||
#include "sherpa-onnx/csrc/symbol-table.h"
|
||||
|
||||
namespace sherpa_onnx {
|
||||
|
||||
static OfflineRecognitionResult Convert(
|
||||
const OfflineParaformerDecoderResult &src, const SymbolTable &sym_table) {
|
||||
OfflineRecognitionResult r;
|
||||
r.tokens.reserve(src.tokens.size());
|
||||
|
||||
std::string text;
|
||||
for (auto i : src.tokens) {
|
||||
auto sym = sym_table[i];
|
||||
text.append(sym);
|
||||
|
||||
r.tokens.push_back(std::move(sym));
|
||||
}
|
||||
r.text = std::move(text);
|
||||
|
||||
return r;
|
||||
}
|
||||
|
||||
class OfflineRecognizerParaformerImpl : public OfflineRecognizerImpl {
|
||||
public:
|
||||
explicit OfflineRecognizerParaformerImpl(
|
||||
const OfflineRecognizerConfig &config)
|
||||
: config_(config),
|
||||
symbol_table_(config_.model_config.tokens),
|
||||
model_(std::make_unique<OfflineParaformerModel>(config.model_config)) {
|
||||
if (config.decoding_method == "greedy_search") {
|
||||
int32_t eos_id = symbol_table_["</s>"];
|
||||
decoder_ = std::make_unique<OfflineParaformerGreedySearchDecoder>(eos_id);
|
||||
} else {
|
||||
SHERPA_ONNX_LOGE("Only greedy_search is supported at present. Given %s",
|
||||
config.decoding_method.c_str());
|
||||
exit(-1);
|
||||
}
|
||||
|
||||
// Paraformer models assume input samples are in the range
|
||||
// [-32768, 32767], so we set normalize_samples to false
|
||||
config_.feat_config.normalize_samples = false;
|
||||
}
|
||||
|
||||
std::unique_ptr<OfflineStream> CreateStream() const override {
|
||||
return std::make_unique<OfflineStream>(config_.feat_config);
|
||||
}
|
||||
|
||||
void DecodeStreams(OfflineStream **ss, int32_t n) const override {
|
||||
// 1. Apply LFR
|
||||
// 2. Apply CMVN
|
||||
//
|
||||
// Please refer to
|
||||
// https://static.googleusercontent.com/media/research.google.com/en//pubs/archive/45555.pdf
|
||||
// for what LFR means
|
||||
//
|
||||
// "Lower Frame Rate Neural Network Acoustic Models"
|
||||
auto memory_info =
|
||||
Ort::MemoryInfo::CreateCpu(OrtDeviceAllocator, OrtMemTypeDefault);
|
||||
|
||||
std::vector<Ort::Value> features;
|
||||
features.reserve(n);
|
||||
|
||||
int32_t feat_dim =
|
||||
config_.feat_config.feature_dim * model_->LfrWindowSize();
|
||||
|
||||
std::vector<std::vector<float>> features_vec(n);
|
||||
std::vector<int32_t> features_length_vec(n);
|
||||
for (int32_t i = 0; i != n; ++i) {
|
||||
std::vector<float> f = ss[i]->GetFrames();
|
||||
|
||||
f = ApplyLFR(f);
|
||||
ApplyCMVN(&f);
|
||||
|
||||
int32_t num_frames = f.size() / feat_dim;
|
||||
features_vec[i] = std::move(f);
|
||||
|
||||
features_length_vec[i] = num_frames;
|
||||
|
||||
std::array<int64_t, 2> shape = {num_frames, feat_dim};
|
||||
|
||||
Ort::Value x = Ort::Value::CreateTensor(
|
||||
memory_info, features_vec[i].data(), features_vec[i].size(),
|
||||
shape.data(), shape.size());
|
||||
features.push_back(std::move(x));
|
||||
}
|
||||
|
||||
std::vector<const Ort::Value *> features_pointer(n);
|
||||
for (int32_t i = 0; i != n; ++i) {
|
||||
features_pointer[i] = &features[i];
|
||||
}
|
||||
|
||||
std::array<int64_t, 1> features_length_shape = {n};
|
||||
Ort::Value x_length = Ort::Value::CreateTensor(
|
||||
memory_info, features_length_vec.data(), n,
|
||||
features_length_shape.data(), features_length_shape.size());
|
||||
|
||||
// Caution(fangjun): We cannot pad it with log(eps),
|
||||
// i.e., -23.025850929940457f
|
||||
Ort::Value x = PadSequence(model_->Allocator(), features_pointer, 0);
|
||||
|
||||
auto t = model_->Forward(std::move(x), std::move(x_length));
|
||||
|
||||
auto results = decoder_->Decode(std::move(t.first), std::move(t.second));
|
||||
|
||||
for (int32_t i = 0; i != n; ++i) {
|
||||
auto r = Convert(results[i], symbol_table_);
|
||||
ss[i]->SetResult(r);
|
||||
}
|
||||
}
|
||||
|
||||
private:
|
||||
std::vector<float> ApplyLFR(const std::vector<float> &in) const {
|
||||
int32_t lfr_window_size = model_->LfrWindowSize();
|
||||
int32_t lfr_window_shift = model_->LfrWindowShift();
|
||||
int32_t in_feat_dim = config_.feat_config.feature_dim;
|
||||
|
||||
int32_t in_num_frames = in.size() / in_feat_dim;
|
||||
int32_t out_num_frames =
|
||||
(in_num_frames - lfr_window_size) / lfr_window_shift + 1;
|
||||
int32_t out_feat_dim = in_feat_dim * lfr_window_size;
|
||||
|
||||
std::vector<float> out(out_num_frames * out_feat_dim);
|
||||
|
||||
const float *p_in = in.data();
|
||||
float *p_out = out.data();
|
||||
|
||||
for (int32_t i = 0; i != out_num_frames; ++i) {
|
||||
std::copy(p_in, p_in + out_feat_dim, p_out);
|
||||
|
||||
p_out += out_feat_dim;
|
||||
p_in += lfr_window_shift * in_feat_dim;
|
||||
}
|
||||
|
||||
return out;
|
||||
}
|
||||
|
||||
void ApplyCMVN(std::vector<float> *v) const {
|
||||
const std::vector<float> &neg_mean = model_->NegativeMean();
|
||||
const std::vector<float> &inv_stddev = model_->InverseStdDev();
|
||||
|
||||
int32_t dim = neg_mean.size();
|
||||
int32_t num_frames = v->size() / dim;
|
||||
|
||||
float *p = v->data();
|
||||
|
||||
for (int32_t i = 0; i != num_frames; ++i) {
|
||||
for (int32_t k = 0; k != dim; ++k) {
|
||||
p[k] = (p[k] + neg_mean[k]) * inv_stddev[k];
|
||||
}
|
||||
|
||||
p += dim;
|
||||
}
|
||||
}
|
||||
|
||||
OfflineRecognizerConfig config_;
|
||||
SymbolTable symbol_table_;
|
||||
std::unique_ptr<OfflineParaformerModel> model_;
|
||||
std::unique_ptr<OfflineParaformerDecoder> decoder_;
|
||||
};
|
||||
|
||||
} // namespace sherpa_onnx
|
||||
|
||||
#endif // SHERPA_ONNX_CSRC_OFFLINE_RECOGNIZER_PARAFORMER_IMPL_H_
|
||||
@@ -1,6 +1,6 @@
|
||||
// sherpa-onnx/csrc/offline-recognizer-transducer-impl.h
|
||||
//
|
||||
// Copyright (c) 2022 Xiaomi Corporation
|
||||
// Copyright (c) 2022-2023 Xiaomi Corporation
|
||||
|
||||
#ifndef SHERPA_ONNX_CSRC_OFFLINE_RECOGNIZER_TRANSDUCER_IMPL_H_
|
||||
#define SHERPA_ONNX_CSRC_OFFLINE_RECOGNIZER_TRANSDUCER_IMPL_H_
|
||||
|
||||
@@ -6,6 +6,8 @@
|
||||
|
||||
#include <memory>
|
||||
|
||||
#include "sherpa-onnx/csrc/file-utils.h"
|
||||
#include "sherpa-onnx/csrc/macros.h"
|
||||
#include "sherpa-onnx/csrc/offline-recognizer-impl.h"
|
||||
|
||||
namespace sherpa_onnx {
|
||||
|
||||
@@ -9,6 +9,7 @@
|
||||
#include <string>
|
||||
#include <vector>
|
||||
|
||||
#include "sherpa-onnx/csrc/offline-model-config.h"
|
||||
#include "sherpa-onnx/csrc/offline-stream.h"
|
||||
#include "sherpa-onnx/csrc/offline-transducer-model-config.h"
|
||||
#include "sherpa-onnx/csrc/parse-options.h"
|
||||
@@ -32,7 +33,7 @@ struct OfflineRecognitionResult {
|
||||
|
||||
struct OfflineRecognizerConfig {
|
||||
OfflineFeatureExtractorConfig feat_config;
|
||||
OfflineTransducerModelConfig model_config;
|
||||
OfflineModelConfig model_config;
|
||||
|
||||
std::string decoding_method = "greedy_search";
|
||||
// only greedy_search is implemented
|
||||
@@ -40,7 +41,7 @@ struct OfflineRecognizerConfig {
|
||||
|
||||
OfflineRecognizerConfig() = default;
|
||||
OfflineRecognizerConfig(const OfflineFeatureExtractorConfig &feat_config,
|
||||
const OfflineTransducerModelConfig &model_config,
|
||||
const OfflineModelConfig &model_config,
|
||||
const std::string &decoding_method)
|
||||
: feat_config(feat_config),
|
||||
model_config(model_config),
|
||||
|
||||
@@ -38,7 +38,7 @@ std::string OfflineFeatureExtractorConfig::ToString() const {
|
||||
|
||||
class OfflineStream::Impl {
|
||||
public:
|
||||
explicit Impl(const OfflineFeatureExtractorConfig &config) {
|
||||
explicit Impl(const OfflineFeatureExtractorConfig &config) : config_(config) {
|
||||
opts_.frame_opts.dither = 0;
|
||||
opts_.frame_opts.snip_edges = false;
|
||||
opts_.frame_opts.samp_freq = config.sampling_rate;
|
||||
@@ -48,6 +48,19 @@ class OfflineStream::Impl {
|
||||
}
|
||||
|
||||
void AcceptWaveform(int32_t sampling_rate, const float *waveform, int32_t n) {
|
||||
if (config_.normalize_samples) {
|
||||
AcceptWaveformImpl(sampling_rate, waveform, n);
|
||||
} else {
|
||||
std::vector<float> buf(n);
|
||||
for (int32_t i = 0; i != n; ++i) {
|
||||
buf[i] = waveform[i] * 32768;
|
||||
}
|
||||
AcceptWaveformImpl(sampling_rate, buf.data(), n);
|
||||
}
|
||||
}
|
||||
|
||||
void AcceptWaveformImpl(int32_t sampling_rate, const float *waveform,
|
||||
int32_t n) {
|
||||
if (sampling_rate != opts_.frame_opts.samp_freq) {
|
||||
SHERPA_ONNX_LOGE(
|
||||
"Creating a resampler:\n"
|
||||
@@ -101,6 +114,7 @@ class OfflineStream::Impl {
|
||||
const OfflineRecognitionResult &GetResult() const { return r_; }
|
||||
|
||||
private:
|
||||
OfflineFeatureExtractorConfig config_;
|
||||
std::unique_ptr<knf::OnlineFbank> fbank_;
|
||||
knf::FbankOptions opts_;
|
||||
OfflineRecognitionResult r_;
|
||||
|
||||
@@ -23,6 +23,13 @@ struct OfflineFeatureExtractorConfig {
|
||||
// Feature dimension
|
||||
int32_t feature_dim = 80;
|
||||
|
||||
// Set internally by some models, e.g., paraformer
|
||||
// This parameter is not exposed to users from the commandline
|
||||
// If true, the feature extractor expects inputs to be normalized to
|
||||
// the range [-1, 1].
|
||||
// If false, we will multiply the inputs by 32768
|
||||
bool normalize_samples = true;
|
||||
|
||||
std::string ToString() const;
|
||||
|
||||
void Register(ParseOptions *po);
|
||||
|
||||
@@ -14,20 +14,9 @@ void OfflineTransducerModelConfig::Register(ParseOptions *po) {
|
||||
po->Register("encoder", &encoder_filename, "Path to encoder.onnx");
|
||||
po->Register("decoder", &decoder_filename, "Path to decoder.onnx");
|
||||
po->Register("joiner", &joiner_filename, "Path to joiner.onnx");
|
||||
po->Register("tokens", &tokens, "Path to tokens.txt");
|
||||
po->Register("num_threads", &num_threads,
|
||||
"Number of threads to run the neural network");
|
||||
|
||||
po->Register("debug", &debug,
|
||||
"true to print model information while loading it.");
|
||||
}
|
||||
|
||||
bool OfflineTransducerModelConfig::Validate() const {
|
||||
if (!FileExists(tokens)) {
|
||||
SHERPA_ONNX_LOGE("%s does not exist", tokens.c_str());
|
||||
return false;
|
||||
}
|
||||
|
||||
if (!FileExists(encoder_filename)) {
|
||||
SHERPA_ONNX_LOGE("%s does not exist", encoder_filename.c_str());
|
||||
return false;
|
||||
@@ -43,11 +32,6 @@ bool OfflineTransducerModelConfig::Validate() const {
|
||||
return false;
|
||||
}
|
||||
|
||||
if (num_threads < 1) {
|
||||
SHERPA_ONNX_LOGE("num_threads should be > 0. Given %d", num_threads);
|
||||
return false;
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
@@ -57,10 +41,7 @@ std::string OfflineTransducerModelConfig::ToString() const {
|
||||
os << "OfflineTransducerModelConfig(";
|
||||
os << "encoder_filename=\"" << encoder_filename << "\", ";
|
||||
os << "decoder_filename=\"" << decoder_filename << "\", ";
|
||||
os << "joiner_filename=\"" << joiner_filename << "\", ";
|
||||
os << "tokens=\"" << tokens << "\", ";
|
||||
os << "num_threads=" << num_threads << ", ";
|
||||
os << "debug=" << (debug ? "True" : "False") << ")";
|
||||
os << "joiner_filename=\"" << joiner_filename << "\")";
|
||||
|
||||
return os.str();
|
||||
}
|
||||
|
||||
@@ -14,22 +14,14 @@ struct OfflineTransducerModelConfig {
|
||||
std::string encoder_filename;
|
||||
std::string decoder_filename;
|
||||
std::string joiner_filename;
|
||||
std::string tokens;
|
||||
int32_t num_threads = 2;
|
||||
bool debug = false;
|
||||
|
||||
OfflineTransducerModelConfig() = default;
|
||||
OfflineTransducerModelConfig(const std::string &encoder_filename,
|
||||
const std::string &decoder_filename,
|
||||
const std::string &joiner_filename,
|
||||
const std::string &tokens, int32_t num_threads,
|
||||
bool debug)
|
||||
const std::string &joiner_filename)
|
||||
: encoder_filename(encoder_filename),
|
||||
decoder_filename(decoder_filename),
|
||||
joiner_filename(joiner_filename),
|
||||
tokens(tokens),
|
||||
num_threads(num_threads),
|
||||
debug(debug) {}
|
||||
joiner_filename(joiner_filename) {}
|
||||
|
||||
void Register(ParseOptions *po);
|
||||
bool Validate() const;
|
||||
|
||||
@@ -16,7 +16,7 @@ namespace sherpa_onnx {
|
||||
|
||||
class OfflineTransducerModel::Impl {
|
||||
public:
|
||||
explicit Impl(const OfflineTransducerModelConfig &config)
|
||||
explicit Impl(const OfflineModelConfig &config)
|
||||
: config_(config),
|
||||
env_(ORT_LOGGING_LEVEL_WARNING),
|
||||
sess_opts_{},
|
||||
@@ -24,17 +24,17 @@ class OfflineTransducerModel::Impl {
|
||||
sess_opts_.SetIntraOpNumThreads(config.num_threads);
|
||||
sess_opts_.SetInterOpNumThreads(config.num_threads);
|
||||
{
|
||||
auto buf = ReadFile(config.encoder_filename);
|
||||
auto buf = ReadFile(config.transducer.encoder_filename);
|
||||
InitEncoder(buf.data(), buf.size());
|
||||
}
|
||||
|
||||
{
|
||||
auto buf = ReadFile(config.decoder_filename);
|
||||
auto buf = ReadFile(config.transducer.decoder_filename);
|
||||
InitDecoder(buf.data(), buf.size());
|
||||
}
|
||||
|
||||
{
|
||||
auto buf = ReadFile(config.joiner_filename);
|
||||
auto buf = ReadFile(config.transducer.joiner_filename);
|
||||
InitJoiner(buf.data(), buf.size());
|
||||
}
|
||||
}
|
||||
@@ -164,7 +164,7 @@ class OfflineTransducerModel::Impl {
|
||||
}
|
||||
|
||||
private:
|
||||
OfflineTransducerModelConfig config_;
|
||||
OfflineModelConfig config_;
|
||||
Ort::Env env_;
|
||||
Ort::SessionOptions sess_opts_;
|
||||
Ort::AllocatorWithDefaultOptions allocator_;
|
||||
@@ -195,8 +195,7 @@ class OfflineTransducerModel::Impl {
|
||||
int32_t context_size_ = 0; // initialized in InitDecoder
|
||||
};
|
||||
|
||||
OfflineTransducerModel::OfflineTransducerModel(
|
||||
const OfflineTransducerModelConfig &config)
|
||||
OfflineTransducerModel::OfflineTransducerModel(const OfflineModelConfig &config)
|
||||
: impl_(std::make_unique<Impl>(config)) {}
|
||||
|
||||
OfflineTransducerModel::~OfflineTransducerModel() = default;
|
||||
|
||||
@@ -9,7 +9,7 @@
|
||||
#include <vector>
|
||||
|
||||
#include "onnxruntime_cxx_api.h" // NOLINT
|
||||
#include "sherpa-onnx/csrc/offline-transducer-model-config.h"
|
||||
#include "sherpa-onnx/csrc/offline-model-config.h"
|
||||
|
||||
namespace sherpa_onnx {
|
||||
|
||||
@@ -17,7 +17,7 @@ struct OfflineTransducerDecoderResult;
|
||||
|
||||
class OfflineTransducerModel {
|
||||
public:
|
||||
explicit OfflineTransducerModel(const OfflineTransducerModelConfig &config);
|
||||
explicit OfflineTransducerModel(const OfflineModelConfig &config);
|
||||
~OfflineTransducerModel();
|
||||
|
||||
/** Run the encoder.
|
||||
@@ -25,6 +25,7 @@ class OfflineTransducerModel {
|
||||
* @param features A tensor of shape (N, T, C). It is changed in-place.
|
||||
* @param features_length A 1-D tensor of shape (N,) containing number of
|
||||
* valid frames in `features` before padding.
|
||||
* Its dtype is int64_t.
|
||||
*
|
||||
* @return Return a pair containing:
|
||||
* - encoder_out: A 3-D tensor of shape (N, T', encoder_dim)
|
||||
|
||||
@@ -5,6 +5,7 @@
|
||||
|
||||
#include <algorithm>
|
||||
#include <fstream>
|
||||
#include <sstream>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
|
||||
@@ -133,19 +134,24 @@ void Print1D(Ort::Value *v) {
|
||||
fprintf(stderr, "\n");
|
||||
}
|
||||
|
||||
template <typename T /*= float*/>
|
||||
void Print2D(Ort::Value *v) {
|
||||
std::vector<int64_t> shape = v->GetTensorTypeAndShapeInfo().GetShape();
|
||||
const float *d = v->GetTensorData<float>();
|
||||
const T *d = v->GetTensorData<T>();
|
||||
|
||||
std::ostringstream os;
|
||||
for (int32_t r = 0; r != static_cast<int32_t>(shape[0]); ++r) {
|
||||
for (int32_t c = 0; c != static_cast<int32_t>(shape[1]); ++c, ++d) {
|
||||
fprintf(stderr, "%.3f ", *d);
|
||||
os << *d << " ";
|
||||
}
|
||||
fprintf(stderr, "\n");
|
||||
os << "\n";
|
||||
}
|
||||
fprintf(stderr, "\n");
|
||||
fprintf(stderr, "%s\n", os.str().c_str());
|
||||
}
|
||||
|
||||
template void Print2D<int64_t>(Ort::Value *v);
|
||||
template void Print2D<float>(Ort::Value *v);
|
||||
|
||||
void Print3D(Ort::Value *v) {
|
||||
std::vector<int64_t> shape = v->GetTensorTypeAndShapeInfo().GetShape();
|
||||
const float *d = v->GetTensorData<float>();
|
||||
|
||||
@@ -24,18 +24,6 @@
|
||||
|
||||
namespace sherpa_onnx {
|
||||
|
||||
#ifdef _MSC_VER
|
||||
// See
|
||||
// https://stackoverflow.com/questions/2573834/c-convert-string-or-char-to-wstring-or-wchar-t
|
||||
static std::wstring ToWide(const std::string &s) {
|
||||
std::wstring_convert<std::codecvt_utf8_utf16<wchar_t>> converter;
|
||||
return converter.from_bytes(s);
|
||||
}
|
||||
#define SHERPA_MAYBE_WIDE(s) ToWide(s)
|
||||
#else
|
||||
#define SHERPA_MAYBE_WIDE(s) s
|
||||
#endif
|
||||
|
||||
/**
|
||||
* Get the input names of a model.
|
||||
*
|
||||
@@ -79,6 +67,7 @@ Ort::Value Clone(OrtAllocator *allocator, const Ort::Value *v);
|
||||
void Print1D(Ort::Value *v);
|
||||
|
||||
// Print a 2-D tensor to stderr
|
||||
template <typename T = float>
|
||||
void Print2D(Ort::Value *v);
|
||||
|
||||
// Print a 3-D tensor to stderr
|
||||
|
||||
@@ -9,24 +9,35 @@
|
||||
#include <vector>
|
||||
|
||||
#include "sherpa-onnx/csrc/offline-recognizer.h"
|
||||
#include "sherpa-onnx/csrc/offline-stream.h"
|
||||
#include "sherpa-onnx/csrc/offline-transducer-decoder.h"
|
||||
#include "sherpa-onnx/csrc/offline-transducer-greedy-search-decoder.h"
|
||||
#include "sherpa-onnx/csrc/offline-transducer-model.h"
|
||||
#include "sherpa-onnx/csrc/pad-sequence.h"
|
||||
#include "sherpa-onnx/csrc/symbol-table.h"
|
||||
#include "sherpa-onnx/csrc/parse-options.h"
|
||||
#include "sherpa-onnx/csrc/wave-reader.h"
|
||||
|
||||
int main(int32_t argc, char *argv[]) {
|
||||
if (argc < 6 || argc > 8) {
|
||||
const char *usage = R"usage(
|
||||
const char *kUsageMessage = R"usage(
|
||||
Usage:
|
||||
|
||||
(1) Transducer from icefall
|
||||
|
||||
./bin/sherpa-onnx-offline \
|
||||
/path/to/tokens.txt \
|
||||
/path/to/encoder.onnx \
|
||||
/path/to/decoder.onnx \
|
||||
/path/to/joiner.onnx \
|
||||
/path/to/foo.wav [num_threads [decoding_method]]
|
||||
--tokens=/path/to/tokens.txt \
|
||||
--encoder=/path/to/encoder.onnx \
|
||||
--decoder=/path/to/decoder.onnx \
|
||||
--joiner=/path/to/joiner.onnx \
|
||||
--num-threads=2 \
|
||||
--decoding-method=greedy_search \
|
||||
/path/to/foo.wav [bar.wav foobar.wav ...]
|
||||
|
||||
|
||||
(2) Paraformer from FunASR
|
||||
|
||||
./bin/sherpa-onnx-offline \
|
||||
--tokens=/path/to/tokens.txt \
|
||||
--paraformer=/path/to/model.onnx \
|
||||
--num-threads=2 \
|
||||
--decoding-method=greedy_search \
|
||||
/path/to/foo.wav [bar.wav foobar.wav ...]
|
||||
|
||||
Note: It supports decoding multiple files in batches
|
||||
|
||||
Default value for num_threads is 2.
|
||||
Valid values for decoding_method: greedy_search.
|
||||
@@ -37,29 +48,15 @@ Please refer to
|
||||
https://k2-fsa.github.io/sherpa/onnx/pretrained_models/index.html
|
||||
for a list of pre-trained models to download.
|
||||
)usage";
|
||||
fprintf(stderr, "%s\n", usage);
|
||||
|
||||
return 0;
|
||||
}
|
||||
|
||||
sherpa_onnx::ParseOptions po(kUsageMessage);
|
||||
sherpa_onnx::OfflineRecognizerConfig config;
|
||||
config.Register(&po);
|
||||
|
||||
config.model_config.tokens = argv[1];
|
||||
|
||||
config.model_config.debug = false;
|
||||
config.model_config.encoder_filename = argv[2];
|
||||
config.model_config.decoder_filename = argv[3];
|
||||
config.model_config.joiner_filename = argv[4];
|
||||
|
||||
std::string wav_filename = argv[5];
|
||||
|
||||
config.model_config.num_threads = 2;
|
||||
if (argc == 7 && atoi(argv[6]) > 0) {
|
||||
config.model_config.num_threads = atoi(argv[6]);
|
||||
}
|
||||
|
||||
if (argc == 8) {
|
||||
config.decoding_method = argv[7];
|
||||
po.Read(argc, argv);
|
||||
if (po.NumArgs() < 1) {
|
||||
po.PrintUsage();
|
||||
exit(EXIT_FAILURE);
|
||||
}
|
||||
|
||||
fprintf(stderr, "%s\n", config.ToString().c_str());
|
||||
@@ -69,35 +66,43 @@ for a list of pre-trained models to download.
|
||||
return -1;
|
||||
}
|
||||
|
||||
int32_t sampling_rate = -1;
|
||||
|
||||
bool is_ok = false;
|
||||
std::vector<float> samples =
|
||||
sherpa_onnx::ReadWave(wav_filename, &sampling_rate, &is_ok);
|
||||
if (!is_ok) {
|
||||
fprintf(stderr, "Failed to read %s\n", wav_filename.c_str());
|
||||
return -1;
|
||||
}
|
||||
fprintf(stderr, "sampling rate of input file: %d\n", sampling_rate);
|
||||
|
||||
float duration = samples.size() / static_cast<float>(sampling_rate);
|
||||
|
||||
sherpa_onnx::OfflineRecognizer recognizer(config);
|
||||
auto s = recognizer.CreateStream();
|
||||
|
||||
auto begin = std::chrono::steady_clock::now();
|
||||
fprintf(stderr, "Started\n");
|
||||
|
||||
s->AcceptWaveform(sampling_rate, samples.data(), samples.size());
|
||||
std::vector<std::unique_ptr<sherpa_onnx::OfflineStream>> ss;
|
||||
std::vector<sherpa_onnx::OfflineStream *> ss_pointers;
|
||||
float duration = 0;
|
||||
for (int32_t i = 1; i <= po.NumArgs(); ++i) {
|
||||
std::string wav_filename = po.GetArg(i);
|
||||
int32_t sampling_rate = -1;
|
||||
bool is_ok = false;
|
||||
std::vector<float> samples =
|
||||
sherpa_onnx::ReadWave(wav_filename, &sampling_rate, &is_ok);
|
||||
if (!is_ok) {
|
||||
fprintf(stderr, "Failed to read %s\n", wav_filename.c_str());
|
||||
return -1;
|
||||
}
|
||||
duration += samples.size() / static_cast<float>(sampling_rate);
|
||||
|
||||
recognizer.DecodeStream(s.get());
|
||||
auto s = recognizer.CreateStream();
|
||||
s->AcceptWaveform(sampling_rate, samples.data(), samples.size());
|
||||
|
||||
fprintf(stderr, "Done!\n");
|
||||
ss.push_back(std::move(s));
|
||||
ss_pointers.push_back(ss.back().get());
|
||||
}
|
||||
|
||||
fprintf(stderr, "Recognition result for %s:\n%s\n", wav_filename.c_str(),
|
||||
s->GetResult().text.c_str());
|
||||
recognizer.DecodeStreams(ss_pointers.data(), ss_pointers.size());
|
||||
|
||||
auto end = std::chrono::steady_clock::now();
|
||||
|
||||
fprintf(stderr, "Done!\n\n");
|
||||
for (int32_t i = 1; i <= po.NumArgs(); ++i) {
|
||||
fprintf(stderr, "%s\n%s\n----\n", po.GetArg(i).c_str(),
|
||||
ss[i - 1]->GetResult().text.c_str());
|
||||
}
|
||||
|
||||
float elapsed_seconds =
|
||||
std::chrono::duration_cast<std::chrono::milliseconds>(end - begin)
|
||||
.count() /
|
||||
|
||||
Reference in New Issue
Block a user