Add Android demo for spoken language identification using Whisper multilingual models (#783)
This commit is contained in:
@@ -4,6 +4,8 @@
|
||||
|
||||
#include "sherpa-onnx/csrc/audio-tagging-impl.h"
|
||||
|
||||
#include <memory>
|
||||
|
||||
#if __ANDROID_API__ >= 9
|
||||
#include "android/asset_manager.h"
|
||||
#include "android/asset_manager_jni.h"
|
||||
|
||||
@@ -4,6 +4,8 @@
|
||||
|
||||
#include "sherpa-onnx/csrc/audio-tagging.h"
|
||||
|
||||
#include <string>
|
||||
|
||||
#if __ANDROID_API__ >= 9
|
||||
#include "android/asset_manager.h"
|
||||
#include "android/asset_manager_jni.h"
|
||||
|
||||
@@ -70,6 +70,23 @@ class OfflineWhisperModel::Impl {
|
||||
InitDecoder(buf.data(), buf.size());
|
||||
}
|
||||
}
|
||||
|
||||
Impl(AAssetManager *mgr, const SpokenLanguageIdentificationConfig &config)
|
||||
: lid_config_(config),
|
||||
env_(ORT_LOGGING_LEVEL_ERROR),
|
||||
sess_opts_(GetSessionOptions(config)),
|
||||
allocator_{} {
|
||||
debug_ = config_.debug;
|
||||
{
|
||||
auto buf = ReadFile(mgr, config.whisper.encoder);
|
||||
InitEncoder(buf.data(), buf.size());
|
||||
}
|
||||
|
||||
{
|
||||
auto buf = ReadFile(mgr, config.whisper.decoder);
|
||||
InitDecoder(buf.data(), buf.size());
|
||||
}
|
||||
}
|
||||
#endif
|
||||
|
||||
std::pair<Ort::Value, Ort::Value> ForwardEncoder(Ort::Value features) {
|
||||
@@ -326,6 +343,11 @@ OfflineWhisperModel::OfflineWhisperModel(
|
||||
OfflineWhisperModel::OfflineWhisperModel(AAssetManager *mgr,
|
||||
const OfflineModelConfig &config)
|
||||
: impl_(std::make_unique<Impl>(mgr, config)) {}
|
||||
|
||||
OfflineWhisperModel::OfflineWhisperModel(
|
||||
AAssetManager *mgr, const SpokenLanguageIdentificationConfig &config)
|
||||
: impl_(std::make_unique<Impl>(mgr, config)) {}
|
||||
|
||||
#endif
|
||||
|
||||
OfflineWhisperModel::~OfflineWhisperModel() = default;
|
||||
|
||||
@@ -31,6 +31,8 @@ class OfflineWhisperModel {
|
||||
|
||||
#if __ANDROID_API__ >= 9
|
||||
OfflineWhisperModel(AAssetManager *mgr, const OfflineModelConfig &config);
|
||||
OfflineWhisperModel(AAssetManager *mgr,
|
||||
const SpokenLanguageIdentificationConfig &config);
|
||||
#endif
|
||||
|
||||
~OfflineWhisperModel();
|
||||
|
||||
@@ -5,6 +5,11 @@
|
||||
|
||||
#include <memory>
|
||||
|
||||
#if __ANDROID_API__ >= 9
|
||||
#include "android/asset_manager.h"
|
||||
#include "android/asset_manager_jni.h"
|
||||
#endif
|
||||
|
||||
#include "sherpa-onnx/csrc/macros.h"
|
||||
#include "sherpa-onnx/csrc/onnx-utils.h"
|
||||
#include "sherpa-onnx/csrc/spoken-language-identification-whisper-impl.h"
|
||||
@@ -85,4 +90,34 @@ SpokenLanguageIdentificationImpl::Create(
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
#if __ANDROID_API__ >= 9
|
||||
std::unique_ptr<SpokenLanguageIdentificationImpl>
|
||||
SpokenLanguageIdentificationImpl::Create(
|
||||
AAssetManager *mgr, const SpokenLanguageIdentificationConfig &config) {
|
||||
ModelType model_type = ModelType::kUnknown;
|
||||
{
|
||||
if (config.whisper.encoder.empty()) {
|
||||
SHERPA_ONNX_LOGE("Only whisper models are supported at present");
|
||||
exit(-1);
|
||||
}
|
||||
auto buffer = ReadFile(mgr, config.whisper.encoder);
|
||||
|
||||
model_type = GetModelType(buffer.data(), buffer.size(), config.debug);
|
||||
}
|
||||
|
||||
switch (model_type) {
|
||||
case ModelType::kWhisper:
|
||||
return std::make_unique<SpokenLanguageIdentificationWhisperImpl>(mgr,
|
||||
config);
|
||||
case ModelType::kUnknown:
|
||||
SHERPA_ONNX_LOGE(
|
||||
"Unknown model type for spoken language identification!");
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
// unreachable code
|
||||
return nullptr;
|
||||
}
|
||||
#endif
|
||||
|
||||
} // namespace sherpa_onnx
|
||||
|
||||
@@ -7,6 +7,11 @@
|
||||
#include <memory>
|
||||
#include <string>
|
||||
|
||||
#if __ANDROID_API__ >= 9
|
||||
#include "android/asset_manager.h"
|
||||
#include "android/asset_manager_jni.h"
|
||||
#endif
|
||||
|
||||
#include "sherpa-onnx/csrc/spoken-language-identification.h"
|
||||
|
||||
namespace sherpa_onnx {
|
||||
@@ -18,6 +23,11 @@ class SpokenLanguageIdentificationImpl {
|
||||
static std::unique_ptr<SpokenLanguageIdentificationImpl> Create(
|
||||
const SpokenLanguageIdentificationConfig &config);
|
||||
|
||||
#if __ANDROID_API__ >= 9
|
||||
static std::unique_ptr<SpokenLanguageIdentificationImpl> Create(
|
||||
AAssetManager *mgr, const SpokenLanguageIdentificationConfig &config);
|
||||
#endif
|
||||
|
||||
virtual std::unique_ptr<OfflineStream> CreateStream() const = 0;
|
||||
|
||||
virtual std::string Compute(OfflineStream *s) const = 0;
|
||||
|
||||
@@ -11,6 +11,11 @@
|
||||
#include <utility>
|
||||
#include <vector>
|
||||
|
||||
#if __ANDROID_API__ >= 9
|
||||
#include "android/asset_manager.h"
|
||||
#include "android/asset_manager_jni.h"
|
||||
#endif
|
||||
|
||||
#include "sherpa-onnx/csrc/offline-whisper-model.h"
|
||||
#include "sherpa-onnx/csrc/spoken-language-identification-impl.h"
|
||||
#include "sherpa-onnx/csrc/transpose.h"
|
||||
@@ -26,6 +31,15 @@ class SpokenLanguageIdentificationWhisperImpl
|
||||
Check();
|
||||
}
|
||||
|
||||
#if __ANDROID_API__ >= 9
|
||||
SpokenLanguageIdentificationWhisperImpl(
|
||||
AAssetManager *mgr, const SpokenLanguageIdentificationConfig &config)
|
||||
: config_(config),
|
||||
model_(std::make_unique<OfflineWhisperModel>(mgr, config)) {
|
||||
Check();
|
||||
}
|
||||
#endif
|
||||
|
||||
std::unique_ptr<OfflineStream> CreateStream() const override {
|
||||
return std::make_unique<OfflineStream>(WhisperTag{});
|
||||
}
|
||||
|
||||
@@ -6,6 +6,11 @@
|
||||
|
||||
#include <string>
|
||||
|
||||
#if __ANDROID_API__ >= 9
|
||||
#include "android/asset_manager.h"
|
||||
#include "android/asset_manager_jni.h"
|
||||
#endif
|
||||
|
||||
#include "sherpa-onnx/csrc/file-utils.h"
|
||||
#include "sherpa-onnx/csrc/macros.h"
|
||||
#include "sherpa-onnx/csrc/spoken-language-identification-impl.h"
|
||||
@@ -103,6 +108,12 @@ SpokenLanguageIdentification::SpokenLanguageIdentification(
|
||||
const SpokenLanguageIdentificationConfig &config)
|
||||
: impl_(SpokenLanguageIdentificationImpl::Create(config)) {}
|
||||
|
||||
#if __ANDROID_API__ >= 9
|
||||
SpokenLanguageIdentification::SpokenLanguageIdentification(
|
||||
AAssetManager *mgr, const SpokenLanguageIdentificationConfig &config)
|
||||
: impl_(SpokenLanguageIdentificationImpl::Create(mgr, config)) {}
|
||||
#endif
|
||||
|
||||
SpokenLanguageIdentification::~SpokenLanguageIdentification() = default;
|
||||
|
||||
std::unique_ptr<OfflineStream> SpokenLanguageIdentification::CreateStream()
|
||||
|
||||
@@ -7,6 +7,11 @@
|
||||
#include <memory>
|
||||
#include <string>
|
||||
|
||||
#if __ANDROID_API__ >= 9
|
||||
#include "android/asset_manager.h"
|
||||
#include "android/asset_manager_jni.h"
|
||||
#endif
|
||||
|
||||
#include "sherpa-onnx/csrc/offline-stream.h"
|
||||
#include "sherpa-onnx/csrc/parse-options.h"
|
||||
|
||||
@@ -70,6 +75,11 @@ class SpokenLanguageIdentification {
|
||||
explicit SpokenLanguageIdentification(
|
||||
const SpokenLanguageIdentificationConfig &config);
|
||||
|
||||
#if __ANDROID_API__ >= 9
|
||||
SpokenLanguageIdentification(
|
||||
AAssetManager *mgr, const SpokenLanguageIdentificationConfig &config);
|
||||
#endif
|
||||
|
||||
~SpokenLanguageIdentification();
|
||||
|
||||
// Create a stream to accept audio samples and compute features
|
||||
|
||||
@@ -54,6 +54,32 @@ static SpokenLanguageIdentificationConfig GetSpokenLanguageIdentificationConfig(
|
||||
|
||||
} // namespace sherpa_onnx
|
||||
|
||||
SHERPA_ONNX_EXTERN_C
|
||||
JNIEXPORT jlong JNICALL
|
||||
Java_com_k2fsa_sherpa_onnx_SpokenLanguageIdentification_newFromAsset(
|
||||
JNIEnv *env, jobject /*obj*/, jobject asset_manager, jobject _config) {
|
||||
#if __ANDROID_API__ >= 9
|
||||
AAssetManager *mgr = AAssetManager_fromJava(env, asset_manager);
|
||||
if (!mgr) {
|
||||
SHERPA_ONNX_LOGE("Failed to get asset manager: %p", mgr);
|
||||
}
|
||||
#endif
|
||||
|
||||
auto config =
|
||||
sherpa_onnx::GetSpokenLanguageIdentificationConfig(env, _config);
|
||||
SHERPA_ONNX_LOGE("spoken language identification newFromAsset config:\n%s",
|
||||
config.ToString().c_str());
|
||||
|
||||
auto slid = new sherpa_onnx::SpokenLanguageIdentification(
|
||||
#if __ANDROID_API__ >= 9
|
||||
mgr,
|
||||
#endif
|
||||
config);
|
||||
SHERPA_ONNX_LOGE("slid %p", slid);
|
||||
|
||||
return (jlong)slid;
|
||||
}
|
||||
|
||||
SHERPA_ONNX_EXTERN_C
|
||||
JNIEXPORT jlong JNICALL
|
||||
Java_com_k2fsa_sherpa_onnx_SpokenLanguageIdentification_newFromFile(
|
||||
@@ -73,6 +99,14 @@ Java_com_k2fsa_sherpa_onnx_SpokenLanguageIdentification_newFromFile(
|
||||
return (jlong)tagger;
|
||||
}
|
||||
|
||||
SHERPA_ONNX_EXTERN_C
|
||||
JNIEXPORT void JNICALL
|
||||
Java_com_k2fsa_sherpa_onnx_SpokenLanguageIdentification_delete(JNIEnv *env,
|
||||
jobject /*obj*/,
|
||||
jlong ptr) {
|
||||
delete reinterpret_cast<sherpa_onnx::SpokenLanguageIdentification *>(ptr);
|
||||
}
|
||||
|
||||
SHERPA_ONNX_EXTERN_C
|
||||
JNIEXPORT jlong JNICALL
|
||||
Java_com_k2fsa_sherpa_onnx_SpokenLanguageIdentification_createStream(
|
||||
|
||||
Reference in New Issue
Block a user