diff --git a/.github/workflows/test-go-package.yaml b/.github/workflows/test-go-package.yaml index 2634e5ca..6839b0ba 100644 --- a/.github/workflows/test-go-package.yaml +++ b/.github/workflows/test-go-package.yaml @@ -68,6 +68,50 @@ jobs: run: | gcc --version + - name: Test non-streaming speaker diarization + if: matrix.os != 'windows-latest' + shell: bash + run: | + cd go-api-examples/non-streaming-speaker-diarization/ + ./run.sh + + - name: Test non-streaming speaker diarization + if: matrix.os == 'windows-latest' && matrix.arch == 'x64' + shell: bash + run: | + cd go-api-examples/non-streaming-speaker-diarization/ + go mod tidy + cat go.mod + go build + + echo $PWD + ls -lh /C/Users/runneradmin/go/pkg/mod/github.com/k2-fsa/ + ls -lh /C/Users/runneradmin/go/pkg/mod/github.com/k2-fsa/* + cp -v /C/Users/runneradmin/go/pkg/mod/github.com/k2-fsa/sherpa-onnx-go-windows*/lib/x86_64-pc-windows-gnu/*.dll . + + ./run.sh + + - name: Test non-streaming speaker diarization + if: matrix.os == 'windows-latest' && matrix.arch == 'x86' + shell: bash + run: | + cd go-api-examples/non-streaming-speaker-diarization/ + + go env GOARCH + go env -w GOARCH=386 + go env -w CGO_ENABLED=1 + + go mod tidy + cat go.mod + go build + + echo $PWD + ls -lh /C/Users/runneradmin/go/pkg/mod/github.com/k2-fsa/ + ls -lh /C/Users/runneradmin/go/pkg/mod/github.com/k2-fsa/* + cp -v /C/Users/runneradmin/go/pkg/mod/github.com/k2-fsa/sherpa-onnx-go-windows*/lib/i686-pc-windows-gnu/*.dll . + + ./run.sh + - name: Test streaming HLG decoding (Linux/macOS) if: matrix.os != 'windows-latest' shell: bash diff --git a/.github/workflows/test-go.yaml b/.github/workflows/test-go.yaml index 65c72e17..ccb3eb8d 100644 --- a/.github/workflows/test-go.yaml +++ b/.github/workflows/test-go.yaml @@ -134,6 +134,12 @@ jobs: name: ${{ matrix.os }}-libs path: to-upload/ + - name: Test non-streaming speaker diarization + shell: bash + run: | + cd scripts/go/_internal/non-streaming-speaker-diarization/ + ./run.sh + - name: Test speaker identification shell: bash run: | diff --git a/go-api-examples/non-streaming-speaker-diarization/go.mod b/go-api-examples/non-streaming-speaker-diarization/go.mod new file mode 100644 index 00000000..39edcecf --- /dev/null +++ b/go-api-examples/non-streaming-speaker-diarization/go.mod @@ -0,0 +1,3 @@ +module non-streaming-speaker-diarization + +go 1.12 diff --git a/go-api-examples/non-streaming-speaker-diarization/main.go b/go-api-examples/non-streaming-speaker-diarization/main.go new file mode 100644 index 00000000..7b975bf6 --- /dev/null +++ b/go-api-examples/non-streaming-speaker-diarization/main.go @@ -0,0 +1,87 @@ +package main + +import ( + sherpa "github.com/k2-fsa/sherpa-onnx-go/sherpa_onnx" + "log" +) + +/* +Usage: + +Step 1: Download a speaker segmentation model + +Please visit https://github.com/k2-fsa/sherpa-onnx/releases/tag/speaker-segmentation-models +for a list of available models. The following is an example + + wget https://github.com/k2-fsa/sherpa-onnx/releases/download/speaker-segmentation-models/sherpa-onnx-pyannote-segmentation-3-0.tar.bz2 + tar xvf sherpa-onnx-pyannote-segmentation-3-0.tar.bz2 + rm sherpa-onnx-pyannote-segmentation-3-0.tar.bz2 + +Step 2: Download a speaker embedding extractor model + +Please visit https://github.com/k2-fsa/sherpa-onnx/releases/tag/speaker-recongition-models +for a list of available models. The following is an example + + wget https://github.com/k2-fsa/sherpa-onnx/releases/download/speaker-recongition-models/3dspeaker_speech_eres2net_base_sv_zh-cn_3dspeaker_16k.onnx + +Step 3. Download test wave files + +Please visit https://github.com/k2-fsa/sherpa-onnx/releases/tag/speaker-segmentation-models +for a list of available test wave files. The following is an example + + wget https://github.com/k2-fsa/sherpa-onnx/releases/download/speaker-segmentation-models/0-four-speakers-zh.wav + +Step 4. Run it +*/ + +func initSpeakerDiarization() *sherpa.OfflineSpeakerDiarization { + config := sherpa.OfflineSpeakerDiarizationConfig{} + + config.Segmentation.Pyannote.Model = "./sherpa-onnx-pyannote-segmentation-3-0/model.onnx" + config.Embedding.Model = "./3dspeaker_speech_eres2net_base_sv_zh-cn_3dspeaker_16k.onnx" + + // The test wave file contains 4 speakers, so we use 4 here + config.Clustering.NumClusters = 4 + + // if you don't know the actual numbers in the wave file, + // then please don't set NumClusters; you need to use + // + // config.Clustering.Threshold = 0.5 + // + + // A larger Threshold leads to fewer clusters + // A smaller Threshold leads to more clusters + + sd := sherpa.NewOfflineSpeakerDiarization(&config) + return sd +} + +func main() { + wave_filename := "./0-four-speakers-zh.wav" + wave := sherpa.ReadWave(wave_filename) + if wave == nil { + log.Printf("Failed to read %v", wave_filename) + return + } + + sd := initSpeakerDiarization() + if sd == nil { + log.Printf("Please check your config") + return + } + + defer sherpa.DeleteOfflineSpeakerDiarization(sd) + + if wave.SampleRate != sd.SampleRate() { + log.Printf("Expected sample rate: %v, given: %d\n", sd.SampleRate(), wave.SampleRate) + return + } + + log.Println("Started") + segments := sd.Process(wave.Samples) + n := len(segments) + + for i := 0; i < n; i++ { + log.Printf("%.3f -- %.3f speaker_%02d\n", segments[i].Start, segments[i].End, segments[i].Speaker) + } +} diff --git a/go-api-examples/non-streaming-speaker-diarization/run.sh b/go-api-examples/non-streaming-speaker-diarization/run.sh new file mode 100755 index 00000000..1ebfd4aa --- /dev/null +++ b/go-api-examples/non-streaming-speaker-diarization/run.sh @@ -0,0 +1,20 @@ +#!/usr/bin/env bash + + +if [ ! -f ./sherpa-onnx-pyannote-segmentation-3-0/model.onnx ]; then + curl -SL -O https://github.com/k2-fsa/sherpa-onnx/releases/download/speaker-segmentation-models/sherpa-onnx-pyannote-segmentation-3-0.tar.bz2 + tar xvf sherpa-onnx-pyannote-segmentation-3-0.tar.bz2 + rm sherpa-onnx-pyannote-segmentation-3-0.tar.bz2 +fi + +if [ ! -f ./3dspeaker_speech_eres2net_base_sv_zh-cn_3dspeaker_16k.onnx ]; then + curl -SL -O https://github.com/k2-fsa/sherpa-onnx/releases/download/speaker-recongition-models/3dspeaker_speech_eres2net_base_sv_zh-cn_3dspeaker_16k.onnx +fi + +if [ ! -f ./0-four-speakers-zh.wav ]; then + curl -SL -O https://github.com/k2-fsa/sherpa-onnx/releases/download/speaker-segmentation-models/0-four-speakers-zh.wav +fi + +go mod tidy +go build +./non-streaming-speaker-diarization diff --git a/scripts/go/_internal/non-streaming-speaker-diarization/go.mod b/scripts/go/_internal/non-streaming-speaker-diarization/go.mod new file mode 100644 index 00000000..38ae36d4 --- /dev/null +++ b/scripts/go/_internal/non-streaming-speaker-diarization/go.mod @@ -0,0 +1,5 @@ +module non-streaming-speaker-diarization + +go 1.12 + +replace github.com/k2-fsa/sherpa-onnx-go/sherpa_onnx => ../ diff --git a/scripts/go/_internal/non-streaming-speaker-diarization/main.go b/scripts/go/_internal/non-streaming-speaker-diarization/main.go new file mode 120000 index 00000000..2e4da65e --- /dev/null +++ b/scripts/go/_internal/non-streaming-speaker-diarization/main.go @@ -0,0 +1 @@ +../../../../go-api-examples/non-streaming-speaker-diarization/main.go \ No newline at end of file diff --git a/scripts/go/_internal/non-streaming-speaker-diarization/run.sh b/scripts/go/_internal/non-streaming-speaker-diarization/run.sh new file mode 120000 index 00000000..0746440f --- /dev/null +++ b/scripts/go/_internal/non-streaming-speaker-diarization/run.sh @@ -0,0 +1 @@ +../../../../go-api-examples/non-streaming-speaker-diarization/run.sh \ No newline at end of file diff --git a/scripts/go/sherpa_onnx.go b/scripts/go/sherpa_onnx.go index 85ab8e0b..b8b9e6ee 100644 --- a/scripts/go/sherpa_onnx.go +++ b/scripts/go/sherpa_onnx.go @@ -1175,7 +1175,14 @@ func ReadWave(filename string) *Wave { w := C.SherpaOnnxReadWave(s) defer C.SherpaOnnxFreeWave(w) + if w == nil { + return nil + } + n := int(w.num_samples) + if n == 0 { + return nil + } ans := &Wave{} ans.SampleRate = int(w.sample_rate) @@ -1189,3 +1196,114 @@ func ReadWave(filename string) *Wave { return ans } + +// ============================================================ +// For offline speaker diarization +// ============================================================ +type OfflineSpeakerSegmentationPyannoteModelConfig struct { + Model string +} + +type OfflineSpeakerSegmentationModelConfig struct { + Pyannote OfflineSpeakerSegmentationPyannoteModelConfig + NumThreads int + Debug int + Provider string +} + +type FastClusteringConfig struct { + NumClusters int + Threshold float32 +} + +type OfflineSpeakerDiarizationConfig struct { + Segmentation OfflineSpeakerSegmentationModelConfig + Embedding SpeakerEmbeddingExtractorConfig + Clustering FastClusteringConfig + MinDurationOn float32 + MinDurationOff float32 +} + +type OfflineSpeakerDiarization struct { + impl *C.struct_SherpaOnnxOfflineSpeakerDiarization +} + +func DeleteOfflineSpeakerDiarization(sd *OfflineSpeakerDiarization) { + C.SherpaOnnxDestroyOfflineSpeakerDiarization(sd.impl) + sd.impl = nil +} + +func NewOfflineSpeakerDiarization(config *OfflineSpeakerDiarizationConfig) *OfflineSpeakerDiarization { + c := C.struct_SherpaOnnxOfflineSpeakerDiarizationConfig{} + c.segmentation.pyannote.model = C.CString(config.Segmentation.Pyannote.Model) + defer C.free(unsafe.Pointer(c.segmentation.pyannote.model)) + + c.segmentation.num_threads = C.int(config.Segmentation.NumThreads) + + c.segmentation.debug = C.int(config.Segmentation.Debug) + + c.segmentation.provider = C.CString(config.Segmentation.Provider) + defer C.free(unsafe.Pointer(c.segmentation.provider)) + + c.embedding.model = C.CString(config.Embedding.Model) + defer C.free(unsafe.Pointer(c.embedding.model)) + + c.embedding.num_threads = C.int(config.Embedding.NumThreads) + + c.embedding.debug = C.int(config.Embedding.Debug) + + c.embedding.provider = C.CString(config.Embedding.Provider) + defer C.free(unsafe.Pointer(c.embedding.provider)) + + c.clustering.num_clusters = C.int(config.Clustering.NumClusters) + c.clustering.threshold = C.float(config.Clustering.Threshold) + c.min_duration_on = C.float(config.MinDurationOn) + c.min_duration_off = C.float(config.MinDurationOff) + + p := C.SherpaOnnxCreateOfflineSpeakerDiarization(&c) + + if p == nil { + return nil + } + + sd := &OfflineSpeakerDiarization{} + sd.impl = p + + return sd +} + +func (sd *OfflineSpeakerDiarization) SampleRate() int { + return int(C.SherpaOnnxOfflineSpeakerDiarizationGetSampleRate(sd.impl)) +} + +type OfflineSpeakerDiarizationSegment struct { + Start float32 + End float32 + Speaker int +} + +func (sd *OfflineSpeakerDiarization) Process(samples []float32) []OfflineSpeakerDiarizationSegment { + r := C.SherpaOnnxOfflineSpeakerDiarizationProcess(sd.impl, (*C.float)(&samples[0]), C.int(len(samples))) + defer C.SherpaOnnxOfflineSpeakerDiarizationDestroyResult(r) + + n := int(C.SherpaOnnxOfflineSpeakerDiarizationResultGetNumSegments(r)) + + if n == 0 { + return nil + } + + s := C.SherpaOnnxOfflineSpeakerDiarizationResultSortByStartTime(r) + defer C.SherpaOnnxOfflineSpeakerDiarizationDestroySegment(s) + + ans := make([]OfflineSpeakerDiarizationSegment, n) + + p := (*[1 << 28]C.struct_SherpaOnnxOfflineSpeakerDiarizationSegment)(unsafe.Pointer(s))[:n:n] + + for i := 0; i < n; i++ { + ans[i].Start = float32(p[i].start) + ans[i].End = float32(p[i].end) + ans[i].Speaker = int(p[i].speaker) + } + + return ans +}