Add Dart API for audio tagging (#1181)

This commit is contained in:
Fangjun Kuang
2024-07-29 11:15:14 +08:00
committed by GitHub
parent 69b6b47d91
commit cd1fedaa49
30 changed files with 504 additions and 18 deletions

View File

@@ -0,0 +1,144 @@
// Copyright (c) 2024 Xiaomi Corporation
import 'dart:ffi';
import 'package:ffi/ffi.dart';
import './offline_stream.dart';
import './sherpa_onnx_bindings.dart';
class OfflineZipformerAudioTaggingModelConfig {
const OfflineZipformerAudioTaggingModelConfig({this.model = ''});
@override
String toString() {
return 'OfflineZipformerAudioTaggingModelConfig(model: $model)';
}
final String model;
}
class AudioTaggingModelConfig {
AudioTaggingModelConfig(
{this.zipformer = const OfflineZipformerAudioTaggingModelConfig(),
this.ced = '',
this.numThreads = 1,
this.provider = 'cpu',
this.debug = true});
@override
String toString() {
return 'AudioTaggingModelConfig(zipformer: $zipformer, ced: $ced, numThreads: $numThreads, provider: $provider, debug: $debug)';
}
final OfflineZipformerAudioTaggingModelConfig zipformer;
final String ced;
final int numThreads;
final String provider;
final bool debug;
}
class AudioTaggingConfig {
AudioTaggingConfig({required this.model, this.labels = ''});
@override
String toString() {
return 'AudioTaggingConfig(model: $model, labels: $labels)';
}
final AudioTaggingModelConfig model;
final String labels;
}
class AudioEvent {
AudioEvent({required this.name, required this.index, required this.prob});
@override
String toString() {
return 'AudioEvent(name: $name, index: $index, prob: $prob)';
}
final String name;
final int index;
final double prob;
}
class AudioTagging {
AudioTagging._({required this.ptr, required this.config});
// The user has to invoke AudioTagging.free() to avoid memory leak.
factory AudioTagging({required AudioTaggingConfig config}) {
final c = calloc<SherpaOnnxAudioTaggingConfig>();
final zipformerPtr = config.model.zipformer.model.toNativeUtf8();
c.ref.model.zipformer.model = zipformerPtr;
final cedPtr = config.model.ced.toNativeUtf8();
c.ref.model.ced = cedPtr;
c.ref.model.numThreads = config.model.numThreads;
final providerPtr = config.model.provider.toNativeUtf8();
c.ref.model.provider = providerPtr;
c.ref.model.debug = config.model.debug ? 1 : 0;
final labelsPtr = config.labels.toNativeUtf8();
c.ref.labels = labelsPtr;
final ptr =
SherpaOnnxBindings.sherpaOnnxCreateAudioTagging?.call(c) ?? nullptr;
calloc.free(labelsPtr);
calloc.free(providerPtr);
calloc.free(cedPtr);
calloc.free(zipformerPtr);
calloc.free(c);
return AudioTagging._(ptr: ptr, config: config);
}
void free() {
SherpaOnnxBindings.sherpaOnnxDestroyAudioTagging?.call(ptr);
ptr = nullptr;
}
/// The user has to invoke stream.free() on the returned instance
/// to avoid memory leak
OfflineStream createStream() {
final p = SherpaOnnxBindings.sherpaOnnxAudioTaggingCreateOfflineStream
?.call(ptr) ??
nullptr;
return OfflineStream(ptr: p);
}
List<AudioEvent> compute({required OfflineStream stream, required int topK}) {
final pp = SherpaOnnxBindings.sherpaOnnxAudioTaggingCompute
?.call(ptr, stream.ptr, topK) ??
nullptr;
final ans = <AudioEvent>[];
if (pp == nullptr) {
return ans;
}
var i = 0;
while (pp[i] != nullptr) {
final p = pp[i];
final name = p.ref.name.toDartString();
final index = p.ref.index;
final prob = p.ref.prob;
final e = AudioEvent(name: name, index: index, prob: prob);
ans.add(e);
i += 1;
}
SherpaOnnxBindings.sherpaOnnxAudioTaggingFreeResults?.call(pp);
return ans;
}
Pointer<SherpaOnnxAudioTagging> ptr;
final AudioTaggingConfig config;
}

View File

@@ -2,6 +2,41 @@
import 'dart:ffi';
import 'package:ffi/ffi.dart';
final class SherpaOnnxOfflineZipformerAudioTaggingModelConfig extends Struct {
external Pointer<Utf8> model;
}
final class SherpaOnnxAudioTaggingModelConfig extends Struct {
external SherpaOnnxOfflineZipformerAudioTaggingModelConfig zipformer;
external Pointer<Utf8> ced;
@Int32()
external int numThreads;
@Int32()
external int debug;
external Pointer<Utf8> provider;
}
final class SherpaOnnxAudioTaggingConfig extends Struct {
external SherpaOnnxAudioTaggingModelConfig model;
external Pointer<Utf8> labels;
@Int32()
external int topK;
}
final class SherpaOnnxAudioEvent extends Struct {
external Pointer<Utf8> name;
@Int32()
external int index;
@Float()
external double prob;
}
final class SherpaOnnxOfflineTtsVitsModelConfig extends Struct {
external Pointer<Utf8> model;
external Pointer<Utf8> lexicon;
@@ -303,6 +338,8 @@ final class SherpaOnnxKeywordSpotterConfig extends Struct {
external Pointer<Utf8> keywordsFile;
}
final class SherpaOnnxAudioTagging extends Opaque {}
final class SherpaOnnxKeywordSpotter extends Opaque {}
final class SherpaOnnxOfflineTts extends Opaque {}
@@ -323,6 +360,40 @@ final class SherpaOnnxSpeakerEmbeddingExtractor extends Opaque {}
final class SherpaOnnxSpeakerEmbeddingManager extends Opaque {}
typedef SherpaOnnxCreateAudioTaggingNative = Pointer<SherpaOnnxAudioTagging>
Function(Pointer<SherpaOnnxAudioTaggingConfig>);
typedef SherpaOnnxCreateAudioTagging = SherpaOnnxCreateAudioTaggingNative;
typedef SherpaOnnxDestroyAudioTaggingNative = Void Function(
Pointer<SherpaOnnxAudioTagging>);
typedef SherpaOnnxDestroyAudioTagging = void Function(
Pointer<SherpaOnnxAudioTagging>);
typedef SherpaOnnxAudioTaggingCreateOfflineStreamNative
= Pointer<SherpaOnnxOfflineStream> Function(
Pointer<SherpaOnnxAudioTagging>);
typedef SherpaOnnxAudioTaggingCreateOfflineStream
= SherpaOnnxAudioTaggingCreateOfflineStreamNative;
typedef SherpaOnnxAudioTaggingComputeNative
= Pointer<Pointer<SherpaOnnxAudioEvent>> Function(
Pointer<SherpaOnnxAudioTagging>,
Pointer<SherpaOnnxOfflineStream>,
Int32);
typedef SherpaOnnxAudioTaggingCompute
= Pointer<Pointer<SherpaOnnxAudioEvent>> Function(
Pointer<SherpaOnnxAudioTagging>, Pointer<SherpaOnnxOfflineStream>, int);
typedef SherpaOnnxAudioTaggingFreeResultsNative = Void Function(
Pointer<Pointer<SherpaOnnxAudioEvent>>);
typedef SherpaOnnxAudioTaggingFreeResults = void Function(
Pointer<Pointer<SherpaOnnxAudioEvent>>);
typedef CreateKeywordSpotterNative = Pointer<SherpaOnnxKeywordSpotter> Function(
Pointer<SherpaOnnxKeywordSpotterConfig>);
@@ -804,6 +875,13 @@ typedef SherpaOnnxFreeWaveNative = Void Function(Pointer<SherpaOnnxWave>);
typedef SherpaOnnxFreeWave = void Function(Pointer<SherpaOnnxWave>);
class SherpaOnnxBindings {
static SherpaOnnxCreateAudioTagging? sherpaOnnxCreateAudioTagging;
static SherpaOnnxDestroyAudioTagging? sherpaOnnxDestroyAudioTagging;
static SherpaOnnxAudioTaggingCreateOfflineStream?
sherpaOnnxAudioTaggingCreateOfflineStream;
static SherpaOnnxAudioTaggingCompute? sherpaOnnxAudioTaggingCompute;
static SherpaOnnxAudioTaggingFreeResults? sherpaOnnxAudioTaggingFreeResults;
static CreateKeywordSpotter? createKeywordSpotter;
static DestroyKeywordSpotter? destroyKeywordSpotter;
static CreateKeywordStream? createKeywordStream;
@@ -958,6 +1036,33 @@ class SherpaOnnxBindings {
static SherpaOnnxFreeWave? freeWave;
static void init(DynamicLibrary dynamicLibrary) {
sherpaOnnxCreateAudioTagging ??= dynamicLibrary
.lookup<NativeFunction<SherpaOnnxCreateAudioTaggingNative>>(
'SherpaOnnxCreateAudioTagging')
.asFunction();
sherpaOnnxDestroyAudioTagging ??= dynamicLibrary
.lookup<NativeFunction<SherpaOnnxDestroyAudioTaggingNative>>(
'SherpaOnnxDestroyAudioTagging')
.asFunction();
sherpaOnnxAudioTaggingCreateOfflineStream ??= dynamicLibrary
.lookup<
NativeFunction<
SherpaOnnxAudioTaggingCreateOfflineStreamNative>>(
'SherpaOnnxAudioTaggingCreateOfflineStream')
.asFunction();
sherpaOnnxAudioTaggingCompute ??= dynamicLibrary
.lookup<NativeFunction<SherpaOnnxAudioTaggingComputeNative>>(
'SherpaOnnxAudioTaggingCompute')
.asFunction();
sherpaOnnxAudioTaggingFreeResults ??= dynamicLibrary
.lookup<NativeFunction<SherpaOnnxAudioTaggingFreeResultsNative>>(
'SherpaOnnxAudioTaggingFreeResults')
.asFunction();
createKeywordSpotter ??= dynamicLibrary
.lookup<NativeFunction<CreateKeywordSpotterNative>>(
'SherpaOnnxCreateKeywordSpotter')