Limit number of tokens per second for whisper. (#1958)
Otherwise, it spends lots of time in the loop if the EOT token is not predicted.
This commit is contained in:
@@ -131,7 +131,7 @@ class OfflineRecognizerWhisperImpl : public OfflineRecognizerImpl {
|
|||||||
auto cross_kv = model_->ForwardEncoder(std::move(mel));
|
auto cross_kv = model_->ForwardEncoder(std::move(mel));
|
||||||
|
|
||||||
auto results = decoder_->Decode(std::move(cross_kv.first),
|
auto results = decoder_->Decode(std::move(cross_kv.first),
|
||||||
std::move(cross_kv.second));
|
std::move(cross_kv.second), num_frames);
|
||||||
|
|
||||||
auto r = Convert(results[0], symbol_table_);
|
auto r = Convert(results[0], symbol_table_);
|
||||||
s->SetResult(r);
|
s->SetResult(r);
|
||||||
|
|||||||
@@ -33,7 +33,8 @@ class OfflineWhisperDecoder {
|
|||||||
* @return Return a vector of size `N` containing the decoded results.
|
* @return Return a vector of size `N` containing the decoded results.
|
||||||
*/
|
*/
|
||||||
virtual std::vector<OfflineWhisperDecoderResult> Decode(
|
virtual std::vector<OfflineWhisperDecoderResult> Decode(
|
||||||
Ort::Value n_layer_cross_k, Ort::Value n_layer_cross_v) = 0;
|
Ort::Value n_layer_cross_k, Ort::Value n_layer_cross_v,
|
||||||
|
int32_t num_feature_frames) = 0;
|
||||||
|
|
||||||
virtual void SetConfig(const OfflineWhisperModelConfig &config) = 0;
|
virtual void SetConfig(const OfflineWhisperModelConfig &config) = 0;
|
||||||
};
|
};
|
||||||
|
|||||||
@@ -19,7 +19,8 @@ void OfflineWhisperGreedySearchDecoder::SetConfig(
|
|||||||
|
|
||||||
std::vector<OfflineWhisperDecoderResult>
|
std::vector<OfflineWhisperDecoderResult>
|
||||||
OfflineWhisperGreedySearchDecoder::Decode(Ort::Value cross_k,
|
OfflineWhisperGreedySearchDecoder::Decode(Ort::Value cross_k,
|
||||||
Ort::Value cross_v) {
|
Ort::Value cross_v,
|
||||||
|
int32_t num_feature_frames) {
|
||||||
auto memory_info =
|
auto memory_info =
|
||||||
Ort::MemoryInfo::CreateCpu(OrtDeviceAllocator, OrtMemTypeDefault);
|
Ort::MemoryInfo::CreateCpu(OrtDeviceAllocator, OrtMemTypeDefault);
|
||||||
|
|
||||||
@@ -99,7 +100,12 @@ OfflineWhisperGreedySearchDecoder::Decode(Ort::Value cross_k,
|
|||||||
int32_t n_text_ctx = model_->TextCtx();
|
int32_t n_text_ctx = model_->TextCtx();
|
||||||
|
|
||||||
std::vector<int32_t> predicted_tokens;
|
std::vector<int32_t> predicted_tokens;
|
||||||
for (int32_t i = 0; i < n_text_ctx / 2; ++i) {
|
|
||||||
|
// assume at most 6 tokens per second
|
||||||
|
int32_t num_possible_tokens = num_feature_frames / 100 * 6;
|
||||||
|
num_possible_tokens = std::min<int32_t>(num_possible_tokens, n_text_ctx / 2);
|
||||||
|
|
||||||
|
for (int32_t i = 0; i < num_possible_tokens; ++i) {
|
||||||
if (max_token_id == model_->EOT()) {
|
if (max_token_id == model_->EOT()) {
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -18,8 +18,9 @@ class OfflineWhisperGreedySearchDecoder : public OfflineWhisperDecoder {
|
|||||||
OfflineWhisperModel *model)
|
OfflineWhisperModel *model)
|
||||||
: config_(config), model_(model) {}
|
: config_(config), model_(model) {}
|
||||||
|
|
||||||
std::vector<OfflineWhisperDecoderResult> Decode(Ort::Value cross_k,
|
std::vector<OfflineWhisperDecoderResult> Decode(
|
||||||
Ort::Value cross_v) override;
|
Ort::Value cross_k, Ort::Value cross_v,
|
||||||
|
int32_t num_feature_frames) override;
|
||||||
|
|
||||||
void SetConfig(const OfflineWhisperModelConfig &config) override;
|
void SetConfig(const OfflineWhisperModelConfig &config) override;
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user