diff --git a/.github/workflows/flutter-linux.yaml b/.github/workflows/flutter-linux.yaml index eb2bdb41..ab78fd3a 100644 --- a/.github/workflows/flutter-linux.yaml +++ b/.github/workflows/flutter-linux.yaml @@ -164,6 +164,19 @@ jobs: cd example/assets curl -SL -O https://github.com/k2-fsa/sherpa-onnx/releases/download/asr-models/silero_vad.onnx + curl -SL -O https://github.com/k2-fsa/sherpa-onnx/releases/download/asr-models/sherpa-onnx-streaming-zipformer-bilingual-zh-en-2023-02-20.tar.bz2 + tar xvf sherpa-onnx-streaming-zipformer-bilingual-zh-en-2023-02-20.tar.bz2 + rm sherpa-onnx-streaming-zipformer-bilingual-zh-en-2023-02-20.tar.bz2 + cd sherpa-onnx-streaming-zipformer-bilingual-zh-en-2023-02-20 + rm encoder-epoch-99-avg-1.onnx + rm decoder-epoch-99-avg-1.int8.onnx + rm joiner-epoch-99-avg-1.onnx + rm README.md + rm bpe.model + rm bpe.vocab + rm -rf test_wavs + ls -lh + cd .. - name: Build flutter shell: bash diff --git a/.github/workflows/flutter-macos.yaml b/.github/workflows/flutter-macos.yaml index 3448978b..6f215e57 100644 --- a/.github/workflows/flutter-macos.yaml +++ b/.github/workflows/flutter-macos.yaml @@ -132,6 +132,19 @@ jobs: # curl -SL -O https://github.com/k2-fsa/sherpa-onnx/releases/download/speaker-recongition-models/3dspeaker_speech_eres2net_base_sv_zh-cn_3dspeaker_16k.onnx curl -SL -O https://github.com/k2-fsa/sherpa-onnx/releases/download/asr-models/silero_vad.onnx # git clone https://github.com/csukuangfj/sr-data + curl -SL -O https://github.com/k2-fsa/sherpa-onnx/releases/download/asr-models/sherpa-onnx-streaming-zipformer-bilingual-zh-en-2023-02-20.tar.bz2 + tar xvf sherpa-onnx-streaming-zipformer-bilingual-zh-en-2023-02-20.tar.bz2 + rm sherpa-onnx-streaming-zipformer-bilingual-zh-en-2023-02-20.tar.bz2 + cd sherpa-onnx-streaming-zipformer-bilingual-zh-en-2023-02-20 + rm encoder-epoch-99-avg-1.onnx + rm decoder-epoch-99-avg-1.int8.onnx + rm joiner-epoch-99-avg-1.onnx + rm README.md + rm bpe.model + rm bpe.vocab + rm -rf test_wavs + ls -lh + cd .. rm -rf sr-data/.git popd diff --git a/.github/workflows/flutter-windows-x64.yaml b/.github/workflows/flutter-windows-x64.yaml index fba0c9fc..95eb8329 100644 --- a/.github/workflows/flutter-windows-x64.yaml +++ b/.github/workflows/flutter-windows-x64.yaml @@ -27,7 +27,7 @@ on: workflow_dispatch: concurrency: - group: flutter-windows-x64${{ github.ref }} + group: flutter-windows-x64-${{ github.ref }} cancel-in-progress: true jobs: @@ -115,6 +115,19 @@ jobs: cd example/assets curl -SL -O https://github.com/k2-fsa/sherpa-onnx/releases/download/asr-models/silero_vad.onnx + curl -SL -O https://github.com/k2-fsa/sherpa-onnx/releases/download/asr-models/sherpa-onnx-streaming-zipformer-bilingual-zh-en-2023-02-20.tar.bz2 + tar xvf sherpa-onnx-streaming-zipformer-bilingual-zh-en-2023-02-20.tar.bz2 + rm sherpa-onnx-streaming-zipformer-bilingual-zh-en-2023-02-20.tar.bz2 + cd sherpa-onnx-streaming-zipformer-bilingual-zh-en-2023-02-20 + rm encoder-epoch-99-avg-1.onnx + rm decoder-epoch-99-avg-1.int8.onnx + rm joiner-epoch-99-avg-1.onnx + rm README.md + rm bpe.model + rm bpe.vocab + rm -rf test_wavs + ls -lh + cd .. - name: Build flutter shell: bash diff --git a/.github/workflows/test-dot-net-nuget.yaml b/.github/workflows/test-dot-net-nuget.yaml index 0e7f21b1..f472cb42 100644 --- a/.github/workflows/test-dot-net-nuget.yaml +++ b/.github/workflows/test-dot-net-nuget.yaml @@ -13,7 +13,7 @@ on: - cron: "50 23 * * *" concurrency: - group: test-dot-net-nuget + group: test-dot-net-nuget-${{ github.ref }} cancel-in-progress: true permissions: diff --git a/.github/workflows/test-dot-net.yaml b/.github/workflows/test-dot-net.yaml index a62e00b5..d0214d43 100644 --- a/.github/workflows/test-dot-net.yaml +++ b/.github/workflows/test-dot-net.yaml @@ -26,7 +26,7 @@ on: workflow_dispatch: concurrency: - group: test-dot-net + group: test-dot-net-${{ github.ref }} cancel-in-progress: true permissions: @@ -61,7 +61,15 @@ jobs: mkdir build cd build - cmake -DBUILD_SHARED_LIBS=ON -DCMAKE_INSTALL_PREFIX=./install -DCMAKE_BUILD_TYPE=Release .. + cmake \ + -DBUILD_SHARED_LIBS=ON \ + -DCMAKE_INSTALL_PREFIX=./install \ + -DCMAKE_BUILD_TYPE=Release \ + -DSHERPA_ONNX_ENABLE_WEBSOCKET=OFF \ + -DBUILD_ESPEAK_NG_EXE=OFF \ + -DSHERPA_ONNX_ENABLE_BINARY=OFF \ + .. + cmake --build . --target install --config Release - name: Build sherpa-onnx for windows x86 @@ -74,7 +82,15 @@ jobs: mkdir build-win32 cd build-win32 - cmake -A Win32 -DBUILD_SHARED_LIBS=ON -DCMAKE_INSTALL_PREFIX=./install -DCMAKE_BUILD_TYPE=Release .. + cmake \ + -A Win32 \ + -DBUILD_SHARED_LIBS=ON \ + -DCMAKE_INSTALL_PREFIX=./install \ + -DCMAKE_BUILD_TYPE=Release \ + -DSHERPA_ONNX_ENABLE_WEBSOCKET=OFF \ + -DBUILD_ESPEAK_NG_EXE=OFF \ + -DSHERPA_ONNX_ENABLE_BINARY=OFF \ + .. cmake --build . --target install --config Release - uses: actions/upload-artifact@v4 diff --git a/sherpa-onnx/flutter/example/assets/streaming-asr.ico b/sherpa-onnx/flutter/example/assets/streaming-asr.ico new file mode 100644 index 00000000..9ce1bc23 Binary files /dev/null and b/sherpa-onnx/flutter/example/assets/streaming-asr.ico differ diff --git a/sherpa-onnx/flutter/example/assets/vad.ico b/sherpa-onnx/flutter/example/assets/vad.ico new file mode 100644 index 00000000..c46ba9b4 Binary files /dev/null and b/sherpa-onnx/flutter/example/assets/vad.ico differ diff --git a/sherpa-onnx/flutter/example/lib/main.dart b/sherpa-onnx/flutter/example/lib/main.dart index 7db1245d..25306057 100644 --- a/sherpa-onnx/flutter/example/lib/main.dart +++ b/sherpa-onnx/flutter/example/lib/main.dart @@ -2,9 +2,8 @@ import 'package:sherpa_onnx/sherpa_onnx.dart' as sherpa_onnx; import 'package:flutter/material.dart'; -import "./speaker_identification_test.dart"; -import "./vad_test.dart"; -import './home.dart'; +import './vad.dart'; +import './streaming_asr.dart'; import './info.dart'; void main() { @@ -20,7 +19,7 @@ class MyApp extends StatelessWidget { theme: ThemeData( primarySwatch: Colors.blue, ), - home: const MyHomePage(title: 'Next-gen Kaldi: VAD demo'), + home: const MyHomePage(title: 'Next-gen Kaldi Demo'), ); } } @@ -35,7 +34,8 @@ class MyHomePage extends StatefulWidget { class _MyHomePageState extends State { int _currentIndex = 0; final List _tabs = [ - HomeScreen(), + StreamingAsrScreen(), + VadScreen(), InfoScreen(), ]; @override @@ -52,10 +52,15 @@ class _MyHomePageState extends State { _currentIndex = index; }); }, + // https://www.xiconeditor.com/ items: [ BottomNavigationBarItem( - icon: Icon(Icons.home), - label: 'Home', + icon: new Image.asset("assets/streaming-asr.ico"), + label: '', + ), + BottomNavigationBarItem( + icon: new Image.asset("assets/vad.ico"), + label: '', ), BottomNavigationBarItem( icon: Icon(Icons.info), diff --git a/sherpa-onnx/flutter/example/lib/streaming_asr.dart b/sherpa-onnx/flutter/example/lib/streaming_asr.dart new file mode 100644 index 00000000..3d74b364 --- /dev/null +++ b/sherpa-onnx/flutter/example/lib/streaming_asr.dart @@ -0,0 +1,259 @@ +// Copyright (c) 2024 Xiaomi Corporation +import 'dart:async'; + +import 'package:flutter/foundation.dart'; +import 'package:flutter/material.dart'; +import 'package:path/path.dart' as p; +import 'package:path_provider/path_provider.dart'; +import 'package:record/record.dart'; + +import 'package:sherpa_onnx/sherpa_onnx.dart' as sherpa_onnx; + +import './utils.dart'; + +import './streaming_transducer_asr_test.dart'; // TODO(fangjun): remove it + +Future createOnlineRecognizer() async { + var encoder = + 'assets/sherpa-onnx-streaming-zipformer-bilingual-zh-en-2023-02-20/encoder-epoch-99-avg-1.int8.onnx'; + var decoder = + 'assets/sherpa-onnx-streaming-zipformer-bilingual-zh-en-2023-02-20/decoder-epoch-99-avg-1.onnx'; + var joiner = + 'assets/sherpa-onnx-streaming-zipformer-bilingual-zh-en-2023-02-20/joiner-epoch-99-avg-1.int8.onnx'; + var tokens = + 'assets/sherpa-onnx-streaming-zipformer-bilingual-zh-en-2023-02-20/tokens.txt'; + + encoder = await copyAssetFile(src: encoder, dst: 'encoder.onnx'); + decoder = await copyAssetFile(src: decoder, dst: 'decoder.onnx'); + joiner = await copyAssetFile(src: joiner, dst: 'joiner.onnx'); + tokens = await copyAssetFile(src: tokens, dst: 'tokens.txt'); + + final transducer = sherpa_onnx.OnlineTransducerModelConfig( + encoder: encoder, + decoder: decoder, + joiner: joiner, + ); + + final modelConfig = sherpa_onnx.OnlineModelConfig( + transducer: transducer, + tokens: tokens, + modelType: 'zipformer', + ); + + final config = sherpa_onnx.OnlineRecognizerConfig(model: modelConfig); + return sherpa_onnx.OnlineRecognizer(config); +} + +class StreamingAsrScreen extends StatefulWidget { + const StreamingAsrScreen({super.key}); + + @override + State createState() => _StreamingAsrScreenState(); +} + +class _StreamingAsrScreenState extends State { + late final TextEditingController _controller; + late final AudioRecorder _audioRecorder; + + String _title = 'Real-time speech recognition'; + String _last = ''; + int _index = 0; + bool _isInitialized = false; + + sherpa_onnx.OnlineRecognizer? _recognizer; + sherpa_onnx.OnlineStream? _stream; + int _sampleRate = 16000; + + StreamSubscription? _recordSub; + RecordState _recordState = RecordState.stop; + + @override + void initState() { + _audioRecorder = AudioRecorder(); + _controller = TextEditingController(); + + _recordSub = _audioRecorder.onStateChanged().listen((recordState) { + _updateRecordState(recordState); + }); + + super.initState(); + } + + Future _start() async { + if (!_isInitialized) { + sherpa_onnx.initBindings(); + _recognizer = await createOnlineRecognizer(); + _stream = _recognizer?.createStream(); + + _isInitialized = true; + } + + try { + if (await _audioRecorder.hasPermission()) { + const encoder = AudioEncoder.pcm16bits; + + if (!await _isEncoderSupported(encoder)) { + return; + } + + final devs = await _audioRecorder.listInputDevices(); + debugPrint(devs.toString()); + + const config = RecordConfig( + encoder: encoder, + sampleRate: 16000, + numChannels: 1, + ); + + final stream = await _audioRecorder.startStream(config); + + stream.listen( + (data) { + final samplesFloat32 = + convertBytesToFloat32(Uint8List.fromList(data)); + + _stream!.acceptWaveform( + samples: samplesFloat32, sampleRate: _sampleRate); + while (_recognizer!.isReady(_stream!)) { + _recognizer!.decode(_stream!); + } + final text = _recognizer!.getResult(_stream!).text; + String textToDisplay = _last; + if (text != '') { + if (_last == '') { + textToDisplay = '$_index: $text'; + } else { + textToDisplay = '$_index: $text\n$_last'; + } + } + + if (_recognizer!.isEndpoint(_stream!)) { + _recognizer!.reset(_stream!); + if (text != '') { + _last = textToDisplay; + _index += 1; + } + } + print('text: $textToDisplay'); + + _controller.value = TextEditingValue( + text: textToDisplay, + selection: TextSelection.collapsed(offset: textToDisplay.length), + ); + }, + onDone: () { + print('stream stopped.'); + }, + ); + } + } catch (e) { + print(e); + } + } + + Future _stop() async { + _stream!.free(); + _stream = _recognizer!.createStream(); + + await _audioRecorder.stop(); + } + + Future _pause() => _audioRecorder.pause(); + + Future _resume() => _audioRecorder.resume(); + + void _updateRecordState(RecordState recordState) { + setState(() => _recordState = recordState); + } + + Future _isEncoderSupported(AudioEncoder encoder) async { + final isSupported = await _audioRecorder.isEncoderSupported( + encoder, + ); + + if (!isSupported) { + debugPrint('${encoder.name} is not supported on this platform.'); + debugPrint('Supported encoders are:'); + + for (final e in AudioEncoder.values) { + if (await _audioRecorder.isEncoderSupported(e)) { + debugPrint('- ${encoder.name}'); + } + } + } + + return isSupported; + } + + @override + Widget build(BuildContext context) { + return MaterialApp( + home: Scaffold( + body: Column( + mainAxisAlignment: MainAxisAlignment.center, + children: [ + Text(_title), + const SizedBox(height: 50), + TextField( + maxLines: 5, + controller: _controller, + readOnly: true, + ), + const SizedBox(height: 50), + Row( + mainAxisAlignment: MainAxisAlignment.center, + children: [ + _buildRecordStopControl(), + const SizedBox(width: 20), + _buildText(), + ], + ), + ], + ), + ), + ); + } + + @override + void dispose() { + _recordSub?.cancel(); + _audioRecorder.dispose(); + _stream?.free(); + _recognizer?.free(); + super.dispose(); + } + + Widget _buildRecordStopControl() { + late Icon icon; + late Color color; + + if (_recordState != RecordState.stop) { + icon = const Icon(Icons.stop, color: Colors.red, size: 30); + color = Colors.red.withOpacity(0.1); + } else { + final theme = Theme.of(context); + icon = Icon(Icons.mic, color: theme.primaryColor, size: 30); + color = theme.primaryColor.withOpacity(0.1); + } + + return ClipOval( + child: Material( + color: color, + child: InkWell( + child: SizedBox(width: 56, height: 56, child: icon), + onTap: () { + (_recordState != RecordState.stop) ? _stop() : _start(); + }, + ), + ), + ); + } + + Widget _buildText() { + if (_recordState == RecordState.stop) { + return const Text("Start"); + } else { + return const Text("Stop"); + } + } +} diff --git a/sherpa-onnx/flutter/example/lib/streaming_transducer_asr_test.dart b/sherpa-onnx/flutter/example/lib/streaming_transducer_asr_test.dart new file mode 100644 index 00000000..4d9e3eb8 --- /dev/null +++ b/sherpa-onnx/flutter/example/lib/streaming_transducer_asr_test.dart @@ -0,0 +1,61 @@ +// Copyright (c) 2024 Xiaomi Corporation +import 'package:path/path.dart'; +import 'package:path_provider/path_provider.dart'; +import 'package:flutter/services.dart' show rootBundle; +import 'dart:typed_data'; +import "dart:io"; + +import 'package:sherpa_onnx/sherpa_onnx.dart' as sherpa_onnx; +import './utils.dart'; + +Future testStreamingTransducerAsr() async { + var encoder = + 'assets/sherpa-onnx-streaming-zipformer-bilingual-zh-en-2023-02-20/encoder-epoch-99-avg-1.int8.onnx'; + var decoder = + 'assets/sherpa-onnx-streaming-zipformer-bilingual-zh-en-2023-02-20/decoder-epoch-99-avg-1.onnx'; + var joiner = + 'assets/sherpa-onnx-streaming-zipformer-bilingual-zh-en-2023-02-20/joiner-epoch-99-avg-1.int8.onnx'; + var tokens = + 'assets/sherpa-onnx-streaming-zipformer-bilingual-zh-en-2023-02-20/tokens.txt'; + + var testWave = + 'assets/sherpa-onnx-streaming-zipformer-bilingual-zh-en-2023-02-20/test_wavs/0.wav'; + + encoder = await copyAssetFile(src: encoder, dst: 'encoder.onnx'); + decoder = await copyAssetFile(src: decoder, dst: 'decoder.onnx'); + joiner = await copyAssetFile(src: joiner, dst: 'joiner.onnx'); + tokens = await copyAssetFile(src: tokens, dst: 'tokens.txt'); + testWave = await copyAssetFile(src: testWave, dst: 'test.wav'); + + final transducer = sherpa_onnx.OnlineTransducerModelConfig( + encoder: encoder, + decoder: decoder, + joiner: joiner, + ); + + final modelConfig = sherpa_onnx.OnlineModelConfig( + transducer: transducer, + tokens: tokens, + modelType: 'zipformer', + ); + + final config = sherpa_onnx.OnlineRecognizerConfig(model: modelConfig); + print(config); + final recognizer = sherpa_onnx.OnlineRecognizer(config); + + final waveData = sherpa_onnx.readWave(testWave); + final stream = recognizer.createStream(); + + stream.acceptWaveform( + samples: waveData.samples, sampleRate: waveData.sampleRate); + while (recognizer.isReady(stream)) { + recognizer.decode(stream); + } + + final result = recognizer.getResult(stream); + print('result is: ${result}'); + + print('recognizer: ${recognizer.ptr}'); + stream.free(); + recognizer.free(); +} diff --git a/sherpa-onnx/flutter/example/lib/home.dart b/sherpa-onnx/flutter/example/lib/vad.dart similarity index 97% rename from sherpa-onnx/flutter/example/lib/home.dart rename to sherpa-onnx/flutter/example/lib/vad.dart index 291e3249..7df0c9e5 100644 --- a/sherpa-onnx/flutter/example/lib/home.dart +++ b/sherpa-onnx/flutter/example/lib/vad.dart @@ -11,14 +11,14 @@ import 'package:sherpa_onnx/sherpa_onnx.dart' as sherpa_onnx; import './utils.dart'; -class HomeScreen extends StatefulWidget { - const HomeScreen({super.key}); +class VadScreen extends StatefulWidget { + const VadScreen({super.key}); @override - State createState() => _HomeScreenState(); + State createState() => _VadScreenState(); } -class _HomeScreenState extends State { +class _VadScreenState extends State { late final AudioRecorder _audioRecorder; bool _printed = false; diff --git a/sherpa-onnx/flutter/example/pubspec.yaml b/sherpa-onnx/flutter/example/pubspec.yaml index a509309a..73681081 100644 --- a/sherpa-onnx/flutter/example/pubspec.yaml +++ b/sherpa-onnx/flutter/example/pubspec.yaml @@ -73,6 +73,7 @@ flutter: # To add assets to your application, add an assets section, like this: assets: - assets/ + - assets/sherpa-onnx-streaming-zipformer-bilingual-zh-en-2023-02-20/ # - assets/sr-data/enroll/ # - assets/sr-data/test/ # - images/a_dot_ham.jpeg diff --git a/sherpa-onnx/flutter/lib/sherpa_onnx.dart b/sherpa-onnx/flutter/lib/sherpa_onnx.dart index 410a0339..fe29e17c 100644 --- a/sherpa-onnx/flutter/lib/sherpa_onnx.dart +++ b/sherpa-onnx/flutter/lib/sherpa_onnx.dart @@ -2,6 +2,7 @@ import 'dart:io'; import 'dart:ffi'; +export 'src/online_recognizer.dart'; export 'src/online_stream.dart'; export 'src/speaker_identification.dart'; export 'src/vad.dart'; diff --git a/sherpa-onnx/flutter/lib/src/online_recognizer.dart b/sherpa-onnx/flutter/lib/src/online_recognizer.dart new file mode 100644 index 00000000..4bfae0b7 --- /dev/null +++ b/sherpa-onnx/flutter/lib/src/online_recognizer.dart @@ -0,0 +1,291 @@ +// Copyright (c) 2024 Xiaomi Corporation +import 'dart:convert'; +import 'dart:ffi'; +import 'dart:typed_data'; + +import 'package:ffi/ffi.dart'; + +import './online_stream.dart'; +import './sherpa_onnx_bindings.dart'; + +class FeatureConfig { + const FeatureConfig({this.sampleRate = 16000, this.featureDim = 80}); + + @override + String toString() { + return 'FeatureConfig(sampleRate: $sampleRate, featureDim: $featureDim)'; + } + + final int sampleRate; + final int featureDim; +} + +class OnlineTransducerModelConfig { + const OnlineTransducerModelConfig({ + this.encoder = '', + this.decoder = '', + this.joiner = '', + }); + + @override + String toString() { + return 'OnlineTransducerModelConfig(encoder: $encoder, decoder: $decoder, joiner: $joiner)'; + } + + final String encoder; + final String decoder; + final String joiner; +} + +class OnlineParaformerModelConfig { + const OnlineParaformerModelConfig({this.encoder = '', this.decoder = ''}); + + @override + String toString() { + return 'OnlineParaformerModelConfig(encoder: $encoder, decoder: $decoder)'; + } + + final String encoder; + final String decoder; +} + +class OnlineZipformer2CtcModelConfig { + const OnlineZipformer2CtcModelConfig({this.model = ''}); + + @override + String toString() { + return 'OnlineZipformer2CtcModelConfig(model: $model)'; + } + + final String model; +} + +class OnlineModelConfig { + const OnlineModelConfig({ + this.transducer = const OnlineTransducerModelConfig(), + this.paraformer = const OnlineParaformerModelConfig(), + this.zipformer2Ctc = const OnlineZipformer2CtcModelConfig(), + required this.tokens, + this.numThreads = 1, + this.provider = 'cpu', + this.debug = true, + this.modelType = '', + }); + + @override + String toString() { + return 'OnlineModelConfig(transducer: $transducer, paraformer: $paraformer, zipformer2Ctc: $zipformer2Ctc, tokens: $tokens, numThreads: $numThreads, provider: $provider, debug: $debug, modelType: $modelType)'; + } + + final OnlineTransducerModelConfig transducer; + final OnlineParaformerModelConfig paraformer; + final OnlineZipformer2CtcModelConfig zipformer2Ctc; + + final String tokens; + + final int numThreads; + + final String provider; + + final bool debug; + + final String modelType; +} + +class OnlineCtcFstDecoderConfig { + const OnlineCtcFstDecoderConfig({this.graph = '', this.maxActive = 3000}); + + @override + String toString() { + return 'OnlineCtcFstDecoderConfig(graph: $graph, maxActive: $maxActive)'; + } + + final String graph; + final int maxActive; +} + +class OnlineRecognizerConfig { + const OnlineRecognizerConfig({ + this.feat = const FeatureConfig(), + required this.model, + this.decodingMethod = 'greedy_search', + this.maxActivePaths = 4, + this.enableEndpoint = true, + this.rule1MinTrailingSilence = 2.4, + this.rule2MinTrailingSilence = 1.2, + this.rule3MinUtteranceLength = 20, + this.hotwordsFile = '', + this.hotwordsScore = 1.5, + this.ctcFstDecoderConfig = const OnlineCtcFstDecoderConfig(), + }); + + @override + String toString() { + return 'OnlineRecognizerConfig(feat: $feat, model: $model, decodingMethod: $decodingMethod, maxActivePaths: $maxActivePaths, enableEndpoint: $enableEndpoint, rule1MinTrailingSilence: $rule1MinTrailingSilence, rule2MinTrailingSilence: $rule2MinTrailingSilence, rule3MinUtteranceLength: $rule3MinUtteranceLength, hotwordsFile: $hotwordsFile, hotwordsScore: $hotwordsScore, ctcFstDecoderConfig: $ctcFstDecoderConfig)'; + } + + final FeatureConfig feat; + final OnlineModelConfig model; + final String decodingMethod; + + final int maxActivePaths; + + final bool enableEndpoint; + + final double rule1MinTrailingSilence; + + final double rule2MinTrailingSilence; + + final double rule3MinUtteranceLength; + + final String hotwordsFile; + + final double hotwordsScore; + + final OnlineCtcFstDecoderConfig ctcFstDecoderConfig; +} + +class OnlineRecognizerResult { + OnlineRecognizerResult( + {required this.text, required this.tokens, required this.timestamps}); + + @override + String toString() { + return 'OnlineRecognizerResult(text: $text, tokens: $tokens, timestamps: $timestamps)'; + } + + final String text; + final List tokens; + final List timestamps; +} + +class OnlineRecognizer { + OnlineRecognizer._({required this.ptr, required this.config}); + + /// The user is responsible to call the OnlineRecognizer.free() + /// method of the returned instance to avoid memory leak. + factory OnlineRecognizer(OnlineRecognizerConfig config) { + final c = calloc(); + c.ref.feat.sampleRate = config.feat.sampleRate; + c.ref.feat.featureDim = config.feat.featureDim; + + // transducer + c.ref.model.transducer.encoder = + config.model.transducer.encoder.toNativeUtf8(); + c.ref.model.transducer.decoder = + config.model.transducer.decoder.toNativeUtf8(); + c.ref.model.transducer.joiner = + config.model.transducer.joiner.toNativeUtf8(); + + // paraformer + c.ref.model.paraformer.encoder = + config.model.paraformer.encoder.toNativeUtf8(); + c.ref.model.paraformer.decoder = + config.model.paraformer.decoder.toNativeUtf8(); + + // zipformer2Ctc + c.ref.model.zipformer2Ctc.model = + config.model.zipformer2Ctc.model.toNativeUtf8(); + + c.ref.model.tokens = config.model.tokens.toNativeUtf8(); + c.ref.model.numThreads = config.model.numThreads; + c.ref.model.provider = config.model.provider.toNativeUtf8(); + c.ref.model.debug = config.model.debug ? 1 : 0; + c.ref.model.modelType = config.model.modelType.toNativeUtf8(); + + c.ref.decodingMethod = config.decodingMethod.toNativeUtf8(); + c.ref.maxActivePaths = config.maxActivePaths; + c.ref.enableEndpoint = config.enableEndpoint ? 1 : 0; + c.ref.rule1MinTrailingSilence = config.rule1MinTrailingSilence; + c.ref.rule2MinTrailingSilence = config.rule2MinTrailingSilence; + c.ref.rule3MinUtteranceLength = config.rule3MinUtteranceLength; + c.ref.hotwordsFile = config.hotwordsFile.toNativeUtf8(); + c.ref.hotwordsScore = config.hotwordsScore; + + c.ref.ctcFstDecoderConfig.graph = + config.ctcFstDecoderConfig.graph.toNativeUtf8(); + c.ref.ctcFstDecoderConfig.maxActive = config.ctcFstDecoderConfig.maxActive; + + final ptr = SherpaOnnxBindings.createOnlineRecognizer?.call(c) ?? nullptr; + + calloc.free(c.ref.ctcFstDecoderConfig.graph); + calloc.free(c.ref.hotwordsFile); + calloc.free(c.ref.decodingMethod); + calloc.free(c.ref.model.modelType); + calloc.free(c.ref.model.provider); + calloc.free(c.ref.model.tokens); + calloc.free(c.ref.model.zipformer2Ctc.model); + calloc.free(c.ref.model.paraformer.encoder); + calloc.free(c.ref.model.paraformer.decoder); + + calloc.free(c.ref.model.transducer.encoder); + calloc.free(c.ref.model.transducer.decoder); + calloc.free(c.ref.model.transducer.joiner); + calloc.free(c); + + return OnlineRecognizer._(ptr: ptr, config: config); + } + + void free() { + SherpaOnnxBindings.destroyOnlineRecognizer?.call(ptr); + ptr = nullptr; + } + + /// The user has to invoke stream.free() on the returned instance + /// to avoid memory leak + OnlineStream createStream({String hotwords = ''}) { + if (hotwords == '') { + final p = SherpaOnnxBindings.createOnlineStream?.call(ptr) ?? nullptr; + return OnlineStream(ptr: p); + } + + final utf8 = hotwords.toNativeUtf8(); + final p = + SherpaOnnxBindings.createOnlineStreamWithHotwords?.call(ptr, utf8) ?? + nullptr; + calloc.free(utf8); + return OnlineStream(ptr: p); + } + + bool isReady(OnlineStream stream) { + int ready = + SherpaOnnxBindings.isOnlineStreamReady?.call(ptr, stream.ptr) ?? 0; + + return ready == 1; + } + + OnlineRecognizerResult getResult(OnlineStream stream) { + final json = + SherpaOnnxBindings.getOnlineStreamResultAsJson?.call(ptr, stream.ptr) ?? + nullptr; + if (json == null) { + return OnlineRecognizerResult(text: '', tokens: [], timestamps: []); + } + + final parsedJson = jsonDecode(json.toDartString()); + + SherpaOnnxBindings.destroyOnlineStreamResultJson?.call(json); + + return OnlineRecognizerResult( + text: parsedJson['text'], + tokens: List.from(parsedJson['tokens']), + timestamps: List.from(parsedJson['timestamps'])); + } + + void reset(OnlineStream stream) { + SherpaOnnxBindings.reset?.call(ptr, stream.ptr); + } + + void decode(OnlineStream stream) { + SherpaOnnxBindings.decodeOnlineStream?.call(ptr, stream.ptr); + } + + bool isEndpoint(OnlineStream stream) { + int yes = SherpaOnnxBindings.isEndpoint?.call(ptr, stream.ptr) ?? 0; + + return yes == 1; + } + + Pointer ptr; + OnlineRecognizerConfig config; +} diff --git a/sherpa-onnx/flutter/lib/src/sherpa_onnx_bindings.dart b/sherpa-onnx/flutter/lib/src/sherpa_onnx_bindings.dart index 82b0bd75..b635819a 100644 --- a/sherpa-onnx/flutter/lib/src/sherpa_onnx_bindings.dart +++ b/sherpa-onnx/flutter/lib/src/sherpa_onnx_bindings.dart @@ -2,6 +2,82 @@ import 'dart:ffi'; import 'package:ffi/ffi.dart'; +final class SherpaOnnxFeatureConfig extends Struct { + @Int32() + external int sampleRate; + + @Int32() + external int featureDim; +} + +final class SherpaOnnxOnlineTransducerModelConfig extends Struct { + external Pointer encoder; + external Pointer decoder; + external Pointer joiner; +} + +final class SherpaOnnxOnlineParaformerModelConfig extends Struct { + external Pointer encoder; + external Pointer decoder; +} + +final class SherpaOnnxOnlineZipformer2CtcModelConfig extends Struct { + external Pointer model; +} + +final class SherpaOnnxOnlineModelConfig extends Struct { + external SherpaOnnxOnlineTransducerModelConfig transducer; + external SherpaOnnxOnlineParaformerModelConfig paraformer; + external SherpaOnnxOnlineZipformer2CtcModelConfig zipformer2Ctc; + + external Pointer tokens; + + @Int32() + external int numThreads; + + external Pointer provider; + + @Int32() + external int debug; + + external Pointer modelType; +} + +final class SherpaOnnxOnlineCtcFstDecoderConfig extends Struct { + external Pointer graph; + + @Int32() + external int maxActive; +} + +final class SherpaOnnxOnlineRecognizerConfig extends Struct { + external SherpaOnnxFeatureConfig feat; + external SherpaOnnxOnlineModelConfig model; + external Pointer decodingMethod; + + @Int32() + external int maxActivePaths; + + @Int32() + external int enableEndpoint; + + @Float() + external double rule1MinTrailingSilence; + + @Float() + external double rule2MinTrailingSilence; + + @Float() + external double rule3MinUtteranceLength; + + external Pointer hotwordsFile; + + @Float() + external double hotwordsScore; + + external SherpaOnnxOnlineCtcFstDecoderConfig ctcFstDecoderConfig; +} + final class SherpaOnnxSileroVadModelConfig extends Struct { external Pointer model; @@ -71,10 +147,66 @@ final class SherpaOnnxVoiceActivityDetector extends Opaque {} final class SherpaOnnxOnlineStream extends Opaque {} +final class SherpaOnnxOnlineRecognizer extends Opaque {} + final class SherpaOnnxSpeakerEmbeddingExtractor extends Opaque {} final class SherpaOnnxSpeakerEmbeddingManager extends Opaque {} +typedef CreateOnlineRecognizerNative = Pointer + Function(Pointer); + +typedef CreateOnlineRecognizer = CreateOnlineRecognizerNative; + +typedef DestroyOnlineRecognizerNative = Void Function( + Pointer); + +typedef DestroyOnlineRecognizer = void Function( + Pointer); + +typedef CreateOnlineStreamNative = Pointer Function( + Pointer); + +typedef CreateOnlineStream = CreateOnlineStreamNative; + +typedef CreateOnlineStreamWithHotwordsNative = Pointer + Function(Pointer, Pointer); + +typedef CreateOnlineStreamWithHotwords = CreateOnlineStreamWithHotwordsNative; + +typedef IsOnlineStreamReadyNative = Int32 Function( + Pointer, Pointer); + +typedef IsOnlineStreamReady = int Function( + Pointer, Pointer); + +typedef DecodeOnlineStreamNative = Void Function( + Pointer, Pointer); + +typedef DecodeOnlineStream = void Function( + Pointer, Pointer); + +typedef GetOnlineStreamResultAsJsonNative = Pointer Function( + Pointer, Pointer); + +typedef GetOnlineStreamResultAsJson = GetOnlineStreamResultAsJsonNative; + +typedef ResetNative = Void Function( + Pointer, Pointer); + +typedef Reset = void Function( + Pointer, Pointer); + +typedef IsEndpointNative = Int32 Function( + Pointer, Pointer); + +typedef IsEndpoint = int Function( + Pointer, Pointer); + +typedef DestroyOnlineStreamResultJsonNative = Void Function(Pointer); + +typedef DestroyOnlineStreamResultJson = void Function(Pointer); + typedef SherpaOnnxCreateVoiceActivityDetectorNative = Pointer Function( Pointer, Float); @@ -356,6 +488,26 @@ typedef SherpaOnnxFreeWaveNative = Void Function(Pointer); typedef SherpaOnnxFreeWave = void Function(Pointer); class SherpaOnnxBindings { + static CreateOnlineRecognizer? createOnlineRecognizer; + + static DestroyOnlineRecognizer? destroyOnlineRecognizer; + + static CreateOnlineStream? createOnlineStream; + + static CreateOnlineStreamWithHotwords? createOnlineStreamWithHotwords; + + static IsOnlineStreamReady? isOnlineStreamReady; + + static DecodeOnlineStream? decodeOnlineStream; + + static GetOnlineStreamResultAsJson? getOnlineStreamResultAsJson; + + static Reset? reset; + + static IsEndpoint? isEndpoint; + + static DestroyOnlineStreamResultJson? destroyOnlineStreamResultJson; + static SherpaOnnxCreateVoiceActivityDetector? createVoiceActivityDetector; static SherpaOnnxDestroyVoiceActivityDetector? destroyVoiceActivityDetector; @@ -459,6 +611,52 @@ class SherpaOnnxBindings { static SherpaOnnxFreeWave? freeWave; static void init(DynamicLibrary dynamicLibrary) { + createOnlineRecognizer ??= dynamicLibrary + .lookup>( + 'CreateOnlineRecognizer') + .asFunction(); + + destroyOnlineRecognizer ??= dynamicLibrary + .lookup>( + 'DestroyOnlineRecognizer') + .asFunction(); + + createOnlineStream ??= dynamicLibrary + .lookup>('CreateOnlineStream') + .asFunction(); + + createOnlineStreamWithHotwords ??= dynamicLibrary + .lookup>( + 'CreateOnlineStreamWithHotwords') + .asFunction(); + + isOnlineStreamReady ??= dynamicLibrary + .lookup>( + 'IsOnlineStreamReady') + .asFunction(); + + decodeOnlineStream ??= dynamicLibrary + .lookup>('DecodeOnlineStream') + .asFunction(); + + getOnlineStreamResultAsJson ??= dynamicLibrary + .lookup>( + 'GetOnlineStreamResultAsJson') + .asFunction(); + + reset ??= dynamicLibrary + .lookup>('Reset') + .asFunction(); + + isEndpoint ??= dynamicLibrary + .lookup>('IsEndpoint') + .asFunction(); + + destroyOnlineStreamResultJson ??= dynamicLibrary + .lookup>( + 'DestroyOnlineStreamResultJson') + .asFunction(); + createVoiceActivityDetector ??= dynamicLibrary .lookup>( 'SherpaOnnxCreateVoiceActivityDetector')