Support spoken language identification with whisper (#694)
This commit is contained in:
@@ -86,6 +86,8 @@ set(sources
|
||||
silero-vad-model-config.cc
|
||||
silero-vad-model.cc
|
||||
slice.cc
|
||||
spoken-language-identification-impl.cc
|
||||
spoken-language-identification.cc
|
||||
stack.cc
|
||||
symbol-table.cc
|
||||
text-utils.cc
|
||||
@@ -184,6 +186,7 @@ if(SHERPA_ONNX_ENABLE_BINARY)
|
||||
add_executable(sherpa-onnx-offline sherpa-onnx-offline.cc)
|
||||
add_executable(sherpa-onnx-offline-parallel sherpa-onnx-offline-parallel.cc)
|
||||
add_executable(sherpa-onnx-offline-tts sherpa-onnx-offline-tts.cc)
|
||||
add_executable(sherpa-onnx-offline-language-identification sherpa-onnx-offline-language-identification.cc)
|
||||
|
||||
set(main_exes
|
||||
sherpa-onnx
|
||||
@@ -191,6 +194,7 @@ if(SHERPA_ONNX_ENABLE_BINARY)
|
||||
sherpa-onnx-offline
|
||||
sherpa-onnx-offline-parallel
|
||||
sherpa-onnx-offline-tts
|
||||
sherpa-onnx-offline-language-identification
|
||||
)
|
||||
|
||||
foreach(exe IN LISTS main_exes)
|
||||
|
||||
@@ -23,7 +23,7 @@ enum class ModelType {
|
||||
kTdnn,
|
||||
kZipformerCtc,
|
||||
kWenetCtc,
|
||||
kUnkown,
|
||||
kUnknown,
|
||||
};
|
||||
|
||||
} // namespace
|
||||
@@ -59,7 +59,7 @@ static ModelType GetModelType(char *model_data, size_t model_data_length,
|
||||
"run.sh\n"
|
||||
"\n"
|
||||
"for how to add metadta to model.onnx\n");
|
||||
return ModelType::kUnkown;
|
||||
return ModelType::kUnknown;
|
||||
}
|
||||
|
||||
if (model_type.get() == std::string("EncDecCTCModelBPE")) {
|
||||
@@ -72,13 +72,13 @@ static ModelType GetModelType(char *model_data, size_t model_data_length,
|
||||
return ModelType::kWenetCtc;
|
||||
} else {
|
||||
SHERPA_ONNX_LOGE("Unsupported model_type: %s", model_type.get());
|
||||
return ModelType::kUnkown;
|
||||
return ModelType::kUnknown;
|
||||
}
|
||||
}
|
||||
|
||||
std::unique_ptr<OfflineCtcModel> OfflineCtcModel::Create(
|
||||
const OfflineModelConfig &config) {
|
||||
ModelType model_type = ModelType::kUnkown;
|
||||
ModelType model_type = ModelType::kUnknown;
|
||||
|
||||
std::string filename;
|
||||
if (!config.nemo_ctc.model.empty()) {
|
||||
@@ -113,7 +113,7 @@ std::unique_ptr<OfflineCtcModel> OfflineCtcModel::Create(
|
||||
case ModelType::kWenetCtc:
|
||||
return std::make_unique<OfflineWenetCtcModel>(config);
|
||||
break;
|
||||
case ModelType::kUnkown:
|
||||
case ModelType::kUnknown:
|
||||
SHERPA_ONNX_LOGE("Unknown model type in offline CTC!");
|
||||
return nullptr;
|
||||
}
|
||||
@@ -125,7 +125,7 @@ std::unique_ptr<OfflineCtcModel> OfflineCtcModel::Create(
|
||||
|
||||
std::unique_ptr<OfflineCtcModel> OfflineCtcModel::Create(
|
||||
AAssetManager *mgr, const OfflineModelConfig &config) {
|
||||
ModelType model_type = ModelType::kUnkown;
|
||||
ModelType model_type = ModelType::kUnknown;
|
||||
|
||||
std::string filename;
|
||||
if (!config.nemo_ctc.model.empty()) {
|
||||
@@ -160,7 +160,7 @@ std::unique_ptr<OfflineCtcModel> OfflineCtcModel::Create(
|
||||
case ModelType::kWenetCtc:
|
||||
return std::make_unique<OfflineWenetCtcModel>(mgr, config);
|
||||
break;
|
||||
case ModelType::kUnkown:
|
||||
case ModelType::kUnknown:
|
||||
SHERPA_ONNX_LOGE("Unknown model type in offline CTC!");
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
@@ -114,7 +114,7 @@ class OfflineRecognizerWhisperImpl : public OfflineRecognizerImpl {
|
||||
num_frames = max_num_frames - 50;
|
||||
}
|
||||
|
||||
NormalizeFeatures(f.data(), num_frames, feat_dim);
|
||||
model_->NormalizeFeatures(f.data(), num_frames, feat_dim);
|
||||
|
||||
// note that 1000 is an experience-value.
|
||||
// You can replace 1000 by other values, say, 100.
|
||||
@@ -162,38 +162,6 @@ class OfflineRecognizerWhisperImpl : public OfflineRecognizerImpl {
|
||||
}
|
||||
}
|
||||
|
||||
private:
|
||||
static void NormalizeFeatures(float *features, int32_t num_frames,
|
||||
int32_t feat_dim) {
|
||||
// log_spec = torch.clamp(features, min=1e-10).log10()
|
||||
// log_spec = torch.maximum(log_spec, log_spec.max() - 8.0)
|
||||
// mel = (log_spec + 4.0) / 4.0
|
||||
|
||||
int32_t n = num_frames * feat_dim;
|
||||
float max_v = -1e20;
|
||||
for (int32_t i = 0; i != n; ++i) {
|
||||
float f = features[i];
|
||||
|
||||
f = std::max<float>(f, 1e-10);
|
||||
f = std::log10(f);
|
||||
|
||||
max_v = std::max(f, max_v);
|
||||
|
||||
features[i] = f;
|
||||
}
|
||||
|
||||
max_v -= 8;
|
||||
|
||||
for (int32_t i = 0; i != n; ++i) {
|
||||
float f = features[i];
|
||||
f = std::max(f, max_v);
|
||||
|
||||
f = (f + 4) / 4;
|
||||
|
||||
features[i] = f;
|
||||
}
|
||||
}
|
||||
|
||||
private:
|
||||
OfflineRecognizerConfig config_;
|
||||
SymbolTable symbol_table_;
|
||||
|
||||
@@ -12,56 +12,6 @@
|
||||
|
||||
namespace sherpa_onnx {
|
||||
|
||||
int32_t OfflineWhisperGreedySearchDecoder::DetectLanguage(
|
||||
Ort::Value &cross_k, Ort::Value &cross_v) const { // NOLINT
|
||||
int64_t token_val = model_->SOT();
|
||||
std::array<int64_t, 2> token_shape{1, 1};
|
||||
|
||||
auto memory_info =
|
||||
Ort::MemoryInfo::CreateCpu(OrtDeviceAllocator, OrtMemTypeDefault);
|
||||
|
||||
Ort::Value tokens = Ort::Value::CreateTensor(
|
||||
memory_info, &token_val, 1, token_shape.data(), token_shape.size());
|
||||
|
||||
auto self_kv_cache = model_->GetInitialSelfKVCache();
|
||||
|
||||
std::array<int64_t, 1> offset_shape{1};
|
||||
Ort::Value offset = Ort::Value::CreateTensor<int64_t>(
|
||||
model_->Allocator(), offset_shape.data(), offset_shape.size());
|
||||
*(offset.GetTensorMutableData<int64_t>()) = 0;
|
||||
|
||||
auto decoder_out = model_->ForwardDecoder(
|
||||
std::move(tokens), std::move(self_kv_cache.first),
|
||||
std::move(self_kv_cache.second), std::move(cross_k), std::move(cross_v),
|
||||
std::move(offset));
|
||||
|
||||
cross_k = std::move(std::get<3>(decoder_out));
|
||||
cross_v = std::move(std::get<4>(decoder_out));
|
||||
|
||||
const float *p_logits = std::get<0>(decoder_out).GetTensorData<float>();
|
||||
int32_t vocab_size = model_->VocabSize();
|
||||
const auto &all_language_ids = model_->GetAllLanguageIDs();
|
||||
|
||||
int32_t lang_id = all_language_ids[0];
|
||||
float this_logit = p_logits[lang_id];
|
||||
|
||||
for (int32_t i = 1; i != all_language_ids.size(); ++i) {
|
||||
int32_t id = all_language_ids[i];
|
||||
float p = p_logits[id];
|
||||
|
||||
if (p > this_logit) {
|
||||
this_logit = p;
|
||||
lang_id = id;
|
||||
}
|
||||
}
|
||||
#if 1
|
||||
SHERPA_ONNX_LOGE("Detected language: %s",
|
||||
model_->GetID2Lang().at(lang_id).c_str());
|
||||
#endif
|
||||
|
||||
return lang_id;
|
||||
}
|
||||
|
||||
std::vector<OfflineWhisperDecoderResult>
|
||||
OfflineWhisperGreedySearchDecoder::Decode(Ort::Value cross_k,
|
||||
Ort::Value cross_v) {
|
||||
@@ -89,7 +39,7 @@ OfflineWhisperGreedySearchDecoder::Decode(Ort::Value cross_k,
|
||||
// 0: sot, 1: lang_id, 2: task, 3: no_timestamps
|
||||
initial_tokens[1] = lang_id;
|
||||
} else {
|
||||
int32_t lang_id = DetectLanguage(cross_k, cross_v);
|
||||
int32_t lang_id = model_->DetectLanguage(cross_k, cross_v);
|
||||
|
||||
// 0: sot, 1: lang_id, 2: task, 3: no_timestamps
|
||||
initial_tokens[1] = lang_id;
|
||||
|
||||
@@ -22,9 +22,6 @@ class OfflineWhisperGreedySearchDecoder : public OfflineWhisperDecoder {
|
||||
std::vector<OfflineWhisperDecoderResult> Decode(Ort::Value cross_k,
|
||||
Ort::Value cross_v) override;
|
||||
|
||||
int32_t DetectLanguage(Ort::Value &cross_k, // NOLINT
|
||||
Ort::Value &cross_v) const; // NOLINT
|
||||
|
||||
private:
|
||||
OfflineWhisperModelConfig config_;
|
||||
OfflineWhisperModel *model_; // not owned
|
||||
|
||||
@@ -35,19 +35,28 @@ void OfflineWhisperModelConfig::Register(ParseOptions *po) {
|
||||
|
||||
po->Register(
|
||||
"whisper-tail-paddings", &tail_paddings,
|
||||
"Suggest value: 50 for English models. 300 for multilingual models. "
|
||||
"Suggested value: 50 for English models. 300 for multilingual models. "
|
||||
"Since we have removed the 30-second constraint, we need to add some "
|
||||
"tail padding frames "
|
||||
"so that whisper can detect the eot token. Leave it to -1 to use 50 for "
|
||||
"English models and 300 for multilingual models.");
|
||||
"so that whisper can detect the eot token. Leave it to -1 to use 1000.");
|
||||
}
|
||||
|
||||
bool OfflineWhisperModelConfig::Validate() const {
|
||||
if (encoder.empty()) {
|
||||
SHERPA_ONNX_LOGE("Please provide --whisper-encoder");
|
||||
return false;
|
||||
}
|
||||
|
||||
if (!FileExists(encoder)) {
|
||||
SHERPA_ONNX_LOGE("whisper encoder file %s does not exist", encoder.c_str());
|
||||
return false;
|
||||
}
|
||||
|
||||
if (decoder.empty()) {
|
||||
SHERPA_ONNX_LOGE("Please provide --whisper-decoder");
|
||||
return false;
|
||||
}
|
||||
|
||||
if (!FileExists(decoder)) {
|
||||
SHERPA_ONNX_LOGE("whisper decoder file %s does not exist", decoder.c_str());
|
||||
return false;
|
||||
|
||||
@@ -24,6 +24,24 @@ class OfflineWhisperModel::Impl {
|
||||
env_(ORT_LOGGING_LEVEL_ERROR),
|
||||
sess_opts_(GetSessionOptions(config)),
|
||||
allocator_{} {
|
||||
debug_ = config_.debug;
|
||||
{
|
||||
auto buf = ReadFile(config.whisper.encoder);
|
||||
InitEncoder(buf.data(), buf.size());
|
||||
}
|
||||
|
||||
{
|
||||
auto buf = ReadFile(config.whisper.decoder);
|
||||
InitDecoder(buf.data(), buf.size());
|
||||
}
|
||||
}
|
||||
|
||||
explicit Impl(const SpokenLanguageIdentificationConfig &config)
|
||||
: lid_config_(config),
|
||||
env_(ORT_LOGGING_LEVEL_ERROR),
|
||||
sess_opts_(GetSessionOptions(config)),
|
||||
allocator_{} {
|
||||
debug_ = config_.debug;
|
||||
{
|
||||
auto buf = ReadFile(config.whisper.encoder);
|
||||
InitEncoder(buf.data(), buf.size());
|
||||
@@ -41,6 +59,7 @@ class OfflineWhisperModel::Impl {
|
||||
env_(ORT_LOGGING_LEVEL_ERROR),
|
||||
sess_opts_(GetSessionOptions(config)),
|
||||
allocator_{} {
|
||||
debug_ = config_.debug;
|
||||
{
|
||||
auto buf = ReadFile(mgr, config.whisper.encoder);
|
||||
InitEncoder(buf.data(), buf.size());
|
||||
@@ -85,6 +104,57 @@ class OfflineWhisperModel::Impl {
|
||||
std::move(decoder_input[4]), std::move(decoder_input[5])};
|
||||
}
|
||||
|
||||
int32_t DetectLanguage(Ort::Value &cross_k, // NOLINT
|
||||
Ort::Value &cross_v) { // NOLINT
|
||||
int64_t token_val = SOT();
|
||||
std::array<int64_t, 2> token_shape{1, 1};
|
||||
|
||||
auto memory_info =
|
||||
Ort::MemoryInfo::CreateCpu(OrtDeviceAllocator, OrtMemTypeDefault);
|
||||
|
||||
Ort::Value tokens = Ort::Value::CreateTensor(
|
||||
memory_info, &token_val, 1, token_shape.data(), token_shape.size());
|
||||
|
||||
auto self_kv_cache = GetInitialSelfKVCache();
|
||||
|
||||
std::array<int64_t, 1> offset_shape{1};
|
||||
Ort::Value offset = Ort::Value::CreateTensor<int64_t>(
|
||||
Allocator(), offset_shape.data(), offset_shape.size());
|
||||
*(offset.GetTensorMutableData<int64_t>()) = 0;
|
||||
|
||||
auto decoder_out =
|
||||
ForwardDecoder(std::move(tokens), std::move(self_kv_cache.first),
|
||||
std::move(self_kv_cache.second), std::move(cross_k),
|
||||
std::move(cross_v), std::move(offset));
|
||||
|
||||
cross_k = std::move(std::get<3>(decoder_out));
|
||||
cross_v = std::move(std::get<4>(decoder_out));
|
||||
|
||||
const float *p_logits = std::get<0>(decoder_out).GetTensorData<float>();
|
||||
int32_t vocab_size = VocabSize();
|
||||
const auto &all_language_ids = GetAllLanguageIDs();
|
||||
|
||||
int32_t lang_id = all_language_ids[0];
|
||||
float this_logit = p_logits[lang_id];
|
||||
|
||||
for (int32_t i = 1; i != all_language_ids.size(); ++i) {
|
||||
int32_t id = all_language_ids[i];
|
||||
float p = p_logits[id];
|
||||
|
||||
if (p > this_logit) {
|
||||
this_logit = p;
|
||||
lang_id = id;
|
||||
}
|
||||
}
|
||||
|
||||
if (debug_) {
|
||||
SHERPA_ONNX_LOGE("Detected language: %s",
|
||||
GetID2Lang().at(lang_id).c_str());
|
||||
}
|
||||
|
||||
return lang_id;
|
||||
}
|
||||
|
||||
std::pair<Ort::Value, Ort::Value> GetInitialSelfKVCache() {
|
||||
std::array<int64_t, 4> shape{n_text_layer_, 1, n_text_ctx_, n_text_state_};
|
||||
|
||||
@@ -148,7 +218,7 @@ class OfflineWhisperModel::Impl {
|
||||
|
||||
// get meta data
|
||||
Ort::ModelMetadata meta_data = encoder_sess_->GetModelMetadata();
|
||||
if (config_.debug) {
|
||||
if (debug_) {
|
||||
std::ostringstream os;
|
||||
os << "---encoder---\n";
|
||||
PrintModelMetadata(os, meta_data);
|
||||
@@ -203,6 +273,8 @@ class OfflineWhisperModel::Impl {
|
||||
|
||||
private:
|
||||
OfflineModelConfig config_;
|
||||
SpokenLanguageIdentificationConfig lid_config_;
|
||||
bool debug_ = false;
|
||||
Ort::Env env_;
|
||||
Ort::SessionOptions sess_opts_;
|
||||
Ort::AllocatorWithDefaultOptions allocator_;
|
||||
@@ -246,6 +318,10 @@ class OfflineWhisperModel::Impl {
|
||||
OfflineWhisperModel::OfflineWhisperModel(const OfflineModelConfig &config)
|
||||
: impl_(std::make_unique<Impl>(config)) {}
|
||||
|
||||
OfflineWhisperModel::OfflineWhisperModel(
|
||||
const SpokenLanguageIdentificationConfig &config)
|
||||
: impl_(std::make_unique<Impl>(config)) {}
|
||||
|
||||
#if __ANDROID_API__ >= 9
|
||||
OfflineWhisperModel::OfflineWhisperModel(AAssetManager *mgr,
|
||||
const OfflineModelConfig &config)
|
||||
@@ -273,6 +349,11 @@ OfflineWhisperModel::ForwardDecoder(Ort::Value tokens,
|
||||
std::move(n_layer_cross_v), std::move(offset));
|
||||
}
|
||||
|
||||
int32_t OfflineWhisperModel::DetectLanguage(Ort::Value &cross_k, // NOLINT
|
||||
Ort::Value &cross_v) { // NOLINT
|
||||
return impl_->DetectLanguage(cross_k, cross_v);
|
||||
}
|
||||
|
||||
std::pair<Ort::Value, Ort::Value> OfflineWhisperModel::GetInitialSelfKVCache()
|
||||
const {
|
||||
return impl_->GetInitialSelfKVCache();
|
||||
@@ -318,4 +399,35 @@ bool OfflineWhisperModel::IsMultiLingual() const {
|
||||
return impl_->IsMultiLingual();
|
||||
}
|
||||
|
||||
void OfflineWhisperModel::NormalizeFeatures(float *features, int32_t num_frames,
|
||||
int32_t feat_dim) {
|
||||
// log_spec = torch.clamp(features, min=1e-10).log10()
|
||||
// log_spec = torch.maximum(log_spec, log_spec.max() - 8.0)
|
||||
// mel = (log_spec + 4.0) / 4.0
|
||||
|
||||
int32_t n = num_frames * feat_dim;
|
||||
float max_v = -1e20;
|
||||
for (int32_t i = 0; i != n; ++i) {
|
||||
float f = features[i];
|
||||
|
||||
f = std::max<float>(f, 1e-10);
|
||||
f = std::log10(f);
|
||||
|
||||
max_v = std::max(f, max_v);
|
||||
|
||||
features[i] = f;
|
||||
}
|
||||
|
||||
max_v -= 8;
|
||||
|
||||
for (int32_t i = 0; i != n; ++i) {
|
||||
float f = features[i];
|
||||
f = std::max(f, max_v);
|
||||
|
||||
f = (f + 4) / 4;
|
||||
|
||||
features[i] = f;
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace sherpa_onnx
|
||||
|
||||
@@ -18,6 +18,7 @@
|
||||
|
||||
#include "onnxruntime_cxx_api.h" // NOLINT
|
||||
#include "sherpa-onnx/csrc/offline-model-config.h"
|
||||
#include "sherpa-onnx/csrc/spoken-language-identification.h"
|
||||
|
||||
namespace sherpa_onnx {
|
||||
|
||||
@@ -25,6 +26,9 @@ class OfflineWhisperModel {
|
||||
public:
|
||||
explicit OfflineWhisperModel(const OfflineModelConfig &config);
|
||||
|
||||
explicit OfflineWhisperModel(
|
||||
const SpokenLanguageIdentificationConfig &config);
|
||||
|
||||
#if __ANDROID_API__ >= 9
|
||||
OfflineWhisperModel(AAssetManager *mgr, const OfflineModelConfig &config);
|
||||
#endif
|
||||
@@ -72,7 +76,8 @@ class OfflineWhisperModel {
|
||||
Ort::Value n_layer_self_v_cache, Ort::Value n_layer_cross_k,
|
||||
Ort::Value n_layer_cross_v, Ort::Value offset) const;
|
||||
|
||||
int32_t DetectLanguage() const;
|
||||
int32_t DetectLanguage(Ort::Value &cross_k, // NOLINT
|
||||
Ort::Value &cross_v); // NOLINT
|
||||
|
||||
/** Return the initial self kv cache in a pair
|
||||
* - n_layer_self_k_cache A 4-D tensor of shape
|
||||
@@ -98,6 +103,9 @@ class OfflineWhisperModel {
|
||||
int32_t Translate() const;
|
||||
bool IsMultiLingual() const;
|
||||
|
||||
static void NormalizeFeatures(float *features, int32_t num_frames,
|
||||
int32_t feat_dim);
|
||||
|
||||
private:
|
||||
class Impl;
|
||||
std::unique_ptr<Impl> impl_;
|
||||
|
||||
@@ -28,7 +28,7 @@ enum class ModelType {
|
||||
kLstm,
|
||||
kZipformer,
|
||||
kZipformer2,
|
||||
kUnkown,
|
||||
kUnknown,
|
||||
};
|
||||
|
||||
} // namespace
|
||||
@@ -58,7 +58,7 @@ static ModelType GetModelType(char *model_data, size_t model_data_length,
|
||||
"No model_type in the metadata!\n"
|
||||
"Please make sure you are using the latest export-onnx.py from icefall "
|
||||
"to export your transducer models");
|
||||
return ModelType::kUnkown;
|
||||
return ModelType::kUnknown;
|
||||
}
|
||||
|
||||
if (model_type.get() == std::string("conformer")) {
|
||||
@@ -71,7 +71,7 @@ static ModelType GetModelType(char *model_data, size_t model_data_length,
|
||||
return ModelType::kZipformer2;
|
||||
} else {
|
||||
SHERPA_ONNX_LOGE("Unsupported model_type: %s", model_type.get());
|
||||
return ModelType::kUnkown;
|
||||
return ModelType::kUnknown;
|
||||
}
|
||||
}
|
||||
|
||||
@@ -93,7 +93,7 @@ std::unique_ptr<OnlineTransducerModel> OnlineTransducerModel::Create(
|
||||
model_type.c_str());
|
||||
}
|
||||
}
|
||||
ModelType model_type = ModelType::kUnkown;
|
||||
ModelType model_type = ModelType::kUnknown;
|
||||
|
||||
{
|
||||
auto buffer = ReadFile(config.transducer.encoder);
|
||||
@@ -110,7 +110,7 @@ std::unique_ptr<OnlineTransducerModel> OnlineTransducerModel::Create(
|
||||
return std::make_unique<OnlineZipformerTransducerModel>(config);
|
||||
case ModelType::kZipformer2:
|
||||
return std::make_unique<OnlineZipformer2TransducerModel>(config);
|
||||
case ModelType::kUnkown:
|
||||
case ModelType::kUnknown:
|
||||
SHERPA_ONNX_LOGE("Unknown model type in online transducer!");
|
||||
return nullptr;
|
||||
}
|
||||
@@ -185,7 +185,7 @@ std::unique_ptr<OnlineTransducerModel> OnlineTransducerModel::Create(
|
||||
return std::make_unique<OnlineZipformerTransducerModel>(mgr, config);
|
||||
case ModelType::kZipformer2:
|
||||
return std::make_unique<OnlineZipformer2TransducerModel>(mgr, config);
|
||||
case ModelType::kUnkown:
|
||||
case ModelType::kUnknown:
|
||||
SHERPA_ONNX_LOGE("Unknown model type in online transducer!");
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
@@ -149,4 +149,9 @@ Ort::SessionOptions GetSessionOptions(
|
||||
return GetSessionOptionsImpl(config.num_threads, config.provider);
|
||||
}
|
||||
|
||||
Ort::SessionOptions GetSessionOptions(
|
||||
const SpokenLanguageIdentificationConfig &config) {
|
||||
return GetSessionOptionsImpl(config.num_threads, config.provider);
|
||||
}
|
||||
|
||||
} // namespace sherpa_onnx
|
||||
|
||||
@@ -12,6 +12,7 @@
|
||||
#include "sherpa-onnx/csrc/online-lm-config.h"
|
||||
#include "sherpa-onnx/csrc/online-model-config.h"
|
||||
#include "sherpa-onnx/csrc/speaker-embedding-extractor.h"
|
||||
#include "sherpa-onnx/csrc/spoken-language-identification.h"
|
||||
#include "sherpa-onnx/csrc/vad-model-config.h"
|
||||
|
||||
namespace sherpa_onnx {
|
||||
@@ -30,6 +31,10 @@ Ort::SessionOptions GetSessionOptions(const OfflineTtsModelConfig &config);
|
||||
|
||||
Ort::SessionOptions GetSessionOptions(
|
||||
const SpeakerEmbeddingExtractorConfig &config);
|
||||
|
||||
Ort::SessionOptions GetSessionOptions(
|
||||
const SpokenLanguageIdentificationConfig &config);
|
||||
|
||||
} // namespace sherpa_onnx
|
||||
|
||||
#endif // SHERPA_ONNX_CSRC_SESSION_H_
|
||||
|
||||
107
sherpa-onnx/csrc/sherpa-onnx-offline-language-identification.cc
Normal file
107
sherpa-onnx/csrc/sherpa-onnx-offline-language-identification.cc
Normal file
@@ -0,0 +1,107 @@
|
||||
// sherpa-onnx/csrc/sherpa-onnx-offline-language-identification.cc
|
||||
//
|
||||
// Copyright (c) 2022-2024 Xiaomi Corporation
|
||||
|
||||
#include <stdio.h>
|
||||
|
||||
#include <chrono> // NOLINT
|
||||
#include <string>
|
||||
#include <vector>
|
||||
|
||||
#include "sherpa-onnx/csrc/parse-options.h"
|
||||
#include "sherpa-onnx/csrc/spoken-language-identification.h"
|
||||
#include "sherpa-onnx/csrc/wave-reader.h"
|
||||
|
||||
int main(int32_t argc, char *argv[]) {
|
||||
const char *kUsageMessage = R"usage(
|
||||
Spoken language identification with sherpa-onnx.
|
||||
|
||||
Usage:
|
||||
|
||||
(1) Use a whisper multilingual model
|
||||
|
||||
wget https://github.com/k2-fsa/sherpa-onnx/releases/download/asr-models/sherpa-onnx-whisper-tiny.tar.bz2
|
||||
tar xvf sherpa-onnx-whisper-tiny.tar.bz2
|
||||
rm sherpa-onnx-whisper-tiny.tar.bz2
|
||||
|
||||
We only use the int8.onnx models below.
|
||||
|
||||
./bin/sherpa-onnx-offline-spoken-language-identification \
|
||||
--whisper-encoder=sherpa-onnx-whisper-tiny/tiny-encoder.int8.onnx \
|
||||
--whisper-decoder=sherpa-onnx-whisper-tiny/tiny-decoder.int8.onnx \
|
||||
--num-threads=1 \
|
||||
/path/to/foo.wav
|
||||
|
||||
foo.wav should be of single channel, 16-bit PCM encoded wave file; its
|
||||
sampling rate can be arbitrary and does not need to be 16kHz.
|
||||
You can find test waves for different languages at
|
||||
https://hf-mirror.com/spaces/k2-fsa/spoken-language-identification/tree/main/test_wavs
|
||||
|
||||
Please refer to
|
||||
https://k2-fsa.github.io/sherpa/onnx/pretrained_models/whisper/index.html
|
||||
Note that only whisper multilingual models are supported. For instance,
|
||||
"tiny" is supported but "tiny.en" is not.
|
||||
for a list of pre-trained models to download.
|
||||
)usage";
|
||||
|
||||
sherpa_onnx::ParseOptions po(kUsageMessage);
|
||||
sherpa_onnx::SpokenLanguageIdentificationConfig config;
|
||||
config.Register(&po);
|
||||
|
||||
po.Read(argc, argv);
|
||||
if (po.NumArgs() != 1) {
|
||||
fprintf(stderr, "Error: Please provide 1 wave file.\n\n");
|
||||
po.PrintUsage();
|
||||
exit(EXIT_FAILURE);
|
||||
}
|
||||
|
||||
fprintf(stderr, "%s\n", config.ToString().c_str());
|
||||
|
||||
if (!config.Validate()) {
|
||||
fprintf(stderr, "Errors in config!\n");
|
||||
return -1;
|
||||
}
|
||||
|
||||
fprintf(stderr, "Creating spoken language identifier ...\n");
|
||||
sherpa_onnx::SpokenLanguageIdentification slid(config);
|
||||
|
||||
fprintf(stderr, "Started\n");
|
||||
const std::string wav_filename = po.GetArg(1);
|
||||
|
||||
int32_t sampling_rate = -1;
|
||||
bool is_ok = false;
|
||||
const std::vector<float> samples =
|
||||
sherpa_onnx::ReadWave(wav_filename, &sampling_rate, &is_ok);
|
||||
if (!is_ok) {
|
||||
fprintf(stderr, "Failed to read %s\n", wav_filename.c_str());
|
||||
return -1;
|
||||
}
|
||||
float duration = samples.size() / static_cast<float>(sampling_rate);
|
||||
|
||||
const auto begin = std::chrono::steady_clock::now();
|
||||
|
||||
auto s = slid.CreateStream();
|
||||
s->AcceptWaveform(sampling_rate, samples.data(), samples.size());
|
||||
|
||||
auto language = slid.Compute(s.get());
|
||||
|
||||
const auto end = std::chrono::steady_clock::now();
|
||||
|
||||
fprintf(stderr, "Done!\n\n");
|
||||
fprintf(stderr, "%s\nDetected language: %s\n", wav_filename.c_str(),
|
||||
language.c_str());
|
||||
|
||||
float elapsed_seconds =
|
||||
std::chrono::duration_cast<std::chrono::milliseconds>(end - begin)
|
||||
.count() /
|
||||
1000.;
|
||||
|
||||
fprintf(stderr, "num threads: %d\n", config.num_threads);
|
||||
|
||||
fprintf(stderr, "Elapsed seconds: %.3f s\n", elapsed_seconds);
|
||||
float rtf = elapsed_seconds / duration;
|
||||
fprintf(stderr, "Real time factor (RTF): %.3f / %.3f = %.3f\n",
|
||||
elapsed_seconds, duration, rtf);
|
||||
|
||||
return 0;
|
||||
}
|
||||
@@ -16,7 +16,7 @@ enum class ModelType {
|
||||
kWeSpeaker,
|
||||
k3dSpeaker,
|
||||
kNeMo,
|
||||
kUnkown,
|
||||
kUnknown,
|
||||
};
|
||||
|
||||
} // namespace
|
||||
@@ -47,7 +47,7 @@ static ModelType GetModelType(char *model_data, size_t model_data_length,
|
||||
"https://github.com/k2-fsa/sherpa-onnx/blob/master/scripts/wespeaker/"
|
||||
"add_meta_data.py"
|
||||
"to add metadata to models from WeSpeaker\n");
|
||||
return ModelType::kUnkown;
|
||||
return ModelType::kUnknown;
|
||||
}
|
||||
|
||||
if (model_type.get() == std::string("wespeaker")) {
|
||||
@@ -58,14 +58,14 @@ static ModelType GetModelType(char *model_data, size_t model_data_length,
|
||||
return ModelType::kNeMo;
|
||||
} else {
|
||||
SHERPA_ONNX_LOGE("Unsupported model_type: %s", model_type.get());
|
||||
return ModelType::kUnkown;
|
||||
return ModelType::kUnknown;
|
||||
}
|
||||
}
|
||||
|
||||
std::unique_ptr<SpeakerEmbeddingExtractorImpl>
|
||||
SpeakerEmbeddingExtractorImpl::Create(
|
||||
const SpeakerEmbeddingExtractorConfig &config) {
|
||||
ModelType model_type = ModelType::kUnkown;
|
||||
ModelType model_type = ModelType::kUnknown;
|
||||
|
||||
{
|
||||
auto buffer = ReadFile(config.model);
|
||||
@@ -80,9 +80,8 @@ SpeakerEmbeddingExtractorImpl::Create(
|
||||
return std::make_unique<SpeakerEmbeddingExtractorGeneralImpl>(config);
|
||||
case ModelType::kNeMo:
|
||||
return std::make_unique<SpeakerEmbeddingExtractorNeMoImpl>(config);
|
||||
case ModelType::kUnkown:
|
||||
SHERPA_ONNX_LOGE(
|
||||
"Unknown model type in for speaker embedding extractor!");
|
||||
case ModelType::kUnknown:
|
||||
SHERPA_ONNX_LOGE("Unknown model type for speaker embedding extractor!");
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
@@ -94,7 +93,7 @@ SpeakerEmbeddingExtractorImpl::Create(
|
||||
std::unique_ptr<SpeakerEmbeddingExtractorImpl>
|
||||
SpeakerEmbeddingExtractorImpl::Create(
|
||||
AAssetManager *mgr, const SpeakerEmbeddingExtractorConfig &config) {
|
||||
ModelType model_type = ModelType::kUnkown;
|
||||
ModelType model_type = ModelType::kUnknown;
|
||||
|
||||
{
|
||||
auto buffer = ReadFile(mgr, config.model);
|
||||
@@ -110,7 +109,7 @@ SpeakerEmbeddingExtractorImpl::Create(
|
||||
config);
|
||||
case ModelType::kNeMo:
|
||||
return std::make_unique<SpeakerEmbeddingExtractorNeMoImpl>(mgr, config);
|
||||
case ModelType::kUnkown:
|
||||
case ModelType::kUnknown:
|
||||
SHERPA_ONNX_LOGE(
|
||||
"Unknown model type in for speaker embedding extractor!");
|
||||
return nullptr;
|
||||
|
||||
88
sherpa-onnx/csrc/spoken-language-identification-impl.cc
Normal file
88
sherpa-onnx/csrc/spoken-language-identification-impl.cc
Normal file
@@ -0,0 +1,88 @@
|
||||
// sherpa-onnx/csrc/spoken-language-identification-impl.cc
|
||||
//
|
||||
// Copyright (c) 2024 Xiaomi Corporation
|
||||
#include "sherpa-onnx/csrc/spoken-language-identification-impl.h"
|
||||
|
||||
#include <memory>
|
||||
|
||||
#include "sherpa-onnx/csrc/macros.h"
|
||||
#include "sherpa-onnx/csrc/onnx-utils.h"
|
||||
#include "sherpa-onnx/csrc/spoken-language-identification-whisper-impl.h"
|
||||
|
||||
namespace sherpa_onnx {
|
||||
|
||||
namespace {
|
||||
|
||||
enum class ModelType {
|
||||
kWhisper,
|
||||
kUnknown,
|
||||
};
|
||||
|
||||
}
|
||||
|
||||
static ModelType GetModelType(char *model_data, size_t model_data_length,
|
||||
bool debug) {
|
||||
Ort::Env env(ORT_LOGGING_LEVEL_WARNING);
|
||||
Ort::SessionOptions sess_opts;
|
||||
|
||||
auto sess = std::make_unique<Ort::Session>(env, model_data, model_data_length,
|
||||
sess_opts);
|
||||
|
||||
Ort::ModelMetadata meta_data = sess->GetModelMetadata();
|
||||
if (debug) {
|
||||
std::ostringstream os;
|
||||
PrintModelMetadata(os, meta_data);
|
||||
SHERPA_ONNX_LOGE("%s", os.str().c_str());
|
||||
}
|
||||
|
||||
Ort::AllocatorWithDefaultOptions allocator;
|
||||
auto model_type =
|
||||
meta_data.LookupCustomMetadataMapAllocated("model_type", allocator);
|
||||
if (!model_type) {
|
||||
SHERPA_ONNX_LOGE(
|
||||
"No model_type in the metadata!\n"
|
||||
"Please make sure you have added metadata to the model.\n\n"
|
||||
"For instance, you can use\n"
|
||||
"https://github.com/k2-fsa/sherpa-onnx/blob/master/scripts/whisper/"
|
||||
"export-onnx.py "
|
||||
"to add metadata to models from whisper\n");
|
||||
return ModelType::kUnknown;
|
||||
}
|
||||
|
||||
auto model_type_str = std::string(model_type.get());
|
||||
if (model_type_str.find("whisper") == 0) {
|
||||
return ModelType::kWhisper;
|
||||
} else {
|
||||
SHERPA_ONNX_LOGE("Unsupported model_type: %s", model_type.get());
|
||||
return ModelType::kUnknown;
|
||||
}
|
||||
}
|
||||
|
||||
std::unique_ptr<SpokenLanguageIdentificationImpl>
|
||||
SpokenLanguageIdentificationImpl::Create(
|
||||
const SpokenLanguageIdentificationConfig &config) {
|
||||
ModelType model_type = ModelType::kUnknown;
|
||||
{
|
||||
if (config.whisper.encoder.empty()) {
|
||||
SHERPA_ONNX_LOGE("Only whisper models are supported at present");
|
||||
exit(-1);
|
||||
}
|
||||
auto buffer = ReadFile(config.whisper.encoder);
|
||||
|
||||
model_type = GetModelType(buffer.data(), buffer.size(), config.debug);
|
||||
}
|
||||
|
||||
switch (model_type) {
|
||||
case ModelType::kWhisper:
|
||||
return std::make_unique<SpokenLanguageIdentificationWhisperImpl>(config);
|
||||
case ModelType::kUnknown:
|
||||
SHERPA_ONNX_LOGE(
|
||||
"Unknown model type for spoken language identification!");
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
// unreachable code
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
} // namespace sherpa_onnx
|
||||
28
sherpa-onnx/csrc/spoken-language-identification-impl.h
Normal file
28
sherpa-onnx/csrc/spoken-language-identification-impl.h
Normal file
@@ -0,0 +1,28 @@
|
||||
// sherpa-onnx/csrc/spoken-language-identification-impl.h
|
||||
//
|
||||
// Copyright (c) 2024 Xiaomi Corporation
|
||||
#ifndef SHERPA_ONNX_CSRC_SPOKEN_LANGUAGE_IDENTIFICATION_IMPL_H_
|
||||
#define SHERPA_ONNX_CSRC_SPOKEN_LANGUAGE_IDENTIFICATION_IMPL_H_
|
||||
|
||||
#include <memory>
|
||||
#include <string>
|
||||
|
||||
#include "sherpa-onnx/csrc/spoken-language-identification.h"
|
||||
|
||||
namespace sherpa_onnx {
|
||||
|
||||
class SpokenLanguageIdentificationImpl {
|
||||
public:
|
||||
virtual ~SpokenLanguageIdentificationImpl() = default;
|
||||
|
||||
static std::unique_ptr<SpokenLanguageIdentificationImpl> Create(
|
||||
const SpokenLanguageIdentificationConfig &config);
|
||||
|
||||
virtual std::unique_ptr<OfflineStream> CreateStream() const = 0;
|
||||
|
||||
virtual std::string Compute(OfflineStream *s) const = 0;
|
||||
};
|
||||
|
||||
} // namespace sherpa_onnx
|
||||
|
||||
#endif // SHERPA_ONNX_CSRC_SPOKEN_LANGUAGE_IDENTIFICATION_IMPL_H_
|
||||
119
sherpa-onnx/csrc/spoken-language-identification-whisper-impl.h
Normal file
119
sherpa-onnx/csrc/spoken-language-identification-whisper-impl.h
Normal file
@@ -0,0 +1,119 @@
|
||||
// sherpa-onnx/csrc/spoken-language-identification-whisper-impl.h
|
||||
//
|
||||
// Copyright (c) 2024 Xiaomi Corporation
|
||||
|
||||
#ifndef SHERPA_ONNX_CSRC_SPOKEN_LANGUAGE_IDENTIFICATION_WHISPER_IMPL_H_
|
||||
#define SHERPA_ONNX_CSRC_SPOKEN_LANGUAGE_IDENTIFICATION_WHISPER_IMPL_H_
|
||||
|
||||
#include <algorithm>
|
||||
#include <memory>
|
||||
#include <string>
|
||||
#include <utility>
|
||||
#include <vector>
|
||||
|
||||
#include "sherpa-onnx/csrc/offline-whisper-model.h"
|
||||
#include "sherpa-onnx/csrc/spoken-language-identification-impl.h"
|
||||
#include "sherpa-onnx/csrc/transpose.h"
|
||||
|
||||
namespace sherpa_onnx {
|
||||
|
||||
class SpokenLanguageIdentificationWhisperImpl
|
||||
: public SpokenLanguageIdentificationImpl {
|
||||
public:
|
||||
explicit SpokenLanguageIdentificationWhisperImpl(
|
||||
const SpokenLanguageIdentificationConfig &config)
|
||||
: config_(config), model_(std::make_unique<OfflineWhisperModel>(config)) {
|
||||
Check();
|
||||
}
|
||||
|
||||
std::unique_ptr<OfflineStream> CreateStream() const override {
|
||||
return std::make_unique<OfflineStream>(WhisperTag{});
|
||||
}
|
||||
|
||||
std::string Compute(OfflineStream *s) const override {
|
||||
int32_t max_num_frames = 3000;
|
||||
auto memory_info =
|
||||
Ort::MemoryInfo::CreateCpu(OrtDeviceAllocator, OrtMemTypeDefault);
|
||||
|
||||
int32_t feat_dim = s->FeatureDim();
|
||||
std::vector<float> f = s->GetFrames();
|
||||
int32_t num_frames = f.size() / feat_dim;
|
||||
|
||||
// we use 50 here so that there will be some zero tail paddings
|
||||
if (num_frames >= max_num_frames - 50) {
|
||||
SHERPA_ONNX_LOGE(
|
||||
"Only waves less than 30 seconds are supported. We process only the "
|
||||
"first 30 seconds and discard the remaining data");
|
||||
num_frames = max_num_frames - 50;
|
||||
}
|
||||
|
||||
model_->NormalizeFeatures(f.data(), num_frames, feat_dim);
|
||||
|
||||
// note that 1000 is an experience-value.
|
||||
// You can replace 1000 by other values, say, 100.
|
||||
//
|
||||
// Since we have removed the 30 seconds constraint, we need
|
||||
// tail_padding_frames so that whisper is able to detect the eot token.
|
||||
int32_t tail_padding_frames = 1000;
|
||||
|
||||
if (config_.whisper.tail_paddings > 0) {
|
||||
tail_padding_frames = config_.whisper.tail_paddings;
|
||||
}
|
||||
|
||||
int32_t actual_frames =
|
||||
std::min(num_frames + tail_padding_frames, max_num_frames);
|
||||
|
||||
std::array<int64_t, 3> shape{1, actual_frames, feat_dim};
|
||||
|
||||
Ort::Value mel = Ort::Value::CreateTensor<float>(
|
||||
model_->Allocator(), shape.data(), shape.size());
|
||||
|
||||
float *p_mel = mel.GetTensorMutableData<float>();
|
||||
std::copy(f.data(), f.data() + num_frames * feat_dim, p_mel);
|
||||
|
||||
std::fill_n(p_mel + num_frames * feat_dim,
|
||||
(actual_frames - num_frames) * feat_dim, 0);
|
||||
|
||||
mel = Transpose12(model_->Allocator(), &mel);
|
||||
|
||||
try {
|
||||
auto cross_kv = model_->ForwardEncoder(std::move(mel));
|
||||
int32_t lang_id = model_->DetectLanguage(cross_kv.first, cross_kv.second);
|
||||
const auto &id2lang = model_->GetID2Lang();
|
||||
if (id2lang.count(lang_id)) {
|
||||
return id2lang.at(lang_id);
|
||||
} else {
|
||||
SHERPA_ONNX_LOGE("Unknown language ID: %d. Return an empty string.",
|
||||
lang_id);
|
||||
return "";
|
||||
}
|
||||
} catch (const Ort::Exception &ex) {
|
||||
SHERPA_ONNX_LOGE(
|
||||
"\n\nCaught exception:\n\n%s\n\nReturn an empty result. Number of "
|
||||
"input frames: %d, Current tail "
|
||||
"paddings: %d. If you see a lot of such exceptions, please consider "
|
||||
"using a larger --whisper-tail-paddings",
|
||||
ex.what(), num_frames, tail_padding_frames);
|
||||
return "";
|
||||
}
|
||||
}
|
||||
|
||||
private:
|
||||
void Check() const {
|
||||
if (!model_->IsMultiLingual()) {
|
||||
SHERPA_ONNX_LOGE(
|
||||
"Only whisper multilingual models can be used for spoken language "
|
||||
"identification. Given: %s,%s",
|
||||
config_.whisper.encoder.c_str(), config_.whisper.decoder.c_str());
|
||||
exit(-1);
|
||||
}
|
||||
}
|
||||
|
||||
private:
|
||||
SpokenLanguageIdentificationConfig config_;
|
||||
std::unique_ptr<OfflineWhisperModel> model_;
|
||||
};
|
||||
|
||||
} // namespace sherpa_onnx
|
||||
|
||||
#endif // SHERPA_ONNX_CSRC_SPOKEN_LANGUAGE_IDENTIFICATION_WHISPER_IMPL_H_
|
||||
117
sherpa-onnx/csrc/spoken-language-identification.cc
Normal file
117
sherpa-onnx/csrc/spoken-language-identification.cc
Normal file
@@ -0,0 +1,117 @@
|
||||
// sherpa-onnx/csrc/spoken-language-identification.cc
|
||||
//
|
||||
// Copyright (c) 2024 Xiaomi Corporation
|
||||
|
||||
#include "sherpa-onnx/csrc/spoken-language-identification.h"
|
||||
|
||||
#include <string>
|
||||
|
||||
#include "sherpa-onnx/csrc/file-utils.h"
|
||||
#include "sherpa-onnx/csrc/macros.h"
|
||||
#include "sherpa-onnx/csrc/spoken-language-identification-impl.h"
|
||||
|
||||
namespace sherpa_onnx {
|
||||
|
||||
void SpokenLanguageIdentificationWhisperConfig::Register(ParseOptions *po) {
|
||||
po->Register(
|
||||
"whisper-encoder", &encoder,
|
||||
"Path to then encoder of a whisper multilingual model. Support only "
|
||||
"tiny, base, small, medium, large.");
|
||||
|
||||
po->Register(
|
||||
"whisper-decoder", &decoder,
|
||||
"Path to the decoder of a whisper multilingual model. Support only "
|
||||
"tiny, base, small, medium, large.");
|
||||
|
||||
po->Register(
|
||||
"whisper-tail-paddings", &tail_paddings,
|
||||
"Suggested value: 300 for multilingual models. "
|
||||
"Since we have removed the 30-second constraint, we need to add some "
|
||||
"tail padding frames "
|
||||
"so that whisper can detect the eot token. Leave it to -1 to use 1000");
|
||||
}
|
||||
|
||||
bool SpokenLanguageIdentificationWhisperConfig::Validate() const {
|
||||
if (encoder.empty()) {
|
||||
SHERPA_ONNX_LOGE("Please provide --whisper-encoder");
|
||||
return false;
|
||||
}
|
||||
|
||||
if (!FileExists(encoder)) {
|
||||
SHERPA_ONNX_LOGE("whisper encoder file %s does not exist", encoder.c_str());
|
||||
return false;
|
||||
}
|
||||
|
||||
if (decoder.empty()) {
|
||||
SHERPA_ONNX_LOGE("Please provide --whisper-decoder");
|
||||
return false;
|
||||
}
|
||||
|
||||
if (!FileExists(decoder)) {
|
||||
SHERPA_ONNX_LOGE("whisper decoder file %s does not exist", decoder.c_str());
|
||||
return false;
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
std::string SpokenLanguageIdentificationWhisperConfig::ToString() const {
|
||||
std::ostringstream os;
|
||||
|
||||
os << "SpokenLanguageIdentificationWhisperConfig(";
|
||||
os << "encoder=\"" << encoder << "\", ";
|
||||
os << "decoder=\"" << decoder << "\", ";
|
||||
os << "tail_paddings=" << tail_paddings << ")";
|
||||
|
||||
return os.str();
|
||||
}
|
||||
|
||||
void SpokenLanguageIdentificationConfig::Register(ParseOptions *po) {
|
||||
whisper.Register(po);
|
||||
|
||||
po->Register("num-threads", &num_threads,
|
||||
"Number of threads to run the neural network");
|
||||
|
||||
po->Register("debug", &debug,
|
||||
"true to print model information while loading it.");
|
||||
|
||||
po->Register("provider", &provider,
|
||||
"Specify a provider to use: cpu, cuda, coreml");
|
||||
}
|
||||
|
||||
bool SpokenLanguageIdentificationConfig::Validate() const {
|
||||
if (!whisper.Validate()) {
|
||||
return false;
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
std::string SpokenLanguageIdentificationConfig::ToString() const {
|
||||
std::ostringstream os;
|
||||
|
||||
os << "SpokenLanguageIdentificationConfig(";
|
||||
os << "whisper=\"" << whisper.ToString() << "\", ";
|
||||
os << "num_threads=" << num_threads << ", ";
|
||||
os << "debug=" << (debug ? "True" : "False") << ", ";
|
||||
os << "provider=\"" << provider << "\")";
|
||||
|
||||
return os.str();
|
||||
}
|
||||
|
||||
SpokenLanguageIdentification::SpokenLanguageIdentification(
|
||||
const SpokenLanguageIdentificationConfig &config)
|
||||
: impl_(SpokenLanguageIdentificationImpl::Create(config)) {}
|
||||
|
||||
SpokenLanguageIdentification::~SpokenLanguageIdentification() = default;
|
||||
|
||||
std::unique_ptr<OfflineStream> SpokenLanguageIdentification::CreateStream()
|
||||
const {
|
||||
return impl_->CreateStream();
|
||||
}
|
||||
|
||||
std::string SpokenLanguageIdentification::Compute(OfflineStream *s) const {
|
||||
return impl_->Compute(s);
|
||||
}
|
||||
|
||||
} // namespace sherpa_onnx
|
||||
89
sherpa-onnx/csrc/spoken-language-identification.h
Normal file
89
sherpa-onnx/csrc/spoken-language-identification.h
Normal file
@@ -0,0 +1,89 @@
|
||||
// sherpa-onnx/csrc/spoken-language-identification.h
|
||||
//
|
||||
// Copyright (c) 2024 Xiaomi Corporation
|
||||
#ifndef SHERPA_ONNX_CSRC_SPOKEN_LANGUAGE_IDENTIFICATION_H_
|
||||
#define SHERPA_ONNX_CSRC_SPOKEN_LANGUAGE_IDENTIFICATION_H_
|
||||
|
||||
#include <memory>
|
||||
#include <string>
|
||||
|
||||
#include "sherpa-onnx/csrc/offline-stream.h"
|
||||
#include "sherpa-onnx/csrc/parse-options.h"
|
||||
|
||||
namespace sherpa_onnx {
|
||||
|
||||
struct SpokenLanguageIdentificationWhisperConfig {
|
||||
// Requires a multi-lingual whisper model.
|
||||
// That is, it supports only tiny, base, small, medium, large.
|
||||
// Note: It does NOT support tiny.en, base.en, small.en, medium.en
|
||||
std::string encoder;
|
||||
std::string decoder;
|
||||
|
||||
// Number of tail padding frames.
|
||||
//
|
||||
// Since we remove the 30-second constraint, we need to add some paddings
|
||||
// at the end.
|
||||
//
|
||||
// Recommended values:
|
||||
// - 50 for English models
|
||||
// - 300 for multilingual models
|
||||
int32_t tail_paddings = -1;
|
||||
|
||||
SpokenLanguageIdentificationWhisperConfig() = default;
|
||||
|
||||
SpokenLanguageIdentificationWhisperConfig(const std::string &encoder,
|
||||
const std::string &decoder,
|
||||
int32_t tail_paddings)
|
||||
: encoder(encoder), decoder(decoder), tail_paddings(tail_paddings) {}
|
||||
|
||||
void Register(ParseOptions *po);
|
||||
bool Validate() const;
|
||||
std::string ToString() const;
|
||||
};
|
||||
|
||||
struct SpokenLanguageIdentificationConfig {
|
||||
SpokenLanguageIdentificationWhisperConfig whisper;
|
||||
|
||||
int32_t num_threads = 1;
|
||||
bool debug = false;
|
||||
std::string provider = "cpu";
|
||||
|
||||
SpokenLanguageIdentificationConfig() = default;
|
||||
|
||||
SpokenLanguageIdentificationConfig(
|
||||
const SpokenLanguageIdentificationWhisperConfig &whisper,
|
||||
int32_t num_threads, bool debug, const std::string &provider)
|
||||
: whisper(whisper),
|
||||
num_threads(num_threads),
|
||||
debug(debug),
|
||||
provider(provider) {}
|
||||
|
||||
void Register(ParseOptions *po);
|
||||
bool Validate() const;
|
||||
std::string ToString() const;
|
||||
};
|
||||
|
||||
class SpokenLanguageIdentificationImpl;
|
||||
|
||||
class SpokenLanguageIdentification {
|
||||
public:
|
||||
explicit SpokenLanguageIdentification(
|
||||
const SpokenLanguageIdentificationConfig &config);
|
||||
|
||||
~SpokenLanguageIdentification();
|
||||
|
||||
// Create a stream to accept audio samples and compute features
|
||||
std::unique_ptr<OfflineStream> CreateStream() const;
|
||||
|
||||
// Return a string containing the language, e.g., en, zh, de,
|
||||
// etc.
|
||||
// Note: en is for English, zh is for Chinese, de is for German, etc.
|
||||
std::string Compute(OfflineStream *s) const;
|
||||
|
||||
private:
|
||||
std::unique_ptr<SpokenLanguageIdentificationImpl> impl_;
|
||||
};
|
||||
|
||||
} // namespace sherpa_onnx
|
||||
|
||||
#endif // SHERPA_ONNX_CSRC_SPOKEN_LANGUAGE_IDENTIFICATION_H_
|
||||
@@ -33,6 +33,7 @@ set(srcs
|
||||
silero-vad-model-config.cc
|
||||
speaker-embedding-extractor.cc
|
||||
speaker-embedding-manager.cc
|
||||
spoken-language-identification.cc
|
||||
vad-model-config.cc
|
||||
vad-model.cc
|
||||
voice-activity-detector.cc
|
||||
|
||||
@@ -22,6 +22,7 @@
|
||||
#include "sherpa-onnx/python/csrc/online-stream.h"
|
||||
#include "sherpa-onnx/python/csrc/speaker-embedding-extractor.h"
|
||||
#include "sherpa-onnx/python/csrc/speaker-embedding-manager.h"
|
||||
#include "sherpa-onnx/python/csrc/spoken-language-identification.h"
|
||||
#include "sherpa-onnx/python/csrc/vad-model-config.h"
|
||||
#include "sherpa-onnx/python/csrc/vad-model.h"
|
||||
#include "sherpa-onnx/python/csrc/voice-activity-detector.h"
|
||||
@@ -55,6 +56,7 @@ PYBIND11_MODULE(_sherpa_onnx, m) {
|
||||
PybindOfflineTts(&m);
|
||||
PybindSpeakerEmbeddingExtractor(&m);
|
||||
PybindSpeakerEmbeddingManager(&m);
|
||||
PybindSpokenLanguageIdentification(&m);
|
||||
|
||||
PybindAlsa(&m);
|
||||
}
|
||||
|
||||
60
sherpa-onnx/python/csrc/spoken-language-identification.cc
Normal file
60
sherpa-onnx/python/csrc/spoken-language-identification.cc
Normal file
@@ -0,0 +1,60 @@
|
||||
// sherpa-onnx/python/csrc/spoken-language-identification.cc
|
||||
//
|
||||
// Copyright (c) 2024 Xiaomi Corporation
|
||||
|
||||
#include "sherpa-onnx/python/csrc/spoken-language-identification.h"
|
||||
|
||||
#include <string>
|
||||
|
||||
#include "sherpa-onnx/csrc/spoken-language-identification.h"
|
||||
|
||||
namespace sherpa_onnx {
|
||||
|
||||
static void PybindSpokenLanguageIdentificationWhisperConfig(py::module *m) {
|
||||
using PyClass = SpokenLanguageIdentificationWhisperConfig;
|
||||
|
||||
py::class_<PyClass>(*m, "SpokenLanguageIdentificationWhisperConfig")
|
||||
.def(py::init<>())
|
||||
.def(py::init<const std::string &, const std::string &, int32_t>(),
|
||||
py::arg("encoder"), py::arg("decoder"),
|
||||
py::arg("tail_paddings") = -1)
|
||||
.def_readwrite("encoder", &PyClass::encoder)
|
||||
.def_readwrite("decoder", &PyClass::decoder)
|
||||
.def_readwrite("tail_paddings", &PyClass::tail_paddings)
|
||||
.def("validate", &PyClass::Validate)
|
||||
.def("__str__", &PyClass::ToString);
|
||||
}
|
||||
|
||||
static void PybindSpokenLanguageIdentificationConfig(py::module *m) {
|
||||
PybindSpokenLanguageIdentificationWhisperConfig(m);
|
||||
|
||||
using PyClass = SpokenLanguageIdentificationConfig;
|
||||
|
||||
py::class_<PyClass>(*m, "SpokenLanguageIdentificationConfig")
|
||||
.def(py::init<>())
|
||||
.def(py::init<const SpokenLanguageIdentificationWhisperConfig &, int32_t,
|
||||
bool, const std::string>(),
|
||||
py::arg("whisper"), py::arg("num_threads") = 1,
|
||||
py::arg("debug") = false, py::arg("provider") = "cpu")
|
||||
.def_readwrite("whisper", &PyClass::whisper)
|
||||
.def_readwrite("num_threads", &PyClass::num_threads)
|
||||
.def_readwrite("debug", &PyClass::debug)
|
||||
.def_readwrite("provider", &PyClass::provider)
|
||||
.def("validate", &PyClass::Validate)
|
||||
.def("__str__", &PyClass::ToString);
|
||||
}
|
||||
|
||||
void PybindSpokenLanguageIdentification(py::module *m) {
|
||||
PybindSpokenLanguageIdentificationConfig(m);
|
||||
|
||||
using PyClass = SpokenLanguageIdentification;
|
||||
py::class_<PyClass>(*m, "SpokenLanguageIdentification")
|
||||
.def(py::init<const SpokenLanguageIdentificationConfig &>(),
|
||||
py::arg("config"), py::call_guard<py::gil_scoped_release>())
|
||||
.def("create_stream", &PyClass::CreateStream,
|
||||
py::call_guard<py::gil_scoped_release>())
|
||||
.def("compute", &PyClass::Compute,
|
||||
py::call_guard<py::gil_scoped_release>());
|
||||
}
|
||||
|
||||
} // namespace sherpa_onnx
|
||||
16
sherpa-onnx/python/csrc/spoken-language-identification.h
Normal file
16
sherpa-onnx/python/csrc/spoken-language-identification.h
Normal file
@@ -0,0 +1,16 @@
|
||||
// sherpa-onnx/python/csrc/spoken-language-identification.h
|
||||
//
|
||||
// Copyright (c) 2024 Xiaomi Corporation
|
||||
|
||||
#ifndef SHERPA_ONNX_PYTHON_CSRC_SPOKEN_LANGUAGE_IDENTIFICATION_H_
|
||||
#define SHERPA_ONNX_PYTHON_CSRC_SPOKEN_LANGUAGE_IDENTIFICATION_H_
|
||||
|
||||
#include "sherpa-onnx/python/csrc/sherpa-onnx.h"
|
||||
|
||||
namespace sherpa_onnx {
|
||||
|
||||
void PybindSpokenLanguageIdentification(py::module *m);
|
||||
|
||||
}
|
||||
|
||||
#endif // SHERPA_ONNX_PYTHON_CSRC_SPOKEN_LANGUAGE_IDENTIFICATION_H_
|
||||
@@ -13,6 +13,9 @@ from _sherpa_onnx import (
|
||||
SpeakerEmbeddingExtractorConfig,
|
||||
SpeakerEmbeddingManager,
|
||||
SpeechSegment,
|
||||
SpokenLanguageIdentification,
|
||||
SpokenLanguageIdentificationConfig,
|
||||
SpokenLanguageIdentificationWhisperConfig,
|
||||
VadModel,
|
||||
VadModelConfig,
|
||||
VoiceActivityDetector,
|
||||
|
||||
Reference in New Issue
Block a user