Add speaker diarization API for HarmonyOS. (#1609)

This commit is contained in:
Fangjun Kuang
2024-12-10 16:03:03 +08:00
committed by GitHub
parent 14944d8c81
commit 1bae4085ca
18 changed files with 279 additions and 79 deletions

View File

@@ -1,11 +1,6 @@
export { export { listRawfileDir, readWave, readWaveFromBinary, } from "libsherpa_onnx.so";
listRawfileDir,
readWave,
readWaveFromBinary,
} from "libsherpa_onnx.so";
export { export { CircularBuffer,
CircularBuffer,
SileroVadConfig, SileroVadConfig,
SpeechSegment, SpeechSegment,
Vad, Vad,
@@ -13,8 +8,7 @@ export {
} from './src/main/ets/components/Vad'; } from './src/main/ets/components/Vad';
export { export { Samples,
Samples,
OfflineStream, OfflineStream,
FeatureConfig, FeatureConfig,
OfflineTransducerModelConfig, OfflineTransducerModelConfig,
@@ -31,8 +25,7 @@ export {
OfflineRecognizer, OfflineRecognizer,
} from './src/main/ets/components/NonStreamingAsr'; } from './src/main/ets/components/NonStreamingAsr';
export { export { OnlineStream,
OnlineStream,
OnlineTransducerModelConfig, OnlineTransducerModelConfig,
OnlineParaformerModelConfig, OnlineParaformerModelConfig,
OnlineZipformer2CtcModelConfig, OnlineZipformer2CtcModelConfig,
@@ -43,8 +36,7 @@ export {
OnlineRecognizer, OnlineRecognizer,
} from './src/main/ets/components/StreamingAsr'; } from './src/main/ets/components/StreamingAsr';
export { export { OfflineTtsVitsModelConfig,
OfflineTtsVitsModelConfig,
OfflineTtsModelConfig, OfflineTtsModelConfig,
OfflineTtsConfig, OfflineTtsConfig,
OfflineTts, OfflineTts,
@@ -52,8 +44,15 @@ export {
TtsInput, TtsInput,
} from './src/main/ets/components/NonStreamingTts'; } from './src/main/ets/components/NonStreamingTts';
export { export { SpeakerEmbeddingExtractorConfig,
SpeakerEmbeddingExtractorConfig,
SpeakerEmbeddingExtractor, SpeakerEmbeddingExtractor,
SpeakerEmbeddingManager, SpeakerEmbeddingManager,
} from './src/main/ets/components/SpeakerIdentification'; } from './src/main/ets/components/SpeakerIdentification';
export { OfflineSpeakerSegmentationPyannoteModelConfig,
OfflineSpeakerSegmentationModelConfig,
OfflineSpeakerDiarizationConfig,
OfflineSpeakerDiarizationSegment,
OfflineSpeakerDiarization,
FastClusteringConfig,
} from './src/main/ets/components/NonStreamingSpeakerDiarization';

View File

@@ -101,6 +101,17 @@ static SherpaOnnxFastClusteringConfig GetFastClusteringConfig(
static Napi::External<SherpaOnnxOfflineSpeakerDiarization> static Napi::External<SherpaOnnxOfflineSpeakerDiarization>
CreateOfflineSpeakerDiarizationWrapper(const Napi::CallbackInfo &info) { CreateOfflineSpeakerDiarizationWrapper(const Napi::CallbackInfo &info) {
Napi::Env env = info.Env(); Napi::Env env = info.Env();
#if __OHOS__
if (info.Length() != 2) {
std::ostringstream os;
os << "Expect only 2 arguments. Given: " << info.Length();
Napi::TypeError::New(env, os.str()).ThrowAsJavaScriptException();
return {};
}
#else
if (info.Length() != 1) { if (info.Length() != 1) {
std::ostringstream os; std::ostringstream os;
os << "Expect only 1 argument. Given: " << info.Length(); os << "Expect only 1 argument. Given: " << info.Length();
@@ -109,6 +120,7 @@ CreateOfflineSpeakerDiarizationWrapper(const Napi::CallbackInfo &info) {
return {}; return {};
} }
#endif
if (!info[0].IsObject()) { if (!info[0].IsObject()) {
Napi::TypeError::New(env, "Expect an object as the argument") Napi::TypeError::New(env, "Expect an object as the argument")
@@ -129,8 +141,18 @@ CreateOfflineSpeakerDiarizationWrapper(const Napi::CallbackInfo &info) {
SHERPA_ONNX_ASSIGN_ATTR_FLOAT(min_duration_on, minDurationOn); SHERPA_ONNX_ASSIGN_ATTR_FLOAT(min_duration_on, minDurationOn);
SHERPA_ONNX_ASSIGN_ATTR_FLOAT(min_duration_off, minDurationOff); SHERPA_ONNX_ASSIGN_ATTR_FLOAT(min_duration_off, minDurationOff);
#if __OHOS__
std::unique_ptr<NativeResourceManager,
decltype(&OH_ResourceManager_ReleaseNativeResourceManager)>
mgr(OH_ResourceManager_InitNativeResourceManager(env, info[1]),
&OH_ResourceManager_ReleaseNativeResourceManager);
const SherpaOnnxOfflineSpeakerDiarization *sd =
SherpaOnnxCreateOfflineSpeakerDiarizationOHOS(&c, mgr.get());
#else
const SherpaOnnxOfflineSpeakerDiarization *sd = const SherpaOnnxOfflineSpeakerDiarization *sd =
SherpaOnnxCreateOfflineSpeakerDiarization(&c); SherpaOnnxCreateOfflineSpeakerDiarization(&c);
#endif
if (c.segmentation.pyannote.model) { if (c.segmentation.pyannote.model) {
delete[] c.segmentation.pyannote.model; delete[] c.segmentation.pyannote.model;
@@ -224,9 +246,17 @@ static Napi::Array OfflineSpeakerDiarizationProcessWrapper(
Napi::Float32Array samples = info[1].As<Napi::Float32Array>(); Napi::Float32Array samples = info[1].As<Napi::Float32Array>();
#if __OHOS__
// Note(fangjun): For unknown reasons on HarmonyOS, we need to divide it by
// sizeof(float) here
const SherpaOnnxOfflineSpeakerDiarizationResult *r =
SherpaOnnxOfflineSpeakerDiarizationProcess(
sd, samples.Data(), samples.ElementLength() / sizeof(float));
#else
const SherpaOnnxOfflineSpeakerDiarizationResult *r = const SherpaOnnxOfflineSpeakerDiarizationResult *r =
SherpaOnnxOfflineSpeakerDiarizationProcess(sd, samples.Data(), SherpaOnnxOfflineSpeakerDiarizationProcess(sd, samples.Data(),
samples.ElementLength()); samples.ElementLength());
#endif
int32_t num_segments = int32_t num_segments =
SherpaOnnxOfflineSpeakerDiarizationResultGetNumSegments(r); SherpaOnnxOfflineSpeakerDiarizationResultGetNumSegments(r);

View File

@@ -62,3 +62,8 @@ export const speakerEmbeddingManagerVerify: (handle: object, obj: {name: string,
export const speakerEmbeddingManagerContains: (handle: object, name: string) => boolean; export const speakerEmbeddingManagerContains: (handle: object, name: string) => boolean;
export const speakerEmbeddingManagerNumSpeakers: (handle: object) => number; export const speakerEmbeddingManagerNumSpeakers: (handle: object) => number;
export const speakerEmbeddingManagerGetAllSpeakers: (handle: object) => Array<string>; export const speakerEmbeddingManagerGetAllSpeakers: (handle: object) => Array<string>;
export const createOfflineSpeakerDiarization: (config: object, mgr?: object) => object;
export const getOfflineSpeakerDiarizationSampleRate: (handle: object) => number;
export const offlineSpeakerDiarizationProcess: (handle: object, samples: Float32Array) => object;
export const offlineSpeakerDiarizationSetConfig: (handle: object, config: object) => void;

View File

@@ -67,10 +67,15 @@ static Napi::Boolean WriteWaveWrapper(const Napi::CallbackInfo &info) {
Napi::Float32Array samples = obj.Get("samples").As<Napi::Float32Array>(); Napi::Float32Array samples = obj.Get("samples").As<Napi::Float32Array>();
int32_t sample_rate = obj.Get("sampleRate").As<Napi::Number>().Int32Value(); int32_t sample_rate = obj.Get("sampleRate").As<Napi::Number>().Int32Value();
#if __OHOS__
int32_t ok = SherpaOnnxWriteWave(
samples.Data(), samples.ElementLength() / sizeof(float), sample_rate,
info[0].As<Napi::String>().Utf8Value().c_str());
#else
int32_t ok = int32_t ok =
SherpaOnnxWriteWave(samples.Data(), samples.ElementLength(), sample_rate, SherpaOnnxWriteWave(samples.Data(), samples.ElementLength(), sample_rate,
info[0].As<Napi::String>().Utf8Value().c_str()); info[0].As<Napi::String>().Utf8Value().c_str());
#endif
return Napi::Boolean::New(env, ok); return Napi::Boolean::New(env, ok);
} }

View File

@@ -0,0 +1,73 @@
import {
createOfflineSpeakerDiarization,
getOfflineSpeakerDiarizationSampleRate,
offlineSpeakerDiarizationProcess,
offlineSpeakerDiarizationSetConfig,
} from 'libsherpa_onnx.so';
import { SpeakerEmbeddingExtractorConfig } from './SpeakerIdentification';
export class OfflineSpeakerSegmentationPyannoteModelConfig {
public model: string = '';
}
export class OfflineSpeakerSegmentationModelConfig {
public pyannote: OfflineSpeakerSegmentationPyannoteModelConfig = new OfflineSpeakerSegmentationPyannoteModelConfig();
public numThreads: number = 1;
public debug: boolean = false;
public provider: string = 'cpu';
}
export class FastClusteringConfig {
public numClusters: number = -1;
public threshold: number = 0.5;
}
export class OfflineSpeakerDiarizationConfig {
public segmentation: OfflineSpeakerSegmentationModelConfig = new OfflineSpeakerSegmentationModelConfig();
public embedding: SpeakerEmbeddingExtractorConfig = new SpeakerEmbeddingExtractorConfig();
public clustering: FastClusteringConfig = new FastClusteringConfig();
public minDurationOn: number = 0.2;
public minDurationOff: number = 0.5;
}
export class OfflineSpeakerDiarizationSegment {
public start: number = 0; // in secondspublic end: number = 0; // in secondspublic speaker: number =
0; // ID of the speaker; count from 0
}
export class OfflineSpeakerDiarization {
public config: OfflineSpeakerDiarizationConfig;
public sampleRate: number;
private handle: object;
constructor(config: OfflineSpeakerDiarizationConfig, mgr?: object) {
this.handle = createOfflineSpeakerDiarization(config, mgr);
this.config = config;
this.sampleRate = getOfflineSpeakerDiarizationSampleRate(this.handle);
}
/**
* samples is a 1-d float32 array. Each element of the array should be
* in the range [-1, 1].
*
* We assume its sample rate equals to this.sampleRate.
*
* Returns an array of object, where an object is
*
* {
* "start": start_time_in_seconds,
* "end": end_time_in_seconds,
* "speaker": an_integer,
* }
*/
process(samples: Float32Array): OfflineSpeakerDiarizationSegment {
return offlineSpeakerDiarizationProcess(this.handle, samples) as OfflineSpeakerDiarizationSegment;
}
setConfig(config: OfflineSpeakerDiarizationConfig) {
offlineSpeakerDiarizationSetConfig(this.handle, config);
this.config.clustering = config.clustering;
}
}

View File

@@ -35,8 +35,7 @@ export class SpeakerEmbeddingExtractor {
} }
createStream(): OnlineStream { createStream(): OnlineStream {
return new OnlineStream( return new OnlineStream(speakerEmbeddingExtractorCreateStream(this.handle));
speakerEmbeddingExtractorCreateStream(this.handle));
} }
isReady(stream: OnlineStream): boolean { isReady(stream: OnlineStream): boolean {
@@ -44,8 +43,7 @@ export class SpeakerEmbeddingExtractor {
} }
compute(stream: OnlineStream, enableExternalBuffer: boolean = true): Float32Array { compute(stream: OnlineStream, enableExternalBuffer: boolean = true): Float32Array {
return speakerEmbeddingExtractorComputeEmbedding( return speakerEmbeddingExtractorComputeEmbedding(this.handle, stream.handle, enableExternalBuffer);
this.handle, stream.handle, enableExternalBuffer);
} }
} }
@@ -106,9 +104,7 @@ export class SpeakerEmbeddingManager {
addMulti(speaker: SpeakerNameWithEmbeddingList): boolean { addMulti(speaker: SpeakerNameWithEmbeddingList): boolean {
const c: SpeakerNameWithEmbeddingN = { const c: SpeakerNameWithEmbeddingN = {
name: speaker.name, name: speaker.name, vv: flatten(speaker.v), n: speaker.v.length,
vv: flatten(speaker.v),
n: speaker.v.length,
}; };
return speakerEmbeddingManagerAddListFlattened(this.handle, c); return speakerEmbeddingManagerAddListFlattened(this.handle, c);
} }

View File

@@ -125,8 +125,7 @@ export class OnlineRecognizer {
} }
getResult(stream: OnlineStream): OnlineRecognizerResult { getResult(stream: OnlineStream): OnlineRecognizerResult {
const jsonStr: string = const jsonStr: string = getOnlineStreamResultAsJson(this.handle, stream.handle);
getOnlineStreamResultAsJson(this.handle, stream.handle);
let o = JSON.parse(jsonStr) as OnlineRecognizerResultJson; let o = JSON.parse(jsonStr) as OnlineRecognizerResultJson;

View File

@@ -62,8 +62,7 @@ export class CircularBuffer {
// return a float32 array // return a float32 array
get(startIndex: number, n: number, enableExternalBuffer: boolean = true): Float32Array { get(startIndex: number, n: number, enableExternalBuffer: boolean = true): Float32Array {
return circularBufferGet( return circularBufferGet(this.handle, startIndex, n, enableExternalBuffer);
this.handle, startIndex, n, enableExternalBuffer);
} }
pop(n: number) { pop(n: number) {
@@ -93,8 +92,7 @@ export class Vad {
private handle: object; private handle: object;
constructor(config: VadConfig, bufferSizeInSeconds?: number, mgr?: object) { constructor(config: VadConfig, bufferSizeInSeconds?: number, mgr?: object) {
this.handle = this.handle = createVoiceActivityDetector(config, bufferSizeInSeconds, mgr);
createVoiceActivityDetector(config, bufferSizeInSeconds, mgr);
this.config = config; this.config = config;
} }

View File

@@ -27,7 +27,7 @@ class OfflineSpeakerDiarization {
} }
setConfig(config) { setConfig(config) {
addon.offlineSpeakerDiarizationSetConfig(config); addon.offlineSpeakerDiarizationSetConfig(this.handle, config);
this.config.clustering = config.clustering; this.config.clustering = config.clustering;
} }
} }

View File

@@ -1784,8 +1784,8 @@ struct SherpaOnnxOfflineSpeakerDiarizationResult {
sherpa_onnx::OfflineSpeakerDiarizationResult impl; sherpa_onnx::OfflineSpeakerDiarizationResult impl;
}; };
const SherpaOnnxOfflineSpeakerDiarization * static sherpa_onnx::OfflineSpeakerDiarizationConfig
SherpaOnnxCreateOfflineSpeakerDiarization( GetOfflineSpeakerDiarizationConfig(
const SherpaOnnxOfflineSpeakerDiarizationConfig *config) { const SherpaOnnxOfflineSpeakerDiarizationConfig *config) {
sherpa_onnx::OfflineSpeakerDiarizationConfig sd_config; sherpa_onnx::OfflineSpeakerDiarizationConfig sd_config;
@@ -1820,6 +1820,22 @@ SherpaOnnxCreateOfflineSpeakerDiarization(
sd_config.min_duration_off = SHERPA_ONNX_OR(config->min_duration_off, 0.5); sd_config.min_duration_off = SHERPA_ONNX_OR(config->min_duration_off, 0.5);
if (sd_config.segmentation.debug || sd_config.embedding.debug) {
#if __OHOS__
SHERPA_ONNX_LOGE("%{public}s\n", sd_config.ToString().c_str());
#else
SHERPA_ONNX_LOGE("%s\n", sd_config.ToString().c_str());
#endif
}
return sd_config;
}
const SherpaOnnxOfflineSpeakerDiarization *
SherpaOnnxCreateOfflineSpeakerDiarization(
const SherpaOnnxOfflineSpeakerDiarizationConfig *config) {
auto sd_config = GetOfflineSpeakerDiarizationConfig(config);
if (!sd_config.Validate()) { if (!sd_config.Validate()) {
SHERPA_ONNX_LOGE("Errors in config"); SHERPA_ONNX_LOGE("Errors in config");
return nullptr; return nullptr;
@@ -1831,10 +1847,6 @@ SherpaOnnxCreateOfflineSpeakerDiarization(
sd->impl = sd->impl =
std::make_unique<sherpa_onnx::OfflineSpeakerDiarization>(sd_config); std::make_unique<sherpa_onnx::OfflineSpeakerDiarization>(sd_config);
if (sd_config.segmentation.debug || sd_config.embedding.debug) {
SHERPA_ONNX_LOGE("%s\n", sd_config.ToString().c_str());
}
return sd; return sd;
} }
@@ -2029,5 +2041,32 @@ SherpaOnnxOfflineTts *SherpaOnnxCreateOfflineTtsOHOS(
} }
#endif // #if SHERPA_ONNX_ENABLE_TTS == 1 #endif // #if SHERPA_ONNX_ENABLE_TTS == 1
//
#if SHERPA_ONNX_ENABLE_SPEAKER_DIARIZATION == 1
const SherpaOnnxOfflineSpeakerDiarization *
SherpaOnnxCreateOfflineSpeakerDiarizationOHOS(
const SherpaOnnxOfflineSpeakerDiarizationConfig *config,
NativeResourceManager *mgr) {
if (!mgr) {
return SherpaOnnxCreateOfflineSpeakerDiarization(config);
}
auto sd_config = GetOfflineSpeakerDiarizationConfig(config);
if (!sd_config.Validate()) {
SHERPA_ONNX_LOGE("Errors in config");
return nullptr;
}
SherpaOnnxOfflineSpeakerDiarization *sd =
new SherpaOnnxOfflineSpeakerDiarization;
sd->impl =
std::make_unique<sherpa_onnx::OfflineSpeakerDiarization>(mgr, sd_config);
return sd;
}
#endif // #if SHERPA_ONNX_ENABLE_SPEAKER_DIARIZATION == 1
#endif // #ifdef __OHOS__ #endif // #ifdef __OHOS__

View File

@@ -1577,6 +1577,11 @@ SHERPA_ONNX_API const SherpaOnnxSpeakerEmbeddingExtractor *
SherpaOnnxCreateSpeakerEmbeddingExtractorOHOS( SherpaOnnxCreateSpeakerEmbeddingExtractorOHOS(
const SherpaOnnxSpeakerEmbeddingExtractorConfig *config, const SherpaOnnxSpeakerEmbeddingExtractorConfig *config,
NativeResourceManager *mgr); NativeResourceManager *mgr);
SHERPA_ONNX_API const SherpaOnnxOfflineSpeakerDiarization *
SherpaOnnxCreateOfflineSpeakerDiarizationOHOS(
const SherpaOnnxOfflineSpeakerDiarizationConfig *config,
NativeResourceManager *mgr);
#endif #endif
#if defined(__GNUC__) #if defined(__GNUC__)

View File

@@ -6,6 +6,15 @@
#include <memory> #include <memory>
#if __ANDROID_API__ >= 9
#include "android/asset_manager.h"
#include "android/asset_manager_jni.h"
#endif
#if __OHOS__
#include "rawfile/raw_file_manager.h"
#endif
#include "sherpa-onnx/csrc/macros.h" #include "sherpa-onnx/csrc/macros.h"
#include "sherpa-onnx/csrc/offline-speaker-diarization-pyannote-impl.h" #include "sherpa-onnx/csrc/offline-speaker-diarization-pyannote-impl.h"
@@ -23,10 +32,10 @@ OfflineSpeakerDiarizationImpl::Create(
return nullptr; return nullptr;
} }
#if __ANDROID_API__ >= 9 template <typename Manager>
std::unique_ptr<OfflineSpeakerDiarizationImpl> std::unique_ptr<OfflineSpeakerDiarizationImpl>
OfflineSpeakerDiarizationImpl::Create( OfflineSpeakerDiarizationImpl::Create(
AAssetManager *mgr, const OfflineSpeakerDiarizationConfig &config) { Manager *mgr, const OfflineSpeakerDiarizationConfig &config) {
if (!config.segmentation.pyannote.model.empty()) { if (!config.segmentation.pyannote.model.empty()) {
return std::make_unique<OfflineSpeakerDiarizationPyannoteImpl>(mgr, config); return std::make_unique<OfflineSpeakerDiarizationPyannoteImpl>(mgr, config);
} }
@@ -35,6 +44,17 @@ OfflineSpeakerDiarizationImpl::Create(
return nullptr; return nullptr;
} }
#if __ANDROID_API__ >= 9
template std::unique_ptr<OfflineSpeakerDiarizationImpl>
OfflineSpeakerDiarizationImpl::Create(
AAssetManager *mgr, const OfflineSpeakerDiarizationConfig &config);
#endif
#if __OHOS__
template std::unique_ptr<OfflineSpeakerDiarizationImpl>
OfflineSpeakerDiarizationImpl::Create(
NativeResourceManager *mgr, const OfflineSpeakerDiarizationConfig &config);
#endif #endif
} // namespace sherpa_onnx } // namespace sherpa_onnx

View File

@@ -8,11 +8,6 @@
#include <functional> #include <functional>
#include <memory> #include <memory>
#if __ANDROID_API__ >= 9
#include "android/asset_manager.h"
#include "android/asset_manager_jni.h"
#endif
#include "sherpa-onnx/csrc/offline-speaker-diarization.h" #include "sherpa-onnx/csrc/offline-speaker-diarization.h"
namespace sherpa_onnx { namespace sherpa_onnx {
@@ -21,10 +16,9 @@ class OfflineSpeakerDiarizationImpl {
static std::unique_ptr<OfflineSpeakerDiarizationImpl> Create( static std::unique_ptr<OfflineSpeakerDiarizationImpl> Create(
const OfflineSpeakerDiarizationConfig &config); const OfflineSpeakerDiarizationConfig &config);
#if __ANDROID_API__ >= 9 template <typename Manager>
static std::unique_ptr<OfflineSpeakerDiarizationImpl> Create( static std::unique_ptr<OfflineSpeakerDiarizationImpl> Create(
AAssetManager *mgr, const OfflineSpeakerDiarizationConfig &config); Manager *mgr, const OfflineSpeakerDiarizationConfig &config);
#endif
virtual ~OfflineSpeakerDiarizationImpl() = default; virtual ~OfflineSpeakerDiarizationImpl() = default;

View File

@@ -11,11 +11,6 @@
#include <utility> #include <utility>
#include <vector> #include <vector>
#if __ANDROID_API__ >= 9
#include "android/asset_manager.h"
#include "android/asset_manager_jni.h"
#endif
#include "Eigen/Dense" #include "Eigen/Dense"
#include "sherpa-onnx/csrc/fast-clustering.h" #include "sherpa-onnx/csrc/fast-clustering.h"
#include "sherpa-onnx/csrc/math.h" #include "sherpa-onnx/csrc/math.h"
@@ -71,16 +66,15 @@ class OfflineSpeakerDiarizationPyannoteImpl
Init(); Init();
} }
#if __ANDROID_API__ >= 9 template <typename Manager>
OfflineSpeakerDiarizationPyannoteImpl( OfflineSpeakerDiarizationPyannoteImpl(
AAssetManager *mgr, const OfflineSpeakerDiarizationConfig &config) Manager *mgr, const OfflineSpeakerDiarizationConfig &config)
: config_(config), : config_(config),
segmentation_model_(mgr, config_.segmentation), segmentation_model_(mgr, config_.segmentation),
embedding_extractor_(mgr, config_.embedding), embedding_extractor_(mgr, config_.embedding),
clustering_(std::make_unique<FastClustering>(config_.clustering)) { clustering_(std::make_unique<FastClustering>(config_.clustering)) {
Init(); Init();
} }
#endif
int32_t SampleRate() const override { int32_t SampleRate() const override {
const auto &meta_data = segmentation_model_.GetModelMetaData(); const auto &meta_data = segmentation_model_.GetModelMetaData();
@@ -213,8 +207,13 @@ class OfflineSpeakerDiarizationPyannoteImpl
} }
} }
} else { } else {
#if __OHOS__
SHERPA_ONNX_LOGE(
"powerset_max_classes = %{public}d is currently not supported!", i);
#else
SHERPA_ONNX_LOGE( SHERPA_ONNX_LOGE(
"powerset_max_classes = %d is currently not supported!", i); "powerset_max_classes = %d is currently not supported!", i);
#endif
SHERPA_ONNX_EXIT(-1); SHERPA_ONNX_EXIT(-1);
} }
} }
@@ -229,10 +228,17 @@ class OfflineSpeakerDiarizationPyannoteImpl
int32_t window_shift = meta_data.window_shift; int32_t window_shift = meta_data.window_shift;
if (n <= 0) { if (n <= 0) {
#if __OHOS__
SHERPA_ONNX_LOGE(
"number of audio samples is %{public}d (<= 0). Please provide a "
"positive number",
n);
#else
SHERPA_ONNX_LOGE( SHERPA_ONNX_LOGE(
"number of audio samples is %d (<= 0). Please provide a positive " "number of audio samples is %d (<= 0). Please provide a positive "
"number", "number",
n); n);
#endif
return {}; return {};
} }

View File

@@ -7,6 +7,15 @@
#include <string> #include <string>
#include <utility> #include <utility>
#if __ANDROID_API__ >= 9
#include "android/asset_manager.h"
#include "android/asset_manager_jni.h"
#endif
#if __OHOS__
#include "rawfile/raw_file_manager.h"
#endif
#include "sherpa-onnx/csrc/offline-speaker-diarization-impl.h" #include "sherpa-onnx/csrc/offline-speaker-diarization-impl.h"
namespace sherpa_onnx { namespace sherpa_onnx {
@@ -74,11 +83,10 @@ OfflineSpeakerDiarization::OfflineSpeakerDiarization(
const OfflineSpeakerDiarizationConfig &config) const OfflineSpeakerDiarizationConfig &config)
: impl_(OfflineSpeakerDiarizationImpl::Create(config)) {} : impl_(OfflineSpeakerDiarizationImpl::Create(config)) {}
#if __ANDROID_API__ >= 9 template <typename Manager>
OfflineSpeakerDiarization::OfflineSpeakerDiarization( OfflineSpeakerDiarization::OfflineSpeakerDiarization(
AAssetManager *mgr, const OfflineSpeakerDiarizationConfig &config) Manager *mgr, const OfflineSpeakerDiarizationConfig &config)
: impl_(OfflineSpeakerDiarizationImpl::Create(mgr, config)) {} : impl_(OfflineSpeakerDiarizationImpl::Create(mgr, config)) {}
#endif
OfflineSpeakerDiarization::~OfflineSpeakerDiarization() = default; OfflineSpeakerDiarization::~OfflineSpeakerDiarization() = default;
@@ -98,4 +106,14 @@ OfflineSpeakerDiarizationResult OfflineSpeakerDiarization::Process(
return impl_->Process(audio, n, std::move(callback), callback_arg); return impl_->Process(audio, n, std::move(callback), callback_arg);
} }
#if __ANDROID_API__ >= 9
template OfflineSpeakerDiarization::OfflineSpeakerDiarization(
AAssetManager *mgr, const OfflineSpeakerDiarizationConfig &config);
#endif
#if __OHOS__
template OfflineSpeakerDiarization::OfflineSpeakerDiarization(
NativeResourceManager *mgr, const OfflineSpeakerDiarizationConfig &config);
#endif
} // namespace sherpa_onnx } // namespace sherpa_onnx

View File

@@ -9,11 +9,6 @@
#include <memory> #include <memory>
#include <string> #include <string>
#if __ANDROID_API__ >= 9
#include "android/asset_manager.h"
#include "android/asset_manager_jni.h"
#endif
#include "sherpa-onnx/csrc/fast-clustering-config.h" #include "sherpa-onnx/csrc/fast-clustering-config.h"
#include "sherpa-onnx/csrc/offline-speaker-diarization-result.h" #include "sherpa-onnx/csrc/offline-speaker-diarization-result.h"
#include "sherpa-onnx/csrc/offline-speaker-segmentation-model-config.h" #include "sherpa-onnx/csrc/offline-speaker-segmentation-model-config.h"
@@ -62,10 +57,9 @@ class OfflineSpeakerDiarization {
explicit OfflineSpeakerDiarization( explicit OfflineSpeakerDiarization(
const OfflineSpeakerDiarizationConfig &config); const OfflineSpeakerDiarizationConfig &config);
#if __ANDROID_API__ >= 9 template <typename Manager>
OfflineSpeakerDiarization(AAssetManager *mgr, OfflineSpeakerDiarization(Manager *mgr,
const OfflineSpeakerDiarizationConfig &config); const OfflineSpeakerDiarizationConfig &config);
#endif
~OfflineSpeakerDiarization(); ~OfflineSpeakerDiarization();

View File

@@ -8,6 +8,15 @@
#include <utility> #include <utility>
#include <vector> #include <vector>
#if __ANDROID_API__ >= 9
#include "android/asset_manager.h"
#include "android/asset_manager_jni.h"
#endif
#if __OHOS__
#include "rawfile/raw_file_manager.h"
#endif
#include "sherpa-onnx/csrc/onnx-utils.h" #include "sherpa-onnx/csrc/onnx-utils.h"
#include "sherpa-onnx/csrc/session.h" #include "sherpa-onnx/csrc/session.h"
@@ -24,8 +33,8 @@ class OfflineSpeakerSegmentationPyannoteModel::Impl {
Init(buf.data(), buf.size()); Init(buf.data(), buf.size());
} }
#if __ANDROID_API__ >= 9 template <typename Manager>
Impl(AAssetManager *mgr, const OfflineSpeakerSegmentationModelConfig &config) Impl(Manager *mgr, const OfflineSpeakerSegmentationModelConfig &config)
: config_(config), : config_(config),
env_(ORT_LOGGING_LEVEL_ERROR), env_(ORT_LOGGING_LEVEL_ERROR),
sess_opts_(GetSessionOptions(config)), sess_opts_(GetSessionOptions(config)),
@@ -33,7 +42,6 @@ class OfflineSpeakerSegmentationPyannoteModel::Impl {
auto buf = ReadFile(mgr, config_.pyannote.model); auto buf = ReadFile(mgr, config_.pyannote.model);
Init(buf.data(), buf.size()); Init(buf.data(), buf.size());
} }
#endif
const OfflineSpeakerSegmentationPyannoteModelMetaData &GetModelMetaData() const OfflineSpeakerSegmentationPyannoteModelMetaData &GetModelMetaData()
const { const {
@@ -61,7 +69,11 @@ class OfflineSpeakerSegmentationPyannoteModel::Impl {
if (config_.debug) { if (config_.debug) {
std::ostringstream os; std::ostringstream os;
PrintModelMetadata(os, meta_data); PrintModelMetadata(os, meta_data);
#if __OHOS__
SHERPA_ONNX_LOGE("%{public}s\n", os.str().c_str());
#else
SHERPA_ONNX_LOGE("%s\n", os.str().c_str()); SHERPA_ONNX_LOGE("%s\n", os.str().c_str());
#endif
} }
Ort::AllocatorWithDefaultOptions allocator; // used in the macro below Ort::AllocatorWithDefaultOptions allocator; // used in the macro below
@@ -103,12 +115,11 @@ OfflineSpeakerSegmentationPyannoteModel::
const OfflineSpeakerSegmentationModelConfig &config) const OfflineSpeakerSegmentationModelConfig &config)
: impl_(std::make_unique<Impl>(config)) {} : impl_(std::make_unique<Impl>(config)) {}
#if __ANDROID_API__ >= 9 template <typename Manager>
OfflineSpeakerSegmentationPyannoteModel:: OfflineSpeakerSegmentationPyannoteModel::
OfflineSpeakerSegmentationPyannoteModel( OfflineSpeakerSegmentationPyannoteModel(
AAssetManager *mgr, const OfflineSpeakerSegmentationModelConfig &config) Manager *mgr, const OfflineSpeakerSegmentationModelConfig &config)
: impl_(std::make_unique<Impl>(mgr, config)) {} : impl_(std::make_unique<Impl>(mgr, config)) {}
#endif
OfflineSpeakerSegmentationPyannoteModel:: OfflineSpeakerSegmentationPyannoteModel::
~OfflineSpeakerSegmentationPyannoteModel() = default; ~OfflineSpeakerSegmentationPyannoteModel() = default;
@@ -123,4 +134,18 @@ Ort::Value OfflineSpeakerSegmentationPyannoteModel::Forward(
return impl_->Forward(std::move(x)); return impl_->Forward(std::move(x));
} }
#if __ANDROID_API__ >= 9
template OfflineSpeakerSegmentationPyannoteModel::
OfflineSpeakerSegmentationPyannoteModel(
AAssetManager *mgr,
const OfflineSpeakerSegmentationModelConfig &config);
#endif
#if __OHOS__
template OfflineSpeakerSegmentationPyannoteModel::
OfflineSpeakerSegmentationPyannoteModel(
NativeResourceManager *mgr,
const OfflineSpeakerSegmentationModelConfig &config);
#endif
} // namespace sherpa_onnx } // namespace sherpa_onnx

View File

@@ -6,11 +6,6 @@
#include <memory> #include <memory>
#if __ANDROID_API__ >= 9
#include "android/asset_manager.h"
#include "android/asset_manager_jni.h"
#endif
#include "onnxruntime_cxx_api.h" // NOLINT #include "onnxruntime_cxx_api.h" // NOLINT
#include "sherpa-onnx/csrc/offline-speaker-segmentation-model-config.h" #include "sherpa-onnx/csrc/offline-speaker-segmentation-model-config.h"
#include "sherpa-onnx/csrc/offline-speaker-segmentation-pyannote-model-meta-data.h" #include "sherpa-onnx/csrc/offline-speaker-segmentation-pyannote-model-meta-data.h"
@@ -22,10 +17,9 @@ class OfflineSpeakerSegmentationPyannoteModel {
explicit OfflineSpeakerSegmentationPyannoteModel( explicit OfflineSpeakerSegmentationPyannoteModel(
const OfflineSpeakerSegmentationModelConfig &config); const OfflineSpeakerSegmentationModelConfig &config);
#if __ANDROID_API__ >= 9 template <typename Manager>
OfflineSpeakerSegmentationPyannoteModel( OfflineSpeakerSegmentationPyannoteModel(
AAssetManager *mgr, const OfflineSpeakerSegmentationModelConfig &config); Manager *mgr, const OfflineSpeakerSegmentationModelConfig &config);
#endif
~OfflineSpeakerSegmentationPyannoteModel(); ~OfflineSpeakerSegmentationPyannoteModel();