Swift API for speaker diarization (#1404)

This commit is contained in:
Fangjun Kuang
2024-10-09 23:25:39 +08:00
committed by GitHub
parent df681e9807
commit 1571344509
4 changed files with 209 additions and 0 deletions

View File

@@ -1078,3 +1078,116 @@ class SherpaOnnxOfflinePunctuationWrapper {
return ans
}
}
func sherpaOnnxOfflineSpeakerSegmentationPyannoteModelConfig(model: String)
-> SherpaOnnxOfflineSpeakerSegmentationPyannoteModelConfig
{
return SherpaOnnxOfflineSpeakerSegmentationPyannoteModelConfig(model: toCPointer(model))
}
func sherpaOnnxOfflineSpeakerSegmentationModelConfig(
pyannote: SherpaOnnxOfflineSpeakerSegmentationPyannoteModelConfig,
numThreads: Int = 1,
debug: Int = 0,
provider: String = "cpu"
) -> SherpaOnnxOfflineSpeakerSegmentationModelConfig {
return SherpaOnnxOfflineSpeakerSegmentationModelConfig(
pyannote: pyannote,
num_threads: Int32(numThreads),
debug: Int32(debug),
provider: toCPointer(provider)
)
}
func sherpaOnnxFastClusteringConfig(numClusters: Int = -1, threshold: Float = 0.5)
-> SherpaOnnxFastClusteringConfig
{
return SherpaOnnxFastClusteringConfig(num_clusters: Int32(numClusters), threshold: threshold)
}
func sherpaOnnxSpeakerEmbeddingExtractorConfig(
model: String,
numThreads: Int = 1,
debug: Int = 0,
provider: String = "cpu"
) -> SherpaOnnxSpeakerEmbeddingExtractorConfig {
return SherpaOnnxSpeakerEmbeddingExtractorConfig(
model: toCPointer(model),
num_threads: Int32(numThreads),
debug: Int32(debug),
provider: toCPointer(provider)
)
}
func sherpaOnnxOfflineSpeakerDiarizationConfig(
segmentation: SherpaOnnxOfflineSpeakerSegmentationModelConfig,
embedding: SherpaOnnxSpeakerEmbeddingExtractorConfig,
clustering: SherpaOnnxFastClusteringConfig,
minDurationOn: Float = 0.3,
minDurationOff: Float = 0.5
) -> SherpaOnnxOfflineSpeakerDiarizationConfig {
return SherpaOnnxOfflineSpeakerDiarizationConfig(
segmentation: segmentation,
embedding: embedding,
clustering: clustering,
min_duration_on: minDurationOn,
min_duration_off: minDurationOff
)
}
struct SherpaOnnxOfflineSpeakerDiarizationSegmentWrapper {
var start: Float = 0
var end: Float = 0
var speaker: Int = 0
}
class SherpaOnnxOfflineSpeakerDiarizationWrapper {
/// A pointer to the underlying counterpart in C
let impl: OpaquePointer!
init(
config: UnsafePointer<SherpaOnnxOfflineSpeakerDiarizationConfig>!
) {
impl = SherpaOnnxCreateOfflineSpeakerDiarization(config)
}
deinit {
if let impl {
SherpaOnnxDestroyOfflineSpeakerDiarization(impl)
}
}
var sampleRate: Int {
return Int(SherpaOnnxOfflineSpeakerDiarizationGetSampleRate(impl))
}
func process(samples: [Float]) -> [SherpaOnnxOfflineSpeakerDiarizationSegmentWrapper] {
let result = SherpaOnnxOfflineSpeakerDiarizationProcess(
impl, samples, Int32(samples.count))
if result == nil {
return []
}
let numSegments = Int(SherpaOnnxOfflineSpeakerDiarizationResultGetNumSegments(result))
let p: UnsafePointer<SherpaOnnxOfflineSpeakerDiarizationSegment>? =
SherpaOnnxOfflineSpeakerDiarizationResultSortByStartTime(result)
if p == nil {
return []
}
var ans: [SherpaOnnxOfflineSpeakerDiarizationSegmentWrapper] = []
for i in 0..<numSegments {
ans.append(
SherpaOnnxOfflineSpeakerDiarizationSegmentWrapper(
start: p![i].start, end: p![i].end, speaker: Int(p![i].speaker)))
}
SherpaOnnxOfflineSpeakerDiarizationDestroySegment(p)
SherpaOnnxOfflineSpeakerDiarizationDestroyResult(result)
return ans
}
}