Add lm rescore to online-modified-beam-search (#133)
This commit is contained in:
@@ -182,9 +182,10 @@ class MainActivity : AppCompatActivity() {
|
||||
val config = OnlineRecognizerConfig(
|
||||
featConfig = getFeatureConfig(sampleRate = sampleRateInHz, featureDim = 80),
|
||||
modelConfig = getModelConfig(type = type)!!,
|
||||
lmConfig = getOnlineLMConfig(type = type),
|
||||
endpointConfig = getEndpointConfig(),
|
||||
enableEndpoint = true,
|
||||
decodingMethod = "greedy_search",
|
||||
decodingMethod = "modified_beam_search",
|
||||
maxActivePaths = 4,
|
||||
)
|
||||
|
||||
|
||||
@@ -23,6 +23,11 @@ data class OnlineTransducerModelConfig(
|
||||
var debug: Boolean = false,
|
||||
)
|
||||
|
||||
data class OnlineLMConfig(
|
||||
var model: String = "",
|
||||
var scale: Float = 0.5f,
|
||||
)
|
||||
|
||||
data class FeatureConfig(
|
||||
var sampleRate: Int = 16000,
|
||||
var featureDim: Int = 80,
|
||||
@@ -31,6 +36,7 @@ data class FeatureConfig(
|
||||
data class OnlineRecognizerConfig(
|
||||
var featConfig: FeatureConfig = FeatureConfig(),
|
||||
var modelConfig: OnlineTransducerModelConfig,
|
||||
var lmConfig : OnlineLMConfig,
|
||||
var endpointConfig: EndpointConfig = EndpointConfig(),
|
||||
var enableEndpoint: Boolean = true,
|
||||
var decodingMethod: String = "greedy_search",
|
||||
@@ -151,6 +157,32 @@ fun getModelConfig(type: Int): OnlineTransducerModelConfig? {
|
||||
return null;
|
||||
}
|
||||
|
||||
/*
|
||||
Please see
|
||||
https://k2-fsa.github.io/sherpa/onnx/pretrained_models/index.html
|
||||
for a list of pre-trained models.
|
||||
|
||||
We only add a few here. Please change the following code
|
||||
to add your own LM model. (It should be straightforward to train a new NN LM model
|
||||
by following the code, https://github.com/k2-fsa/icefall/blob/master/icefall/rnn_lm/train.py)
|
||||
|
||||
@param type
|
||||
0 - sherpa-onnx-streaming-zipformer-bilingual-zh-en-2023-02-20 (Bilingual, Chinese + English)
|
||||
https://k2-fsa.github.io/sherpa/onnx/pretrained_models/zipformer-transducer-models.html#sherpa-onnx-streaming-zipformer-bilingual-zh-en-2023-02-20-bilingual-chinese-english
|
||||
*/
|
||||
fun getOnlineLMConfig(type : Int): OnlineLMConfig {
|
||||
when (type) {
|
||||
0 -> {
|
||||
val modelDir = "sherpa-onnx-streaming-zipformer-bilingual-zh-en-2023-02-20"
|
||||
return OnlineLMConfig(
|
||||
model = "$modelDir/with-state-epoch-99-avg-1.int8.onnx",
|
||||
scale = 0.5f,
|
||||
)
|
||||
}
|
||||
}
|
||||
return OnlineLMConfig();
|
||||
}
|
||||
|
||||
fun getEndpointConfig(): EndpointConfig {
|
||||
return EndpointConfig(
|
||||
rule1 = EndpointRule(false, 2.4f, 0.0f),
|
||||
|
||||
@@ -22,8 +22,11 @@ fun main() {
|
||||
|
||||
var endpointConfig = EndpointConfig()
|
||||
|
||||
var lmConfig = OnlineLMConfig()
|
||||
|
||||
var config = OnlineRecognizerConfig(
|
||||
modelConfig = modelConfig,
|
||||
lmConfig = lmConfig,
|
||||
featConfig = featConfig,
|
||||
endpointConfig = endpointConfig,
|
||||
enableEndpoint = true,
|
||||
|
||||
@@ -34,9 +34,11 @@ set(sources
|
||||
offline-transducer-model-config.cc
|
||||
offline-transducer-model.cc
|
||||
offline-transducer-modified-beam-search-decoder.cc
|
||||
online-lm.cc
|
||||
online-lm-config.cc
|
||||
online-lstm-transducer-model.cc
|
||||
online-recognizer.cc
|
||||
online-rnn-lm.cc
|
||||
online-stream.cc
|
||||
online-transducer-decoder.cc
|
||||
online-transducer-greedy-search-decoder.cc
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
/**
|
||||
* Copyright (c) 2023 Xiaomi Corporation
|
||||
*
|
||||
* Copyright (c) 2023 Pingfeng Luo
|
||||
*/
|
||||
|
||||
#include "sherpa-onnx/csrc/hypothesis.h"
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
/**
|
||||
* Copyright (c) 2023 Xiaomi Corporation
|
||||
* Copyright (c) 2023 Pingfeng Luo
|
||||
*
|
||||
*/
|
||||
|
||||
@@ -12,7 +13,9 @@
|
||||
#include <utility>
|
||||
#include <vector>
|
||||
|
||||
#include "onnxruntime_cxx_api.h" // NOLINT
|
||||
#include "sherpa-onnx/csrc/math.h"
|
||||
#include "sherpa-onnx/csrc/onnx-utils.h"
|
||||
|
||||
namespace sherpa_onnx {
|
||||
|
||||
@@ -31,6 +34,13 @@ struct Hypothesis {
|
||||
// LM log prob if any.
|
||||
double lm_log_prob = 0;
|
||||
|
||||
int32_t cur_scored_pos = 0; // cur scored tokens by RNN LM
|
||||
std::vector<CopyableOrtValue> nn_lm_states;
|
||||
|
||||
// TODO(fangjun): Make it configurable
|
||||
// the minimum of tokens in a chunk for streaming RNN LM
|
||||
int32_t lm_rescore_min_chunk = 2; // a const
|
||||
|
||||
int32_t num_trailing_blanks = 0;
|
||||
|
||||
Hypothesis() = default;
|
||||
|
||||
@@ -96,17 +96,15 @@ void LogSoftmax(T *in, int32_t w, int32_t h) {
|
||||
}
|
||||
}
|
||||
|
||||
// TODO(fangjun): use std::partial_sort to replace std::sort.
|
||||
// Remember also to fix sherpa-ncnn
|
||||
template <class T>
|
||||
std::vector<int32_t> TopkIndex(const T *vec, int32_t size, int32_t topk) {
|
||||
std::vector<int32_t> vec_index(size);
|
||||
std::iota(vec_index.begin(), vec_index.end(), 0);
|
||||
|
||||
std::sort(vec_index.begin(), vec_index.end(),
|
||||
[vec](int32_t index_1, int32_t index_2) {
|
||||
return vec[index_1] > vec[index_2];
|
||||
});
|
||||
std::partial_sort(vec_index.begin(), vec_index.begin() + topk,
|
||||
vec_index.end(), [vec](int32_t index_1, int32_t index_2) {
|
||||
return vec[index_1] > vec[index_2];
|
||||
});
|
||||
|
||||
int32_t k_num = std::min<int32_t>(size, topk);
|
||||
std::vector<int32_t> index(vec_index.begin(), vec_index.begin() + k_num);
|
||||
|
||||
@@ -15,7 +15,7 @@ struct OfflineLMConfig {
|
||||
std::string model;
|
||||
|
||||
// LM scale
|
||||
float scale = 1.0;
|
||||
float scale = 0.5;
|
||||
|
||||
OfflineLMConfig() = default;
|
||||
|
||||
|
||||
@@ -15,7 +15,7 @@ struct OnlineLMConfig {
|
||||
std::string model;
|
||||
|
||||
// LM scale
|
||||
float scale = 1.0;
|
||||
float scale = 0.5;
|
||||
|
||||
OnlineLMConfig() = default;
|
||||
|
||||
|
||||
92
sherpa-onnx/csrc/online-lm.cc
Normal file
92
sherpa-onnx/csrc/online-lm.cc
Normal file
@@ -0,0 +1,92 @@
|
||||
// sherpa-onnx/csrc/online-lm.cc
|
||||
//
|
||||
// Copyright (c) 2023 Pingfeng Luo
|
||||
// Copyright (c) 2023 Xiaomi Corporation
|
||||
|
||||
#include "sherpa-onnx/csrc/online-lm.h"
|
||||
|
||||
#include <algorithm>
|
||||
#include <utility>
|
||||
#include <vector>
|
||||
|
||||
#include "sherpa-onnx/csrc/online-rnn-lm.h"
|
||||
|
||||
namespace sherpa_onnx {
|
||||
|
||||
static std::vector<CopyableOrtValue> Convert(std::vector<Ort::Value> values) {
|
||||
std::vector<CopyableOrtValue> ans;
|
||||
ans.reserve(values.size());
|
||||
|
||||
for (auto &v : values) {
|
||||
ans.emplace_back(std::move(v));
|
||||
}
|
||||
|
||||
return ans;
|
||||
}
|
||||
|
||||
static std::vector<Ort::Value> Convert(std::vector<CopyableOrtValue> values) {
|
||||
std::vector<Ort::Value> ans;
|
||||
ans.reserve(values.size());
|
||||
|
||||
for (auto &v : values) {
|
||||
ans.emplace_back(std::move(v.value));
|
||||
}
|
||||
|
||||
return ans;
|
||||
}
|
||||
|
||||
std::unique_ptr<OnlineLM> OnlineLM::Create(const OnlineLMConfig &config) {
|
||||
return std::make_unique<OnlineRnnLM>(config);
|
||||
}
|
||||
|
||||
void OnlineLM::ComputeLMScore(float scale, int32_t context_size,
|
||||
std::vector<Hypotheses> *hyps) {
|
||||
Ort::AllocatorWithDefaultOptions allocator;
|
||||
|
||||
for (auto &hyp : *hyps) {
|
||||
for (auto &h_m : hyp) {
|
||||
auto &h = h_m.second;
|
||||
auto &ys = h.ys;
|
||||
const int32_t token_num_in_chunk =
|
||||
ys.size() - context_size - h.cur_scored_pos - 1;
|
||||
|
||||
if (token_num_in_chunk < 1) {
|
||||
continue;
|
||||
}
|
||||
|
||||
if (h.nn_lm_states.empty()) {
|
||||
h.nn_lm_states = Convert(GetInitStates());
|
||||
}
|
||||
|
||||
if (token_num_in_chunk >= h.lm_rescore_min_chunk) {
|
||||
std::array<int64_t, 2> x_shape{1, token_num_in_chunk};
|
||||
// shape of x and y are same
|
||||
Ort::Value x = Ort::Value::CreateTensor<int64_t>(
|
||||
allocator, x_shape.data(), x_shape.size());
|
||||
Ort::Value y = Ort::Value::CreateTensor<int64_t>(
|
||||
allocator, x_shape.data(), x_shape.size());
|
||||
int64_t *p_x = x.GetTensorMutableData<int64_t>();
|
||||
int64_t *p_y = y.GetTensorMutableData<int64_t>();
|
||||
std::copy(ys.begin() + context_size + h.cur_scored_pos, ys.end() - 1,
|
||||
p_x);
|
||||
std::copy(ys.begin() + context_size + h.cur_scored_pos + 1, ys.end(),
|
||||
p_y);
|
||||
|
||||
// streaming forward by NN LM
|
||||
auto out = Rescore(std::move(x), std::move(y),
|
||||
Convert(std::move(h.nn_lm_states)));
|
||||
|
||||
// update NN LM score in hyp
|
||||
const float *p_nll = out.first.GetTensorData<float>();
|
||||
h.lm_log_prob = -scale * (*p_nll);
|
||||
|
||||
// update NN LM states in hyp
|
||||
h.nn_lm_states = Convert(std::move(out.second));
|
||||
|
||||
h.cur_scored_pos += token_num_in_chunk;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace sherpa_onnx
|
||||
@@ -34,7 +34,7 @@ class OnlineLM {
|
||||
*
|
||||
* Caution: It returns negative log likelihood (nll), not log likelihood
|
||||
*/
|
||||
std::pair<Ort::Value, std::vector<Ort::Value>> Ort::Value Rescore(
|
||||
virtual std::pair<Ort::Value, std::vector<Ort::Value>> Rescore(
|
||||
Ort::Value x, Ort::Value y, std::vector<Ort::Value> states) = 0;
|
||||
|
||||
// This function updates hyp.lm_lob_prob of hyps.
|
||||
@@ -44,19 +44,6 @@ class OnlineLM {
|
||||
// @param hyps It is changed in-place.
|
||||
void ComputeLMScore(float scale, int32_t context_size,
|
||||
std::vector<Hypotheses> *hyps);
|
||||
/** TODO(fangjun):
|
||||
*
|
||||
* 1. Add two fields to Hypothesis
|
||||
* (a) int32_t lm_cur_pos = 0; number of scored tokens so far
|
||||
* (b) std::vector<Ort::Value> lm_states;
|
||||
* 2. When we want to score a hypothesis, we construct x and y as follows:
|
||||
*
|
||||
* std::vector x = {hyp.ys.begin() + context_size + lm_cur_pos,
|
||||
* hyp.ys.end() - 1};
|
||||
* std::vector y = {hyp.ys.begin() + context_size + lm_cur_pos + 1
|
||||
* hyp.ys.end()};
|
||||
* hyp.lm_cur_pos += hyp.ys.size() - context_size - lm_cur_pos;
|
||||
*/
|
||||
};
|
||||
|
||||
} // namespace sherpa_onnx
|
||||
|
||||
@@ -16,6 +16,8 @@
|
||||
|
||||
#include "nlohmann/json.hpp"
|
||||
#include "sherpa-onnx/csrc/file-utils.h"
|
||||
#include "sherpa-onnx/csrc/macros.h"
|
||||
#include "sherpa-onnx/csrc/online-lm.h"
|
||||
#include "sherpa-onnx/csrc/online-transducer-decoder.h"
|
||||
#include "sherpa-onnx/csrc/online-transducer-greedy-search-decoder.h"
|
||||
#include "sherpa-onnx/csrc/online-transducer-model.h"
|
||||
@@ -80,6 +82,7 @@ void OnlineRecognizerConfig::Register(ParseOptions *po) {
|
||||
feat_config.Register(po);
|
||||
model_config.Register(po);
|
||||
endpoint_config.Register(po);
|
||||
lm_config.Register(po);
|
||||
|
||||
po->Register("enable-endpoint", &enable_endpoint,
|
||||
"True to enable endpoint detection. False to disable it.");
|
||||
@@ -91,6 +94,14 @@ void OnlineRecognizerConfig::Register(ParseOptions *po) {
|
||||
}
|
||||
|
||||
bool OnlineRecognizerConfig::Validate() const {
|
||||
if (decoding_method == "modified_beam_search" && !lm_config.model.empty()) {
|
||||
if (max_active_paths <= 0) {
|
||||
SHERPA_ONNX_LOGE("max_active_paths is less than 0! Given: %d",
|
||||
max_active_paths);
|
||||
return false;
|
||||
}
|
||||
if (!lm_config.Validate()) return false;
|
||||
}
|
||||
return model_config.Validate();
|
||||
}
|
||||
|
||||
@@ -100,6 +111,7 @@ std::string OnlineRecognizerConfig::ToString() const {
|
||||
os << "OnlineRecognizerConfig(";
|
||||
os << "feat_config=" << feat_config.ToString() << ", ";
|
||||
os << "model_config=" << model_config.ToString() << ", ";
|
||||
os << "lm_config=" << lm_config.ToString() << ", ";
|
||||
os << "endpoint_config=" << endpoint_config.ToString() << ", ";
|
||||
os << "enable_endpoint=" << (enable_endpoint ? "True" : "False") << ", ";
|
||||
os << "max_active_paths=" << max_active_paths << ", ";
|
||||
@@ -116,8 +128,13 @@ class OnlineRecognizer::Impl {
|
||||
sym_(config.model_config.tokens),
|
||||
endpoint_(config_.endpoint_config) {
|
||||
if (config.decoding_method == "modified_beam_search") {
|
||||
if (!config_.lm_config.model.empty()) {
|
||||
lm_ = OnlineLM::Create(config.lm_config);
|
||||
}
|
||||
|
||||
decoder_ = std::make_unique<OnlineTransducerModifiedBeamSearchDecoder>(
|
||||
model_.get(), config_.max_active_paths);
|
||||
model_.get(), lm_.get(), config_.max_active_paths,
|
||||
config_.lm_config.scale);
|
||||
} else if (config.decoding_method == "greedy_search") {
|
||||
decoder_ =
|
||||
std::make_unique<OnlineTransducerGreedySearchDecoder>(model_.get());
|
||||
@@ -136,7 +153,8 @@ class OnlineRecognizer::Impl {
|
||||
endpoint_(config_.endpoint_config) {
|
||||
if (config.decoding_method == "modified_beam_search") {
|
||||
decoder_ = std::make_unique<OnlineTransducerModifiedBeamSearchDecoder>(
|
||||
model_.get(), config_.max_active_paths);
|
||||
model_.get(), lm_.get(), config_.max_active_paths,
|
||||
config_.lm_config.scale);
|
||||
} else if (config.decoding_method == "greedy_search") {
|
||||
decoder_ =
|
||||
std::make_unique<OnlineTransducerGreedySearchDecoder>(model_.get());
|
||||
@@ -246,6 +264,7 @@ class OnlineRecognizer::Impl {
|
||||
private:
|
||||
OnlineRecognizerConfig config_;
|
||||
std::unique_ptr<OnlineTransducerModel> model_;
|
||||
std::unique_ptr<OnlineLM> lm_;
|
||||
std::unique_ptr<OnlineTransducerDecoder> decoder_;
|
||||
SymbolTable sym_;
|
||||
Endpoint endpoint_;
|
||||
|
||||
@@ -16,6 +16,7 @@
|
||||
|
||||
#include "sherpa-onnx/csrc/endpoint.h"
|
||||
#include "sherpa-onnx/csrc/features.h"
|
||||
#include "sherpa-onnx/csrc/online-lm-config.h"
|
||||
#include "sherpa-onnx/csrc/online-stream.h"
|
||||
#include "sherpa-onnx/csrc/online-transducer-model-config.h"
|
||||
#include "sherpa-onnx/csrc/parse-options.h"
|
||||
@@ -67,10 +68,11 @@ struct OnlineRecognizerResult {
|
||||
struct OnlineRecognizerConfig {
|
||||
FeatureExtractorConfig feat_config;
|
||||
OnlineTransducerModelConfig model_config;
|
||||
OnlineLMConfig lm_config;
|
||||
EndpointConfig endpoint_config;
|
||||
bool enable_endpoint = true;
|
||||
|
||||
std::string decoding_method = "greedy_search";
|
||||
std::string decoding_method = "modified_beam_search";
|
||||
// now support modified_beam_search and greedy_search
|
||||
|
||||
int32_t max_active_paths = 4; // used only for modified_beam_search
|
||||
@@ -79,6 +81,7 @@ struct OnlineRecognizerConfig {
|
||||
|
||||
OnlineRecognizerConfig(const FeatureExtractorConfig &feat_config,
|
||||
const OnlineTransducerModelConfig &model_config,
|
||||
const OnlineLMConfig &lm_config,
|
||||
const EndpointConfig &endpoint_config,
|
||||
bool enable_endpoint,
|
||||
const std::string &decoding_method,
|
||||
|
||||
140
sherpa-onnx/csrc/online-rnn-lm.cc
Normal file
140
sherpa-onnx/csrc/online-rnn-lm.cc
Normal file
@@ -0,0 +1,140 @@
|
||||
// sherpa-onnx/csrc/on-rnn-lm.cc
|
||||
//
|
||||
// Copyright (c) 2023 Pingfeng Luo
|
||||
// Copyright (c) 2023 Xiaomi Corporation
|
||||
|
||||
#include "sherpa-onnx/csrc/online-rnn-lm.h"
|
||||
|
||||
#include <string>
|
||||
#include <utility>
|
||||
#include <vector>
|
||||
|
||||
#include "onnxruntime_cxx_api.h" // NOLINT
|
||||
#include "sherpa-onnx/csrc/macros.h"
|
||||
#include "sherpa-onnx/csrc/onnx-utils.h"
|
||||
#include "sherpa-onnx/csrc/text-utils.h"
|
||||
|
||||
namespace sherpa_onnx {
|
||||
|
||||
class OnlineRnnLM::Impl {
|
||||
public:
|
||||
explicit Impl(const OnlineLMConfig &config)
|
||||
: config_(config),
|
||||
env_(ORT_LOGGING_LEVEL_ERROR),
|
||||
sess_opts_{},
|
||||
allocator_{} {
|
||||
Init(config);
|
||||
}
|
||||
|
||||
std::pair<Ort::Value, std::vector<Ort::Value>> Rescore(
|
||||
Ort::Value x, Ort::Value y, std::vector<Ort::Value> states) {
|
||||
std::array<Ort::Value, 4> inputs = {
|
||||
std::move(x), std::move(y), std::move(states[0]), std::move(states[1])};
|
||||
|
||||
auto out =
|
||||
sess_->Run({}, input_names_ptr_.data(), inputs.data(), inputs.size(),
|
||||
output_names_ptr_.data(), output_names_ptr_.size());
|
||||
|
||||
std::vector<Ort::Value> next_states;
|
||||
next_states.reserve(2);
|
||||
next_states.push_back(std::move(out[1]));
|
||||
next_states.push_back(std::move(out[2]));
|
||||
|
||||
return {std::move(out[0]), std::move(next_states)};
|
||||
}
|
||||
|
||||
std::vector<Ort::Value> GetInitStates() const {
|
||||
std::vector<Ort::Value> ans;
|
||||
ans.reserve(init_states_.size());
|
||||
|
||||
for (const auto &s : init_states_) {
|
||||
ans.emplace_back(Clone(allocator_, &s));
|
||||
}
|
||||
|
||||
return ans;
|
||||
}
|
||||
|
||||
private:
|
||||
void Init(const OnlineLMConfig &config) {
|
||||
auto buf = ReadFile(config_.model);
|
||||
|
||||
sess_ = std::make_unique<Ort::Session>(env_, buf.data(), buf.size(),
|
||||
sess_opts_);
|
||||
|
||||
GetInputNames(sess_.get(), &input_names_, &input_names_ptr_);
|
||||
GetOutputNames(sess_.get(), &output_names_, &output_names_ptr_);
|
||||
|
||||
Ort::ModelMetadata meta_data = sess_->GetModelMetadata();
|
||||
Ort::AllocatorWithDefaultOptions allocator; // used in the macro below
|
||||
SHERPA_ONNX_READ_META_DATA(rnn_num_layers_, "num_layers");
|
||||
SHERPA_ONNX_READ_META_DATA(rnn_hidden_size_, "hidden_size");
|
||||
SHERPA_ONNX_READ_META_DATA(sos_id_, "sos_id");
|
||||
|
||||
ComputeInitStates();
|
||||
}
|
||||
|
||||
void ComputeInitStates() {
|
||||
constexpr int32_t kBatchSize = 1;
|
||||
std::array<int64_t, 3> h_shape{rnn_num_layers_, kBatchSize,
|
||||
rnn_hidden_size_};
|
||||
std::array<int64_t, 3> c_shape{rnn_num_layers_, kBatchSize,
|
||||
rnn_hidden_size_};
|
||||
Ort::Value h = Ort::Value::CreateTensor<float>(allocator_, h_shape.data(),
|
||||
h_shape.size());
|
||||
Ort::Value c = Ort::Value::CreateTensor<float>(allocator_, c_shape.data(),
|
||||
c_shape.size());
|
||||
Fill<float>(&h, 0);
|
||||
Fill<float>(&c, 0);
|
||||
std::array<int64_t, 2> x_shape{1, 1};
|
||||
// shape of x and y are same
|
||||
Ort::Value x = Ort::Value::CreateTensor<int64_t>(allocator_, x_shape.data(),
|
||||
x_shape.size());
|
||||
Ort::Value y = Ort::Value::CreateTensor<int64_t>(allocator_, x_shape.data(),
|
||||
x_shape.size());
|
||||
*x.GetTensorMutableData<int64_t>() = sos_id_;
|
||||
*y.GetTensorMutableData<int64_t>() = sos_id_;
|
||||
|
||||
std::vector<Ort::Value> states;
|
||||
states.push_back(std::move(h));
|
||||
states.push_back(std::move(c));
|
||||
auto pair = Rescore(std::move(x), std::move(y), std::move(states));
|
||||
|
||||
init_states_ = std::move(pair.second);
|
||||
}
|
||||
|
||||
private:
|
||||
OnlineLMConfig config_;
|
||||
Ort::Env env_;
|
||||
Ort::SessionOptions sess_opts_;
|
||||
Ort::AllocatorWithDefaultOptions allocator_;
|
||||
|
||||
std::unique_ptr<Ort::Session> sess_;
|
||||
|
||||
std::vector<std::string> input_names_;
|
||||
std::vector<const char *> input_names_ptr_;
|
||||
|
||||
std::vector<std::string> output_names_;
|
||||
std::vector<const char *> output_names_ptr_;
|
||||
|
||||
std::vector<Ort::Value> init_states_;
|
||||
|
||||
int32_t rnn_num_layers_ = 2;
|
||||
int32_t rnn_hidden_size_ = 512;
|
||||
int32_t sos_id_ = 1;
|
||||
};
|
||||
|
||||
OnlineRnnLM::OnlineRnnLM(const OnlineLMConfig &config)
|
||||
: impl_(std::make_unique<Impl>(config)) {}
|
||||
|
||||
OnlineRnnLM::~OnlineRnnLM() = default;
|
||||
|
||||
std::vector<Ort::Value> OnlineRnnLM::GetInitStates() {
|
||||
return impl_->GetInitStates();
|
||||
}
|
||||
|
||||
std::pair<Ort::Value, std::vector<Ort::Value>> OnlineRnnLM::Rescore(
|
||||
Ort::Value x, Ort::Value y, std::vector<Ort::Value> states) {
|
||||
return impl_->Rescore(std::move(x), std::move(y), std::move(states));
|
||||
}
|
||||
|
||||
} // namespace sherpa_onnx
|
||||
48
sherpa-onnx/csrc/online-rnn-lm.h
Normal file
48
sherpa-onnx/csrc/online-rnn-lm.h
Normal file
@@ -0,0 +1,48 @@
|
||||
// sherpa-onnx/csrc/online-rnn-lm.h
|
||||
//
|
||||
// Copyright (c) 2023 Pingfeng Luo
|
||||
// Copyright (c) 2023 Xiaomi Corporation
|
||||
|
||||
#ifndef SHERPA_ONNX_CSRC_ONLINE_RNN_LM_H_
|
||||
#define SHERPA_ONNX_CSRC_ONLINE_RNN_LM_H_
|
||||
|
||||
#include <memory>
|
||||
#include <utility>
|
||||
#include <vector>
|
||||
|
||||
#include "onnxruntime_cxx_api.h" // NOLINT
|
||||
#include "sherpa-onnx/csrc/online-lm-config.h"
|
||||
#include "sherpa-onnx/csrc/online-lm.h"
|
||||
|
||||
namespace sherpa_onnx {
|
||||
|
||||
class OnlineRnnLM : public OnlineLM {
|
||||
public:
|
||||
~OnlineRnnLM() override;
|
||||
|
||||
explicit OnlineRnnLM(const OnlineLMConfig &config);
|
||||
|
||||
std::vector<Ort::Value> GetInitStates() override;
|
||||
|
||||
/** Rescore a batch of sentences.
|
||||
*
|
||||
* @param x A 2-D tensor of shape (N, L) with data type int64.
|
||||
* @param y A 2-D tensor of shape (N, L) with data type int64.
|
||||
* @param states It contains the states for the LM model
|
||||
* @return Return a pair containingo
|
||||
* - negative loglike
|
||||
* - updated states
|
||||
*
|
||||
* Caution: It returns negative log likelihood (nll), not log likelihood
|
||||
*/
|
||||
std::pair<Ort::Value, std::vector<Ort::Value>> Rescore(
|
||||
Ort::Value x, Ort::Value y, std::vector<Ort::Value> states) override;
|
||||
|
||||
private:
|
||||
class Impl;
|
||||
std::unique_ptr<Impl> impl_;
|
||||
};
|
||||
|
||||
} // namespace sherpa_onnx
|
||||
|
||||
#endif // SHERPA_ONNX_CSRC_ONLINE_RNN_LM_H_
|
||||
@@ -156,6 +156,10 @@ void OnlineTransducerModifiedBeamSearchDecoder::Decode(
|
||||
} // for (int32_t b = 0; b != batch_size; ++b)
|
||||
}
|
||||
|
||||
if (lm_) {
|
||||
lm_->ComputeLMScore(lm_scale_, model_->ContextSize(), &cur);
|
||||
}
|
||||
|
||||
for (int32_t b = 0; b != batch_size; ++b) {
|
||||
auto &hyps = cur[b];
|
||||
auto best_hyp = hyps.GetMostProbable(true);
|
||||
|
||||
@@ -8,6 +8,7 @@
|
||||
|
||||
#include <vector>
|
||||
|
||||
#include "sherpa-onnx/csrc/online-lm.h"
|
||||
#include "sherpa-onnx/csrc/online-transducer-decoder.h"
|
||||
#include "sherpa-onnx/csrc/online-transducer-model.h"
|
||||
|
||||
@@ -17,8 +18,13 @@ class OnlineTransducerModifiedBeamSearchDecoder
|
||||
: public OnlineTransducerDecoder {
|
||||
public:
|
||||
OnlineTransducerModifiedBeamSearchDecoder(OnlineTransducerModel *model,
|
||||
int32_t max_active_paths)
|
||||
: model_(model), max_active_paths_(max_active_paths) {}
|
||||
OnlineLM *lm,
|
||||
int32_t max_active_paths,
|
||||
float lm_scale)
|
||||
: model_(model),
|
||||
lm_(lm),
|
||||
max_active_paths_(max_active_paths),
|
||||
lm_scale_(lm_scale) {}
|
||||
|
||||
OnlineTransducerDecoderResult GetEmptyResult() const override;
|
||||
|
||||
@@ -31,7 +37,10 @@ class OnlineTransducerModifiedBeamSearchDecoder
|
||||
|
||||
private:
|
||||
OnlineTransducerModel *model_; // Not owned
|
||||
OnlineLM *lm_; // Not owned
|
||||
|
||||
int32_t max_active_paths_;
|
||||
float lm_scale_; // used only when lm_ is not nullptr
|
||||
};
|
||||
|
||||
} // namespace sherpa_onnx
|
||||
|
||||
@@ -1,13 +1,13 @@
|
||||
// sherpa-onnx/csrc/onnx-utils.cc
|
||||
//
|
||||
// Copyright (c) 2023 Xiaomi Corporation
|
||||
// Copyright (c) 2023 Pingfeng Luo
|
||||
#include "sherpa-onnx/csrc/onnx-utils.h"
|
||||
|
||||
#include <algorithm>
|
||||
#include <fstream>
|
||||
#include <sstream>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
|
||||
#if __ANDROID_API__ >= 9
|
||||
#include "android/asset_manager.h"
|
||||
@@ -218,4 +218,31 @@ Ort::Value Repeat(OrtAllocator *allocator, Ort::Value *cur_encoder_out,
|
||||
return ans;
|
||||
}
|
||||
|
||||
CopyableOrtValue::CopyableOrtValue(const CopyableOrtValue &other) {
|
||||
*this = other;
|
||||
}
|
||||
|
||||
CopyableOrtValue &CopyableOrtValue::operator=(const CopyableOrtValue &other) {
|
||||
if (this == &other) {
|
||||
return *this;
|
||||
}
|
||||
if (other.value) {
|
||||
Ort::AllocatorWithDefaultOptions allocator;
|
||||
value = Clone(allocator, &other.value);
|
||||
}
|
||||
return *this;
|
||||
}
|
||||
|
||||
CopyableOrtValue::CopyableOrtValue(CopyableOrtValue &&other) {
|
||||
*this = std::move(other);
|
||||
}
|
||||
|
||||
CopyableOrtValue &CopyableOrtValue::operator=(CopyableOrtValue &&other) {
|
||||
if (this == &other) {
|
||||
return *this;
|
||||
}
|
||||
value = std::move(other.value);
|
||||
return *this;
|
||||
}
|
||||
|
||||
} // namespace sherpa_onnx
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
// sherpa-onnx/csrc/onnx-utils.h
|
||||
//
|
||||
// Copyright (c) 2023 Xiaomi Corporation
|
||||
// Copyright (c) 2023 Pingfeng Luo
|
||||
#ifndef SHERPA_ONNX_CSRC_ONNX_UTILS_H_
|
||||
#define SHERPA_ONNX_CSRC_ONNX_UTILS_H_
|
||||
|
||||
@@ -13,6 +14,7 @@
|
||||
#include <cassert>
|
||||
#include <ostream>
|
||||
#include <string>
|
||||
#include <utility>
|
||||
#include <vector>
|
||||
|
||||
#if __ANDROID_API__ >= 9
|
||||
@@ -89,6 +91,24 @@ std::vector<char> ReadFile(AAssetManager *mgr, const std::string &filename);
|
||||
// TODO(fangjun): Document it
|
||||
Ort::Value Repeat(OrtAllocator *allocator, Ort::Value *cur_encoder_out,
|
||||
const std::vector<int32_t> &hyps_num_split);
|
||||
|
||||
struct CopyableOrtValue {
|
||||
Ort::Value value{nullptr};
|
||||
|
||||
CopyableOrtValue() = default;
|
||||
|
||||
/*explicit*/ CopyableOrtValue(Ort::Value v) // NOLINT
|
||||
: value(std::move(v)) {}
|
||||
|
||||
CopyableOrtValue(const CopyableOrtValue &other);
|
||||
|
||||
CopyableOrtValue &operator=(const CopyableOrtValue &other);
|
||||
|
||||
CopyableOrtValue(CopyableOrtValue &&other);
|
||||
|
||||
CopyableOrtValue &operator=(CopyableOrtValue &&other);
|
||||
};
|
||||
|
||||
} // namespace sherpa_onnx
|
||||
|
||||
#endif // SHERPA_ONNX_CSRC_ONNX_UTILS_H_
|
||||
|
||||
@@ -13,8 +13,9 @@
|
||||
#include "sherpa-onnx/csrc/symbol-table.h"
|
||||
#include "sherpa-onnx/csrc/wave-reader.h"
|
||||
|
||||
// TODO(fangjun): Use ParseOptions as we are getting more args
|
||||
int main(int32_t argc, char *argv[]) {
|
||||
if (argc < 6 || argc > 8) {
|
||||
if (argc < 6 || argc > 9) {
|
||||
const char *usage = R"usage(
|
||||
Usage:
|
||||
./bin/sherpa-onnx \
|
||||
@@ -22,7 +23,7 @@ Usage:
|
||||
/path/to/encoder.onnx \
|
||||
/path/to/decoder.onnx \
|
||||
/path/to/joiner.onnx \
|
||||
/path/to/foo.wav [num_threads [decoding_method]]
|
||||
/path/to/foo.wav [num_threads [decoding_method [/path/to/rnn_lm.onnx]]]
|
||||
|
||||
Default value for num_threads is 2.
|
||||
Valid values for decoding_method: greedy_search (default), modified_beam_search.
|
||||
@@ -53,10 +54,12 @@ for a list of pre-trained models to download.
|
||||
if (argc == 7 && atoi(argv[6]) > 0) {
|
||||
config.model_config.num_threads = atoi(argv[6]);
|
||||
}
|
||||
|
||||
if (argc == 8) {
|
||||
config.decoding_method = argv[7];
|
||||
}
|
||||
if (argc == 9) {
|
||||
config.lm_config.model = argv[8];
|
||||
}
|
||||
config.max_active_paths = 4;
|
||||
|
||||
fprintf(stderr, "%s\n", config.ToString().c_str());
|
||||
|
||||
@@ -16,9 +16,8 @@
|
||||
#if __ANDROID_API__ >= 9
|
||||
#include "android/asset_manager.h"
|
||||
#include "android/asset_manager_jni.h"
|
||||
#else
|
||||
#include <fstream>
|
||||
#endif
|
||||
#include <fstream>
|
||||
|
||||
#include "sherpa-onnx/csrc/macros.h"
|
||||
#include "sherpa-onnx/csrc/online-recognizer.h"
|
||||
@@ -188,6 +187,21 @@ static OnlineRecognizerConfig GetConfig(JNIEnv *env, jobject config) {
|
||||
fid = env->GetFieldID(model_config_cls, "debug", "Z");
|
||||
ans.model_config.debug = env->GetBooleanField(model_config, fid);
|
||||
|
||||
//---------- rnn lm model config ----------
|
||||
fid = env->GetFieldID(cls, "lmConfig",
|
||||
"Lcom/k2fsa/sherpa/onnx/OnlineLMConfig;");
|
||||
jobject lm_model_config = env->GetObjectField(config, fid);
|
||||
jclass lm_model_config_cls = env->GetObjectClass(lm_model_config);
|
||||
|
||||
fid = env->GetFieldID(lm_model_config_cls, "model", "Ljava/lang/String;");
|
||||
s = (jstring)env->GetObjectField(lm_model_config, fid);
|
||||
p = env->GetStringUTFChars(s, nullptr);
|
||||
ans.lm_config.model = p;
|
||||
env->ReleaseStringUTFChars(s, p);
|
||||
|
||||
fid = env->GetFieldID(lm_model_config_cls, "scale", "F");
|
||||
ans.lm_config.scale = env->GetFloatField(lm_model_config, fid);
|
||||
|
||||
return ans;
|
||||
}
|
||||
|
||||
|
||||
@@ -11,6 +11,7 @@ pybind11_add_module(_sherpa_onnx
|
||||
offline-recognizer.cc
|
||||
offline-stream.cc
|
||||
offline-transducer-model-config.cc
|
||||
online-lm-config.cc
|
||||
online-recognizer.cc
|
||||
online-stream.cc
|
||||
online-transducer-model-config.cc
|
||||
|
||||
23
sherpa-onnx/python/csrc/online-lm-config.cc
Normal file
23
sherpa-onnx/python/csrc/online-lm-config.cc
Normal file
@@ -0,0 +1,23 @@
|
||||
// sherpa-onnx/python/csrc/online-lm-config.cc
|
||||
//
|
||||
// Copyright (c) 2023 Xiaomi Corporation
|
||||
|
||||
#include "sherpa-onnx/python/csrc/online-lm-config.h"
|
||||
|
||||
#include <string>
|
||||
|
||||
#include "sherpa-onnx//csrc/online-lm-config.h"
|
||||
|
||||
namespace sherpa_onnx {
|
||||
|
||||
void PybindOnlineLMConfig(py::module *m) {
|
||||
using PyClass = OnlineLMConfig;
|
||||
py::class_<PyClass>(*m, "OnlineLMConfig")
|
||||
.def(py::init<const std::string &, float>(), py::arg("model") = "",
|
||||
py::arg("scale") = 0.5f)
|
||||
.def_readwrite("model", &PyClass::model)
|
||||
.def_readwrite("scale", &PyClass::scale)
|
||||
.def("__str__", &PyClass::ToString);
|
||||
}
|
||||
|
||||
} // namespace sherpa_onnx
|
||||
16
sherpa-onnx/python/csrc/online-lm-config.h
Normal file
16
sherpa-onnx/python/csrc/online-lm-config.h
Normal file
@@ -0,0 +1,16 @@
|
||||
// sherpa-onnx/python/csrc/online-lm-config.h
|
||||
//
|
||||
// Copyright (c) 2023 Xiaomi Corporation
|
||||
|
||||
#ifndef SHERPA_ONNX_PYTHON_CSRC_ONLINE_LM_CONFIG_H_
|
||||
#define SHERPA_ONNX_PYTHON_CSRC_ONLINE_LM_CONFIG_H_
|
||||
|
||||
#include "sherpa-onnx/python/csrc/sherpa-onnx.h"
|
||||
|
||||
namespace sherpa_onnx {
|
||||
|
||||
void PybindOnlineLMConfig(py::module *m);
|
||||
|
||||
}
|
||||
|
||||
#endif // SHERPA_ONNX_PYTHON_CSRC_ONLINE_LM_CONFIG_H_
|
||||
@@ -21,11 +21,13 @@ static void PybindOnlineRecognizerConfig(py::module *m) {
|
||||
using PyClass = OnlineRecognizerConfig;
|
||||
py::class_<PyClass>(*m, "OnlineRecognizerConfig")
|
||||
.def(py::init<const FeatureExtractorConfig &,
|
||||
const OnlineTransducerModelConfig &, const EndpointConfig &,
|
||||
bool, const std::string &, int32_t>(),
|
||||
const OnlineTransducerModelConfig &, const OnlineLMConfig &,
|
||||
const EndpointConfig &, bool, const std::string &,
|
||||
int32_t>(),
|
||||
py::arg("feat_config"), py::arg("model_config"),
|
||||
py::arg("endpoint_config"), py::arg("enable_endpoint"),
|
||||
py::arg("decoding_method"), py::arg("max_active_paths"))
|
||||
py::arg("lm_config") = OnlineLMConfig(), py::arg("endpoint_config"),
|
||||
py::arg("enable_endpoint"), py::arg("decoding_method"),
|
||||
py::arg("max_active_paths"))
|
||||
.def_readwrite("feat_config", &PyClass::feat_config)
|
||||
.def_readwrite("model_config", &PyClass::model_config)
|
||||
.def_readwrite("endpoint_config", &PyClass::endpoint_config)
|
||||
|
||||
@@ -11,6 +11,7 @@
|
||||
#include "sherpa-onnx/python/csrc/offline-model-config.h"
|
||||
#include "sherpa-onnx/python/csrc/offline-recognizer.h"
|
||||
#include "sherpa-onnx/python/csrc/offline-stream.h"
|
||||
#include "sherpa-onnx/python/csrc/online-lm-config.h"
|
||||
#include "sherpa-onnx/python/csrc/online-recognizer.h"
|
||||
#include "sherpa-onnx/python/csrc/online-stream.h"
|
||||
#include "sherpa-onnx/python/csrc/online-transducer-model-config.h"
|
||||
@@ -22,6 +23,7 @@ PYBIND11_MODULE(_sherpa_onnx, m) {
|
||||
|
||||
PybindFeatures(&m);
|
||||
PybindOnlineTransducerModelConfig(&m);
|
||||
PybindOnlineLMConfig(&m);
|
||||
PybindOnlineStream(&m);
|
||||
PybindEndpoint(&m);
|
||||
PybindOnlineRecognizer(&m);
|
||||
|
||||
Reference in New Issue
Block a user