Fix modified beam search for iOS and android (#76)

* Use Int type for sampling rate

* Fix swift

* Fix iOS
This commit is contained in:
Fangjun Kuang
2023-03-03 15:18:31 +08:00
committed by GitHub
parent 7f72c13d9a
commit 5f31b22c12
15 changed files with 125 additions and 93 deletions

View File

@@ -76,7 +76,7 @@ SherpaOnnxOnlineStream *CreateOnlineStream(
void DestoryOnlineStream(SherpaOnnxOnlineStream *stream) { delete stream; }
void AcceptWaveform(SherpaOnnxOnlineStream *stream, float sample_rate,
void AcceptWaveform(SherpaOnnxOnlineStream *stream, int32_t sample_rate,
const float *samples, int32_t n) {
stream->impl->AcceptWaveform(sample_rate, samples, n);
}

View File

@@ -120,7 +120,7 @@ void DestoryOnlineStream(SherpaOnnxOnlineStream *stream);
/// @param samples A pointer to a 1-D array containing audio samples.
/// The range of samples has to be normalized to [-1, 1].
/// @param n Number of elements in the samples array.
void AcceptWaveform(SherpaOnnxOnlineStream *stream, float sample_rate,
void AcceptWaveform(SherpaOnnxOnlineStream *stream, int32_t sample_rate,
const float *samples, int32_t n);
/// Return 1 if there are enough number of feature frames for decoding.

View File

@@ -48,7 +48,7 @@ class FeatureExtractor::Impl {
fbank_ = std::make_unique<knf::OnlineFbank>(opts_);
}
void AcceptWaveform(float sampling_rate, const float *waveform, int32_t n) {
void AcceptWaveform(int32_t sampling_rate, const float *waveform, int32_t n) {
std::lock_guard<std::mutex> lock(mutex_);
fbank_->AcceptWaveform(sampling_rate, waveform, n);
}
@@ -107,7 +107,7 @@ FeatureExtractor::FeatureExtractor(const FeatureExtractorConfig &config /*={}*/)
FeatureExtractor::~FeatureExtractor() = default;
void FeatureExtractor::AcceptWaveform(float sampling_rate,
void FeatureExtractor::AcceptWaveform(int32_t sampling_rate,
const float *waveform, int32_t n) {
impl_->AcceptWaveform(sampling_rate, waveform, n);
}

View File

@@ -14,7 +14,7 @@
namespace sherpa_onnx {
struct FeatureExtractorConfig {
float sampling_rate = 16000;
int32_t sampling_rate = 16000;
int32_t feature_dim = 80;
int32_t max_feature_vectors = -1;
@@ -34,7 +34,7 @@ class FeatureExtractor {
@param waveform Pointer to a 1-D array of size n
@param n Number of entries in waveform
*/
void AcceptWaveform(float sampling_rate, const float *waveform, int32_t n);
void AcceptWaveform(int32_t sampling_rate, const float *waveform, int32_t n);
/**
* InputFinished() tells the class you won't be providing any

View File

@@ -112,7 +112,7 @@ for a list of pre-trained models to download.
param.suggestedLatency = info->defaultLowInputLatency;
param.hostApiSpecificStreamInfo = nullptr;
const float sample_rate = 16000;
float sample_rate = 16000;
PaStream *stream;
PaError err =

View File

@@ -61,7 +61,7 @@ for a list of pre-trained models to download.
sherpa_onnx::OnlineRecognizer recognizer(config);
float expected_sampling_rate = config.feat_config.sampling_rate;
int32_t expected_sampling_rate = config.feat_config.sampling_rate;
bool is_ok = false;
std::vector<float> samples =
@@ -72,7 +72,7 @@ for a list of pre-trained models to download.
return -1;
}
float duration = samples.size() / expected_sampling_rate;
float duration = samples.size() / static_cast<float>(expected_sampling_rate);
fprintf(stderr, "wav filename: %s\n", wav_filename.c_str());
fprintf(stderr, "wav duration (s): %.3f\n", duration);

View File

@@ -40,19 +40,18 @@ class SherpaOnnx {
mgr,
#endif
config),
stream_(recognizer_.CreateStream()),
tail_padding_(16000 * 0.32, 0) {
stream_(recognizer_.CreateStream()) {
}
void DecodeSamples(float sample_rate, const float *samples, int32_t n) const {
void AcceptWaveform(int32_t sample_rate, const float *samples,
int32_t n) const {
stream_->AcceptWaveform(sample_rate, samples, n);
Decode();
}
void InputFinished() const {
stream_->AcceptWaveform(16000, tail_padding_.data(), tail_padding_.size());
std::vector<float> tail_padding(16000 * 0.32, 0);
stream_->AcceptWaveform(16000, tail_padding.data(), tail_padding.size());
stream_->InputFinished();
Decode();
}
const std::string GetText() const {
@@ -62,19 +61,15 @@ class SherpaOnnx {
bool IsEndpoint() const { return recognizer_.IsEndpoint(stream_.get()); }
bool IsReady() const { return recognizer_.IsReady(stream_.get()); }
void Reset() const { return recognizer_.Reset(stream_.get()); }
private:
void Decode() const {
while (recognizer_.IsReady(stream_.get())) {
recognizer_.DecodeStream(stream_.get());
}
}
void Decode() const { recognizer_.DecodeStream(stream_.get()); }
private:
sherpa_onnx::OnlineRecognizer recognizer_;
std::unique_ptr<sherpa_onnx::OnlineStream> stream_;
std::vector<float> tail_padding_;
};
static OnlineRecognizerConfig GetConfig(JNIEnv *env, jobject config) {
@@ -86,14 +81,24 @@ static OnlineRecognizerConfig GetConfig(JNIEnv *env, jobject config) {
// https://docs.oracle.com/javase/7/docs/technotes/guides/jni/spec/types.html
// https://courses.cs.washington.edu/courses/cse341/99wi/java/tutorial/native1.1/implementing/field.html
//---------- decoding ----------
fid = env->GetFieldID(cls, "decodingMethod", "Ljava/lang/String;");
jstring s = (jstring)env->GetObjectField(config, fid);
const char *p = env->GetStringUTFChars(s, nullptr);
ans.decoding_method = p;
env->ReleaseStringUTFChars(s, p);
fid = env->GetFieldID(cls, "maxActivePaths", "I");
ans.max_active_paths = env->GetIntField(config, fid);
//---------- feat config ----------
fid = env->GetFieldID(cls, "featConfig",
"Lcom/k2fsa/sherpa/onnx/FeatureConfig;");
jobject feat_config = env->GetObjectField(config, fid);
jclass feat_config_cls = env->GetObjectClass(feat_config);
fid = env->GetFieldID(feat_config_cls, "sampleRate", "F");
ans.feat_config.sampling_rate = env->GetFloatField(feat_config, fid);
fid = env->GetFieldID(feat_config_cls, "sampleRate", "I");
ans.feat_config.sampling_rate = env->GetIntField(feat_config, fid);
fid = env->GetFieldID(feat_config_cls, "featureDim", "I");
ans.feat_config.feature_dim = env->GetIntField(feat_config, fid);
@@ -153,8 +158,8 @@ static OnlineRecognizerConfig GetConfig(JNIEnv *env, jobject config) {
jclass model_config_cls = env->GetObjectClass(model_config);
fid = env->GetFieldID(model_config_cls, "encoder", "Ljava/lang/String;");
jstring s = (jstring)env->GetObjectField(model_config, fid);
const char *p = env->GetStringUTFChars(s, nullptr);
s = (jstring)env->GetObjectField(model_config, fid);
p = env->GetStringUTFChars(s, nullptr);
ans.model_config.encoder_filename = p;
env->ReleaseStringUTFChars(s, p);
@@ -198,6 +203,7 @@ JNIEXPORT jlong JNICALL Java_com_k2fsa_sherpa_onnx_SherpaOnnx_new(
#endif
auto config = sherpa_onnx::GetConfig(env, _config);
SHERPA_ONNX_LOGE("config:\n%s", config.ToString().c_str());
auto model = new sherpa_onnx::SherpaOnnx(
#if __ANDROID_API__ >= 9
mgr,
@@ -220,6 +226,13 @@ JNIEXPORT void JNICALL Java_com_k2fsa_sherpa_onnx_SherpaOnnx_reset(
model->Reset();
}
SHERPA_ONNX_EXTERN_C
JNIEXPORT bool JNICALL Java_com_k2fsa_sherpa_onnx_SherpaOnnx_isReady(
JNIEnv *env, jobject /*obj*/, jlong ptr) {
auto model = reinterpret_cast<sherpa_onnx::SherpaOnnx *>(ptr);
return model->IsReady();
}
SHERPA_ONNX_EXTERN_C
JNIEXPORT bool JNICALL Java_com_k2fsa_sherpa_onnx_SherpaOnnx_isEndpoint(
JNIEnv *env, jobject /*obj*/, jlong ptr) {
@@ -228,15 +241,22 @@ JNIEXPORT bool JNICALL Java_com_k2fsa_sherpa_onnx_SherpaOnnx_isEndpoint(
}
SHERPA_ONNX_EXTERN_C
JNIEXPORT void JNICALL Java_com_k2fsa_sherpa_onnx_SherpaOnnx_decodeSamples(
JNIEXPORT void JNICALL Java_com_k2fsa_sherpa_onnx_SherpaOnnx_decode(
JNIEnv *env, jobject /*obj*/, jlong ptr) {
auto model = reinterpret_cast<sherpa_onnx::SherpaOnnx *>(ptr);
model->Decode();
}
SHERPA_ONNX_EXTERN_C
JNIEXPORT void JNICALL Java_com_k2fsa_sherpa_onnx_SherpaOnnx_acceptWaveform(
JNIEnv *env, jobject /*obj*/, jlong ptr, jfloatArray samples,
jfloat sample_rate) {
jint sample_rate) {
auto model = reinterpret_cast<sherpa_onnx::SherpaOnnx *>(ptr);
jfloat *p = env->GetFloatArrayElements(samples, nullptr);
jsize n = env->GetArrayLength(samples);
model->DecodeSamples(sample_rate, p, n);
model->AcceptWaveform(sample_rate, p, n);
env->ReleaseFloatArrayElements(samples, p, JNI_ABORT);
}