Support non-streaming WeNet CTC models. (#426)

This commit is contained in:
Fangjun Kuang
2023-11-15 14:23:20 +08:00
committed by GitHub
parent d34640e3a3
commit b83b3e3cd1
21 changed files with 469 additions and 32 deletions

View File

@@ -75,6 +75,12 @@ class OfflineRecognizerCtcImpl : public OfflineRecognizerImpl {
#endif
void Init() {
if (!config_.model_config.wenet_ctc.model.empty()) {
// WeNet CTC models assume input samples are in the range
// [-32768, 32767], so we set normalize_samples to false
config_.feat_config.normalize_samples = false;
}
config_.feat_config.nemo_normalize_type =
model_->FeatureNormalizationMethod();
@@ -85,10 +91,11 @@ class OfflineRecognizerCtcImpl : public OfflineRecognizerImpl {
config_.ctc_fst_decoder_config);
} else if (config_.decoding_method == "greedy_search") {
if (!symbol_table_.contains("<blk>") &&
!symbol_table_.contains("<eps>")) {
!symbol_table_.contains("<eps>") &&
!symbol_table_.contains("<blank>")) {
SHERPA_ONNX_LOGE(
"We expect that tokens.txt contains "
"the symbol <blk> or <eps> and its ID.");
"the symbol <blk> or <eps> or <blank> and its ID.");
exit(-1);
}
@@ -98,6 +105,9 @@ class OfflineRecognizerCtcImpl : public OfflineRecognizerImpl {
} else if (symbol_table_.contains("<eps>")) {
// for tdnn models of the yesno recipe from icefall
blank_id = symbol_table_["<eps>"];
} else if (symbol_table_.contains("<blank>")) {
// for Wenet CTC models
blank_id = symbol_table_["<blank>"];
}
decoder_ = std::make_unique<OfflineCtcGreedySearchDecoder>(blank_id);
@@ -113,6 +123,15 @@ class OfflineRecognizerCtcImpl : public OfflineRecognizerImpl {
}
void DecodeStreams(OfflineStream **ss, int32_t n) const override {
if (!model_->SupportBatchProcessing()) {
// If the model does not support batch process,
// we process each stream independently.
for (int32_t i = 0; i != n; ++i) {
DecodeStream(ss[i]);
}
return;
}
auto memory_info =
Ort::MemoryInfo::CreateCpu(OrtDeviceAllocator, OrtMemTypeDefault);
@@ -164,6 +183,38 @@ class OfflineRecognizerCtcImpl : public OfflineRecognizerImpl {
}
}
private:
// Decode a single stream.
// Some models do not support batch size > 1, e.g., WeNet CTC models.
void DecodeStream(OfflineStream *s) const {
auto memory_info =
Ort::MemoryInfo::CreateCpu(OrtDeviceAllocator, OrtMemTypeDefault);
int32_t feat_dim = config_.feat_config.feature_dim;
std::vector<float> f = s->GetFrames();
int32_t num_frames = f.size() / feat_dim;
std::array<int64_t, 3> shape = {1, num_frames, feat_dim};
Ort::Value x = Ort::Value::CreateTensor(memory_info, f.data(), f.size(),
shape.data(), shape.size());
int64_t x_length_scalar = num_frames;
std::array<int64_t, 1> x_length_shape = {1};
Ort::Value x_length =
Ort::Value::CreateTensor(memory_info, &x_length_scalar, 1,
x_length_shape.data(), x_length_shape.size());
auto t = model_->Forward(std::move(x), std::move(x_length));
auto results = decoder_->Decode(std::move(t[0]), std::move(t[1]));
int32_t frame_shift_ms = 10;
auto r = Convert(results[0], symbol_table_, frame_shift_ms,
model_->SubsamplingFactor());
s->SetResult(r);
}
private:
OfflineRecognizerConfig config_;
SymbolTable symbol_table_;