// sherpa-onnx/csrc/offline-whisper-greedy-search-decoder.cc // // Copyright (c) 2023 Xiaomi Corporation #include "sherpa-onnx/csrc/offline-whisper-greedy-search-decoder.h" #include #include namespace sherpa_onnx { std::vector OfflineWhisperGreedySearchDecoder::Decode(Ort::Value cross_k, Ort::Value cross_v) { auto memory_info = Ort::MemoryInfo::CreateCpu(OrtDeviceAllocator, OrtMemTypeDefault); auto self_kv_cache = model_->GetInitialSelfKVCache(); std::vector initial_tokens = model_->GetInitialTokens(); int32_t batch_size = 1; std::array token_shape{ batch_size, static_cast(initial_tokens.size())}; Ort::Value tokens = Ort::Value::CreateTensor( memory_info, initial_tokens.data(), initial_tokens.size(), token_shape.data(), token_shape.size()); std::array offset_shape{1}; Ort::Value offset = Ort::Value::CreateTensor( model_->Allocator(), offset_shape.data(), offset_shape.size()); *(offset.GetTensorMutableData()) = 0; auto decoder_out = model_->ForwardDecoder( std::move(tokens), std::move(self_kv_cache.first), std::move(self_kv_cache.second), std::move(cross_k), std::move(cross_v), std::move(offset)); const auto &logits = std::get<0>(decoder_out); const float *p_logits = logits.GetTensorData(); auto logits_shape = logits.GetTensorTypeAndShapeInfo().GetShape(); int32_t vocab_size = logits_shape[2]; int32_t max_token_id = static_cast(std::distance( p_logits, std::max_element(p_logits, p_logits + vocab_size))); int32_t n_text_ctx = model_->TextCtx(); std::vector predicted_tokens; for (int32_t i = 0; i < n_text_ctx; ++i) { if (max_token_id == model_->EOT()) { break; } predicted_tokens.push_back(max_token_id); std::array token_shape{1, 1}; Ort::Value tokens = Ort::Value::CreateTensor( model_->Allocator(), token_shape.data(), token_shape.size()); int64_t *p_tokens = tokens.GetTensorMutableData(); p_tokens[0] = max_token_id; int64_t *p_offset = std::get<5>(decoder_out).GetTensorMutableData(); if (i == 0) { *p_offset = initial_tokens.size(); } else { *p_offset += 1; } decoder_out = model_->ForwardDecoder(std::move(tokens), std::move(std::get<1>(decoder_out)), std::move(std::get<2>(decoder_out)), std::move(std::get<3>(decoder_out)), std::move(std::get<4>(decoder_out)), std::move(std::get<5>(decoder_out))); const auto &logits = std::get<0>(decoder_out); const float *p_logits = logits.GetTensorData(); max_token_id = static_cast(std::distance( p_logits, std::max_element(p_logits, p_logits + vocab_size))); } std::vector ans(1); ans[0].tokens = std::move(predicted_tokens); return ans; } } // namespace sherpa_onnx