diff --git a/.github/workflows/test-go-package.yaml b/.github/workflows/test-go-package.yaml index 6ef4235f..d761be4f 100644 --- a/.github/workflows/test-go-package.yaml +++ b/.github/workflows/test-go-package.yaml @@ -66,6 +66,12 @@ jobs: run: | gcc --version + - name: Test speaker identification + shell: bash + run: | + cd go-api-examples/speaker-identification + ./run.sh + - name: Test non-streaming TTS (Linux/macOS) if: matrix.os != 'windows-latest' shell: bash diff --git a/.github/workflows/test-go.yaml b/.github/workflows/test-go.yaml index ae9b6f00..298403ec 100644 --- a/.github/workflows/test-go.yaml +++ b/.github/workflows/test-go.yaml @@ -74,6 +74,12 @@ jobs: go mod tidy go build + - name: Test speaker identification + shell: bash + run: | + cd scripts/go/_internal/speaker-identification/ + ./run.sh + - name: Test non-streaming TTS (macOS) shell: bash run: | diff --git a/.gitignore b/.gitignore index ea1d57e9..c2c87424 100644 --- a/.gitignore +++ b/.gitignore @@ -88,3 +88,5 @@ vits-mms-* *.tar.bz2 sherpa-onnx-paraformer-trilingual-zh-cantonese-en sr-data +*xcworkspace/xcuserdata/* + diff --git a/go-api-examples/README.md b/go-api-examples/README.md index 51a44b38..91f2c76e 100644 --- a/go-api-examples/README.md +++ b/go-api-examples/README.md @@ -26,4 +26,8 @@ for details. - [./vad-spoken-language-identification](./vad-spoken-language-identification) It shows how to use silero VAD + Whisper for spoken language identification. +- [./speaker-identification](./speaker-identification) It shows how to use Go API for speaker identification. + +- [./vad-speaker-identification](./vad-speaker-identification) It shows how to use Go API for VAD + speaker identification. + [sherpa-onnx]: https://github.com/k2-fsa/sherpa-onnx diff --git a/go-api-examples/speaker-identification/go.mod b/go-api-examples/speaker-identification/go.mod new file mode 100644 index 00000000..188a7b98 --- /dev/null +++ b/go-api-examples/speaker-identification/go.mod @@ -0,0 +1,3 @@ +module speaker-identification + +go 1.12 diff --git a/go-api-examples/speaker-identification/main.go b/go-api-examples/speaker-identification/main.go new file mode 100644 index 00000000..7f2c791f --- /dev/null +++ b/go-api-examples/speaker-identification/main.go @@ -0,0 +1,146 @@ +package main + +import ( + sherpa "github.com/k2-fsa/sherpa-onnx-go/sherpa_onnx" + "log" +) + +func createSpeakerEmbeddingExtractor() *sherpa.SpeakerEmbeddingExtractor { + config := sherpa.SpeakerEmbeddingExtractorConfig{} + + // Please download the model from + // https://github.com/k2-fsa/sherpa-onnx/releases/download/speaker-recongition-models/3dspeaker_speech_campplus_sv_zh-cn_16k-common.onnx + // + // You can find more models at + // https://github.com/k2-fsa/sherpa-onnx/releases/tag/speaker-recongition-models + + config.Model = "./3dspeaker_speech_campplus_sv_zh-cn_16k-common.onnx" + config.NumThreads = 1 + config.Debug = 1 + config.Provider = "cpu" + + ex := sherpa.NewSpeakerEmbeddingExtractor(&config) + return ex +} + +func computeEmbeddings(ex *sherpa.SpeakerEmbeddingExtractor, files []string) [][]float32 { + embeddings := make([][]float32, len(files)) + + for i, f := range files { + wave := sherpa.ReadWave(f) + + stream := ex.CreateStream() + defer sherpa.DeleteOnlineStream(stream) + stream.AcceptWaveform(wave.SampleRate, wave.Samples) + stream.InputFinished() + embeddings[i] = ex.Compute(stream) + } + + return embeddings + +} + +func registerSpeakers(ex *sherpa.SpeakerEmbeddingExtractor, manager *sherpa.SpeakerEmbeddingManager) { + // Please download the test data from + // https://github.com/csukuangfj/sr-data + spk1_files := []string{ + "./sr-data/enroll/fangjun-sr-1.wav", + "./sr-data/enroll/fangjun-sr-2.wav", + "./sr-data/enroll/fangjun-sr-3.wav", + } + + spk2_files := []string{ + "./sr-data/enroll/leijun-sr-1.wav", + "./sr-data/enroll/leijun-sr-2.wav", + } + + spk1_embeddings := computeEmbeddings(ex, spk1_files) + spk2_embeddings := computeEmbeddings(ex, spk2_files) + + ok := manager.RegisterV("fangjun", spk1_embeddings) + if !ok { + panic("Failed to register fangjun") + } + + ok = manager.RegisterV("leijun", spk2_embeddings) + if !ok { + panic("Failed to register leijun") + } + + if !manager.Contains("fangjun") { + panic("Failed to find fangjun") + } + + if !manager.Contains("leijun") { + panic("Failed to find leijun") + } + + if manager.NumSpeakers() != 2 { + panic("There should be only 2 speakers") + } + + all_speakers := manager.AllSpeakers() + log.Printf("All speakers: %v\n", all_speakers) +} + +func main() { + log.SetFlags(log.LstdFlags | log.Lmicroseconds) + + ex := createSpeakerEmbeddingExtractor() + defer sherpa.DeleteSpeakerEmbeddingExtractor(ex) + + manager := sherpa.NewSpeakerEmbeddingManager(ex.Dim()) + defer sherpa.DeleteSpeakerEmbeddingManager(manager) + registerSpeakers(ex, manager) + + // Please download the test data from + // https://github.com/csukuangfj/sr-data + test1 := "./sr-data/test/fangjun-test-sr-1.wav" + embeddings := computeEmbeddings(ex, []string{test1})[0] + threshold := float32(0.6) + name := manager.Search(embeddings, threshold) + if len(name) > 0 { + log.Printf("%v matches %v", test1, name) + } else { + log.Printf("No matches found for %v", test1) + } + + test2 := "./sr-data/test/leijun-test-sr-1.wav" + embeddings = computeEmbeddings(ex, []string{test2})[0] + name = manager.Search(embeddings, threshold) + if len(name) > 0 { + log.Printf("%v matches %v", test2, name) + } else { + log.Printf("No matches found for %v", test2) + } + + test3 := "./sr-data/test/liudehua-test-sr-1.wav" + embeddings = computeEmbeddings(ex, []string{test3})[0] + name = manager.Search(embeddings, threshold) + if len(name) > 0 { + log.Printf("%v matches %v", test3, name) + } else { + log.Printf("No matches found for %v", test3) + } + + if !manager.Remove("fangjun") { + panic("Failed to deregister fangjun") + } else { + log.Print("fangjun deregistered\n") + } + + test1 = "./sr-data/test/fangjun-test-sr-1.wav" + embeddings = computeEmbeddings(ex, []string{test1})[0] + name = manager.Search(embeddings, threshold) + if len(name) > 0 { + log.Printf("%v matches %v", test1, name) + } else { + log.Printf("No matches found for %v", test1) + } +} + +func chk(err error) { + if err != nil { + panic(err) + } +} diff --git a/go-api-examples/speaker-identification/run.sh b/go-api-examples/speaker-identification/run.sh new file mode 100755 index 00000000..3829eee9 --- /dev/null +++ b/go-api-examples/speaker-identification/run.sh @@ -0,0 +1,13 @@ +#!/usr/bin/env bash + +if [ ! -f ./3dspeaker_speech_campplus_sv_zh-cn_16k-common.onnx ]; then + curl -SL -O https://github.com/k2-fsa/sherpa-onnx/releases/download/speaker-recongition-models/3dspeaker_speech_campplus_sv_zh-cn_16k-common.onnx +fi + +if [ ! -f ./sr-data/enroll/fangjun-sr-1.wav ]; then + git clone https://github.com/csukuangfj/sr-data +fi + +go mod tidy +go build +./speaker-identification diff --git a/go-api-examples/vad-asr-paraformer/main.go b/go-api-examples/vad-asr-paraformer/main.go index 25515da6..54e1ed1c 100644 --- a/go-api-examples/vad-asr-paraformer/main.go +++ b/go-api-examples/vad-asr-paraformer/main.go @@ -104,7 +104,7 @@ func main() { duration := float32(len(speechSegment.Samples)) / float32(config.SampleRate) - audio := &sherpa.GeneratedAudio{} + audio := &sherpa.Wave{} audio.Samples = speechSegment.Samples audio.SampleRate = config.SampleRate @@ -120,7 +120,7 @@ func main() { chk(s.Stop()) } -func decode(recognizer *sherpa.OfflineRecognizer, audio *sherpa.GeneratedAudio, id int) { +func decode(recognizer *sherpa.OfflineRecognizer, audio *sherpa.Wave, id int) { stream := sherpa.NewOfflineStream(recognizer) defer sherpa.DeleteOfflineStream(stream) stream.AcceptWaveform(audio.SampleRate, audio.Samples) diff --git a/go-api-examples/vad-asr-whisper/main.go b/go-api-examples/vad-asr-whisper/main.go index 55ed7c86..85c675ef 100644 --- a/go-api-examples/vad-asr-whisper/main.go +++ b/go-api-examples/vad-asr-whisper/main.go @@ -102,7 +102,7 @@ func main() { duration := float32(len(speechSegment.Samples)) / float32(config.SampleRate) - audio := &sherpa.GeneratedAudio{} + audio := &sherpa.Wave{} audio.Samples = speechSegment.Samples audio.SampleRate = config.SampleRate @@ -118,7 +118,7 @@ func main() { chk(s.Stop()) } -func decode(recognizer *sherpa.OfflineRecognizer, audio *sherpa.GeneratedAudio, id int) { +func decode(recognizer *sherpa.OfflineRecognizer, audio *sherpa.Wave, id int) { stream := sherpa.NewOfflineStream(recognizer) defer sherpa.DeleteOfflineStream(stream) stream.AcceptWaveform(audio.SampleRate, audio.Samples) diff --git a/go-api-examples/vad-speaker-identification/go.mod b/go-api-examples/vad-speaker-identification/go.mod new file mode 100644 index 00000000..9b1f1563 --- /dev/null +++ b/go-api-examples/vad-speaker-identification/go.mod @@ -0,0 +1,3 @@ +module vad-speaker-identification + +go 1.12 diff --git a/go-api-examples/vad-speaker-identification/main.go b/go-api-examples/vad-speaker-identification/main.go new file mode 100644 index 00000000..f2b69d09 --- /dev/null +++ b/go-api-examples/vad-speaker-identification/main.go @@ -0,0 +1,221 @@ +package main + +import ( + "fmt" + "github.com/gordonklaus/portaudio" + sherpa "github.com/k2-fsa/sherpa-onnx-go/sherpa_onnx" + "log" +) + +func createSpeakerEmbeddingExtractor() *sherpa.SpeakerEmbeddingExtractor { + config := sherpa.SpeakerEmbeddingExtractorConfig{} + + // Please download the model from + // https://github.com/k2-fsa/sherpa-onnx/releases/download/speaker-recongition-models/3dspeaker_speech_campplus_sv_zh-cn_16k-common.onnx + // + // You can find more models at + // https://github.com/k2-fsa/sherpa-onnx/releases/tag/speaker-recongition-models + + config.Model = "./3dspeaker_speech_campplus_sv_zh-cn_16k-common.onnx" + config.NumThreads = 2 + config.Debug = 1 + config.Provider = "cpu" + + ex := sherpa.NewSpeakerEmbeddingExtractor(&config) + return ex +} + +func computeEmbeddings(ex *sherpa.SpeakerEmbeddingExtractor, files []string) [][]float32 { + embeddings := make([][]float32, len(files)) + + for i, f := range files { + wave := sherpa.ReadWave(f) + + stream := ex.CreateStream() + defer sherpa.DeleteOnlineStream(stream) + stream.AcceptWaveform(wave.SampleRate, wave.Samples) + stream.InputFinished() + embeddings[i] = ex.Compute(stream) + } + + return embeddings + +} + +func registerSpeakers(ex *sherpa.SpeakerEmbeddingExtractor, manager *sherpa.SpeakerEmbeddingManager) { + // Please download the test data from + // https://github.com/csukuangfj/sr-data + spk1_files := []string{ + "./sr-data/enroll/fangjun-sr-1.wav", + "./sr-data/enroll/fangjun-sr-2.wav", + "./sr-data/enroll/fangjun-sr-3.wav", + } + + spk2_files := []string{ + "./sr-data/enroll/leijun-sr-1.wav", + "./sr-data/enroll/leijun-sr-2.wav", + } + + spk1_embeddings := computeEmbeddings(ex, spk1_files) + spk2_embeddings := computeEmbeddings(ex, spk2_files) + + ok := manager.RegisterV("fangjun", spk1_embeddings) + if !ok { + panic("Failed to register fangjun") + } + + ok = manager.RegisterV("leijun", spk2_embeddings) + if !ok { + panic("Failed to register leijun") + } + + if !manager.Contains("fangjun") { + panic("Failed to find fangjun") + } + + if !manager.Contains("leijun") { + panic("Failed to find leijun") + } + + if manager.NumSpeakers() != 2 { + panic("There should be only 2 speakers") + } + + all_speakers := manager.AllSpeakers() + log.Printf("All speakers: %v\n", all_speakers) +} + +func createVad() *sherpa.VoiceActivityDetector { + config := sherpa.VadModelConfig{} + + // Please download silero_vad.onnx from + // https://github.com/snakers4/silero-vad/blob/master/files/silero_vad.onnx + + config.SileroVad.Model = "./silero_vad.onnx" + config.SileroVad.Threshold = 0.5 + config.SileroVad.MinSilenceDuration = 0.5 + config.SileroVad.MinSpeechDuration = 0.5 + config.SileroVad.WindowSize = 512 + config.SampleRate = 16000 + config.NumThreads = 1 + config.Provider = "cpu" + config.Debug = 1 + + var bufferSizeInSeconds float32 = 20 + + vad := sherpa.NewVoiceActivityDetector(&config, bufferSizeInSeconds) + return vad +} + +func main() { + log.SetFlags(log.LstdFlags | log.Lmicroseconds) + + vad := createVad() + defer sherpa.DeleteVoiceActivityDetector(vad) + + ex := createSpeakerEmbeddingExtractor() + defer sherpa.DeleteSpeakerEmbeddingExtractor(ex) + + manager := sherpa.NewSpeakerEmbeddingManager(ex.Dim()) + defer sherpa.DeleteSpeakerEmbeddingManager(manager) + registerSpeakers(ex, manager) + + err := portaudio.Initialize() + if err != nil { + log.Fatalf("Unable to initialize portaudio: %v\n", err) + } + defer portaudio.Terminate() + + default_device, err := portaudio.DefaultInputDevice() + if err != nil { + log.Fatal("Failed to get default input device: %v\n", err) + } + log.Printf("Selected default input device: %s\n", default_device.Name) + param := portaudio.StreamParameters{} + param.Input.Device = default_device + param.Input.Channels = 1 + param.Input.Latency = default_device.DefaultHighInputLatency + + param.SampleRate = 16000 + param.FramesPerBuffer = 0 + param.Flags = portaudio.ClipOff + + // you can choose another value for 0.1 if you want + samplesPerCall := int32(param.SampleRate * 0.1) // 0.1 second + samples := make([]float32, samplesPerCall) + + s, err := portaudio.OpenStream(param, samples) + if err != nil { + log.Fatalf("Failed to open the stream") + } + + defer s.Close() + chk(s.Start()) + + log.Print("Started! Please speak") + printed := false + + k := 0 + for { + chk(s.Read()) + vad.AcceptWaveform(samples) + + if vad.IsSpeech() && !printed { + printed = true + log.Print("Detected speech\n") + } + + if !vad.IsSpeech() { + printed = false + } + + for !vad.IsEmpty() { + speechSegment := vad.Front() + vad.Pop() + + audio := &sherpa.Wave{} + audio.Samples = speechSegment.Samples + audio.SampleRate = 16000 + + // Now decode it + go decode(ex, manager, audio, k) + + k += 1 + } + } + + chk(s.Stop()) + +} + +func chk(err error) { + if err != nil { + panic(err) + } +} + +func decode(ex *sherpa.SpeakerEmbeddingExtractor, manager *sherpa.SpeakerEmbeddingManager, audio *sherpa.GeneratedAudio, id int) { + stream := ex.CreateStream() + defer sherpa.DeleteOnlineStream(stream) + + stream.AcceptWaveform(audio.SampleRate, audio.Samples) + stream.InputFinished() + embeddings := ex.Compute(stream) + threshold := float32(0.5) + name := manager.Search(embeddings, threshold) + if len(name) > 0 { + log.Printf("Found speaker: %v\n", name) + } else { + log.Print("Unknown speaker\n") + name = "Unknown" + } + + duration := float32(len(audio.Samples)) / float32(audio.SampleRate) + + filename := fmt.Sprintf("seg-%d-%.2f-seconds-%s.wav", id, duration, name) + ok := audio.Save(filename) + if ok { + log.Printf("Saved to %s", filename) + } + log.Print("----------\n") +} diff --git a/go-api-examples/vad-speaker-identification/run.sh b/go-api-examples/vad-speaker-identification/run.sh new file mode 100755 index 00000000..3829eee9 --- /dev/null +++ b/go-api-examples/vad-speaker-identification/run.sh @@ -0,0 +1,13 @@ +#!/usr/bin/env bash + +if [ ! -f ./3dspeaker_speech_campplus_sv_zh-cn_16k-common.onnx ]; then + curl -SL -O https://github.com/k2-fsa/sherpa-onnx/releases/download/speaker-recongition-models/3dspeaker_speech_campplus_sv_zh-cn_16k-common.onnx +fi + +if [ ! -f ./sr-data/enroll/fangjun-sr-1.wav ]; then + git clone https://github.com/csukuangfj/sr-data +fi + +go mod tidy +go build +./speaker-identification diff --git a/go-api-examples/vad-spoken-language-identification/main.go b/go-api-examples/vad-spoken-language-identification/main.go index 5db250e8..71c37532 100644 --- a/go-api-examples/vad-spoken-language-identification/main.go +++ b/go-api-examples/vad-spoken-language-identification/main.go @@ -99,7 +99,7 @@ func main() { duration := float32(len(speechSegment.Samples)) / float32(config.SampleRate) - audio := &sherpa.GeneratedAudio{} + audio := &sherpa.Wave{} audio.Samples = speechSegment.Samples audio.SampleRate = config.SampleRate @@ -115,7 +115,7 @@ func main() { chk(s.Stop()) } -func decode(slid *sherpa.SpokenLanguageIdentification, audio *sherpa.GeneratedAudio, id int) { +func decode(slid *sherpa.SpokenLanguageIdentification, audio *sherpa.Wave, id int) { stream := slid.CreateStream() defer sherpa.DeleteOfflineStream(stream) diff --git a/ios-swift/SherpaOnnx/SherpaOnnx.xcodeproj/project.xcworkspace/xcuserdata/fangjun.xcuserdatad/UserInterfaceState.xcuserstate b/ios-swift/SherpaOnnx/SherpaOnnx.xcodeproj/project.xcworkspace/xcuserdata/fangjun.xcuserdatad/UserInterfaceState.xcuserstate deleted file mode 100644 index e6383ab0..00000000 Binary files a/ios-swift/SherpaOnnx/SherpaOnnx.xcodeproj/project.xcworkspace/xcuserdata/fangjun.xcuserdatad/UserInterfaceState.xcuserstate and /dev/null differ diff --git a/ios-swift/SherpaOnnx/SherpaOnnx.xcodeproj/xcuserdata/fangjun.xcuserdatad/xcschemes/xcschememanagement.plist b/ios-swift/SherpaOnnx/SherpaOnnx.xcodeproj/xcuserdata/fangjun.xcuserdatad/xcschemes/xcschememanagement.plist deleted file mode 100644 index 4f6d363e..00000000 --- a/ios-swift/SherpaOnnx/SherpaOnnx.xcodeproj/xcuserdata/fangjun.xcuserdatad/xcschemes/xcschememanagement.plist +++ /dev/null @@ -1,14 +0,0 @@ - - - - - SchemeUserState - - SherpaOnnx.xcscheme_^#shared#^_ - - orderHint - 0 - - - - diff --git a/ios-swiftui/SherpaOnnx/SherpaOnnx.xcodeproj/project.xcworkspace/xcuserdata/fangjun.xcuserdatad/UserInterfaceState.xcuserstate b/ios-swiftui/SherpaOnnx/SherpaOnnx.xcodeproj/project.xcworkspace/xcuserdata/fangjun.xcuserdatad/UserInterfaceState.xcuserstate deleted file mode 100644 index 86affab0..00000000 Binary files a/ios-swiftui/SherpaOnnx/SherpaOnnx.xcodeproj/project.xcworkspace/xcuserdata/fangjun.xcuserdatad/UserInterfaceState.xcuserstate and /dev/null differ diff --git a/scripts/go/_internal/speaker-identification/.gitignore b/scripts/go/_internal/speaker-identification/.gitignore new file mode 100644 index 00000000..41c46990 --- /dev/null +++ b/scripts/go/_internal/speaker-identification/.gitignore @@ -0,0 +1 @@ +speaker-identification diff --git a/scripts/go/_internal/speaker-identification/go.mod b/scripts/go/_internal/speaker-identification/go.mod new file mode 100644 index 00000000..38c62bc3 --- /dev/null +++ b/scripts/go/_internal/speaker-identification/go.mod @@ -0,0 +1,5 @@ +module speaker-identification + +go 1.12 + +replace github.com/k2-fsa/sherpa-onnx-go/sherpa_onnx => ../ diff --git a/scripts/go/_internal/speaker-identification/main.go b/scripts/go/_internal/speaker-identification/main.go new file mode 120000 index 00000000..90623708 --- /dev/null +++ b/scripts/go/_internal/speaker-identification/main.go @@ -0,0 +1 @@ +../../../../go-api-examples/speaker-identification/main.go \ No newline at end of file diff --git a/scripts/go/_internal/speaker-identification/run.sh b/scripts/go/_internal/speaker-identification/run.sh new file mode 120000 index 00000000..41164f3e --- /dev/null +++ b/scripts/go/_internal/speaker-identification/run.sh @@ -0,0 +1 @@ +../../../../go-api-examples/speaker-identification/run.sh \ No newline at end of file diff --git a/scripts/go/_internal/vad-asr-paraformer/go.mod b/scripts/go/_internal/vad-asr-paraformer/go.mod index d0130405..7b763cc8 100644 --- a/scripts/go/_internal/vad-asr-paraformer/go.mod +++ b/scripts/go/_internal/vad-asr-paraformer/go.mod @@ -3,8 +3,3 @@ module vad-asr-paraformer go 1.12 replace github.com/k2-fsa/sherpa-onnx-go/sherpa_onnx => ../ - -require ( - github.com/gordonklaus/portaudio v0.0.0-20230709114228-aafa478834f5 - github.com/k2-fsa/sherpa-onnx-go/sherpa_onnx v0.0.0-00010101000000-000000000000 -) diff --git a/scripts/go/_internal/vad-speaker-identification/go.mod b/scripts/go/_internal/vad-speaker-identification/go.mod new file mode 100644 index 00000000..faee85d3 --- /dev/null +++ b/scripts/go/_internal/vad-speaker-identification/go.mod @@ -0,0 +1,5 @@ +module vad-speaker-identification + +go 1.12 + +replace github.com/k2-fsa/sherpa-onnx-go/sherpa_onnx => ../ diff --git a/scripts/go/_internal/vad-speaker-identification/main.go b/scripts/go/_internal/vad-speaker-identification/main.go new file mode 120000 index 00000000..3109e8b0 --- /dev/null +++ b/scripts/go/_internal/vad-speaker-identification/main.go @@ -0,0 +1 @@ +../../../../go-api-examples/vad-speaker-identification/main.go \ No newline at end of file diff --git a/scripts/go/_internal/vad-speaker-identification/run.sh b/scripts/go/_internal/vad-speaker-identification/run.sh new file mode 120000 index 00000000..d3520535 --- /dev/null +++ b/scripts/go/_internal/vad-speaker-identification/run.sh @@ -0,0 +1 @@ +../../../../go-api-examples/vad-speaker-identification/run.sh \ No newline at end of file diff --git a/scripts/go/sherpa_onnx.go b/scripts/go/sherpa_onnx.go index 01dc5948..1b4c60ab 100644 --- a/scripts/go/sherpa_onnx.go +++ b/scripts/go/sherpa_onnx.go @@ -746,11 +746,11 @@ func (vad *VoiceActivityDetector) AcceptWaveform(samples []float32) { } func (vad *VoiceActivityDetector) IsEmpty() bool { - return 1 == int(C.SherpaOnnxVoiceActivityDetectorEmpty(vad.impl)) + return int(C.SherpaOnnxVoiceActivityDetectorEmpty(vad.impl)) == 1 } func (vad *VoiceActivityDetector) IsSpeech() bool { - return 1 == int(C.SherpaOnnxVoiceActivityDetectorDetected(vad.impl)) + return int(C.SherpaOnnxVoiceActivityDetectorDetected(vad.impl)) == 1 } func (vad *VoiceActivityDetector) Pop() { @@ -852,3 +852,204 @@ func (slid *SpokenLanguageIdentification) Compute(stream *OfflineStream) *Spoken return ans } + +// ============================================================ +// For speaker embedding extraction +// ============================================================ + +type SpeakerEmbeddingExtractorConfig struct { + Model string + NumThreads int + Debug int + Provider string +} + +type SpeakerEmbeddingExtractor struct { + impl *C.struct_SherpaOnnxSpeakerEmbeddingExtractor +} + +// The user has to invoke [DeleteSpeakerEmbeddingExtractor]() to free the returned value +// to avoid memory leak +func NewSpeakerEmbeddingExtractor(config *SpeakerEmbeddingExtractorConfig) *SpeakerEmbeddingExtractor { + c := C.struct_SherpaOnnxSpeakerEmbeddingExtractorConfig{} + + c.model = C.CString(config.Model) + defer C.free(unsafe.Pointer(c.model)) + + c.num_threads = C.int(config.NumThreads) + c.debug = C.int(config.Debug) + + c.provider = C.CString(config.Provider) + defer C.free(unsafe.Pointer(c.provider)) + + ex := &SpeakerEmbeddingExtractor{} + ex.impl = C.SherpaOnnxCreateSpeakerEmbeddingExtractor(&c) + + return ex +} + +func DeleteSpeakerEmbeddingExtractor(ex *SpeakerEmbeddingExtractor) { + C.SherpaOnnxDestroySpeakerEmbeddingExtractor(ex.impl) + ex.impl = nil +} + +func (ex *SpeakerEmbeddingExtractor) Dim() int { + return int(C.SherpaOnnxSpeakerEmbeddingExtractorDim(ex.impl)) +} + +// The user is responsible to invoke [DeleteOnlineStream]() to free +// the returned stream to avoid memory leak +func (ex *SpeakerEmbeddingExtractor) CreateStream() *OnlineStream { + stream := &OnlineStream{} + stream.impl = C.SherpaOnnxSpeakerEmbeddingExtractorCreateStream(ex.impl) + return stream +} + +func (ex *SpeakerEmbeddingExtractor) IsReady(stream *OnlineStream) bool { + return int(C.SherpaOnnxSpeakerEmbeddingExtractorIsReady(ex.impl, stream.impl)) == 1 +} + +func (ex *SpeakerEmbeddingExtractor) Compute(stream *OnlineStream) []float32 { + embedding := C.SherpaOnnxSpeakerEmbeddingExtractorComputeEmbedding(ex.impl, stream.impl) + defer C.SherpaOnnxSpeakerEmbeddingExtractorDestroyEmbedding(embedding) + + n := ex.Dim() + ans := make([]float32, n) + + // see https://stackoverflow.com/questions/48756732/what-does-1-30c-yourtype-do-exactly-in-cgo + // :n:n means 0:n:n, means low:high:capacity + c := (*[1 << 28]C.float)(unsafe.Pointer(embedding))[:n:n] + + for i := 0; i < n; i++ { + ans[i] = float32(c[i]) + } + + return ans +} + +type SpeakerEmbeddingManager struct { + impl *C.struct_SherpaOnnxSpeakerEmbeddingManager +} + +// The user has to invoke [DeleteSpeakerEmbeddingManager]() to free the returned +// value to avoid memory leak +func NewSpeakerEmbeddingManager(dim int) *SpeakerEmbeddingManager { + m := &SpeakerEmbeddingManager{} + m.impl = C.SherpaOnnxCreateSpeakerEmbeddingManager(C.int(dim)) + return m +} + +func DeleteSpeakerEmbeddingManager(m *SpeakerEmbeddingManager) { + C.SherpaOnnxDestroySpeakerEmbeddingManager(m.impl) + m.impl = nil +} + +func (m *SpeakerEmbeddingManager) Register(name string, embedding []float32) bool { + s := C.CString(name) + defer C.free(unsafe.Pointer(s)) + + return C.int(C.SherpaOnnxSpeakerEmbeddingManagerAdd(m.impl, s, (*C.float)(&embedding[0]))) == 1 +} + +func (m *SpeakerEmbeddingManager) RegisterV(name string, embeddings [][]float32) bool { + s := C.CString(name) + defer C.free(unsafe.Pointer(s)) + + if len(embeddings) == 0 { + return false + } + + dim := len(embeddings[0]) + v := make([]float32, 0, dim*len(embeddings)) + for _, embedding := range embeddings { + v = append(v, embedding...) + } + + return C.int(C.SherpaOnnxSpeakerEmbeddingManagerAddListFlattened(m.impl, s, (*C.float)(&v[0]), C.int(len(embeddings)))) == 1 +} + +func (m *SpeakerEmbeddingManager) Remove(name string) bool { + s := C.CString(name) + defer C.free(unsafe.Pointer(s)) + + return C.int(C.SherpaOnnxSpeakerEmbeddingManagerRemove(m.impl, s)) == 1 +} + +func (m *SpeakerEmbeddingManager) Search(embedding []float32, threshold float32) string { + var s string + + name := C.SherpaOnnxSpeakerEmbeddingManagerSearch(m.impl, (*C.float)(&embedding[0]), C.float(threshold)) + defer C.SherpaOnnxSpeakerEmbeddingManagerFreeSearch(name) + + if name != nil { + s = C.GoString(name) + } + + return s +} + +func (m *SpeakerEmbeddingManager) Verify(name string, embedding []float32, threshold float32) bool { + s := C.CString(name) + defer C.free(unsafe.Pointer(s)) + + return C.int(C.SherpaOnnxSpeakerEmbeddingManagerVerify(m.impl, s, (*C.float)(&embedding[0]), C.float(threshold))) == 1 +} + +func (m *SpeakerEmbeddingManager) Contains(name string) bool { + s := C.CString(name) + defer C.free(unsafe.Pointer(s)) + + return C.int(C.SherpaOnnxSpeakerEmbeddingManagerContains(m.impl, s)) == 1 +} + +func (m *SpeakerEmbeddingManager) NumSpeakers() int { + return int(C.SherpaOnnxSpeakerEmbeddingManagerNumSpeakers(m.impl)) +} + +func (m *SpeakerEmbeddingManager) AllSpeakers() []string { + all_speakers := C.SherpaOnnxSpeakerEmbeddingManagerGetAllSpeakers(m.impl) + defer C.SherpaOnnxSpeakerEmbeddingManagerFreeAllSpeakers(all_speakers) + + n := m.NumSpeakers() + if n == 0 { + return nil + } + + // https://stackoverflow.com/questions/62012070/convert-array-of-strings-from-cgo-in-go + p := (*[1 << 28]*C.char)(unsafe.Pointer(all_speakers))[:n:n] + + ans := make([]string, n) + + for i := 0; i < n; i++ { + ans[i] = C.GoString(p[i]) + } + + return ans +} + +// Wave + +// single channel wave +type Wave = GeneratedAudio + +func ReadWave(filename string) *Wave { + s := C.CString(filename) + defer C.free(unsafe.Pointer(s)) + + w := C.SherpaOnnxReadWave(s) + defer C.SherpaOnnxFreeWave(w) + + n := int(w.num_samples) + + ans := &Wave{} + ans.SampleRate = int(w.sample_rate) + samples := (*[1 << 28]C.float)(unsafe.Pointer(w.samples))[:n:n] + + ans.Samples = make([]float32, n) + + for i := 0; i < n; i++ { + ans.Samples[i] = float32(samples[i]) + } + + return ans +}