Add Kotlin and Java API for online punctuation models (#1936)
This commit is contained in:
@@ -53,6 +53,10 @@ java_files += OfflinePunctuationModelConfig.java
|
||||
java_files += OfflinePunctuationConfig.java
|
||||
java_files += OfflinePunctuation.java
|
||||
|
||||
java_files += OnlinePunctuationModelConfig.java
|
||||
java_files += OnlinePunctuationConfig.java
|
||||
java_files += OnlinePunctuation.java
|
||||
|
||||
java_files += OfflineZipformerAudioTaggingModelConfig.java
|
||||
java_files += AudioTaggingModelConfig.java
|
||||
java_files += AudioTaggingConfig.java
|
||||
|
||||
@@ -0,0 +1,39 @@
|
||||
// Copyright 2024 Xiaomi Corporation
|
||||
|
||||
package com.k2fsa.sherpa.onnx;
|
||||
|
||||
public class OnlinePunctuation {
|
||||
static {
|
||||
System.loadLibrary("sherpa-onnx-jni");
|
||||
}
|
||||
|
||||
private long ptr = 0;
|
||||
|
||||
public OnlinePunctuation(OnlinePunctuationConfig config) {
|
||||
ptr = newFromFile(config);
|
||||
}
|
||||
|
||||
public String addPunctuation(String text) {
|
||||
return addPunctuation(ptr, text);
|
||||
}
|
||||
|
||||
@Override
|
||||
protected void finalize() throws Throwable {
|
||||
release();
|
||||
}
|
||||
|
||||
// You'd better call it manually if it is not used anymore
|
||||
public void release() {
|
||||
if (this.ptr == 0) {
|
||||
return;
|
||||
}
|
||||
delete(this.ptr);
|
||||
this.ptr = 0;
|
||||
}
|
||||
|
||||
private native void delete(long ptr);
|
||||
|
||||
private native long newFromFile(OnlinePunctuationConfig config);
|
||||
|
||||
private native String addPunctuation(long ptr, String text);
|
||||
}
|
||||
@@ -0,0 +1,33 @@
|
||||
// Copyright 2024 Xiaomi Corporation
|
||||
|
||||
package com.k2fsa.sherpa.onnx;
|
||||
|
||||
public class OnlinePunctuationConfig {
|
||||
private final OnlinePunctuationModelConfig model;
|
||||
|
||||
private OnlinePunctuationConfig(Builder builder) {
|
||||
this.model = builder.model;
|
||||
}
|
||||
|
||||
public static Builder builder() {
|
||||
return new Builder();
|
||||
}
|
||||
|
||||
public OnlinePunctuationModelConfig getModel() {
|
||||
return model;
|
||||
}
|
||||
|
||||
|
||||
public static class Builder {
|
||||
private OnlinePunctuationModelConfig model = OnlinePunctuationModelConfig.builder().build();
|
||||
|
||||
public OnlinePunctuationConfig build() {
|
||||
return new OnlinePunctuationConfig(this);
|
||||
}
|
||||
|
||||
public Builder setModel(OnlinePunctuationModelConfig model) {
|
||||
this.model = model;
|
||||
return this;
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,68 @@
|
||||
// Copyright 2024 Xiaomi Corporation
|
||||
|
||||
package com.k2fsa.sherpa.onnx;
|
||||
|
||||
public class OnlinePunctuationModelConfig {
|
||||
private final String cnnBilstm;
|
||||
private final String bpeVocab;
|
||||
private final int numThreads;
|
||||
private final boolean debug;
|
||||
private final String provider;
|
||||
|
||||
private OnlinePunctuationModelConfig(Builder builder) {
|
||||
this.cnnBilstm = builder.cnnBilstm;
|
||||
this.bpeVocab = builder.bpeVocab;
|
||||
this.numThreads = builder.numThreads;
|
||||
this.debug = builder.debug;
|
||||
this.provider = builder.provider;
|
||||
}
|
||||
|
||||
public static Builder builder() {
|
||||
return new Builder();
|
||||
}
|
||||
|
||||
public String getCnnBilstm() {
|
||||
return cnnBilstm;
|
||||
}
|
||||
|
||||
public String getBpeVocab() {
|
||||
return bpeVocab;
|
||||
}
|
||||
|
||||
public static class Builder {
|
||||
private String cnnBilstm = "";
|
||||
private String bpeVocab = "";
|
||||
private int numThreads = 1;
|
||||
private boolean debug = true;
|
||||
private String provider = "cpu";
|
||||
|
||||
public OnlinePunctuationModelConfig build() {
|
||||
return new OnlinePunctuationModelConfig(this);
|
||||
}
|
||||
|
||||
public Builder setCnnBilstm(String cnnBilstm) {
|
||||
this.cnnBilstm = cnnBilstm;
|
||||
return this;
|
||||
}
|
||||
|
||||
public Builder setBpeVocab(String bpeVocab) {
|
||||
this.bpeVocab = bpeVocab;
|
||||
return this;
|
||||
}
|
||||
|
||||
public Builder setNumThreads(int numThreads) {
|
||||
this.numThreads = numThreads;
|
||||
return this;
|
||||
}
|
||||
|
||||
public Builder setDebug(boolean debug) {
|
||||
this.debug = debug;
|
||||
return this;
|
||||
}
|
||||
|
||||
public Builder setProvider(String provider) {
|
||||
this.provider = provider;
|
||||
return this;
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -17,6 +17,7 @@ set(sources
|
||||
offline-punctuation.cc
|
||||
offline-recognizer.cc
|
||||
offline-stream.cc
|
||||
online-punctuation.cc
|
||||
online-recognizer.cc
|
||||
online-stream.cc
|
||||
speaker-embedding-extractor.cc
|
||||
|
||||
117
sherpa-onnx/jni/online-punctuation.cc
Normal file
117
sherpa-onnx/jni/online-punctuation.cc
Normal file
@@ -0,0 +1,117 @@
|
||||
// sherpa-onnx/jni/online-punctuation.cc
|
||||
//
|
||||
// Copyright (c) 2025 Xiaomi Corporation
|
||||
|
||||
#include "sherpa-onnx/csrc/online-punctuation.h"
|
||||
|
||||
#include "sherpa-onnx/csrc/macros.h"
|
||||
#include "sherpa-onnx/jni/common.h"
|
||||
|
||||
namespace sherpa_onnx {
|
||||
|
||||
static OnlinePunctuationConfig GetOnlinePunctuationConfig(JNIEnv *env,
|
||||
jobject config) {
|
||||
OnlinePunctuationConfig ans;
|
||||
|
||||
jclass cls = env->GetObjectClass(config);
|
||||
jfieldID fid;
|
||||
|
||||
fid = env->GetFieldID(cls, "model",
|
||||
"Lcom/k2fsa/sherpa/onnx/OnlinePunctuationModelConfig;");
|
||||
jobject model_config = env->GetObjectField(config, fid);
|
||||
jclass model_config_cls = env->GetObjectClass(model_config);
|
||||
|
||||
fid = env->GetFieldID(model_config_cls, "cnnBilstm", "Ljava/lang/String;");
|
||||
jstring s = (jstring)env->GetObjectField(model_config, fid);
|
||||
const char *p = env->GetStringUTFChars(s, nullptr);
|
||||
ans.model.cnn_bilstm = p;
|
||||
env->ReleaseStringUTFChars(s, p);
|
||||
|
||||
fid = env->GetFieldID(model_config_cls, "bpeVocab", "Ljava/lang/String;");
|
||||
s = (jstring)env->GetObjectField(model_config, fid);
|
||||
p = env->GetStringUTFChars(s, nullptr);
|
||||
ans.model.bpe_vocab = p;
|
||||
env->ReleaseStringUTFChars(s, p);
|
||||
|
||||
fid = env->GetFieldID(model_config_cls, "numThreads", "I");
|
||||
ans.model.num_threads = env->GetIntField(model_config, fid);
|
||||
|
||||
fid = env->GetFieldID(model_config_cls, "debug", "Z");
|
||||
ans.model.debug = env->GetBooleanField(model_config, fid);
|
||||
|
||||
fid = env->GetFieldID(model_config_cls, "provider", "Ljava/lang/String;");
|
||||
s = (jstring)env->GetObjectField(model_config, fid);
|
||||
p = env->GetStringUTFChars(s, nullptr);
|
||||
ans.model.provider = p;
|
||||
env->ReleaseStringUTFChars(s, p);
|
||||
|
||||
return ans;
|
||||
}
|
||||
|
||||
} // namespace sherpa_onnx
|
||||
|
||||
SHERPA_ONNX_EXTERN_C
|
||||
JNIEXPORT jlong JNICALL
|
||||
Java_com_k2fsa_sherpa_onnx_OnlinePunctuation_newFromAsset(JNIEnv *env,
|
||||
jobject /*obj*/,
|
||||
jobject asset_manager,
|
||||
jobject _config) {
|
||||
#if __ANDROID_API__ >= 9
|
||||
AAssetManager *mgr = AAssetManager_fromJava(env, asset_manager);
|
||||
if (!mgr) {
|
||||
SHERPA_ONNX_LOGE("Failed to get asset manager: %p", mgr);
|
||||
return 0;
|
||||
}
|
||||
#endif
|
||||
auto config = sherpa_onnx::GetOnlinePunctuationConfig(env, _config);
|
||||
SHERPA_ONNX_LOGE("config:\n%s", config.ToString().c_str());
|
||||
|
||||
auto model = new sherpa_onnx::OnlinePunctuation(
|
||||
#if __ANDROID_API__ >= 9
|
||||
mgr,
|
||||
#endif
|
||||
config);
|
||||
|
||||
return (jlong)model;
|
||||
}
|
||||
|
||||
SHERPA_ONNX_EXTERN_C
|
||||
JNIEXPORT jlong JNICALL
|
||||
Java_com_k2fsa_sherpa_onnx_OnlinePunctuation_newFromFile(JNIEnv *env,
|
||||
jobject /*obj*/,
|
||||
jobject _config) {
|
||||
auto config = sherpa_onnx::GetOnlinePunctuationConfig(env, _config);
|
||||
SHERPA_ONNX_LOGE("config:\n%s", config.ToString().c_str());
|
||||
|
||||
if (!config.Validate()) {
|
||||
SHERPA_ONNX_LOGE("Errors found in config!");
|
||||
return 0;
|
||||
}
|
||||
|
||||
auto model = new sherpa_onnx::OnlinePunctuation(config);
|
||||
|
||||
return (jlong)model;
|
||||
}
|
||||
|
||||
SHERPA_ONNX_EXTERN_C
|
||||
JNIEXPORT void JNICALL Java_com_k2fsa_sherpa_onnx_OnlinePunctuation_delete(
|
||||
JNIEnv * /*env*/, jobject /*obj*/, jlong ptr) {
|
||||
delete reinterpret_cast<sherpa_onnx::OnlinePunctuation *>(ptr);
|
||||
}
|
||||
|
||||
SHERPA_ONNX_EXTERN_C
|
||||
JNIEXPORT jstring JNICALL
|
||||
Java_com_k2fsa_sherpa_onnx_OnlinePunctuation_addPunctuation(JNIEnv *env,
|
||||
jobject /*obj*/,
|
||||
jlong ptr,
|
||||
jstring text) {
|
||||
auto punct = reinterpret_cast<const sherpa_onnx::OnlinePunctuation *>(ptr);
|
||||
|
||||
const char *ptext = env->GetStringUTFChars(text, nullptr);
|
||||
|
||||
std::string result = punct->AddPunctuationWithCase(ptext);
|
||||
|
||||
env->ReleaseStringUTFChars(text, ptext);
|
||||
|
||||
return env->NewStringUTF(result.c_str());
|
||||
}
|
||||
61
sherpa-onnx/kotlin-api/OnlinePunctuation.kt
Normal file
61
sherpa-onnx/kotlin-api/OnlinePunctuation.kt
Normal file
@@ -0,0 +1,61 @@
|
||||
package com.k2fsa.sherpa.onnx
|
||||
|
||||
import android.content.res.AssetManager
|
||||
|
||||
data class OnlinePunctuationModelConfig(
|
||||
var cnnBilstm: String = "",
|
||||
var bpeVocab: String = "",
|
||||
var numThreads: Int = 1,
|
||||
var debug: Boolean = false,
|
||||
var provider: String = "cpu",
|
||||
)
|
||||
|
||||
|
||||
data class OnlinePunctuationConfig(
|
||||
var model: OnlinePunctuationModelConfig,
|
||||
)
|
||||
|
||||
class OnlinePunctuation(
|
||||
assetManager: AssetManager? = null,
|
||||
config: OnlinePunctuationConfig,
|
||||
) {
|
||||
private var ptr: Long
|
||||
|
||||
init {
|
||||
ptr = if (assetManager != null) {
|
||||
newFromAsset(assetManager, config)
|
||||
} else {
|
||||
newFromFile(config)
|
||||
}
|
||||
}
|
||||
|
||||
protected fun finalize() {
|
||||
if (ptr != 0L) {
|
||||
delete(ptr)
|
||||
ptr = 0
|
||||
}
|
||||
}
|
||||
|
||||
fun release() = finalize()
|
||||
|
||||
fun addPunctuation(text: String) = addPunctuation(ptr, text)
|
||||
|
||||
private external fun delete(ptr: Long)
|
||||
|
||||
private external fun addPunctuation(ptr: Long, text: String): String
|
||||
|
||||
private external fun newFromAsset(
|
||||
assetManager: AssetManager,
|
||||
config: OnlinePunctuationConfig,
|
||||
): Long
|
||||
|
||||
private external fun newFromFile(
|
||||
config: OnlinePunctuationConfig,
|
||||
): Long
|
||||
|
||||
companion object {
|
||||
init {
|
||||
System.loadLibrary("sherpa-onnx-jni")
|
||||
}
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user