Add RNN LM rescore for offline ASR with modified_beam_search (#125)
This commit is contained in:
1
.gitignore
vendored
1
.gitignore
vendored
@@ -56,3 +56,4 @@ run-offline-decode-files.sh
|
|||||||
sherpa-onnx-nemo-ctc-en-citrinet-512
|
sherpa-onnx-nemo-ctc-en-citrinet-512
|
||||||
run-offline-decode-files-nemo-ctc.sh
|
run-offline-decode-files-nemo-ctc.sh
|
||||||
*.jar
|
*.jar
|
||||||
|
sherpa-onnx-nemo-ctc-*
|
||||||
|
|||||||
@@ -51,6 +51,11 @@ if(DEFINED ANDROID_ABI)
|
|||||||
set(SHERPA_ONNX_ENABLE_JNI ON CACHE BOOL "" FORCE)
|
set(SHERPA_ONNX_ENABLE_JNI ON CACHE BOOL "" FORCE)
|
||||||
endif()
|
endif()
|
||||||
|
|
||||||
|
if(SHERPA_ONNX_ENABLE_PYTHON AND NOT BUILD_SHARED_LIBS)
|
||||||
|
message(STATUS "Set BUILD_SHARED_LIBS to ON since SHERPA_ONNX_ENABLE_PYTHON is ON")
|
||||||
|
set(BUILD_SHARED_LIBS ON CACHE BOOL "" FORCE)
|
||||||
|
endif()
|
||||||
|
|
||||||
if(SHERPA_ONNX_ENABLE_JNI AND NOT BUILD_SHARED_LIBS)
|
if(SHERPA_ONNX_ENABLE_JNI AND NOT BUILD_SHARED_LIBS)
|
||||||
message(STATUS "Set BUILD_SHARED_LIBS to ON since SHERPA_ONNX_ENABLE_JNI is ON")
|
message(STATUS "Set BUILD_SHARED_LIBS to ON since SHERPA_ONNX_ENABLE_JNI is ON")
|
||||||
set(BUILD_SHARED_LIBS ON CACHE BOOL "" FORCE)
|
set(BUILD_SHARED_LIBS ON CACHE BOOL "" FORCE)
|
||||||
|
|||||||
@@ -18,6 +18,8 @@ set(sources
|
|||||||
hypothesis.cc
|
hypothesis.cc
|
||||||
offline-ctc-greedy-search-decoder.cc
|
offline-ctc-greedy-search-decoder.cc
|
||||||
offline-ctc-model.cc
|
offline-ctc-model.cc
|
||||||
|
offline-lm-config.cc
|
||||||
|
offline-lm.cc
|
||||||
offline-model-config.cc
|
offline-model-config.cc
|
||||||
offline-nemo-enc-dec-ctc-model-config.cc
|
offline-nemo-enc-dec-ctc-model-config.cc
|
||||||
offline-nemo-enc-dec-ctc-model.cc
|
offline-nemo-enc-dec-ctc-model.cc
|
||||||
@@ -26,10 +28,13 @@ set(sources
|
|||||||
offline-paraformer-model.cc
|
offline-paraformer-model.cc
|
||||||
offline-recognizer-impl.cc
|
offline-recognizer-impl.cc
|
||||||
offline-recognizer.cc
|
offline-recognizer.cc
|
||||||
|
offline-rnn-lm.cc
|
||||||
offline-stream.cc
|
offline-stream.cc
|
||||||
offline-transducer-greedy-search-decoder.cc
|
offline-transducer-greedy-search-decoder.cc
|
||||||
offline-transducer-model-config.cc
|
offline-transducer-model-config.cc
|
||||||
offline-transducer-model.cc
|
offline-transducer-model.cc
|
||||||
|
offline-transducer-modified-beam-search-decoder.cc
|
||||||
|
online-lm-config.cc
|
||||||
online-lstm-transducer-model.cc
|
online-lstm-transducer-model.cc
|
||||||
online-recognizer.cc
|
online-recognizer.cc
|
||||||
online-stream.cc
|
online-stream.cc
|
||||||
|
|||||||
@@ -17,6 +17,9 @@ void Hypotheses::Add(Hypothesis hyp) {
|
|||||||
hyps_dict_[key] = std::move(hyp);
|
hyps_dict_[key] = std::move(hyp);
|
||||||
} else {
|
} else {
|
||||||
it->second.log_prob = LogAdd<double>()(it->second.log_prob, hyp.log_prob);
|
it->second.log_prob = LogAdd<double>()(it->second.log_prob, hyp.log_prob);
|
||||||
|
|
||||||
|
it->second.lm_log_prob =
|
||||||
|
LogAdd<double>()(it->second.lm_log_prob, hyp.lm_log_prob);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -24,8 +27,8 @@ Hypothesis Hypotheses::GetMostProbable(bool length_norm) const {
|
|||||||
if (length_norm == false) {
|
if (length_norm == false) {
|
||||||
return std::max_element(hyps_dict_.begin(), hyps_dict_.end(),
|
return std::max_element(hyps_dict_.begin(), hyps_dict_.end(),
|
||||||
[](const auto &left, auto &right) -> bool {
|
[](const auto &left, auto &right) -> bool {
|
||||||
return left.second.log_prob <
|
return left.second.TotalLogProb() <
|
||||||
right.second.log_prob;
|
right.second.TotalLogProb();
|
||||||
})
|
})
|
||||||
->second;
|
->second;
|
||||||
} else {
|
} else {
|
||||||
@@ -33,8 +36,8 @@ Hypothesis Hypotheses::GetMostProbable(bool length_norm) const {
|
|||||||
return std::max_element(
|
return std::max_element(
|
||||||
hyps_dict_.begin(), hyps_dict_.end(),
|
hyps_dict_.begin(), hyps_dict_.end(),
|
||||||
[](const auto &left, const auto &right) -> bool {
|
[](const auto &left, const auto &right) -> bool {
|
||||||
return left.second.log_prob / left.second.ys.size() <
|
return left.second.TotalLogProb() / left.second.ys.size() <
|
||||||
right.second.log_prob / right.second.ys.size();
|
right.second.TotalLogProb() / right.second.ys.size();
|
||||||
})
|
})
|
||||||
->second;
|
->second;
|
||||||
}
|
}
|
||||||
@@ -47,15 +50,16 @@ std::vector<Hypothesis> Hypotheses::GetTopK(int32_t k, bool length_norm) const {
|
|||||||
std::vector<Hypothesis> all_hyps = Vec();
|
std::vector<Hypothesis> all_hyps = Vec();
|
||||||
|
|
||||||
if (length_norm == false) {
|
if (length_norm == false) {
|
||||||
std::partial_sort(
|
std::partial_sort(all_hyps.begin(), all_hyps.begin() + k, all_hyps.end(),
|
||||||
all_hyps.begin(), all_hyps.begin() + k, all_hyps.end(),
|
[](const auto &a, const auto &b) {
|
||||||
[](const auto &a, const auto &b) { return a.log_prob > b.log_prob; });
|
return a.TotalLogProb() > b.TotalLogProb();
|
||||||
|
});
|
||||||
} else {
|
} else {
|
||||||
// for length_norm is true
|
// for length_norm is true
|
||||||
std::partial_sort(all_hyps.begin(), all_hyps.begin() + k, all_hyps.end(),
|
std::partial_sort(all_hyps.begin(), all_hyps.begin() + k, all_hyps.end(),
|
||||||
[](const auto &a, const auto &b) {
|
[](const auto &a, const auto &b) {
|
||||||
return a.log_prob / a.ys.size() >
|
return a.TotalLogProb() / a.ys.size() >
|
||||||
b.log_prob / b.ys.size();
|
b.TotalLogProb() / b.ys.size();
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -25,14 +25,20 @@ struct Hypothesis {
|
|||||||
std::vector<int32_t> timestamps;
|
std::vector<int32_t> timestamps;
|
||||||
|
|
||||||
// The total score of ys in log space.
|
// The total score of ys in log space.
|
||||||
|
// It contains only acoustic scores
|
||||||
double log_prob = 0;
|
double log_prob = 0;
|
||||||
|
|
||||||
|
// LM log prob if any.
|
||||||
|
double lm_log_prob = 0;
|
||||||
|
|
||||||
int32_t num_trailing_blanks = 0;
|
int32_t num_trailing_blanks = 0;
|
||||||
|
|
||||||
Hypothesis() = default;
|
Hypothesis() = default;
|
||||||
Hypothesis(const std::vector<int64_t> &ys, double log_prob)
|
Hypothesis(const std::vector<int64_t> &ys, double log_prob)
|
||||||
: ys(ys), log_prob(log_prob) {}
|
: ys(ys), log_prob(log_prob) {}
|
||||||
|
|
||||||
|
double TotalLogProb() const { return log_prob + lm_log_prob; }
|
||||||
|
|
||||||
// If two Hypotheses have the same `Key`, then they contain
|
// If two Hypotheses have the same `Key`, then they contain
|
||||||
// the same token sequence.
|
// the same token sequence.
|
||||||
std::string Key() const {
|
std::string Key() const {
|
||||||
@@ -94,6 +100,9 @@ class Hypotheses {
|
|||||||
const auto begin() const { return hyps_dict_.begin(); }
|
const auto begin() const { return hyps_dict_.begin(); }
|
||||||
const auto end() const { return hyps_dict_.end(); }
|
const auto end() const { return hyps_dict_.end(); }
|
||||||
|
|
||||||
|
auto begin() { return hyps_dict_.begin(); }
|
||||||
|
auto end() { return hyps_dict_.end(); }
|
||||||
|
|
||||||
void Clear() { hyps_dict_.clear(); }
|
void Clear() { hyps_dict_.clear(); }
|
||||||
|
|
||||||
private:
|
private:
|
||||||
|
|||||||
@@ -88,6 +88,16 @@ void LogSoftmax(T *input, int32_t input_len) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
void LogSoftmax(T *in, int32_t w, int32_t h) {
|
||||||
|
for (int32_t i = 0; i != h; ++i) {
|
||||||
|
LogSoftmax(in, w);
|
||||||
|
in += w;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// TODO(fangjun): use std::partial_sort to replace std::sort.
|
||||||
|
// Remember also to fix sherpa-ncnn
|
||||||
template <class T>
|
template <class T>
|
||||||
std::vector<int32_t> TopkIndex(const T *vec, int32_t size, int32_t topk) {
|
std::vector<int32_t> TopkIndex(const T *vec, int32_t size, int32_t topk) {
|
||||||
std::vector<int32_t> vec_index(size);
|
std::vector<int32_t> vec_index(size);
|
||||||
|
|||||||
38
sherpa-onnx/csrc/offline-lm-config.cc
Normal file
38
sherpa-onnx/csrc/offline-lm-config.cc
Normal file
@@ -0,0 +1,38 @@
|
|||||||
|
// sherpa-onnx/csrc/offline-lm-config.cc
|
||||||
|
//
|
||||||
|
// Copyright (c) 2023 Xiaomi Corporation
|
||||||
|
|
||||||
|
#include "sherpa-onnx/csrc/offline-lm-config.h"
|
||||||
|
|
||||||
|
#include <string>
|
||||||
|
|
||||||
|
#include "sherpa-onnx/csrc/file-utils.h"
|
||||||
|
#include "sherpa-onnx/csrc/macros.h"
|
||||||
|
|
||||||
|
namespace sherpa_onnx {
|
||||||
|
|
||||||
|
void OfflineLMConfig::Register(ParseOptions *po) {
|
||||||
|
po->Register("lm", &model, "Path to LM model.");
|
||||||
|
po->Register("lm-scale", &scale, "LM scale.");
|
||||||
|
}
|
||||||
|
|
||||||
|
bool OfflineLMConfig::Validate() const {
|
||||||
|
if (!FileExists(model)) {
|
||||||
|
SHERPA_ONNX_LOGE("%s does not exist", model.c_str());
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
|
std::string OfflineLMConfig::ToString() const {
|
||||||
|
std::ostringstream os;
|
||||||
|
|
||||||
|
os << "OfflineLMConfig(";
|
||||||
|
os << "model=\"" << model << "\", ";
|
||||||
|
os << "scale=" << scale << ")";
|
||||||
|
|
||||||
|
return os.str();
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace sherpa_onnx
|
||||||
33
sherpa-onnx/csrc/offline-lm-config.h
Normal file
33
sherpa-onnx/csrc/offline-lm-config.h
Normal file
@@ -0,0 +1,33 @@
|
|||||||
|
// sherpa-onnx/csrc/offline-lm-config.h
|
||||||
|
//
|
||||||
|
// Copyright (c) 2023 Xiaomi Corporation
|
||||||
|
#ifndef SHERPA_ONNX_CSRC_OFFLINE_LM_CONFIG_H_
|
||||||
|
#define SHERPA_ONNX_CSRC_OFFLINE_LM_CONFIG_H_
|
||||||
|
|
||||||
|
#include <string>
|
||||||
|
|
||||||
|
#include "sherpa-onnx/csrc/parse-options.h"
|
||||||
|
|
||||||
|
namespace sherpa_onnx {
|
||||||
|
|
||||||
|
struct OfflineLMConfig {
|
||||||
|
// path to the onnx model
|
||||||
|
std::string model;
|
||||||
|
|
||||||
|
// LM scale
|
||||||
|
float scale = 1.0;
|
||||||
|
|
||||||
|
OfflineLMConfig() = default;
|
||||||
|
|
||||||
|
OfflineLMConfig(const std::string &model, float scale)
|
||||||
|
: model(model), scale(scale) {}
|
||||||
|
|
||||||
|
void Register(ParseOptions *po);
|
||||||
|
bool Validate() const;
|
||||||
|
|
||||||
|
std::string ToString() const;
|
||||||
|
};
|
||||||
|
|
||||||
|
} // namespace sherpa_onnx
|
||||||
|
|
||||||
|
#endif // SHERPA_ONNX_CSRC_OFFLINE_LM_CONFIG_H_
|
||||||
71
sherpa-onnx/csrc/offline-lm.cc
Normal file
71
sherpa-onnx/csrc/offline-lm.cc
Normal file
@@ -0,0 +1,71 @@
|
|||||||
|
// sherpa-onnx/csrc/offline-lm.cc
|
||||||
|
//
|
||||||
|
// Copyright (c) 2023 Xiaomi Corporation
|
||||||
|
|
||||||
|
#include "sherpa-onnx/csrc/offline-lm.h"
|
||||||
|
|
||||||
|
#include <algorithm>
|
||||||
|
#include <utility>
|
||||||
|
#include <vector>
|
||||||
|
|
||||||
|
#include "sherpa-onnx/csrc/offline-rnn-lm.h"
|
||||||
|
|
||||||
|
namespace sherpa_onnx {
|
||||||
|
|
||||||
|
std::unique_ptr<OfflineLM> OfflineLM::Create(const OfflineLMConfig &config) {
|
||||||
|
return std::make_unique<OfflineRnnLM>(config);
|
||||||
|
}
|
||||||
|
|
||||||
|
void OfflineLM::ComputeLMScore(float scale, int32_t context_size,
|
||||||
|
std::vector<Hypotheses> *hyps) {
|
||||||
|
// compute the max token seq so that we know how much space to allocate
|
||||||
|
int32_t max_token_seq = 0;
|
||||||
|
int32_t num_hyps = 0;
|
||||||
|
|
||||||
|
// we subtract context_size below since each token sequence is prepended
|
||||||
|
// with context_size blanks
|
||||||
|
for (const auto &h : *hyps) {
|
||||||
|
num_hyps += h.Size();
|
||||||
|
for (const auto &t : h) {
|
||||||
|
max_token_seq =
|
||||||
|
std::max<int32_t>(max_token_seq, t.second.ys.size() - context_size);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
Ort::AllocatorWithDefaultOptions allocator;
|
||||||
|
std::array<int64_t, 2> x_shape{num_hyps, max_token_seq};
|
||||||
|
Ort::Value x = Ort::Value::CreateTensor<int64_t>(allocator, x_shape.data(),
|
||||||
|
x_shape.size());
|
||||||
|
|
||||||
|
std::array<int64_t, 1> x_lens_shape{num_hyps};
|
||||||
|
Ort::Value x_lens = Ort::Value::CreateTensor<int64_t>(
|
||||||
|
allocator, x_lens_shape.data(), x_lens_shape.size());
|
||||||
|
|
||||||
|
int64_t *p = x.GetTensorMutableData<int64_t>();
|
||||||
|
std::fill(p, p + num_hyps * max_token_seq, 0);
|
||||||
|
|
||||||
|
int64_t *p_lens = x_lens.GetTensorMutableData<int64_t>();
|
||||||
|
|
||||||
|
for (const auto &h : *hyps) {
|
||||||
|
for (const auto &t : h) {
|
||||||
|
const auto &ys = t.second.ys;
|
||||||
|
int32_t len = ys.size() - context_size;
|
||||||
|
std::copy(ys.begin() + context_size, ys.end(), p);
|
||||||
|
*p_lens = len;
|
||||||
|
|
||||||
|
p += max_token_seq;
|
||||||
|
++p_lens;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
auto negative_loglike = Rescore(std::move(x), std::move(x_lens));
|
||||||
|
const float *p_nll = negative_loglike.GetTensorData<float>();
|
||||||
|
for (auto &h : *hyps) {
|
||||||
|
for (auto &t : h) {
|
||||||
|
// Use -scale here since we want to change negative loglike to loglike.
|
||||||
|
t.second.lm_log_prob = -scale * (*p_nll);
|
||||||
|
++p_nll;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace sherpa_onnx
|
||||||
46
sherpa-onnx/csrc/offline-lm.h
Normal file
46
sherpa-onnx/csrc/offline-lm.h
Normal file
@@ -0,0 +1,46 @@
|
|||||||
|
// sherpa-onnx/csrc/offline-lm.h
|
||||||
|
//
|
||||||
|
// Copyright (c) 2023 Xiaomi Corporation
|
||||||
|
|
||||||
|
#ifndef SHERPA_ONNX_CSRC_OFFLINE_LM_H_
|
||||||
|
#define SHERPA_ONNX_CSRC_OFFLINE_LM_H_
|
||||||
|
|
||||||
|
#include <memory>
|
||||||
|
#include <vector>
|
||||||
|
|
||||||
|
#include "onnxruntime_cxx_api.h" // NOLINT
|
||||||
|
#include "sherpa-onnx/csrc/hypothesis.h"
|
||||||
|
#include "sherpa-onnx/csrc/offline-lm-config.h"
|
||||||
|
|
||||||
|
namespace sherpa_onnx {
|
||||||
|
|
||||||
|
class OfflineLM {
|
||||||
|
public:
|
||||||
|
virtual ~OfflineLM() = default;
|
||||||
|
|
||||||
|
static std::unique_ptr<OfflineLM> Create(const OfflineLMConfig &config);
|
||||||
|
|
||||||
|
/** Rescore a batch of sentences.
|
||||||
|
*
|
||||||
|
* @param x A 2-D tensor of shape (N, L) with data type int64.
|
||||||
|
* @param x_lens A 1-D tensor of shape (N,) with data type int64.
|
||||||
|
* It contains number of valid tokens in x before padding.
|
||||||
|
* @return Return a 1-D tensor of shape (N,) containing the negative log
|
||||||
|
* likelihood of each utterance. Its data type is float32.
|
||||||
|
*
|
||||||
|
* Caution: It returns negative log likelihood (nll), not log likelihood
|
||||||
|
*/
|
||||||
|
virtual Ort::Value Rescore(Ort::Value x, Ort::Value x_lens) = 0;
|
||||||
|
|
||||||
|
// This function updates hyp.lm_lob_prob of hyps.
|
||||||
|
//
|
||||||
|
// @param scale LM score
|
||||||
|
// @param context_size Context size of the transducer decoder model
|
||||||
|
// @param hyps It is changed in-place.
|
||||||
|
void ComputeLMScore(float scale, int32_t context_size,
|
||||||
|
std::vector<Hypotheses> *hyps);
|
||||||
|
};
|
||||||
|
|
||||||
|
} // namespace sherpa_onnx
|
||||||
|
|
||||||
|
#endif // SHERPA_ONNX_CSRC_OFFLINE_LM_H_
|
||||||
@@ -16,6 +16,7 @@
|
|||||||
#include "sherpa-onnx/csrc/offline-transducer-decoder.h"
|
#include "sherpa-onnx/csrc/offline-transducer-decoder.h"
|
||||||
#include "sherpa-onnx/csrc/offline-transducer-greedy-search-decoder.h"
|
#include "sherpa-onnx/csrc/offline-transducer-greedy-search-decoder.h"
|
||||||
#include "sherpa-onnx/csrc/offline-transducer-model.h"
|
#include "sherpa-onnx/csrc/offline-transducer-model.h"
|
||||||
|
#include "sherpa-onnx/csrc/offline-transducer-modified-beam-search-decoder.h"
|
||||||
#include "sherpa-onnx/csrc/pad-sequence.h"
|
#include "sherpa-onnx/csrc/pad-sequence.h"
|
||||||
#include "sherpa-onnx/csrc/symbol-table.h"
|
#include "sherpa-onnx/csrc/symbol-table.h"
|
||||||
|
|
||||||
@@ -57,8 +58,13 @@ class OfflineRecognizerTransducerImpl : public OfflineRecognizerImpl {
|
|||||||
decoder_ =
|
decoder_ =
|
||||||
std::make_unique<OfflineTransducerGreedySearchDecoder>(model_.get());
|
std::make_unique<OfflineTransducerGreedySearchDecoder>(model_.get());
|
||||||
} else if (config_.decoding_method == "modified_beam_search") {
|
} else if (config_.decoding_method == "modified_beam_search") {
|
||||||
SHERPA_ONNX_LOGE("TODO: modified_beam_search is to be implemented");
|
if (!config_.lm_config.model.empty()) {
|
||||||
exit(-1);
|
lm_ = OfflineLM::Create(config.lm_config);
|
||||||
|
}
|
||||||
|
|
||||||
|
decoder_ = std::make_unique<OfflineTransducerModifiedBeamSearchDecoder>(
|
||||||
|
model_.get(), lm_.get(), config_.max_active_paths,
|
||||||
|
config_.lm_config.scale);
|
||||||
} else {
|
} else {
|
||||||
SHERPA_ONNX_LOGE("Unsupported decoding method: %s",
|
SHERPA_ONNX_LOGE("Unsupported decoding method: %s",
|
||||||
config_.decoding_method.c_str());
|
config_.decoding_method.c_str());
|
||||||
@@ -127,6 +133,7 @@ class OfflineRecognizerTransducerImpl : public OfflineRecognizerImpl {
|
|||||||
SymbolTable symbol_table_;
|
SymbolTable symbol_table_;
|
||||||
std::unique_ptr<OfflineTransducerModel> model_;
|
std::unique_ptr<OfflineTransducerModel> model_;
|
||||||
std::unique_ptr<OfflineTransducerDecoder> decoder_;
|
std::unique_ptr<OfflineTransducerDecoder> decoder_;
|
||||||
|
std::unique_ptr<OfflineLM> lm_;
|
||||||
};
|
};
|
||||||
|
|
||||||
} // namespace sherpa_onnx
|
} // namespace sherpa_onnx
|
||||||
|
|||||||
@@ -8,6 +8,7 @@
|
|||||||
|
|
||||||
#include "sherpa-onnx/csrc/file-utils.h"
|
#include "sherpa-onnx/csrc/file-utils.h"
|
||||||
#include "sherpa-onnx/csrc/macros.h"
|
#include "sherpa-onnx/csrc/macros.h"
|
||||||
|
#include "sherpa-onnx/csrc/offline-lm-config.h"
|
||||||
#include "sherpa-onnx/csrc/offline-recognizer-impl.h"
|
#include "sherpa-onnx/csrc/offline-recognizer-impl.h"
|
||||||
|
|
||||||
namespace sherpa_onnx {
|
namespace sherpa_onnx {
|
||||||
@@ -15,13 +16,28 @@ namespace sherpa_onnx {
|
|||||||
void OfflineRecognizerConfig::Register(ParseOptions *po) {
|
void OfflineRecognizerConfig::Register(ParseOptions *po) {
|
||||||
feat_config.Register(po);
|
feat_config.Register(po);
|
||||||
model_config.Register(po);
|
model_config.Register(po);
|
||||||
|
lm_config.Register(po);
|
||||||
|
|
||||||
po->Register("decoding-method", &decoding_method,
|
po->Register(
|
||||||
|
"decoding-method", &decoding_method,
|
||||||
"decoding method,"
|
"decoding method,"
|
||||||
"Valid values: greedy_search.");
|
"Valid values: greedy_search, modified_beam_search. "
|
||||||
|
"modified_beam_search is applicable only for transducer models.");
|
||||||
|
|
||||||
|
po->Register("max-active-paths", &max_active_paths,
|
||||||
|
"Used only when decoding_method is modified_beam_search");
|
||||||
}
|
}
|
||||||
|
|
||||||
bool OfflineRecognizerConfig::Validate() const {
|
bool OfflineRecognizerConfig::Validate() const {
|
||||||
|
if (decoding_method == "modified_beam_search" && !lm_config.model.empty()) {
|
||||||
|
if (max_active_paths <= 0) {
|
||||||
|
SHERPA_ONNX_LOGE("max_active_paths is less than 0! Given: %d",
|
||||||
|
max_active_paths);
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
if (!lm_config.Validate()) return false;
|
||||||
|
}
|
||||||
|
|
||||||
return model_config.Validate();
|
return model_config.Validate();
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -31,7 +47,9 @@ std::string OfflineRecognizerConfig::ToString() const {
|
|||||||
os << "OfflineRecognizerConfig(";
|
os << "OfflineRecognizerConfig(";
|
||||||
os << "feat_config=" << feat_config.ToString() << ", ";
|
os << "feat_config=" << feat_config.ToString() << ", ";
|
||||||
os << "model_config=" << model_config.ToString() << ", ";
|
os << "model_config=" << model_config.ToString() << ", ";
|
||||||
os << "decoding_method=\"" << decoding_method << "\")";
|
os << "lm_config=" << lm_config.ToString() << ", ";
|
||||||
|
os << "decoding_method=\"" << decoding_method << "\", ";
|
||||||
|
os << "max_active_paths=" << max_active_paths << ")";
|
||||||
|
|
||||||
return os.str();
|
return os.str();
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -9,6 +9,7 @@
|
|||||||
#include <string>
|
#include <string>
|
||||||
#include <vector>
|
#include <vector>
|
||||||
|
|
||||||
|
#include "sherpa-onnx/csrc/offline-lm-config.h"
|
||||||
#include "sherpa-onnx/csrc/offline-model-config.h"
|
#include "sherpa-onnx/csrc/offline-model-config.h"
|
||||||
#include "sherpa-onnx/csrc/offline-stream.h"
|
#include "sherpa-onnx/csrc/offline-stream.h"
|
||||||
#include "sherpa-onnx/csrc/offline-transducer-model-config.h"
|
#include "sherpa-onnx/csrc/offline-transducer-model-config.h"
|
||||||
@@ -21,18 +22,24 @@ struct OfflineRecognitionResult;
|
|||||||
struct OfflineRecognizerConfig {
|
struct OfflineRecognizerConfig {
|
||||||
OfflineFeatureExtractorConfig feat_config;
|
OfflineFeatureExtractorConfig feat_config;
|
||||||
OfflineModelConfig model_config;
|
OfflineModelConfig model_config;
|
||||||
|
OfflineLMConfig lm_config;
|
||||||
|
|
||||||
std::string decoding_method = "greedy_search";
|
std::string decoding_method = "greedy_search";
|
||||||
|
int32_t max_active_paths = 4;
|
||||||
// only greedy_search is implemented
|
// only greedy_search is implemented
|
||||||
// TODO(fangjun): Implement modified_beam_search
|
// TODO(fangjun): Implement modified_beam_search
|
||||||
|
|
||||||
OfflineRecognizerConfig() = default;
|
OfflineRecognizerConfig() = default;
|
||||||
OfflineRecognizerConfig(const OfflineFeatureExtractorConfig &feat_config,
|
OfflineRecognizerConfig(const OfflineFeatureExtractorConfig &feat_config,
|
||||||
const OfflineModelConfig &model_config,
|
const OfflineModelConfig &model_config,
|
||||||
const std::string &decoding_method)
|
const OfflineLMConfig &lm_config,
|
||||||
|
const std::string &decoding_method,
|
||||||
|
int32_t max_active_paths)
|
||||||
: feat_config(feat_config),
|
: feat_config(feat_config),
|
||||||
model_config(model_config),
|
model_config(model_config),
|
||||||
decoding_method(decoding_method) {}
|
lm_config(lm_config),
|
||||||
|
decoding_method(decoding_method),
|
||||||
|
max_active_paths(max_active_paths) {}
|
||||||
|
|
||||||
void Register(ParseOptions *po);
|
void Register(ParseOptions *po);
|
||||||
bool Validate() const;
|
bool Validate() const;
|
||||||
|
|||||||
74
sherpa-onnx/csrc/offline-rnn-lm.cc
Normal file
74
sherpa-onnx/csrc/offline-rnn-lm.cc
Normal file
@@ -0,0 +1,74 @@
|
|||||||
|
// sherpa-onnx/csrc/offline-rnn-lm.cc
|
||||||
|
//
|
||||||
|
// Copyright (c) 2023 Xiaomi Corporation
|
||||||
|
|
||||||
|
#include "sherpa-onnx/csrc/offline-rnn-lm.h"
|
||||||
|
|
||||||
|
#include <string>
|
||||||
|
#include <utility>
|
||||||
|
#include <vector>
|
||||||
|
|
||||||
|
#include "onnxruntime_cxx_api.h" // NOLINT
|
||||||
|
#include "sherpa-onnx/csrc/macros.h"
|
||||||
|
#include "sherpa-onnx/csrc/onnx-utils.h"
|
||||||
|
#include "sherpa-onnx/csrc/text-utils.h"
|
||||||
|
|
||||||
|
namespace sherpa_onnx {
|
||||||
|
|
||||||
|
class OfflineRnnLM::Impl {
|
||||||
|
public:
|
||||||
|
explicit Impl(const OfflineLMConfig &config)
|
||||||
|
: config_(config),
|
||||||
|
env_(ORT_LOGGING_LEVEL_ERROR),
|
||||||
|
sess_opts_{},
|
||||||
|
allocator_{} {
|
||||||
|
Init(config);
|
||||||
|
}
|
||||||
|
|
||||||
|
Ort::Value Rescore(Ort::Value x, Ort::Value x_lens) {
|
||||||
|
std::array<Ort::Value, 2> inputs = {std::move(x), std::move(x_lens)};
|
||||||
|
|
||||||
|
auto out =
|
||||||
|
sess_->Run({}, input_names_ptr_.data(), inputs.data(), inputs.size(),
|
||||||
|
output_names_ptr_.data(), output_names_ptr_.size());
|
||||||
|
|
||||||
|
return std::move(out[0]);
|
||||||
|
}
|
||||||
|
|
||||||
|
private:
|
||||||
|
void Init(const OfflineLMConfig &config) {
|
||||||
|
auto buf = ReadFile(config_.model);
|
||||||
|
|
||||||
|
sess_ = std::make_unique<Ort::Session>(env_, buf.data(), buf.size(),
|
||||||
|
sess_opts_);
|
||||||
|
|
||||||
|
GetInputNames(sess_.get(), &input_names_, &input_names_ptr_);
|
||||||
|
|
||||||
|
GetOutputNames(sess_.get(), &output_names_, &output_names_ptr_);
|
||||||
|
}
|
||||||
|
|
||||||
|
private:
|
||||||
|
OfflineLMConfig config_;
|
||||||
|
Ort::Env env_;
|
||||||
|
Ort::SessionOptions sess_opts_;
|
||||||
|
Ort::AllocatorWithDefaultOptions allocator_;
|
||||||
|
|
||||||
|
std::unique_ptr<Ort::Session> sess_;
|
||||||
|
|
||||||
|
std::vector<std::string> input_names_;
|
||||||
|
std::vector<const char *> input_names_ptr_;
|
||||||
|
|
||||||
|
std::vector<std::string> output_names_;
|
||||||
|
std::vector<const char *> output_names_ptr_;
|
||||||
|
};
|
||||||
|
|
||||||
|
OfflineRnnLM::OfflineRnnLM(const OfflineLMConfig &config)
|
||||||
|
: impl_(std::make_unique<Impl>(config)) {}
|
||||||
|
|
||||||
|
OfflineRnnLM::~OfflineRnnLM() = default;
|
||||||
|
|
||||||
|
Ort::Value OfflineRnnLM::Rescore(Ort::Value x, Ort::Value x_lens) {
|
||||||
|
return impl_->Rescore(std::move(x), std::move(x_lens));
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace sherpa_onnx
|
||||||
41
sherpa-onnx/csrc/offline-rnn-lm.h
Normal file
41
sherpa-onnx/csrc/offline-rnn-lm.h
Normal file
@@ -0,0 +1,41 @@
|
|||||||
|
// sherpa-onnx/csrc/offline-rnn-lm.h
|
||||||
|
//
|
||||||
|
// Copyright (c) 2023 Xiaomi Corporation
|
||||||
|
|
||||||
|
#ifndef SHERPA_ONNX_CSRC_OFFLINE_RNN_LM_H_
|
||||||
|
#define SHERPA_ONNX_CSRC_OFFLINE_RNN_LM_H_
|
||||||
|
|
||||||
|
#include <memory>
|
||||||
|
|
||||||
|
#include "onnxruntime_cxx_api.h" // NOLINT
|
||||||
|
#include "sherpa-onnx/csrc/offline-lm-config.h"
|
||||||
|
#include "sherpa-onnx/csrc/offline-lm.h"
|
||||||
|
|
||||||
|
namespace sherpa_onnx {
|
||||||
|
|
||||||
|
class OfflineRnnLM : public OfflineLM {
|
||||||
|
public:
|
||||||
|
~OfflineRnnLM() override;
|
||||||
|
|
||||||
|
explicit OfflineRnnLM(const OfflineLMConfig &config);
|
||||||
|
|
||||||
|
/** Rescore a batch of sentences.
|
||||||
|
*
|
||||||
|
* @param x A 2-D tensor of shape (N, L) with data type int64.
|
||||||
|
* @param x_lens A 1-D tensor of shape (N,) with data type int64.
|
||||||
|
* It contains number of valid tokens in x before padding.
|
||||||
|
* @return Return a 1-D tensor of shape (N,) containing the log likelihood
|
||||||
|
* of each utterance. Its data type is float32.
|
||||||
|
*
|
||||||
|
* Caution: It returns log likelihood, not negative log likelihood (nll).
|
||||||
|
*/
|
||||||
|
Ort::Value Rescore(Ort::Value x, Ort::Value x_lens) override;
|
||||||
|
|
||||||
|
private:
|
||||||
|
class Impl;
|
||||||
|
std::unique_ptr<Impl> impl_;
|
||||||
|
};
|
||||||
|
|
||||||
|
} // namespace sherpa_onnx
|
||||||
|
|
||||||
|
#endif // SHERPA_ONNX_CSRC_OFFLINE_RNN_LM_H_
|
||||||
@@ -95,6 +95,30 @@ class OfflineTransducerModel::Impl {
|
|||||||
std::copy(begin, end, p);
|
std::copy(begin, end, p);
|
||||||
p += context_size;
|
p += context_size;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
return decoder_input;
|
||||||
|
}
|
||||||
|
|
||||||
|
Ort::Value BuildDecoderInput(const std::vector<Hypothesis> &results,
|
||||||
|
int32_t end_index) const {
|
||||||
|
assert(end_index <= results.size());
|
||||||
|
|
||||||
|
int32_t batch_size = end_index;
|
||||||
|
int32_t context_size = ContextSize();
|
||||||
|
std::array<int64_t, 2> shape{batch_size, context_size};
|
||||||
|
|
||||||
|
Ort::Value decoder_input = Ort::Value::CreateTensor<int64_t>(
|
||||||
|
Allocator(), shape.data(), shape.size());
|
||||||
|
int64_t *p = decoder_input.GetTensorMutableData<int64_t>();
|
||||||
|
|
||||||
|
for (int32_t i = 0; i != batch_size; ++i) {
|
||||||
|
const auto &r = results[i];
|
||||||
|
const int64_t *begin = r.ys.data() + r.ys.size() - context_size;
|
||||||
|
const int64_t *end = r.ys.data() + r.ys.size();
|
||||||
|
std::copy(begin, end, p);
|
||||||
|
p += context_size;
|
||||||
|
}
|
||||||
|
|
||||||
return decoder_input;
|
return decoder_input;
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -234,4 +258,9 @@ Ort::Value OfflineTransducerModel::BuildDecoderInput(
|
|||||||
return impl_->BuildDecoderInput(results, end_index);
|
return impl_->BuildDecoderInput(results, end_index);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
Ort::Value OfflineTransducerModel::BuildDecoderInput(
|
||||||
|
const std::vector<Hypothesis> &results, int32_t end_index) const {
|
||||||
|
return impl_->BuildDecoderInput(results, end_index);
|
||||||
|
}
|
||||||
|
|
||||||
} // namespace sherpa_onnx
|
} // namespace sherpa_onnx
|
||||||
|
|||||||
@@ -9,6 +9,7 @@
|
|||||||
#include <vector>
|
#include <vector>
|
||||||
|
|
||||||
#include "onnxruntime_cxx_api.h" // NOLINT
|
#include "onnxruntime_cxx_api.h" // NOLINT
|
||||||
|
#include "sherpa-onnx/csrc/hypothesis.h"
|
||||||
#include "sherpa-onnx/csrc/offline-model-config.h"
|
#include "sherpa-onnx/csrc/offline-model-config.h"
|
||||||
|
|
||||||
namespace sherpa_onnx {
|
namespace sherpa_onnx {
|
||||||
@@ -79,13 +80,16 @@ class OfflineTransducerModel {
|
|||||||
*
|
*
|
||||||
* @param results Current decoded results.
|
* @param results Current decoded results.
|
||||||
* @param end_index We only use results[0:end_index] to build
|
* @param end_index We only use results[0:end_index] to build
|
||||||
* the decoder_input.
|
* the decoder_input. results[end_index] is not used.
|
||||||
* @return Return a tensor of shape (results.size(), ContextSize())
|
* @return Return a tensor of shape (results.size(), ContextSize())
|
||||||
*/
|
*/
|
||||||
Ort::Value BuildDecoderInput(
|
Ort::Value BuildDecoderInput(
|
||||||
const std::vector<OfflineTransducerDecoderResult> &results,
|
const std::vector<OfflineTransducerDecoderResult> &results,
|
||||||
int32_t end_index) const;
|
int32_t end_index) const;
|
||||||
|
|
||||||
|
Ort::Value BuildDecoderInput(const std::vector<Hypothesis> &results,
|
||||||
|
int32_t end_index) const;
|
||||||
|
|
||||||
private:
|
private:
|
||||||
class Impl;
|
class Impl;
|
||||||
std::unique_ptr<Impl> impl_;
|
std::unique_ptr<Impl> impl_;
|
||||||
|
|||||||
@@ -0,0 +1,163 @@
|
|||||||
|
// sherpa-onnx/csrc/offline-transducer-modified-beam-search-decoder.cc
|
||||||
|
//
|
||||||
|
// Copyright (c) 2023 Xiaomi Corporation
|
||||||
|
|
||||||
|
#include "sherpa-onnx/csrc/offline-transducer-modified-beam-search-decoder.h"
|
||||||
|
|
||||||
|
#include <deque>
|
||||||
|
#include <utility>
|
||||||
|
#include <vector>
|
||||||
|
|
||||||
|
#include "sherpa-onnx/csrc/hypothesis.h"
|
||||||
|
#include "sherpa-onnx/csrc/onnx-utils.h"
|
||||||
|
#include "sherpa-onnx/csrc/packed-sequence.h"
|
||||||
|
#include "sherpa-onnx/csrc/slice.h"
|
||||||
|
|
||||||
|
namespace sherpa_onnx {
|
||||||
|
|
||||||
|
static std::vector<int32_t> GetHypsRowSplits(
|
||||||
|
const std::vector<Hypotheses> &hyps) {
|
||||||
|
std::vector<int32_t> row_splits;
|
||||||
|
row_splits.reserve(hyps.size() + 1);
|
||||||
|
|
||||||
|
row_splits.push_back(0);
|
||||||
|
int32_t s = 0;
|
||||||
|
for (const auto &h : hyps) {
|
||||||
|
s += h.Size();
|
||||||
|
row_splits.push_back(s);
|
||||||
|
}
|
||||||
|
|
||||||
|
return row_splits;
|
||||||
|
}
|
||||||
|
|
||||||
|
std::vector<OfflineTransducerDecoderResult>
|
||||||
|
OfflineTransducerModifiedBeamSearchDecoder::Decode(
|
||||||
|
Ort::Value encoder_out, Ort::Value encoder_out_length) {
|
||||||
|
PackedSequence packed_encoder_out = PackPaddedSequence(
|
||||||
|
model_->Allocator(), &encoder_out, &encoder_out_length);
|
||||||
|
|
||||||
|
int32_t batch_size =
|
||||||
|
static_cast<int32_t>(packed_encoder_out.sorted_indexes.size());
|
||||||
|
|
||||||
|
int32_t vocab_size = model_->VocabSize();
|
||||||
|
int32_t context_size = model_->ContextSize();
|
||||||
|
|
||||||
|
std::vector<int64_t> blanks(context_size, 0);
|
||||||
|
Hypotheses blank_hyp({{blanks, 0}});
|
||||||
|
|
||||||
|
std::deque<Hypotheses> finalized;
|
||||||
|
std::vector<Hypotheses> cur(batch_size, blank_hyp);
|
||||||
|
std::vector<Hypothesis> prev;
|
||||||
|
|
||||||
|
int32_t start = 0;
|
||||||
|
int32_t t = 0;
|
||||||
|
for (auto n : packed_encoder_out.batch_sizes) {
|
||||||
|
Ort::Value cur_encoder_out = packed_encoder_out.Get(start, n);
|
||||||
|
start += n;
|
||||||
|
|
||||||
|
if (n < static_cast<int32_t>(cur.size())) {
|
||||||
|
for (int32_t k = static_cast<int32_t>(cur.size()) - 1; k >= n; --k) {
|
||||||
|
finalized.push_front(std::move(cur[k]));
|
||||||
|
}
|
||||||
|
|
||||||
|
cur.erase(cur.begin() + n, cur.end());
|
||||||
|
} // if (n < static_cast<int32_t>(cur.size()))
|
||||||
|
|
||||||
|
// Due to merging paths with identical token sequences,
|
||||||
|
// not all utterances have "max_active_paths" paths.
|
||||||
|
auto hyps_row_splits = GetHypsRowSplits(cur);
|
||||||
|
int32_t num_hyps = hyps_row_splits.back();
|
||||||
|
|
||||||
|
prev.clear();
|
||||||
|
prev.reserve(num_hyps);
|
||||||
|
|
||||||
|
for (auto &hyps : cur) {
|
||||||
|
for (auto &h : hyps) {
|
||||||
|
prev.push_back(std::move(h.second));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
cur.clear();
|
||||||
|
cur.reserve(n);
|
||||||
|
|
||||||
|
auto decoder_input = model_->BuildDecoderInput(prev, num_hyps);
|
||||||
|
// decoder_input shape: (num_hyps, context_size)
|
||||||
|
|
||||||
|
auto decoder_out = model_->RunDecoder(std::move(decoder_input));
|
||||||
|
// decoder_out is (num_hyps, joiner_dim)
|
||||||
|
|
||||||
|
cur_encoder_out =
|
||||||
|
Repeat(model_->Allocator(), &cur_encoder_out, hyps_row_splits);
|
||||||
|
// now cur_encoder_out is of shape (num_hyps, joiner_dim)
|
||||||
|
|
||||||
|
Ort::Value logit = model_->RunJoiner(
|
||||||
|
std::move(cur_encoder_out), Clone(model_->Allocator(), &decoder_out));
|
||||||
|
|
||||||
|
float *p_logit = logit.GetTensorMutableData<float>();
|
||||||
|
LogSoftmax(p_logit, vocab_size, num_hyps);
|
||||||
|
|
||||||
|
// now p_logit contains log_softmax output, we rename it to p_logprob
|
||||||
|
// to match what it actually contains
|
||||||
|
float *p_logprob = p_logit;
|
||||||
|
|
||||||
|
// add log_prob of each hypothesis to p_logprob before taking top_k
|
||||||
|
for (int32_t i = 0; i != num_hyps; ++i) {
|
||||||
|
float log_prob = prev[i].log_prob;
|
||||||
|
for (int32_t k = 0; k != vocab_size; ++k, ++p_logprob) {
|
||||||
|
*p_logprob += log_prob;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
p_logprob = p_logit; // we changed p_logprob in the above for loop
|
||||||
|
|
||||||
|
// Now compute top_k for each utterance
|
||||||
|
for (int32_t i = 0; i != n; ++i) {
|
||||||
|
int32_t start = hyps_row_splits[i];
|
||||||
|
int32_t end = hyps_row_splits[i + 1];
|
||||||
|
auto topk =
|
||||||
|
TopkIndex(p_logprob, vocab_size * (end - start), max_active_paths_);
|
||||||
|
|
||||||
|
Hypotheses hyps;
|
||||||
|
for (auto k : topk) {
|
||||||
|
int32_t hyp_index = k / vocab_size + start;
|
||||||
|
int32_t new_token = k % vocab_size;
|
||||||
|
Hypothesis new_hyp = prev[hyp_index];
|
||||||
|
|
||||||
|
if (new_token != 0) {
|
||||||
|
// blank id is fixed to 0
|
||||||
|
new_hyp.ys.push_back(new_token);
|
||||||
|
new_hyp.timestamps.push_back(t);
|
||||||
|
}
|
||||||
|
|
||||||
|
new_hyp.log_prob = p_logprob[k];
|
||||||
|
hyps.Add(std::move(new_hyp));
|
||||||
|
} // for (auto k : topk)
|
||||||
|
p_logprob += (end - start) * vocab_size;
|
||||||
|
cur.push_back(std::move(hyps));
|
||||||
|
} // for (int32_t i = 0; i != n; ++i)
|
||||||
|
|
||||||
|
++t;
|
||||||
|
} // for (auto n : packed_encoder_out.batch_sizes)
|
||||||
|
|
||||||
|
for (auto &h : finalized) {
|
||||||
|
cur.push_back(std::move(h));
|
||||||
|
}
|
||||||
|
|
||||||
|
if (lm_) {
|
||||||
|
// use LM for rescoring
|
||||||
|
lm_->ComputeLMScore(lm_scale_, context_size, &cur);
|
||||||
|
}
|
||||||
|
|
||||||
|
std::vector<OfflineTransducerDecoderResult> unsorted_ans(batch_size);
|
||||||
|
for (int32_t i = 0; i != batch_size; ++i) {
|
||||||
|
Hypothesis hyp = cur[i].GetMostProbable(true);
|
||||||
|
|
||||||
|
auto &r = unsorted_ans[packed_encoder_out.sorted_indexes[i]];
|
||||||
|
|
||||||
|
// strip leading blanks
|
||||||
|
r.tokens = {hyp.ys.begin() + context_size, hyp.ys.end()};
|
||||||
|
r.timestamps = std::move(hyp.timestamps);
|
||||||
|
}
|
||||||
|
|
||||||
|
return unsorted_ans;
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace sherpa_onnx
|
||||||
@@ -0,0 +1,41 @@
|
|||||||
|
// sherpa-onnx/csrc/offline-transducer-modified-beam-search-decoder.h
|
||||||
|
//
|
||||||
|
// Copyright (c) 2023 Xiaomi Corporation
|
||||||
|
|
||||||
|
#ifndef SHERPA_ONNX_CSRC_OFFLINE_TRANSDUCER_MODIFIED_BEAM_SEARCH_DECODER_H_
|
||||||
|
#define SHERPA_ONNX_CSRC_OFFLINE_TRANSDUCER_MODIFIED_BEAM_SEARCH_DECODER_H_
|
||||||
|
|
||||||
|
#include <vector>
|
||||||
|
|
||||||
|
#include "sherpa-onnx/csrc/offline-lm.h"
|
||||||
|
#include "sherpa-onnx/csrc/offline-transducer-decoder.h"
|
||||||
|
#include "sherpa-onnx/csrc/offline-transducer-model.h"
|
||||||
|
|
||||||
|
namespace sherpa_onnx {
|
||||||
|
|
||||||
|
class OfflineTransducerModifiedBeamSearchDecoder
|
||||||
|
: public OfflineTransducerDecoder {
|
||||||
|
public:
|
||||||
|
OfflineTransducerModifiedBeamSearchDecoder(OfflineTransducerModel *model,
|
||||||
|
OfflineLM *lm,
|
||||||
|
int32_t max_active_paths,
|
||||||
|
float lm_scale)
|
||||||
|
: model_(model),
|
||||||
|
lm_(lm),
|
||||||
|
max_active_paths_(max_active_paths),
|
||||||
|
lm_scale_(lm_scale) {}
|
||||||
|
|
||||||
|
std::vector<OfflineTransducerDecoderResult> Decode(
|
||||||
|
Ort::Value encoder_out, Ort::Value encoder_out_length) override;
|
||||||
|
|
||||||
|
private:
|
||||||
|
OfflineTransducerModel *model_; // Not owned
|
||||||
|
OfflineLM *lm_; // Not owned; may be nullptr
|
||||||
|
|
||||||
|
int32_t max_active_paths_;
|
||||||
|
float lm_scale_; // used only when lm_ is not nullptr
|
||||||
|
};
|
||||||
|
|
||||||
|
} // namespace sherpa_onnx
|
||||||
|
|
||||||
|
#endif // SHERPA_ONNX_CSRC_OFFLINE_TRANSDUCER_MODIFIED_BEAM_SEARCH_DECODER_H_
|
||||||
38
sherpa-onnx/csrc/online-lm-config.cc
Normal file
38
sherpa-onnx/csrc/online-lm-config.cc
Normal file
@@ -0,0 +1,38 @@
|
|||||||
|
// sherpa-onnx/csrc/online-lm-config.cc
|
||||||
|
//
|
||||||
|
// Copyright (c) 2023 Xiaomi Corporation
|
||||||
|
|
||||||
|
#include "sherpa-onnx/csrc/online-lm-config.h"
|
||||||
|
|
||||||
|
#include <string>
|
||||||
|
|
||||||
|
#include "sherpa-onnx/csrc/file-utils.h"
|
||||||
|
#include "sherpa-onnx/csrc/macros.h"
|
||||||
|
|
||||||
|
namespace sherpa_onnx {
|
||||||
|
|
||||||
|
void OnlineLMConfig::Register(ParseOptions *po) {
|
||||||
|
po->Register("lm", &model, "Path to LM model.");
|
||||||
|
po->Register("lm-scale", &scale, "LM scale.");
|
||||||
|
}
|
||||||
|
|
||||||
|
bool OnlineLMConfig::Validate() const {
|
||||||
|
if (!FileExists(model)) {
|
||||||
|
SHERPA_ONNX_LOGE("%s does not exist", model.c_str());
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
|
std::string OnlineLMConfig::ToString() const {
|
||||||
|
std::ostringstream os;
|
||||||
|
|
||||||
|
os << "OnlineLMConfig(";
|
||||||
|
os << "model=\"" << model << "\", ";
|
||||||
|
os << "scale=" << scale << ")";
|
||||||
|
|
||||||
|
return os.str();
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace sherpa_onnx
|
||||||
33
sherpa-onnx/csrc/online-lm-config.h
Normal file
33
sherpa-onnx/csrc/online-lm-config.h
Normal file
@@ -0,0 +1,33 @@
|
|||||||
|
// sherpa-onnx/csrc/online-lm-config.h
|
||||||
|
//
|
||||||
|
// Copyright (c) 2023 Xiaomi Corporation
|
||||||
|
#ifndef SHERPA_ONNX_CSRC_ONLINE_LM_CONFIG_H_
|
||||||
|
#define SHERPA_ONNX_CSRC_ONLINE_LM_CONFIG_H_
|
||||||
|
|
||||||
|
#include <string>
|
||||||
|
|
||||||
|
#include "sherpa-onnx/csrc/parse-options.h"
|
||||||
|
|
||||||
|
namespace sherpa_onnx {
|
||||||
|
|
||||||
|
struct OnlineLMConfig {
|
||||||
|
// path to the onnx model
|
||||||
|
std::string model;
|
||||||
|
|
||||||
|
// LM scale
|
||||||
|
float scale = 1.0;
|
||||||
|
|
||||||
|
OnlineLMConfig() = default;
|
||||||
|
|
||||||
|
OnlineLMConfig(const std::string &model, float scale)
|
||||||
|
: model(model), scale(scale) {}
|
||||||
|
|
||||||
|
void Register(ParseOptions *po);
|
||||||
|
bool Validate() const;
|
||||||
|
|
||||||
|
std::string ToString() const;
|
||||||
|
};
|
||||||
|
|
||||||
|
} // namespace sherpa_onnx
|
||||||
|
|
||||||
|
#endif // SHERPA_ONNX_CSRC_ONLINE_LM_CONFIG_H_
|
||||||
64
sherpa-onnx/csrc/online-lm.h
Normal file
64
sherpa-onnx/csrc/online-lm.h
Normal file
@@ -0,0 +1,64 @@
|
|||||||
|
// sherpa-onnx/csrc/online-lm.h
|
||||||
|
//
|
||||||
|
// Copyright (c) 2023 Xiaomi Corporation
|
||||||
|
|
||||||
|
#ifndef SHERPA_ONNX_CSRC_ONLINE_LM_H_
|
||||||
|
#define SHERPA_ONNX_CSRC_ONLINE_LM_H_
|
||||||
|
|
||||||
|
#include <memory>
|
||||||
|
#include <utility>
|
||||||
|
#include <vector>
|
||||||
|
|
||||||
|
#include "onnxruntime_cxx_api.h" // NOLINT
|
||||||
|
#include "sherpa-onnx/csrc/hypothesis.h"
|
||||||
|
#include "sherpa-onnx/csrc/online-lm-config.h"
|
||||||
|
|
||||||
|
namespace sherpa_onnx {
|
||||||
|
|
||||||
|
class OnlineLM {
|
||||||
|
public:
|
||||||
|
virtual ~OnlineLM() = default;
|
||||||
|
|
||||||
|
static std::unique_ptr<OnlineLM> Create(const OnlineLMConfig &config);
|
||||||
|
|
||||||
|
virtual std::vector<Ort::Value> GetInitStates() = 0;
|
||||||
|
|
||||||
|
/** Rescore a batch of sentences.
|
||||||
|
*
|
||||||
|
* @param x A 2-D tensor of shape (N, L) with data type int64.
|
||||||
|
* @param y A 2-D tensor of shape (N, L) with data type int64.
|
||||||
|
* @param states It contains the states for the LM model
|
||||||
|
* @return Return a pair containingo
|
||||||
|
* - negative loglike
|
||||||
|
* - updated states
|
||||||
|
*
|
||||||
|
* Caution: It returns negative log likelihood (nll), not log likelihood
|
||||||
|
*/
|
||||||
|
std::pair<Ort::Value, std::vector<Ort::Value>> Ort::Value Rescore(
|
||||||
|
Ort::Value x, Ort::Value y, std::vector<Ort::Value> states) = 0;
|
||||||
|
|
||||||
|
// This function updates hyp.lm_lob_prob of hyps.
|
||||||
|
//
|
||||||
|
// @param scale LM score
|
||||||
|
// @param context_size Context size of the transducer decoder model
|
||||||
|
// @param hyps It is changed in-place.
|
||||||
|
void ComputeLMScore(float scale, int32_t context_size,
|
||||||
|
std::vector<Hypotheses> *hyps);
|
||||||
|
/** TODO(fangjun):
|
||||||
|
*
|
||||||
|
* 1. Add two fields to Hypothesis
|
||||||
|
* (a) int32_t lm_cur_pos = 0; number of scored tokens so far
|
||||||
|
* (b) std::vector<Ort::Value> lm_states;
|
||||||
|
* 2. When we want to score a hypothesis, we construct x and y as follows:
|
||||||
|
*
|
||||||
|
* std::vector x = {hyp.ys.begin() + context_size + lm_cur_pos,
|
||||||
|
* hyp.ys.end() - 1};
|
||||||
|
* std::vector y = {hyp.ys.begin() + context_size + lm_cur_pos + 1
|
||||||
|
* hyp.ys.end()};
|
||||||
|
* hyp.lm_cur_pos += hyp.ys.size() - context_size - lm_cur_pos;
|
||||||
|
*/
|
||||||
|
};
|
||||||
|
|
||||||
|
} // namespace sherpa_onnx
|
||||||
|
|
||||||
|
#endif // SHERPA_ONNX_CSRC_ONLINE_LM_H_
|
||||||
@@ -36,38 +36,6 @@ static void UseCachedDecoderOut(
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
static Ort::Value Repeat(OrtAllocator *allocator, Ort::Value *cur_encoder_out,
|
|
||||||
const std::vector<int32_t> &hyps_num_split) {
|
|
||||||
std::vector<int64_t> cur_encoder_out_shape =
|
|
||||||
cur_encoder_out->GetTensorTypeAndShapeInfo().GetShape();
|
|
||||||
|
|
||||||
std::array<int64_t, 2> ans_shape{hyps_num_split.back(),
|
|
||||||
cur_encoder_out_shape[1]};
|
|
||||||
|
|
||||||
Ort::Value ans = Ort::Value::CreateTensor<float>(allocator, ans_shape.data(),
|
|
||||||
ans_shape.size());
|
|
||||||
|
|
||||||
const float *src = cur_encoder_out->GetTensorData<float>();
|
|
||||||
float *dst = ans.GetTensorMutableData<float>();
|
|
||||||
int32_t batch_size = static_cast<int32_t>(hyps_num_split.size()) - 1;
|
|
||||||
for (int32_t b = 0; b != batch_size; ++b) {
|
|
||||||
int32_t cur_stream_hyps_num = hyps_num_split[b + 1] - hyps_num_split[b];
|
|
||||||
for (int32_t i = 0; i != cur_stream_hyps_num; ++i) {
|
|
||||||
std::copy(src, src + cur_encoder_out_shape[1], dst);
|
|
||||||
dst += cur_encoder_out_shape[1];
|
|
||||||
}
|
|
||||||
src += cur_encoder_out_shape[1];
|
|
||||||
}
|
|
||||||
return ans;
|
|
||||||
}
|
|
||||||
|
|
||||||
static void LogSoftmax(float *in, int32_t w, int32_t h) {
|
|
||||||
for (int32_t i = 0; i != h; ++i) {
|
|
||||||
LogSoftmax(in, w);
|
|
||||||
in += w;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
OnlineTransducerDecoderResult
|
OnlineTransducerDecoderResult
|
||||||
OnlineTransducerModifiedBeamSearchDecoder::GetEmptyResult() const {
|
OnlineTransducerModifiedBeamSearchDecoder::GetEmptyResult() const {
|
||||||
int32_t context_size = model_->ContextSize();
|
int32_t context_size = model_->ContextSize();
|
||||||
|
|||||||
@@ -193,4 +193,29 @@ std::vector<char> ReadFile(AAssetManager *mgr, const std::string &filename) {
|
|||||||
}
|
}
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
|
Ort::Value Repeat(OrtAllocator *allocator, Ort::Value *cur_encoder_out,
|
||||||
|
const std::vector<int32_t> &hyps_num_split) {
|
||||||
|
std::vector<int64_t> cur_encoder_out_shape =
|
||||||
|
cur_encoder_out->GetTensorTypeAndShapeInfo().GetShape();
|
||||||
|
|
||||||
|
std::array<int64_t, 2> ans_shape{hyps_num_split.back(),
|
||||||
|
cur_encoder_out_shape[1]};
|
||||||
|
|
||||||
|
Ort::Value ans = Ort::Value::CreateTensor<float>(allocator, ans_shape.data(),
|
||||||
|
ans_shape.size());
|
||||||
|
|
||||||
|
const float *src = cur_encoder_out->GetTensorData<float>();
|
||||||
|
float *dst = ans.GetTensorMutableData<float>();
|
||||||
|
int32_t batch_size = static_cast<int32_t>(hyps_num_split.size()) - 1;
|
||||||
|
for (int32_t b = 0; b != batch_size; ++b) {
|
||||||
|
int32_t cur_stream_hyps_num = hyps_num_split[b + 1] - hyps_num_split[b];
|
||||||
|
for (int32_t i = 0; i != cur_stream_hyps_num; ++i) {
|
||||||
|
std::copy(src, src + cur_encoder_out_shape[1], dst);
|
||||||
|
dst += cur_encoder_out_shape[1];
|
||||||
|
}
|
||||||
|
src += cur_encoder_out_shape[1];
|
||||||
|
}
|
||||||
|
return ans;
|
||||||
|
}
|
||||||
|
|
||||||
} // namespace sherpa_onnx
|
} // namespace sherpa_onnx
|
||||||
|
|||||||
@@ -86,6 +86,9 @@ std::vector<char> ReadFile(const std::string &filename);
|
|||||||
std::vector<char> ReadFile(AAssetManager *mgr, const std::string &filename);
|
std::vector<char> ReadFile(AAssetManager *mgr, const std::string &filename);
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
|
// TODO(fangjun): Document it
|
||||||
|
Ort::Value Repeat(OrtAllocator *allocator, Ort::Value *cur_encoder_out,
|
||||||
|
const std::vector<int32_t> &hyps_num_split);
|
||||||
} // namespace sherpa_onnx
|
} // namespace sherpa_onnx
|
||||||
|
|
||||||
#endif // SHERPA_ONNX_CSRC_ONNX_UTILS_H_
|
#endif // SHERPA_ONNX_CSRC_ONNX_UTILS_H_
|
||||||
|
|||||||
@@ -111,6 +111,9 @@ for a list of pre-trained models to download.
|
|||||||
|
|
||||||
fprintf(stderr, "num threads: %d\n", config.model_config.num_threads);
|
fprintf(stderr, "num threads: %d\n", config.model_config.num_threads);
|
||||||
fprintf(stderr, "decoding method: %s\n", config.decoding_method.c_str());
|
fprintf(stderr, "decoding method: %s\n", config.decoding_method.c_str());
|
||||||
|
if (config.decoding_method == "modified_beam_search") {
|
||||||
|
fprintf(stderr, "max active paths: %d\n", config.max_active_paths);
|
||||||
|
}
|
||||||
|
|
||||||
fprintf(stderr, "Elapsed seconds: %.3f s\n", elapsed_seconds);
|
fprintf(stderr, "Elapsed seconds: %.3f s\n", elapsed_seconds);
|
||||||
float rtf = elapsed_seconds / duration;
|
float rtf = elapsed_seconds / duration;
|
||||||
|
|||||||
@@ -117,6 +117,9 @@ for a list of pre-trained models to download.
|
|||||||
|
|
||||||
fprintf(stderr, "num threads: %d\n", config.model_config.num_threads);
|
fprintf(stderr, "num threads: %d\n", config.model_config.num_threads);
|
||||||
fprintf(stderr, "decoding method: %s\n", config.decoding_method.c_str());
|
fprintf(stderr, "decoding method: %s\n", config.decoding_method.c_str());
|
||||||
|
if (config.decoding_method == "modified_beam_search") {
|
||||||
|
fprintf(stderr, "max active paths: %d\n", config.max_active_paths);
|
||||||
|
}
|
||||||
|
|
||||||
fprintf(stderr, "Elapsed seconds: %.3f s\n", elapsed_seconds);
|
fprintf(stderr, "Elapsed seconds: %.3f s\n", elapsed_seconds);
|
||||||
float rtf = elapsed_seconds / duration;
|
float rtf = elapsed_seconds / duration;
|
||||||
|
|||||||
@@ -4,6 +4,7 @@ pybind11_add_module(_sherpa_onnx
|
|||||||
display.cc
|
display.cc
|
||||||
endpoint.cc
|
endpoint.cc
|
||||||
features.cc
|
features.cc
|
||||||
|
offline-lm-config.cc
|
||||||
offline-model-config.cc
|
offline-model-config.cc
|
||||||
offline-nemo-enc-dec-ctc-model-config.cc
|
offline-nemo-enc-dec-ctc-model-config.cc
|
||||||
offline-paraformer-model-config.cc
|
offline-paraformer-model-config.cc
|
||||||
|
|||||||
23
sherpa-onnx/python/csrc/offline-lm-config.cc
Normal file
23
sherpa-onnx/python/csrc/offline-lm-config.cc
Normal file
@@ -0,0 +1,23 @@
|
|||||||
|
// sherpa-onnx/python/csrc/offline-lm-config.cc
|
||||||
|
//
|
||||||
|
// Copyright (c) 2023 Xiaomi Corporation
|
||||||
|
|
||||||
|
#include "sherpa-onnx/python/csrc/offline-lm-config.h"
|
||||||
|
|
||||||
|
#include <string>
|
||||||
|
|
||||||
|
#include "sherpa-onnx//csrc/offline-lm-config.h"
|
||||||
|
|
||||||
|
namespace sherpa_onnx {
|
||||||
|
|
||||||
|
void PybindOfflineLMConfig(py::module *m) {
|
||||||
|
using PyClass = OfflineLMConfig;
|
||||||
|
py::class_<PyClass>(*m, "OfflineLMConfig")
|
||||||
|
.def(py::init<const std::string &, float>(), py::arg("model"),
|
||||||
|
py::arg("scale"))
|
||||||
|
.def_readwrite("model", &PyClass::model)
|
||||||
|
.def_readwrite("scale", &PyClass::scale)
|
||||||
|
.def("__str__", &PyClass::ToString);
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace sherpa_onnx
|
||||||
16
sherpa-onnx/python/csrc/offline-lm-config.h
Normal file
16
sherpa-onnx/python/csrc/offline-lm-config.h
Normal file
@@ -0,0 +1,16 @@
|
|||||||
|
// sherpa-onnx/python/csrc/offline-lm-config.h
|
||||||
|
//
|
||||||
|
// Copyright (c) 2023 Xiaomi Corporation
|
||||||
|
|
||||||
|
#ifndef SHERPA_ONNX_PYTHON_CSRC_OFFLINE_LM_CONFIG_H_
|
||||||
|
#define SHERPA_ONNX_PYTHON_CSRC_OFFLINE_LM_CONFIG_H_
|
||||||
|
|
||||||
|
#include "sherpa-onnx/python/csrc/sherpa-onnx.h"
|
||||||
|
|
||||||
|
namespace sherpa_onnx {
|
||||||
|
|
||||||
|
void PybindOfflineLMConfig(py::module *m);
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
#endif // SHERPA_ONNX_PYTHON_CSRC_OFFLINE_LM_CONFIG_H_
|
||||||
@@ -15,12 +15,17 @@ static void PybindOfflineRecognizerConfig(py::module *m) {
|
|||||||
using PyClass = OfflineRecognizerConfig;
|
using PyClass = OfflineRecognizerConfig;
|
||||||
py::class_<PyClass>(*m, "OfflineRecognizerConfig")
|
py::class_<PyClass>(*m, "OfflineRecognizerConfig")
|
||||||
.def(py::init<const OfflineFeatureExtractorConfig &,
|
.def(py::init<const OfflineFeatureExtractorConfig &,
|
||||||
const OfflineModelConfig &, const std::string &>(),
|
const OfflineModelConfig &, const OfflineLMConfig &,
|
||||||
|
const std::string &, int32_t>(),
|
||||||
py::arg("feat_config"), py::arg("model_config"),
|
py::arg("feat_config"), py::arg("model_config"),
|
||||||
py::arg("decoding_method"))
|
py::arg("lm_config") = OfflineLMConfig(),
|
||||||
|
py::arg("decoding_method") = "greedy_search",
|
||||||
|
py::arg("max_active_paths") = 4)
|
||||||
.def_readwrite("feat_config", &PyClass::feat_config)
|
.def_readwrite("feat_config", &PyClass::feat_config)
|
||||||
.def_readwrite("model_config", &PyClass::model_config)
|
.def_readwrite("model_config", &PyClass::model_config)
|
||||||
|
.def_readwrite("lm_config", &PyClass::lm_config)
|
||||||
.def_readwrite("decoding_method", &PyClass::decoding_method)
|
.def_readwrite("decoding_method", &PyClass::decoding_method)
|
||||||
|
.def_readwrite("max_active_paths", &PyClass::max_active_paths)
|
||||||
.def("__str__", &PyClass::ToString);
|
.def("__str__", &PyClass::ToString);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -7,6 +7,7 @@
|
|||||||
#include "sherpa-onnx/python/csrc/display.h"
|
#include "sherpa-onnx/python/csrc/display.h"
|
||||||
#include "sherpa-onnx/python/csrc/endpoint.h"
|
#include "sherpa-onnx/python/csrc/endpoint.h"
|
||||||
#include "sherpa-onnx/python/csrc/features.h"
|
#include "sherpa-onnx/python/csrc/features.h"
|
||||||
|
#include "sherpa-onnx/python/csrc/offline-lm-config.h"
|
||||||
#include "sherpa-onnx/python/csrc/offline-model-config.h"
|
#include "sherpa-onnx/python/csrc/offline-model-config.h"
|
||||||
#include "sherpa-onnx/python/csrc/offline-recognizer.h"
|
#include "sherpa-onnx/python/csrc/offline-recognizer.h"
|
||||||
#include "sherpa-onnx/python/csrc/offline-stream.h"
|
#include "sherpa-onnx/python/csrc/offline-stream.h"
|
||||||
@@ -28,6 +29,7 @@ PYBIND11_MODULE(_sherpa_onnx, m) {
|
|||||||
PybindDisplay(&m);
|
PybindDisplay(&m);
|
||||||
|
|
||||||
PybindOfflineStream(&m);
|
PybindOfflineStream(&m);
|
||||||
|
PybindOfflineLMConfig(&m);
|
||||||
PybindOfflineModelConfig(&m);
|
PybindOfflineModelConfig(&m);
|
||||||
PybindOfflineRecognizer(&m);
|
PybindOfflineRecognizer(&m);
|
||||||
}
|
}
|
||||||
|
|||||||
Reference in New Issue
Block a user