Support spoken language identification with whisper (#694)

This commit is contained in:
Fangjun Kuang
2024-03-24 22:57:00 +08:00
committed by GitHub
parent 3cdad9b5d1
commit 0d258dd150
36 changed files with 1173 additions and 200 deletions

View File

@@ -24,6 +24,24 @@ class OfflineWhisperModel::Impl {
env_(ORT_LOGGING_LEVEL_ERROR),
sess_opts_(GetSessionOptions(config)),
allocator_{} {
debug_ = config_.debug;
{
auto buf = ReadFile(config.whisper.encoder);
InitEncoder(buf.data(), buf.size());
}
{
auto buf = ReadFile(config.whisper.decoder);
InitDecoder(buf.data(), buf.size());
}
}
explicit Impl(const SpokenLanguageIdentificationConfig &config)
: lid_config_(config),
env_(ORT_LOGGING_LEVEL_ERROR),
sess_opts_(GetSessionOptions(config)),
allocator_{} {
debug_ = config_.debug;
{
auto buf = ReadFile(config.whisper.encoder);
InitEncoder(buf.data(), buf.size());
@@ -41,6 +59,7 @@ class OfflineWhisperModel::Impl {
env_(ORT_LOGGING_LEVEL_ERROR),
sess_opts_(GetSessionOptions(config)),
allocator_{} {
debug_ = config_.debug;
{
auto buf = ReadFile(mgr, config.whisper.encoder);
InitEncoder(buf.data(), buf.size());
@@ -85,6 +104,57 @@ class OfflineWhisperModel::Impl {
std::move(decoder_input[4]), std::move(decoder_input[5])};
}
int32_t DetectLanguage(Ort::Value &cross_k, // NOLINT
Ort::Value &cross_v) { // NOLINT
int64_t token_val = SOT();
std::array<int64_t, 2> token_shape{1, 1};
auto memory_info =
Ort::MemoryInfo::CreateCpu(OrtDeviceAllocator, OrtMemTypeDefault);
Ort::Value tokens = Ort::Value::CreateTensor(
memory_info, &token_val, 1, token_shape.data(), token_shape.size());
auto self_kv_cache = GetInitialSelfKVCache();
std::array<int64_t, 1> offset_shape{1};
Ort::Value offset = Ort::Value::CreateTensor<int64_t>(
Allocator(), offset_shape.data(), offset_shape.size());
*(offset.GetTensorMutableData<int64_t>()) = 0;
auto decoder_out =
ForwardDecoder(std::move(tokens), std::move(self_kv_cache.first),
std::move(self_kv_cache.second), std::move(cross_k),
std::move(cross_v), std::move(offset));
cross_k = std::move(std::get<3>(decoder_out));
cross_v = std::move(std::get<4>(decoder_out));
const float *p_logits = std::get<0>(decoder_out).GetTensorData<float>();
int32_t vocab_size = VocabSize();
const auto &all_language_ids = GetAllLanguageIDs();
int32_t lang_id = all_language_ids[0];
float this_logit = p_logits[lang_id];
for (int32_t i = 1; i != all_language_ids.size(); ++i) {
int32_t id = all_language_ids[i];
float p = p_logits[id];
if (p > this_logit) {
this_logit = p;
lang_id = id;
}
}
if (debug_) {
SHERPA_ONNX_LOGE("Detected language: %s",
GetID2Lang().at(lang_id).c_str());
}
return lang_id;
}
std::pair<Ort::Value, Ort::Value> GetInitialSelfKVCache() {
std::array<int64_t, 4> shape{n_text_layer_, 1, n_text_ctx_, n_text_state_};
@@ -148,7 +218,7 @@ class OfflineWhisperModel::Impl {
// get meta data
Ort::ModelMetadata meta_data = encoder_sess_->GetModelMetadata();
if (config_.debug) {
if (debug_) {
std::ostringstream os;
os << "---encoder---\n";
PrintModelMetadata(os, meta_data);
@@ -203,6 +273,8 @@ class OfflineWhisperModel::Impl {
private:
OfflineModelConfig config_;
SpokenLanguageIdentificationConfig lid_config_;
bool debug_ = false;
Ort::Env env_;
Ort::SessionOptions sess_opts_;
Ort::AllocatorWithDefaultOptions allocator_;
@@ -246,6 +318,10 @@ class OfflineWhisperModel::Impl {
OfflineWhisperModel::OfflineWhisperModel(const OfflineModelConfig &config)
: impl_(std::make_unique<Impl>(config)) {}
OfflineWhisperModel::OfflineWhisperModel(
const SpokenLanguageIdentificationConfig &config)
: impl_(std::make_unique<Impl>(config)) {}
#if __ANDROID_API__ >= 9
OfflineWhisperModel::OfflineWhisperModel(AAssetManager *mgr,
const OfflineModelConfig &config)
@@ -273,6 +349,11 @@ OfflineWhisperModel::ForwardDecoder(Ort::Value tokens,
std::move(n_layer_cross_v), std::move(offset));
}
int32_t OfflineWhisperModel::DetectLanguage(Ort::Value &cross_k, // NOLINT
Ort::Value &cross_v) { // NOLINT
return impl_->DetectLanguage(cross_k, cross_v);
}
std::pair<Ort::Value, Ort::Value> OfflineWhisperModel::GetInitialSelfKVCache()
const {
return impl_->GetInitialSelfKVCache();
@@ -318,4 +399,35 @@ bool OfflineWhisperModel::IsMultiLingual() const {
return impl_->IsMultiLingual();
}
void OfflineWhisperModel::NormalizeFeatures(float *features, int32_t num_frames,
int32_t feat_dim) {
// log_spec = torch.clamp(features, min=1e-10).log10()
// log_spec = torch.maximum(log_spec, log_spec.max() - 8.0)
// mel = (log_spec + 4.0) / 4.0
int32_t n = num_frames * feat_dim;
float max_v = -1e20;
for (int32_t i = 0; i != n; ++i) {
float f = features[i];
f = std::max<float>(f, 1e-10);
f = std::log10(f);
max_v = std::max(f, max_v);
features[i] = f;
}
max_v -= 8;
for (int32_t i = 0; i != n; ++i) {
float f = features[i];
f = std::max(f, max_v);
f = (f + 4) / 4;
features[i] = f;
}
}
} // namespace sherpa_onnx