Add RNN LM rescore for offline ASR with modified_beam_search (#125)
This commit is contained in:
1
.gitignore
vendored
1
.gitignore
vendored
@@ -56,3 +56,4 @@ run-offline-decode-files.sh
|
||||
sherpa-onnx-nemo-ctc-en-citrinet-512
|
||||
run-offline-decode-files-nemo-ctc.sh
|
||||
*.jar
|
||||
sherpa-onnx-nemo-ctc-*
|
||||
|
||||
@@ -51,6 +51,11 @@ if(DEFINED ANDROID_ABI)
|
||||
set(SHERPA_ONNX_ENABLE_JNI ON CACHE BOOL "" FORCE)
|
||||
endif()
|
||||
|
||||
if(SHERPA_ONNX_ENABLE_PYTHON AND NOT BUILD_SHARED_LIBS)
|
||||
message(STATUS "Set BUILD_SHARED_LIBS to ON since SHERPA_ONNX_ENABLE_PYTHON is ON")
|
||||
set(BUILD_SHARED_LIBS ON CACHE BOOL "" FORCE)
|
||||
endif()
|
||||
|
||||
if(SHERPA_ONNX_ENABLE_JNI AND NOT BUILD_SHARED_LIBS)
|
||||
message(STATUS "Set BUILD_SHARED_LIBS to ON since SHERPA_ONNX_ENABLE_JNI is ON")
|
||||
set(BUILD_SHARED_LIBS ON CACHE BOOL "" FORCE)
|
||||
|
||||
@@ -18,6 +18,8 @@ set(sources
|
||||
hypothesis.cc
|
||||
offline-ctc-greedy-search-decoder.cc
|
||||
offline-ctc-model.cc
|
||||
offline-lm-config.cc
|
||||
offline-lm.cc
|
||||
offline-model-config.cc
|
||||
offline-nemo-enc-dec-ctc-model-config.cc
|
||||
offline-nemo-enc-dec-ctc-model.cc
|
||||
@@ -26,10 +28,13 @@ set(sources
|
||||
offline-paraformer-model.cc
|
||||
offline-recognizer-impl.cc
|
||||
offline-recognizer.cc
|
||||
offline-rnn-lm.cc
|
||||
offline-stream.cc
|
||||
offline-transducer-greedy-search-decoder.cc
|
||||
offline-transducer-model-config.cc
|
||||
offline-transducer-model.cc
|
||||
offline-transducer-modified-beam-search-decoder.cc
|
||||
online-lm-config.cc
|
||||
online-lstm-transducer-model.cc
|
||||
online-recognizer.cc
|
||||
online-stream.cc
|
||||
|
||||
@@ -17,6 +17,9 @@ void Hypotheses::Add(Hypothesis hyp) {
|
||||
hyps_dict_[key] = std::move(hyp);
|
||||
} else {
|
||||
it->second.log_prob = LogAdd<double>()(it->second.log_prob, hyp.log_prob);
|
||||
|
||||
it->second.lm_log_prob =
|
||||
LogAdd<double>()(it->second.lm_log_prob, hyp.lm_log_prob);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -24,8 +27,8 @@ Hypothesis Hypotheses::GetMostProbable(bool length_norm) const {
|
||||
if (length_norm == false) {
|
||||
return std::max_element(hyps_dict_.begin(), hyps_dict_.end(),
|
||||
[](const auto &left, auto &right) -> bool {
|
||||
return left.second.log_prob <
|
||||
right.second.log_prob;
|
||||
return left.second.TotalLogProb() <
|
||||
right.second.TotalLogProb();
|
||||
})
|
||||
->second;
|
||||
} else {
|
||||
@@ -33,8 +36,8 @@ Hypothesis Hypotheses::GetMostProbable(bool length_norm) const {
|
||||
return std::max_element(
|
||||
hyps_dict_.begin(), hyps_dict_.end(),
|
||||
[](const auto &left, const auto &right) -> bool {
|
||||
return left.second.log_prob / left.second.ys.size() <
|
||||
right.second.log_prob / right.second.ys.size();
|
||||
return left.second.TotalLogProb() / left.second.ys.size() <
|
||||
right.second.TotalLogProb() / right.second.ys.size();
|
||||
})
|
||||
->second;
|
||||
}
|
||||
@@ -47,15 +50,16 @@ std::vector<Hypothesis> Hypotheses::GetTopK(int32_t k, bool length_norm) const {
|
||||
std::vector<Hypothesis> all_hyps = Vec();
|
||||
|
||||
if (length_norm == false) {
|
||||
std::partial_sort(
|
||||
all_hyps.begin(), all_hyps.begin() + k, all_hyps.end(),
|
||||
[](const auto &a, const auto &b) { return a.log_prob > b.log_prob; });
|
||||
std::partial_sort(all_hyps.begin(), all_hyps.begin() + k, all_hyps.end(),
|
||||
[](const auto &a, const auto &b) {
|
||||
return a.TotalLogProb() > b.TotalLogProb();
|
||||
});
|
||||
} else {
|
||||
// for length_norm is true
|
||||
std::partial_sort(all_hyps.begin(), all_hyps.begin() + k, all_hyps.end(),
|
||||
[](const auto &a, const auto &b) {
|
||||
return a.log_prob / a.ys.size() >
|
||||
b.log_prob / b.ys.size();
|
||||
return a.TotalLogProb() / a.ys.size() >
|
||||
b.TotalLogProb() / b.ys.size();
|
||||
});
|
||||
}
|
||||
|
||||
|
||||
@@ -25,14 +25,20 @@ struct Hypothesis {
|
||||
std::vector<int32_t> timestamps;
|
||||
|
||||
// The total score of ys in log space.
|
||||
// It contains only acoustic scores
|
||||
double log_prob = 0;
|
||||
|
||||
// LM log prob if any.
|
||||
double lm_log_prob = 0;
|
||||
|
||||
int32_t num_trailing_blanks = 0;
|
||||
|
||||
Hypothesis() = default;
|
||||
Hypothesis(const std::vector<int64_t> &ys, double log_prob)
|
||||
: ys(ys), log_prob(log_prob) {}
|
||||
|
||||
double TotalLogProb() const { return log_prob + lm_log_prob; }
|
||||
|
||||
// If two Hypotheses have the same `Key`, then they contain
|
||||
// the same token sequence.
|
||||
std::string Key() const {
|
||||
@@ -94,6 +100,9 @@ class Hypotheses {
|
||||
const auto begin() const { return hyps_dict_.begin(); }
|
||||
const auto end() const { return hyps_dict_.end(); }
|
||||
|
||||
auto begin() { return hyps_dict_.begin(); }
|
||||
auto end() { return hyps_dict_.end(); }
|
||||
|
||||
void Clear() { hyps_dict_.clear(); }
|
||||
|
||||
private:
|
||||
|
||||
@@ -88,6 +88,16 @@ void LogSoftmax(T *input, int32_t input_len) {
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
void LogSoftmax(T *in, int32_t w, int32_t h) {
|
||||
for (int32_t i = 0; i != h; ++i) {
|
||||
LogSoftmax(in, w);
|
||||
in += w;
|
||||
}
|
||||
}
|
||||
|
||||
// TODO(fangjun): use std::partial_sort to replace std::sort.
|
||||
// Remember also to fix sherpa-ncnn
|
||||
template <class T>
|
||||
std::vector<int32_t> TopkIndex(const T *vec, int32_t size, int32_t topk) {
|
||||
std::vector<int32_t> vec_index(size);
|
||||
|
||||
38
sherpa-onnx/csrc/offline-lm-config.cc
Normal file
38
sherpa-onnx/csrc/offline-lm-config.cc
Normal file
@@ -0,0 +1,38 @@
|
||||
// sherpa-onnx/csrc/offline-lm-config.cc
|
||||
//
|
||||
// Copyright (c) 2023 Xiaomi Corporation
|
||||
|
||||
#include "sherpa-onnx/csrc/offline-lm-config.h"
|
||||
|
||||
#include <string>
|
||||
|
||||
#include "sherpa-onnx/csrc/file-utils.h"
|
||||
#include "sherpa-onnx/csrc/macros.h"
|
||||
|
||||
namespace sherpa_onnx {
|
||||
|
||||
void OfflineLMConfig::Register(ParseOptions *po) {
|
||||
po->Register("lm", &model, "Path to LM model.");
|
||||
po->Register("lm-scale", &scale, "LM scale.");
|
||||
}
|
||||
|
||||
bool OfflineLMConfig::Validate() const {
|
||||
if (!FileExists(model)) {
|
||||
SHERPA_ONNX_LOGE("%s does not exist", model.c_str());
|
||||
return false;
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
std::string OfflineLMConfig::ToString() const {
|
||||
std::ostringstream os;
|
||||
|
||||
os << "OfflineLMConfig(";
|
||||
os << "model=\"" << model << "\", ";
|
||||
os << "scale=" << scale << ")";
|
||||
|
||||
return os.str();
|
||||
}
|
||||
|
||||
} // namespace sherpa_onnx
|
||||
33
sherpa-onnx/csrc/offline-lm-config.h
Normal file
33
sherpa-onnx/csrc/offline-lm-config.h
Normal file
@@ -0,0 +1,33 @@
|
||||
// sherpa-onnx/csrc/offline-lm-config.h
|
||||
//
|
||||
// Copyright (c) 2023 Xiaomi Corporation
|
||||
#ifndef SHERPA_ONNX_CSRC_OFFLINE_LM_CONFIG_H_
|
||||
#define SHERPA_ONNX_CSRC_OFFLINE_LM_CONFIG_H_
|
||||
|
||||
#include <string>
|
||||
|
||||
#include "sherpa-onnx/csrc/parse-options.h"
|
||||
|
||||
namespace sherpa_onnx {
|
||||
|
||||
struct OfflineLMConfig {
|
||||
// path to the onnx model
|
||||
std::string model;
|
||||
|
||||
// LM scale
|
||||
float scale = 1.0;
|
||||
|
||||
OfflineLMConfig() = default;
|
||||
|
||||
OfflineLMConfig(const std::string &model, float scale)
|
||||
: model(model), scale(scale) {}
|
||||
|
||||
void Register(ParseOptions *po);
|
||||
bool Validate() const;
|
||||
|
||||
std::string ToString() const;
|
||||
};
|
||||
|
||||
} // namespace sherpa_onnx
|
||||
|
||||
#endif // SHERPA_ONNX_CSRC_OFFLINE_LM_CONFIG_H_
|
||||
71
sherpa-onnx/csrc/offline-lm.cc
Normal file
71
sherpa-onnx/csrc/offline-lm.cc
Normal file
@@ -0,0 +1,71 @@
|
||||
// sherpa-onnx/csrc/offline-lm.cc
|
||||
//
|
||||
// Copyright (c) 2023 Xiaomi Corporation
|
||||
|
||||
#include "sherpa-onnx/csrc/offline-lm.h"
|
||||
|
||||
#include <algorithm>
|
||||
#include <utility>
|
||||
#include <vector>
|
||||
|
||||
#include "sherpa-onnx/csrc/offline-rnn-lm.h"
|
||||
|
||||
namespace sherpa_onnx {
|
||||
|
||||
std::unique_ptr<OfflineLM> OfflineLM::Create(const OfflineLMConfig &config) {
|
||||
return std::make_unique<OfflineRnnLM>(config);
|
||||
}
|
||||
|
||||
void OfflineLM::ComputeLMScore(float scale, int32_t context_size,
|
||||
std::vector<Hypotheses> *hyps) {
|
||||
// compute the max token seq so that we know how much space to allocate
|
||||
int32_t max_token_seq = 0;
|
||||
int32_t num_hyps = 0;
|
||||
|
||||
// we subtract context_size below since each token sequence is prepended
|
||||
// with context_size blanks
|
||||
for (const auto &h : *hyps) {
|
||||
num_hyps += h.Size();
|
||||
for (const auto &t : h) {
|
||||
max_token_seq =
|
||||
std::max<int32_t>(max_token_seq, t.second.ys.size() - context_size);
|
||||
}
|
||||
}
|
||||
|
||||
Ort::AllocatorWithDefaultOptions allocator;
|
||||
std::array<int64_t, 2> x_shape{num_hyps, max_token_seq};
|
||||
Ort::Value x = Ort::Value::CreateTensor<int64_t>(allocator, x_shape.data(),
|
||||
x_shape.size());
|
||||
|
||||
std::array<int64_t, 1> x_lens_shape{num_hyps};
|
||||
Ort::Value x_lens = Ort::Value::CreateTensor<int64_t>(
|
||||
allocator, x_lens_shape.data(), x_lens_shape.size());
|
||||
|
||||
int64_t *p = x.GetTensorMutableData<int64_t>();
|
||||
std::fill(p, p + num_hyps * max_token_seq, 0);
|
||||
|
||||
int64_t *p_lens = x_lens.GetTensorMutableData<int64_t>();
|
||||
|
||||
for (const auto &h : *hyps) {
|
||||
for (const auto &t : h) {
|
||||
const auto &ys = t.second.ys;
|
||||
int32_t len = ys.size() - context_size;
|
||||
std::copy(ys.begin() + context_size, ys.end(), p);
|
||||
*p_lens = len;
|
||||
|
||||
p += max_token_seq;
|
||||
++p_lens;
|
||||
}
|
||||
}
|
||||
auto negative_loglike = Rescore(std::move(x), std::move(x_lens));
|
||||
const float *p_nll = negative_loglike.GetTensorData<float>();
|
||||
for (auto &h : *hyps) {
|
||||
for (auto &t : h) {
|
||||
// Use -scale here since we want to change negative loglike to loglike.
|
||||
t.second.lm_log_prob = -scale * (*p_nll);
|
||||
++p_nll;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace sherpa_onnx
|
||||
46
sherpa-onnx/csrc/offline-lm.h
Normal file
46
sherpa-onnx/csrc/offline-lm.h
Normal file
@@ -0,0 +1,46 @@
|
||||
// sherpa-onnx/csrc/offline-lm.h
|
||||
//
|
||||
// Copyright (c) 2023 Xiaomi Corporation
|
||||
|
||||
#ifndef SHERPA_ONNX_CSRC_OFFLINE_LM_H_
|
||||
#define SHERPA_ONNX_CSRC_OFFLINE_LM_H_
|
||||
|
||||
#include <memory>
|
||||
#include <vector>
|
||||
|
||||
#include "onnxruntime_cxx_api.h" // NOLINT
|
||||
#include "sherpa-onnx/csrc/hypothesis.h"
|
||||
#include "sherpa-onnx/csrc/offline-lm-config.h"
|
||||
|
||||
namespace sherpa_onnx {
|
||||
|
||||
class OfflineLM {
|
||||
public:
|
||||
virtual ~OfflineLM() = default;
|
||||
|
||||
static std::unique_ptr<OfflineLM> Create(const OfflineLMConfig &config);
|
||||
|
||||
/** Rescore a batch of sentences.
|
||||
*
|
||||
* @param x A 2-D tensor of shape (N, L) with data type int64.
|
||||
* @param x_lens A 1-D tensor of shape (N,) with data type int64.
|
||||
* It contains number of valid tokens in x before padding.
|
||||
* @return Return a 1-D tensor of shape (N,) containing the negative log
|
||||
* likelihood of each utterance. Its data type is float32.
|
||||
*
|
||||
* Caution: It returns negative log likelihood (nll), not log likelihood
|
||||
*/
|
||||
virtual Ort::Value Rescore(Ort::Value x, Ort::Value x_lens) = 0;
|
||||
|
||||
// This function updates hyp.lm_lob_prob of hyps.
|
||||
//
|
||||
// @param scale LM score
|
||||
// @param context_size Context size of the transducer decoder model
|
||||
// @param hyps It is changed in-place.
|
||||
void ComputeLMScore(float scale, int32_t context_size,
|
||||
std::vector<Hypotheses> *hyps);
|
||||
};
|
||||
|
||||
} // namespace sherpa_onnx
|
||||
|
||||
#endif // SHERPA_ONNX_CSRC_OFFLINE_LM_H_
|
||||
@@ -16,6 +16,7 @@
|
||||
#include "sherpa-onnx/csrc/offline-transducer-decoder.h"
|
||||
#include "sherpa-onnx/csrc/offline-transducer-greedy-search-decoder.h"
|
||||
#include "sherpa-onnx/csrc/offline-transducer-model.h"
|
||||
#include "sherpa-onnx/csrc/offline-transducer-modified-beam-search-decoder.h"
|
||||
#include "sherpa-onnx/csrc/pad-sequence.h"
|
||||
#include "sherpa-onnx/csrc/symbol-table.h"
|
||||
|
||||
@@ -57,8 +58,13 @@ class OfflineRecognizerTransducerImpl : public OfflineRecognizerImpl {
|
||||
decoder_ =
|
||||
std::make_unique<OfflineTransducerGreedySearchDecoder>(model_.get());
|
||||
} else if (config_.decoding_method == "modified_beam_search") {
|
||||
SHERPA_ONNX_LOGE("TODO: modified_beam_search is to be implemented");
|
||||
exit(-1);
|
||||
if (!config_.lm_config.model.empty()) {
|
||||
lm_ = OfflineLM::Create(config.lm_config);
|
||||
}
|
||||
|
||||
decoder_ = std::make_unique<OfflineTransducerModifiedBeamSearchDecoder>(
|
||||
model_.get(), lm_.get(), config_.max_active_paths,
|
||||
config_.lm_config.scale);
|
||||
} else {
|
||||
SHERPA_ONNX_LOGE("Unsupported decoding method: %s",
|
||||
config_.decoding_method.c_str());
|
||||
@@ -127,6 +133,7 @@ class OfflineRecognizerTransducerImpl : public OfflineRecognizerImpl {
|
||||
SymbolTable symbol_table_;
|
||||
std::unique_ptr<OfflineTransducerModel> model_;
|
||||
std::unique_ptr<OfflineTransducerDecoder> decoder_;
|
||||
std::unique_ptr<OfflineLM> lm_;
|
||||
};
|
||||
|
||||
} // namespace sherpa_onnx
|
||||
|
||||
@@ -8,6 +8,7 @@
|
||||
|
||||
#include "sherpa-onnx/csrc/file-utils.h"
|
||||
#include "sherpa-onnx/csrc/macros.h"
|
||||
#include "sherpa-onnx/csrc/offline-lm-config.h"
|
||||
#include "sherpa-onnx/csrc/offline-recognizer-impl.h"
|
||||
|
||||
namespace sherpa_onnx {
|
||||
@@ -15,13 +16,28 @@ namespace sherpa_onnx {
|
||||
void OfflineRecognizerConfig::Register(ParseOptions *po) {
|
||||
feat_config.Register(po);
|
||||
model_config.Register(po);
|
||||
lm_config.Register(po);
|
||||
|
||||
po->Register("decoding-method", &decoding_method,
|
||||
"decoding method,"
|
||||
"Valid values: greedy_search.");
|
||||
po->Register(
|
||||
"decoding-method", &decoding_method,
|
||||
"decoding method,"
|
||||
"Valid values: greedy_search, modified_beam_search. "
|
||||
"modified_beam_search is applicable only for transducer models.");
|
||||
|
||||
po->Register("max-active-paths", &max_active_paths,
|
||||
"Used only when decoding_method is modified_beam_search");
|
||||
}
|
||||
|
||||
bool OfflineRecognizerConfig::Validate() const {
|
||||
if (decoding_method == "modified_beam_search" && !lm_config.model.empty()) {
|
||||
if (max_active_paths <= 0) {
|
||||
SHERPA_ONNX_LOGE("max_active_paths is less than 0! Given: %d",
|
||||
max_active_paths);
|
||||
return false;
|
||||
}
|
||||
if (!lm_config.Validate()) return false;
|
||||
}
|
||||
|
||||
return model_config.Validate();
|
||||
}
|
||||
|
||||
@@ -31,7 +47,9 @@ std::string OfflineRecognizerConfig::ToString() const {
|
||||
os << "OfflineRecognizerConfig(";
|
||||
os << "feat_config=" << feat_config.ToString() << ", ";
|
||||
os << "model_config=" << model_config.ToString() << ", ";
|
||||
os << "decoding_method=\"" << decoding_method << "\")";
|
||||
os << "lm_config=" << lm_config.ToString() << ", ";
|
||||
os << "decoding_method=\"" << decoding_method << "\", ";
|
||||
os << "max_active_paths=" << max_active_paths << ")";
|
||||
|
||||
return os.str();
|
||||
}
|
||||
|
||||
@@ -9,6 +9,7 @@
|
||||
#include <string>
|
||||
#include <vector>
|
||||
|
||||
#include "sherpa-onnx/csrc/offline-lm-config.h"
|
||||
#include "sherpa-onnx/csrc/offline-model-config.h"
|
||||
#include "sherpa-onnx/csrc/offline-stream.h"
|
||||
#include "sherpa-onnx/csrc/offline-transducer-model-config.h"
|
||||
@@ -21,18 +22,24 @@ struct OfflineRecognitionResult;
|
||||
struct OfflineRecognizerConfig {
|
||||
OfflineFeatureExtractorConfig feat_config;
|
||||
OfflineModelConfig model_config;
|
||||
OfflineLMConfig lm_config;
|
||||
|
||||
std::string decoding_method = "greedy_search";
|
||||
int32_t max_active_paths = 4;
|
||||
// only greedy_search is implemented
|
||||
// TODO(fangjun): Implement modified_beam_search
|
||||
|
||||
OfflineRecognizerConfig() = default;
|
||||
OfflineRecognizerConfig(const OfflineFeatureExtractorConfig &feat_config,
|
||||
const OfflineModelConfig &model_config,
|
||||
const std::string &decoding_method)
|
||||
const OfflineLMConfig &lm_config,
|
||||
const std::string &decoding_method,
|
||||
int32_t max_active_paths)
|
||||
: feat_config(feat_config),
|
||||
model_config(model_config),
|
||||
decoding_method(decoding_method) {}
|
||||
lm_config(lm_config),
|
||||
decoding_method(decoding_method),
|
||||
max_active_paths(max_active_paths) {}
|
||||
|
||||
void Register(ParseOptions *po);
|
||||
bool Validate() const;
|
||||
|
||||
74
sherpa-onnx/csrc/offline-rnn-lm.cc
Normal file
74
sherpa-onnx/csrc/offline-rnn-lm.cc
Normal file
@@ -0,0 +1,74 @@
|
||||
// sherpa-onnx/csrc/offline-rnn-lm.cc
|
||||
//
|
||||
// Copyright (c) 2023 Xiaomi Corporation
|
||||
|
||||
#include "sherpa-onnx/csrc/offline-rnn-lm.h"
|
||||
|
||||
#include <string>
|
||||
#include <utility>
|
||||
#include <vector>
|
||||
|
||||
#include "onnxruntime_cxx_api.h" // NOLINT
|
||||
#include "sherpa-onnx/csrc/macros.h"
|
||||
#include "sherpa-onnx/csrc/onnx-utils.h"
|
||||
#include "sherpa-onnx/csrc/text-utils.h"
|
||||
|
||||
namespace sherpa_onnx {
|
||||
|
||||
class OfflineRnnLM::Impl {
|
||||
public:
|
||||
explicit Impl(const OfflineLMConfig &config)
|
||||
: config_(config),
|
||||
env_(ORT_LOGGING_LEVEL_ERROR),
|
||||
sess_opts_{},
|
||||
allocator_{} {
|
||||
Init(config);
|
||||
}
|
||||
|
||||
Ort::Value Rescore(Ort::Value x, Ort::Value x_lens) {
|
||||
std::array<Ort::Value, 2> inputs = {std::move(x), std::move(x_lens)};
|
||||
|
||||
auto out =
|
||||
sess_->Run({}, input_names_ptr_.data(), inputs.data(), inputs.size(),
|
||||
output_names_ptr_.data(), output_names_ptr_.size());
|
||||
|
||||
return std::move(out[0]);
|
||||
}
|
||||
|
||||
private:
|
||||
void Init(const OfflineLMConfig &config) {
|
||||
auto buf = ReadFile(config_.model);
|
||||
|
||||
sess_ = std::make_unique<Ort::Session>(env_, buf.data(), buf.size(),
|
||||
sess_opts_);
|
||||
|
||||
GetInputNames(sess_.get(), &input_names_, &input_names_ptr_);
|
||||
|
||||
GetOutputNames(sess_.get(), &output_names_, &output_names_ptr_);
|
||||
}
|
||||
|
||||
private:
|
||||
OfflineLMConfig config_;
|
||||
Ort::Env env_;
|
||||
Ort::SessionOptions sess_opts_;
|
||||
Ort::AllocatorWithDefaultOptions allocator_;
|
||||
|
||||
std::unique_ptr<Ort::Session> sess_;
|
||||
|
||||
std::vector<std::string> input_names_;
|
||||
std::vector<const char *> input_names_ptr_;
|
||||
|
||||
std::vector<std::string> output_names_;
|
||||
std::vector<const char *> output_names_ptr_;
|
||||
};
|
||||
|
||||
OfflineRnnLM::OfflineRnnLM(const OfflineLMConfig &config)
|
||||
: impl_(std::make_unique<Impl>(config)) {}
|
||||
|
||||
OfflineRnnLM::~OfflineRnnLM() = default;
|
||||
|
||||
Ort::Value OfflineRnnLM::Rescore(Ort::Value x, Ort::Value x_lens) {
|
||||
return impl_->Rescore(std::move(x), std::move(x_lens));
|
||||
}
|
||||
|
||||
} // namespace sherpa_onnx
|
||||
41
sherpa-onnx/csrc/offline-rnn-lm.h
Normal file
41
sherpa-onnx/csrc/offline-rnn-lm.h
Normal file
@@ -0,0 +1,41 @@
|
||||
// sherpa-onnx/csrc/offline-rnn-lm.h
|
||||
//
|
||||
// Copyright (c) 2023 Xiaomi Corporation
|
||||
|
||||
#ifndef SHERPA_ONNX_CSRC_OFFLINE_RNN_LM_H_
|
||||
#define SHERPA_ONNX_CSRC_OFFLINE_RNN_LM_H_
|
||||
|
||||
#include <memory>
|
||||
|
||||
#include "onnxruntime_cxx_api.h" // NOLINT
|
||||
#include "sherpa-onnx/csrc/offline-lm-config.h"
|
||||
#include "sherpa-onnx/csrc/offline-lm.h"
|
||||
|
||||
namespace sherpa_onnx {
|
||||
|
||||
class OfflineRnnLM : public OfflineLM {
|
||||
public:
|
||||
~OfflineRnnLM() override;
|
||||
|
||||
explicit OfflineRnnLM(const OfflineLMConfig &config);
|
||||
|
||||
/** Rescore a batch of sentences.
|
||||
*
|
||||
* @param x A 2-D tensor of shape (N, L) with data type int64.
|
||||
* @param x_lens A 1-D tensor of shape (N,) with data type int64.
|
||||
* It contains number of valid tokens in x before padding.
|
||||
* @return Return a 1-D tensor of shape (N,) containing the log likelihood
|
||||
* of each utterance. Its data type is float32.
|
||||
*
|
||||
* Caution: It returns log likelihood, not negative log likelihood (nll).
|
||||
*/
|
||||
Ort::Value Rescore(Ort::Value x, Ort::Value x_lens) override;
|
||||
|
||||
private:
|
||||
class Impl;
|
||||
std::unique_ptr<Impl> impl_;
|
||||
};
|
||||
|
||||
} // namespace sherpa_onnx
|
||||
|
||||
#endif // SHERPA_ONNX_CSRC_OFFLINE_RNN_LM_H_
|
||||
@@ -95,6 +95,30 @@ class OfflineTransducerModel::Impl {
|
||||
std::copy(begin, end, p);
|
||||
p += context_size;
|
||||
}
|
||||
|
||||
return decoder_input;
|
||||
}
|
||||
|
||||
Ort::Value BuildDecoderInput(const std::vector<Hypothesis> &results,
|
||||
int32_t end_index) const {
|
||||
assert(end_index <= results.size());
|
||||
|
||||
int32_t batch_size = end_index;
|
||||
int32_t context_size = ContextSize();
|
||||
std::array<int64_t, 2> shape{batch_size, context_size};
|
||||
|
||||
Ort::Value decoder_input = Ort::Value::CreateTensor<int64_t>(
|
||||
Allocator(), shape.data(), shape.size());
|
||||
int64_t *p = decoder_input.GetTensorMutableData<int64_t>();
|
||||
|
||||
for (int32_t i = 0; i != batch_size; ++i) {
|
||||
const auto &r = results[i];
|
||||
const int64_t *begin = r.ys.data() + r.ys.size() - context_size;
|
||||
const int64_t *end = r.ys.data() + r.ys.size();
|
||||
std::copy(begin, end, p);
|
||||
p += context_size;
|
||||
}
|
||||
|
||||
return decoder_input;
|
||||
}
|
||||
|
||||
@@ -234,4 +258,9 @@ Ort::Value OfflineTransducerModel::BuildDecoderInput(
|
||||
return impl_->BuildDecoderInput(results, end_index);
|
||||
}
|
||||
|
||||
Ort::Value OfflineTransducerModel::BuildDecoderInput(
|
||||
const std::vector<Hypothesis> &results, int32_t end_index) const {
|
||||
return impl_->BuildDecoderInput(results, end_index);
|
||||
}
|
||||
|
||||
} // namespace sherpa_onnx
|
||||
|
||||
@@ -9,6 +9,7 @@
|
||||
#include <vector>
|
||||
|
||||
#include "onnxruntime_cxx_api.h" // NOLINT
|
||||
#include "sherpa-onnx/csrc/hypothesis.h"
|
||||
#include "sherpa-onnx/csrc/offline-model-config.h"
|
||||
|
||||
namespace sherpa_onnx {
|
||||
@@ -79,13 +80,16 @@ class OfflineTransducerModel {
|
||||
*
|
||||
* @param results Current decoded results.
|
||||
* @param end_index We only use results[0:end_index] to build
|
||||
* the decoder_input.
|
||||
* the decoder_input. results[end_index] is not used.
|
||||
* @return Return a tensor of shape (results.size(), ContextSize())
|
||||
*/
|
||||
Ort::Value BuildDecoderInput(
|
||||
const std::vector<OfflineTransducerDecoderResult> &results,
|
||||
int32_t end_index) const;
|
||||
|
||||
Ort::Value BuildDecoderInput(const std::vector<Hypothesis> &results,
|
||||
int32_t end_index) const;
|
||||
|
||||
private:
|
||||
class Impl;
|
||||
std::unique_ptr<Impl> impl_;
|
||||
|
||||
@@ -0,0 +1,163 @@
|
||||
// sherpa-onnx/csrc/offline-transducer-modified-beam-search-decoder.cc
|
||||
//
|
||||
// Copyright (c) 2023 Xiaomi Corporation
|
||||
|
||||
#include "sherpa-onnx/csrc/offline-transducer-modified-beam-search-decoder.h"
|
||||
|
||||
#include <deque>
|
||||
#include <utility>
|
||||
#include <vector>
|
||||
|
||||
#include "sherpa-onnx/csrc/hypothesis.h"
|
||||
#include "sherpa-onnx/csrc/onnx-utils.h"
|
||||
#include "sherpa-onnx/csrc/packed-sequence.h"
|
||||
#include "sherpa-onnx/csrc/slice.h"
|
||||
|
||||
namespace sherpa_onnx {
|
||||
|
||||
static std::vector<int32_t> GetHypsRowSplits(
|
||||
const std::vector<Hypotheses> &hyps) {
|
||||
std::vector<int32_t> row_splits;
|
||||
row_splits.reserve(hyps.size() + 1);
|
||||
|
||||
row_splits.push_back(0);
|
||||
int32_t s = 0;
|
||||
for (const auto &h : hyps) {
|
||||
s += h.Size();
|
||||
row_splits.push_back(s);
|
||||
}
|
||||
|
||||
return row_splits;
|
||||
}
|
||||
|
||||
std::vector<OfflineTransducerDecoderResult>
|
||||
OfflineTransducerModifiedBeamSearchDecoder::Decode(
|
||||
Ort::Value encoder_out, Ort::Value encoder_out_length) {
|
||||
PackedSequence packed_encoder_out = PackPaddedSequence(
|
||||
model_->Allocator(), &encoder_out, &encoder_out_length);
|
||||
|
||||
int32_t batch_size =
|
||||
static_cast<int32_t>(packed_encoder_out.sorted_indexes.size());
|
||||
|
||||
int32_t vocab_size = model_->VocabSize();
|
||||
int32_t context_size = model_->ContextSize();
|
||||
|
||||
std::vector<int64_t> blanks(context_size, 0);
|
||||
Hypotheses blank_hyp({{blanks, 0}});
|
||||
|
||||
std::deque<Hypotheses> finalized;
|
||||
std::vector<Hypotheses> cur(batch_size, blank_hyp);
|
||||
std::vector<Hypothesis> prev;
|
||||
|
||||
int32_t start = 0;
|
||||
int32_t t = 0;
|
||||
for (auto n : packed_encoder_out.batch_sizes) {
|
||||
Ort::Value cur_encoder_out = packed_encoder_out.Get(start, n);
|
||||
start += n;
|
||||
|
||||
if (n < static_cast<int32_t>(cur.size())) {
|
||||
for (int32_t k = static_cast<int32_t>(cur.size()) - 1; k >= n; --k) {
|
||||
finalized.push_front(std::move(cur[k]));
|
||||
}
|
||||
|
||||
cur.erase(cur.begin() + n, cur.end());
|
||||
} // if (n < static_cast<int32_t>(cur.size()))
|
||||
|
||||
// Due to merging paths with identical token sequences,
|
||||
// not all utterances have "max_active_paths" paths.
|
||||
auto hyps_row_splits = GetHypsRowSplits(cur);
|
||||
int32_t num_hyps = hyps_row_splits.back();
|
||||
|
||||
prev.clear();
|
||||
prev.reserve(num_hyps);
|
||||
|
||||
for (auto &hyps : cur) {
|
||||
for (auto &h : hyps) {
|
||||
prev.push_back(std::move(h.second));
|
||||
}
|
||||
}
|
||||
cur.clear();
|
||||
cur.reserve(n);
|
||||
|
||||
auto decoder_input = model_->BuildDecoderInput(prev, num_hyps);
|
||||
// decoder_input shape: (num_hyps, context_size)
|
||||
|
||||
auto decoder_out = model_->RunDecoder(std::move(decoder_input));
|
||||
// decoder_out is (num_hyps, joiner_dim)
|
||||
|
||||
cur_encoder_out =
|
||||
Repeat(model_->Allocator(), &cur_encoder_out, hyps_row_splits);
|
||||
// now cur_encoder_out is of shape (num_hyps, joiner_dim)
|
||||
|
||||
Ort::Value logit = model_->RunJoiner(
|
||||
std::move(cur_encoder_out), Clone(model_->Allocator(), &decoder_out));
|
||||
|
||||
float *p_logit = logit.GetTensorMutableData<float>();
|
||||
LogSoftmax(p_logit, vocab_size, num_hyps);
|
||||
|
||||
// now p_logit contains log_softmax output, we rename it to p_logprob
|
||||
// to match what it actually contains
|
||||
float *p_logprob = p_logit;
|
||||
|
||||
// add log_prob of each hypothesis to p_logprob before taking top_k
|
||||
for (int32_t i = 0; i != num_hyps; ++i) {
|
||||
float log_prob = prev[i].log_prob;
|
||||
for (int32_t k = 0; k != vocab_size; ++k, ++p_logprob) {
|
||||
*p_logprob += log_prob;
|
||||
}
|
||||
}
|
||||
p_logprob = p_logit; // we changed p_logprob in the above for loop
|
||||
|
||||
// Now compute top_k for each utterance
|
||||
for (int32_t i = 0; i != n; ++i) {
|
||||
int32_t start = hyps_row_splits[i];
|
||||
int32_t end = hyps_row_splits[i + 1];
|
||||
auto topk =
|
||||
TopkIndex(p_logprob, vocab_size * (end - start), max_active_paths_);
|
||||
|
||||
Hypotheses hyps;
|
||||
for (auto k : topk) {
|
||||
int32_t hyp_index = k / vocab_size + start;
|
||||
int32_t new_token = k % vocab_size;
|
||||
Hypothesis new_hyp = prev[hyp_index];
|
||||
|
||||
if (new_token != 0) {
|
||||
// blank id is fixed to 0
|
||||
new_hyp.ys.push_back(new_token);
|
||||
new_hyp.timestamps.push_back(t);
|
||||
}
|
||||
|
||||
new_hyp.log_prob = p_logprob[k];
|
||||
hyps.Add(std::move(new_hyp));
|
||||
} // for (auto k : topk)
|
||||
p_logprob += (end - start) * vocab_size;
|
||||
cur.push_back(std::move(hyps));
|
||||
} // for (int32_t i = 0; i != n; ++i)
|
||||
|
||||
++t;
|
||||
} // for (auto n : packed_encoder_out.batch_sizes)
|
||||
|
||||
for (auto &h : finalized) {
|
||||
cur.push_back(std::move(h));
|
||||
}
|
||||
|
||||
if (lm_) {
|
||||
// use LM for rescoring
|
||||
lm_->ComputeLMScore(lm_scale_, context_size, &cur);
|
||||
}
|
||||
|
||||
std::vector<OfflineTransducerDecoderResult> unsorted_ans(batch_size);
|
||||
for (int32_t i = 0; i != batch_size; ++i) {
|
||||
Hypothesis hyp = cur[i].GetMostProbable(true);
|
||||
|
||||
auto &r = unsorted_ans[packed_encoder_out.sorted_indexes[i]];
|
||||
|
||||
// strip leading blanks
|
||||
r.tokens = {hyp.ys.begin() + context_size, hyp.ys.end()};
|
||||
r.timestamps = std::move(hyp.timestamps);
|
||||
}
|
||||
|
||||
return unsorted_ans;
|
||||
}
|
||||
|
||||
} // namespace sherpa_onnx
|
||||
@@ -0,0 +1,41 @@
|
||||
// sherpa-onnx/csrc/offline-transducer-modified-beam-search-decoder.h
|
||||
//
|
||||
// Copyright (c) 2023 Xiaomi Corporation
|
||||
|
||||
#ifndef SHERPA_ONNX_CSRC_OFFLINE_TRANSDUCER_MODIFIED_BEAM_SEARCH_DECODER_H_
|
||||
#define SHERPA_ONNX_CSRC_OFFLINE_TRANSDUCER_MODIFIED_BEAM_SEARCH_DECODER_H_
|
||||
|
||||
#include <vector>
|
||||
|
||||
#include "sherpa-onnx/csrc/offline-lm.h"
|
||||
#include "sherpa-onnx/csrc/offline-transducer-decoder.h"
|
||||
#include "sherpa-onnx/csrc/offline-transducer-model.h"
|
||||
|
||||
namespace sherpa_onnx {
|
||||
|
||||
class OfflineTransducerModifiedBeamSearchDecoder
|
||||
: public OfflineTransducerDecoder {
|
||||
public:
|
||||
OfflineTransducerModifiedBeamSearchDecoder(OfflineTransducerModel *model,
|
||||
OfflineLM *lm,
|
||||
int32_t max_active_paths,
|
||||
float lm_scale)
|
||||
: model_(model),
|
||||
lm_(lm),
|
||||
max_active_paths_(max_active_paths),
|
||||
lm_scale_(lm_scale) {}
|
||||
|
||||
std::vector<OfflineTransducerDecoderResult> Decode(
|
||||
Ort::Value encoder_out, Ort::Value encoder_out_length) override;
|
||||
|
||||
private:
|
||||
OfflineTransducerModel *model_; // Not owned
|
||||
OfflineLM *lm_; // Not owned; may be nullptr
|
||||
|
||||
int32_t max_active_paths_;
|
||||
float lm_scale_; // used only when lm_ is not nullptr
|
||||
};
|
||||
|
||||
} // namespace sherpa_onnx
|
||||
|
||||
#endif // SHERPA_ONNX_CSRC_OFFLINE_TRANSDUCER_MODIFIED_BEAM_SEARCH_DECODER_H_
|
||||
38
sherpa-onnx/csrc/online-lm-config.cc
Normal file
38
sherpa-onnx/csrc/online-lm-config.cc
Normal file
@@ -0,0 +1,38 @@
|
||||
// sherpa-onnx/csrc/online-lm-config.cc
|
||||
//
|
||||
// Copyright (c) 2023 Xiaomi Corporation
|
||||
|
||||
#include "sherpa-onnx/csrc/online-lm-config.h"
|
||||
|
||||
#include <string>
|
||||
|
||||
#include "sherpa-onnx/csrc/file-utils.h"
|
||||
#include "sherpa-onnx/csrc/macros.h"
|
||||
|
||||
namespace sherpa_onnx {
|
||||
|
||||
void OnlineLMConfig::Register(ParseOptions *po) {
|
||||
po->Register("lm", &model, "Path to LM model.");
|
||||
po->Register("lm-scale", &scale, "LM scale.");
|
||||
}
|
||||
|
||||
bool OnlineLMConfig::Validate() const {
|
||||
if (!FileExists(model)) {
|
||||
SHERPA_ONNX_LOGE("%s does not exist", model.c_str());
|
||||
return false;
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
std::string OnlineLMConfig::ToString() const {
|
||||
std::ostringstream os;
|
||||
|
||||
os << "OnlineLMConfig(";
|
||||
os << "model=\"" << model << "\", ";
|
||||
os << "scale=" << scale << ")";
|
||||
|
||||
return os.str();
|
||||
}
|
||||
|
||||
} // namespace sherpa_onnx
|
||||
33
sherpa-onnx/csrc/online-lm-config.h
Normal file
33
sherpa-onnx/csrc/online-lm-config.h
Normal file
@@ -0,0 +1,33 @@
|
||||
// sherpa-onnx/csrc/online-lm-config.h
|
||||
//
|
||||
// Copyright (c) 2023 Xiaomi Corporation
|
||||
#ifndef SHERPA_ONNX_CSRC_ONLINE_LM_CONFIG_H_
|
||||
#define SHERPA_ONNX_CSRC_ONLINE_LM_CONFIG_H_
|
||||
|
||||
#include <string>
|
||||
|
||||
#include "sherpa-onnx/csrc/parse-options.h"
|
||||
|
||||
namespace sherpa_onnx {
|
||||
|
||||
struct OnlineLMConfig {
|
||||
// path to the onnx model
|
||||
std::string model;
|
||||
|
||||
// LM scale
|
||||
float scale = 1.0;
|
||||
|
||||
OnlineLMConfig() = default;
|
||||
|
||||
OnlineLMConfig(const std::string &model, float scale)
|
||||
: model(model), scale(scale) {}
|
||||
|
||||
void Register(ParseOptions *po);
|
||||
bool Validate() const;
|
||||
|
||||
std::string ToString() const;
|
||||
};
|
||||
|
||||
} // namespace sherpa_onnx
|
||||
|
||||
#endif // SHERPA_ONNX_CSRC_ONLINE_LM_CONFIG_H_
|
||||
64
sherpa-onnx/csrc/online-lm.h
Normal file
64
sherpa-onnx/csrc/online-lm.h
Normal file
@@ -0,0 +1,64 @@
|
||||
// sherpa-onnx/csrc/online-lm.h
|
||||
//
|
||||
// Copyright (c) 2023 Xiaomi Corporation
|
||||
|
||||
#ifndef SHERPA_ONNX_CSRC_ONLINE_LM_H_
|
||||
#define SHERPA_ONNX_CSRC_ONLINE_LM_H_
|
||||
|
||||
#include <memory>
|
||||
#include <utility>
|
||||
#include <vector>
|
||||
|
||||
#include "onnxruntime_cxx_api.h" // NOLINT
|
||||
#include "sherpa-onnx/csrc/hypothesis.h"
|
||||
#include "sherpa-onnx/csrc/online-lm-config.h"
|
||||
|
||||
namespace sherpa_onnx {
|
||||
|
||||
class OnlineLM {
|
||||
public:
|
||||
virtual ~OnlineLM() = default;
|
||||
|
||||
static std::unique_ptr<OnlineLM> Create(const OnlineLMConfig &config);
|
||||
|
||||
virtual std::vector<Ort::Value> GetInitStates() = 0;
|
||||
|
||||
/** Rescore a batch of sentences.
|
||||
*
|
||||
* @param x A 2-D tensor of shape (N, L) with data type int64.
|
||||
* @param y A 2-D tensor of shape (N, L) with data type int64.
|
||||
* @param states It contains the states for the LM model
|
||||
* @return Return a pair containingo
|
||||
* - negative loglike
|
||||
* - updated states
|
||||
*
|
||||
* Caution: It returns negative log likelihood (nll), not log likelihood
|
||||
*/
|
||||
std::pair<Ort::Value, std::vector<Ort::Value>> Ort::Value Rescore(
|
||||
Ort::Value x, Ort::Value y, std::vector<Ort::Value> states) = 0;
|
||||
|
||||
// This function updates hyp.lm_lob_prob of hyps.
|
||||
//
|
||||
// @param scale LM score
|
||||
// @param context_size Context size of the transducer decoder model
|
||||
// @param hyps It is changed in-place.
|
||||
void ComputeLMScore(float scale, int32_t context_size,
|
||||
std::vector<Hypotheses> *hyps);
|
||||
/** TODO(fangjun):
|
||||
*
|
||||
* 1. Add two fields to Hypothesis
|
||||
* (a) int32_t lm_cur_pos = 0; number of scored tokens so far
|
||||
* (b) std::vector<Ort::Value> lm_states;
|
||||
* 2. When we want to score a hypothesis, we construct x and y as follows:
|
||||
*
|
||||
* std::vector x = {hyp.ys.begin() + context_size + lm_cur_pos,
|
||||
* hyp.ys.end() - 1};
|
||||
* std::vector y = {hyp.ys.begin() + context_size + lm_cur_pos + 1
|
||||
* hyp.ys.end()};
|
||||
* hyp.lm_cur_pos += hyp.ys.size() - context_size - lm_cur_pos;
|
||||
*/
|
||||
};
|
||||
|
||||
} // namespace sherpa_onnx
|
||||
|
||||
#endif // SHERPA_ONNX_CSRC_ONLINE_LM_H_
|
||||
@@ -36,38 +36,6 @@ static void UseCachedDecoderOut(
|
||||
}
|
||||
}
|
||||
|
||||
static Ort::Value Repeat(OrtAllocator *allocator, Ort::Value *cur_encoder_out,
|
||||
const std::vector<int32_t> &hyps_num_split) {
|
||||
std::vector<int64_t> cur_encoder_out_shape =
|
||||
cur_encoder_out->GetTensorTypeAndShapeInfo().GetShape();
|
||||
|
||||
std::array<int64_t, 2> ans_shape{hyps_num_split.back(),
|
||||
cur_encoder_out_shape[1]};
|
||||
|
||||
Ort::Value ans = Ort::Value::CreateTensor<float>(allocator, ans_shape.data(),
|
||||
ans_shape.size());
|
||||
|
||||
const float *src = cur_encoder_out->GetTensorData<float>();
|
||||
float *dst = ans.GetTensorMutableData<float>();
|
||||
int32_t batch_size = static_cast<int32_t>(hyps_num_split.size()) - 1;
|
||||
for (int32_t b = 0; b != batch_size; ++b) {
|
||||
int32_t cur_stream_hyps_num = hyps_num_split[b + 1] - hyps_num_split[b];
|
||||
for (int32_t i = 0; i != cur_stream_hyps_num; ++i) {
|
||||
std::copy(src, src + cur_encoder_out_shape[1], dst);
|
||||
dst += cur_encoder_out_shape[1];
|
||||
}
|
||||
src += cur_encoder_out_shape[1];
|
||||
}
|
||||
return ans;
|
||||
}
|
||||
|
||||
static void LogSoftmax(float *in, int32_t w, int32_t h) {
|
||||
for (int32_t i = 0; i != h; ++i) {
|
||||
LogSoftmax(in, w);
|
||||
in += w;
|
||||
}
|
||||
}
|
||||
|
||||
OnlineTransducerDecoderResult
|
||||
OnlineTransducerModifiedBeamSearchDecoder::GetEmptyResult() const {
|
||||
int32_t context_size = model_->ContextSize();
|
||||
|
||||
@@ -193,4 +193,29 @@ std::vector<char> ReadFile(AAssetManager *mgr, const std::string &filename) {
|
||||
}
|
||||
#endif
|
||||
|
||||
Ort::Value Repeat(OrtAllocator *allocator, Ort::Value *cur_encoder_out,
|
||||
const std::vector<int32_t> &hyps_num_split) {
|
||||
std::vector<int64_t> cur_encoder_out_shape =
|
||||
cur_encoder_out->GetTensorTypeAndShapeInfo().GetShape();
|
||||
|
||||
std::array<int64_t, 2> ans_shape{hyps_num_split.back(),
|
||||
cur_encoder_out_shape[1]};
|
||||
|
||||
Ort::Value ans = Ort::Value::CreateTensor<float>(allocator, ans_shape.data(),
|
||||
ans_shape.size());
|
||||
|
||||
const float *src = cur_encoder_out->GetTensorData<float>();
|
||||
float *dst = ans.GetTensorMutableData<float>();
|
||||
int32_t batch_size = static_cast<int32_t>(hyps_num_split.size()) - 1;
|
||||
for (int32_t b = 0; b != batch_size; ++b) {
|
||||
int32_t cur_stream_hyps_num = hyps_num_split[b + 1] - hyps_num_split[b];
|
||||
for (int32_t i = 0; i != cur_stream_hyps_num; ++i) {
|
||||
std::copy(src, src + cur_encoder_out_shape[1], dst);
|
||||
dst += cur_encoder_out_shape[1];
|
||||
}
|
||||
src += cur_encoder_out_shape[1];
|
||||
}
|
||||
return ans;
|
||||
}
|
||||
|
||||
} // namespace sherpa_onnx
|
||||
|
||||
@@ -86,6 +86,9 @@ std::vector<char> ReadFile(const std::string &filename);
|
||||
std::vector<char> ReadFile(AAssetManager *mgr, const std::string &filename);
|
||||
#endif
|
||||
|
||||
// TODO(fangjun): Document it
|
||||
Ort::Value Repeat(OrtAllocator *allocator, Ort::Value *cur_encoder_out,
|
||||
const std::vector<int32_t> &hyps_num_split);
|
||||
} // namespace sherpa_onnx
|
||||
|
||||
#endif // SHERPA_ONNX_CSRC_ONNX_UTILS_H_
|
||||
|
||||
@@ -111,6 +111,9 @@ for a list of pre-trained models to download.
|
||||
|
||||
fprintf(stderr, "num threads: %d\n", config.model_config.num_threads);
|
||||
fprintf(stderr, "decoding method: %s\n", config.decoding_method.c_str());
|
||||
if (config.decoding_method == "modified_beam_search") {
|
||||
fprintf(stderr, "max active paths: %d\n", config.max_active_paths);
|
||||
}
|
||||
|
||||
fprintf(stderr, "Elapsed seconds: %.3f s\n", elapsed_seconds);
|
||||
float rtf = elapsed_seconds / duration;
|
||||
|
||||
@@ -117,6 +117,9 @@ for a list of pre-trained models to download.
|
||||
|
||||
fprintf(stderr, "num threads: %d\n", config.model_config.num_threads);
|
||||
fprintf(stderr, "decoding method: %s\n", config.decoding_method.c_str());
|
||||
if (config.decoding_method == "modified_beam_search") {
|
||||
fprintf(stderr, "max active paths: %d\n", config.max_active_paths);
|
||||
}
|
||||
|
||||
fprintf(stderr, "Elapsed seconds: %.3f s\n", elapsed_seconds);
|
||||
float rtf = elapsed_seconds / duration;
|
||||
|
||||
@@ -4,6 +4,7 @@ pybind11_add_module(_sherpa_onnx
|
||||
display.cc
|
||||
endpoint.cc
|
||||
features.cc
|
||||
offline-lm-config.cc
|
||||
offline-model-config.cc
|
||||
offline-nemo-enc-dec-ctc-model-config.cc
|
||||
offline-paraformer-model-config.cc
|
||||
|
||||
23
sherpa-onnx/python/csrc/offline-lm-config.cc
Normal file
23
sherpa-onnx/python/csrc/offline-lm-config.cc
Normal file
@@ -0,0 +1,23 @@
|
||||
// sherpa-onnx/python/csrc/offline-lm-config.cc
|
||||
//
|
||||
// Copyright (c) 2023 Xiaomi Corporation
|
||||
|
||||
#include "sherpa-onnx/python/csrc/offline-lm-config.h"
|
||||
|
||||
#include <string>
|
||||
|
||||
#include "sherpa-onnx//csrc/offline-lm-config.h"
|
||||
|
||||
namespace sherpa_onnx {
|
||||
|
||||
void PybindOfflineLMConfig(py::module *m) {
|
||||
using PyClass = OfflineLMConfig;
|
||||
py::class_<PyClass>(*m, "OfflineLMConfig")
|
||||
.def(py::init<const std::string &, float>(), py::arg("model"),
|
||||
py::arg("scale"))
|
||||
.def_readwrite("model", &PyClass::model)
|
||||
.def_readwrite("scale", &PyClass::scale)
|
||||
.def("__str__", &PyClass::ToString);
|
||||
}
|
||||
|
||||
} // namespace sherpa_onnx
|
||||
16
sherpa-onnx/python/csrc/offline-lm-config.h
Normal file
16
sherpa-onnx/python/csrc/offline-lm-config.h
Normal file
@@ -0,0 +1,16 @@
|
||||
// sherpa-onnx/python/csrc/offline-lm-config.h
|
||||
//
|
||||
// Copyright (c) 2023 Xiaomi Corporation
|
||||
|
||||
#ifndef SHERPA_ONNX_PYTHON_CSRC_OFFLINE_LM_CONFIG_H_
|
||||
#define SHERPA_ONNX_PYTHON_CSRC_OFFLINE_LM_CONFIG_H_
|
||||
|
||||
#include "sherpa-onnx/python/csrc/sherpa-onnx.h"
|
||||
|
||||
namespace sherpa_onnx {
|
||||
|
||||
void PybindOfflineLMConfig(py::module *m);
|
||||
|
||||
}
|
||||
|
||||
#endif // SHERPA_ONNX_PYTHON_CSRC_OFFLINE_LM_CONFIG_H_
|
||||
@@ -15,12 +15,17 @@ static void PybindOfflineRecognizerConfig(py::module *m) {
|
||||
using PyClass = OfflineRecognizerConfig;
|
||||
py::class_<PyClass>(*m, "OfflineRecognizerConfig")
|
||||
.def(py::init<const OfflineFeatureExtractorConfig &,
|
||||
const OfflineModelConfig &, const std::string &>(),
|
||||
const OfflineModelConfig &, const OfflineLMConfig &,
|
||||
const std::string &, int32_t>(),
|
||||
py::arg("feat_config"), py::arg("model_config"),
|
||||
py::arg("decoding_method"))
|
||||
py::arg("lm_config") = OfflineLMConfig(),
|
||||
py::arg("decoding_method") = "greedy_search",
|
||||
py::arg("max_active_paths") = 4)
|
||||
.def_readwrite("feat_config", &PyClass::feat_config)
|
||||
.def_readwrite("model_config", &PyClass::model_config)
|
||||
.def_readwrite("lm_config", &PyClass::lm_config)
|
||||
.def_readwrite("decoding_method", &PyClass::decoding_method)
|
||||
.def_readwrite("max_active_paths", &PyClass::max_active_paths)
|
||||
.def("__str__", &PyClass::ToString);
|
||||
}
|
||||
|
||||
|
||||
@@ -7,6 +7,7 @@
|
||||
#include "sherpa-onnx/python/csrc/display.h"
|
||||
#include "sherpa-onnx/python/csrc/endpoint.h"
|
||||
#include "sherpa-onnx/python/csrc/features.h"
|
||||
#include "sherpa-onnx/python/csrc/offline-lm-config.h"
|
||||
#include "sherpa-onnx/python/csrc/offline-model-config.h"
|
||||
#include "sherpa-onnx/python/csrc/offline-recognizer.h"
|
||||
#include "sherpa-onnx/python/csrc/offline-stream.h"
|
||||
@@ -28,6 +29,7 @@ PYBIND11_MODULE(_sherpa_onnx, m) {
|
||||
PybindDisplay(&m);
|
||||
|
||||
PybindOfflineStream(&m);
|
||||
PybindOfflineLMConfig(&m);
|
||||
PybindOfflineModelConfig(&m);
|
||||
PybindOfflineRecognizer(&m);
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user