Code refactoring (#74)

* Don't reset model state and feature extractor on endpointing

* support passing decoding_method from commandline

* Add modified_beam_search to Python API

* fix C API example

* Fix style issues
This commit is contained in:
Fangjun Kuang
2023-03-03 12:10:59 +08:00
committed by GitHub
parent c241f93c40
commit 7f72c13d9a
34 changed files with 744 additions and 374 deletions

View File

@@ -9,6 +9,7 @@
#include <utility>
#include <vector>
#include "sherpa-onnx/csrc/display.h"
#include "sherpa-onnx/csrc/online-recognizer.h"
struct SherpaOnnxOnlineRecognizer {
@@ -21,6 +22,10 @@ struct SherpaOnnxOnlineStream {
: impl(std::move(p)) {}
};
struct SherpaOnnxDisplay {
std::unique_ptr<sherpa_onnx::Display> impl;
};
SherpaOnnxOnlineRecognizer *CreateOnlineRecognizer(
const SherpaOnnxOnlineRecognizerConfig *config) {
sherpa_onnx::OnlineRecognizerConfig recognizer_config;
@@ -37,6 +42,9 @@ SherpaOnnxOnlineRecognizer *CreateOnlineRecognizer(
recognizer_config.model_config.num_threads = config->model_config.num_threads;
recognizer_config.model_config.debug = config->model_config.debug;
recognizer_config.decoding_method = config->decoding_method;
recognizer_config.max_active_paths = config->max_active_paths;
recognizer_config.enable_endpoint = config->enable_endpoint;
recognizer_config.endpoint_config.rule1.min_trailing_silence =
@@ -124,3 +132,15 @@ int32_t IsEndpoint(SherpaOnnxOnlineRecognizer *recognizer,
SherpaOnnxOnlineStream *stream) {
return recognizer->impl->IsEndpoint(stream->impl.get());
}
SherpaOnnxDisplay *CreateDisplay(int32_t max_word_per_line) {
SherpaOnnxDisplay *ans = new SherpaOnnxDisplay;
ans->impl = std::make_unique<sherpa_onnx::Display>(max_word_per_line);
return ans;
}
void DestroyDisplay(SherpaOnnxDisplay *display) { delete display; }
void SherpaOnnxPrint(SherpaOnnxDisplay *display, int32_t idx, const char *s) {
display->impl->Print(idx, s);
}

View File

@@ -48,6 +48,13 @@ typedef struct SherpaOnnxOnlineRecognizerConfig {
SherpaOnnxFeatureConfig feat_config;
SherpaOnnxOnlineTransducerModelConfig model_config;
/// Possible values are: greedy_search, modified_beam_search
const char *decoding_method;
/// Used only when decoding_method is modified_beam_search
/// Example value: 4
int32_t max_active_paths;
/// 0 to disable endpoint detection.
/// A non-zero value to enable endpoint detection.
int32_t enable_endpoint;
@@ -187,6 +194,18 @@ void InputFinished(SherpaOnnxOnlineStream *stream);
int32_t IsEndpoint(SherpaOnnxOnlineRecognizer *recognizer,
SherpaOnnxOnlineStream *stream);
// for displaying results on Linux/macOS.
typedef struct SherpaOnnxDisplay SherpaOnnxDisplay;
/// Create a display object. Must be freed using DestroyDisplay to avoid
/// memory leak.
SherpaOnnxDisplay *CreateDisplay(int32_t max_word_per_line);
void DestroyDisplay(SherpaOnnxDisplay *display);
/// Print the result.
void SherpaOnnxPrint(SherpaOnnxDisplay *display, int32_t idx, const char *s);
#ifdef __cplusplus
} /* extern "C" */
#endif

View File

@@ -9,10 +9,11 @@ set(sources
online-lstm-transducer-model.cc
online-recognizer.cc
online-stream.cc
online-transducer-decoder.cc
online-transducer-greedy-search-decoder.cc
online-transducer-model-config.cc
online-transducer-modified-beam-search-decoder.cc
online-transducer-model.cc
online-transducer-modified-beam-search-decoder.cc
online-zipformer-transducer-model.cc
onnx-utils.cc
parse-options.cc

View File

@@ -12,9 +12,16 @@ namespace sherpa_onnx {
class Display {
public:
explicit Display(int32_t max_word_per_line = 60)
: max_word_per_line_(max_word_per_line) {}
void Print(int32_t segment_id, const std::string &s) {
#ifdef _MSC_VER
fprintf(stderr, "%d:%s\n", segment_id, s.c_str());
if (segment_id != -1) {
fprintf(stderr, "%d:%s\n", segment_id, s.c_str());
} else {
fprintf(stderr, "%s\n", s.c_str());
}
return;
#endif
if (last_segment_ == segment_id) {
@@ -27,7 +34,9 @@ class Display {
num_previous_lines_ = 0;
}
fprintf(stderr, "\r%d:", segment_id);
if (segment_id != -1) {
fprintf(stderr, "\r%d:", segment_id);
}
int32_t i = 0;
for (size_t n = 0; n < s.size();) {
@@ -69,7 +78,7 @@ class Display {
void GoUpOneLine() const { fprintf(stderr, "\033[1A\r"); }
private:
int32_t max_word_per_line_ = 60;
int32_t max_word_per_line_;
int32_t num_previous_lines_ = 0;
int32_t last_segment_ = -1;
};

View File

@@ -28,7 +28,8 @@ std::string FeatureExtractorConfig::ToString() const {
os << "FeatureExtractorConfig(";
os << "sampling_rate=" << sampling_rate << ", ";
os << "feature_dim=" << feature_dim << ")";
os << "feature_dim=" << feature_dim << ", ";
os << "max_feature_vectors=" << max_feature_vectors << ")";
return os.str();
}
@@ -40,9 +41,7 @@ class FeatureExtractor::Impl {
opts_.frame_opts.snip_edges = false;
opts_.frame_opts.samp_freq = config.sampling_rate;
// cache 100 seconds of feature frames, which is more than enough
// for real needs
opts_.frame_opts.max_feature_vectors = 100 * 100;
opts_.frame_opts.max_feature_vectors = config.max_feature_vectors;
opts_.mel_opts.num_bins = config.feature_dim;

View File

@@ -16,6 +16,7 @@ namespace sherpa_onnx {
struct FeatureExtractorConfig {
float sampling_rate = 16000;
int32_t feature_dim = 80;
int32_t max_feature_vectors = -1;
std::string ToString() const;

View File

@@ -18,7 +18,7 @@ namespace sherpa_onnx {
struct Hypothesis {
// The predicted tokens so far. Newly predicated tokens are appended.
std::vector<int32_t> ys;
std::vector<int64_t> ys;
// timestamps[i] contains the frame number after subsampling
// on which ys[i] is decoded.
@@ -30,7 +30,7 @@ struct Hypothesis {
int32_t num_trailing_blanks = 0;
Hypothesis() = default;
Hypothesis(const std::vector<int32_t> &ys, double log_prob)
Hypothesis(const std::vector<int64_t> &ys, double log_prob)
: ys(ys), log_prob(log_prob) {}
// If two Hypotheses have the same `Key`, then they contain

View File

@@ -43,7 +43,7 @@ void OnlineRecognizerConfig::Register(ParseOptions *po) {
"True to enable endpoint detection. False to disable it.");
po->Register("max-active-paths", &max_active_paths,
"beam size used in modified beam search.");
po->Register("decoding-mothod", &decoding_method,
po->Register("decoding-method", &decoding_method,
"decoding method,"
"now support greedy_search and modified_beam_search.");
}
@@ -59,8 +59,8 @@ std::string OnlineRecognizerConfig::ToString() const {
os << "feat_config=" << feat_config.ToString() << ", ";
os << "model_config=" << model_config.ToString() << ", ";
os << "endpoint_config=" << endpoint_config.ToString() << ", ";
os << "enable_endpoint=" << (enable_endpoint ? "True" : "False") << ",";
os << "max_active_paths=" << max_active_paths << ",";
os << "enable_endpoint=" << (enable_endpoint ? "True" : "False") << ", ";
os << "max_active_paths=" << max_active_paths << ", ";
os << "decoding_method=\"" << decoding_method << "\")";
return os.str();
@@ -187,16 +187,14 @@ class OnlineRecognizer::Impl {
}
void Reset(OnlineStream *s) const {
// reset result, neural network model state, and
// the feature extractor state
// reset result
// we keep the decoder_out
decoder_->UpdateDecoderOut(&s->GetResult());
Ort::Value decoder_out = std::move(s->GetResult().decoder_out);
s->SetResult(decoder_->GetEmptyResult());
s->GetResult().decoder_out = std::move(decoder_out);
// reset neural network model state
s->SetStates(model_->GetEncoderInitStates());
// reset feature extractor
// Note: We only update counters. The underlying audio samples
// are not discarded.
s->Reset();
}

View File

@@ -33,21 +33,26 @@ struct OnlineRecognizerConfig {
OnlineTransducerModelConfig model_config;
EndpointConfig endpoint_config;
bool enable_endpoint = true;
int32_t max_active_paths = 4;
std::string decoding_method = "modified_beam_search";
std::string decoding_method = "greedy_search";
// now support modified_beam_search and greedy_search
int32_t max_active_paths = 4; // used only for modified_beam_search
OnlineRecognizerConfig() = default;
OnlineRecognizerConfig(const FeatureExtractorConfig &feat_config,
const OnlineTransducerModelConfig &model_config,
const EndpointConfig &endpoint_config,
bool enable_endpoint)
bool enable_endpoint,
const std::string &decoding_method,
int32_t max_active_paths)
: feat_config(feat_config),
model_config(model_config),
endpoint_config(endpoint_config),
enable_endpoint(enable_endpoint) {}
enable_endpoint(enable_endpoint),
decoding_method(decoding_method),
max_active_paths(max_active_paths) {}
void Register(ParseOptions *po);
bool Validate() const;

View File

@@ -22,18 +22,21 @@ class OnlineStream::Impl {
void InputFinished() { feat_extractor_.InputFinished(); }
int32_t NumFramesReady() const { return feat_extractor_.NumFramesReady(); }
int32_t NumFramesReady() const {
return feat_extractor_.NumFramesReady() - start_frame_index_;
}
bool IsLastFrame(int32_t frame) const {
return feat_extractor_.IsLastFrame(frame);
}
std::vector<float> GetFrames(int32_t frame_index, int32_t n) const {
return feat_extractor_.GetFrames(frame_index, n);
return feat_extractor_.GetFrames(frame_index + start_frame_index_, n);
}
void Reset() {
feat_extractor_.Reset();
// we don't reset the feature extractor
start_frame_index_ += num_processed_frames_;
num_processed_frames_ = 0;
}
@@ -41,7 +44,7 @@ class OnlineStream::Impl {
void SetResult(const OnlineTransducerDecoderResult &r) { result_ = r; }
const OnlineTransducerDecoderResult &GetResult() const { return result_; }
OnlineTransducerDecoderResult &GetResult() { return result_; }
int32_t FeatureDim() const { return feat_extractor_.FeatureDim(); }
@@ -54,6 +57,7 @@ class OnlineStream::Impl {
private:
FeatureExtractor feat_extractor_;
int32_t num_processed_frames_ = 0; // before subsampling
int32_t start_frame_index_ = 0; // never reset
OnlineTransducerDecoderResult result_;
std::vector<Ort::Value> states_;
};
@@ -93,7 +97,7 @@ void OnlineStream::SetResult(const OnlineTransducerDecoderResult &r) {
impl_->SetResult(r);
}
const OnlineTransducerDecoderResult &OnlineStream::GetResult() const {
OnlineTransducerDecoderResult &OnlineStream::GetResult() {
return impl_->GetResult();
}

View File

@@ -63,7 +63,7 @@ class OnlineStream {
int32_t &GetNumProcessedFrames();
void SetResult(const OnlineTransducerDecoderResult &r);
const OnlineTransducerDecoderResult &GetResult() const;
OnlineTransducerDecoderResult &GetResult();
void SetStates(std::vector<Ort::Value> states);
std::vector<Ort::Value> &GetStates();

View File

@@ -0,0 +1,60 @@
// sherpa-onnx/csrc/online-transducer-decoder.cc
//
// Copyright (c) 2023 Xiaomi Corporation
#include "sherpa-onnx/csrc/online-transducer-decoder.h"
#include <utility>
#include <vector>
#include "onnxruntime_cxx_api.h" // NOLINT
#include "sherpa-onnx/csrc/onnx-utils.h"
namespace sherpa_onnx {
OnlineTransducerDecoderResult::OnlineTransducerDecoderResult(
const OnlineTransducerDecoderResult &other)
: OnlineTransducerDecoderResult() {
*this = other;
}
OnlineTransducerDecoderResult &OnlineTransducerDecoderResult::operator=(
const OnlineTransducerDecoderResult &other) {
if (this == &other) {
return *this;
}
tokens = other.tokens;
num_trailing_blanks = other.num_trailing_blanks;
Ort::AllocatorWithDefaultOptions allocator;
if (other.decoder_out) {
decoder_out = Clone(allocator, &other.decoder_out);
}
hyps = other.hyps;
return *this;
}
OnlineTransducerDecoderResult::OnlineTransducerDecoderResult(
OnlineTransducerDecoderResult &&other)
: OnlineTransducerDecoderResult() {
*this = std::move(other);
}
OnlineTransducerDecoderResult &OnlineTransducerDecoderResult::operator=(
OnlineTransducerDecoderResult &&other) {
if (this == &other) {
return *this;
}
tokens = std::move(other.tokens);
num_trailing_blanks = other.num_trailing_blanks;
decoder_out = std::move(other.decoder_out);
hyps = std::move(other.hyps);
return *this;
}
} // namespace sherpa_onnx

View File

@@ -19,8 +19,24 @@ struct OnlineTransducerDecoderResult {
/// number of trailing blank frames decoded so far
int32_t num_trailing_blanks = 0;
// Cache decoder_out for endpointing
Ort::Value decoder_out;
// used only in modified beam_search
Hypotheses hyps;
OnlineTransducerDecoderResult()
: tokens{}, num_trailing_blanks(0), decoder_out{nullptr}, hyps{} {}
OnlineTransducerDecoderResult(const OnlineTransducerDecoderResult &other);
OnlineTransducerDecoderResult &operator=(
const OnlineTransducerDecoderResult &other);
OnlineTransducerDecoderResult(OnlineTransducerDecoderResult &&other);
OnlineTransducerDecoderResult &operator=(
OnlineTransducerDecoderResult &&other);
};
class OnlineTransducerDecoder {
@@ -53,6 +69,9 @@ class OnlineTransducerDecoder {
*/
virtual void Decode(Ort::Value encoder_out,
std::vector<OnlineTransducerDecoderResult> *result) = 0;
// used for endpointing. We need to keep decoder_out after reset
virtual void UpdateDecoderOut(OnlineTransducerDecoderResult *result) {}
};
} // namespace sherpa_onnx

View File

@@ -13,6 +13,43 @@
namespace sherpa_onnx {
static void UseCachedDecoderOut(
const std::vector<OnlineTransducerDecoderResult> &results,
Ort::Value *decoder_out) {
std::vector<int64_t> shape =
decoder_out->GetTensorTypeAndShapeInfo().GetShape();
float *dst = decoder_out->GetTensorMutableData<float>();
for (const auto &r : results) {
if (r.decoder_out) {
const float *src = r.decoder_out.GetTensorData<float>();
std::copy(src, src + shape[1], dst);
}
dst += shape[1];
}
}
static void UpdateCachedDecoderOut(
OrtAllocator *allocator, const Ort::Value *decoder_out,
std::vector<OnlineTransducerDecoderResult> *results) {
std::vector<int64_t> shape =
decoder_out->GetTensorTypeAndShapeInfo().GetShape();
auto memory_info =
Ort::MemoryInfo::CreateCpu(OrtDeviceAllocator, OrtMemTypeDefault);
std::array<int64_t, 2> v_shape{1, shape[1]};
const float *src = decoder_out->GetTensorData<float>();
for (auto &r : *results) {
if (!r.decoder_out) {
r.decoder_out = Ort::Value::CreateTensor<float>(allocator, v_shape.data(),
v_shape.size());
}
float *dst = r.decoder_out.GetTensorMutableData<float>();
std::copy(src, src + shape[1], dst);
src += shape[1];
}
}
OnlineTransducerDecoderResult
OnlineTransducerGreedySearchDecoder::GetEmptyResult() const {
int32_t context_size = model_->ContextSize();
@@ -53,6 +90,7 @@ void OnlineTransducerGreedySearchDecoder::Decode(
Ort::Value decoder_input = model_->BuildDecoderInput(*result);
Ort::Value decoder_out = model_->RunDecoder(std::move(decoder_input));
UseCachedDecoderOut(*result, &decoder_out);
for (int32_t t = 0; t != num_frames; ++t) {
Ort::Value cur_encoder_out =
@@ -77,10 +115,12 @@ void OnlineTransducerGreedySearchDecoder::Decode(
}
}
if (emitted) {
decoder_input = model_->BuildDecoderInput(*result);
Ort::Value decoder_input = model_->BuildDecoderInput(*result);
decoder_out = model_->RunDecoder(std::move(decoder_input));
}
}
UpdateCachedDecoderOut(model_->Allocator(), &decoder_out, result);
}
} // namespace sherpa_onnx

View File

@@ -13,6 +13,29 @@
namespace sherpa_onnx {
static void UseCachedDecoderOut(
const std::vector<int32_t> &hyps_num_split,
const std::vector<OnlineTransducerDecoderResult> &results,
int32_t context_size, Ort::Value *decoder_out) {
std::vector<int64_t> shape =
decoder_out->GetTensorTypeAndShapeInfo().GetShape();
float *dst = decoder_out->GetTensorMutableData<float>();
int32_t batch_size = static_cast<int32_t>(results.size());
for (int32_t i = 0; i != batch_size; ++i) {
int32_t num_hyps = hyps_num_split[i + 1] - hyps_num_split[i];
if (num_hyps > 1 || !results[i].decoder_out) {
dst += num_hyps * shape[1];
continue;
}
const float *src = results[i].decoder_out.GetTensorData<float>();
std::copy(src, src + shape[1], dst);
dst += shape[1];
}
}
static Ort::Value Repeat(OrtAllocator *allocator, Ort::Value *cur_encoder_out,
const std::vector<int32_t> &hyps_num_split) {
std::vector<int64_t> cur_encoder_out_shape =
@@ -50,7 +73,7 @@ OnlineTransducerModifiedBeamSearchDecoder::GetEmptyResult() const {
int32_t context_size = model_->ContextSize();
int32_t blank_id = 0; // always 0
OnlineTransducerDecoderResult r;
std::vector<int32_t> blanks(context_size, blank_id);
std::vector<int64_t> blanks(context_size, blank_id);
Hypotheses blank_hyp({{blanks, 0}});
r.hyps = std::move(blank_hyp);
return r;
@@ -110,6 +133,10 @@ void OnlineTransducerModifiedBeamSearchDecoder::Decode(
Ort::Value decoder_input = model_->BuildDecoderInput(prev);
Ort::Value decoder_out = model_->RunDecoder(std::move(decoder_input));
if (t == 0) {
UseCachedDecoderOut(hyps_num_split, *result, model_->ContextSize(),
&decoder_out);
}
Ort::Value cur_encoder_out =
GetEncoderOutFrame(model_->Allocator(), &encoder_out, t);
@@ -147,8 +174,23 @@ void OnlineTransducerModifiedBeamSearchDecoder::Decode(
}
for (int32_t b = 0; b != batch_size; ++b) {
(*result)[b].hyps = std::move(cur[b]);
auto &hyps = cur[b];
auto best_hyp = hyps.GetMostProbable(true);
(*result)[b].hyps = std::move(hyps);
(*result)[b].tokens = std::move(best_hyp.ys);
(*result)[b].num_trailing_blanks = best_hyp.num_trailing_blanks;
}
}
void OnlineTransducerModifiedBeamSearchDecoder::UpdateDecoderOut(
OnlineTransducerDecoderResult *result) {
if (result->tokens.size() == model_->ContextSize()) {
result->decoder_out = Ort::Value{nullptr};
return;
}
Ort::Value decoder_input = model_->BuildDecoderInput({*result});
result->decoder_out = model_->RunDecoder(std::move(decoder_input));
}
} // namespace sherpa_onnx

View File

@@ -27,6 +27,8 @@ class OnlineTransducerModifiedBeamSearchDecoder
void Decode(Ort::Value encoder_out,
std::vector<OnlineTransducerDecoderResult> *result) override;
void UpdateDecoderOut(OnlineTransducerDecoderResult *result) override;
private:
OnlineTransducerModel *model_; // Not owned
int32_t max_active_paths_;

View File

@@ -21,7 +21,7 @@ static void Handler(int sig) {
}
int main(int32_t argc, char *argv[]) {
if (argc < 6 || argc > 7) {
if (argc < 6 || argc > 8) {
const char *usage = R"usage(
Usage:
./bin/sherpa-onnx-alsa \
@@ -30,7 +30,10 @@ Usage:
/path/to/decoder.onnx \
/path/to/joiner.onnx \
device_name \
[num_threads]
[num_threads [decoding_method]]
Default value for num_threads is 2.
Valid values for decoding_method: greedy_search (default), modified_beam_search.
Please refer to
https://k2-fsa.github.io/sherpa/onnx/pretrained_models/index.html
@@ -79,6 +82,11 @@ as the device_name.
config.model_config.num_threads = atoi(argv[6]);
}
if (argc == 8) {
config.decoding_method = argv[7];
}
config.max_active_paths = 4;
config.enable_endpoint = true;
config.endpoint_config.rule1.min_trailing_silence = 2.4;

View File

@@ -36,7 +36,7 @@ static void Handler(int32_t sig) {
}
int32_t main(int32_t argc, char *argv[]) {
if (argc < 5 || argc > 6) {
if (argc < 5 || argc > 7) {
const char *usage = R"usage(
Usage:
./bin/sherpa-onnx-microphone \
@@ -44,7 +44,10 @@ Usage:
/path/to/encoder.onnx\
/path/to/decoder.onnx\
/path/to/joiner.onnx\
[num_threads]
[num_threads [decoding_method]]
Default value for num_threads is 2.
Valid values for decoding_method: greedy_search (default), modified_beam_search.
Please refer to
https://k2-fsa.github.io/sherpa/onnx/pretrained_models/index.html
@@ -70,6 +73,11 @@ for a list of pre-trained models to download.
config.model_config.num_threads = atoi(argv[5]);
}
if (argc == 7) {
config.decoding_method = argv[6];
}
config.max_active_paths = 4;
config.enable_endpoint = true;
config.endpoint_config.rule1.min_trailing_silence = 2.4;

View File

@@ -14,7 +14,7 @@
#include "sherpa-onnx/csrc/wave-reader.h"
int main(int32_t argc, char *argv[]) {
if (argc < 6 || argc > 7) {
if (argc < 6 || argc > 8) {
const char *usage = R"usage(
Usage:
./bin/sherpa-onnx \
@@ -22,7 +22,10 @@ Usage:
/path/to/encoder.onnx \
/path/to/decoder.onnx \
/path/to/joiner.onnx \
/path/to/foo.wav [num_threads]
/path/to/foo.wav [num_threads [decoding_method]]
Default value for num_threads is 2.
Valid values for decoding_method: greedy_search (default), modified_beam_search.
Please refer to
https://k2-fsa.github.io/sherpa/onnx/pretrained_models/index.html
@@ -45,9 +48,15 @@ for a list of pre-trained models to download.
std::string wav_filename = argv[5];
config.model_config.num_threads = 2;
if (argc == 7) {
if (argc == 7 && atoi(argv[6]) > 0) {
config.model_config.num_threads = atoi(argv[6]);
}
if (argc == 8) {
config.decoding_method = argv[7];
}
config.max_active_paths = 4;
fprintf(stderr, "%s\n", config.ToString().c_str());
sherpa_onnx::OnlineRecognizer recognizer(config);
@@ -98,6 +107,7 @@ for a list of pre-trained models to download.
1000.;
fprintf(stderr, "num threads: %d\n", config.model_config.num_threads);
fprintf(stderr, "decoding method: %s\n", config.decoding_method.c_str());
fprintf(stderr, "Elapsed seconds: %.3f s\n", elapsed_seconds);
float rtf = elapsed_seconds / duration;

View File

@@ -1,12 +1,13 @@
include_directories(${CMAKE_SOURCE_DIR})
pybind11_add_module(_sherpa_onnx
display.cc
endpoint.cc
features.cc
online-recognizer.cc
online-stream.cc
online-transducer-model-config.cc
sherpa-onnx.cc
endpoint.cc
online-stream.cc
online-recognizer.cc
)
if(APPLE)

View File

@@ -0,0 +1,18 @@
// sherpa-onnx/python/csrc/display.cc
//
// Copyright (c) 2023 Xiaomi Corporation
#include "sherpa-onnx/python/csrc/display.h"
#include "sherpa-onnx/csrc/display.h"
namespace sherpa_onnx {
void PybindDisplay(py::module *m) {
using PyClass = Display;
py::class_<PyClass>(*m, "Display")
.def(py::init<int32_t>(), py::arg("max_word_per_line") = 60)
.def("print", &PyClass::Print, py::arg("idx"), py::arg("s"));
}
} // namespace sherpa_onnx

View File

@@ -0,0 +1,16 @@
// sherpa-onnx/python/csrc/display.h
//
// Copyright (c) 2023 Xiaomi Corporation
#ifndef SHERPA_ONNX_PYTHON_CSRC_DISPLAY_H_
#define SHERPA_ONNX_PYTHON_CSRC_DISPLAY_H_
#include "sherpa-onnx/python/csrc/sherpa-onnx.h"
namespace sherpa_onnx {
void PybindDisplay(py::module *m);
} // namespace sherpa_onnx
#endif // SHERPA_ONNX_PYTHON_CSRC_DISPLAY_H_

View File

@@ -11,10 +11,12 @@ namespace sherpa_onnx {
static void PybindFeatureExtractorConfig(py::module *m) {
using PyClass = FeatureExtractorConfig;
py::class_<PyClass>(*m, "FeatureExtractorConfig")
.def(py::init<float, int32_t>(), py::arg("sampling_rate") = 16000,
py::arg("feature_dim") = 80)
.def(py::init<float, int32_t, int32_t>(),
py::arg("sampling_rate") = 16000, py::arg("feature_dim") = 80,
py::arg("max_feature_vectors") = -1)
.def_readwrite("sampling_rate", &PyClass::sampling_rate)
.def_readwrite("feature_dim", &PyClass::feature_dim)
.def_readwrite("max_feature_vectors", &PyClass::max_feature_vectors)
.def("__str__", &PyClass::ToString);
}

View File

@@ -22,13 +22,16 @@ static void PybindOnlineRecognizerConfig(py::module *m) {
py::class_<PyClass>(*m, "OnlineRecognizerConfig")
.def(py::init<const FeatureExtractorConfig &,
const OnlineTransducerModelConfig &, const EndpointConfig &,
bool>(),
bool, const std::string &, int32_t>(),
py::arg("feat_config"), py::arg("model_config"),
py::arg("endpoint_config"), py::arg("enable_endpoint"))
py::arg("endpoint_config"), py::arg("enable_endpoint"),
py::arg("decoding_method"), py::arg("max_active_paths"))
.def_readwrite("feat_config", &PyClass::feat_config)
.def_readwrite("model_config", &PyClass::model_config)
.def_readwrite("endpoint_config", &PyClass::endpoint_config)
.def_readwrite("enable_endpoint", &PyClass::enable_endpoint)
.def_readwrite("decoding_method", &PyClass::decoding_method)
.def_readwrite("max_active_paths", &PyClass::max_active_paths)
.def("__str__", &PyClass::ToString);
}

View File

@@ -4,6 +4,7 @@
#include "sherpa-onnx/python/csrc/sherpa-onnx.h"
#include "sherpa-onnx/python/csrc/display.h"
#include "sherpa-onnx/python/csrc/endpoint.h"
#include "sherpa-onnx/python/csrc/features.h"
#include "sherpa-onnx/python/csrc/online-recognizer.h"
@@ -19,6 +20,8 @@ PYBIND11_MODULE(_sherpa_onnx, m) {
PybindOnlineStream(&m);
PybindEndpoint(&m);
PybindOnlineRecognizer(&m);
PybindDisplay(&m);
}
} // namespace sherpa_onnx

View File

@@ -1,9 +1,3 @@
from _sherpa_onnx import (
EndpointConfig,
FeatureExtractorConfig,
OnlineRecognizerConfig,
OnlineStream,
OnlineTransducerModelConfig,
)
from _sherpa_onnx import Display
from .online_recognizer import OnlineRecognizer

View File

@@ -32,6 +32,9 @@ class OnlineRecognizer(object):
rule1_min_trailing_silence: int = 2.4,
rule2_min_trailing_silence: int = 1.2,
rule3_min_utterance_length: int = 20,
decoding_method: str = "greedy_search",
max_active_paths: int = 4,
max_feature_vectors: int = -1,
):
"""
Please refer to
@@ -74,6 +77,14 @@ class OnlineRecognizer(object):
Used only when enable_endpoint_detection is True. If the utterance
length in seconds is larger than this value, we assume an endpoint
is detected.
decoding_method:
Valid values are greedy_search, modified_beam_search.
max_active_paths:
Use only when decoding_method is modified_beam_search. It specifies
the maximum number of active paths during beam search.
max_feature_vectors:
Number of feature vectors to cache. -1 means to cache all feature
frames that have been processed.
"""
_assert_file_exists(tokens)
_assert_file_exists(encoder)
@@ -93,6 +104,7 @@ class OnlineRecognizer(object):
feat_config = FeatureExtractorConfig(
sampling_rate=sample_rate,
feature_dim=feature_dim,
max_feature_vectors=max_feature_vectors,
)
endpoint_config = EndpointConfig(
@@ -106,6 +118,8 @@ class OnlineRecognizer(object):
model_config=model_config,
endpoint_config=endpoint_config,
enable_endpoint=enable_endpoint_detection,
decoding_method=decoding_method,
max_active_paths=max_active_paths,
)
self.recognizer = _Recognizer(recognizer_config)