Go API for speaker diarization (#1403)
This commit is contained in:
@@ -0,0 +1,5 @@
|
||||
module non-streaming-speaker-diarization
|
||||
|
||||
go 1.12
|
||||
|
||||
replace github.com/k2-fsa/sherpa-onnx-go/sherpa_onnx => ../
|
||||
1
scripts/go/_internal/non-streaming-speaker-diarization/main.go
Symbolic link
1
scripts/go/_internal/non-streaming-speaker-diarization/main.go
Symbolic link
@@ -0,0 +1 @@
|
||||
../../../../go-api-examples/non-streaming-speaker-diarization/main.go
|
||||
1
scripts/go/_internal/non-streaming-speaker-diarization/run.sh
Symbolic link
1
scripts/go/_internal/non-streaming-speaker-diarization/run.sh
Symbolic link
@@ -0,0 +1 @@
|
||||
../../../../go-api-examples/non-streaming-speaker-diarization/run.sh
|
||||
@@ -1175,7 +1175,14 @@ func ReadWave(filename string) *Wave {
|
||||
w := C.SherpaOnnxReadWave(s)
|
||||
defer C.SherpaOnnxFreeWave(w)
|
||||
|
||||
if w == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
n := int(w.num_samples)
|
||||
if n == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
ans := &Wave{}
|
||||
ans.SampleRate = int(w.sample_rate)
|
||||
@@ -1189,3 +1196,114 @@ func ReadWave(filename string) *Wave {
|
||||
|
||||
return ans
|
||||
}
|
||||
|
||||
// ============================================================
|
||||
// For offline speaker diarization
|
||||
// ============================================================
|
||||
type OfflineSpeakerSegmentationPyannoteModelConfig struct {
|
||||
Model string
|
||||
}
|
||||
|
||||
type OfflineSpeakerSegmentationModelConfig struct {
|
||||
Pyannote OfflineSpeakerSegmentationPyannoteModelConfig
|
||||
NumThreads int
|
||||
Debug int
|
||||
Provider string
|
||||
}
|
||||
|
||||
type FastClusteringConfig struct {
|
||||
NumClusters int
|
||||
Threshold float32
|
||||
}
|
||||
|
||||
type OfflineSpeakerDiarizationConfig struct {
|
||||
Segmentation OfflineSpeakerSegmentationModelConfig
|
||||
Embedding SpeakerEmbeddingExtractorConfig
|
||||
Clustering FastClusteringConfig
|
||||
MinDurationOn float32
|
||||
MinDurationOff float32
|
||||
}
|
||||
|
||||
type OfflineSpeakerDiarization struct {
|
||||
impl *C.struct_SherpaOnnxOfflineSpeakerDiarization
|
||||
}
|
||||
|
||||
func DeleteOfflineSpeakerDiarization(sd *OfflineSpeakerDiarization) {
|
||||
C.SherpaOnnxDestroyOfflineSpeakerDiarization(sd.impl)
|
||||
sd.impl = nil
|
||||
}
|
||||
|
||||
func NewOfflineSpeakerDiarization(config *OfflineSpeakerDiarizationConfig) *OfflineSpeakerDiarization {
|
||||
c := C.struct_SherpaOnnxOfflineSpeakerDiarizationConfig{}
|
||||
c.segmentation.pyannote.model = C.CString(config.Segmentation.Pyannote.Model)
|
||||
defer C.free(unsafe.Pointer(c.segmentation.pyannote.model))
|
||||
|
||||
c.segmentation.num_threads = C.int(config.Segmentation.NumThreads)
|
||||
|
||||
c.segmentation.debug = C.int(config.Segmentation.Debug)
|
||||
|
||||
c.segmentation.provider = C.CString(config.Segmentation.Provider)
|
||||
defer C.free(unsafe.Pointer(c.segmentation.provider))
|
||||
|
||||
c.embedding.model = C.CString(config.Embedding.Model)
|
||||
defer C.free(unsafe.Pointer(c.embedding.model))
|
||||
|
||||
c.embedding.num_threads = C.int(config.Embedding.NumThreads)
|
||||
|
||||
c.embedding.debug = C.int(config.Embedding.Debug)
|
||||
|
||||
c.embedding.provider = C.CString(config.Embedding.Provider)
|
||||
defer C.free(unsafe.Pointer(c.embedding.provider))
|
||||
|
||||
c.clustering.num_clusters = C.int(config.Clustering.NumClusters)
|
||||
c.clustering.threshold = C.float(config.Clustering.Threshold)
|
||||
c.min_duration_on = C.float(config.MinDurationOn)
|
||||
c.min_duration_off = C.float(config.MinDurationOff)
|
||||
|
||||
p := C.SherpaOnnxCreateOfflineSpeakerDiarization(&c)
|
||||
|
||||
if p == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
sd := &OfflineSpeakerDiarization{}
|
||||
sd.impl = p
|
||||
|
||||
return sd
|
||||
}
|
||||
|
||||
func (sd *OfflineSpeakerDiarization) SampleRate() int {
|
||||
return int(C.SherpaOnnxOfflineSpeakerDiarizationGetSampleRate(sd.impl))
|
||||
}
|
||||
|
||||
type OfflineSpeakerDiarizationSegment struct {
|
||||
Start float32
|
||||
End float32
|
||||
Speaker int
|
||||
}
|
||||
|
||||
func (sd *OfflineSpeakerDiarization) Process(samples []float32) []OfflineSpeakerDiarizationSegment {
|
||||
r := C.SherpaOnnxOfflineSpeakerDiarizationProcess(sd.impl, (*C.float)(&samples[0]), C.int(len(samples)))
|
||||
defer C.SherpaOnnxOfflineSpeakerDiarizationDestroyResult(r)
|
||||
|
||||
n := int(C.SherpaOnnxOfflineSpeakerDiarizationResultGetNumSegments(r))
|
||||
|
||||
if n == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
s := C.SherpaOnnxOfflineSpeakerDiarizationResultSortByStartTime(r)
|
||||
defer C.SherpaOnnxOfflineSpeakerDiarizationDestroySegment(s)
|
||||
|
||||
ans := make([]OfflineSpeakerDiarizationSegment, n)
|
||||
|
||||
p := (*[1 << 28]C.struct_SherpaOnnxOfflineSpeakerDiarizationSegment)(unsafe.Pointer(s))[:n:n]
|
||||
|
||||
for i := 0; i < n; i++ {
|
||||
ans[i].Start = float32(p[i].start)
|
||||
ans[i].End = float32(p[i].end)
|
||||
ans[i].Speaker = int(p[i].speaker)
|
||||
}
|
||||
|
||||
return ans
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user