Add Dart API for audio tagging (#1181)
This commit is contained in:
144
flutter/sherpa_onnx/lib/src/audio_tagging.dart
Normal file
144
flutter/sherpa_onnx/lib/src/audio_tagging.dart
Normal file
@@ -0,0 +1,144 @@
|
||||
// Copyright (c) 2024 Xiaomi Corporation
|
||||
import 'dart:ffi';
|
||||
import 'package:ffi/ffi.dart';
|
||||
|
||||
import './offline_stream.dart';
|
||||
import './sherpa_onnx_bindings.dart';
|
||||
|
||||
class OfflineZipformerAudioTaggingModelConfig {
|
||||
const OfflineZipformerAudioTaggingModelConfig({this.model = ''});
|
||||
|
||||
@override
|
||||
String toString() {
|
||||
return 'OfflineZipformerAudioTaggingModelConfig(model: $model)';
|
||||
}
|
||||
|
||||
final String model;
|
||||
}
|
||||
|
||||
class AudioTaggingModelConfig {
|
||||
AudioTaggingModelConfig(
|
||||
{this.zipformer = const OfflineZipformerAudioTaggingModelConfig(),
|
||||
this.ced = '',
|
||||
this.numThreads = 1,
|
||||
this.provider = 'cpu',
|
||||
this.debug = true});
|
||||
|
||||
@override
|
||||
String toString() {
|
||||
return 'AudioTaggingModelConfig(zipformer: $zipformer, ced: $ced, numThreads: $numThreads, provider: $provider, debug: $debug)';
|
||||
}
|
||||
|
||||
final OfflineZipformerAudioTaggingModelConfig zipformer;
|
||||
final String ced;
|
||||
final int numThreads;
|
||||
final String provider;
|
||||
final bool debug;
|
||||
}
|
||||
|
||||
class AudioTaggingConfig {
|
||||
AudioTaggingConfig({required this.model, this.labels = ''});
|
||||
|
||||
@override
|
||||
String toString() {
|
||||
return 'AudioTaggingConfig(model: $model, labels: $labels)';
|
||||
}
|
||||
|
||||
final AudioTaggingModelConfig model;
|
||||
final String labels;
|
||||
}
|
||||
|
||||
class AudioEvent {
|
||||
AudioEvent({required this.name, required this.index, required this.prob});
|
||||
|
||||
@override
|
||||
String toString() {
|
||||
return 'AudioEvent(name: $name, index: $index, prob: $prob)';
|
||||
}
|
||||
|
||||
final String name;
|
||||
final int index;
|
||||
final double prob;
|
||||
}
|
||||
|
||||
class AudioTagging {
|
||||
AudioTagging._({required this.ptr, required this.config});
|
||||
|
||||
// The user has to invoke AudioTagging.free() to avoid memory leak.
|
||||
factory AudioTagging({required AudioTaggingConfig config}) {
|
||||
final c = calloc<SherpaOnnxAudioTaggingConfig>();
|
||||
|
||||
final zipformerPtr = config.model.zipformer.model.toNativeUtf8();
|
||||
c.ref.model.zipformer.model = zipformerPtr;
|
||||
|
||||
final cedPtr = config.model.ced.toNativeUtf8();
|
||||
c.ref.model.ced = cedPtr;
|
||||
|
||||
c.ref.model.numThreads = config.model.numThreads;
|
||||
|
||||
final providerPtr = config.model.provider.toNativeUtf8();
|
||||
c.ref.model.provider = providerPtr;
|
||||
|
||||
c.ref.model.debug = config.model.debug ? 1 : 0;
|
||||
|
||||
final labelsPtr = config.labels.toNativeUtf8();
|
||||
c.ref.labels = labelsPtr;
|
||||
|
||||
final ptr =
|
||||
SherpaOnnxBindings.sherpaOnnxCreateAudioTagging?.call(c) ?? nullptr;
|
||||
|
||||
calloc.free(labelsPtr);
|
||||
calloc.free(providerPtr);
|
||||
calloc.free(cedPtr);
|
||||
calloc.free(zipformerPtr);
|
||||
calloc.free(c);
|
||||
|
||||
return AudioTagging._(ptr: ptr, config: config);
|
||||
}
|
||||
|
||||
void free() {
|
||||
SherpaOnnxBindings.sherpaOnnxDestroyAudioTagging?.call(ptr);
|
||||
ptr = nullptr;
|
||||
}
|
||||
|
||||
/// The user has to invoke stream.free() on the returned instance
|
||||
/// to avoid memory leak
|
||||
OfflineStream createStream() {
|
||||
final p = SherpaOnnxBindings.sherpaOnnxAudioTaggingCreateOfflineStream
|
||||
?.call(ptr) ??
|
||||
nullptr;
|
||||
return OfflineStream(ptr: p);
|
||||
}
|
||||
|
||||
List<AudioEvent> compute({required OfflineStream stream, required int topK}) {
|
||||
final pp = SherpaOnnxBindings.sherpaOnnxAudioTaggingCompute
|
||||
?.call(ptr, stream.ptr, topK) ??
|
||||
nullptr;
|
||||
|
||||
final ans = <AudioEvent>[];
|
||||
|
||||
if (pp == nullptr) {
|
||||
return ans;
|
||||
}
|
||||
|
||||
var i = 0;
|
||||
while (pp[i] != nullptr) {
|
||||
final p = pp[i];
|
||||
|
||||
final name = p.ref.name.toDartString();
|
||||
final index = p.ref.index;
|
||||
final prob = p.ref.prob;
|
||||
final e = AudioEvent(name: name, index: index, prob: prob);
|
||||
ans.add(e);
|
||||
|
||||
i += 1;
|
||||
}
|
||||
|
||||
SherpaOnnxBindings.sherpaOnnxAudioTaggingFreeResults?.call(pp);
|
||||
|
||||
return ans;
|
||||
}
|
||||
|
||||
Pointer<SherpaOnnxAudioTagging> ptr;
|
||||
final AudioTaggingConfig config;
|
||||
}
|
||||
@@ -2,6 +2,41 @@
|
||||
import 'dart:ffi';
|
||||
import 'package:ffi/ffi.dart';
|
||||
|
||||
final class SherpaOnnxOfflineZipformerAudioTaggingModelConfig extends Struct {
|
||||
external Pointer<Utf8> model;
|
||||
}
|
||||
|
||||
final class SherpaOnnxAudioTaggingModelConfig extends Struct {
|
||||
external SherpaOnnxOfflineZipformerAudioTaggingModelConfig zipformer;
|
||||
external Pointer<Utf8> ced;
|
||||
|
||||
@Int32()
|
||||
external int numThreads;
|
||||
|
||||
@Int32()
|
||||
external int debug;
|
||||
|
||||
external Pointer<Utf8> provider;
|
||||
}
|
||||
|
||||
final class SherpaOnnxAudioTaggingConfig extends Struct {
|
||||
external SherpaOnnxAudioTaggingModelConfig model;
|
||||
external Pointer<Utf8> labels;
|
||||
|
||||
@Int32()
|
||||
external int topK;
|
||||
}
|
||||
|
||||
final class SherpaOnnxAudioEvent extends Struct {
|
||||
external Pointer<Utf8> name;
|
||||
|
||||
@Int32()
|
||||
external int index;
|
||||
|
||||
@Float()
|
||||
external double prob;
|
||||
}
|
||||
|
||||
final class SherpaOnnxOfflineTtsVitsModelConfig extends Struct {
|
||||
external Pointer<Utf8> model;
|
||||
external Pointer<Utf8> lexicon;
|
||||
@@ -303,6 +338,8 @@ final class SherpaOnnxKeywordSpotterConfig extends Struct {
|
||||
external Pointer<Utf8> keywordsFile;
|
||||
}
|
||||
|
||||
final class SherpaOnnxAudioTagging extends Opaque {}
|
||||
|
||||
final class SherpaOnnxKeywordSpotter extends Opaque {}
|
||||
|
||||
final class SherpaOnnxOfflineTts extends Opaque {}
|
||||
@@ -323,6 +360,40 @@ final class SherpaOnnxSpeakerEmbeddingExtractor extends Opaque {}
|
||||
|
||||
final class SherpaOnnxSpeakerEmbeddingManager extends Opaque {}
|
||||
|
||||
typedef SherpaOnnxCreateAudioTaggingNative = Pointer<SherpaOnnxAudioTagging>
|
||||
Function(Pointer<SherpaOnnxAudioTaggingConfig>);
|
||||
|
||||
typedef SherpaOnnxCreateAudioTagging = SherpaOnnxCreateAudioTaggingNative;
|
||||
|
||||
typedef SherpaOnnxDestroyAudioTaggingNative = Void Function(
|
||||
Pointer<SherpaOnnxAudioTagging>);
|
||||
|
||||
typedef SherpaOnnxDestroyAudioTagging = void Function(
|
||||
Pointer<SherpaOnnxAudioTagging>);
|
||||
|
||||
typedef SherpaOnnxAudioTaggingCreateOfflineStreamNative
|
||||
= Pointer<SherpaOnnxOfflineStream> Function(
|
||||
Pointer<SherpaOnnxAudioTagging>);
|
||||
|
||||
typedef SherpaOnnxAudioTaggingCreateOfflineStream
|
||||
= SherpaOnnxAudioTaggingCreateOfflineStreamNative;
|
||||
|
||||
typedef SherpaOnnxAudioTaggingComputeNative
|
||||
= Pointer<Pointer<SherpaOnnxAudioEvent>> Function(
|
||||
Pointer<SherpaOnnxAudioTagging>,
|
||||
Pointer<SherpaOnnxOfflineStream>,
|
||||
Int32);
|
||||
|
||||
typedef SherpaOnnxAudioTaggingCompute
|
||||
= Pointer<Pointer<SherpaOnnxAudioEvent>> Function(
|
||||
Pointer<SherpaOnnxAudioTagging>, Pointer<SherpaOnnxOfflineStream>, int);
|
||||
|
||||
typedef SherpaOnnxAudioTaggingFreeResultsNative = Void Function(
|
||||
Pointer<Pointer<SherpaOnnxAudioEvent>>);
|
||||
|
||||
typedef SherpaOnnxAudioTaggingFreeResults = void Function(
|
||||
Pointer<Pointer<SherpaOnnxAudioEvent>>);
|
||||
|
||||
typedef CreateKeywordSpotterNative = Pointer<SherpaOnnxKeywordSpotter> Function(
|
||||
Pointer<SherpaOnnxKeywordSpotterConfig>);
|
||||
|
||||
@@ -804,6 +875,13 @@ typedef SherpaOnnxFreeWaveNative = Void Function(Pointer<SherpaOnnxWave>);
|
||||
typedef SherpaOnnxFreeWave = void Function(Pointer<SherpaOnnxWave>);
|
||||
|
||||
class SherpaOnnxBindings {
|
||||
static SherpaOnnxCreateAudioTagging? sherpaOnnxCreateAudioTagging;
|
||||
static SherpaOnnxDestroyAudioTagging? sherpaOnnxDestroyAudioTagging;
|
||||
static SherpaOnnxAudioTaggingCreateOfflineStream?
|
||||
sherpaOnnxAudioTaggingCreateOfflineStream;
|
||||
static SherpaOnnxAudioTaggingCompute? sherpaOnnxAudioTaggingCompute;
|
||||
static SherpaOnnxAudioTaggingFreeResults? sherpaOnnxAudioTaggingFreeResults;
|
||||
|
||||
static CreateKeywordSpotter? createKeywordSpotter;
|
||||
static DestroyKeywordSpotter? destroyKeywordSpotter;
|
||||
static CreateKeywordStream? createKeywordStream;
|
||||
@@ -958,6 +1036,33 @@ class SherpaOnnxBindings {
|
||||
static SherpaOnnxFreeWave? freeWave;
|
||||
|
||||
static void init(DynamicLibrary dynamicLibrary) {
|
||||
sherpaOnnxCreateAudioTagging ??= dynamicLibrary
|
||||
.lookup<NativeFunction<SherpaOnnxCreateAudioTaggingNative>>(
|
||||
'SherpaOnnxCreateAudioTagging')
|
||||
.asFunction();
|
||||
|
||||
sherpaOnnxDestroyAudioTagging ??= dynamicLibrary
|
||||
.lookup<NativeFunction<SherpaOnnxDestroyAudioTaggingNative>>(
|
||||
'SherpaOnnxDestroyAudioTagging')
|
||||
.asFunction();
|
||||
|
||||
sherpaOnnxAudioTaggingCreateOfflineStream ??= dynamicLibrary
|
||||
.lookup<
|
||||
NativeFunction<
|
||||
SherpaOnnxAudioTaggingCreateOfflineStreamNative>>(
|
||||
'SherpaOnnxAudioTaggingCreateOfflineStream')
|
||||
.asFunction();
|
||||
|
||||
sherpaOnnxAudioTaggingCompute ??= dynamicLibrary
|
||||
.lookup<NativeFunction<SherpaOnnxAudioTaggingComputeNative>>(
|
||||
'SherpaOnnxAudioTaggingCompute')
|
||||
.asFunction();
|
||||
|
||||
sherpaOnnxAudioTaggingFreeResults ??= dynamicLibrary
|
||||
.lookup<NativeFunction<SherpaOnnxAudioTaggingFreeResultsNative>>(
|
||||
'SherpaOnnxAudioTaggingFreeResults')
|
||||
.asFunction();
|
||||
|
||||
createKeywordSpotter ??= dynamicLibrary
|
||||
.lookup<NativeFunction<CreateKeywordSpotterNative>>(
|
||||
'SherpaOnnxCreateKeywordSpotter')
|
||||
|
||||
Reference in New Issue
Block a user