Fix code style issues (#774)
This commit is contained in:
@@ -26,8 +26,7 @@ void FeatureExtractorConfig::Register(ParseOptions *po) {
|
|||||||
po->Register("feat-dim", &feature_dim,
|
po->Register("feat-dim", &feature_dim,
|
||||||
"Feature dimension. Must match the one expected by the model.");
|
"Feature dimension. Must match the one expected by the model.");
|
||||||
|
|
||||||
po->Register("low-freq", &low_freq,
|
po->Register("low-freq", &low_freq, "Low cutoff frequency for mel bins");
|
||||||
"Low cutoff frequency for mel bins");
|
|
||||||
|
|
||||||
po->Register("high-freq", &high_freq,
|
po->Register("high-freq", &high_freq,
|
||||||
"High cutoff frequency for mel bins "
|
"High cutoff frequency for mel bins "
|
||||||
@@ -67,7 +66,7 @@ class FeatureExtractor::Impl {
|
|||||||
opts_.mel_opts.num_bins = config.feature_dim;
|
opts_.mel_opts.num_bins = config.feature_dim;
|
||||||
|
|
||||||
opts_.mel_opts.high_freq = config.high_freq;
|
opts_.mel_opts.high_freq = config.high_freq;
|
||||||
opts_.mel_opts.low_freq = config.low_freq;
|
opts_.mel_opts.low_freq = config.low_freq;
|
||||||
|
|
||||||
opts_.mel_opts.is_librosa = config.is_librosa;
|
opts_.mel_opts.is_librosa = config.is_librosa;
|
||||||
|
|
||||||
|
|||||||
@@ -15,7 +15,7 @@ void OfflineLMConfig::Register(ParseOptions *po) {
|
|||||||
po->Register("lm", &model, "Path to LM model.");
|
po->Register("lm", &model, "Path to LM model.");
|
||||||
po->Register("lm-scale", &scale, "LM scale.");
|
po->Register("lm-scale", &scale, "LM scale.");
|
||||||
po->Register("lm-num-threads", &lm_num_threads,
|
po->Register("lm-num-threads", &lm_num_threads,
|
||||||
"Number of threads to run the neural network of LM model");
|
"Number of threads to run the neural network of LM model");
|
||||||
po->Register("lm-provider", &lm_provider,
|
po->Register("lm-provider", &lm_provider,
|
||||||
"Specify a provider to LM model use: cpu, cuda, coreml");
|
"Specify a provider to LM model use: cpu, cuda, coreml");
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -80,9 +80,8 @@ class OfflineRecognizerTransducerImpl : public OfflineRecognizerImpl {
|
|||||||
InitHotwords();
|
InitHotwords();
|
||||||
}
|
}
|
||||||
if (config_.decoding_method == "greedy_search") {
|
if (config_.decoding_method == "greedy_search") {
|
||||||
decoder_ =
|
decoder_ = std::make_unique<OfflineTransducerGreedySearchDecoder>(
|
||||||
std::make_unique<OfflineTransducerGreedySearchDecoder>(
|
model_.get(), config_.blank_penalty);
|
||||||
model_.get(), config_.blank_penalty);
|
|
||||||
} else if (config_.decoding_method == "modified_beam_search") {
|
} else if (config_.decoding_method == "modified_beam_search") {
|
||||||
if (!config_.lm_config.model.empty()) {
|
if (!config_.lm_config.model.empty()) {
|
||||||
lm_ = OfflineLM::Create(config.lm_config);
|
lm_ = OfflineLM::Create(config.lm_config);
|
||||||
@@ -106,9 +105,8 @@ class OfflineRecognizerTransducerImpl : public OfflineRecognizerImpl {
|
|||||||
model_(std::make_unique<OfflineTransducerModel>(mgr,
|
model_(std::make_unique<OfflineTransducerModel>(mgr,
|
||||||
config_.model_config)) {
|
config_.model_config)) {
|
||||||
if (config_.decoding_method == "greedy_search") {
|
if (config_.decoding_method == "greedy_search") {
|
||||||
decoder_ =
|
decoder_ = std::make_unique<OfflineTransducerGreedySearchDecoder>(
|
||||||
std::make_unique<OfflineTransducerGreedySearchDecoder>(
|
model_.get(), config_.blank_penalty);
|
||||||
model_.get(), config_.blank_penalty);
|
|
||||||
} else if (config_.decoding_method == "modified_beam_search") {
|
} else if (config_.decoding_method == "modified_beam_search") {
|
||||||
if (!config_.lm_config.model.empty()) {
|
if (!config_.lm_config.model.empty()) {
|
||||||
lm_ = OfflineLM::Create(mgr, config.lm_config);
|
lm_ = OfflineLM::Create(mgr, config.lm_config);
|
||||||
|
|||||||
@@ -16,8 +16,7 @@ class OfflineTransducerGreedySearchDecoder : public OfflineTransducerDecoder {
|
|||||||
public:
|
public:
|
||||||
explicit OfflineTransducerGreedySearchDecoder(OfflineTransducerModel *model,
|
explicit OfflineTransducerGreedySearchDecoder(OfflineTransducerModel *model,
|
||||||
float blank_penalty)
|
float blank_penalty)
|
||||||
: model_(model),
|
: model_(model), blank_penalty_(blank_penalty) {}
|
||||||
blank_penalty_(blank_penalty) {}
|
|
||||||
|
|
||||||
std::vector<OfflineTransducerDecoderResult> Decode(
|
std::vector<OfflineTransducerDecoderResult> Decode(
|
||||||
Ort::Value encoder_out, Ort::Value encoder_out_length,
|
Ort::Value encoder_out, Ort::Value encoder_out_length,
|
||||||
|
|||||||
@@ -102,9 +102,9 @@ void OfflineWebsocketDecoder::Decode() {
|
|||||||
asio::post(server_->GetConnectionContext(),
|
asio::post(server_->GetConnectionContext(),
|
||||||
[this, hdl, result = ss[i]->GetResult()]() {
|
[this, hdl, result = ss[i]->GetResult()]() {
|
||||||
websocketpp::lib::error_code ec;
|
websocketpp::lib::error_code ec;
|
||||||
server_->GetServer().send(
|
server_->GetServer().send(hdl, result.AsJsonString(),
|
||||||
hdl, result.AsJsonString(),
|
websocketpp::frame::opcode::text,
|
||||||
websocketpp::frame::opcode::text, ec);
|
ec);
|
||||||
if (ec) {
|
if (ec) {
|
||||||
server_->GetServer().get_alog().write(
|
server_->GetServer().get_alog().write(
|
||||||
websocketpp::log::alevel::app, ec.message());
|
websocketpp::log::alevel::app, ec.message());
|
||||||
|
|||||||
@@ -15,7 +15,7 @@ void OnlineLMConfig::Register(ParseOptions *po) {
|
|||||||
po->Register("lm", &model, "Path to LM model.");
|
po->Register("lm", &model, "Path to LM model.");
|
||||||
po->Register("lm-scale", &scale, "LM scale.");
|
po->Register("lm-scale", &scale, "LM scale.");
|
||||||
po->Register("lm-num-threads", &lm_num_threads,
|
po->Register("lm-num-threads", &lm_num_threads,
|
||||||
"Number of threads to run the neural network of LM model");
|
"Number of threads to run the neural network of LM model");
|
||||||
po->Register("lm-provider", &lm_provider,
|
po->Register("lm-provider", &lm_provider,
|
||||||
"Specify a provider to LM model use: cpu, cuda, coreml");
|
"Specify a provider to LM model use: cpu, cuda, coreml");
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -22,7 +22,7 @@ struct OnlineLMConfig {
|
|||||||
OnlineLMConfig() = default;
|
OnlineLMConfig() = default;
|
||||||
|
|
||||||
OnlineLMConfig(const std::string &model, float scale, int32_t lm_num_threads,
|
OnlineLMConfig(const std::string &model, float scale, int32_t lm_num_threads,
|
||||||
const std::string &lm_provider)
|
const std::string &lm_provider)
|
||||||
: model(model),
|
: model(model),
|
||||||
scale(scale),
|
scale(scale),
|
||||||
lm_num_threads(lm_num_threads),
|
lm_num_threads(lm_num_threads),
|
||||||
|
|||||||
@@ -40,8 +40,7 @@ struct OnlineModelConfig {
|
|||||||
const OnlineWenetCtcModelConfig &wenet_ctc,
|
const OnlineWenetCtcModelConfig &wenet_ctc,
|
||||||
const OnlineZipformer2CtcModelConfig &zipformer2_ctc,
|
const OnlineZipformer2CtcModelConfig &zipformer2_ctc,
|
||||||
const std::string &tokens, int32_t num_threads,
|
const std::string &tokens, int32_t num_threads,
|
||||||
int32_t warm_up, bool debug,
|
int32_t warm_up, bool debug, const std::string &provider,
|
||||||
const std::string &provider,
|
|
||||||
const std::string &model_type)
|
const std::string &model_type)
|
||||||
: transducer(transducer),
|
: transducer(transducer),
|
||||||
paraformer(paraformer),
|
paraformer(paraformer),
|
||||||
|
|||||||
@@ -30,9 +30,9 @@
|
|||||||
#include "sherpa-onnx/csrc/online-transducer-greedy-search-decoder.h"
|
#include "sherpa-onnx/csrc/online-transducer-greedy-search-decoder.h"
|
||||||
#include "sherpa-onnx/csrc/online-transducer-model.h"
|
#include "sherpa-onnx/csrc/online-transducer-model.h"
|
||||||
#include "sherpa-onnx/csrc/online-transducer-modified-beam-search-decoder.h"
|
#include "sherpa-onnx/csrc/online-transducer-modified-beam-search-decoder.h"
|
||||||
|
#include "sherpa-onnx/csrc/onnx-utils.h"
|
||||||
#include "sherpa-onnx/csrc/symbol-table.h"
|
#include "sherpa-onnx/csrc/symbol-table.h"
|
||||||
#include "sherpa-onnx/csrc/utils.h"
|
#include "sherpa-onnx/csrc/utils.h"
|
||||||
#include "sherpa-onnx/csrc/onnx-utils.h"
|
|
||||||
|
|
||||||
namespace sherpa_onnx {
|
namespace sherpa_onnx {
|
||||||
|
|
||||||
@@ -185,7 +185,7 @@ class OnlineRecognizerTransducerImpl : public OnlineRecognizerImpl {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Warmping up engine with wp: warm_up count and max-batch-size
|
// Warmping up engine with wp: warm_up count and max-batch-size
|
||||||
void WarmpUpRecognizer(int32_t warmup, int32_t mbs) const {
|
void WarmpUpRecognizer(int32_t warmup, int32_t mbs) const override {
|
||||||
auto max_batch_size = mbs;
|
auto max_batch_size = mbs;
|
||||||
if (warmup <= 0 || warmup > 100) {
|
if (warmup <= 0 || warmup > 100) {
|
||||||
return;
|
return;
|
||||||
@@ -210,8 +210,8 @@ class OnlineRecognizerTransducerImpl : public OnlineRecognizerImpl {
|
|||||||
for (int32_t i = 0; i != warmup; ++i) {
|
for (int32_t i = 0; i != warmup; ++i) {
|
||||||
auto states = model_->StackStates(states_vec);
|
auto states = model_->StackStates(states_vec);
|
||||||
Ort::Value x = Ort::Value::CreateTensor(memory_info, features_vec.data(),
|
Ort::Value x = Ort::Value::CreateTensor(memory_info, features_vec.data(),
|
||||||
features_vec.size(), x_shape.data(),
|
features_vec.size(),
|
||||||
x_shape.size());
|
x_shape.data(), x_shape.size());
|
||||||
auto x_copy = Clone(model_->Allocator(), &x);
|
auto x_copy = Clone(model_->Allocator(), &x);
|
||||||
auto pair = model_->RunEncoder(std::move(x), std::move(states),
|
auto pair = model_->RunEncoder(std::move(x), std::move(states),
|
||||||
std::move(x_copy));
|
std::move(x_copy));
|
||||||
|
|||||||
@@ -168,7 +168,7 @@ class OnlineRecognizer {
|
|||||||
*
|
*
|
||||||
* @param warmup Number of warmups.
|
* @param warmup Number of warmups.
|
||||||
* @param mbs : max-batch-size Max batch size for the models
|
* @param mbs : max-batch-size Max batch size for the models
|
||||||
*/
|
*/
|
||||||
void WarmpUpRecognizer(int32_t warmup, int32_t mbs) const;
|
void WarmpUpRecognizer(int32_t warmup, int32_t mbs) const;
|
||||||
|
|
||||||
/** Decode multiple streams in parallel
|
/** Decode multiple streams in parallel
|
||||||
|
|||||||
@@ -12,8 +12,8 @@
|
|||||||
#include "onnxruntime_cxx_api.h" // NOLINT
|
#include "onnxruntime_cxx_api.h" // NOLINT
|
||||||
#include "sherpa-onnx/csrc/macros.h"
|
#include "sherpa-onnx/csrc/macros.h"
|
||||||
#include "sherpa-onnx/csrc/onnx-utils.h"
|
#include "sherpa-onnx/csrc/onnx-utils.h"
|
||||||
#include "sherpa-onnx/csrc/text-utils.h"
|
|
||||||
#include "sherpa-onnx/csrc/session.h"
|
#include "sherpa-onnx/csrc/session.h"
|
||||||
|
#include "sherpa-onnx/csrc/text-utils.h"
|
||||||
|
|
||||||
namespace sherpa_onnx {
|
namespace sherpa_onnx {
|
||||||
|
|
||||||
@@ -42,10 +42,9 @@ class OnlineRnnLM::Impl {
|
|||||||
// nn_lm_scores
|
// nn_lm_scores
|
||||||
std::array<int64_t, 2> x_shape{1, 1};
|
std::array<int64_t, 2> x_shape{1, 1};
|
||||||
Ort::Value x = Ort::Value::CreateTensor<int64_t>(allocator_, x_shape.data(),
|
Ort::Value x = Ort::Value::CreateTensor<int64_t>(allocator_, x_shape.data(),
|
||||||
x_shape.size());
|
x_shape.size());
|
||||||
*x.GetTensorMutableData<int64_t>() = hyp->ys.back();
|
*x.GetTensorMutableData<int64_t>() = hyp->ys.back();
|
||||||
auto lm_out =
|
auto lm_out = ScoreToken(std::move(x), Convert(hyp->nn_lm_states));
|
||||||
ScoreToken(std::move(x), Convert(hyp->nn_lm_states));
|
|
||||||
hyp->nn_lm_scores.value = std::move(lm_out.first);
|
hyp->nn_lm_scores.value = std::move(lm_out.first);
|
||||||
hyp->nn_lm_states = Convert(std::move(lm_out.second));
|
hyp->nn_lm_states = Convert(std::move(lm_out.second));
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -71,11 +71,9 @@ void OnlineTransducerGreedySearchDecoder::StripLeadingBlanks(
|
|||||||
r->tokens = std::vector<int64_t>(start, end);
|
r->tokens = std::vector<int64_t>(start, end);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
void OnlineTransducerGreedySearchDecoder::Decode(
|
void OnlineTransducerGreedySearchDecoder::Decode(
|
||||||
Ort::Value encoder_out,
|
Ort::Value encoder_out,
|
||||||
std::vector<OnlineTransducerDecoderResult> *result) {
|
std::vector<OnlineTransducerDecoderResult> *result) {
|
||||||
|
|
||||||
std::vector<int64_t> encoder_out_shape =
|
std::vector<int64_t> encoder_out_shape =
|
||||||
encoder_out.GetTensorTypeAndShapeInfo().GetShape();
|
encoder_out.GetTensorTypeAndShapeInfo().GetShape();
|
||||||
|
|
||||||
@@ -106,7 +104,8 @@ void OnlineTransducerGreedySearchDecoder::Decode(
|
|||||||
r.decoder_out.GetTensorTypeAndShapeInfo().GetShape();
|
r.decoder_out.GetTensorTypeAndShapeInfo().GetShape();
|
||||||
decoder_out_shape[0] = batch_size;
|
decoder_out_shape[0] = batch_size;
|
||||||
decoder_out = Ort::Value::CreateTensor<float>(model_->Allocator(),
|
decoder_out = Ort::Value::CreateTensor<float>(model_->Allocator(),
|
||||||
decoder_out_shape.data(), decoder_out_shape.size());
|
decoder_out_shape.data(),
|
||||||
|
decoder_out_shape.size());
|
||||||
UseCachedDecoderOut(*result, &decoder_out);
|
UseCachedDecoderOut(*result, &decoder_out);
|
||||||
} else {
|
} else {
|
||||||
Ort::Value decoder_input = model_->BuildDecoderInput(*result);
|
Ort::Value decoder_input = model_->BuildDecoderInput(*result);
|
||||||
@@ -116,8 +115,8 @@ void OnlineTransducerGreedySearchDecoder::Decode(
|
|||||||
for (int32_t t = 0; t != num_frames; ++t) {
|
for (int32_t t = 0; t != num_frames; ++t) {
|
||||||
Ort::Value cur_encoder_out =
|
Ort::Value cur_encoder_out =
|
||||||
GetEncoderOutFrame(model_->Allocator(), &encoder_out, t);
|
GetEncoderOutFrame(model_->Allocator(), &encoder_out, t);
|
||||||
Ort::Value logit = model_->RunJoiner(
|
Ort::Value logit =
|
||||||
std::move(cur_encoder_out), View(&decoder_out));
|
model_->RunJoiner(std::move(cur_encoder_out), View(&decoder_out));
|
||||||
|
|
||||||
float *p_logit = logit.GetTensorMutableData<float>();
|
float *p_logit = logit.GetTensorMutableData<float>();
|
||||||
|
|
||||||
@@ -145,9 +144,9 @@ void OnlineTransducerGreedySearchDecoder::Decode(
|
|||||||
|
|
||||||
// export the per-token log scores
|
// export the per-token log scores
|
||||||
if (y != 0 && y != unk_id_) {
|
if (y != 0 && y != unk_id_) {
|
||||||
LogSoftmax(p_logit, vocab_size); // renormalize probabilities,
|
LogSoftmax(p_logit, vocab_size); // renormalize probabilities,
|
||||||
// save time by doing it only for
|
// save time by doing it only for
|
||||||
// emitted symbols
|
// emitted symbols
|
||||||
const float *p_logprob = p_logit; // rename p_logit as p_logprob,
|
const float *p_logprob = p_logit; // rename p_logit as p_logprob,
|
||||||
// now it contains normalized
|
// now it contains normalized
|
||||||
// probability
|
// probability
|
||||||
|
|||||||
@@ -15,8 +15,7 @@ namespace sherpa_onnx {
|
|||||||
class OnlineTransducerGreedySearchDecoder : public OnlineTransducerDecoder {
|
class OnlineTransducerGreedySearchDecoder : public OnlineTransducerDecoder {
|
||||||
public:
|
public:
|
||||||
OnlineTransducerGreedySearchDecoder(OnlineTransducerModel *model,
|
OnlineTransducerGreedySearchDecoder(OnlineTransducerModel *model,
|
||||||
int32_t unk_id,
|
int32_t unk_id, float blank_penalty)
|
||||||
float blank_penalty)
|
|
||||||
: model_(model), unk_id_(unk_id), blank_penalty_(blank_penalty) {}
|
: model_(model), unk_id_(unk_id), blank_penalty_(blank_penalty) {}
|
||||||
|
|
||||||
OnlineTransducerDecoderResult GetEmptyResult() const override;
|
OnlineTransducerDecoderResult GetEmptyResult() const override;
|
||||||
|
|||||||
@@ -69,7 +69,7 @@ class OnlineTransducerModel {
|
|||||||
* This has to be called before GetEncoderInitStates(), so the `encoder_embed`
|
* This has to be called before GetEncoderInitStates(), so the `encoder_embed`
|
||||||
* init state has the correct `embed_dim` of its output.
|
* init state has the correct `embed_dim` of its output.
|
||||||
*/
|
*/
|
||||||
virtual void SetFeatureDim(int32_t feature_dim) { }
|
virtual void SetFeatureDim(int32_t feature_dim) {}
|
||||||
|
|
||||||
/** Run the encoder.
|
/** Run the encoder.
|
||||||
*
|
*
|
||||||
|
|||||||
@@ -188,7 +188,7 @@ void OnlineTransducerModifiedBeamSearchDecoder::Decode(
|
|||||||
// score of the transducer
|
// score of the transducer
|
||||||
// export the per-token log scores
|
// export the per-token log scores
|
||||||
if (new_token != 0 && new_token != unk_id_) {
|
if (new_token != 0 && new_token != unk_id_) {
|
||||||
const Hypothesis& prev_i = prev[hyp_index];
|
const Hypothesis &prev_i = prev[hyp_index];
|
||||||
// subtract 'prev[i]' path scores, which were added before
|
// subtract 'prev[i]' path scores, which were added before
|
||||||
// getting topk tokens
|
// getting topk tokens
|
||||||
float y_prob = p_logprob[k] - prev_i.log_prob - prev_i.lm_log_prob;
|
float y_prob = p_logprob[k] - prev_i.log_prob - prev_i.lm_log_prob;
|
||||||
|
|||||||
@@ -16,10 +16,10 @@ TEST(Stack, Test1DTensors) {
|
|||||||
std::array<int64_t, 1> b_shape{3};
|
std::array<int64_t, 1> b_shape{3};
|
||||||
|
|
||||||
Ort::Value a = Ort::Value::CreateTensor<float>(allocator, a_shape.data(),
|
Ort::Value a = Ort::Value::CreateTensor<float>(allocator, a_shape.data(),
|
||||||
a_shape.size());
|
a_shape.size());
|
||||||
|
|
||||||
Ort::Value b = Ort::Value::CreateTensor<float>(allocator, b_shape.data(),
|
Ort::Value b = Ort::Value::CreateTensor<float>(allocator, b_shape.data(),
|
||||||
b_shape.size());
|
b_shape.size());
|
||||||
float *pa = a.GetTensorMutableData<float>();
|
float *pa = a.GetTensorMutableData<float>();
|
||||||
float *pb = b.GetTensorMutableData<float>();
|
float *pb = b.GetTensorMutableData<float>();
|
||||||
for (int32_t i = 0; i != static_cast<int32_t>(a_shape[0]); ++i) {
|
for (int32_t i = 0; i != static_cast<int32_t>(a_shape[0]); ++i) {
|
||||||
@@ -51,11 +51,11 @@ TEST(Stack, Test2DTensorsDim0) {
|
|||||||
std::array<int64_t, 2> a_shape{2, 3};
|
std::array<int64_t, 2> a_shape{2, 3};
|
||||||
std::array<int64_t, 2> b_shape{2, 3};
|
std::array<int64_t, 2> b_shape{2, 3};
|
||||||
|
|
||||||
Ort::Value a = Ort::Value::CreateTensor<float>(
|
Ort::Value a = Ort::Value::CreateTensor<float>(allocator, a_shape.data(),
|
||||||
allocator, a_shape.data(), a_shape.size());
|
a_shape.size());
|
||||||
|
|
||||||
Ort::Value b = Ort::Value::CreateTensor<float>(
|
Ort::Value b = Ort::Value::CreateTensor<float>(allocator, b_shape.data(),
|
||||||
allocator, b_shape.data(), b_shape.size());
|
b_shape.size());
|
||||||
|
|
||||||
float *pa = a.GetTensorMutableData<float>();
|
float *pa = a.GetTensorMutableData<float>();
|
||||||
float *pb = b.GetTensorMutableData<float>();
|
float *pb = b.GetTensorMutableData<float>();
|
||||||
|
|||||||
@@ -12,10 +12,8 @@ static void PybindFeatureExtractorConfig(py::module *m) {
|
|||||||
using PyClass = FeatureExtractorConfig;
|
using PyClass = FeatureExtractorConfig;
|
||||||
py::class_<PyClass>(*m, "FeatureExtractorConfig")
|
py::class_<PyClass>(*m, "FeatureExtractorConfig")
|
||||||
.def(py::init<int32_t, int32_t, float, float, float>(),
|
.def(py::init<int32_t, int32_t, float, float, float>(),
|
||||||
py::arg("sampling_rate") = 16000,
|
py::arg("sampling_rate") = 16000, py::arg("feature_dim") = 80,
|
||||||
py::arg("feature_dim") = 80,
|
py::arg("low_freq") = 20.0f, py::arg("high_freq") = -400.0f,
|
||||||
py::arg("low_freq") = 20.0f,
|
|
||||||
py::arg("high_freq") = -400.0f,
|
|
||||||
py::arg("dither") = 0.0f)
|
py::arg("dither") = 0.0f)
|
||||||
.def_readwrite("sampling_rate", &PyClass::sampling_rate)
|
.def_readwrite("sampling_rate", &PyClass::sampling_rate)
|
||||||
.def_readwrite("feature_dim", &PyClass::feature_dim)
|
.def_readwrite("feature_dim", &PyClass::feature_dim)
|
||||||
|
|||||||
@@ -23,8 +23,7 @@ static void PybindOfflineRecognizerConfig(py::module *m) {
|
|||||||
py::arg("ctc_fst_decoder_config") = OfflineCtcFstDecoderConfig(),
|
py::arg("ctc_fst_decoder_config") = OfflineCtcFstDecoderConfig(),
|
||||||
py::arg("decoding_method") = "greedy_search",
|
py::arg("decoding_method") = "greedy_search",
|
||||||
py::arg("max_active_paths") = 4, py::arg("hotwords_file") = "",
|
py::arg("max_active_paths") = 4, py::arg("hotwords_file") = "",
|
||||||
py::arg("hotwords_score") = 1.5,
|
py::arg("hotwords_score") = 1.5, py::arg("blank_penalty") = 0.0)
|
||||||
py::arg("blank_penalty") = 0.0)
|
|
||||||
.def_readwrite("feat_config", &PyClass::feat_config)
|
.def_readwrite("feat_config", &PyClass::feat_config)
|
||||||
.def_readwrite("model_config", &PyClass::model_config)
|
.def_readwrite("model_config", &PyClass::model_config)
|
||||||
.def_readwrite("lm_config", &PyClass::lm_config)
|
.def_readwrite("lm_config", &PyClass::lm_config)
|
||||||
|
|||||||
@@ -4,7 +4,6 @@
|
|||||||
|
|
||||||
#include "sherpa-onnx/python/csrc/offline-transducer-model-config.h"
|
#include "sherpa-onnx/python/csrc/offline-transducer-model-config.h"
|
||||||
|
|
||||||
|
|
||||||
#include <string>
|
#include <string>
|
||||||
#include <vector>
|
#include <vector>
|
||||||
|
|
||||||
@@ -16,7 +15,7 @@ void PybindOfflineTransducerModelConfig(py::module *m) {
|
|||||||
using PyClass = OfflineTransducerModelConfig;
|
using PyClass = OfflineTransducerModelConfig;
|
||||||
py::class_<PyClass>(*m, "OfflineTransducerModelConfig")
|
py::class_<PyClass>(*m, "OfflineTransducerModelConfig")
|
||||||
.def(py::init<const std::string &, const std::string &,
|
.def(py::init<const std::string &, const std::string &,
|
||||||
const std::string &>(),
|
const std::string &>(),
|
||||||
py::arg("encoder_filename"), py::arg("decoder_filename"),
|
py::arg("encoder_filename"), py::arg("decoder_filename"),
|
||||||
py::arg("joiner_filename"))
|
py::arg("joiner_filename"))
|
||||||
.def_readwrite("encoder_filename", &PyClass::encoder_filename)
|
.def_readwrite("encoder_filename", &PyClass::encoder_filename)
|
||||||
|
|||||||
@@ -27,9 +27,9 @@ void PybindOnlineModelConfig(py::module *m) {
|
|||||||
.def(py::init<const OnlineTransducerModelConfig &,
|
.def(py::init<const OnlineTransducerModelConfig &,
|
||||||
const OnlineParaformerModelConfig &,
|
const OnlineParaformerModelConfig &,
|
||||||
const OnlineWenetCtcModelConfig &,
|
const OnlineWenetCtcModelConfig &,
|
||||||
const OnlineZipformer2CtcModelConfig &,
|
const OnlineZipformer2CtcModelConfig &, const std::string &,
|
||||||
const std::string &, int32_t, int32_t,
|
int32_t, int32_t, bool, const std::string &,
|
||||||
bool, const std::string &, const std::string &>(),
|
const std::string &>(),
|
||||||
py::arg("transducer") = OnlineTransducerModelConfig(),
|
py::arg("transducer") = OnlineTransducerModelConfig(),
|
||||||
py::arg("paraformer") = OnlineParaformerModelConfig(),
|
py::arg("paraformer") = OnlineParaformerModelConfig(),
|
||||||
py::arg("wenet_ctc") = OnlineWenetCtcModelConfig(),
|
py::arg("wenet_ctc") = OnlineWenetCtcModelConfig(),
|
||||||
|
|||||||
Reference in New Issue
Block a user