// sherpa-onnx/csrc/offline-fire-red-asr-greedy-search-decoder.cc // // Copyright (c) 2025 Xiaomi Corporation #include "sherpa-onnx/csrc/offline-fire-red-asr-greedy-search-decoder.h" #include #include #include #include "sherpa-onnx/csrc/macros.h" #include "sherpa-onnx/csrc/onnx-utils.h" namespace sherpa_onnx { // Note: this functions works only for batch size == 1 at present std::vector OfflineFireRedAsrGreedySearchDecoder::Decode(Ort::Value cross_k, Ort::Value cross_v) { const auto &meta_data = model_->GetModelMetadata(); auto memory_info = Ort::MemoryInfo::CreateCpu(OrtDeviceAllocator, OrtMemTypeDefault); // For multilingual models, initial_tokens contains [sot, language, task] // - language is English by default // - task is transcribe by default // // For non-multilingual models, initial_tokens contains [sot] std::array token_shape = {1, 1}; int64_t token = meta_data.sos_id; int32_t batch_size = 1; Ort::Value tokens = Ort::Value::CreateTensor( memory_info, &token, 1, token_shape.data(), token_shape.size()); std::array offset_shape{1}; Ort::Value offset = Ort::Value::CreateTensor( model_->Allocator(), offset_shape.data(), offset_shape.size()); *(offset.GetTensorMutableData()) = 0; std::vector ans(1); auto self_kv_cache = model_->GetInitialSelfKVCache(); std::tuple decoder_out = {Ort::Value{nullptr}, std::move(self_kv_cache.first), std::move(self_kv_cache.second), std::move(cross_k), std::move(cross_v), std::move(offset)}; for (int32_t i = 0; i < meta_data.max_len; ++i) { decoder_out = model_->ForwardDecoder(View(&tokens), std::move(std::get<1>(decoder_out)), std::move(std::get<2>(decoder_out)), std::move(std::get<3>(decoder_out)), std::move(std::get<4>(decoder_out)), std::move(std::get<5>(decoder_out))); const auto &logits = std::get<0>(decoder_out); const float *p_logits = logits.GetTensorData(); auto logits_shape = logits.GetTensorTypeAndShapeInfo().GetShape(); int32_t vocab_size = logits_shape[2]; int32_t max_token_id = static_cast(std::distance( p_logits, std::max_element(p_logits, p_logits + vocab_size))); if (max_token_id == meta_data.eos_id) { break; } ans[0].tokens.push_back(max_token_id); token = max_token_id; // increment offset *(std::get<5>(decoder_out).GetTensorMutableData()) += 1; } return ans; } } // namespace sherpa_onnx