WebAssembly exmaple for speaker diarization (#1411)

This commit is contained in:
Fangjun Kuang
2024-10-10 22:14:45 +08:00
committed by GitHub
parent 67349b52f2
commit 1d061df355
37 changed files with 1008 additions and 24 deletions

View File

@@ -18,6 +18,10 @@ if(SHERPA_ONNX_ENABLE_WASM_VAD_ASR)
add_subdirectory(vad-asr)
endif()
if(SHERPA_ONNX_ENABLE_WASM_SPEAKER_DIARIZATION)
add_subdirectory(speaker-diarization)
endif()
if(SHERPA_ONNX_ENABLE_WASM_NODEJS)
add_subdirectory(nodejs)
endif()

View File

@@ -0,0 +1,61 @@
if(NOT $ENV{SHERPA_ONNX_IS_USING_BUILD_WASM_SH})
message(FATAL_ERROR "Please use ./build-wasm-simd-speaker-diarization.sh to build for WASM for speaker diarization")
endif()
if(NOT EXISTS "${CMAKE_CURRENT_SOURCE_DIR}/assets/segmentation.onnx" OR NOT EXISTS "${CMAKE_CURRENT_SOURCE_DIR}/assets/embedding.onnx")
message(FATAL_ERROR "Please read ${CMAKE_CURRENT_SOURCE_DIR}/assets/README.md before you continue")
endif()
set(exported_functions
MyPrint
SherpaOnnxCreateOfflineSpeakerDiarization
SherpaOnnxDestroyOfflineSpeakerDiarization
SherpaOnnxOfflineSpeakerDiarizationDestroyResult
SherpaOnnxOfflineSpeakerDiarizationDestroySegment
SherpaOnnxOfflineSpeakerDiarizationGetSampleRate
SherpaOnnxOfflineSpeakerDiarizationProcess
SherpaOnnxOfflineSpeakerDiarizationProcessWithCallback
SherpaOnnxOfflineSpeakerDiarizationResultGetNumSegments
SherpaOnnxOfflineSpeakerDiarizationResultSortByStartTime
SherpaOnnxOfflineSpeakerDiarizationSetConfig
)
set(mangled_exported_functions)
foreach(x IN LISTS exported_functions)
list(APPEND mangled_exported_functions "_${x}")
endforeach()
list(JOIN mangled_exported_functions "," all_exported_functions)
include_directories(${CMAKE_SOURCE_DIR})
set(MY_FLAGS " -s FORCE_FILESYSTEM=1 -s INITIAL_MEMORY=512MB -s ALLOW_MEMORY_GROWTH=1")
string(APPEND MY_FLAGS " -sSTACK_SIZE=10485760 ") # 10MB
string(APPEND MY_FLAGS " -sEXPORTED_FUNCTIONS=[_CopyHeap,_malloc,_free,${all_exported_functions}] ")
string(APPEND MY_FLAGS "--preload-file ${CMAKE_CURRENT_SOURCE_DIR}/assets@. ")
string(APPEND MY_FLAGS " -sEXPORTED_RUNTIME_METHODS=['ccall','stringToUTF8','setValue','getValue','lengthBytesUTF8','UTF8ToString'] ")
message(STATUS "MY_FLAGS: ${MY_FLAGS}")
set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS} ${MY_FLAGS}")
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} ${MY_FLAGS}")
set(CMAKE_EXECUTBLE_LINKER_FLAGS "${CMAKE_EXECUTBLE_LINKER_FLAGS} ${MY_FLAGS}")
if (NOT CMAKE_EXECUTABLE_SUFFIX STREQUAL ".js")
message(FATAL_ERROR "The default suffix for building executables should be .js!")
endif()
# set(CMAKE_EXECUTABLE_SUFFIX ".html")
add_executable(sherpa-onnx-wasm-main-speaker-diarization sherpa-onnx-wasm-main-speaker-diarization.cc)
target_link_libraries(sherpa-onnx-wasm-main-speaker-diarization sherpa-onnx-c-api)
install(TARGETS sherpa-onnx-wasm-main-speaker-diarization DESTINATION bin/wasm/speaker-diarization)
install(
FILES
"$<TARGET_FILE_DIR:sherpa-onnx-wasm-main-speaker-diarization>/sherpa-onnx-wasm-main-speaker-diarization.js"
"index.html"
"sherpa-onnx-speaker-diarization.js"
"app-speaker-diarization.js"
"$<TARGET_FILE_DIR:sherpa-onnx-wasm-main-speaker-diarization>/sherpa-onnx-wasm-main-speaker-diarization.wasm"
"$<TARGET_FILE_DIR:sherpa-onnx-wasm-main-speaker-diarization>/sherpa-onnx-wasm-main-speaker-diarization.data"
DESTINATION
bin/wasm/speaker-diarization
)

