diff --git a/.github/scripts/test-dart.sh b/.github/scripts/test-dart.sh index 05b8fce0..f392d204 100755 --- a/.github/scripts/test-dart.sh +++ b/.github/scripts/test-dart.sh @@ -4,6 +4,10 @@ set -ex cd dart-api-examples +pushd keyword-spotter +./run-zh.sh +popd + pushd non-streaming-asr echo '----------SenseVoice----------' diff --git a/.github/workflows/test-dart.yaml b/.github/workflows/test-dart.yaml index fa4ee5ba..387cdebc 100644 --- a/.github/workflows/test-dart.yaml +++ b/.github/workflows/test-dart.yaml @@ -108,6 +108,7 @@ jobs: cp scripts/dart/non-streaming-asr-pubspec.yaml dart-api-examples/non-streaming-asr/pubspec.yaml cp scripts/dart/streaming-asr-pubspec.yaml dart-api-examples/streaming-asr/pubspec.yaml cp scripts/dart/tts-pubspec.yaml dart-api-examples/tts/pubspec.yaml + cp scripts/dart/kws-pubspec.yaml dart-api-examples/keyword-spotter/pubspec.yaml cp scripts/dart/sherpa-onnx-pubspec.yaml flutter/sherpa_onnx/pubspec.yaml .github/scripts/test-dart.sh diff --git a/CHANGELOG.md b/CHANGELOG.md index b432e375..337ac7de 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,6 +1,7 @@ ## 1.10.17 * Support SenseVoice CTC models. +* Add Dart API for keyword spotter. ## 1.10.16 diff --git a/dart-api-examples/.gitignore b/dart-api-examples/.gitignore index 248f032f..a204dfd4 100644 --- a/dart-api-examples/.gitignore +++ b/dart-api-examples/.gitignore @@ -1 +1,28 @@ !run*.sh +# See https://www.dartlang.org/guides/libraries/private-files + +# Files and directories created by pub +.dart_tool/ +.packages +build/ +# If you're building an application, you may want to check-in your pubspec.lock +pubspec.lock + +# Directory created by dartdoc +# If you don't generate documentation locally you can remove this line. +doc/api/ + +# dotenv environment variables file +.env* + +# Avoid committing generated Javascript files: +*.dart.js +*.info.json # Produced by the --dump-info flag. +*.js # When generated by dart2js. Don't specify *.js if your + # project includes source files written in JavaScript. +*.js_ +*.js.deps +*.js.map + +.flutter-plugins +.flutter-plugins-dependencies diff --git a/dart-api-examples/keyword-spotter/.gitignore b/dart-api-examples/keyword-spotter/.gitignore new file mode 100644 index 00000000..3a857904 --- /dev/null +++ b/dart-api-examples/keyword-spotter/.gitignore @@ -0,0 +1,3 @@ +# https://dart.dev/guides/libraries/private-files +# Created by `dart pub` +.dart_tool/ diff --git a/dart-api-examples/keyword-spotter/CHANGELOG.md b/dart-api-examples/keyword-spotter/CHANGELOG.md new file mode 100644 index 00000000..effe43c8 --- /dev/null +++ b/dart-api-examples/keyword-spotter/CHANGELOG.md @@ -0,0 +1,3 @@ +## 1.0.0 + +- Initial version. diff --git a/dart-api-examples/keyword-spotter/README.md b/dart-api-examples/keyword-spotter/README.md new file mode 100644 index 00000000..d632a268 --- /dev/null +++ b/dart-api-examples/keyword-spotter/README.md @@ -0,0 +1,4 @@ +# Introduction + +This directory contains keyword spotting examples using +Dart API from [sherpa-onnx](https://github.com/k2-fsa/sherpa-onnx) diff --git a/dart-api-examples/keyword-spotter/analysis_options.yaml b/dart-api-examples/keyword-spotter/analysis_options.yaml new file mode 100644 index 00000000..dee8927a --- /dev/null +++ b/dart-api-examples/keyword-spotter/analysis_options.yaml @@ -0,0 +1,30 @@ +# This file configures the static analysis results for your project (errors, +# warnings, and lints). +# +# This enables the 'recommended' set of lints from `package:lints`. +# This set helps identify many issues that may lead to problems when running +# or consuming Dart code, and enforces writing Dart using a single, idiomatic +# style and format. +# +# If you want a smaller set of lints you can change this to specify +# 'package:lints/core.yaml'. These are just the most critical lints +# (the recommended set includes the core lints). +# The core lints are also what is used by pub.dev for scoring packages. + +include: package:lints/recommended.yaml + +# Uncomment the following section to specify additional rules. + +# linter: +# rules: +# - camel_case_types + +# analyzer: +# exclude: +# - path/to/excluded/files/** + +# For more information about the core and recommended set of lints, see +# https://dart.dev/go/core-lints + +# For additional information about configuring this file, see +# https://dart.dev/guides/language/analysis-options diff --git a/dart-api-examples/keyword-spotter/bin/init.dart b/dart-api-examples/keyword-spotter/bin/init.dart new file mode 120000 index 00000000..48508cfd --- /dev/null +++ b/dart-api-examples/keyword-spotter/bin/init.dart @@ -0,0 +1 @@ +../../vad/bin/init.dart \ No newline at end of file diff --git a/dart-api-examples/keyword-spotter/bin/zipformer-transducer.dart b/dart-api-examples/keyword-spotter/bin/zipformer-transducer.dart new file mode 100644 index 00000000..ebef1fd7 --- /dev/null +++ b/dart-api-examples/keyword-spotter/bin/zipformer-transducer.dart @@ -0,0 +1,98 @@ +// Copyright (c) 2024 Xiaomi Corporation +import 'dart:io'; +import 'dart:typed_data'; + +import 'package:args/args.dart'; +import 'package:sherpa_onnx/sherpa_onnx.dart' as sherpa_onnx; + +import './init.dart'; + +void main(List arguments) async { + await initSherpaOnnx(); + + final parser = ArgParser() + ..addOption('encoder', help: 'Path to the encoder model') + ..addOption('decoder', help: 'Path to decoder model') + ..addOption('joiner', help: 'Path to joiner model') + ..addOption('tokens', help: 'Path to tokens.txt') + ..addOption('keywords-file', help: 'Path to keywords.txt') + ..addOption('input-wav', help: 'Path to input.wav to transcribe'); + + final res = parser.parse(arguments); + if (res['encoder'] == null || + res['decoder'] == null || + res['joiner'] == null || + res['tokens'] == null || + res['keywords-file'] == null || + res['input-wav'] == null) { + print(parser.usage); + exit(1); + } + + final encoder = res['encoder'] as String; + final decoder = res['decoder'] as String; + final joiner = res['joiner'] as String; + final tokens = res['tokens'] as String; + final keywordsFile = res['keywords-file'] as String; + final inputWav = res['input-wav'] as String; + + final transducer = sherpa_onnx.OnlineTransducerModelConfig( + encoder: encoder, + decoder: decoder, + joiner: joiner, + ); + + final modelConfig = sherpa_onnx.OnlineModelConfig( + transducer: transducer, + tokens: tokens, + debug: true, + numThreads: 1, + ); + final config = sherpa_onnx.KeywordSpotterConfig( + model: modelConfig, + keywordsFile: keywordsFile, + ); + final spotter = sherpa_onnx.KeywordSpotter(config); + + final waveData = sherpa_onnx.readWave(inputWav); + var stream = spotter.createStream(); + + // simulate streaming. You can choose an arbitrary chunk size. + // chunkSize of a single sample is also ok, i.e, chunkSize = 1 + final chunkSize = 1600; // 0.1 second for 16kHz + final numChunks = waveData.samples.length ~/ chunkSize; + + for (int i = 0; i != numChunks; ++i) { + int start = i * chunkSize; + stream.acceptWaveform( + samples: + Float32List.sublistView(waveData.samples, start, start + chunkSize), + sampleRate: waveData.sampleRate, + ); + while (spotter.isReady(stream)) { + spotter.decode(stream); + final result = spotter.getResult(stream); + if (result.keyword != '') { + print('Detected: ${result.keyword}'); + } + } + } + + // 0.5 seconds, assume sampleRate is 16kHz + final tailPaddings = Float32List(8000); + stream.acceptWaveform( + samples: tailPaddings, + sampleRate: waveData.sampleRate, + ); + + while (spotter.isReady(stream)) { + spotter.decode(stream); + final result = spotter.getResult(stream); + if (result.keyword != '') { + print('Detected: ${result.keyword}'); + } + } + + stream.free(); + spotter.free(); +} diff --git a/dart-api-examples/keyword-spotter/pubspec.yaml b/dart-api-examples/keyword-spotter/pubspec.yaml new file mode 100644 index 00000000..b95dcf72 --- /dev/null +++ b/dart-api-examples/keyword-spotter/pubspec.yaml @@ -0,0 +1,19 @@ +name: keyword_spotter + +description: > + This example demonstrates how to use the Dart API for keyword spotting + +version: 1.0.0 + +environment: + sdk: ^3.4.0 + +dependencies: + sherpa_onnx: ^1.10.17 + # sherpa_onnx: + # path: ../../flutter/sherpa_onnx + path: ^1.9.0 + args: ^2.5.0 + +dev_dependencies: + lints: ^3.0.0 diff --git a/dart-api-examples/keyword-spotter/run-zh.sh b/dart-api-examples/keyword-spotter/run-zh.sh new file mode 100755 index 00000000..5d71360f --- /dev/null +++ b/dart-api-examples/keyword-spotter/run-zh.sh @@ -0,0 +1,21 @@ +#!/usr/bin/env bash + +set -ex + +dart pub get + +if [ ! -f ./sherpa-onnx-kws-zipformer-wenetspeech-3.3M-2024-01-01/tokens.txt ]; then + curl -SL -O https://github.com/k2-fsa/sherpa-onnx/releases/download/kws-models/sherpa-onnx-kws-zipformer-wenetspeech-3.3M-2024-01-01.tar.bz2 + tar xvf sherpa-onnx-kws-zipformer-wenetspeech-3.3M-2024-01-01.tar.bz2 + rm sherpa-onnx-kws-zipformer-wenetspeech-3.3M-2024-01-01.tar.bz2 +fi + +dart run \ + ./bin/zipformer-transducer.dart \ + --encoder ./sherpa-onnx-kws-zipformer-wenetspeech-3.3M-2024-01-01/encoder-epoch-12-avg-2-chunk-16-left-64.onnx \ + --decoder ./sherpa-onnx-kws-zipformer-wenetspeech-3.3M-2024-01-01/decoder-epoch-12-avg-2-chunk-16-left-64.onnx \ + --joiner ./sherpa-onnx-kws-zipformer-wenetspeech-3.3M-2024-01-01/joiner-epoch-12-avg-2-chunk-16-left-64.onnx \ + --tokens ./sherpa-onnx-kws-zipformer-wenetspeech-3.3M-2024-01-01/tokens.txt \ + --keywords-file ./sherpa-onnx-kws-zipformer-wenetspeech-3.3M-2024-01-01/test_wavs/test_keywords.txt \ + --input-wav ./sherpa-onnx-kws-zipformer-wenetspeech-3.3M-2024-01-01/test_wavs/3.wav + diff --git a/flutter/sherpa_onnx/lib/sherpa_onnx.dart b/flutter/sherpa_onnx/lib/sherpa_onnx.dart index 3382406f..4cc1e762 100644 --- a/flutter/sherpa_onnx/lib/sherpa_onnx.dart +++ b/flutter/sherpa_onnx/lib/sherpa_onnx.dart @@ -3,6 +3,7 @@ import 'dart:io'; import 'dart:ffi'; export 'src/feature_config.dart'; +export 'src/keyword_spotter.dart'; export 'src/offline_recognizer.dart'; export 'src/offline_stream.dart'; export 'src/online_recognizer.dart'; diff --git a/flutter/sherpa_onnx/lib/src/keyword_spotter.dart b/flutter/sherpa_onnx/lib/src/keyword_spotter.dart new file mode 100644 index 00000000..724acd6f --- /dev/null +++ b/flutter/sherpa_onnx/lib/src/keyword_spotter.dart @@ -0,0 +1,164 @@ +// Copyright (c) 2024 Xiaomi Corporation +import 'dart:convert'; +import 'dart:ffi'; + +import 'package:ffi/ffi.dart'; + +import './feature_config.dart'; +import './online_stream.dart'; +import './online_recognizer.dart'; +import './sherpa_onnx_bindings.dart'; +import './utils.dart'; + +class KeywordSpotterConfig { + const KeywordSpotterConfig({ + this.feat = const FeatureConfig(), + required this.model, + this.maxActivePaths = 4, + this.numTrailingBlanks = 1, + this.keywordsScore = 1.0, + this.keywordsThreshold = 0.25, + this.keywordsFile = '', + }); + + @override + String toString() { + return 'KeywordSpotterConfig(feat: $feat, model: $model, maxActivePaths: $maxActivePaths, numTrailingBlanks: $numTrailingBlanks, keywordsScore: $keywordsScore, keywordsThreshold: $keywordsThreshold, keywordsFile: $keywordsFile)'; + } + + final FeatureConfig feat; + final OnlineModelConfig model; + + final int maxActivePaths; + final int numTrailingBlanks; + + final double keywordsScore; + final double keywordsThreshold; + final String keywordsFile; +} + +class KeywordResult { + KeywordResult({required this.keyword}); + + @override + String toString() { + return 'KeywordResult(keyword: $keyword)'; + } + + final String keyword; +} + +class KeywordSpotter { + KeywordSpotter._({required this.ptr, required this.config}); + + /// The user is responsible to call the OnlineRecognizer.free() + /// method of the returned instance to avoid memory leak. + factory KeywordSpotter(KeywordSpotterConfig config) { + final c = calloc(); + c.ref.feat.sampleRate = config.feat.sampleRate; + c.ref.feat.featureDim = config.feat.featureDim; + + // transducer + c.ref.model.transducer.encoder = + config.model.transducer.encoder.toNativeUtf8(); + c.ref.model.transducer.decoder = + config.model.transducer.decoder.toNativeUtf8(); + c.ref.model.transducer.joiner = + config.model.transducer.joiner.toNativeUtf8(); + + // paraformer + c.ref.model.paraformer.encoder = + config.model.paraformer.encoder.toNativeUtf8(); + c.ref.model.paraformer.decoder = + config.model.paraformer.decoder.toNativeUtf8(); + + // zipformer2Ctc + c.ref.model.zipformer2Ctc.model = + config.model.zipformer2Ctc.model.toNativeUtf8(); + + c.ref.model.tokens = config.model.tokens.toNativeUtf8(); + c.ref.model.numThreads = config.model.numThreads; + c.ref.model.provider = config.model.provider.toNativeUtf8(); + c.ref.model.debug = config.model.debug ? 1 : 0; + c.ref.model.modelType = config.model.modelType.toNativeUtf8(); + c.ref.model.modelingUnit = config.model.modelingUnit.toNativeUtf8(); + c.ref.model.bpeVocab = config.model.bpeVocab.toNativeUtf8(); + + c.ref.maxActivePaths = config.maxActivePaths; + c.ref.numTrailingBlanks = config.numTrailingBlanks; + c.ref.keywordsScore = config.keywordsScore; + c.ref.keywordsThreshold = config.keywordsThreshold; + c.ref.keywordsFile = config.keywordsFile.toNativeUtf8(); + + final ptr = SherpaOnnxBindings.createKeywordSpotter?.call(c) ?? nullptr; + + calloc.free(c.ref.keywordsFile); + calloc.free(c.ref.model.bpeVocab); + calloc.free(c.ref.model.modelingUnit); + calloc.free(c.ref.model.modelType); + calloc.free(c.ref.model.provider); + calloc.free(c.ref.model.tokens); + calloc.free(c.ref.model.zipformer2Ctc.model); + calloc.free(c.ref.model.paraformer.encoder); + calloc.free(c.ref.model.paraformer.decoder); + + calloc.free(c.ref.model.transducer.encoder); + calloc.free(c.ref.model.transducer.decoder); + calloc.free(c.ref.model.transducer.joiner); + calloc.free(c); + + return KeywordSpotter._(ptr: ptr, config: config); + } + + void free() { + SherpaOnnxBindings.destroyKeywordSpotter?.call(ptr); + ptr = nullptr; + } + + /// The user has to invoke stream.free() on the returned instance + /// to avoid memory leak + OnlineStream createStream({String keywords = ''}) { + if (keywords == '') { + final p = SherpaOnnxBindings.createKeywordStream?.call(ptr) ?? nullptr; + return OnlineStream(ptr: p); + } + + final utf8 = keywords.toNativeUtf8(); + final p = + SherpaOnnxBindings.createKeywordStreamWithKeywords?.call(ptr, utf8) ?? + nullptr; + calloc.free(utf8); + return OnlineStream(ptr: p); + } + + bool isReady(OnlineStream stream) { + int ready = + SherpaOnnxBindings.isKeywordStreamReady?.call(ptr, stream.ptr) ?? 0; + + return ready == 1; + } + + KeywordResult getResult(OnlineStream stream) { + final json = + SherpaOnnxBindings.getKeywordResultAsJson?.call(ptr, stream.ptr) ?? + nullptr; + if (json == nullptr) { + return KeywordResult(keyword: ''); + } + + final parsedJson = jsonDecode(toDartString(json)); + + SherpaOnnxBindings.freeKeywordResultJson?.call(json); + + return KeywordResult( + keyword: parsedJson['keyword'], + ); + } + + void decode(OnlineStream stream) { + SherpaOnnxBindings.decodeKeywordStream?.call(ptr, stream.ptr); + } + + Pointer ptr; + KeywordSpotterConfig config; +} diff --git a/flutter/sherpa_onnx/lib/src/sherpa_onnx_bindings.dart b/flutter/sherpa_onnx/lib/src/sherpa_onnx_bindings.dart index 72dfc96b..0e369009 100644 --- a/flutter/sherpa_onnx/lib/src/sherpa_onnx_bindings.dart +++ b/flutter/sherpa_onnx/lib/src/sherpa_onnx_bindings.dart @@ -283,6 +283,28 @@ final class SherpaOnnxSpeakerEmbeddingExtractorConfig extends Struct { external Pointer provider; } +final class SherpaOnnxKeywordSpotterConfig extends Struct { + external SherpaOnnxFeatureConfig feat; + + external SherpaOnnxOnlineModelConfig model; + + @Int32() + external int maxActivePaths; + + @Int32() + external int numTrailingBlanks; + + @Float() + external double keywordsScore; + + @Float() + external double keywordsThreshold; + + external Pointer keywordsFile; +} + +final class SherpaOnnxKeywordSpotter extends Opaque {} + final class SherpaOnnxOfflineTts extends Opaque {} final class SherpaOnnxCircularBuffer extends Opaque {} @@ -301,6 +323,48 @@ final class SherpaOnnxSpeakerEmbeddingExtractor extends Opaque {} final class SherpaOnnxSpeakerEmbeddingManager extends Opaque {} +typedef CreateKeywordSpotterNative = Pointer Function( + Pointer); + +typedef CreateKeywordSpotter = CreateKeywordSpotterNative; + +typedef DestroyKeywordSpotterNative = Void Function( + Pointer); + +typedef DestroyKeywordSpotter = void Function( + Pointer); + +typedef CreateKeywordStreamNative = Pointer Function( + Pointer); + +typedef CreateKeywordStream = CreateKeywordStreamNative; + +typedef CreateKeywordStreamWithKeywordsNative = Pointer + Function(Pointer, Pointer); + +typedef CreateKeywordStreamWithKeywords = CreateKeywordStreamWithKeywordsNative; + +typedef IsKeywordStreamReadyNative = Int32 Function( + Pointer, Pointer); + +typedef IsKeywordStreamReady = int Function( + Pointer, Pointer); + +typedef DecodeKeywordStreamNative = Void Function( + Pointer, Pointer); + +typedef DecodeKeywordStream = void Function( + Pointer, Pointer); + +typedef GetKeywordResultAsJsonNative = Pointer Function( + Pointer, Pointer); + +typedef GetKeywordResultAsJson = GetKeywordResultAsJsonNative; + +typedef FreeKeywordResultJsonNative = Void Function(Pointer); + +typedef FreeKeywordResultJson = void Function(Pointer); + typedef SherpaOnnxCreateOfflineTtsNative = Pointer Function(Pointer); @@ -735,6 +799,15 @@ typedef SherpaOnnxFreeWaveNative = Void Function(Pointer); typedef SherpaOnnxFreeWave = void Function(Pointer); class SherpaOnnxBindings { + static CreateKeywordSpotter? createKeywordSpotter; + static DestroyKeywordSpotter? destroyKeywordSpotter; + static CreateKeywordStream? createKeywordStream; + static CreateKeywordStreamWithKeywords? createKeywordStreamWithKeywords; + static IsKeywordStreamReady? isKeywordStreamReady; + static DecodeKeywordStream? decodeKeywordStream; + static GetKeywordResultAsJson? getKeywordResultAsJson; + static FreeKeywordResultJson? freeKeywordResultJson; + static SherpaOnnxCreateOfflineTts? createOfflineTts; static SherpaOnnxDestroyOfflineTts? destroyOfflineTts; static SherpaOnnxOfflineTtsSampleRate? offlineTtsSampleRate; @@ -879,6 +952,46 @@ class SherpaOnnxBindings { static SherpaOnnxFreeWave? freeWave; static void init(DynamicLibrary dynamicLibrary) { + createKeywordSpotter ??= dynamicLibrary + .lookup>( + 'CreateKeywordSpotter') + .asFunction(); + + destroyKeywordSpotter ??= dynamicLibrary + .lookup>( + 'DestroyKeywordSpotter') + .asFunction(); + + createKeywordStream ??= dynamicLibrary + .lookup>( + 'CreateKeywordStream') + .asFunction(); + + createKeywordStreamWithKeywords ??= dynamicLibrary + .lookup>( + 'CreateKeywordStreamWithKeywords') + .asFunction(); + + isKeywordStreamReady ??= dynamicLibrary + .lookup>( + 'IsKeywordStreamReady') + .asFunction(); + + decodeKeywordStream ??= dynamicLibrary + .lookup>( + 'DecodeKeywordStream') + .asFunction(); + + getKeywordResultAsJson ??= dynamicLibrary + .lookup>( + 'GetKeywordResultAsJson') + .asFunction(); + + freeKeywordResultJson ??= dynamicLibrary + .lookup>( + 'FreeKeywordResultJson') + .asFunction(); + createOfflineTts ??= dynamicLibrary .lookup>( 'SherpaOnnxCreateOfflineTts') diff --git a/flutter/sherpa_onnx/pubspec.yaml b/flutter/sherpa_onnx/pubspec.yaml index 0064476d..fa3a1f79 100644 --- a/flutter/sherpa_onnx/pubspec.yaml +++ b/flutter/sherpa_onnx/pubspec.yaml @@ -31,20 +31,24 @@ dependencies: sdk: flutter sherpa_onnx_android: ^1.10.17 - # path: ../sherpa_onnx_android + # sherpa_onnx_android: + # path: ../sherpa_onnx_android sherpa_onnx_macos: ^1.10.17 - # path: ../sherpa_onnx_macos + # sherpa_onnx_macos: + # path: ../sherpa_onnx_macos sherpa_onnx_linux: ^1.10.17 - # path: ../sherpa_onnx_linux + # sherpa_onnx_linux: + # path: ../sherpa_onnx_linux # sherpa_onnx_windows: ^1.10.17 - # path: ../sherpa_onnx_windows + # sherpa_onnx_windows: + # path: ../sherpa_onnx_windows sherpa_onnx_ios: ^1.10.17 # sherpa_onnx_ios: - # path: ../sherpa_onnx_ios + # path: ../sherpa_onnx_ios dev_dependencies: flutter_lints: ^3.0.0 diff --git a/scripts/dart/kws-pubspec.yaml b/scripts/dart/kws-pubspec.yaml new file mode 100644 index 00000000..6a9c2652 --- /dev/null +++ b/scripts/dart/kws-pubspec.yaml @@ -0,0 +1,19 @@ +name: keyword_spotter + +description: > + This example demonstrates how to use the Dart API for keyword spotting + +version: 1.0.0 + +environment: + sdk: ^3.4.0 + +dependencies: + # sherpa_onnx: ^1.10.17 + sherpa_onnx: + path: ../../flutter/sherpa_onnx + path: ^1.9.0 + args: ^2.5.0 + +dev_dependencies: + lints: ^3.0.0