Support spoken language identification with whisper (#694)
This commit is contained in:
@@ -12,56 +12,6 @@
|
||||
|
||||
namespace sherpa_onnx {
|
||||
|
||||
int32_t OfflineWhisperGreedySearchDecoder::DetectLanguage(
|
||||
Ort::Value &cross_k, Ort::Value &cross_v) const { // NOLINT
|
||||
int64_t token_val = model_->SOT();
|
||||
std::array<int64_t, 2> token_shape{1, 1};
|
||||
|
||||
auto memory_info =
|
||||
Ort::MemoryInfo::CreateCpu(OrtDeviceAllocator, OrtMemTypeDefault);
|
||||
|
||||
Ort::Value tokens = Ort::Value::CreateTensor(
|
||||
memory_info, &token_val, 1, token_shape.data(), token_shape.size());
|
||||
|
||||
auto self_kv_cache = model_->GetInitialSelfKVCache();
|
||||
|
||||
std::array<int64_t, 1> offset_shape{1};
|
||||
Ort::Value offset = Ort::Value::CreateTensor<int64_t>(
|
||||
model_->Allocator(), offset_shape.data(), offset_shape.size());
|
||||
*(offset.GetTensorMutableData<int64_t>()) = 0;
|
||||
|
||||
auto decoder_out = model_->ForwardDecoder(
|
||||
std::move(tokens), std::move(self_kv_cache.first),
|
||||
std::move(self_kv_cache.second), std::move(cross_k), std::move(cross_v),
|
||||
std::move(offset));
|
||||
|
||||
cross_k = std::move(std::get<3>(decoder_out));
|
||||
cross_v = std::move(std::get<4>(decoder_out));
|
||||
|
||||
const float *p_logits = std::get<0>(decoder_out).GetTensorData<float>();
|
||||
int32_t vocab_size = model_->VocabSize();
|
||||
const auto &all_language_ids = model_->GetAllLanguageIDs();
|
||||
|
||||
int32_t lang_id = all_language_ids[0];
|
||||
float this_logit = p_logits[lang_id];
|
||||
|
||||
for (int32_t i = 1; i != all_language_ids.size(); ++i) {
|
||||
int32_t id = all_language_ids[i];
|
||||
float p = p_logits[id];
|
||||
|
||||
if (p > this_logit) {
|
||||
this_logit = p;
|
||||
lang_id = id;
|
||||
}
|
||||
}
|
||||
#if 1
|
||||
SHERPA_ONNX_LOGE("Detected language: %s",
|
||||
model_->GetID2Lang().at(lang_id).c_str());
|
||||
#endif
|
||||
|
||||
return lang_id;
|
||||
}
|
||||
|
||||
std::vector<OfflineWhisperDecoderResult>
|
||||
OfflineWhisperGreedySearchDecoder::Decode(Ort::Value cross_k,
|
||||
Ort::Value cross_v) {
|
||||
@@ -89,7 +39,7 @@ OfflineWhisperGreedySearchDecoder::Decode(Ort::Value cross_k,
|
||||
// 0: sot, 1: lang_id, 2: task, 3: no_timestamps
|
||||
initial_tokens[1] = lang_id;
|
||||
} else {
|
||||
int32_t lang_id = DetectLanguage(cross_k, cross_v);
|
||||
int32_t lang_id = model_->DetectLanguage(cross_k, cross_v);
|
||||
|
||||
// 0: sot, 1: lang_id, 2: task, 3: no_timestamps
|
||||
initial_tokens[1] = lang_id;
|
||||
|
||||
Reference in New Issue
Block a user