Add Kotlin and Java API for online punctuation models (#1936)

This commit is contained in:
Fangjun Kuang
2025-02-27 16:52:36 +08:00
committed by GitHub
parent 815ebac8f9
commit f5dfcf8d2f
16 changed files with 474 additions and 13 deletions

View File

@@ -53,6 +53,10 @@ java_files += OfflinePunctuationModelConfig.java
java_files += OfflinePunctuationConfig.java
java_files += OfflinePunctuation.java
java_files += OnlinePunctuationModelConfig.java
java_files += OnlinePunctuationConfig.java
java_files += OnlinePunctuation.java
java_files += OfflineZipformerAudioTaggingModelConfig.java
java_files += AudioTaggingModelConfig.java
java_files += AudioTaggingConfig.java

View File

@@ -0,0 +1,39 @@
// Copyright 2024 Xiaomi Corporation
package com.k2fsa.sherpa.onnx;
public class OnlinePunctuation {
static {
System.loadLibrary("sherpa-onnx-jni");
}
private long ptr = 0;
public OnlinePunctuation(OnlinePunctuationConfig config) {
ptr = newFromFile(config);
}
public String addPunctuation(String text) {
return addPunctuation(ptr, text);
}
@Override
protected void finalize() throws Throwable {
release();
}
// You'd better call it manually if it is not used anymore
public void release() {
if (this.ptr == 0) {
return;
}
delete(this.ptr);
this.ptr = 0;
}
private native void delete(long ptr);
private native long newFromFile(OnlinePunctuationConfig config);
private native String addPunctuation(long ptr, String text);
}

View File

@@ -0,0 +1,33 @@
// Copyright 2024 Xiaomi Corporation
package com.k2fsa.sherpa.onnx;
public class OnlinePunctuationConfig {
private final OnlinePunctuationModelConfig model;
private OnlinePunctuationConfig(Builder builder) {
this.model = builder.model;
}
public static Builder builder() {
return new Builder();
}
public OnlinePunctuationModelConfig getModel() {
return model;
}
public static class Builder {
private OnlinePunctuationModelConfig model = OnlinePunctuationModelConfig.builder().build();
public OnlinePunctuationConfig build() {
return new OnlinePunctuationConfig(this);
}
public Builder setModel(OnlinePunctuationModelConfig model) {
this.model = model;
return this;
}
}
}

View File

@@ -0,0 +1,68 @@
// Copyright 2024 Xiaomi Corporation
package com.k2fsa.sherpa.onnx;
public class OnlinePunctuationModelConfig {
private final String cnnBilstm;
private final String bpeVocab;
private final int numThreads;
private final boolean debug;
private final String provider;
private OnlinePunctuationModelConfig(Builder builder) {
this.cnnBilstm = builder.cnnBilstm;
this.bpeVocab = builder.bpeVocab;
this.numThreads = builder.numThreads;
this.debug = builder.debug;
this.provider = builder.provider;
}
public static Builder builder() {
return new Builder();
}
public String getCnnBilstm() {
return cnnBilstm;
}
public String getBpeVocab() {
return bpeVocab;
}
public static class Builder {
private String cnnBilstm = "";
private String bpeVocab = "";
private int numThreads = 1;
private boolean debug = true;
private String provider = "cpu";
public OnlinePunctuationModelConfig build() {
return new OnlinePunctuationModelConfig(this);
}
public Builder setCnnBilstm(String cnnBilstm) {
this.cnnBilstm = cnnBilstm;
return this;
}
public Builder setBpeVocab(String bpeVocab) {
this.bpeVocab = bpeVocab;
return this;
}
public Builder setNumThreads(int numThreads) {
this.numThreads = numThreads;
return this;
}
public Builder setDebug(boolean debug) {
this.debug = debug;
return this;
}
public Builder setProvider(String provider) {
this.provider = provider;
return this;
}
}
}

View File

@@ -17,6 +17,7 @@ set(sources
offline-punctuation.cc
offline-recognizer.cc
offline-stream.cc
online-punctuation.cc
online-recognizer.cc
online-stream.cc
speaker-embedding-extractor.cc

View File

@@ -0,0 +1,117 @@
// sherpa-onnx/jni/online-punctuation.cc
//
// Copyright (c) 2025 Xiaomi Corporation
#include "sherpa-onnx/csrc/online-punctuation.h"
#include "sherpa-onnx/csrc/macros.h"
#include "sherpa-onnx/jni/common.h"
namespace sherpa_onnx {
static OnlinePunctuationConfig GetOnlinePunctuationConfig(JNIEnv *env,
jobject config) {
OnlinePunctuationConfig ans;
jclass cls = env->GetObjectClass(config);
jfieldID fid;
fid = env->GetFieldID(cls, "model",
"Lcom/k2fsa/sherpa/onnx/OnlinePunctuationModelConfig;");
jobject model_config = env->GetObjectField(config, fid);
jclass model_config_cls = env->GetObjectClass(model_config);
fid = env->GetFieldID(model_config_cls, "cnnBilstm", "Ljava/lang/String;");
jstring s = (jstring)env->GetObjectField(model_config, fid);
const char *p = env->GetStringUTFChars(s, nullptr);
ans.model.cnn_bilstm = p;
env->ReleaseStringUTFChars(s, p);
fid = env->GetFieldID(model_config_cls, "bpeVocab", "Ljava/lang/String;");
s = (jstring)env->GetObjectField(model_config, fid);
p = env->GetStringUTFChars(s, nullptr);
ans.model.bpe_vocab = p;
env->ReleaseStringUTFChars(s, p);
fid = env->GetFieldID(model_config_cls, "numThreads", "I");
ans.model.num_threads = env->GetIntField(model_config, fid);
fid = env->GetFieldID(model_config_cls, "debug", "Z");
ans.model.debug = env->GetBooleanField(model_config, fid);
fid = env->GetFieldID(model_config_cls, "provider", "Ljava/lang/String;");
s = (jstring)env->GetObjectField(model_config, fid);
p = env->GetStringUTFChars(s, nullptr);
ans.model.provider = p;
env->ReleaseStringUTFChars(s, p);
return ans;
}
} // namespace sherpa_onnx
SHERPA_ONNX_EXTERN_C
JNIEXPORT jlong JNICALL
Java_com_k2fsa_sherpa_onnx_OnlinePunctuation_newFromAsset(JNIEnv *env,
jobject /*obj*/,
jobject asset_manager,
jobject _config) {
#if __ANDROID_API__ >= 9
AAssetManager *mgr = AAssetManager_fromJava(env, asset_manager);
if (!mgr) {
SHERPA_ONNX_LOGE("Failed to get asset manager: %p", mgr);
return 0;
}
#endif
auto config = sherpa_onnx::GetOnlinePunctuationConfig(env, _config);
SHERPA_ONNX_LOGE("config:\n%s", config.ToString().c_str());
auto model = new sherpa_onnx::OnlinePunctuation(
#if __ANDROID_API__ >= 9
mgr,
#endif
config);
return (jlong)model;
}
SHERPA_ONNX_EXTERN_C
JNIEXPORT jlong JNICALL
Java_com_k2fsa_sherpa_onnx_OnlinePunctuation_newFromFile(JNIEnv *env,
jobject /*obj*/,
jobject _config) {
auto config = sherpa_onnx::GetOnlinePunctuationConfig(env, _config);
SHERPA_ONNX_LOGE("config:\n%s", config.ToString().c_str());
if (!config.Validate()) {
SHERPA_ONNX_LOGE("Errors found in config!");
return 0;
}
auto model = new sherpa_onnx::OnlinePunctuation(config);
return (jlong)model;
}
SHERPA_ONNX_EXTERN_C
JNIEXPORT void JNICALL Java_com_k2fsa_sherpa_onnx_OnlinePunctuation_delete(
JNIEnv * /*env*/, jobject /*obj*/, jlong ptr) {
delete reinterpret_cast<sherpa_onnx::OnlinePunctuation *>(ptr);
}
SHERPA_ONNX_EXTERN_C
JNIEXPORT jstring JNICALL
Java_com_k2fsa_sherpa_onnx_OnlinePunctuation_addPunctuation(JNIEnv *env,
jobject /*obj*/,
jlong ptr,
jstring text) {
auto punct = reinterpret_cast<const sherpa_onnx::OnlinePunctuation *>(ptr);
const char *ptext = env->GetStringUTFChars(text, nullptr);
std::string result = punct->AddPunctuationWithCase(ptext);
env->ReleaseStringUTFChars(text, ptext);
return env->NewStringUTF(result.c_str());
}

View File

@@ -0,0 +1,61 @@
package com.k2fsa.sherpa.onnx
import android.content.res.AssetManager
data class OnlinePunctuationModelConfig(
var cnnBilstm: String = "",
var bpeVocab: String = "",
var numThreads: Int = 1,
var debug: Boolean = false,
var provider: String = "cpu",
)
data class OnlinePunctuationConfig(
var model: OnlinePunctuationModelConfig,
)
class OnlinePunctuation(
assetManager: AssetManager? = null,
config: OnlinePunctuationConfig,
) {
private var ptr: Long
init {
ptr = if (assetManager != null) {
newFromAsset(assetManager, config)
} else {
newFromFile(config)
}
}
protected fun finalize() {
if (ptr != 0L) {
delete(ptr)
ptr = 0
}
}
fun release() = finalize()
fun addPunctuation(text: String) = addPunctuation(ptr, text)
private external fun delete(ptr: Long)
private external fun addPunctuation(ptr: Long, text: String): String
private external fun newFromAsset(
assetManager: AssetManager,
config: OnlinePunctuationConfig,
): Long
private external fun newFromFile(
config: OnlinePunctuationConfig,
): Long
companion object {
init {
System.loadLibrary("sherpa-onnx-jni")
}
}
}