diff --git a/sherpa-onnx/csrc/offline-recognizer-whisper-impl.h b/sherpa-onnx/csrc/offline-recognizer-whisper-impl.h index 605d864b..7c928b09 100644 --- a/sherpa-onnx/csrc/offline-recognizer-whisper-impl.h +++ b/sherpa-onnx/csrc/offline-recognizer-whisper-impl.h @@ -131,7 +131,7 @@ class OfflineRecognizerWhisperImpl : public OfflineRecognizerImpl { auto cross_kv = model_->ForwardEncoder(std::move(mel)); auto results = decoder_->Decode(std::move(cross_kv.first), - std::move(cross_kv.second)); + std::move(cross_kv.second), num_frames); auto r = Convert(results[0], symbol_table_); s->SetResult(r); diff --git a/sherpa-onnx/csrc/offline-whisper-decoder.h b/sherpa-onnx/csrc/offline-whisper-decoder.h index 9cb5088e..6432757f 100644 --- a/sherpa-onnx/csrc/offline-whisper-decoder.h +++ b/sherpa-onnx/csrc/offline-whisper-decoder.h @@ -33,7 +33,8 @@ class OfflineWhisperDecoder { * @return Return a vector of size `N` containing the decoded results. */ virtual std::vector Decode( - Ort::Value n_layer_cross_k, Ort::Value n_layer_cross_v) = 0; + Ort::Value n_layer_cross_k, Ort::Value n_layer_cross_v, + int32_t num_feature_frames) = 0; virtual void SetConfig(const OfflineWhisperModelConfig &config) = 0; }; diff --git a/sherpa-onnx/csrc/offline-whisper-greedy-search-decoder.cc b/sherpa-onnx/csrc/offline-whisper-greedy-search-decoder.cc index 061b6ad1..c391e323 100644 --- a/sherpa-onnx/csrc/offline-whisper-greedy-search-decoder.cc +++ b/sherpa-onnx/csrc/offline-whisper-greedy-search-decoder.cc @@ -19,7 +19,8 @@ void OfflineWhisperGreedySearchDecoder::SetConfig( std::vector OfflineWhisperGreedySearchDecoder::Decode(Ort::Value cross_k, - Ort::Value cross_v) { + Ort::Value cross_v, + int32_t num_feature_frames) { auto memory_info = Ort::MemoryInfo::CreateCpu(OrtDeviceAllocator, OrtMemTypeDefault); @@ -99,7 +100,12 @@ OfflineWhisperGreedySearchDecoder::Decode(Ort::Value cross_k, int32_t n_text_ctx = model_->TextCtx(); std::vector predicted_tokens; - for (int32_t i = 0; i < n_text_ctx / 2; ++i) { + + // assume at most 6 tokens per second + int32_t num_possible_tokens = num_feature_frames / 100 * 6; + num_possible_tokens = std::min(num_possible_tokens, n_text_ctx / 2); + + for (int32_t i = 0; i < num_possible_tokens; ++i) { if (max_token_id == model_->EOT()) { break; } diff --git a/sherpa-onnx/csrc/offline-whisper-greedy-search-decoder.h b/sherpa-onnx/csrc/offline-whisper-greedy-search-decoder.h index 9692d90d..d2a0f527 100644 --- a/sherpa-onnx/csrc/offline-whisper-greedy-search-decoder.h +++ b/sherpa-onnx/csrc/offline-whisper-greedy-search-decoder.h @@ -18,8 +18,9 @@ class OfflineWhisperGreedySearchDecoder : public OfflineWhisperDecoder { OfflineWhisperModel *model) : config_(config), model_(model) {} - std::vector Decode(Ort::Value cross_k, - Ort::Value cross_v) override; + std::vector Decode( + Ort::Value cross_k, Ort::Value cross_v, + int32_t num_feature_frames) override; void SetConfig(const OfflineWhisperModelConfig &config) override;