View File

@@ -0,0 +1,124 @@
const startBtn = document.getElementById('startBtn');
const hint = document.getElementById('hint');
const numClustersInput = document.getElementById('numClustersInputID');
const thresholdInput = document.getElementById('thresholdInputID');
const textArea = document.getElementById('text');
const fileSelectCtrl = document.getElementById('file');
let sd = null;
let float32Samples = null;
Module = {};
Module.onRuntimeInitialized = function() {
console.log('Model files downloaded!');
console.log('Initializing speaker diarization ......');
sd = createOfflineSpeakerDiarization(Module)
console.log('sampleRate', sd.sampleRate);
hint.innerText =
'Initialized! Please select a wave file and click the Start button.';
fileSelectCtrl.disabled = false;
};
function onFileChange() {
var files = document.getElementById('file').files;
if (files.length == 0) {
console.log('No file selected');
float32Samples = null;
startBtn.disabled = true;
return;
}
textArea.value = '';
console.log('files: ' + files);
const file = files[0];
console.log(file);
console.log('file.name ' + file.name);
console.log('file.type ' + file.type);
console.log('file.size ' + file.size);
let audioCtx = new AudioContext({sampleRate: sd.sampleRate});
let reader = new FileReader();
reader.onload = function() {
console.log('reading file!');
audioCtx.decodeAudioData(reader.result, decodedDone);
};
function decodedDone(decoded) {
let typedArray = new Float32Array(decoded.length);
float32Samples = decoded.getChannelData(0);
startBtn.disabled = false;
}
reader.readAsArrayBuffer(file);
}
startBtn.onclick = function() {
textArea.value = '';
if (float32Samples == null) {
alert('Empty audio samples!');
startBtn.disabled = true;
return;
}
let numClusters = numClustersInput.value;
if (numClusters.trim().length == 0) {
alert(
'Please provide numClusters. Use -1 if you are not sure how many speakers are there');
return;
}
if (!numClusters.match(/^\d+$/)) {
alert(`number of clusters ${
numClusters} is not an integer .\nPlease enter an integer`);
return;
}
numClusters = parseInt(numClusters, 10);
if (numClusters < -1) {
alert(`Number of clusters should be >= -1`);
return;
}
let threshold = 0.5;
if (numClusters <= 0) {
threshold = thresholdInput.value;
if (threshold.trim().length == 0) {
alert('Please provide a threshold.');
return;
}
threshold = parseFloat(threshold);
if (threshold < 0) {
alert(`Pleaser enter a positive threshold`);
return;
}
}
let config = sd.config
config.clustering = {numClusters: numClusters, threshold: threshold};
sd.setConfig(config);
let segments = sd.process(float32Samples);
if (segments == null) {
textArea.value = 'No speakers detected';
return
}
let s = '';
let sep = '';
for (seg of segments) {
// clang-format off
s += sep + `${seg.start.toFixed(2)} -- ${seg.end.toFixed(2)} speaker_${seg.speaker}`
// clang-format on
sep = '\n';
}
textArea.value = s;
}

View File

@@ -0,0 +1,30 @@
# Introduction
Please refer to
https://github.com/k2-fsa/sherpa-onnx/releases/tag/speaker-segmentation-models
to download a speaker segmentation model
and
refer to
https://github.com/k2-fsa/sherpa-onnx/releases/tag/speaker-recongition-models
to download a speaker embedding extraction model.
Remember to rename the downloaded files.
The following is an example.
```bash
cd wasm/speaker-diarization/assets/
curl -SL -O https://github.com/k2-fsa/sherpa-onnx/releases/download/speaker-segmentation-models/sherpa-onnx-pyannote-segmentation-3-0.tar.bz2
tar xvf sherpa-onnx-pyannote-segmentation-3-0.tar.bz2
rm sherpa-onnx-pyannote-segmentation-3-0.tar.bz2
cp sherpa-onnx-pyannote-segmentation-3-0/model.onnx ./segmentation.onnx
rm -rf sherpa-onnx-pyannote-segmentation-3-0
curl -SL -O https://github.com/k2-fsa/sherpa-onnx/releases/download/speaker-recongition-models/3dspeaker_speech_eres2net_base_sv_zh-cn_3dspeaker_16k.onnx
mv 3dspeaker_speech_eres2net_base_sv_zh-cn_3dspeaker_16k.onnx ./embedding.onnx
```

