Add Go API for offline punctuation models (#1434)
It is contributed by a community user from [our QQ group](https://k2-fsa.github.io/sherpa/social-groups.html#qq).
This commit is contained in:
@@ -1283,7 +1283,7 @@ func (sd *OfflineSpeakerDiarization) SetConfig(config *OfflineSpeakerDiarization
|
|||||||
c.clustering.num_clusters = C.int(config.Clustering.NumClusters)
|
c.clustering.num_clusters = C.int(config.Clustering.NumClusters)
|
||||||
c.clustering.threshold = C.float(config.Clustering.Threshold)
|
c.clustering.threshold = C.float(config.Clustering.Threshold)
|
||||||
|
|
||||||
SherpaOnnxOfflineSpeakerDiarizationSetConfig(sd.impl, &c)
|
C.SherpaOnnxOfflineSpeakerDiarizationSetConfig(sd.impl, &c)
|
||||||
}
|
}
|
||||||
|
|
||||||
type OfflineSpeakerDiarizationSegment struct {
|
type OfflineSpeakerDiarizationSegment struct {
|
||||||
@@ -1317,3 +1317,51 @@ func (sd *OfflineSpeakerDiarization) Process(samples []float32) []OfflineSpeaker
|
|||||||
|
|
||||||
return ans
|
return ans
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// ============================================================
|
||||||
|
// For punctuation
|
||||||
|
// ============================================================
|
||||||
|
type OfflinePunctuationModelConfig struct {
|
||||||
|
Ct_transformer string
|
||||||
|
Num_threads C.int
|
||||||
|
Debug C.int // true to print debug information of the model
|
||||||
|
Provider string
|
||||||
|
}
|
||||||
|
|
||||||
|
type OfflinePunctuationConfig struct {
|
||||||
|
Model OfflinePunctuationModelConfig
|
||||||
|
}
|
||||||
|
|
||||||
|
type OfflinePunctuation struct {
|
||||||
|
impl *C.struct_SherpaOnnxOfflinePunctuation
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewOfflinePunctuation(config *OfflinePunctuationConfig) *OfflinePunctuation {
|
||||||
|
cfg := C.struct_SherpaOnnxOfflinePunctuationConfig{}
|
||||||
|
cfg.model.ct_transformer = C.CString(config.Model.Ct_transformer)
|
||||||
|
defer C.free(unsafe.Pointer(cfg.model.ct_transformer))
|
||||||
|
|
||||||
|
cfg.model.num_threads = config.Model.Num_threads
|
||||||
|
cfg.model.debug = config.Model.Debug
|
||||||
|
cfg.model.provider = C.CString(config.Model.Provider)
|
||||||
|
defer C.free(unsafe.Pointer(cfg.model.provider))
|
||||||
|
|
||||||
|
punc := &OfflinePunctuation{}
|
||||||
|
punc.impl = C.SherpaOnnxCreateOfflinePunctuation(&cfg)
|
||||||
|
|
||||||
|
return punc
|
||||||
|
}
|
||||||
|
|
||||||
|
func DeleteOfflinePunc(punc *OfflinePunctuation) {
|
||||||
|
C.SherpaOnnxDestroyOfflinePunctuation(punc.impl)
|
||||||
|
punc.impl = nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (punc *OfflinePunctuation) AddPunct(text string) string {
|
||||||
|
p := C.SherpaOfflinePunctuationAddPunct(punc.impl, C.CString(text))
|
||||||
|
defer C.free(unsafe.Pointer(p))
|
||||||
|
|
||||||
|
text_with_punct := C.GoString(p)
|
||||||
|
|
||||||
|
return text_with_punct
|
||||||
|
}
|
||||||
|
|||||||
Reference in New Issue
Block a user