diff --git a/scripts/go/sherpa_onnx.go b/scripts/go/sherpa_onnx.go index b8363049..1d459fff 100644 --- a/scripts/go/sherpa_onnx.go +++ b/scripts/go/sherpa_onnx.go @@ -479,6 +479,202 @@ type OfflineRecognizerResult struct { Event string } +func newCOfflineRecognizerConfig(config *OfflineRecognizerConfig) *C.struct_SherpaOnnxOfflineRecognizerConfig { + c := C.struct_SherpaOnnxOfflineRecognizerConfig{} + c.feat_config.sample_rate = C.int(config.FeatConfig.SampleRate) + c.feat_config.feature_dim = C.int(config.FeatConfig.FeatureDim) + + c.model_config.transducer.encoder = C.CString(config.ModelConfig.Transducer.Encoder) + c.model_config.transducer.decoder = C.CString(config.ModelConfig.Transducer.Decoder) + c.model_config.transducer.joiner = C.CString(config.ModelConfig.Transducer.Joiner) + + c.model_config.paraformer.model = C.CString(config.ModelConfig.Paraformer.Model) + + c.model_config.nemo_ctc.model = C.CString(config.ModelConfig.NemoCTC.Model) + + c.model_config.whisper.encoder = C.CString(config.ModelConfig.Whisper.Encoder) + c.model_config.whisper.decoder = C.CString(config.ModelConfig.Whisper.Decoder) + c.model_config.whisper.language = C.CString(config.ModelConfig.Whisper.Language) + c.model_config.whisper.task = C.CString(config.ModelConfig.Whisper.Task) + c.model_config.whisper.tail_paddings = C.int(config.ModelConfig.Whisper.TailPaddings) + + c.model_config.tdnn.model = C.CString(config.ModelConfig.Tdnn.Model) + + c.model_config.sense_voice.model = C.CString(config.ModelConfig.SenseVoice.Model) + c.model_config.sense_voice.language = C.CString(config.ModelConfig.SenseVoice.Language) + c.model_config.sense_voice.use_itn = C.int(config.ModelConfig.SenseVoice.UseInverseTextNormalization) + + c.model_config.moonshine.preprocessor = C.CString(config.ModelConfig.Moonshine.Preprocessor) + c.model_config.moonshine.encoder = C.CString(config.ModelConfig.Moonshine.Encoder) + c.model_config.moonshine.uncached_decoder = C.CString(config.ModelConfig.Moonshine.UncachedDecoder) + c.model_config.moonshine.cached_decoder = C.CString(config.ModelConfig.Moonshine.CachedDecoder) + + c.model_config.fire_red_asr.encoder = C.CString(config.ModelConfig.FireRedAsr.Encoder) + c.model_config.fire_red_asr.decoder = C.CString(config.ModelConfig.FireRedAsr.Decoder) + + c.model_config.tokens = C.CString(config.ModelConfig.Tokens) + + c.model_config.num_threads = C.int(config.ModelConfig.NumThreads) + + c.model_config.debug = C.int(config.ModelConfig.Debug) + + c.model_config.provider = C.CString(config.ModelConfig.Provider) + + c.model_config.model_type = C.CString(config.ModelConfig.ModelType) + + c.model_config.modeling_unit = C.CString(config.ModelConfig.ModelingUnit) + + c.model_config.bpe_vocab = C.CString(config.ModelConfig.BpeVocab) + + c.model_config.telespeech_ctc = C.CString(config.ModelConfig.TeleSpeechCtc) + + c.lm_config.model = C.CString(config.LmConfig.Model) + c.lm_config.scale = C.float(config.LmConfig.Scale) + + c.decoding_method = C.CString(config.DecodingMethod) + + c.max_active_paths = C.int(config.MaxActivePaths) + + c.hotwords_file = C.CString(config.HotwordsFile) + c.hotwords_score = C.float(config.HotwordsScore) + + c.blank_penalty = C.float(config.BlankPenalty) + + c.rule_fsts = C.CString(config.RuleFsts) + c.rule_fars = C.CString(config.RuleFars) + return &c +} +func freeCOfflineRecognizerConfig(c *C.struct_SherpaOnnxOfflineRecognizerConfig) { + if c.model_config.transducer.encoder != nil { + C.free(unsafe.Pointer(c.model_config.transducer.encoder)) + c.model_config.transducer.encoder = nil + } + if c.model_config.transducer.decoder != nil { + C.free(unsafe.Pointer(c.model_config.transducer.decoder)) + c.model_config.transducer.decoder = nil + } + if c.model_config.transducer.joiner != nil { + C.free(unsafe.Pointer(c.model_config.transducer.joiner)) + c.model_config.transducer.joiner = nil + } + + if c.model_config.paraformer.model != nil { + C.free(unsafe.Pointer(c.model_config.paraformer.model)) + c.model_config.paraformer.model = nil + } + + if c.model_config.nemo_ctc.model != nil { + C.free(unsafe.Pointer(c.model_config.nemo_ctc.model)) + c.model_config.nemo_ctc.model = nil + } + + if c.model_config.whisper.encoder != nil { + C.free(unsafe.Pointer(c.model_config.whisper.encoder)) + c.model_config.whisper.encoder = nil + } + if c.model_config.whisper.decoder != nil { + C.free(unsafe.Pointer(c.model_config.whisper.decoder)) + c.model_config.whisper.decoder = nil + } + if c.model_config.whisper.language != nil { + C.free(unsafe.Pointer(c.model_config.whisper.language)) + c.model_config.whisper.language = nil + } + if c.model_config.whisper.task != nil { + C.free(unsafe.Pointer(c.model_config.whisper.task)) + c.model_config.whisper.task = nil + } + + if c.model_config.tdnn.model != nil { + C.free(unsafe.Pointer(c.model_config.tdnn.model)) + c.model_config.tdnn.model = nil + } + + if c.model_config.sense_voice.model != nil { + C.free(unsafe.Pointer(c.model_config.sense_voice.model)) + c.model_config.sense_voice.model = nil + } + if c.model_config.sense_voice.language != nil { + C.free(unsafe.Pointer(c.model_config.sense_voice.language)) + c.model_config.sense_voice.language = nil + } + + if c.model_config.moonshine.preprocessor != nil { + C.free(unsafe.Pointer(c.model_config.moonshine.preprocessor)) + c.model_config.moonshine.preprocessor = nil + } + if c.model_config.moonshine.encoder != nil { + C.free(unsafe.Pointer(c.model_config.moonshine.encoder)) + c.model_config.moonshine.encoder = nil + } + if c.model_config.moonshine.uncached_decoder != nil { + C.free(unsafe.Pointer(c.model_config.moonshine.uncached_decoder)) + c.model_config.moonshine.uncached_decoder = nil + } + if c.model_config.moonshine.cached_decoder != nil { + C.free(unsafe.Pointer(c.model_config.moonshine.cached_decoder)) + c.model_config.moonshine.cached_decoder = nil + } + + if c.model_config.fire_red_asr.encoder != nil { + C.free(unsafe.Pointer(c.model_config.fire_red_asr.encoder)) + c.model_config.fire_red_asr.encoder = nil + } + if c.model_config.fire_red_asr.decoder != nil { + C.free(unsafe.Pointer(c.model_config.fire_red_asr.decoder)) + c.model_config.fire_red_asr.decoder = nil + } + + if c.model_config.tokens != nil { + C.free(unsafe.Pointer(c.model_config.tokens)) + c.model_config.tokens = nil + } + if c.model_config.provider != nil { + C.free(unsafe.Pointer(c.model_config.provider)) + c.model_config.provider = nil + } + if c.model_config.model_type != nil { + C.free(unsafe.Pointer(c.model_config.model_type)) + c.model_config.model_type = nil + } + if c.model_config.modeling_unit != nil { + C.free(unsafe.Pointer(c.model_config.modeling_unit)) + c.model_config.modeling_unit = nil + } + if c.model_config.bpe_vocab != nil { + C.free(unsafe.Pointer(c.model_config.bpe_vocab)) + c.model_config.bpe_vocab = nil + } + if c.model_config.telespeech_ctc != nil { + C.free(unsafe.Pointer(c.model_config.telespeech_ctc)) + c.model_config.telespeech_ctc = nil + } + + if c.lm_config.model != nil { + C.free(unsafe.Pointer(c.lm_config.model)) + c.lm_config.model = nil + } + + if c.decoding_method != nil { + C.free(unsafe.Pointer(c.decoding_method)) + c.decoding_method = nil + } + + if c.hotwords_file != nil { + C.free(unsafe.Pointer(c.hotwords_file)) + c.hotwords_file = nil + } + + if c.rule_fsts != nil { + C.free(unsafe.Pointer(c.rule_fsts)) + c.rule_fsts = nil + } + if c.rule_fars != nil { + C.free(unsafe.Pointer(c.rule_fars)) + c.rule_fars = nil + } +} + // Frees the internal pointer of the recognition to avoid memory leak. func DeleteOfflineRecognizer(recognizer *OfflineRecognizer) { C.SherpaOnnxDestroyOfflineRecognizer(recognizer.impl) @@ -488,114 +684,10 @@ func DeleteOfflineRecognizer(recognizer *OfflineRecognizer) { // The user is responsible to invoke [DeleteOfflineRecognizer]() to free // the returned recognizer to avoid memory leak func NewOfflineRecognizer(config *OfflineRecognizerConfig) *OfflineRecognizer { - c := C.struct_SherpaOnnxOfflineRecognizerConfig{} - c.feat_config.sample_rate = C.int(config.FeatConfig.SampleRate) - c.feat_config.feature_dim = C.int(config.FeatConfig.FeatureDim) + c := newCOfflineRecognizerConfig(config) + defer freeCOfflineRecognizerConfig(c) - c.model_config.transducer.encoder = C.CString(config.ModelConfig.Transducer.Encoder) - defer C.free(unsafe.Pointer(c.model_config.transducer.encoder)) - - c.model_config.transducer.decoder = C.CString(config.ModelConfig.Transducer.Decoder) - defer C.free(unsafe.Pointer(c.model_config.transducer.decoder)) - - c.model_config.transducer.joiner = C.CString(config.ModelConfig.Transducer.Joiner) - defer C.free(unsafe.Pointer(c.model_config.transducer.joiner)) - - c.model_config.paraformer.model = C.CString(config.ModelConfig.Paraformer.Model) - defer C.free(unsafe.Pointer(c.model_config.paraformer.model)) - - c.model_config.nemo_ctc.model = C.CString(config.ModelConfig.NemoCTC.Model) - defer C.free(unsafe.Pointer(c.model_config.nemo_ctc.model)) - - c.model_config.whisper.encoder = C.CString(config.ModelConfig.Whisper.Encoder) - defer C.free(unsafe.Pointer(c.model_config.whisper.encoder)) - - c.model_config.whisper.decoder = C.CString(config.ModelConfig.Whisper.Decoder) - defer C.free(unsafe.Pointer(c.model_config.whisper.decoder)) - - c.model_config.whisper.language = C.CString(config.ModelConfig.Whisper.Language) - defer C.free(unsafe.Pointer(c.model_config.whisper.language)) - - c.model_config.whisper.task = C.CString(config.ModelConfig.Whisper.Task) - defer C.free(unsafe.Pointer(c.model_config.whisper.task)) - - c.model_config.whisper.tail_paddings = C.int(config.ModelConfig.Whisper.TailPaddings) - - c.model_config.tdnn.model = C.CString(config.ModelConfig.Tdnn.Model) - defer C.free(unsafe.Pointer(c.model_config.tdnn.model)) - - c.model_config.sense_voice.model = C.CString(config.ModelConfig.SenseVoice.Model) - defer C.free(unsafe.Pointer(c.model_config.sense_voice.model)) - - c.model_config.sense_voice.language = C.CString(config.ModelConfig.SenseVoice.Language) - defer C.free(unsafe.Pointer(c.model_config.sense_voice.language)) - - c.model_config.sense_voice.use_itn = C.int(config.ModelConfig.SenseVoice.UseInverseTextNormalization) - - c.model_config.moonshine.preprocessor = C.CString(config.ModelConfig.Moonshine.Preprocessor) - defer C.free(unsafe.Pointer(c.model_config.moonshine.preprocessor)) - - c.model_config.moonshine.encoder = C.CString(config.ModelConfig.Moonshine.Encoder) - defer C.free(unsafe.Pointer(c.model_config.moonshine.encoder)) - - c.model_config.moonshine.uncached_decoder = C.CString(config.ModelConfig.Moonshine.UncachedDecoder) - defer C.free(unsafe.Pointer(c.model_config.moonshine.uncached_decoder)) - - c.model_config.moonshine.cached_decoder = C.CString(config.ModelConfig.Moonshine.CachedDecoder) - defer C.free(unsafe.Pointer(c.model_config.moonshine.cached_decoder)) - - c.model_config.fire_red_asr.encoder = C.CString(config.ModelConfig.FireRedAsr.Encoder) - defer C.free(unsafe.Pointer(c.model_config.fire_red_asr.encoder)) - - c.model_config.fire_red_asr.decoder = C.CString(config.ModelConfig.FireRedAsr.Decoder) - defer C.free(unsafe.Pointer(c.model_config.fire_red_asr.decoder)) - - c.model_config.tokens = C.CString(config.ModelConfig.Tokens) - defer C.free(unsafe.Pointer(c.model_config.tokens)) - - c.model_config.num_threads = C.int(config.ModelConfig.NumThreads) - - c.model_config.debug = C.int(config.ModelConfig.Debug) - - c.model_config.provider = C.CString(config.ModelConfig.Provider) - defer C.free(unsafe.Pointer(c.model_config.provider)) - - c.model_config.model_type = C.CString(config.ModelConfig.ModelType) - defer C.free(unsafe.Pointer(c.model_config.model_type)) - - c.model_config.modeling_unit = C.CString(config.ModelConfig.ModelingUnit) - defer C.free(unsafe.Pointer(c.model_config.modeling_unit)) - - c.model_config.bpe_vocab = C.CString(config.ModelConfig.BpeVocab) - defer C.free(unsafe.Pointer(c.model_config.bpe_vocab)) - - c.model_config.telespeech_ctc = C.CString(config.ModelConfig.TeleSpeechCtc) - defer C.free(unsafe.Pointer(c.model_config.telespeech_ctc)) - - c.lm_config.model = C.CString(config.LmConfig.Model) - defer C.free(unsafe.Pointer(c.lm_config.model)) - - c.lm_config.scale = C.float(config.LmConfig.Scale) - - c.decoding_method = C.CString(config.DecodingMethod) - defer C.free(unsafe.Pointer(c.decoding_method)) - - c.max_active_paths = C.int(config.MaxActivePaths) - - c.hotwords_file = C.CString(config.HotwordsFile) - defer C.free(unsafe.Pointer(c.hotwords_file)) - - c.hotwords_score = C.float(config.HotwordsScore) - - c.blank_penalty = C.float(config.BlankPenalty) - - c.rule_fsts = C.CString(config.RuleFsts) - defer C.free(unsafe.Pointer(c.rule_fsts)) - - c.rule_fars = C.CString(config.RuleFars) - defer C.free(unsafe.Pointer(c.rule_fars)) - - impl := C.SherpaOnnxCreateOfflineRecognizer(&c) + impl := C.SherpaOnnxCreateOfflineRecognizer(c) if impl == nil { return nil } @@ -605,6 +697,14 @@ func NewOfflineRecognizer(config *OfflineRecognizerConfig) *OfflineRecognizer { return recognizer } +// Set new config to replace +func (r *OfflineRecognizer) SetConfig(config *OfflineRecognizerConfig) { + c := newCOfflineRecognizerConfig(config) + defer freeCOfflineRecognizerConfig(c) + + C.SherpaOnnxOfflineRecognizerSetConfig(r.impl, c) +} + // Frees the internal pointer of the stream to avoid memory leak. func DeleteOfflineStream(stream *OfflineStream) { C.SherpaOnnxDestroyOfflineStream(stream.impl)