View File

@@ -0,0 +1,48 @@
<html lang="en">
<head>
<meta charset="utf-8">
<meta name="viewport" content="width=device-width" />
<title>Next-gen Kaldi WebAssembly with sherpa-onnx for Speaker Diarization</title>
<style>
h1,div {
text-align: center;
}
textarea {
width:100%;
}
</style>
</head>
<body>
<h1>
Next-gen Kaldi + WebAssembly<br/>
Speaker Diarization <br> with <a href="https://github.com/k2-fsa/sherpa-onnx">sherpa-onnx</a>
</h1>
<div>
<span id="hint">Loading model ... ...</span>
<br/>
<br/>
<label for="avatar">Choose a wav file:</label>
<input type="file" id="file" accept=".wav" onchange="onFileChange()" disabled></input>
<br/>
<br/>
<label for="numClusters" id="numClustersID">Number of speakers: </label>
<input type="text" id="numClustersInputID" name="numClusters" value="-1" />
<br/>
<br/>
<label for="clusteringThreshold" id="thresholdID">Clustering threshold: </label>
<input type="text" id="thresholdInputID" name="clusteringThreshold" value="0.5" />
<br/>
<br/>
<textarea id="text" rows="10" placeholder="If you know the actual number of speakers in the input wave file, please provide it via Number of speakers. Otherwise, please leave Number of speakers to -1 and provide Clustering threshold instead. A larger threshold leads to fewer clusters, i.e., fewer speakers; a smaller threshold leads to more clusters, i.e., more speakers."></textarea>
<br/>
<br/>
<button id="startBtn" disabled>Start</button>
</div>
<script src="app-speaker-diarization.js"></script>
<script src="sherpa-onnx-speaker-diarization.js"></script>
<script src="sherpa-onnx-wasm-main-speaker-diarization.js"></script>
</body>

View File

