118 lines
3.9 KiB
C++
118 lines
3.9 KiB
C++
// sherpa-onnx/jni/online-punctuation.cc
|
|
//
|
|
// Copyright (c) 2025 Xiaomi Corporation
|
|
|
|
#include "sherpa-onnx/csrc/online-punctuation.h"
|
|
|
|
#include "sherpa-onnx/csrc/macros.h"
|
|
#include "sherpa-onnx/jni/common.h"
|
|
|
|
namespace sherpa_onnx {
|
|
|
|
static OnlinePunctuationConfig GetOnlinePunctuationConfig(JNIEnv *env,
|
|
jobject config) {
|
|
OnlinePunctuationConfig ans;
|
|
|
|
jclass cls = env->GetObjectClass(config);
|
|
jfieldID fid;
|
|
|
|
fid = env->GetFieldID(cls, "model",
|
|
"Lcom/k2fsa/sherpa/onnx/OnlinePunctuationModelConfig;");
|
|
jobject model_config = env->GetObjectField(config, fid);
|
|
jclass model_config_cls = env->GetObjectClass(model_config);
|
|
|
|
fid = env->GetFieldID(model_config_cls, "cnnBilstm", "Ljava/lang/String;");
|
|
jstring s = (jstring)env->GetObjectField(model_config, fid);
|
|
const char *p = env->GetStringUTFChars(s, nullptr);
|
|
ans.model.cnn_bilstm = p;
|
|
env->ReleaseStringUTFChars(s, p);
|
|
|
|
fid = env->GetFieldID(model_config_cls, "bpeVocab", "Ljava/lang/String;");
|
|
s = (jstring)env->GetObjectField(model_config, fid);
|
|
p = env->GetStringUTFChars(s, nullptr);
|
|
ans.model.bpe_vocab = p;
|
|
env->ReleaseStringUTFChars(s, p);
|
|
|
|
fid = env->GetFieldID(model_config_cls, "numThreads", "I");
|
|
ans.model.num_threads = env->GetIntField(model_config, fid);
|
|
|
|
fid = env->GetFieldID(model_config_cls, "debug", "Z");
|
|
ans.model.debug = env->GetBooleanField(model_config, fid);
|
|
|
|
fid = env->GetFieldID(model_config_cls, "provider", "Ljava/lang/String;");
|
|
s = (jstring)env->GetObjectField(model_config, fid);
|
|
p = env->GetStringUTFChars(s, nullptr);
|
|
ans.model.provider = p;
|
|
env->ReleaseStringUTFChars(s, p);
|
|
|
|
return ans;
|
|
}
|
|
|
|
} // namespace sherpa_onnx
|
|
|
|
SHERPA_ONNX_EXTERN_C
|
|
JNIEXPORT jlong JNICALL
|
|
Java_com_k2fsa_sherpa_onnx_OnlinePunctuation_newFromAsset(JNIEnv *env,
|
|
jobject /*obj*/,
|
|
jobject asset_manager,
|
|
jobject _config) {
|
|
#if __ANDROID_API__ >= 9
|
|
AAssetManager *mgr = AAssetManager_fromJava(env, asset_manager);
|
|
if (!mgr) {
|
|
SHERPA_ONNX_LOGE("Failed to get asset manager: %p", mgr);
|
|
return 0;
|
|
}
|
|
#endif
|
|
auto config = sherpa_onnx::GetOnlinePunctuationConfig(env, _config);
|
|
SHERPA_ONNX_LOGE("config:\n%s", config.ToString().c_str());
|
|
|
|
auto model = new sherpa_onnx::OnlinePunctuation(
|
|
#if __ANDROID_API__ >= 9
|
|
mgr,
|
|
#endif
|
|
config);
|
|
|
|
return (jlong)model;
|
|
}
|
|
|
|
SHERPA_ONNX_EXTERN_C
|
|
JNIEXPORT jlong JNICALL
|
|
Java_com_k2fsa_sherpa_onnx_OnlinePunctuation_newFromFile(JNIEnv *env,
|
|
jobject /*obj*/,
|
|
jobject _config) {
|
|
auto config = sherpa_onnx::GetOnlinePunctuationConfig(env, _config);
|
|
SHERPA_ONNX_LOGE("config:\n%s", config.ToString().c_str());
|
|
|
|
if (!config.Validate()) {
|
|
SHERPA_ONNX_LOGE("Errors found in config!");
|
|
return 0;
|
|
}
|
|
|
|
auto model = new sherpa_onnx::OnlinePunctuation(config);
|
|
|
|
return (jlong)model;
|
|
}
|
|
|
|
SHERPA_ONNX_EXTERN_C
|
|
JNIEXPORT void JNICALL Java_com_k2fsa_sherpa_onnx_OnlinePunctuation_delete(
|
|
JNIEnv * /*env*/, jobject /*obj*/, jlong ptr) {
|
|
delete reinterpret_cast<sherpa_onnx::OnlinePunctuation *>(ptr);
|
|
}
|
|
|
|
SHERPA_ONNX_EXTERN_C
|
|
JNIEXPORT jstring JNICALL
|
|
Java_com_k2fsa_sherpa_onnx_OnlinePunctuation_addPunctuation(JNIEnv *env,
|
|
jobject /*obj*/,
|
|
jlong ptr,
|
|
jstring text) {
|
|
auto punct = reinterpret_cast<const sherpa_onnx::OnlinePunctuation *>(ptr);
|
|
|
|
const char *ptext = env->GetStringUTFChars(text, nullptr);
|
|
|
|
std::string result = punct->AddPunctuationWithCase(ptext);
|
|
|
|
env->ReleaseStringUTFChars(text, ptext);
|
|
|
|
return env->NewStringUTF(result.c_str());
|
|
}
|