JavaScript API (node-addon) for speaker diarization (#1408)

This commit is contained in:
Fangjun Kuang
2024-10-10 15:51:31 +08:00
committed by GitHub
parent a45e5dba99
commit 67349b52f2
11 changed files with 443 additions and 13 deletions

View File

@@ -21,6 +21,7 @@ set(srcs
src/audio-tagging.cc
src/keyword-spotting.cc
src/non-streaming-asr.cc
src/non-streaming-speaker-diarization.cc
src/non-streaming-tts.cc
src/punctuation.cc
src/sherpa-onnx-node-addon-api.cc

View File

@@ -0,0 +1,32 @@
const addon = require('./addon.js');
class OfflineSpeakerDiarization {
constructor(config) {
this.handle = addon.createOfflineSpeakerDiarization(config);
this.config = config;
this.sampleRate = addon.getOfflineSpeakerDiarizationSampleRate(this.handle);
}
/**
* samples is a 1-d float32 array. Each element of the array should be
* in the range [-1, 1].
*
* We assume its sample rate equals to this.sampleRate.
*
* Returns an array of object, where an object is
*
* {
* "start": start_time_in_seconds,
* "end": end_time_in_seconds,
* "speaker": an_integer,
* }
*/
process(samples) {
return addon.offlineSpeakerDiarizationProcess(this.handle, samples);
}
}
module.exports = {
OfflineSpeakerDiarization,
}

View File

@@ -8,6 +8,7 @@ const sid = require('./speaker-identification.js');
const at = require('./audio-tagg.js');
const punct = require('./punctuation.js');
const kws = require('./keyword-spotter.js');
const sd = require('./non-streaming-speaker-diarization.js');
module.exports = {
OnlineRecognizer: streaming_asr.OnlineRecognizer,
@@ -24,4 +25,5 @@ module.exports = {
AudioTagging: at.AudioTagging,
Punctuation: punct.Punctuation,
KeywordSpotter: kws.KeywordSpotter,
OfflineSpeakerDiarization: sd.OfflineSpeakerDiarization,
}

View File

@@ -1,7 +1,7 @@
{
"main": "lib/sherpa-onnx.js",
"version": "1.0.0",
"description": "Speech-to-text and text-to-speech using Next-gen Kaldi without internet connection",
"description": "Speech-to-text, text-to-speech, and speaker diarization using Next-gen Kaldi without internet connection",
"dependencies": {
"cmake-js": "^6.0.0",
"node-addon-api": "^1.1.0",
@@ -21,8 +21,18 @@
"transcription",
"real-time speech recognition",
"without internet connection",
"locally",
"local",
"embedded systems",
"open source",
"diarization",
"speaker diarization",
"speaker recognition",
"speaker",
"speaker segmentation",
"speaker verification",
"spoken language identification",
"sherpa",
"zipformer",
"asr",
"tts",

View File

@@ -0,0 +1,265 @@
// scripts/node-addon-api/src/non-streaming-speaker-diarization.cc
//
// Copyright (c) 2024 Xiaomi Corporation
#include <algorithm>
#include <sstream>
#include "macros.h" // NOLINT
#include "napi.h" // NOLINT
#include "sherpa-onnx/c-api/c-api.h"
static SherpaOnnxOfflineSpeakerSegmentationPyannoteModelConfig
GetOfflineSpeakerSegmentationPyannoteModelConfig(Napi::Object obj) {
SherpaOnnxOfflineSpeakerSegmentationPyannoteModelConfig c;
memset(&c, 0, sizeof(c));
if (!obj.Has("pyannote") || !obj.Get("pyannote").IsObject()) {
return c;
}
Napi::Object o = obj.Get("pyannote").As<Napi::Object>();
SHERPA_ONNX_ASSIGN_ATTR_STR(model, model);
return c;
}
static SherpaOnnxOfflineSpeakerSegmentationModelConfig
GetOfflineSpeakerSegmentationModelConfig(Napi::Object obj) {
SherpaOnnxOfflineSpeakerSegmentationModelConfig c;
memset(&c, 0, sizeof(c));
if (!obj.Has("segmentation") || !obj.Get("segmentation").IsObject()) {
return c;
}
Napi::Object o = obj.Get("segmentation").As<Napi::Object>();
c.pyannote = GetOfflineSpeakerSegmentationPyannoteModelConfig(o);
SHERPA_ONNX_ASSIGN_ATTR_INT32(num_threads, numThreads);
if (o.Has("debug") &&
(o.Get("debug").IsNumber() || o.Get("debug").IsBoolean())) {
if (o.Get("debug").IsBoolean()) {
c.debug = o.Get("debug").As<Napi::Boolean>().Value();
} else {
c.debug = o.Get("debug").As<Napi::Number>().Int32Value();
}
}
SHERPA_ONNX_ASSIGN_ATTR_STR(provider, provider);
return c;
}
static SherpaOnnxSpeakerEmbeddingExtractorConfig
GetSpeakerEmbeddingExtractorConfig(Napi::Object obj) {
SherpaOnnxSpeakerEmbeddingExtractorConfig c;
memset(&c, 0, sizeof(c));
if (!obj.Has("embedding") || !obj.Get("embedding").IsObject()) {
return c;
}
Napi::Object o = obj.Get("embedding").As<Napi::Object>();
SHERPA_ONNX_ASSIGN_ATTR_STR(model, model);
SHERPA_ONNX_ASSIGN_ATTR_INT32(num_threads, numThreads);
if (o.Has("debug") &&
(o.Get("debug").IsNumber() || o.Get("debug").IsBoolean())) {
if (o.Get("debug").IsBoolean()) {
c.debug = o.Get("debug").As<Napi::Boolean>().Value();
} else {
c.debug = o.Get("debug").As<Napi::Number>().Int32Value();
}
}
SHERPA_ONNX_ASSIGN_ATTR_STR(provider, provider);
return c;
}
static SherpaOnnxFastClusteringConfig GetFastClusteringConfig(
Napi::Object obj) {
SherpaOnnxFastClusteringConfig c;
memset(&c, 0, sizeof(c));
if (!obj.Has("clustering") || !obj.Get("clustering").IsObject()) {
return c;
}
Napi::Object o = obj.Get("clustering").As<Napi::Object>();
SHERPA_ONNX_ASSIGN_ATTR_INT32(num_clusters, numClusters);
SHERPA_ONNX_ASSIGN_ATTR_FLOAT(threshold, threshold);
return c;
}
static Napi::External<SherpaOnnxOfflineSpeakerDiarization>
CreateOfflineSpeakerDiarizationWrapper(const Napi::CallbackInfo &info) {
Napi::Env env = info.Env();
if (info.Length() != 1) {
std::ostringstream os;
os << "Expect only 1 argument. Given: " << info.Length();
Napi::TypeError::New(env, os.str()).ThrowAsJavaScriptException();
return {};
}
if (!info[0].IsObject()) {
Napi::TypeError::New(env, "Expect an object as the argument")
.ThrowAsJavaScriptException();
return {};
}
Napi::Object o = info[0].As<Napi::Object>();
SherpaOnnxOfflineSpeakerDiarizationConfig c;
memset(&c, 0, sizeof(c));
c.segmentation = GetOfflineSpeakerSegmentationModelConfig(o);
c.embedding = GetSpeakerEmbeddingExtractorConfig(o);
c.clustering = GetFastClusteringConfig(o);
SHERPA_ONNX_ASSIGN_ATTR_FLOAT(min_duration_on, minDurationOn);
SHERPA_ONNX_ASSIGN_ATTR_FLOAT(min_duration_off, minDurationOff);
const SherpaOnnxOfflineSpeakerDiarization *sd =
SherpaOnnxCreateOfflineSpeakerDiarization(&c);
if (c.segmentation.pyannote.model) {
delete[] c.segmentation.pyannote.model;
}
if (c.segmentation.provider) {
delete[] c.segmentation.provider;
}
if (c.embedding.model) {
delete[] c.embedding.model;
}
if (c.embedding.provider) {
delete[] c.embedding.provider;
}
if (!sd) {
Napi::TypeError::New(env, "Please check your config!")
.ThrowAsJavaScriptException();
return {};
}
return Napi::External<SherpaOnnxOfflineSpeakerDiarization>::New(
env, const_cast<SherpaOnnxOfflineSpeakerDiarization *>(sd),
[](Napi::Env env, SherpaOnnxOfflineSpeakerDiarization *sd) {
SherpaOnnxDestroyOfflineSpeakerDiarization(sd);
});
}
static Napi::Number OfflineSpeakerDiarizationGetSampleRateWrapper(
const Napi::CallbackInfo &info) {
Napi::Env env = info.Env();
if (info.Length() != 1) {
std::ostringstream os;
os << "Expect only 1 argument. Given: " << info.Length();
Napi::TypeError::New(env, os.str()).ThrowAsJavaScriptException();
return {};
}
if (!info[0].IsExternal()) {
Napi::TypeError::New(
env, "Argument 0 should be an offline speaker diarization pointer.")
.ThrowAsJavaScriptException();
return {};
}
const SherpaOnnxOfflineSpeakerDiarization *sd =
info[0].As<Napi::External<SherpaOnnxOfflineSpeakerDiarization>>().Data();
int32_t sample_rate = SherpaOnnxOfflineSpeakerDiarizationGetSampleRate(sd);
return Napi::Number::New(env, sample_rate);
}
static Napi::Array OfflineSpeakerDiarizationProcessWrapper(
const Napi::CallbackInfo &info) {
Napi::Env env = info.Env();
if (info.Length() != 2) {
std::ostringstream os;
os << "Expect only 2 arguments. Given: " << info.Length();
Napi::TypeError::New(env, os.str()).ThrowAsJavaScriptException();
return {};
}
if (!info[0].IsExternal()) {
Napi::TypeError::New(
env, "Argument 0 should be an offline speaker diarization pointer.")
.ThrowAsJavaScriptException();
return {};
}
const SherpaOnnxOfflineSpeakerDiarization *sd =
info[0].As<Napi::External<SherpaOnnxOfflineSpeakerDiarization>>().Data();
if (!info[1].IsTypedArray()) {
Napi::TypeError::New(env, "Argument 1 should be a typed array")
.ThrowAsJavaScriptException();
return {};
}
Napi::Float32Array samples = info[1].As<Napi::Float32Array>();
const SherpaOnnxOfflineSpeakerDiarizationResult *r =
SherpaOnnxOfflineSpeakerDiarizationProcess(sd, samples.Data(),
samples.ElementLength());
int32_t num_segments =
SherpaOnnxOfflineSpeakerDiarizationResultGetNumSegments(r);
const SherpaOnnxOfflineSpeakerDiarizationSegment *segments =
SherpaOnnxOfflineSpeakerDiarizationResultSortByStartTime(r);
Napi::Array ans = Napi::Array::New(env, num_segments);
for (int32_t i = 0; i != num_segments; ++i) {
Napi::Object obj = Napi::Object::New(env);
obj.Set(Napi::String::New(env, "start"), segments[i].start);
obj.Set(Napi::String::New(env, "end"), segments[i].end);
obj.Set(Napi::String::New(env, "speaker"), segments[i].speaker);
ans[i] = obj;
}
SherpaOnnxOfflineSpeakerDiarizationDestroySegment(segments);
SherpaOnnxOfflineSpeakerDiarizationDestroyResult(r);
return ans;
}
void InitNonStreamingSpeakerDiarization(Napi::Env env, Napi::Object exports) {
exports.Set(Napi::String::New(env, "createOfflineSpeakerDiarization"),
Napi::Function::New(env, CreateOfflineSpeakerDiarizationWrapper));
exports.Set(
Napi::String::New(env, "getOfflineSpeakerDiarizationSampleRate"),
Napi::Function::New(env, OfflineSpeakerDiarizationGetSampleRateWrapper));
exports.Set(
Napi::String::New(env, "offlineSpeakerDiarizationProcess"),
Napi::Function::New(env, OfflineSpeakerDiarizationProcessWrapper));
}

View File

@@ -25,6 +25,8 @@ void InitPunctuation(Napi::Env env, Napi::Object exports);
void InitKeywordSpotting(Napi::Env env, Napi::Object exports);
void InitNonStreamingSpeakerDiarization(Napi::Env env, Napi::Object exports);
Napi::Object Init(Napi::Env env, Napi::Object exports) {
InitStreamingAsr(env, exports);
InitNonStreamingAsr(env, exports);
@@ -37,6 +39,7 @@ Napi::Object Init(Napi::Env env, Napi::Object exports) {
InitAudioTagging(env, exports);
InitPunctuation(env, exports);
InitKeywordSpotting(env, exports);
InitNonStreamingSpeakerDiarization(env, exports);
return exports;
}