@@ -0,0 +1,295 @@
function freeConfig(config, Module) {
if ('buffer' in config) {
Module._free(config.buffer);
}
if ('config' in config) {
freeConfig(config.config, Module)
}
if ('segmentation' in config) {
freeConfig(config.segmentation, Module)
}
if ('embedding' in config) {
freeConfig(config.embedding, Module)
}
if ('clustering' in config) {
freeConfig(config.clustering, Module)
}
Module._free(config.ptr);
}
function initSherpaOnnxOfflineSpeakerSegmentationPyannoteModelConfig(
config, Module) {
const modelLen = Module.lengthBytesUTF8(config.model || '') + 1;
const n = modelLen;
const buffer = Module._malloc(n);
const len = 1 * 4;
const ptr = Module._malloc(len);
let offset = 0;
Module.stringToUTF8(config.model || '', buffer + offset, modelLen);
offset += modelLen;
offset = 0;
Module.setValue(ptr, buffer + offset, 'i8*');
return {
buffer: buffer, ptr: ptr, len: len,
}
}
function initSherpaOnnxOfflineSpeakerSegmentationModelConfig(config, Module) {
if (!('pyannote' in config)) {
config.pyannote = {
model: '',
};
}
const pyannote = initSherpaOnnxOfflineSpeakerSegmentationPyannoteModelConfig(
config.pyannote, Module);
const len = pyannote.len + 3 * 4;
const ptr = Module._malloc(len);
let offset = 0;
Module._CopyHeap(pyannote.ptr, pyannote.len, ptr + offset);
offset += pyannote.len;
Module.setValue(ptr + offset, config.numThreads || 1, 'i32');
offset += 4;
Module.setValue(ptr + offset, config.debug || 1, 'i32');
offset += 4;
const providerLen = Module.lengthBytesUTF8(config.provider || 'cpu') + 1;
const buffer = Module._malloc(providerLen);
Module.stringToUTF8(config.provider || 'cpu', buffer, providerLen);
Module.setValue(ptr + offset, buffer, 'i8*');
return {
buffer: buffer,
ptr: ptr,
len: len,
config: pyannote,
};
}
function initSherpaOnnxSpeakerEmbeddingExtractorConfig(config, Module) {
const modelLen = Module.lengthBytesUTF8(config.model || '') + 1;
const providerLen = Module.lengthBytesUTF8(config.provider || 'cpu') + 1;
const n = modelLen + providerLen;
const buffer = Module._malloc(n);
const len = 4 * 4;
const ptr = Module._malloc(len);
let offset = 0;
Module.stringToUTF8(config.model || '', buffer + offset, modelLen);
offset += modelLen;
Module.stringToUTF8(config.provider || 'cpu', buffer + offset, providerLen);
offset += providerLen;
offset = 0
Module.setValue(ptr + offset, buffer, 'i8*');
offset += 4;
Module.setValue(ptr + offset, config.numThreads || 1, 'i32');
offset += 4;
Module.setValue(ptr + offset, config.debug || 1, 'i32');
offset += 4;
Module.setValue(ptr + offset, buffer + modelLen, 'i8*');
offset += 4;
return {
buffer: buffer,
ptr: ptr,
len: len,
};
}
function initSherpaOnnxFastClusteringConfig(config, Module) {
const len = 2 * 4;
const ptr = Module._malloc(len);
let offset = 0;
Module.setValue(ptr + offset, config.numClusters || -1, 'i32');
offset += 4;
Module.setValue(ptr + offset, config.threshold || 0.5, 'float');
offset += 4;
return {
ptr: ptr,
len: len,
};
}
function initSherpaOnnxOfflineSpeakerDiarizationConfig(config, Module) {
if (!('segmentation' in config)) {
config.segmentation = {
pyannote: {model: ''},
numThreads: 1,
debug: 0,
provider: 'cpu',
};
}
if (!('embedding' in config)) {
config.embedding = {
model: '',
numThreads: 1,
debug: 0,
provider: 'cpu',
};
}
if (!('clustering' in config)) {
config.clustering = {
numClusters: -1,
threshold: 0.5,
};
}
const segmentation = initSherpaOnnxOfflineSpeakerSegmentationModelConfig(
config.segmentation, Module);
const embedding =
initSherpaOnnxSpeakerEmbeddingExtractorConfig(config.embedding, Module);
const clustering =
initSherpaOnnxFastClusteringConfig(config.clustering, Module);
const len = segmentation.len + embedding.len + clustering.len + 2 * 4;
const ptr = Module._malloc(len);
let offset = 0;
Module._CopyHeap(segmentation.ptr, segmentation.len, ptr + offset);
offset += segmentation.len;
Module._CopyHeap(embedding.ptr, embedding.len, ptr + offset);
offset += embedding.len;
Module._CopyHeap(clustering.ptr, clustering.len, ptr + offset);
offset += clustering.len;
Module.setValue(ptr + offset, config.minDurationOn || 0.2, 'float');
offset += 4;
Module.setValue(ptr + offset, config.minDurationOff || 0.5, 'float');
offset += 4;
return {
ptr: ptr, len: len, segmentation: segmentation, embedding: embedding,
clustering: clustering,
}
}
class OfflineSpeakerDiarization {
constructor(configObj, Module) {
const config =
initSherpaOnnxOfflineSpeakerDiarizationConfig(configObj, Module)
// Module._MyPrint(config.ptr);
const handle =
Module._SherpaOnnxCreateOfflineSpeakerDiarization(config.ptr);
freeConfig(config, Module);
this.handle = handle;
this.sampleRate =
Module._SherpaOnnxOfflineSpeakerDiarizationGetSampleRate(this.handle);
this.Module = Module
this.config = configObj;
}
free() {
this.Module._SherpaOnnxDestroyOfflineSpeakerDiarization(this.handle);
this.handle = 0
}
setConfig(configObj) {
if (!('clustering' in configObj)) {
return;
}
const config =
initSherpaOnnxOfflineSpeakerDiarizationConfig(configObj, this.Module);
this.Module._SherpaOnnxOfflineSpeakerDiarizationSetConfig(
this.handle, config.ptr);
freeConfig(config, Module);
this.config.clustering = configObj.clustering;
}
process(samples) {
const pointer =
this.Module._malloc(samples.length * samples.BYTES_PER_ELEMENT);
this.Module.HEAPF32.set(samples, pointer / samples.BYTES_PER_ELEMENT);
let r = this.Module._SherpaOnnxOfflineSpeakerDiarizationProcess(
this.handle, pointer, samples.length);
this.Module._free(pointer);
let numSegments =
this.Module._SherpaOnnxOfflineSpeakerDiarizationResultGetNumSegments(r);
let segments =
this.Module._SherpaOnnxOfflineSpeakerDiarizationResultSortByStartTime(
r);
let ans = [];
let sizeOfSegment = 3 * 4;
for (let i = 0; i < numSegments; ++i) {
let p = segments + i * sizeOfSegment
let start = this.Module.HEAPF32[p / 4 + 0];
let end = this.Module.HEAPF32[p / 4 + 1];
let speaker = this.Module.HEAP32[p / 4 + 2];
ans.push({start: start, end: end, speaker: speaker});
}
this.Module._SherpaOnnxOfflineSpeakerDiarizationDestroySegment(segments);
this.Module._SherpaOnnxOfflineSpeakerDiarizationDestroyResult(r);
return ans;
}
}
function createOfflineSpeakerDiarization(Module, myConfig) {
const config = {
segmentation: {
pyannote: {model: './segmentation.onnx'},
},
embedding: {model: './embedding.onnx'},
clustering: {numClusters: -1, threshold: 0.5},
minDurationOn: 0.3,
minDurationOff: 0.5,
};
if (myConfig) {
config = myConfig;
}
return new OfflineSpeakerDiarization(config, Module);
}
if (typeof process == 'object' && typeof process.versions == 'object' &&
typeof process.versions.node == 'string') {
module.exports = {
createOfflineSpeakerDiarization,
};
}

