Add non-streaming ASR (#92)
This commit is contained in:
@@ -11,6 +11,7 @@
|
||||
#include "jni.h" // NOLINT
|
||||
|
||||
#include <strstream>
|
||||
#include <utility>
|
||||
|
||||
#if __ANDROID_API__ >= 9
|
||||
#include "android/asset_manager.h"
|
||||
@@ -43,14 +44,18 @@ class SherpaOnnx {
|
||||
stream_(recognizer_.CreateStream()) {
|
||||
}
|
||||
|
||||
void AcceptWaveform(int32_t sample_rate, const float *samples,
|
||||
int32_t n) const {
|
||||
void AcceptWaveform(int32_t sample_rate, const float *samples, int32_t n) {
|
||||
if (input_sample_rate_ == -1) {
|
||||
input_sample_rate_ = sample_rate;
|
||||
}
|
||||
|
||||
stream_->AcceptWaveform(sample_rate, samples, n);
|
||||
}
|
||||
|
||||
void InputFinished() const {
|
||||
std::vector<float> tail_padding(16000 * 0.32, 0);
|
||||
stream_->AcceptWaveform(16000, tail_padding.data(), tail_padding.size());
|
||||
std::vector<float> tail_padding(input_sample_rate_ * 0.32, 0);
|
||||
stream_->AcceptWaveform(input_sample_rate_, tail_padding.data(),
|
||||
tail_padding.size());
|
||||
stream_->InputFinished();
|
||||
}
|
||||
|
||||
@@ -70,6 +75,7 @@ class SherpaOnnx {
|
||||
private:
|
||||
sherpa_onnx::OnlineRecognizer recognizer_;
|
||||
std::unique_ptr<sherpa_onnx::OnlineStream> stream_;
|
||||
int32_t input_sample_rate_ = -1;
|
||||
};
|
||||
|
||||
static OnlineRecognizerConfig GetConfig(JNIEnv *env, jobject config) {
|
||||
@@ -276,17 +282,24 @@ JNIEXPORT jstring JNICALL Java_com_k2fsa_sherpa_onnx_SherpaOnnx_getText(
|
||||
return env->NewStringUTF(text.c_str());
|
||||
}
|
||||
|
||||
// see
|
||||
// https://stackoverflow.com/questions/29043872/android-jni-return-multiple-variables
|
||||
static jobject NewInteger(JNIEnv *env, int32_t value) {
|
||||
jclass cls = env->FindClass("java/lang/Integer");
|
||||
jmethodID constructor = env->GetMethodID(cls, "<init>", "(I)V");
|
||||
return env->NewObject(cls, constructor, value);
|
||||
}
|
||||
|
||||
SHERPA_ONNX_EXTERN_C
|
||||
JNIEXPORT jfloatArray JNICALL
|
||||
JNIEXPORT jobjectArray JNICALL
|
||||
Java_com_k2fsa_sherpa_onnx_WaveReader_00024Companion_readWave(
|
||||
JNIEnv *env, jclass /*cls*/, jobject asset_manager, jstring filename,
|
||||
jfloat expected_sample_rate) {
|
||||
JNIEnv *env, jclass /*cls*/, jobject asset_manager, jstring filename) {
|
||||
const char *p_filename = env->GetStringUTFChars(filename, nullptr);
|
||||
#if __ANDROID_API__ >= 9
|
||||
AAssetManager *mgr = AAssetManager_fromJava(env, asset_manager);
|
||||
if (!mgr) {
|
||||
SHERPA_ONNX_LOGE("Failed to get asset manager: %p", mgr);
|
||||
return nullptr;
|
||||
exit(-1);
|
||||
}
|
||||
|
||||
std::vector<char> buffer = sherpa_onnx::ReadFile(mgr, p_filename);
|
||||
@@ -297,16 +310,25 @@ Java_com_k2fsa_sherpa_onnx_WaveReader_00024Companion_readWave(
|
||||
#endif
|
||||
|
||||
bool is_ok = false;
|
||||
int32_t sampling_rate = -1;
|
||||
std::vector<float> samples =
|
||||
sherpa_onnx::ReadWave(is, expected_sample_rate, &is_ok);
|
||||
sherpa_onnx::ReadWave(is, &sampling_rate, &is_ok);
|
||||
|
||||
env->ReleaseStringUTFChars(filename, p_filename);
|
||||
|
||||
if (!is_ok) {
|
||||
return nullptr;
|
||||
SHERPA_ONNX_LOGE("Failed to read %s", p_filename);
|
||||
exit(-1);
|
||||
}
|
||||
|
||||
jfloatArray ans = env->NewFloatArray(samples.size());
|
||||
env->SetFloatArrayRegion(ans, 0, samples.size(), samples.data());
|
||||
return ans;
|
||||
|
||||
jobjectArray obj_arr = (jobjectArray)env->NewObjectArray(
|
||||
2, env->FindClass("java/lang/Object"), nullptr);
|
||||
|
||||
env->SetObjectArrayElement(obj_arr, 0, ans);
|
||||
env->SetObjectArrayElement(obj_arr, 1, NewInteger(env, sampling_rate));
|
||||
|
||||
return obj_arr;
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user