Add JNI (#57)
This commit is contained in:
2
.github/scripts/.gitignore
vendored
Normal file
2
.github/scripts/.gitignore
vendored
Normal file
@@ -0,0 +1,2 @@
|
|||||||
|
Makefile
|
||||||
|
*.jar
|
||||||
4
.github/scripts/AssetManager.kt
vendored
Normal file
4
.github/scripts/AssetManager.kt
vendored
Normal file
@@ -0,0 +1,4 @@
|
|||||||
|
package android.content.res
|
||||||
|
|
||||||
|
// a dummy class for testing only
|
||||||
|
class AssetManager
|
||||||
45
.github/scripts/Main.kt
vendored
Normal file
45
.github/scripts/Main.kt
vendored
Normal file
@@ -0,0 +1,45 @@
|
|||||||
|
package com.k2fsa.sherpa.onnx
|
||||||
|
|
||||||
|
import android.content.res.AssetManager
|
||||||
|
|
||||||
|
fun main() {
|
||||||
|
var featConfig = FeatureConfig(
|
||||||
|
sampleRate=16000.0f,
|
||||||
|
featureDim=80,
|
||||||
|
)
|
||||||
|
|
||||||
|
var modelConfig = OnlineTransducerModelConfig(
|
||||||
|
encoder="./sherpa-onnx-streaming-zipformer-en-2023-02-21/encoder-epoch-99-avg-1.onnx",
|
||||||
|
decoder="./sherpa-onnx-streaming-zipformer-en-2023-02-21/decoder-epoch-99-avg-1.onnx",
|
||||||
|
joiner="./sherpa-onnx-streaming-zipformer-en-2023-02-21/joiner-epoch-99-avg-1.onnx",
|
||||||
|
numThreads=4,
|
||||||
|
debug=false,
|
||||||
|
)
|
||||||
|
|
||||||
|
var endpointConfig = EndpointConfig()
|
||||||
|
|
||||||
|
var config = OnlineRecognizerConfig(
|
||||||
|
modelConfig=modelConfig,
|
||||||
|
featConfig=featConfig,
|
||||||
|
endpointConfig=endpointConfig,
|
||||||
|
tokens="./sherpa-onnx-streaming-zipformer-en-2023-02-21/tokens.txt",
|
||||||
|
enableEndpoint=true,
|
||||||
|
)
|
||||||
|
|
||||||
|
var model = SherpaOnnx(
|
||||||
|
assetManager = AssetManager(),
|
||||||
|
config = config,
|
||||||
|
)
|
||||||
|
var samples = WaveReader.readWave(
|
||||||
|
assetManager = AssetManager(),
|
||||||
|
filename = "./sherpa-onnx-streaming-zipformer-en-2023-02-21/test_wavs/1089-134686-0001.wav",
|
||||||
|
)
|
||||||
|
|
||||||
|
model.decodeSamples(samples!!)
|
||||||
|
|
||||||
|
var tail_paddings = FloatArray(8000) // 0.5 seconds
|
||||||
|
model.decodeSamples(tail_paddings)
|
||||||
|
|
||||||
|
model.inputFinished()
|
||||||
|
println(model.text)
|
||||||
|
}
|
||||||
89
.github/scripts/SherpaOnnx.kt
vendored
Normal file
89
.github/scripts/SherpaOnnx.kt
vendored
Normal file
@@ -0,0 +1,89 @@
|
|||||||
|
package com.k2fsa.sherpa.onnx
|
||||||
|
|
||||||
|
import android.content.res.AssetManager
|
||||||
|
|
||||||
|
data class EndpointRule(
|
||||||
|
var mustContainNonSilence: Boolean,
|
||||||
|
var minTrailingSilence: Float,
|
||||||
|
var minUtteranceLength: Float,
|
||||||
|
)
|
||||||
|
|
||||||
|
data class EndpointConfig(
|
||||||
|
var rule1: EndpointRule = EndpointRule(false, 2.4f, 0.0f),
|
||||||
|
var rule2: EndpointRule = EndpointRule(true, 1.4f, 0.0f),
|
||||||
|
var rule3: EndpointRule = EndpointRule(false, 0.0f, 20.0f)
|
||||||
|
)
|
||||||
|
|
||||||
|
data class OnlineTransducerModelConfig(
|
||||||
|
var encoder: String,
|
||||||
|
var decoder: String,
|
||||||
|
var joiner: String,
|
||||||
|
var numThreads: Int = 4,
|
||||||
|
var debug: Boolean = false,
|
||||||
|
)
|
||||||
|
|
||||||
|
data class FeatureConfig(
|
||||||
|
var sampleRate: Float = 16000.0f,
|
||||||
|
var featureDim: Int = 80,
|
||||||
|
)
|
||||||
|
|
||||||
|
data class OnlineRecognizerConfig(
|
||||||
|
var featConfig: FeatureConfig = FeatureConfig(),
|
||||||
|
var modelConfig: OnlineTransducerModelConfig,
|
||||||
|
var tokens: String,
|
||||||
|
var endpointConfig: EndpointConfig = EndpointConfig(),
|
||||||
|
var enableEndpoint: Boolean,
|
||||||
|
)
|
||||||
|
|
||||||
|
class SherpaOnnx(
|
||||||
|
assetManager: AssetManager,
|
||||||
|
var config: OnlineRecognizerConfig
|
||||||
|
) {
|
||||||
|
private val ptr: Long
|
||||||
|
|
||||||
|
init {
|
||||||
|
ptr = new(assetManager, config)
|
||||||
|
}
|
||||||
|
|
||||||
|
protected fun finalize() {
|
||||||
|
delete(ptr)
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
fun decodeSamples(samples: FloatArray) =
|
||||||
|
decodeSamples(ptr, samples, sampleRate = config.featConfig.sampleRate)
|
||||||
|
|
||||||
|
fun inputFinished() = inputFinished(ptr)
|
||||||
|
fun reset() = reset(ptr)
|
||||||
|
fun isEndpoint(): Boolean = isEndpoint(ptr)
|
||||||
|
|
||||||
|
val text: String
|
||||||
|
get() = getText(ptr)
|
||||||
|
|
||||||
|
private external fun delete(ptr: Long)
|
||||||
|
|
||||||
|
private external fun new(
|
||||||
|
assetManager: AssetManager,
|
||||||
|
config: OnlineRecognizerConfig,
|
||||||
|
): Long
|
||||||
|
|
||||||
|
private external fun decodeSamples(ptr: Long, samples: FloatArray, sampleRate: Float)
|
||||||
|
private external fun inputFinished(ptr: Long)
|
||||||
|
private external fun getText(ptr: Long): String
|
||||||
|
private external fun reset(ptr: Long)
|
||||||
|
private external fun isEndpoint(ptr: Long): Boolean
|
||||||
|
|
||||||
|
companion object {
|
||||||
|
init {
|
||||||
|
System.loadLibrary("sherpa-onnx-jni")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fun getFeatureConfig(): FeatureConfig {
|
||||||
|
val featConfig = FeatureConfig()
|
||||||
|
featConfig.sampleRate = 16000.0f
|
||||||
|
featConfig.featureDim = 80
|
||||||
|
|
||||||
|
return featConfig
|
||||||
|
}
|
||||||
17
.github/scripts/WaveReader.kt
vendored
Normal file
17
.github/scripts/WaveReader.kt
vendored
Normal file
@@ -0,0 +1,17 @@
|
|||||||
|
package com.k2fsa.sherpa.onnx
|
||||||
|
|
||||||
|
import android.content.res.AssetManager
|
||||||
|
|
||||||
|
class WaveReader {
|
||||||
|
companion object {
|
||||||
|
// Read a mono wave file.
|
||||||
|
// No resampling is made.
|
||||||
|
external fun readWave(
|
||||||
|
assetManager: AssetManager, filename: String, expected_sample_rate: Float = 16000.0f
|
||||||
|
): FloatArray?
|
||||||
|
|
||||||
|
init {
|
||||||
|
System.loadLibrary("sherpa-onnx-jni")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
33
.github/scripts/test-jni.sh
vendored
Executable file
33
.github/scripts/test-jni.sh
vendored
Executable file
@@ -0,0 +1,33 @@
|
|||||||
|
#!/usr/bin/env bash
|
||||||
|
|
||||||
|
set -e
|
||||||
|
|
||||||
|
mkdir -p build
|
||||||
|
cd build
|
||||||
|
|
||||||
|
cmake \
|
||||||
|
-DSHERPA_ONNX_ENABLE_PYTHON=OFF \
|
||||||
|
-DSHERPA_ONNX_ENABLE_TESTS=OFF \
|
||||||
|
-DSHERPA_ONNX_ENABLE_CHECK=OFF \
|
||||||
|
-DBUILD_SHARED_LIBS=ON \
|
||||||
|
-DSHERPA_ONNX_ENABLE_PORTAUDIO=OFF \
|
||||||
|
-DSHERPA_ONNX_ENABLE_JNI=ON \
|
||||||
|
..
|
||||||
|
|
||||||
|
make -j4
|
||||||
|
ls -lh lib
|
||||||
|
|
||||||
|
cd ..
|
||||||
|
|
||||||
|
export LD_LIBRARY_PATH=$PWD/build/lib:$LD_LIBRARY_PATH
|
||||||
|
|
||||||
|
cd .github/scripts/
|
||||||
|
|
||||||
|
git lfs install
|
||||||
|
git clone https://huggingface.co/csukuangfj/sherpa-onnx-streaming-zipformer-en-2023-02-21
|
||||||
|
|
||||||
|
kotlinc-jvm -include-runtime -d main.jar Main.kt WaveReader.kt SherpaOnnx.kt AssetManager.kt
|
||||||
|
|
||||||
|
ls -lh main.jar
|
||||||
|
|
||||||
|
java -Djava.library.path=../../build/lib -jar main.jar
|
||||||
59
.github/workflows/jni.yaml
vendored
Normal file
59
.github/workflows/jni.yaml
vendored
Normal file
@@ -0,0 +1,59 @@
|
|||||||
|
name: jni
|
||||||
|
|
||||||
|
on:
|
||||||
|
push:
|
||||||
|
branches:
|
||||||
|
- master
|
||||||
|
paths:
|
||||||
|
- '.github/workflows/jni.yaml'
|
||||||
|
- 'CMakeLists.txt'
|
||||||
|
- 'cmake/**'
|
||||||
|
- 'sherpa-onnx/csrc/*'
|
||||||
|
- 'sherpa-onnx/jni/*'
|
||||||
|
- '.github/scripts/test-jni.sh'
|
||||||
|
pull_request:
|
||||||
|
branches:
|
||||||
|
- master
|
||||||
|
paths:
|
||||||
|
- '.github/workflows/jni.yaml'
|
||||||
|
- 'CMakeLists.txt'
|
||||||
|
- 'cmake/**'
|
||||||
|
- 'sherpa-onnx/csrc/*'
|
||||||
|
- 'sherpa-onnx/jni/*'
|
||||||
|
- '.github/scripts/test-jni.sh'
|
||||||
|
|
||||||
|
concurrency:
|
||||||
|
group: jni-${{ github.ref }}
|
||||||
|
cancel-in-progress: true
|
||||||
|
|
||||||
|
permissions:
|
||||||
|
contents: read
|
||||||
|
|
||||||
|
jobs:
|
||||||
|
jni:
|
||||||
|
runs-on: ${{ matrix.os }}
|
||||||
|
strategy:
|
||||||
|
fail-fast: false
|
||||||
|
matrix:
|
||||||
|
os: [ubuntu-latest, macos-latest]
|
||||||
|
|
||||||
|
steps:
|
||||||
|
- uses: actions/checkout@v2
|
||||||
|
with:
|
||||||
|
fetch-depth: 0
|
||||||
|
|
||||||
|
- name: Display kotlin version
|
||||||
|
shell: bash
|
||||||
|
run: |
|
||||||
|
kotlinc -version
|
||||||
|
|
||||||
|
- name: Display java version
|
||||||
|
shell: bash
|
||||||
|
run: |
|
||||||
|
java -version
|
||||||
|
echo "JAVA_HOME is: ${JAVA_HOME}"
|
||||||
|
|
||||||
|
- name: Run JNI test
|
||||||
|
shell: bash
|
||||||
|
run: |
|
||||||
|
.github/scripts/test-jni.sh
|
||||||
@@ -16,6 +16,7 @@ option(SHERPA_ONNX_ENABLE_TESTS "Whether to build tests" OFF)
|
|||||||
option(SHERPA_ONNX_ENABLE_CHECK "Whether to build with assert" ON)
|
option(SHERPA_ONNX_ENABLE_CHECK "Whether to build with assert" ON)
|
||||||
option(BUILD_SHARED_LIBS "Whether to build shared libraries" OFF)
|
option(BUILD_SHARED_LIBS "Whether to build shared libraries" OFF)
|
||||||
option(SHERPA_ONNX_ENABLE_PORTAUDIO "Whether to build with portaudio" ON)
|
option(SHERPA_ONNX_ENABLE_PORTAUDIO "Whether to build with portaudio" ON)
|
||||||
|
option(SHERPA_ONNX_ENABLE_JNI "Whether to build JNI internface" OFF)
|
||||||
|
|
||||||
set(CMAKE_ARCHIVE_OUTPUT_DIRECTORY "${CMAKE_BINARY_DIR}/lib")
|
set(CMAKE_ARCHIVE_OUTPUT_DIRECTORY "${CMAKE_BINARY_DIR}/lib")
|
||||||
set(CMAKE_LIBRARY_OUTPUT_DIRECTORY "${CMAKE_BINARY_DIR}/lib")
|
set(CMAKE_LIBRARY_OUTPUT_DIRECTORY "${CMAKE_BINARY_DIR}/lib")
|
||||||
@@ -44,6 +45,11 @@ if(NOT CMAKE_BUILD_TYPE)
|
|||||||
set(CMAKE_BUILD_TYPE Release)
|
set(CMAKE_BUILD_TYPE Release)
|
||||||
endif()
|
endif()
|
||||||
|
|
||||||
|
if(DEFINED ANDROID_ABI)
|
||||||
|
message(STATUS "Set SHERPA_ONNX_ENABLE_JNI to ON for Android")
|
||||||
|
set(SHERPA_ONNX_ENABLE_JNI ON CACHE BOOL "" FORCE)
|
||||||
|
endif()
|
||||||
|
|
||||||
message(STATUS "CMAKE_BUILD_TYPE: ${CMAKE_BUILD_TYPE}")
|
message(STATUS "CMAKE_BUILD_TYPE: ${CMAKE_BUILD_TYPE}")
|
||||||
message(STATUS "CMAKE_INSTALL_PREFIX: ${CMAKE_INSTALL_PREFIX}")
|
message(STATUS "CMAKE_INSTALL_PREFIX: ${CMAKE_INSTALL_PREFIX}")
|
||||||
message(STATUS "BUILD_SHARED_LIBS ${BUILD_SHARED_LIBS}")
|
message(STATUS "BUILD_SHARED_LIBS ${BUILD_SHARED_LIBS}")
|
||||||
@@ -51,6 +57,7 @@ message(STATUS "SHERPA_ONNX_ENABLE_PYTHON ${SHERPA_ONNX_ENABLE_PYTHON}")
|
|||||||
message(STATUS "SHERPA_ONNX_ENABLE_TESTS ${SHERPA_ONNX_ENABLE_TESTS}")
|
message(STATUS "SHERPA_ONNX_ENABLE_TESTS ${SHERPA_ONNX_ENABLE_TESTS}")
|
||||||
message(STATUS "SHERPA_ONNX_ENABLE_CHECK ${SHERPA_ONNX_ENABLE_CHECK}")
|
message(STATUS "SHERPA_ONNX_ENABLE_CHECK ${SHERPA_ONNX_ENABLE_CHECK}")
|
||||||
message(STATUS "SHERPA_ONNX_ENABLE_PORTAUDIO ${SHERPA_ONNX_ENABLE_PORTAUDIO}")
|
message(STATUS "SHERPA_ONNX_ENABLE_PORTAUDIO ${SHERPA_ONNX_ENABLE_PORTAUDIO}")
|
||||||
|
message(STATUS "SHERPA_ONNX_ENABLE_JNI ${SHERPA_ONNX_ENABLE_JNI}")
|
||||||
|
|
||||||
set(CMAKE_CXX_STANDARD 14 CACHE STRING "The C++ version to be used.")
|
set(CMAKE_CXX_STANDARD 14 CACHE STRING "The C++ version to be used.")
|
||||||
set(CMAKE_CXX_EXTENSIONS OFF)
|
set(CMAKE_CXX_EXTENSIONS OFF)
|
||||||
|
|||||||
@@ -2,3 +2,7 @@ add_subdirectory(csrc)
|
|||||||
if(SHERPA_ONNX_ENABLE_PYTHON)
|
if(SHERPA_ONNX_ENABLE_PYTHON)
|
||||||
add_subdirectory(python)
|
add_subdirectory(python)
|
||||||
endif()
|
endif()
|
||||||
|
|
||||||
|
if(SHERPA_ONNX_ENABLE_JNI)
|
||||||
|
add_subdirectory(jni)
|
||||||
|
endif()
|
||||||
|
|||||||
14
sherpa-onnx/jni/CMakeLists.txt
Normal file
14
sherpa-onnx/jni/CMakeLists.txt
Normal file
@@ -0,0 +1,14 @@
|
|||||||
|
include_directories(${CMAKE_SOURCE_DIR})
|
||||||
|
|
||||||
|
if(NOT DEFINED ANDROID_ABI)
|
||||||
|
if(NOT DEFINED ENV{JAVA_HOME})
|
||||||
|
message(FATAL_ERROR "Please set the environment variable JAVA_HOME")
|
||||||
|
endif()
|
||||||
|
include_directories($ENV{JAVA_HOME}/include)
|
||||||
|
include_directories($ENV{JAVA_HOME}/include/linux)
|
||||||
|
include_directories($ENV{JAVA_HOME}/include/darwin)
|
||||||
|
endif()
|
||||||
|
|
||||||
|
add_library(sherpa-onnx-jni jni.cc)
|
||||||
|
target_link_libraries(sherpa-onnx-jni sherpa-onnx-core)
|
||||||
|
install(TARGETS sherpa-onnx-jni DESTINATION lib)
|
||||||
315
sherpa-onnx/jni/jni.cc
Normal file
315
sherpa-onnx/jni/jni.cc
Normal file
@@ -0,0 +1,315 @@
|
|||||||
|
// sherpa-onnx/jni/jni.cc
|
||||||
|
//
|
||||||
|
// Copyright (c) 2022-2023 Xiaomi Corporation
|
||||||
|
// 2022 Pingfeng Luo
|
||||||
|
|
||||||
|
// TODO(fangjun): Add documentation to functions/methods in this file
|
||||||
|
// and also show how to use them with kotlin, possibly with java.
|
||||||
|
|
||||||
|
// If you use ndk, you can find "jni.h" inside
|
||||||
|
// android-ndk/toolchains/llvm/prebuilt/linux-x86_64/sysroot/usr/include
|
||||||
|
#include "jni.h" // NOLINT
|
||||||
|
|
||||||
|
#include <strstream>
|
||||||
|
|
||||||
|
#if __ANDROID_API__ >= 9
|
||||||
|
#include "android/asset_manager.h"
|
||||||
|
#include "android/asset_manager_jni.h"
|
||||||
|
#else
|
||||||
|
#include <fstream>
|
||||||
|
#endif
|
||||||
|
|
||||||
|
#if __ANDROID_API__ >= 8
|
||||||
|
#include <android/log.h>
|
||||||
|
#define SHERPA_ONNX_LOGE(...) \
|
||||||
|
do { \
|
||||||
|
fprintf(stderr, ##__VA_ARGS__); \
|
||||||
|
fprintf(stderr, "\n"); \
|
||||||
|
__android_log_print(ANDROID_LOG_WARN, "sherpa-onnx", ##__VA_ARGS__); \
|
||||||
|
} while (0)
|
||||||
|
#else
|
||||||
|
#define SHERPA_ONNX_LOGE(...) \
|
||||||
|
do { \
|
||||||
|
fprintf(stderr, ##__VA_ARGS__); \
|
||||||
|
fprintf(stderr, "\n"); \
|
||||||
|
} while (0)
|
||||||
|
#endif
|
||||||
|
|
||||||
|
#include "sherpa-onnx/csrc/online-recognizer.h"
|
||||||
|
#include "sherpa-onnx/csrc/wave-reader.h"
|
||||||
|
|
||||||
|
#define SHERPA_ONNX_EXTERN_C extern "C"
|
||||||
|
|
||||||
|
namespace sherpa_onnx {
|
||||||
|
|
||||||
|
class SherpaOnnx {
|
||||||
|
public:
|
||||||
|
SherpaOnnx(
|
||||||
|
#if __ANDROID_API__ >= 9
|
||||||
|
AAssetManager *mgr,
|
||||||
|
#endif
|
||||||
|
const sherpa_onnx::OnlineRecognizerConfig &config)
|
||||||
|
: recognizer_(
|
||||||
|
#if __ANDROID_API__ >= 9
|
||||||
|
mgr,
|
||||||
|
#endif
|
||||||
|
config),
|
||||||
|
stream_(recognizer_.CreateStream()),
|
||||||
|
tail_padding_(16000 * 0.32, 0) {
|
||||||
|
}
|
||||||
|
|
||||||
|
void DecodeSamples(float sample_rate, const float *samples, int32_t n) const {
|
||||||
|
stream_->AcceptWaveform(sample_rate, samples, n);
|
||||||
|
Decode();
|
||||||
|
}
|
||||||
|
|
||||||
|
void InputFinished() const {
|
||||||
|
stream_->AcceptWaveform(16000, tail_padding_.data(), tail_padding_.size());
|
||||||
|
stream_->InputFinished();
|
||||||
|
Decode();
|
||||||
|
}
|
||||||
|
|
||||||
|
const std::string GetText() const {
|
||||||
|
auto result = recognizer_.GetResult(stream_.get());
|
||||||
|
return result.text;
|
||||||
|
}
|
||||||
|
|
||||||
|
bool IsEndpoint() const { return recognizer_.IsEndpoint(stream_.get()); }
|
||||||
|
|
||||||
|
void Reset() const { return recognizer_.Reset(stream_.get()); }
|
||||||
|
|
||||||
|
private:
|
||||||
|
void Decode() const {
|
||||||
|
while (recognizer_.IsReady(stream_.get())) {
|
||||||
|
recognizer_.DecodeStream(stream_.get());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
private:
|
||||||
|
sherpa_onnx::OnlineRecognizer recognizer_;
|
||||||
|
std::unique_ptr<sherpa_onnx::OnlineStream> stream_;
|
||||||
|
std::vector<float> tail_padding_;
|
||||||
|
};
|
||||||
|
|
||||||
|
static OnlineRecognizerConfig GetConfig(JNIEnv *env, jobject config) {
|
||||||
|
OnlineRecognizerConfig ans;
|
||||||
|
|
||||||
|
jclass cls = env->GetObjectClass(config);
|
||||||
|
jfieldID fid;
|
||||||
|
|
||||||
|
// https://docs.oracle.com/javase/7/docs/technotes/guides/jni/spec/types.html
|
||||||
|
// https://courses.cs.washington.edu/courses/cse341/99wi/java/tutorial/native1.1/implementing/field.html
|
||||||
|
|
||||||
|
//---------- feat config ----------
|
||||||
|
fid = env->GetFieldID(cls, "featConfig",
|
||||||
|
"Lcom/k2fsa/sherpa/onnx/FeatureConfig;");
|
||||||
|
jobject feat_config = env->GetObjectField(config, fid);
|
||||||
|
jclass feat_config_cls = env->GetObjectClass(feat_config);
|
||||||
|
|
||||||
|
fid = env->GetFieldID(feat_config_cls, "sampleRate", "F");
|
||||||
|
ans.feat_config.sampling_rate = env->GetFloatField(feat_config, fid);
|
||||||
|
|
||||||
|
fid = env->GetFieldID(feat_config_cls, "featureDim", "I");
|
||||||
|
ans.feat_config.feature_dim = env->GetIntField(feat_config, fid);
|
||||||
|
|
||||||
|
//---------- enable endpoint ----------
|
||||||
|
fid = env->GetFieldID(cls, "enableEndpoint", "Z");
|
||||||
|
ans.enable_endpoint = env->GetBooleanField(config, fid);
|
||||||
|
|
||||||
|
//---------- endpoint_config ----------
|
||||||
|
|
||||||
|
fid = env->GetFieldID(cls, "endpointConfig",
|
||||||
|
"Lcom/k2fsa/sherpa/onnx/EndpointConfig;");
|
||||||
|
jobject endpoint_config = env->GetObjectField(config, fid);
|
||||||
|
jclass endpoint_config_cls = env->GetObjectClass(endpoint_config);
|
||||||
|
|
||||||
|
fid = env->GetFieldID(endpoint_config_cls, "rule1",
|
||||||
|
"Lcom/k2fsa/sherpa/onnx/EndpointRule;");
|
||||||
|
jobject rule1 = env->GetObjectField(endpoint_config, fid);
|
||||||
|
jclass rule_class = env->GetObjectClass(rule1);
|
||||||
|
|
||||||
|
fid = env->GetFieldID(endpoint_config_cls, "rule2",
|
||||||
|
"Lcom/k2fsa/sherpa/onnx/EndpointRule;");
|
||||||
|
jobject rule2 = env->GetObjectField(endpoint_config, fid);
|
||||||
|
|
||||||
|
fid = env->GetFieldID(endpoint_config_cls, "rule3",
|
||||||
|
"Lcom/k2fsa/sherpa/onnx/EndpointRule;");
|
||||||
|
jobject rule3 = env->GetObjectField(endpoint_config, fid);
|
||||||
|
|
||||||
|
fid = env->GetFieldID(rule_class, "mustContainNonSilence", "Z");
|
||||||
|
ans.endpoint_config.rule1.must_contain_nonsilence =
|
||||||
|
env->GetBooleanField(rule1, fid);
|
||||||
|
ans.endpoint_config.rule2.must_contain_nonsilence =
|
||||||
|
env->GetBooleanField(rule2, fid);
|
||||||
|
ans.endpoint_config.rule3.must_contain_nonsilence =
|
||||||
|
env->GetBooleanField(rule3, fid);
|
||||||
|
|
||||||
|
fid = env->GetFieldID(rule_class, "minTrailingSilence", "F");
|
||||||
|
ans.endpoint_config.rule1.min_trailing_silence =
|
||||||
|
env->GetFloatField(rule1, fid);
|
||||||
|
ans.endpoint_config.rule2.min_trailing_silence =
|
||||||
|
env->GetFloatField(rule2, fid);
|
||||||
|
ans.endpoint_config.rule3.min_trailing_silence =
|
||||||
|
env->GetFloatField(rule3, fid);
|
||||||
|
|
||||||
|
fid = env->GetFieldID(rule_class, "minUtteranceLength", "F");
|
||||||
|
ans.endpoint_config.rule1.min_utterance_length =
|
||||||
|
env->GetFloatField(rule1, fid);
|
||||||
|
ans.endpoint_config.rule2.min_utterance_length =
|
||||||
|
env->GetFloatField(rule2, fid);
|
||||||
|
ans.endpoint_config.rule3.min_utterance_length =
|
||||||
|
env->GetFloatField(rule3, fid);
|
||||||
|
|
||||||
|
//---------- tokens ----------
|
||||||
|
|
||||||
|
fid = env->GetFieldID(cls, "tokens", "Ljava/lang/String;");
|
||||||
|
jstring s = (jstring)env->GetObjectField(config, fid);
|
||||||
|
const char *p = env->GetStringUTFChars(s, nullptr);
|
||||||
|
ans.tokens = p;
|
||||||
|
env->ReleaseStringUTFChars(s, p);
|
||||||
|
|
||||||
|
//---------- model config ----------
|
||||||
|
fid = env->GetFieldID(cls, "modelConfig",
|
||||||
|
"Lcom/k2fsa/sherpa/onnx/OnlineTransducerModelConfig;");
|
||||||
|
jobject model_config = env->GetObjectField(config, fid);
|
||||||
|
jclass model_config_cls = env->GetObjectClass(model_config);
|
||||||
|
|
||||||
|
fid = env->GetFieldID(model_config_cls, "encoder", "Ljava/lang/String;");
|
||||||
|
s = (jstring)env->GetObjectField(model_config, fid);
|
||||||
|
p = env->GetStringUTFChars(s, nullptr);
|
||||||
|
ans.model_config.encoder_filename = p;
|
||||||
|
env->ReleaseStringUTFChars(s, p);
|
||||||
|
|
||||||
|
fid = env->GetFieldID(model_config_cls, "decoder", "Ljava/lang/String;");
|
||||||
|
s = (jstring)env->GetObjectField(model_config, fid);
|
||||||
|
p = env->GetStringUTFChars(s, nullptr);
|
||||||
|
ans.model_config.decoder_filename = p;
|
||||||
|
env->ReleaseStringUTFChars(s, p);
|
||||||
|
|
||||||
|
fid = env->GetFieldID(model_config_cls, "joiner", "Ljava/lang/String;");
|
||||||
|
s = (jstring)env->GetObjectField(model_config, fid);
|
||||||
|
p = env->GetStringUTFChars(s, nullptr);
|
||||||
|
ans.model_config.joiner_filename = p;
|
||||||
|
env->ReleaseStringUTFChars(s, p);
|
||||||
|
|
||||||
|
fid = env->GetFieldID(model_config_cls, "numThreads", "I");
|
||||||
|
ans.model_config.num_threads = env->GetIntField(model_config, fid);
|
||||||
|
|
||||||
|
fid = env->GetFieldID(model_config_cls, "debug", "Z");
|
||||||
|
ans.model_config.debug = env->GetBooleanField(model_config, fid);
|
||||||
|
|
||||||
|
return ans;
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace sherpa_onnx
|
||||||
|
|
||||||
|
SHERPA_ONNX_EXTERN_C
|
||||||
|
JNIEXPORT jlong JNICALL Java_com_k2fsa_sherpa_onnx_SherpaOnnx_new(
|
||||||
|
JNIEnv *env, jobject /*obj*/, jobject asset_manager, jobject _config) {
|
||||||
|
#if __ANDROID_API__ >= 9
|
||||||
|
AAssetManager *mgr = AAssetManager_fromJava(env, asset_manager);
|
||||||
|
if (!mgr) {
|
||||||
|
SHERPA_ONNX_LOGE("Failed to get asset manager: %p", mgr);
|
||||||
|
}
|
||||||
|
#endif
|
||||||
|
|
||||||
|
auto config = sherpa_onnx::GetConfig(env, _config);
|
||||||
|
auto model = new sherpa_onnx::SherpaOnnx(
|
||||||
|
#if __ANDROID_API__ >= 9
|
||||||
|
mgr,
|
||||||
|
#endif
|
||||||
|
config);
|
||||||
|
|
||||||
|
return (jlong)model;
|
||||||
|
}
|
||||||
|
|
||||||
|
SHERPA_ONNX_EXTERN_C
|
||||||
|
JNIEXPORT void JNICALL Java_com_k2fsa_sherpa_onnx_SherpaOnnx_delete(
|
||||||
|
JNIEnv *env, jobject /*obj*/, jlong ptr) {
|
||||||
|
SHERPA_ONNX_LOGE("freed!");
|
||||||
|
delete reinterpret_cast<sherpa_onnx::SherpaOnnx *>(ptr);
|
||||||
|
}
|
||||||
|
|
||||||
|
SHERPA_ONNX_EXTERN_C
|
||||||
|
JNIEXPORT void JNICALL Java_com_k2fsa_sherpa_onnx_SherpaOnnx_reset(
|
||||||
|
JNIEnv *env, jobject /*obj*/, jlong ptr) {
|
||||||
|
auto model = reinterpret_cast<sherpa_onnx::SherpaOnnx *>(ptr);
|
||||||
|
model->Reset();
|
||||||
|
}
|
||||||
|
|
||||||
|
SHERPA_ONNX_EXTERN_C
|
||||||
|
JNIEXPORT bool JNICALL Java_com_k2fsa_sherpa_onnx_SherpaOnnx_isEndpoint(
|
||||||
|
JNIEnv *env, jobject /*obj*/, jlong ptr) {
|
||||||
|
auto model = reinterpret_cast<sherpa_onnx::SherpaOnnx *>(ptr);
|
||||||
|
return model->IsEndpoint();
|
||||||
|
}
|
||||||
|
|
||||||
|
SHERPA_ONNX_EXTERN_C
|
||||||
|
JNIEXPORT void JNICALL Java_com_k2fsa_sherpa_onnx_SherpaOnnx_decodeSamples(
|
||||||
|
JNIEnv *env, jobject /*obj*/, jlong ptr, jfloatArray samples,
|
||||||
|
jfloat sample_rate) {
|
||||||
|
auto model = reinterpret_cast<sherpa_onnx::SherpaOnnx *>(ptr);
|
||||||
|
|
||||||
|
jfloat *p = env->GetFloatArrayElements(samples, nullptr);
|
||||||
|
jsize n = env->GetArrayLength(samples);
|
||||||
|
|
||||||
|
model->DecodeSamples(sample_rate, p, n);
|
||||||
|
|
||||||
|
env->ReleaseFloatArrayElements(samples, p, JNI_ABORT);
|
||||||
|
}
|
||||||
|
|
||||||
|
SHERPA_ONNX_EXTERN_C
|
||||||
|
JNIEXPORT void JNICALL Java_com_k2fsa_sherpa_onnx_SherpaOnnx_inputFinished(
|
||||||
|
JNIEnv *env, jobject /*obj*/, jlong ptr) {
|
||||||
|
reinterpret_cast<sherpa_onnx::SherpaOnnx *>(ptr)->InputFinished();
|
||||||
|
}
|
||||||
|
|
||||||
|
SHERPA_ONNX_EXTERN_C
|
||||||
|
JNIEXPORT jstring JNICALL Java_com_k2fsa_sherpa_onnx_SherpaOnnx_getText(
|
||||||
|
JNIEnv *env, jobject /*obj*/, jlong ptr) {
|
||||||
|
// see
|
||||||
|
// https://stackoverflow.com/questions/11621449/send-c-string-to-java-via-jni
|
||||||
|
auto text = reinterpret_cast<sherpa_onnx::SherpaOnnx *>(ptr)->GetText();
|
||||||
|
return env->NewStringUTF(text.c_str());
|
||||||
|
}
|
||||||
|
|
||||||
|
SHERPA_ONNX_EXTERN_C
|
||||||
|
JNIEXPORT jfloatArray JNICALL
|
||||||
|
Java_com_k2fsa_sherpa_onnx_WaveReader_00024Companion_readWave(
|
||||||
|
JNIEnv *env, jclass /*cls*/, jobject asset_manager, jstring filename,
|
||||||
|
jfloat expected_sample_rate) {
|
||||||
|
const char *p_filename = env->GetStringUTFChars(filename, nullptr);
|
||||||
|
#if __ANDROID_API__ >= 9
|
||||||
|
AAssetManager *mgr = AAssetManager_fromJava(env, asset_manager);
|
||||||
|
if (!mgr) {
|
||||||
|
SHERPA_ONNX_LOGE("Failed to get asset manager: %p", mgr);
|
||||||
|
return nullptr;
|
||||||
|
}
|
||||||
|
|
||||||
|
AAsset *asset = AAssetManager_open(mgr, p_filename, AASSET_MODE_BUFFER);
|
||||||
|
size_t asset_length = AAsset_getLength(asset);
|
||||||
|
std::vector<char> buffer(asset_length);
|
||||||
|
AAsset_read(asset, buffer.data(), asset_length);
|
||||||
|
|
||||||
|
std::istrstream is(buffer.data(), asset_length);
|
||||||
|
#else
|
||||||
|
std::ifstream is(p_filename, std::ios::binary);
|
||||||
|
#endif
|
||||||
|
|
||||||
|
bool is_ok = false;
|
||||||
|
std::vector<float> samples =
|
||||||
|
sherpa_onnx::ReadWave(is, expected_sample_rate, &is_ok);
|
||||||
|
|
||||||
|
#if __ANDROID_API__ >= 9
|
||||||
|
AAsset_close(asset);
|
||||||
|
#endif
|
||||||
|
env->ReleaseStringUTFChars(filename, p_filename);
|
||||||
|
|
||||||
|
if (!is_ok) {
|
||||||
|
return nullptr;
|
||||||
|
}
|
||||||
|
|
||||||
|
jfloatArray ans = env->NewFloatArray(samples.size());
|
||||||
|
env->SetFloatArrayRegion(ans, 0, samples.size(), samples.data());
|
||||||
|
return ans;
|
||||||
|
}
|
||||||
Reference in New Issue
Block a user