Add online transducer decoder (#27)
This commit is contained in:
@@ -1,8 +1,8 @@
|
|||||||
function(download_kaldi_native_fbank)
|
function(download_kaldi_native_fbank)
|
||||||
include(FetchContent)
|
include(FetchContent)
|
||||||
|
|
||||||
set(kaldi_native_fbank_URL "https://github.com/csukuangfj/kaldi-native-fbank/archive/refs/tags/v1.11.tar.gz")
|
set(kaldi_native_fbank_URL "https://github.com/csukuangfj/kaldi-native-fbank/archive/refs/tags/v1.12.tar.gz")
|
||||||
set(kaldi_native_fbank_HASH "SHA256=e69ae25ef6f30566ef31ca949dd1b0b8ec3a827caeba93a61d82bb848dac5d69")
|
set(kaldi_native_fbank_HASH "SHA256=8f4dfc3f6ddb1adcd9ac0ae87743ebc6cbcae147aacf9d46e76fa54134e12b44")
|
||||||
|
|
||||||
set(KALDI_NATIVE_FBANK_BUILD_TESTS OFF CACHE BOOL "" FORCE)
|
set(KALDI_NATIVE_FBANK_BUILD_TESTS OFF CACHE BOOL "" FORCE)
|
||||||
set(KALDI_NATIVE_FBANK_BUILD_PYTHON OFF CACHE BOOL "" FORCE)
|
set(KALDI_NATIVE_FBANK_BUILD_PYTHON OFF CACHE BOOL "" FORCE)
|
||||||
@@ -11,10 +11,11 @@ function(download_kaldi_native_fbank)
|
|||||||
# If you don't have access to the Internet,
|
# If you don't have access to the Internet,
|
||||||
# please pre-download kaldi-native-fbank
|
# please pre-download kaldi-native-fbank
|
||||||
set(possible_file_locations
|
set(possible_file_locations
|
||||||
$ENV{HOME}/Downloads/kaldi-native-fbank-1.11.tar.gz
|
$ENV{HOME}/Downloads/kaldi-native-fbank-1.12.tar.gz
|
||||||
${PROJECT_SOURCE_DIR}/kaldi-native-fbank-1.11.tar.gz
|
${PROJECT_SOURCE_DIR}/kaldi-native-fbank-1.12.tar.gz
|
||||||
${PROJECT_BINARY_DIR}/kaldi-native-fbank-1.11.tar.gz
|
${PROJECT_BINARY_DIR}/kaldi-native-fbank-1.12.tar.gz
|
||||||
/tmp/kaldi-native-fbank-1.11.tar.gz
|
/tmp/kaldi-native-fbank-1.12.tar.gz
|
||||||
|
/star-fj/fangjun/download/github/kaldi-native-fbank-1.12.tar.gz
|
||||||
)
|
)
|
||||||
|
|
||||||
foreach(f IN LISTS possible_file_locations)
|
foreach(f IN LISTS possible_file_locations)
|
||||||
|
|||||||
@@ -9,6 +9,7 @@ function(download_onnxruntime)
|
|||||||
${PROJECT_SOURCE_DIR}/onnxruntime-linux-x64-1.14.0.tgz
|
${PROJECT_SOURCE_DIR}/onnxruntime-linux-x64-1.14.0.tgz
|
||||||
${PROJECT_BINARY_DIR}/onnxruntime-linux-x64-1.14.0.tgz
|
${PROJECT_BINARY_DIR}/onnxruntime-linux-x64-1.14.0.tgz
|
||||||
/tmp/onnxruntime-linux-x64-1.14.0.tgz
|
/tmp/onnxruntime-linux-x64-1.14.0.tgz
|
||||||
|
/star-fj/fangjun/download/github/onnxruntime-linux-x64-1.14.0.tgz
|
||||||
)
|
)
|
||||||
|
|
||||||
set(onnxruntime_URL "https://github.com/microsoft/onnxruntime/releases/download/v1.14.0/onnxruntime-linux-x64-1.14.0.tgz")
|
set(onnxruntime_URL "https://github.com/microsoft/onnxruntime/releases/download/v1.14.0/onnxruntime-linux-x64-1.14.0.tgz")
|
||||||
|
|||||||
@@ -1,9 +1,9 @@
|
|||||||
include_directories(${CMAKE_SOURCE_DIR})
|
include_directories(${CMAKE_SOURCE_DIR})
|
||||||
|
|
||||||
add_executable(sherpa-onnx
|
add_executable(sherpa-onnx
|
||||||
decode.cc
|
|
||||||
features.cc
|
features.cc
|
||||||
online-lstm-transducer-model.cc
|
online-lstm-transducer-model.cc
|
||||||
|
online-transducer-greedy-search-decoder.cc
|
||||||
online-transducer-model-config.cc
|
online-transducer-model-config.cc
|
||||||
online-transducer-model.cc
|
online-transducer-model.cc
|
||||||
onnx-utils.cc
|
onnx-utils.cc
|
||||||
|
|||||||
@@ -1,79 +0,0 @@
|
|||||||
// sherpa/csrc/decode.cc
|
|
||||||
//
|
|
||||||
// Copyright (c) 2023 Xiaomi Corporation
|
|
||||||
|
|
||||||
#include "sherpa-onnx/csrc/decode.h"
|
|
||||||
|
|
||||||
#include <assert.h>
|
|
||||||
|
|
||||||
#include <algorithm>
|
|
||||||
#include <utility>
|
|
||||||
#include <vector>
|
|
||||||
|
|
||||||
namespace sherpa_onnx {
|
|
||||||
|
|
||||||
static Ort::Value Clone(Ort::Value *v) {
|
|
||||||
auto type_and_shape = v->GetTensorTypeAndShapeInfo();
|
|
||||||
std::vector<int64_t> shape = type_and_shape.GetShape();
|
|
||||||
|
|
||||||
auto memory_info =
|
|
||||||
Ort::MemoryInfo::CreateCpu(OrtDeviceAllocator, OrtMemTypeDefault);
|
|
||||||
|
|
||||||
return Ort::Value::CreateTensor(memory_info, v->GetTensorMutableData<float>(),
|
|
||||||
type_and_shape.GetElementCount(),
|
|
||||||
shape.data(), shape.size());
|
|
||||||
}
|
|
||||||
|
|
||||||
static Ort::Value GetFrame(Ort::Value *encoder_out, int32_t t) {
|
|
||||||
std::vector<int64_t> encoder_out_shape =
|
|
||||||
encoder_out->GetTensorTypeAndShapeInfo().GetShape();
|
|
||||||
assert(encoder_out_shape[0] == 1);
|
|
||||||
|
|
||||||
int32_t encoder_out_dim = encoder_out_shape[2];
|
|
||||||
|
|
||||||
auto memory_info =
|
|
||||||
Ort::MemoryInfo::CreateCpu(OrtDeviceAllocator, OrtMemTypeDefault);
|
|
||||||
|
|
||||||
std::array<int64_t, 2> shape{1, encoder_out_dim};
|
|
||||||
|
|
||||||
return Ort::Value::CreateTensor(
|
|
||||||
memory_info,
|
|
||||||
encoder_out->GetTensorMutableData<float>() + t * encoder_out_dim,
|
|
||||||
encoder_out_dim, shape.data(), shape.size());
|
|
||||||
}
|
|
||||||
|
|
||||||
void GreedySearch(OnlineTransducerModel *model, Ort::Value encoder_out,
|
|
||||||
std::vector<int64_t> *hyp) {
|
|
||||||
std::vector<int64_t> encoder_out_shape =
|
|
||||||
encoder_out.GetTensorTypeAndShapeInfo().GetShape();
|
|
||||||
|
|
||||||
if (encoder_out_shape[0] > 1) {
|
|
||||||
fprintf(stderr, "Only batch_size=1 is implemented. Given: %d\n",
|
|
||||||
static_cast<int32_t>(encoder_out_shape[0]));
|
|
||||||
}
|
|
||||||
|
|
||||||
int32_t num_frames = encoder_out_shape[1];
|
|
||||||
int32_t vocab_size = model->VocabSize();
|
|
||||||
|
|
||||||
Ort::Value decoder_input = model->BuildDecoderInput(*hyp);
|
|
||||||
Ort::Value decoder_out = model->RunDecoder(std::move(decoder_input));
|
|
||||||
|
|
||||||
for (int32_t t = 0; t != num_frames; ++t) {
|
|
||||||
Ort::Value cur_encoder_out = GetFrame(&encoder_out, t);
|
|
||||||
Ort::Value logit =
|
|
||||||
model->RunJoiner(std::move(cur_encoder_out), Clone(&decoder_out));
|
|
||||||
const float *p_logit = logit.GetTensorData<float>();
|
|
||||||
|
|
||||||
auto y = static_cast<int32_t>(std::distance(
|
|
||||||
static_cast<const float *>(p_logit),
|
|
||||||
std::max_element(static_cast<const float *>(p_logit),
|
|
||||||
static_cast<const float *>(p_logit) + vocab_size)));
|
|
||||||
if (y != 0) {
|
|
||||||
hyp->push_back(y);
|
|
||||||
decoder_input = model->BuildDecoderInput(*hyp);
|
|
||||||
decoder_out = model->RunDecoder(std::move(decoder_input));
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
} // namespace sherpa_onnx
|
|
||||||
@@ -1,26 +0,0 @@
|
|||||||
// sherpa/csrc/decode.h
|
|
||||||
//
|
|
||||||
// Copyright (c) 2023 Xiaomi Corporation
|
|
||||||
|
|
||||||
#ifndef SHERPA_ONNX_CSRC_DECODE_H_
|
|
||||||
#define SHERPA_ONNX_CSRC_DECODE_H_
|
|
||||||
|
|
||||||
#include <vector>
|
|
||||||
|
|
||||||
#include "sherpa-onnx/csrc/online-transducer-model.h"
|
|
||||||
|
|
||||||
namespace sherpa_onnx {
|
|
||||||
|
|
||||||
/** Greedy search for non-streaming ASR.
|
|
||||||
*
|
|
||||||
* @TODO(fangjun) Support batch size > 1
|
|
||||||
*
|
|
||||||
* @param model The RnntModel
|
|
||||||
* @param encoder_out Its shape is (1, num_frames, encoder_out_dim).
|
|
||||||
*/
|
|
||||||
void GreedySearch(OnlineTransducerModel *model, Ort::Value encoder_out,
|
|
||||||
std::vector<int64_t> *hyp);
|
|
||||||
|
|
||||||
} // namespace sherpa_onnx
|
|
||||||
|
|
||||||
#endif // SHERPA_ONNX_CSRC_DECODE_H_
|
|
||||||
@@ -15,16 +15,16 @@ namespace sherpa_onnx {
|
|||||||
|
|
||||||
class FeatureExtractor::Impl {
|
class FeatureExtractor::Impl {
|
||||||
public:
|
public:
|
||||||
Impl(int32_t sampling_rate, int32_t feature_dim) {
|
explicit Impl(const FeatureExtractorConfig &config) {
|
||||||
opts_.frame_opts.dither = 0;
|
opts_.frame_opts.dither = 0;
|
||||||
opts_.frame_opts.snip_edges = false;
|
opts_.frame_opts.snip_edges = false;
|
||||||
opts_.frame_opts.samp_freq = sampling_rate;
|
opts_.frame_opts.samp_freq = config.sampling_rate;
|
||||||
|
|
||||||
// cache 100 seconds of feature frames, which is more than enough
|
// cache 100 seconds of feature frames, which is more than enough
|
||||||
// for real needs
|
// for real needs
|
||||||
opts_.frame_opts.max_feature_vectors = 100 * 100;
|
opts_.frame_opts.max_feature_vectors = 100 * 100;
|
||||||
|
|
||||||
opts_.mel_opts.num_bins = feature_dim;
|
opts_.mel_opts.num_bins = config.feature_dim;
|
||||||
|
|
||||||
fbank_ = std::make_unique<knf::OnlineFbank>(opts_);
|
fbank_ = std::make_unique<knf::OnlineFbank>(opts_);
|
||||||
}
|
}
|
||||||
@@ -80,9 +80,8 @@ class FeatureExtractor::Impl {
|
|||||||
mutable std::mutex mutex_;
|
mutable std::mutex mutex_;
|
||||||
};
|
};
|
||||||
|
|
||||||
FeatureExtractor::FeatureExtractor(int32_t sampling_rate /*=16000*/,
|
FeatureExtractor::FeatureExtractor(const FeatureExtractorConfig &config /*={}*/)
|
||||||
int32_t feature_dim /*=80*/)
|
: impl_(std::make_unique<Impl>(config)) {}
|
||||||
: impl_(std::make_unique<Impl>(sampling_rate, feature_dim)) {}
|
|
||||||
|
|
||||||
FeatureExtractor::~FeatureExtractor() = default;
|
FeatureExtractor::~FeatureExtractor() = default;
|
||||||
|
|
||||||
|
|||||||
@@ -10,14 +10,18 @@
|
|||||||
|
|
||||||
namespace sherpa_onnx {
|
namespace sherpa_onnx {
|
||||||
|
|
||||||
|
struct FeatureExtractorConfig {
|
||||||
|
int32_t sampling_rate = 16000;
|
||||||
|
int32_t feature_dim = 80;
|
||||||
|
};
|
||||||
|
|
||||||
class FeatureExtractor {
|
class FeatureExtractor {
|
||||||
public:
|
public:
|
||||||
/**
|
/**
|
||||||
* @param sampling_rate Sampling rate of the data used to train the model.
|
* @param sampling_rate Sampling rate of the data used to train the model.
|
||||||
* @param feature_dim Dimension of the features used to train the model.
|
* @param feature_dim Dimension of the features used to train the model.
|
||||||
*/
|
*/
|
||||||
explicit FeatureExtractor(int32_t sampling_rate = 16000,
|
explicit FeatureExtractor(const FeatureExtractorConfig &config = {});
|
||||||
int32_t feature_dim = 80);
|
|
||||||
~FeatureExtractor();
|
~FeatureExtractor();
|
||||||
|
|
||||||
/**
|
/**
|
||||||
|
|||||||
52
sherpa-onnx/csrc/online-transducer-decoder.h
Normal file
52
sherpa-onnx/csrc/online-transducer-decoder.h
Normal file
@@ -0,0 +1,52 @@
|
|||||||
|
// sherpa/csrc/online-transducer-decoder.h
|
||||||
|
//
|
||||||
|
// Copyright (c) 2023 Xiaomi Corporation
|
||||||
|
|
||||||
|
#ifndef SHERPA_ONNX_CSRC_ONLINE_TRANSDUCER_DECODER_H_
|
||||||
|
#define SHERPA_ONNX_CSRC_ONLINE_TRANSDUCER_DECODER_H_
|
||||||
|
|
||||||
|
#include <vector>
|
||||||
|
|
||||||
|
#include "onnxruntime_cxx_api.h" // NOLINT
|
||||||
|
|
||||||
|
namespace sherpa_onnx {
|
||||||
|
|
||||||
|
struct OnlineTransducerDecoderResult {
|
||||||
|
/// The decoded token IDs so far
|
||||||
|
std::vector<int64_t> tokens;
|
||||||
|
};
|
||||||
|
|
||||||
|
class OnlineTransducerDecoder {
|
||||||
|
public:
|
||||||
|
virtual ~OnlineTransducerDecoder() = default;
|
||||||
|
|
||||||
|
/* Return an empty result.
|
||||||
|
*
|
||||||
|
* To simplify the decoding code, we add `context_size` blanks
|
||||||
|
* to the beginning of the decoding result, which will be
|
||||||
|
* stripped by calling `StripPrecedingBlanks()`.
|
||||||
|
*/
|
||||||
|
virtual OnlineTransducerDecoderResult GetEmptyResult() = 0;
|
||||||
|
|
||||||
|
/** Strip blanks added by `GetEmptyResult()`.
|
||||||
|
*
|
||||||
|
* @param r It is changed in-place.
|
||||||
|
*/
|
||||||
|
virtual void StripLeadingBlanks(OnlineTransducerDecoderResult * /*r*/) {}
|
||||||
|
|
||||||
|
/** Run transducer beam search given the output from the encoder model.
|
||||||
|
*
|
||||||
|
* @param encoder_out A 3-D tensor of shape (N, T, joiner_dim)
|
||||||
|
* @param result It is modified in-place.
|
||||||
|
*
|
||||||
|
* @note There is no need to pass encoder_out_length here since for the
|
||||||
|
* online decoding case, each utterance has the same number of frames
|
||||||
|
* and there are no paddings.
|
||||||
|
*/
|
||||||
|
virtual void Decode(Ort::Value encoder_out,
|
||||||
|
std::vector<OnlineTransducerDecoderResult> *result) = 0;
|
||||||
|
};
|
||||||
|
|
||||||
|
} // namespace sherpa_onnx
|
||||||
|
|
||||||
|
#endif // SHERPA_ONNX_CSRC_ONLINE_TRANSDUCER_DECODER_H_
|
||||||
101
sherpa-onnx/csrc/online-transducer-greedy-search-decoder.cc
Normal file
101
sherpa-onnx/csrc/online-transducer-greedy-search-decoder.cc
Normal file
@@ -0,0 +1,101 @@
|
|||||||
|
// sherpa/csrc/online-transducer-greedy-search-decoder.cc
|
||||||
|
//
|
||||||
|
// Copyright (c) 2023 Xiaomi Corporation
|
||||||
|
|
||||||
|
#include "sherpa-onnx/csrc/online-transducer-greedy-search-decoder.h"
|
||||||
|
|
||||||
|
#include <assert.h>
|
||||||
|
|
||||||
|
#include <algorithm>
|
||||||
|
#include <utility>
|
||||||
|
#include <vector>
|
||||||
|
|
||||||
|
#include "sherpa-onnx/csrc/onnx-utils.h"
|
||||||
|
|
||||||
|
namespace sherpa_onnx {
|
||||||
|
|
||||||
|
static Ort::Value GetFrame(Ort::Value *encoder_out, int32_t t) {
|
||||||
|
std::vector<int64_t> encoder_out_shape =
|
||||||
|
encoder_out->GetTensorTypeAndShapeInfo().GetShape();
|
||||||
|
assert(encoder_out_shape[0] == 1);
|
||||||
|
|
||||||
|
int32_t encoder_out_dim = encoder_out_shape[2];
|
||||||
|
|
||||||
|
auto memory_info =
|
||||||
|
Ort::MemoryInfo::CreateCpu(OrtDeviceAllocator, OrtMemTypeDefault);
|
||||||
|
|
||||||
|
std::array<int64_t, 2> shape{1, encoder_out_dim};
|
||||||
|
|
||||||
|
return Ort::Value::CreateTensor(
|
||||||
|
memory_info,
|
||||||
|
encoder_out->GetTensorMutableData<float>() + t * encoder_out_dim,
|
||||||
|
encoder_out_dim, shape.data(), shape.size());
|
||||||
|
}
|
||||||
|
|
||||||
|
OnlineTransducerDecoderResult
|
||||||
|
OnlineTransducerGreedySearchDecoder::GetEmptyResult() {
|
||||||
|
int32_t context_size = model_->ContextSize();
|
||||||
|
int32_t blank_id = 0; // always 0
|
||||||
|
OnlineTransducerDecoderResult r;
|
||||||
|
r.tokens.resize(context_size, blank_id);
|
||||||
|
|
||||||
|
return r;
|
||||||
|
}
|
||||||
|
|
||||||
|
void OnlineTransducerGreedySearchDecoder::StripLeadingBlanks(
|
||||||
|
OnlineTransducerDecoderResult *r) {
|
||||||
|
int32_t context_size = model_->ContextSize();
|
||||||
|
|
||||||
|
auto start = r->tokens.begin() + context_size;
|
||||||
|
auto end = r->tokens.end();
|
||||||
|
|
||||||
|
r->tokens = std::vector<int64_t>(start, end);
|
||||||
|
}
|
||||||
|
|
||||||
|
void OnlineTransducerGreedySearchDecoder::Decode(
|
||||||
|
Ort::Value encoder_out,
|
||||||
|
std::vector<OnlineTransducerDecoderResult> *result) {
|
||||||
|
std::vector<int64_t> encoder_out_shape =
|
||||||
|
encoder_out.GetTensorTypeAndShapeInfo().GetShape();
|
||||||
|
|
||||||
|
if (encoder_out_shape[0] != result->size()) {
|
||||||
|
fprintf(stderr,
|
||||||
|
"Size mismatch! encoder_out.size(0) %d, result.size(0): %d\n",
|
||||||
|
static_cast<int32_t>(encoder_out_shape[0]),
|
||||||
|
static_cast<int32_t>(result->size()));
|
||||||
|
exit(-1);
|
||||||
|
}
|
||||||
|
|
||||||
|
if (result->size() != 1) {
|
||||||
|
fprintf(stderr, "only batch size == 1 is implemented. Given: %d",
|
||||||
|
static_cast<int32_t>(result->size()));
|
||||||
|
exit(-1);
|
||||||
|
}
|
||||||
|
|
||||||
|
auto &hyp = (*result)[0].tokens;
|
||||||
|
|
||||||
|
int32_t num_frames = encoder_out_shape[1];
|
||||||
|
int32_t vocab_size = model_->VocabSize();
|
||||||
|
|
||||||
|
Ort::Value decoder_input = model_->BuildDecoderInput(hyp);
|
||||||
|
Ort::Value decoder_out = model_->RunDecoder(std::move(decoder_input));
|
||||||
|
|
||||||
|
for (int32_t t = 0; t != num_frames; ++t) {
|
||||||
|
Ort::Value cur_encoder_out = GetFrame(&encoder_out, t);
|
||||||
|
Ort::Value logit =
|
||||||
|
model_->RunJoiner(std::move(cur_encoder_out), Clone(&decoder_out));
|
||||||
|
const float *p_logit = logit.GetTensorData<float>();
|
||||||
|
|
||||||
|
auto y = static_cast<int32_t>(std::distance(
|
||||||
|
static_cast<const float *>(p_logit),
|
||||||
|
std::max_element(static_cast<const float *>(p_logit),
|
||||||
|
static_cast<const float *>(p_logit) + vocab_size)));
|
||||||
|
if (y != 0) {
|
||||||
|
hyp.push_back(y);
|
||||||
|
decoder_input = model_->BuildDecoderInput(hyp);
|
||||||
|
decoder_out = model_->RunDecoder(std::move(decoder_input));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace sherpa_onnx
|
||||||
33
sherpa-onnx/csrc/online-transducer-greedy-search-decoder.h
Normal file
33
sherpa-onnx/csrc/online-transducer-greedy-search-decoder.h
Normal file
@@ -0,0 +1,33 @@
|
|||||||
|
// sherpa/csrc/online-transducer-greedy-search-decoder.h
|
||||||
|
//
|
||||||
|
// Copyright (c) 2023 Xiaomi Corporation
|
||||||
|
|
||||||
|
#ifndef SHERPA_ONNX_CSRC_ONLINE_TRANSDUCER_GREEDY_SEARCH_DECODER_H_
|
||||||
|
#define SHERPA_ONNX_CSRC_ONLINE_TRANSDUCER_GREEDY_SEARCH_DECODER_H_
|
||||||
|
|
||||||
|
#include <vector>
|
||||||
|
|
||||||
|
#include "sherpa-onnx/csrc/online-transducer-decoder.h"
|
||||||
|
#include "sherpa-onnx/csrc/online-transducer-model.h"
|
||||||
|
|
||||||
|
namespace sherpa_onnx {
|
||||||
|
|
||||||
|
class OnlineTransducerGreedySearchDecoder : public OnlineTransducerDecoder {
|
||||||
|
public:
|
||||||
|
explicit OnlineTransducerGreedySearchDecoder(OnlineTransducerModel *model)
|
||||||
|
: model_(model) {}
|
||||||
|
|
||||||
|
OnlineTransducerDecoderResult GetEmptyResult() override;
|
||||||
|
|
||||||
|
void StripLeadingBlanks(OnlineTransducerDecoderResult *r) override;
|
||||||
|
|
||||||
|
void Decode(Ort::Value encoder_out,
|
||||||
|
std::vector<OnlineTransducerDecoderResult> *result) override;
|
||||||
|
|
||||||
|
private:
|
||||||
|
OnlineTransducerModel *model_; // Not owned
|
||||||
|
};
|
||||||
|
|
||||||
|
} // namespace sherpa_onnx
|
||||||
|
|
||||||
|
#endif // SHERPA_ONNX_CSRC_ONLINE_TRANSDUCER_GREEDY_SEARCH_DECODER_H_
|
||||||
@@ -46,4 +46,16 @@ void PrintModelMetadata(std::ostream &os, const Ort::ModelMetadata &meta_data) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
Ort::Value Clone(Ort::Value *v) {
|
||||||
|
auto type_and_shape = v->GetTensorTypeAndShapeInfo();
|
||||||
|
std::vector<int64_t> shape = type_and_shape.GetShape();
|
||||||
|
|
||||||
|
auto memory_info =
|
||||||
|
Ort::MemoryInfo::CreateCpu(OrtDeviceAllocator, OrtMemTypeDefault);
|
||||||
|
|
||||||
|
return Ort::Value::CreateTensor(memory_info, v->GetTensorMutableData<float>(),
|
||||||
|
type_and_shape.GetElementCount(),
|
||||||
|
shape.data(), shape.size());
|
||||||
|
}
|
||||||
|
|
||||||
} // namespace sherpa_onnx
|
} // namespace sherpa_onnx
|
||||||
|
|||||||
@@ -55,6 +55,9 @@ void GetOutputNames(Ort::Session *sess, std::vector<std::string> *output_names,
|
|||||||
void PrintModelMetadata(std::ostream &os,
|
void PrintModelMetadata(std::ostream &os,
|
||||||
const Ort::ModelMetadata &meta_data); // NOLINT
|
const Ort::ModelMetadata &meta_data); // NOLINT
|
||||||
|
|
||||||
|
// Return a shallow copy of v
|
||||||
|
Ort::Value Clone(Ort::Value *v);
|
||||||
|
|
||||||
} // namespace sherpa_onnx
|
} // namespace sherpa_onnx
|
||||||
|
|
||||||
#endif // SHERPA_ONNX_CSRC_ONNX_UTILS_H_
|
#endif // SHERPA_ONNX_CSRC_ONNX_UTILS_H_
|
||||||
|
|||||||
@@ -9,8 +9,8 @@
|
|||||||
#include <vector>
|
#include <vector>
|
||||||
|
|
||||||
#include "kaldi-native-fbank/csrc/online-feature.h"
|
#include "kaldi-native-fbank/csrc/online-feature.h"
|
||||||
#include "sherpa-onnx/csrc/decode.h"
|
|
||||||
#include "sherpa-onnx/csrc/features.h"
|
#include "sherpa-onnx/csrc/features.h"
|
||||||
|
#include "sherpa-onnx/csrc/online-transducer-greedy-search-decoder.h"
|
||||||
#include "sherpa-onnx/csrc/online-transducer-model-config.h"
|
#include "sherpa-onnx/csrc/online-transducer-model-config.h"
|
||||||
#include "sherpa-onnx/csrc/online-transducer-model.h"
|
#include "sherpa-onnx/csrc/online-transducer-model.h"
|
||||||
#include "sherpa-onnx/csrc/symbol-table.h"
|
#include "sherpa-onnx/csrc/symbol-table.h"
|
||||||
@@ -64,8 +64,6 @@ for a list of pre-trained models to download.
|
|||||||
|
|
||||||
std::vector<Ort::Value> states = model->GetEncoderInitStates();
|
std::vector<Ort::Value> states = model->GetEncoderInitStates();
|
||||||
|
|
||||||
std::vector<int64_t> hyp(model->ContextSize(), 0);
|
|
||||||
|
|
||||||
int32_t expected_sampling_rate = 16000;
|
int32_t expected_sampling_rate = 16000;
|
||||||
|
|
||||||
bool is_ok = false;
|
bool is_ok = false;
|
||||||
@@ -100,6 +98,10 @@ for a list of pre-trained models to download.
|
|||||||
|
|
||||||
std::array<int64_t, 3> x_shape{1, chunk_size, feature_dim};
|
std::array<int64_t, 3> x_shape{1, chunk_size, feature_dim};
|
||||||
|
|
||||||
|
sherpa_onnx::OnlineTransducerGreedySearchDecoder decoder(model.get());
|
||||||
|
std::vector<sherpa_onnx::OnlineTransducerDecoderResult> result = {
|
||||||
|
decoder.GetEmptyResult()};
|
||||||
|
|
||||||
for (int32_t start = 0; start + chunk_size < num_frames;
|
for (int32_t start = 0; start + chunk_size < num_frames;
|
||||||
start += chunk_shift) {
|
start += chunk_shift) {
|
||||||
std::vector<float> features = feat_extractor.GetFrames(start, chunk_size);
|
std::vector<float> features = feat_extractor.GetFrames(start, chunk_size);
|
||||||
@@ -109,8 +111,10 @@ for a list of pre-trained models to download.
|
|||||||
x_shape.data(), x_shape.size());
|
x_shape.data(), x_shape.size());
|
||||||
auto pair = model->RunEncoder(std::move(x), states);
|
auto pair = model->RunEncoder(std::move(x), states);
|
||||||
states = std::move(pair.second);
|
states = std::move(pair.second);
|
||||||
sherpa_onnx::GreedySearch(model.get(), std::move(pair.first), &hyp);
|
decoder.Decode(std::move(pair.first), &result);
|
||||||
}
|
}
|
||||||
|
decoder.StripLeadingBlanks(&result[0]);
|
||||||
|
const auto &hyp = result[0].tokens;
|
||||||
std::string text;
|
std::string text;
|
||||||
for (size_t i = model->ContextSize(); i != hyp.size(); ++i) {
|
for (size_t i = model->ContextSize(); i != hyp.size(); ++i) {
|
||||||
text += sym[hyp[i]];
|
text += sym[hyp[i]];
|
||||||
|
|||||||
Reference in New Issue
Block a user