Support spoken language identification with whisper (#694)

This commit is contained in:
Fangjun Kuang
2024-03-24 22:57:00 +08:00
committed by GitHub
parent 3cdad9b5d1
commit 0d258dd150
36 changed files with 1173 additions and 200 deletions

View File

@@ -86,6 +86,8 @@ set(sources
silero-vad-model-config.cc
silero-vad-model.cc
slice.cc
spoken-language-identification-impl.cc
spoken-language-identification.cc
stack.cc
symbol-table.cc
text-utils.cc
@@ -184,6 +186,7 @@ if(SHERPA_ONNX_ENABLE_BINARY)
add_executable(sherpa-onnx-offline sherpa-onnx-offline.cc)
add_executable(sherpa-onnx-offline-parallel sherpa-onnx-offline-parallel.cc)
add_executable(sherpa-onnx-offline-tts sherpa-onnx-offline-tts.cc)
add_executable(sherpa-onnx-offline-language-identification sherpa-onnx-offline-language-identification.cc)
set(main_exes
sherpa-onnx
@@ -191,6 +194,7 @@ if(SHERPA_ONNX_ENABLE_BINARY)
sherpa-onnx-offline
sherpa-onnx-offline-parallel
sherpa-onnx-offline-tts
sherpa-onnx-offline-language-identification
)
foreach(exe IN LISTS main_exes)

View File

@@ -23,7 +23,7 @@ enum class ModelType {
kTdnn,
kZipformerCtc,
kWenetCtc,
kUnkown,
kUnknown,
};
} // namespace
@@ -59,7 +59,7 @@ static ModelType GetModelType(char *model_data, size_t model_data_length,
"run.sh\n"
"\n"
"for how to add metadta to model.onnx\n");
return ModelType::kUnkown;
return ModelType::kUnknown;
}
if (model_type.get() == std::string("EncDecCTCModelBPE")) {
@@ -72,13 +72,13 @@ static ModelType GetModelType(char *model_data, size_t model_data_length,
return ModelType::kWenetCtc;
} else {
SHERPA_ONNX_LOGE("Unsupported model_type: %s", model_type.get());
return ModelType::kUnkown;
return ModelType::kUnknown;
}
}
std::unique_ptr<OfflineCtcModel> OfflineCtcModel::Create(
const OfflineModelConfig &config) {
ModelType model_type = ModelType::kUnkown;
ModelType model_type = ModelType::kUnknown;
std::string filename;
if (!config.nemo_ctc.model.empty()) {
@@ -113,7 +113,7 @@ std::unique_ptr<OfflineCtcModel> OfflineCtcModel::Create(
case ModelType::kWenetCtc:
return std::make_unique<OfflineWenetCtcModel>(config);
break;
case ModelType::kUnkown:
case ModelType::kUnknown:
SHERPA_ONNX_LOGE("Unknown model type in offline CTC!");
return nullptr;
}
@@ -125,7 +125,7 @@ std::unique_ptr<OfflineCtcModel> OfflineCtcModel::Create(
std::unique_ptr<OfflineCtcModel> OfflineCtcModel::Create(
AAssetManager *mgr, const OfflineModelConfig &config) {
ModelType model_type = ModelType::kUnkown;
ModelType model_type = ModelType::kUnknown;
std::string filename;
if (!config.nemo_ctc.model.empty()) {
@@ -160,7 +160,7 @@ std::unique_ptr<OfflineCtcModel> OfflineCtcModel::Create(
case ModelType::kWenetCtc:
return std::make_unique<OfflineWenetCtcModel>(mgr, config);
break;
case ModelType::kUnkown:
case ModelType::kUnknown:
SHERPA_ONNX_LOGE("Unknown model type in offline CTC!");
return nullptr;
}

View File

@@ -114,7 +114,7 @@ class OfflineRecognizerWhisperImpl : public OfflineRecognizerImpl {
num_frames = max_num_frames - 50;
}
NormalizeFeatures(f.data(), num_frames, feat_dim);
model_->NormalizeFeatures(f.data(), num_frames, feat_dim);
// note that 1000 is an experience-value.
// You can replace 1000 by other values, say, 100.
@@ -162,38 +162,6 @@ class OfflineRecognizerWhisperImpl : public OfflineRecognizerImpl {
}
}
private:
static void NormalizeFeatures(float *features, int32_t num_frames,
int32_t feat_dim) {
// log_spec = torch.clamp(features, min=1e-10).log10()
// log_spec = torch.maximum(log_spec, log_spec.max() - 8.0)
// mel = (log_spec + 4.0) / 4.0
int32_t n = num_frames * feat_dim;
float max_v = -1e20;
for (int32_t i = 0; i != n; ++i) {
float f = features[i];
f = std::max<float>(f, 1e-10);
f = std::log10(f);
max_v = std::max(f, max_v);
features[i] = f;
}
max_v -= 8;
for (int32_t i = 0; i != n; ++i) {
float f = features[i];
f = std::max(f, max_v);
f = (f + 4) / 4;
features[i] = f;
}
}
private:
OfflineRecognizerConfig config_;
SymbolTable symbol_table_;

View File

@@ -12,56 +12,6 @@
namespace sherpa_onnx {
int32_t OfflineWhisperGreedySearchDecoder::DetectLanguage(
Ort::Value &cross_k, Ort::Value &cross_v) const { // NOLINT
int64_t token_val = model_->SOT();
std::array<int64_t, 2> token_shape{1, 1};
auto memory_info =
Ort::MemoryInfo::CreateCpu(OrtDeviceAllocator, OrtMemTypeDefault);
Ort::Value tokens = Ort::Value::CreateTensor(
memory_info, &token_val, 1, token_shape.data(), token_shape.size());
auto self_kv_cache = model_->GetInitialSelfKVCache();
std::array<int64_t, 1> offset_shape{1};
Ort::Value offset = Ort::Value::CreateTensor<int64_t>(
model_->Allocator(), offset_shape.data(), offset_shape.size());
*(offset.GetTensorMutableData<int64_t>()) = 0;
auto decoder_out = model_->ForwardDecoder(
std::move(tokens), std::move(self_kv_cache.first),
std::move(self_kv_cache.second), std::move(cross_k), std::move(cross_v),
std::move(offset));
cross_k = std::move(std::get<3>(decoder_out));
cross_v = std::move(std::get<4>(decoder_out));
const float *p_logits = std::get<0>(decoder_out).GetTensorData<float>();
int32_t vocab_size = model_->VocabSize();
const auto &all_language_ids = model_->GetAllLanguageIDs();
int32_t lang_id = all_language_ids[0];
float this_logit = p_logits[lang_id];
for (int32_t i = 1; i != all_language_ids.size(); ++i) {
int32_t id = all_language_ids[i];
float p = p_logits[id];
if (p > this_logit) {
this_logit = p;
lang_id = id;
}
}
#if 1
SHERPA_ONNX_LOGE("Detected language: %s",
model_->GetID2Lang().at(lang_id).c_str());
#endif
return lang_id;
}
std::vector<OfflineWhisperDecoderResult>
OfflineWhisperGreedySearchDecoder::Decode(Ort::Value cross_k,
Ort::Value cross_v) {
@@ -89,7 +39,7 @@ OfflineWhisperGreedySearchDecoder::Decode(Ort::Value cross_k,
// 0: sot, 1: lang_id, 2: task, 3: no_timestamps
initial_tokens[1] = lang_id;
} else {
int32_t lang_id = DetectLanguage(cross_k, cross_v);
int32_t lang_id = model_->DetectLanguage(cross_k, cross_v);
// 0: sot, 1: lang_id, 2: task, 3: no_timestamps
initial_tokens[1] = lang_id;

View File

@@ -22,9 +22,6 @@ class OfflineWhisperGreedySearchDecoder : public OfflineWhisperDecoder {
std::vector<OfflineWhisperDecoderResult> Decode(Ort::Value cross_k,
Ort::Value cross_v) override;
int32_t DetectLanguage(Ort::Value &cross_k, // NOLINT
Ort::Value &cross_v) const; // NOLINT
private:
OfflineWhisperModelConfig config_;
OfflineWhisperModel *model_; // not owned

View File

@@ -35,19 +35,28 @@ void OfflineWhisperModelConfig::Register(ParseOptions *po) {
po->Register(
"whisper-tail-paddings", &tail_paddings,
"Suggest value: 50 for English models. 300 for multilingual models. "
"Suggested value: 50 for English models. 300 for multilingual models. "
"Since we have removed the 30-second constraint, we need to add some "
"tail padding frames "
"so that whisper can detect the eot token. Leave it to -1 to use 50 for "
"English models and 300 for multilingual models.");
"so that whisper can detect the eot token. Leave it to -1 to use 1000.");
}
bool OfflineWhisperModelConfig::Validate() const {
if (encoder.empty()) {
SHERPA_ONNX_LOGE("Please provide --whisper-encoder");
return false;
}
if (!FileExists(encoder)) {
SHERPA_ONNX_LOGE("whisper encoder file %s does not exist", encoder.c_str());
return false;
}
if (decoder.empty()) {
SHERPA_ONNX_LOGE("Please provide --whisper-decoder");
return false;
}
if (!FileExists(decoder)) {
SHERPA_ONNX_LOGE("whisper decoder file %s does not exist", decoder.c_str());
return false;

View File

@@ -24,6 +24,24 @@ class OfflineWhisperModel::Impl {
env_(ORT_LOGGING_LEVEL_ERROR),
sess_opts_(GetSessionOptions(config)),
allocator_{} {
debug_ = config_.debug;
{
auto buf = ReadFile(config.whisper.encoder);
InitEncoder(buf.data(), buf.size());
}
{
auto buf = ReadFile(config.whisper.decoder);
InitDecoder(buf.data(), buf.size());
}
}
explicit Impl(const SpokenLanguageIdentificationConfig &config)
: lid_config_(config),
env_(ORT_LOGGING_LEVEL_ERROR),
sess_opts_(GetSessionOptions(config)),
allocator_{} {
debug_ = config_.debug;
{
auto buf = ReadFile(config.whisper.encoder);
InitEncoder(buf.data(), buf.size());
@@ -41,6 +59,7 @@ class OfflineWhisperModel::Impl {
env_(ORT_LOGGING_LEVEL_ERROR),
sess_opts_(GetSessionOptions(config)),
allocator_{} {
debug_ = config_.debug;
{
auto buf = ReadFile(mgr, config.whisper.encoder);
InitEncoder(buf.data(), buf.size());
@@ -85,6 +104,57 @@ class OfflineWhisperModel::Impl {
std::move(decoder_input[4]), std::move(decoder_input[5])};
}
int32_t DetectLanguage(Ort::Value &cross_k, // NOLINT
Ort::Value &cross_v) { // NOLINT
int64_t token_val = SOT();
std::array<int64_t, 2> token_shape{1, 1};
auto memory_info =
Ort::MemoryInfo::CreateCpu(OrtDeviceAllocator, OrtMemTypeDefault);
Ort::Value tokens = Ort::Value::CreateTensor(
memory_info, &token_val, 1, token_shape.data(), token_shape.size());
auto self_kv_cache = GetInitialSelfKVCache();
std::array<int64_t, 1> offset_shape{1};
Ort::Value offset = Ort::Value::CreateTensor<int64_t>(
Allocator(), offset_shape.data(), offset_shape.size());
*(offset.GetTensorMutableData<int64_t>()) = 0;
auto decoder_out =
ForwardDecoder(std::move(tokens), std::move(self_kv_cache.first),
std::move(self_kv_cache.second), std::move(cross_k),
std::move(cross_v), std::move(offset));
cross_k = std::move(std::get<3>(decoder_out));
cross_v = std::move(std::get<4>(decoder_out));
const float *p_logits = std::get<0>(decoder_out).GetTensorData<float>();
int32_t vocab_size = VocabSize();
const auto &all_language_ids = GetAllLanguageIDs();
int32_t lang_id = all_language_ids[0];
float this_logit = p_logits[lang_id];
for (int32_t i = 1; i != all_language_ids.size(); ++i) {
int32_t id = all_language_ids[i];
float p = p_logits[id];
if (p > this_logit) {
this_logit = p;
lang_id = id;
}
}
if (debug_) {
SHERPA_ONNX_LOGE("Detected language: %s",
GetID2Lang().at(lang_id).c_str());
}
return lang_id;
}
std::pair<Ort::Value, Ort::Value> GetInitialSelfKVCache() {
std::array<int64_t, 4> shape{n_text_layer_, 1, n_text_ctx_, n_text_state_};
@@ -148,7 +218,7 @@ class OfflineWhisperModel::Impl {
// get meta data
Ort::ModelMetadata meta_data = encoder_sess_->GetModelMetadata();
if (config_.debug) {
if (debug_) {
std::ostringstream os;
os << "---encoder---\n";
PrintModelMetadata(os, meta_data);
@@ -203,6 +273,8 @@ class OfflineWhisperModel::Impl {
private:
OfflineModelConfig config_;
SpokenLanguageIdentificationConfig lid_config_;
bool debug_ = false;
Ort::Env env_;
Ort::SessionOptions sess_opts_;
Ort::AllocatorWithDefaultOptions allocator_;
@@ -246,6 +318,10 @@ class OfflineWhisperModel::Impl {
OfflineWhisperModel::OfflineWhisperModel(const OfflineModelConfig &config)
: impl_(std::make_unique<Impl>(config)) {}
OfflineWhisperModel::OfflineWhisperModel(
const SpokenLanguageIdentificationConfig &config)
: impl_(std::make_unique<Impl>(config)) {}
#if __ANDROID_API__ >= 9
OfflineWhisperModel::OfflineWhisperModel(AAssetManager *mgr,
const OfflineModelConfig &config)
@@ -273,6 +349,11 @@ OfflineWhisperModel::ForwardDecoder(Ort::Value tokens,
std::move(n_layer_cross_v), std::move(offset));
}
int32_t OfflineWhisperModel::DetectLanguage(Ort::Value &cross_k, // NOLINT
Ort::Value &cross_v) { // NOLINT
return impl_->DetectLanguage(cross_k, cross_v);
}
std::pair<Ort::Value, Ort::Value> OfflineWhisperModel::GetInitialSelfKVCache()
const {
return impl_->GetInitialSelfKVCache();
@@ -318,4 +399,35 @@ bool OfflineWhisperModel::IsMultiLingual() const {
return impl_->IsMultiLingual();
}
void OfflineWhisperModel::NormalizeFeatures(float *features, int32_t num_frames,
int32_t feat_dim) {
// log_spec = torch.clamp(features, min=1e-10).log10()
// log_spec = torch.maximum(log_spec, log_spec.max() - 8.0)
// mel = (log_spec + 4.0) / 4.0
int32_t n = num_frames * feat_dim;
float max_v = -1e20;
for (int32_t i = 0; i != n; ++i) {
float f = features[i];
f = std::max<float>(f, 1e-10);
f = std::log10(f);
max_v = std::max(f, max_v);
features[i] = f;
}
max_v -= 8;
for (int32_t i = 0; i != n; ++i) {
float f = features[i];
f = std::max(f, max_v);
f = (f + 4) / 4;
features[i] = f;
}
}
} // namespace sherpa_onnx

View File

@@ -18,6 +18,7 @@
#include "onnxruntime_cxx_api.h" // NOLINT
#include "sherpa-onnx/csrc/offline-model-config.h"
#include "sherpa-onnx/csrc/spoken-language-identification.h"
namespace sherpa_onnx {
@@ -25,6 +26,9 @@ class OfflineWhisperModel {
public:
explicit OfflineWhisperModel(const OfflineModelConfig &config);
explicit OfflineWhisperModel(
const SpokenLanguageIdentificationConfig &config);
#if __ANDROID_API__ >= 9
OfflineWhisperModel(AAssetManager *mgr, const OfflineModelConfig &config);
#endif
@@ -72,7 +76,8 @@ class OfflineWhisperModel {
Ort::Value n_layer_self_v_cache, Ort::Value n_layer_cross_k,
Ort::Value n_layer_cross_v, Ort::Value offset) const;
int32_t DetectLanguage() const;
int32_t DetectLanguage(Ort::Value &cross_k, // NOLINT
Ort::Value &cross_v); // NOLINT
/** Return the initial self kv cache in a pair
* - n_layer_self_k_cache A 4-D tensor of shape
@@ -98,6 +103,9 @@ class OfflineWhisperModel {
int32_t Translate() const;
bool IsMultiLingual() const;
static void NormalizeFeatures(float *features, int32_t num_frames,
int32_t feat_dim);
private:
class Impl;
std::unique_ptr<Impl> impl_;

View File

@@ -28,7 +28,7 @@ enum class ModelType {
kLstm,
kZipformer,
kZipformer2,
kUnkown,
kUnknown,
};
} // namespace
@@ -58,7 +58,7 @@ static ModelType GetModelType(char *model_data, size_t model_data_length,
"No model_type in the metadata!\n"
"Please make sure you are using the latest export-onnx.py from icefall "
"to export your transducer models");
return ModelType::kUnkown;
return ModelType::kUnknown;
}
if (model_type.get() == std::string("conformer")) {
@@ -71,7 +71,7 @@ static ModelType GetModelType(char *model_data, size_t model_data_length,
return ModelType::kZipformer2;
} else {
SHERPA_ONNX_LOGE("Unsupported model_type: %s", model_type.get());
return ModelType::kUnkown;
return ModelType::kUnknown;
}
}
@@ -93,7 +93,7 @@ std::unique_ptr<OnlineTransducerModel> OnlineTransducerModel::Create(
model_type.c_str());
}
}
ModelType model_type = ModelType::kUnkown;
ModelType model_type = ModelType::kUnknown;
{
auto buffer = ReadFile(config.transducer.encoder);
@@ -110,7 +110,7 @@ std::unique_ptr<OnlineTransducerModel> OnlineTransducerModel::Create(
return std::make_unique<OnlineZipformerTransducerModel>(config);
case ModelType::kZipformer2:
return std::make_unique<OnlineZipformer2TransducerModel>(config);
case ModelType::kUnkown:
case ModelType::kUnknown:
SHERPA_ONNX_LOGE("Unknown model type in online transducer!");
return nullptr;
}
@@ -185,7 +185,7 @@ std::unique_ptr<OnlineTransducerModel> OnlineTransducerModel::Create(
return std::make_unique<OnlineZipformerTransducerModel>(mgr, config);
case ModelType::kZipformer2:
return std::make_unique<OnlineZipformer2TransducerModel>(mgr, config);
case ModelType::kUnkown:
case ModelType::kUnknown:
SHERPA_ONNX_LOGE("Unknown model type in online transducer!");
return nullptr;
}

View File

@@ -149,4 +149,9 @@ Ort::SessionOptions GetSessionOptions(
return GetSessionOptionsImpl(config.num_threads, config.provider);
}
Ort::SessionOptions GetSessionOptions(
const SpokenLanguageIdentificationConfig &config) {
return GetSessionOptionsImpl(config.num_threads, config.provider);
}
} // namespace sherpa_onnx

View File

@@ -12,6 +12,7 @@
#include "sherpa-onnx/csrc/online-lm-config.h"
#include "sherpa-onnx/csrc/online-model-config.h"
#include "sherpa-onnx/csrc/speaker-embedding-extractor.h"
#include "sherpa-onnx/csrc/spoken-language-identification.h"
#include "sherpa-onnx/csrc/vad-model-config.h"
namespace sherpa_onnx {
@@ -30,6 +31,10 @@ Ort::SessionOptions GetSessionOptions(const OfflineTtsModelConfig &config);
Ort::SessionOptions GetSessionOptions(
const SpeakerEmbeddingExtractorConfig &config);
Ort::SessionOptions GetSessionOptions(
const SpokenLanguageIdentificationConfig &config);
} // namespace sherpa_onnx
#endif // SHERPA_ONNX_CSRC_SESSION_H_

View File

@@ -0,0 +1,107 @@
// sherpa-onnx/csrc/sherpa-onnx-offline-language-identification.cc
//
// Copyright (c) 2022-2024 Xiaomi Corporation
#include <stdio.h>
#include <chrono> // NOLINT
#include <string>
#include <vector>
#include "sherpa-onnx/csrc/parse-options.h"
#include "sherpa-onnx/csrc/spoken-language-identification.h"
#include "sherpa-onnx/csrc/wave-reader.h"
int main(int32_t argc, char *argv[]) {
const char *kUsageMessage = R"usage(
Spoken language identification with sherpa-onnx.
Usage:
(1) Use a whisper multilingual model
wget https://github.com/k2-fsa/sherpa-onnx/releases/download/asr-models/sherpa-onnx-whisper-tiny.tar.bz2
tar xvf sherpa-onnx-whisper-tiny.tar.bz2
rm sherpa-onnx-whisper-tiny.tar.bz2
We only use the int8.onnx models below.
./bin/sherpa-onnx-offline-spoken-language-identification \
--whisper-encoder=sherpa-onnx-whisper-tiny/tiny-encoder.int8.onnx \
--whisper-decoder=sherpa-onnx-whisper-tiny/tiny-decoder.int8.onnx \
--num-threads=1 \
/path/to/foo.wav
foo.wav should be of single channel, 16-bit PCM encoded wave file; its
sampling rate can be arbitrary and does not need to be 16kHz.
You can find test waves for different languages at
https://hf-mirror.com/spaces/k2-fsa/spoken-language-identification/tree/main/test_wavs
Please refer to
https://k2-fsa.github.io/sherpa/onnx/pretrained_models/whisper/index.html
Note that only whisper multilingual models are supported. For instance,
"tiny" is supported but "tiny.en" is not.
for a list of pre-trained models to download.
)usage";
sherpa_onnx::ParseOptions po(kUsageMessage);
sherpa_onnx::SpokenLanguageIdentificationConfig config;
config.Register(&po);
po.Read(argc, argv);
if (po.NumArgs() != 1) {
fprintf(stderr, "Error: Please provide 1 wave file.\n\n");
po.PrintUsage();
exit(EXIT_FAILURE);
}
fprintf(stderr, "%s\n", config.ToString().c_str());
if (!config.Validate()) {
fprintf(stderr, "Errors in config!\n");
return -1;
}
fprintf(stderr, "Creating spoken language identifier ...\n");
sherpa_onnx::SpokenLanguageIdentification slid(config);
fprintf(stderr, "Started\n");
const std::string wav_filename = po.GetArg(1);
int32_t sampling_rate = -1;
bool is_ok = false;
const std::vector<float> samples =
sherpa_onnx::ReadWave(wav_filename, &sampling_rate, &is_ok);
if (!is_ok) {
fprintf(stderr, "Failed to read %s\n", wav_filename.c_str());
return -1;
}
float duration = samples.size() / static_cast<float>(sampling_rate);
const auto begin = std::chrono::steady_clock::now();
auto s = slid.CreateStream();
s->AcceptWaveform(sampling_rate, samples.data(), samples.size());
auto language = slid.Compute(s.get());
const auto end = std::chrono::steady_clock::now();
fprintf(stderr, "Done!\n\n");
fprintf(stderr, "%s\nDetected language: %s\n", wav_filename.c_str(),
language.c_str());
float elapsed_seconds =
std::chrono::duration_cast<std::chrono::milliseconds>(end - begin)
.count() /
1000.;
fprintf(stderr, "num threads: %d\n", config.num_threads);
fprintf(stderr, "Elapsed seconds: %.3f s\n", elapsed_seconds);
float rtf = elapsed_seconds / duration;
fprintf(stderr, "Real time factor (RTF): %.3f / %.3f = %.3f\n",
elapsed_seconds, duration, rtf);
return 0;
}

View File

@@ -16,7 +16,7 @@ enum class ModelType {
kWeSpeaker,
k3dSpeaker,
kNeMo,
kUnkown,
kUnknown,
};
} // namespace
@@ -47,7 +47,7 @@ static ModelType GetModelType(char *model_data, size_t model_data_length,
"https://github.com/k2-fsa/sherpa-onnx/blob/master/scripts/wespeaker/"
"add_meta_data.py"
"to add metadata to models from WeSpeaker\n");
return ModelType::kUnkown;
return ModelType::kUnknown;
}
if (model_type.get() == std::string("wespeaker")) {
@@ -58,14 +58,14 @@ static ModelType GetModelType(char *model_data, size_t model_data_length,
return ModelType::kNeMo;
} else {
SHERPA_ONNX_LOGE("Unsupported model_type: %s", model_type.get());
return ModelType::kUnkown;
return ModelType::kUnknown;
}
}
std::unique_ptr<SpeakerEmbeddingExtractorImpl>
SpeakerEmbeddingExtractorImpl::Create(
const SpeakerEmbeddingExtractorConfig &config) {
ModelType model_type = ModelType::kUnkown;
ModelType model_type = ModelType::kUnknown;
{
auto buffer = ReadFile(config.model);
@@ -80,9 +80,8 @@ SpeakerEmbeddingExtractorImpl::Create(
return std::make_unique<SpeakerEmbeddingExtractorGeneralImpl>(config);
case ModelType::kNeMo:
return std::make_unique<SpeakerEmbeddingExtractorNeMoImpl>(config);
case ModelType::kUnkown:
SHERPA_ONNX_LOGE(
"Unknown model type in for speaker embedding extractor!");
case ModelType::kUnknown:
SHERPA_ONNX_LOGE("Unknown model type for speaker embedding extractor!");
return nullptr;
}
@@ -94,7 +93,7 @@ SpeakerEmbeddingExtractorImpl::Create(
std::unique_ptr<SpeakerEmbeddingExtractorImpl>
SpeakerEmbeddingExtractorImpl::Create(
AAssetManager *mgr, const SpeakerEmbeddingExtractorConfig &config) {
ModelType model_type = ModelType::kUnkown;
ModelType model_type = ModelType::kUnknown;
{
auto buffer = ReadFile(mgr, config.model);
@@ -110,7 +109,7 @@ SpeakerEmbeddingExtractorImpl::Create(
config);
case ModelType::kNeMo:
return std::make_unique<SpeakerEmbeddingExtractorNeMoImpl>(mgr, config);
case ModelType::kUnkown:
case ModelType::kUnknown:
SHERPA_ONNX_LOGE(
"Unknown model type in for speaker embedding extractor!");
return nullptr;

View File

@@ -0,0 +1,88 @@
// sherpa-onnx/csrc/spoken-language-identification-impl.cc
//
// Copyright (c) 2024 Xiaomi Corporation
#include "sherpa-onnx/csrc/spoken-language-identification-impl.h"
#include <memory>
#include "sherpa-onnx/csrc/macros.h"
#include "sherpa-onnx/csrc/onnx-utils.h"
#include "sherpa-onnx/csrc/spoken-language-identification-whisper-impl.h"
namespace sherpa_onnx {
namespace {
enum class ModelType {
kWhisper,
kUnknown,
};
}
static ModelType GetModelType(char *model_data, size_t model_data_length,
bool debug) {
Ort::Env env(ORT_LOGGING_LEVEL_WARNING);
Ort::SessionOptions sess_opts;
auto sess = std::make_unique<Ort::Session>(env, model_data, model_data_length,
sess_opts);
Ort::ModelMetadata meta_data = sess->GetModelMetadata();
if (debug) {
std::ostringstream os;
PrintModelMetadata(os, meta_data);
SHERPA_ONNX_LOGE("%s", os.str().c_str());
}
Ort::AllocatorWithDefaultOptions allocator;
auto model_type =
meta_data.LookupCustomMetadataMapAllocated("model_type", allocator);
if (!model_type) {
SHERPA_ONNX_LOGE(
"No model_type in the metadata!\n"
"Please make sure you have added metadata to the model.\n\n"
"For instance, you can use\n"
"https://github.com/k2-fsa/sherpa-onnx/blob/master/scripts/whisper/"
"export-onnx.py "
"to add metadata to models from whisper\n");
return ModelType::kUnknown;
}
auto model_type_str = std::string(model_type.get());
if (model_type_str.find("whisper") == 0) {
return ModelType::kWhisper;
} else {
SHERPA_ONNX_LOGE("Unsupported model_type: %s", model_type.get());
return ModelType::kUnknown;
}
}
std::unique_ptr<SpokenLanguageIdentificationImpl>
SpokenLanguageIdentificationImpl::Create(
const SpokenLanguageIdentificationConfig &config) {
ModelType model_type = ModelType::kUnknown;
{
if (config.whisper.encoder.empty()) {
SHERPA_ONNX_LOGE("Only whisper models are supported at present");
exit(-1);
}
auto buffer = ReadFile(config.whisper.encoder);
model_type = GetModelType(buffer.data(), buffer.size(), config.debug);
}
switch (model_type) {
case ModelType::kWhisper:
return std::make_unique<SpokenLanguageIdentificationWhisperImpl>(config);
case ModelType::kUnknown:
SHERPA_ONNX_LOGE(
"Unknown model type for spoken language identification!");
return nullptr;
}
// unreachable code
return nullptr;
}
} // namespace sherpa_onnx

View File

@@ -0,0 +1,28 @@
// sherpa-onnx/csrc/spoken-language-identification-impl.h
//
// Copyright (c) 2024 Xiaomi Corporation
#ifndef SHERPA_ONNX_CSRC_SPOKEN_LANGUAGE_IDENTIFICATION_IMPL_H_
#define SHERPA_ONNX_CSRC_SPOKEN_LANGUAGE_IDENTIFICATION_IMPL_H_
#include <memory>
#include <string>
#include "sherpa-onnx/csrc/spoken-language-identification.h"
namespace sherpa_onnx {
class SpokenLanguageIdentificationImpl {
public:
virtual ~SpokenLanguageIdentificationImpl() = default;
static std::unique_ptr<SpokenLanguageIdentificationImpl> Create(
const SpokenLanguageIdentificationConfig &config);
virtual std::unique_ptr<OfflineStream> CreateStream() const = 0;
virtual std::string Compute(OfflineStream *s) const = 0;
};
} // namespace sherpa_onnx
#endif // SHERPA_ONNX_CSRC_SPOKEN_LANGUAGE_IDENTIFICATION_IMPL_H_

View File

@@ -0,0 +1,119 @@
// sherpa-onnx/csrc/spoken-language-identification-whisper-impl.h
//
// Copyright (c) 2024 Xiaomi Corporation
#ifndef SHERPA_ONNX_CSRC_SPOKEN_LANGUAGE_IDENTIFICATION_WHISPER_IMPL_H_
#define SHERPA_ONNX_CSRC_SPOKEN_LANGUAGE_IDENTIFICATION_WHISPER_IMPL_H_
#include <algorithm>
#include <memory>
#include <string>
#include <utility>
#include <vector>
#include "sherpa-onnx/csrc/offline-whisper-model.h"
#include "sherpa-onnx/csrc/spoken-language-identification-impl.h"
#include "sherpa-onnx/csrc/transpose.h"
namespace sherpa_onnx {
class SpokenLanguageIdentificationWhisperImpl
: public SpokenLanguageIdentificationImpl {
public:
explicit SpokenLanguageIdentificationWhisperImpl(
const SpokenLanguageIdentificationConfig &config)
: config_(config), model_(std::make_unique<OfflineWhisperModel>(config)) {
Check();
}
std::unique_ptr<OfflineStream> CreateStream() const override {
return std::make_unique<OfflineStream>(WhisperTag{});
}
std::string Compute(OfflineStream *s) const override {
int32_t max_num_frames = 3000;
auto memory_info =
Ort::MemoryInfo::CreateCpu(OrtDeviceAllocator, OrtMemTypeDefault);
int32_t feat_dim = s->FeatureDim();
std::vector<float> f = s->GetFrames();
int32_t num_frames = f.size() / feat_dim;
// we use 50 here so that there will be some zero tail paddings
if (num_frames >= max_num_frames - 50) {
SHERPA_ONNX_LOGE(
"Only waves less than 30 seconds are supported. We process only the "
"first 30 seconds and discard the remaining data");
num_frames = max_num_frames - 50;
}
model_->NormalizeFeatures(f.data(), num_frames, feat_dim);
// note that 1000 is an experience-value.
// You can replace 1000 by other values, say, 100.
//
// Since we have removed the 30 seconds constraint, we need
// tail_padding_frames so that whisper is able to detect the eot token.
int32_t tail_padding_frames = 1000;
if (config_.whisper.tail_paddings > 0) {
tail_padding_frames = config_.whisper.tail_paddings;
}
int32_t actual_frames =
std::min(num_frames + tail_padding_frames, max_num_frames);
std::array<int64_t, 3> shape{1, actual_frames, feat_dim};
Ort::Value mel = Ort::Value::CreateTensor<float>(
model_->Allocator(), shape.data(), shape.size());
float *p_mel = mel.GetTensorMutableData<float>();
std::copy(f.data(), f.data() + num_frames * feat_dim, p_mel);
std::fill_n(p_mel + num_frames * feat_dim,
(actual_frames - num_frames) * feat_dim, 0);
mel = Transpose12(model_->Allocator(), &mel);
try {
auto cross_kv = model_->ForwardEncoder(std::move(mel));
int32_t lang_id = model_->DetectLanguage(cross_kv.first, cross_kv.second);
const auto &id2lang = model_->GetID2Lang();
if (id2lang.count(lang_id)) {
return id2lang.at(lang_id);
} else {
SHERPA_ONNX_LOGE("Unknown language ID: %d. Return an empty string.",
lang_id);
return "";
}
} catch (const Ort::Exception &ex) {
SHERPA_ONNX_LOGE(
"\n\nCaught exception:\n\n%s\n\nReturn an empty result. Number of "
"input frames: %d, Current tail "
"paddings: %d. If you see a lot of such exceptions, please consider "
"using a larger --whisper-tail-paddings",
ex.what(), num_frames, tail_padding_frames);
return "";
}
}
private:
void Check() const {
if (!model_->IsMultiLingual()) {
SHERPA_ONNX_LOGE(
"Only whisper multilingual models can be used for spoken language "
"identification. Given: %s,%s",
config_.whisper.encoder.c_str(), config_.whisper.decoder.c_str());
exit(-1);
}
}
private:
SpokenLanguageIdentificationConfig config_;
std::unique_ptr<OfflineWhisperModel> model_;
};
} // namespace sherpa_onnx
#endif // SHERPA_ONNX_CSRC_SPOKEN_LANGUAGE_IDENTIFICATION_WHISPER_IMPL_H_

View File

@@ -0,0 +1,117 @@
// sherpa-onnx/csrc/spoken-language-identification.cc
//
// Copyright (c) 2024 Xiaomi Corporation
#include "sherpa-onnx/csrc/spoken-language-identification.h"
#include <string>
#include "sherpa-onnx/csrc/file-utils.h"
#include "sherpa-onnx/csrc/macros.h"
#include "sherpa-onnx/csrc/spoken-language-identification-impl.h"
namespace sherpa_onnx {
void SpokenLanguageIdentificationWhisperConfig::Register(ParseOptions *po) {
po->Register(
"whisper-encoder", &encoder,
"Path to then encoder of a whisper multilingual model. Support only "
"tiny, base, small, medium, large.");
po->Register(
"whisper-decoder", &decoder,
"Path to the decoder of a whisper multilingual model. Support only "
"tiny, base, small, medium, large.");
po->Register(
"whisper-tail-paddings", &tail_paddings,
"Suggested value: 300 for multilingual models. "
"Since we have removed the 30-second constraint, we need to add some "
"tail padding frames "
"so that whisper can detect the eot token. Leave it to -1 to use 1000");
}
bool SpokenLanguageIdentificationWhisperConfig::Validate() const {
if (encoder.empty()) {
SHERPA_ONNX_LOGE("Please provide --whisper-encoder");
return false;
}
if (!FileExists(encoder)) {
SHERPA_ONNX_LOGE("whisper encoder file %s does not exist", encoder.c_str());
return false;
}
if (decoder.empty()) {
SHERPA_ONNX_LOGE("Please provide --whisper-decoder");
return false;
}
if (!FileExists(decoder)) {
SHERPA_ONNX_LOGE("whisper decoder file %s does not exist", decoder.c_str());
return false;
}
return true;
}
std::string SpokenLanguageIdentificationWhisperConfig::ToString() const {
std::ostringstream os;
os << "SpokenLanguageIdentificationWhisperConfig(";
os << "encoder=\"" << encoder << "\", ";
os << "decoder=\"" << decoder << "\", ";
os << "tail_paddings=" << tail_paddings << ")";
return os.str();
}
void SpokenLanguageIdentificationConfig::Register(ParseOptions *po) {
whisper.Register(po);
po->Register("num-threads", &num_threads,
"Number of threads to run the neural network");
po->Register("debug", &debug,
"true to print model information while loading it.");
po->Register("provider", &provider,
"Specify a provider to use: cpu, cuda, coreml");
}
bool SpokenLanguageIdentificationConfig::Validate() const {
if (!whisper.Validate()) {
return false;
}
return true;
}
std::string SpokenLanguageIdentificationConfig::ToString() const {
std::ostringstream os;
os << "SpokenLanguageIdentificationConfig(";
os << "whisper=\"" << whisper.ToString() << "\", ";
os << "num_threads=" << num_threads << ", ";
os << "debug=" << (debug ? "True" : "False") << ", ";
os << "provider=\"" << provider << "\")";
return os.str();
}
SpokenLanguageIdentification::SpokenLanguageIdentification(
const SpokenLanguageIdentificationConfig &config)
: impl_(SpokenLanguageIdentificationImpl::Create(config)) {}
SpokenLanguageIdentification::~SpokenLanguageIdentification() = default;
std::unique_ptr<OfflineStream> SpokenLanguageIdentification::CreateStream()
const {
return impl_->CreateStream();
}
std::string SpokenLanguageIdentification::Compute(OfflineStream *s) const {
return impl_->Compute(s);
}
} // namespace sherpa_onnx

View File

@@ -0,0 +1,89 @@
// sherpa-onnx/csrc/spoken-language-identification.h
//
// Copyright (c) 2024 Xiaomi Corporation
#ifndef SHERPA_ONNX_CSRC_SPOKEN_LANGUAGE_IDENTIFICATION_H_
#define SHERPA_ONNX_CSRC_SPOKEN_LANGUAGE_IDENTIFICATION_H_
#include <memory>
#include <string>
#include "sherpa-onnx/csrc/offline-stream.h"
#include "sherpa-onnx/csrc/parse-options.h"
namespace sherpa_onnx {
struct SpokenLanguageIdentificationWhisperConfig {
// Requires a multi-lingual whisper model.
// That is, it supports only tiny, base, small, medium, large.
// Note: It does NOT support tiny.en, base.en, small.en, medium.en
std::string encoder;
std::string decoder;
// Number of tail padding frames.
//
// Since we remove the 30-second constraint, we need to add some paddings
// at the end.
//
// Recommended values:
// - 50 for English models
// - 300 for multilingual models
int32_t tail_paddings = -1;
SpokenLanguageIdentificationWhisperConfig() = default;
SpokenLanguageIdentificationWhisperConfig(const std::string &encoder,
const std::string &decoder,
int32_t tail_paddings)
: encoder(encoder), decoder(decoder), tail_paddings(tail_paddings) {}
void Register(ParseOptions *po);
bool Validate() const;
std::string ToString() const;
};
struct SpokenLanguageIdentificationConfig {
SpokenLanguageIdentificationWhisperConfig whisper;
int32_t num_threads = 1;
bool debug = false;
std::string provider = "cpu";
SpokenLanguageIdentificationConfig() = default;
SpokenLanguageIdentificationConfig(
const SpokenLanguageIdentificationWhisperConfig &whisper,
int32_t num_threads, bool debug, const std::string &provider)
: whisper(whisper),
num_threads(num_threads),
debug(debug),
provider(provider) {}
void Register(ParseOptions *po);
bool Validate() const;
std::string ToString() const;
};
class SpokenLanguageIdentificationImpl;
class SpokenLanguageIdentification {
public:
explicit SpokenLanguageIdentification(
const SpokenLanguageIdentificationConfig &config);
~SpokenLanguageIdentification();
// Create a stream to accept audio samples and compute features
std::unique_ptr<OfflineStream> CreateStream() const;
// Return a string containing the language, e.g., en, zh, de,
// etc.
// Note: en is for English, zh is for Chinese, de is for German, etc.
std::string Compute(OfflineStream *s) const;
private:
std::unique_ptr<SpokenLanguageIdentificationImpl> impl_;
};
} // namespace sherpa_onnx
#endif // SHERPA_ONNX_CSRC_SPOKEN_LANGUAGE_IDENTIFICATION_H_