Add Java API for audio tagging (#820)

This commit is contained in:
Fangjun Kuang
2024-04-28 22:26:04 +08:00
committed by GitHub
parent 5407f880c0
commit 88202f05bb
39 changed files with 476 additions and 129 deletions

View File

@@ -0,0 +1,32 @@
// Copyright 2024 Xiaomi Corporation
package com.k2fsa.sherpa.onnx;
public class AudioEvent {
private String name = "";
private int index = 0;
private float prob = 0;
public AudioEvent(String name, int index, float prob) {
this.name = name;
this.index = index;
this.prob = prob;
}
public String getName() {
return name;
}
public int getIndex() {
return index;
}
public float getProb() {
return prob;
}
@Override
public String toString() {
return String.format("AudioEven(name=%s, index=%d, prob=%.3f)\n", name, index, prob);
}
}

View File

@@ -0,0 +1,62 @@
// Copyright 2024 Xiaomi Corporation
package com.k2fsa.sherpa.onnx;
public class AudioTagging {
static {
System.loadLibrary("sherpa-onnx-jni");
}
private long ptr = 0;
public AudioTagging(AudioTaggingConfig config) {
ptr = newFromFile(config);
}
public OfflineStream createStream() {
long p = createStream(ptr);
return new OfflineStream(p);
}
public AudioEvent[] compute(OfflineStream stream) {
return compute(stream, -1);
}
public AudioEvent[] compute(OfflineStream stream, int topK) {
Object[] arr = compute(ptr, stream.getPtr(), topK);
AudioEvent[] events = new AudioEvent[arr.length];
for (int i = 0; i < arr.length; ++i) {
Object[] obj = (Object[]) arr[i];
String name = (String) obj[0];
int index = (int) obj[1];
float prob = (float) obj[2];
events[i] = new AudioEvent(name, index, prob);
}
return events;
}
@Override
protected void finalize() throws Throwable {
release();
}
// You'd better call it manually if it is not used anymore
public void release() {
if (this.ptr == 0) {
return;
}
delete(this.ptr);
this.ptr = 0;
}
private native void delete(long ptr);
private native long newFromFile(AudioTaggingConfig config);
private native long createStream(long ptr);
private native Object[] compute(long ptr, long streamPtr, int topK);
}

View File

@@ -0,0 +1,44 @@
// Copyright 2024 Xiaomi Corporation
package com.k2fsa.sherpa.onnx;
public class AudioTaggingConfig {
private final AudioTaggingModelConfig model;
private final String labels;
private final int topK;
private AudioTaggingConfig(Builder builder) {
this.model = builder.model;
this.labels = builder.labels;
this.topK = builder.topK;
}
public static Builder builder() {
return new AudioTaggingConfig.Builder();
}
public static class Builder {
private AudioTaggingModelConfig model = AudioTaggingModelConfig.builder().build();
private String labels = "";
private int topK = 5;
public AudioTaggingConfig build() {
return new AudioTaggingConfig(this);
}
public Builder setModel(AudioTaggingModelConfig model) {
this.model = model;
return this;
}
public Builder setLabels(String labels) {
this.labels = labels;
return this;
}
public Builder setTopK(int topK) {
this.topK = topK;
return this;
}
}
}

View File

@@ -0,0 +1,60 @@
// Copyright 2024 Xiaomi Corporation
package com.k2fsa.sherpa.onnx;
public class AudioTaggingModelConfig {
private final OfflineZipformerAudioTaggingModelConfig zipformer;
private final String ced;
private final int numThreads;
private final boolean debug;
private final String provider;
private AudioTaggingModelConfig(Builder builder) {
this.zipformer = builder.zipformer;
this.ced = builder.ced;
this.numThreads = builder.numThreads;
this.debug = builder.debug;
this.provider = builder.provider;
}
public static Builder builder() {
return new Builder();
}
public static class Builder {
private OfflineZipformerAudioTaggingModelConfig zipformer = OfflineZipformerAudioTaggingModelConfig.builder().build();
private String ced = "";
private int numThreads = 1;
private boolean debug = true;
private String provider = "cpu";
public AudioTaggingModelConfig build() {
return new AudioTaggingModelConfig(this);
}
public Builder setZipformer(OfflineZipformerAudioTaggingModelConfig zipformer) {
this.zipformer = zipformer;
return this;
}
public Builder setCED(String ced) {
this.ced = ced;
return this;
}
public Builder setNumThreads(int numThreads) {
this.numThreads = numThreads;
return this;
}
public Builder setDebug(boolean debug) {
this.debug = debug;
return this;
}
public Builder setProvider(String provider) {
this.provider = provider;
return this;
}
}
}

View File

@@ -1,5 +1,6 @@
// Copyright 2022-2023 by zhaoming
// Copyright 2024 Xiaomi Corporation
package com.k2fsa.sherpa.onnx;
public class EndpointRule {

View File

@@ -1,4 +1,5 @@
// Copyright 2024 Xiaomi Corporation
package com.k2fsa.sherpa.onnx;
public class OfflineModelConfig {

View File

@@ -1,4 +1,5 @@
// Copyright 2024 Xiaomi Corporation
package com.k2fsa.sherpa.onnx;
public class OfflineNemoEncDecCtcModelConfig {

View File

@@ -1,4 +1,5 @@
// Copyright 2024 Xiaomi Corporation
package com.k2fsa.sherpa.onnx;
public class OfflineParaformerModelConfig {

View File

@@ -7,7 +7,7 @@ public class OfflinePunctuation {
System.loadLibrary("sherpa-onnx-jni");
}
private long ptr = 0; // this is the asr engine ptrss
private long ptr = 0;
public OfflinePunctuation(OfflinePunctuationConfig config) {
ptr = newFromFile(config);

View File

@@ -1,4 +1,5 @@
// Copyright 2024 Xiaomi Corporation
package com.k2fsa.sherpa.onnx;
public class OfflineRecognizer {

View File

@@ -1,4 +1,5 @@
// Copyright 2024 Xiaomi Corporation
package com.k2fsa.sherpa.onnx;
public class OfflineRecognizerConfig {

View File

@@ -1,4 +1,5 @@
// Copyright 2024 Xiaomi Corporation
package com.k2fsa.sherpa.onnx;
public class OfflineRecognizerResult {

View File

@@ -1,4 +1,5 @@
// Copyright 2024 Xiaomi Corporation
package com.k2fsa.sherpa.onnx;
public class OfflineStream {

View File

@@ -1,4 +1,5 @@
// Copyright 2024 Xiaomi Corporation
package com.k2fsa.sherpa.onnx;
public class OfflineTransducerModelConfig {

View File

@@ -1,4 +1,5 @@
// Copyright 2024 Xiaomi Corporation
package com.k2fsa.sherpa.onnx;
public class OfflineWhisperModelConfig {

View File

@@ -0,0 +1,32 @@
// Copyright 2024 Xiaomi Corporation
package com.k2fsa.sherpa.onnx;
public class OfflineZipformerAudioTaggingModelConfig {
private final String model;
private OfflineZipformerAudioTaggingModelConfig(Builder builder) {
this.model = builder.model;
}
public static Builder builder() {
return new Builder();
}
public String getModel() {
return model;
}
public static class Builder {
private String model = "";
public OfflineZipformerAudioTaggingModelConfig build() {
return new OfflineZipformerAudioTaggingModelConfig(this);
}
public Builder setModel(String model) {
this.model = model;
return this;
}
}
}

View File

@@ -1,4 +1,5 @@
// Copyright 2024 Xiaomi Corporation
package com.k2fsa.sherpa.onnx;
public class OnlineCtcFstDecoderConfig {

View File

@@ -1,7 +1,7 @@
// Copyright 2022-2023 by zhaoming
// Copyright 2024 Xiaomi Corporation
package com.k2fsa.sherpa.onnx;
package com.k2fsa.sherpa.onnx;
public class OnlineRecognizer {
static {

View File

@@ -1,5 +1,6 @@
// Copyright 2022-2023 by zhaoming
// Copyright 2024 Xiaomi Corporation
package com.k2fsa.sherpa.onnx;
public class OnlineRecognizerConfig {

View File

@@ -1,4 +1,5 @@
// Copyright 2024 Xiaomi Corporation
package com.k2fsa.sherpa.onnx;
public class OnlineRecognizerResult {

View File

@@ -1,5 +1,6 @@
// Copyright 2022-2023 by zhaoming
// Copyright 2024 Xiaomi Corporation
package com.k2fsa.sherpa.onnx;
public class OnlineStream {

View File

@@ -1,4 +1,5 @@
// Copyright 2024 Xiaomi Corporation
package com.k2fsa.sherpa.onnx;
public class OnlineZipformer2CtcModelConfig {