From aa48b76d4b6781266d140f0b4f83c449a07cd579 Mon Sep 17 00:00:00 2001 From: Fangjun Kuang Date: Wed, 9 Aug 2023 12:33:47 +0800 Subject: [PATCH] Fix initial tokens to decoding (#246) --- CMakeLists.txt | 2 +- sherpa-onnx/csrc/offline-transducer-greedy-search-decoder.cc | 3 ++- .../csrc/offline-transducer-modified-beam-search-decoder.cc | 3 ++- sherpa-onnx/csrc/online-transducer-greedy-search-decoder.cc | 3 ++- .../csrc/online-transducer-modified-beam-search-decoder.cc | 4 +++- 5 files changed, 10 insertions(+), 5 deletions(-) diff --git a/CMakeLists.txt b/CMakeLists.txt index 79d5b4b8..0720c72c 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -1,7 +1,7 @@ cmake_minimum_required(VERSION 3.13 FATAL_ERROR) project(sherpa-onnx) -set(SHERPA_ONNX_VERSION "1.6.1") +set(SHERPA_ONNX_VERSION "1.6.2") # Disable warning about # diff --git a/sherpa-onnx/csrc/offline-transducer-greedy-search-decoder.cc b/sherpa-onnx/csrc/offline-transducer-greedy-search-decoder.cc index d9ef5f8d..99ac3338 100644 --- a/sherpa-onnx/csrc/offline-transducer-greedy-search-decoder.cc +++ b/sherpa-onnx/csrc/offline-transducer-greedy-search-decoder.cc @@ -30,8 +30,9 @@ OfflineTransducerGreedySearchDecoder::Decode(Ort::Value encoder_out, std::vector ans(batch_size); for (auto &r : ans) { + r.tokens.resize(context_size, -1); // 0 is the ID of the blank token - r.tokens.resize(context_size, 0); + r.tokens.back() = 0; } auto decoder_input = model_->BuildDecoderInput(ans, ans.size()); diff --git a/sherpa-onnx/csrc/offline-transducer-modified-beam-search-decoder.cc b/sherpa-onnx/csrc/offline-transducer-modified-beam-search-decoder.cc index 1401a839..e845b313 100644 --- a/sherpa-onnx/csrc/offline-transducer-modified-beam-search-decoder.cc +++ b/sherpa-onnx/csrc/offline-transducer-modified-beam-search-decoder.cc @@ -32,7 +32,8 @@ OfflineTransducerModifiedBeamSearchDecoder::Decode( int32_t vocab_size = model_->VocabSize(); int32_t context_size = model_->ContextSize(); - std::vector blanks(context_size, 0); + std::vector blanks(context_size, -1); + blanks.back() = 0; std::deque finalized; std::vector cur; diff --git a/sherpa-onnx/csrc/online-transducer-greedy-search-decoder.cc b/sherpa-onnx/csrc/online-transducer-greedy-search-decoder.cc index 0df46d32..965285ce 100644 --- a/sherpa-onnx/csrc/online-transducer-greedy-search-decoder.cc +++ b/sherpa-onnx/csrc/online-transducer-greedy-search-decoder.cc @@ -55,7 +55,8 @@ OnlineTransducerGreedySearchDecoder::GetEmptyResult() const { int32_t context_size = model_->ContextSize(); int32_t blank_id = 0; // always 0 OnlineTransducerDecoderResult r; - r.tokens.resize(context_size, blank_id); + r.tokens.resize(context_size, -1); + r.tokens.back() = blank_id; return r; } diff --git a/sherpa-onnx/csrc/online-transducer-modified-beam-search-decoder.cc b/sherpa-onnx/csrc/online-transducer-modified-beam-search-decoder.cc index 7e2a4a97..fef67347 100644 --- a/sherpa-onnx/csrc/online-transducer-modified-beam-search-decoder.cc +++ b/sherpa-onnx/csrc/online-transducer-modified-beam-search-decoder.cc @@ -42,7 +42,9 @@ OnlineTransducerModifiedBeamSearchDecoder::GetEmptyResult() const { int32_t context_size = model_->ContextSize(); int32_t blank_id = 0; // always 0 OnlineTransducerDecoderResult r; - std::vector blanks(context_size, blank_id); + std::vector blanks(context_size, -1); + blanks.back() = blank_id; + Hypotheses blank_hyp({{blanks, 0}}); r.hyps = std::move(blank_hyp); r.tokens = std::move(blanks);