Add CTC HLG decoding using OpenFst (#349)
This commit is contained in:
@@ -16,6 +16,7 @@
|
||||
#endif
|
||||
|
||||
#include "sherpa-onnx/csrc/offline-ctc-decoder.h"
|
||||
#include "sherpa-onnx/csrc/offline-ctc-fst-decoder.h"
|
||||
#include "sherpa-onnx/csrc/offline-ctc-greedy-search-decoder.h"
|
||||
#include "sherpa-onnx/csrc/offline-ctc-model.h"
|
||||
#include "sherpa-onnx/csrc/offline-recognizer-impl.h"
|
||||
@@ -25,9 +26,12 @@
|
||||
namespace sherpa_onnx {
|
||||
|
||||
static OfflineRecognitionResult Convert(const OfflineCtcDecoderResult &src,
|
||||
const SymbolTable &sym_table) {
|
||||
const SymbolTable &sym_table,
|
||||
int32_t frame_shift_ms,
|
||||
int32_t subsampling_factor) {
|
||||
OfflineRecognitionResult r;
|
||||
r.tokens.reserve(src.tokens.size());
|
||||
r.timestamps.reserve(src.timestamps.size());
|
||||
|
||||
std::string text;
|
||||
|
||||
@@ -42,6 +46,12 @@ static OfflineRecognitionResult Convert(const OfflineCtcDecoderResult &src,
|
||||
}
|
||||
r.text = std::move(text);
|
||||
|
||||
float frame_shift_s = frame_shift_ms / 1000. * subsampling_factor;
|
||||
for (auto t : src.timestamps) {
|
||||
float time = frame_shift_s * t;
|
||||
r.timestamps.push_back(time);
|
||||
}
|
||||
|
||||
return r;
|
||||
}
|
||||
|
||||
@@ -68,7 +78,12 @@ class OfflineRecognizerCtcImpl : public OfflineRecognizerImpl {
|
||||
config_.feat_config.nemo_normalize_type =
|
||||
model_->FeatureNormalizationMethod();
|
||||
|
||||
if (config_.decoding_method == "greedy_search") {
|
||||
if (!config_.ctc_fst_decoder_config.graph.empty()) {
|
||||
// TODO(fangjun): Support android to read the graph from
|
||||
// asset_manager
|
||||
decoder_ = std::make_unique<OfflineCtcFstDecoder>(
|
||||
config_.ctc_fst_decoder_config);
|
||||
} else if (config_.decoding_method == "greedy_search") {
|
||||
if (!symbol_table_.contains("<blk>") &&
|
||||
!symbol_table_.contains("<eps>")) {
|
||||
SHERPA_ONNX_LOGE(
|
||||
@@ -139,10 +154,12 @@ class OfflineRecognizerCtcImpl : public OfflineRecognizerImpl {
|
||||
-23.025850929940457f);
|
||||
auto t = model_->Forward(std::move(x), std::move(x_length));
|
||||
|
||||
auto results = decoder_->Decode(std::move(t.first), std::move(t.second));
|
||||
auto results = decoder_->Decode(std::move(t[0]), std::move(t[1]));
|
||||
|
||||
int32_t frame_shift_ms = 10;
|
||||
for (int32_t i = 0; i != n; ++i) {
|
||||
auto r = Convert(results[i], symbol_table_);
|
||||
auto r = Convert(results[i], symbol_table_, frame_shift_ms,
|
||||
model_->SubsamplingFactor());
|
||||
ss[i]->SetResult(r);
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user