diff --git a/flutter/sherpa_onnx/lib/sherpa_onnx.dart b/flutter/sherpa_onnx/lib/sherpa_onnx.dart index b9fb7dd5..30206fc6 100644 --- a/flutter/sherpa_onnx/lib/sherpa_onnx.dart +++ b/flutter/sherpa_onnx/lib/sherpa_onnx.dart @@ -5,12 +5,13 @@ import 'dart:ffi'; export 'src/audio_tagging.dart'; export 'src/feature_config.dart'; export 'src/keyword_spotter.dart'; +export 'src/offline_punctuation.dart'; export 'src/offline_recognizer.dart'; export 'src/offline_speaker_diarization.dart'; export 'src/offline_stream.dart'; +export 'src/online_punctuation.dart'; export 'src/online_recognizer.dart'; export 'src/online_stream.dart'; -export 'src/punctuation.dart'; export 'src/speaker_identification.dart'; export 'src/tts.dart'; export 'src/vad.dart'; diff --git a/flutter/sherpa_onnx/lib/src/punctuation.dart b/flutter/sherpa_onnx/lib/src/offline_punctuation.dart similarity index 100% rename from flutter/sherpa_onnx/lib/src/punctuation.dart rename to flutter/sherpa_onnx/lib/src/offline_punctuation.dart diff --git a/flutter/sherpa_onnx/lib/src/online_punctuation.dart b/flutter/sherpa_onnx/lib/src/online_punctuation.dart new file mode 100644 index 00000000..118be8f5 --- /dev/null +++ b/flutter/sherpa_onnx/lib/src/online_punctuation.dart @@ -0,0 +1,99 @@ +import 'dart:ffi'; +import 'package:ffi/ffi.dart'; + +import './sherpa_onnx_bindings.dart'; + +class OnlinePunctuationModelConfig { + OnlinePunctuationModelConfig( + {required this.cnnBiLstm, + required this.bpeVocab, + this.numThreads = 1, + this.provider = 'cpu', + this.debug = true}); + + @override + String toString() { + return 'OnlinePunctuationModelConfig(cnnBiLstm: $cnnBiLstm, ' + 'bpeVocab: $bpeVocab, numThreads: $numThreads, ' + 'provider: $provider, debug: $debug)'; + } + + final String cnnBiLstm; + final String bpeVocab; + final int numThreads; + final String provider; + final bool debug; +} + +class OnlinePunctuationConfig { + OnlinePunctuationConfig({ + required this.model, + }); + + @override + String toString() { + return 'OnlinePunctuationConfig(model: $model)'; + } + + final OnlinePunctuationModelConfig model; +} + +class OnlinePunctuation { + OnlinePunctuation.fromPtr({required this.ptr, required this.config}); + + OnlinePunctuation._({required this.ptr, required this.config}); + + // The user has to invoke OnlinePunctuation.free() to avoid memory leak. + factory OnlinePunctuation({required OnlinePunctuationConfig config}) { + final c = calloc(); + + final cnnBiLstmPtr = config.model.cnnBiLstm.toNativeUtf8(); + final bpeVocabPtr = config.model.bpeVocab.toNativeUtf8(); + c.ref.model.cnnBiLstm = cnnBiLstmPtr; + c.ref.model.bpeVocab = bpeVocabPtr; + c.ref.model.numThreads = config.model.numThreads; + c.ref.model.debug = config.model.debug ? 1 : 0; + + final providerPtr = config.model.provider.toNativeUtf8(); + c.ref.model.provider = providerPtr; + + final ptr = SherpaOnnxBindings.sherpaOnnxCreateOnlinePunctuation?.call(c) ?? + nullptr; + + // Free the allocated strings and struct memory + calloc.free(providerPtr); + calloc.free(cnnBiLstmPtr); + calloc.free(bpeVocabPtr); + calloc.free(c); + + return OnlinePunctuation._(ptr: ptr, config: config); + } + + void free() { + SherpaOnnxBindings.sherpaOnnxDestroyOnlinePunctuation?.call(ptr); + ptr = nullptr; + } + + String addPunct(String text) { + final textPtr = text.toNativeUtf8(); + + final p = SherpaOnnxBindings.sherpaOnnxOnlinePunctuationAddPunct + ?.call(ptr, textPtr) ?? + nullptr; + + calloc.free(textPtr); + + if (p == nullptr) { + return ''; + } + + final ans = p.toDartString(); + + SherpaOnnxBindings.sherpaOnnxOnlinePunctuationFreeText?.call(p); + + return ans; + } + + Pointer ptr; + final OnlinePunctuationConfig config; +} diff --git a/flutter/sherpa_onnx/lib/src/sherpa_onnx_bindings.dart b/flutter/sherpa_onnx/lib/src/sherpa_onnx_bindings.dart index 1a069dc5..d59fb053 100644 --- a/flutter/sherpa_onnx/lib/src/sherpa_onnx_bindings.dart +++ b/flutter/sherpa_onnx/lib/src/sherpa_onnx_bindings.dart @@ -78,6 +78,20 @@ final class SherpaOnnxOfflinePunctuationConfig extends Struct { external SherpaOnnxOfflinePunctuationModelConfig model; } +final class SherpaOnnxOnlinePunctuationModelConfig extends Struct { + external Pointer cnnBiLstm; + external Pointer bpeVocab; + @Int32() + external int numThreads; + @Int32() + external int debug; + external Pointer provider; +} + +final class SherpaOnnxOnlinePunctuationConfig extends Struct { + external SherpaOnnxOnlinePunctuationModelConfig model; +} + final class SherpaOnnxOfflineZipformerAudioTaggingModelConfig extends Struct { external Pointer model; } @@ -469,6 +483,8 @@ final class SherpaOnnxKeywordSpotterConfig extends Struct { final class SherpaOnnxOfflinePunctuation extends Opaque {} +final class SherpaOnnxOnlinePunctuation extends Opaque {} + final class SherpaOnnxAudioTagging extends Opaque {} final class SherpaOnnxKeywordSpotter extends Opaque {} @@ -512,6 +528,10 @@ typedef SherpaOnnxCreateOfflinePunctuationNative = Pointer Function( Pointer); +typedef SherpaOnnxCreateOnlinePunctuationNative + = Pointer Function( + Pointer); + typedef SherpaOnnxOfflineSpeakerDiarizationGetSampleRateNative = Int32 Function( Pointer); @@ -605,6 +625,26 @@ typedef SherpaOfflinePunctuationFreeTextNative = Void Function(Pointer); typedef SherpaOfflinePunctuationFreeText = void Function(Pointer); +typedef SherpaOnnxCreateOnlinePunctuation + = SherpaOnnxCreateOnlinePunctuationNative; + +typedef SherpaOnnxDestroyOnlinePunctuationNative = Void Function( + Pointer); + +typedef SherpaOnnxDestroyOnlinePunctuation = void Function( + Pointer); + +typedef SherpaOnnxOnlinePunctuationAddPunctNative = Pointer Function( + Pointer, Pointer); + +typedef SherpaOnnxOnlinePunctuationAddPunct + = SherpaOnnxOnlinePunctuationAddPunctNative; + +typedef SherpaOnnxOnlinePunctuationFreeTextNative = Void Function( + Pointer); + +typedef SherpaOnnxOnlinePunctuationFreeText = void Function(Pointer); + typedef SherpaOnnxCreateAudioTaggingNative = Pointer Function(Pointer); @@ -1155,6 +1195,13 @@ class SherpaOnnxBindings { static SherpaOfflinePunctuationAddPunct? sherpaOfflinePunctuationAddPunct; static SherpaOfflinePunctuationFreeText? sherpaOfflinePunctuationFreeText; + static SherpaOnnxCreateOnlinePunctuation? sherpaOnnxCreateOnlinePunctuation; + static SherpaOnnxDestroyOnlinePunctuation? sherpaOnnxDestroyOnlinePunctuation; + static SherpaOnnxOnlinePunctuationAddPunct? + sherpaOnnxOnlinePunctuationAddPunct; + static SherpaOnnxOnlinePunctuationFreeText? + sherpaOnnxOnlinePunctuationFreeText; + static SherpaOnnxCreateAudioTagging? sherpaOnnxCreateAudioTagging; static SherpaOnnxDestroyAudioTagging? sherpaOnnxDestroyAudioTagging; static SherpaOnnxAudioTaggingCreateOfflineStream? @@ -1414,6 +1461,26 @@ class SherpaOnnxBindings { 'SherpaOfflinePunctuationFreeText') .asFunction(); + sherpaOnnxCreateOnlinePunctuation ??= dynamicLibrary + .lookup>( + 'SherpaOnnxCreateOnlinePunctuation') + .asFunction(); + + sherpaOnnxDestroyOnlinePunctuation ??= dynamicLibrary + .lookup>( + 'SherpaOnnxDestroyOnlinePunctuation') + .asFunction(); + + sherpaOnnxOnlinePunctuationAddPunct ??= dynamicLibrary + .lookup>( + 'SherpaOnnxOnlinePunctuationAddPunct') + .asFunction(); + + sherpaOnnxOnlinePunctuationFreeText ??= dynamicLibrary + .lookup>( + 'SherpaOnnxOnlinePunctuationFreeText') + .asFunction(); + sherpaOnnxCreateAudioTagging ??= dynamicLibrary .lookup>( 'SherpaOnnxCreateAudioTagging')