View File

@@ -0,0 +1,63 @@
// wasm/sherpa-onnx-wasm-main-speaker-diarization.cc
//
// Copyright (c) 2024 Xiaomi Corporation
#include <stdio.h>
#include <algorithm>
#include <memory>
#include "sherpa-onnx/c-api/c-api.h"
// see also
// https://emscripten.org/docs/porting/connecting_cpp_and_javascript/Interacting-with-code.html
extern "C" {
static_assert(sizeof(SherpaOnnxOfflineSpeakerSegmentationPyannoteModelConfig) ==
1 * 4,
"");
static_assert(
sizeof(SherpaOnnxOfflineSpeakerSegmentationModelConfig) ==
sizeof(SherpaOnnxOfflineSpeakerSegmentationPyannoteModelConfig) + 3 * 4,
"");
static_assert(sizeof(SherpaOnnxFastClusteringConfig) == 2 * 4, "");
static_assert(sizeof(SherpaOnnxSpeakerEmbeddingExtractorConfig) == 4 * 4, "");
static_assert(sizeof(SherpaOnnxOfflineSpeakerDiarizationConfig) ==
sizeof(SherpaOnnxOfflineSpeakerSegmentationModelConfig) +
sizeof(SherpaOnnxSpeakerEmbeddingExtractorConfig) +
sizeof(SherpaOnnxFastClusteringConfig) + 2 * 4,
"");
void MyPrint(const SherpaOnnxOfflineSpeakerDiarizationConfig *sd_config) {
const auto &segmentation = sd_config->segmentation;
const auto &embedding = sd_config->embedding;
const auto &clustering = sd_config->clustering;
fprintf(stdout, "----------segmentation config----------\n");
fprintf(stdout, "pyannote model: %s\n", segmentation.pyannote.model);
fprintf(stdout, "num threads: %d\n", segmentation.num_threads);
fprintf(stdout, "debug: %d\n", segmentation.debug);
fprintf(stdout, "provider: %s\n", segmentation.provider);
fprintf(stdout, "----------embedding config----------\n");
fprintf(stdout, "model: %s\n", embedding.model);
fprintf(stdout, "num threads: %d\n", embedding.num_threads);
fprintf(stdout, "debug: %d\n", embedding.debug);
fprintf(stdout, "provider: %s\n", embedding.provider);
fprintf(stdout, "----------clustering config----------\n");
fprintf(stdout, "num_clusters: %d\n", clustering.num_clusters);
fprintf(stdout, "threshold: %.3f\n", clustering.threshold);
fprintf(stdout, "min_duration_on: %.3f\n", sd_config->min_duration_on);
fprintf(stdout, "min_duration_off: %.3f\n", sd_config->min_duration_off);
}
void CopyHeap(const char *src, int32_t num_bytes, char *dst) {
std::copy(src, src + num_bytes, dst);
}
}