Fix modified beam search for iOS and android (#76)
* Use Int type for sampling rate * Fix swift * Fix iOS
This commit is contained in:
@@ -40,19 +40,18 @@ class SherpaOnnx {
|
||||
mgr,
|
||||
#endif
|
||||
config),
|
||||
stream_(recognizer_.CreateStream()),
|
||||
tail_padding_(16000 * 0.32, 0) {
|
||||
stream_(recognizer_.CreateStream()) {
|
||||
}
|
||||
|
||||
void DecodeSamples(float sample_rate, const float *samples, int32_t n) const {
|
||||
void AcceptWaveform(int32_t sample_rate, const float *samples,
|
||||
int32_t n) const {
|
||||
stream_->AcceptWaveform(sample_rate, samples, n);
|
||||
Decode();
|
||||
}
|
||||
|
||||
void InputFinished() const {
|
||||
stream_->AcceptWaveform(16000, tail_padding_.data(), tail_padding_.size());
|
||||
std::vector<float> tail_padding(16000 * 0.32, 0);
|
||||
stream_->AcceptWaveform(16000, tail_padding.data(), tail_padding.size());
|
||||
stream_->InputFinished();
|
||||
Decode();
|
||||
}
|
||||
|
||||
const std::string GetText() const {
|
||||
@@ -62,19 +61,15 @@ class SherpaOnnx {
|
||||
|
||||
bool IsEndpoint() const { return recognizer_.IsEndpoint(stream_.get()); }
|
||||
|
||||
bool IsReady() const { return recognizer_.IsReady(stream_.get()); }
|
||||
|
||||
void Reset() const { return recognizer_.Reset(stream_.get()); }
|
||||
|
||||
private:
|
||||
void Decode() const {
|
||||
while (recognizer_.IsReady(stream_.get())) {
|
||||
recognizer_.DecodeStream(stream_.get());
|
||||
}
|
||||
}
|
||||
void Decode() const { recognizer_.DecodeStream(stream_.get()); }
|
||||
|
||||
private:
|
||||
sherpa_onnx::OnlineRecognizer recognizer_;
|
||||
std::unique_ptr<sherpa_onnx::OnlineStream> stream_;
|
||||
std::vector<float> tail_padding_;
|
||||
};
|
||||
|
||||
static OnlineRecognizerConfig GetConfig(JNIEnv *env, jobject config) {
|
||||
@@ -86,14 +81,24 @@ static OnlineRecognizerConfig GetConfig(JNIEnv *env, jobject config) {
|
||||
// https://docs.oracle.com/javase/7/docs/technotes/guides/jni/spec/types.html
|
||||
// https://courses.cs.washington.edu/courses/cse341/99wi/java/tutorial/native1.1/implementing/field.html
|
||||
|
||||
//---------- decoding ----------
|
||||
fid = env->GetFieldID(cls, "decodingMethod", "Ljava/lang/String;");
|
||||
jstring s = (jstring)env->GetObjectField(config, fid);
|
||||
const char *p = env->GetStringUTFChars(s, nullptr);
|
||||
ans.decoding_method = p;
|
||||
env->ReleaseStringUTFChars(s, p);
|
||||
|
||||
fid = env->GetFieldID(cls, "maxActivePaths", "I");
|
||||
ans.max_active_paths = env->GetIntField(config, fid);
|
||||
|
||||
//---------- feat config ----------
|
||||
fid = env->GetFieldID(cls, "featConfig",
|
||||
"Lcom/k2fsa/sherpa/onnx/FeatureConfig;");
|
||||
jobject feat_config = env->GetObjectField(config, fid);
|
||||
jclass feat_config_cls = env->GetObjectClass(feat_config);
|
||||
|
||||
fid = env->GetFieldID(feat_config_cls, "sampleRate", "F");
|
||||
ans.feat_config.sampling_rate = env->GetFloatField(feat_config, fid);
|
||||
fid = env->GetFieldID(feat_config_cls, "sampleRate", "I");
|
||||
ans.feat_config.sampling_rate = env->GetIntField(feat_config, fid);
|
||||
|
||||
fid = env->GetFieldID(feat_config_cls, "featureDim", "I");
|
||||
ans.feat_config.feature_dim = env->GetIntField(feat_config, fid);
|
||||
@@ -153,8 +158,8 @@ static OnlineRecognizerConfig GetConfig(JNIEnv *env, jobject config) {
|
||||
jclass model_config_cls = env->GetObjectClass(model_config);
|
||||
|
||||
fid = env->GetFieldID(model_config_cls, "encoder", "Ljava/lang/String;");
|
||||
jstring s = (jstring)env->GetObjectField(model_config, fid);
|
||||
const char *p = env->GetStringUTFChars(s, nullptr);
|
||||
s = (jstring)env->GetObjectField(model_config, fid);
|
||||
p = env->GetStringUTFChars(s, nullptr);
|
||||
ans.model_config.encoder_filename = p;
|
||||
env->ReleaseStringUTFChars(s, p);
|
||||
|
||||
@@ -198,6 +203,7 @@ JNIEXPORT jlong JNICALL Java_com_k2fsa_sherpa_onnx_SherpaOnnx_new(
|
||||
#endif
|
||||
|
||||
auto config = sherpa_onnx::GetConfig(env, _config);
|
||||
SHERPA_ONNX_LOGE("config:\n%s", config.ToString().c_str());
|
||||
auto model = new sherpa_onnx::SherpaOnnx(
|
||||
#if __ANDROID_API__ >= 9
|
||||
mgr,
|
||||
@@ -220,6 +226,13 @@ JNIEXPORT void JNICALL Java_com_k2fsa_sherpa_onnx_SherpaOnnx_reset(
|
||||
model->Reset();
|
||||
}
|
||||
|
||||
SHERPA_ONNX_EXTERN_C
|
||||
JNIEXPORT bool JNICALL Java_com_k2fsa_sherpa_onnx_SherpaOnnx_isReady(
|
||||
JNIEnv *env, jobject /*obj*/, jlong ptr) {
|
||||
auto model = reinterpret_cast<sherpa_onnx::SherpaOnnx *>(ptr);
|
||||
return model->IsReady();
|
||||
}
|
||||
|
||||
SHERPA_ONNX_EXTERN_C
|
||||
JNIEXPORT bool JNICALL Java_com_k2fsa_sherpa_onnx_SherpaOnnx_isEndpoint(
|
||||
JNIEnv *env, jobject /*obj*/, jlong ptr) {
|
||||
@@ -228,15 +241,22 @@ JNIEXPORT bool JNICALL Java_com_k2fsa_sherpa_onnx_SherpaOnnx_isEndpoint(
|
||||
}
|
||||
|
||||
SHERPA_ONNX_EXTERN_C
|
||||
JNIEXPORT void JNICALL Java_com_k2fsa_sherpa_onnx_SherpaOnnx_decodeSamples(
|
||||
JNIEXPORT void JNICALL Java_com_k2fsa_sherpa_onnx_SherpaOnnx_decode(
|
||||
JNIEnv *env, jobject /*obj*/, jlong ptr) {
|
||||
auto model = reinterpret_cast<sherpa_onnx::SherpaOnnx *>(ptr);
|
||||
model->Decode();
|
||||
}
|
||||
|
||||
SHERPA_ONNX_EXTERN_C
|
||||
JNIEXPORT void JNICALL Java_com_k2fsa_sherpa_onnx_SherpaOnnx_acceptWaveform(
|
||||
JNIEnv *env, jobject /*obj*/, jlong ptr, jfloatArray samples,
|
||||
jfloat sample_rate) {
|
||||
jint sample_rate) {
|
||||
auto model = reinterpret_cast<sherpa_onnx::SherpaOnnx *>(ptr);
|
||||
|
||||
jfloat *p = env->GetFloatArrayElements(samples, nullptr);
|
||||
jsize n = env->GetArrayLength(samples);
|
||||
|
||||
model->DecodeSamples(sample_rate, p, n);
|
||||
model->AcceptWaveform(sample_rate, p, n);
|
||||
|
||||
env->ReleaseFloatArrayElements(samples, p, JNI_ABORT);
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user