diff --git a/android/SherpaOnnx/app/src/main/java/com/k2fsa/sherpa/onnx/MainActivity.kt b/android/SherpaOnnx/app/src/main/java/com/k2fsa/sherpa/onnx/MainActivity.kt index d8d7d4a9..121becc0 100644 --- a/android/SherpaOnnx/app/src/main/java/com/k2fsa/sherpa/onnx/MainActivity.kt +++ b/android/SherpaOnnx/app/src/main/java/com/k2fsa/sherpa/onnx/MainActivity.kt @@ -177,7 +177,7 @@ class MainActivity : AppCompatActivity() { // Please change getModelConfig() to add new models // See https://k2-fsa.github.io/sherpa/onnx/pretrained_models/index.html // for a list of available models - val type = 3 + val type = 5 println("Select model type ${type}") val config = OnlineRecognizerConfig( featConfig = getFeatureConfig(sampleRate = sampleRateInHz, featureDim = 80), @@ -185,8 +185,6 @@ class MainActivity : AppCompatActivity() { lmConfig = getOnlineLMConfig(type = type), endpointConfig = getEndpointConfig(), enableEndpoint = true, - decodingMethod = "modified_beam_search", - maxActivePaths = 4, ) model = SherpaOnnx( diff --git a/android/SherpaOnnx/app/src/main/java/com/k2fsa/sherpa/onnx/SherpaOnnx.kt b/android/SherpaOnnx/app/src/main/java/com/k2fsa/sherpa/onnx/SherpaOnnx.kt index 4ef48f93..aedaf0f6 100644 --- a/android/SherpaOnnx/app/src/main/java/com/k2fsa/sherpa/onnx/SherpaOnnx.kt +++ b/android/SherpaOnnx/app/src/main/java/com/k2fsa/sherpa/onnx/SherpaOnnx.kt @@ -15,9 +15,19 @@ data class EndpointConfig( ) data class OnlineTransducerModelConfig( - var encoder: String, - var decoder: String, - var joiner: String, + var encoder: String = "", + var decoder: String = "", + var joiner: String = "", +) + +data class OnlineParaformerModelConfig( + var encoder: String = "", + var decoder: String = "", +) + +data class OnlineModelConfig( + var transducer: OnlineTransducerModelConfig = OnlineTransducerModelConfig(), + var paraformer: OnlineParaformerModelConfig = OnlineParaformerModelConfig(), var tokens: String, var numThreads: Int = 1, var debug: Boolean = false, @@ -37,8 +47,8 @@ data class FeatureConfig( data class OnlineRecognizerConfig( var featConfig: FeatureConfig = FeatureConfig(), - var modelConfig: OnlineTransducerModelConfig, - var lmConfig : OnlineLMConfig, + var modelConfig: OnlineModelConfig, + var lmConfig: OnlineLMConfig, var endpointConfig: EndpointConfig = EndpointConfig(), var enableEndpoint: Boolean = true, var decodingMethod: String = "greedy_search", @@ -115,37 +125,47 @@ to add your own. (It should be straightforward to add a new model by following the code) @param type -0 - sherpa-onnx-streaming-zipformer-bilingual-zh-en-2023-02-20 (Bilingual, Chinese + English) +0 - sherpa-onnx-streaming-zipformer-bilingual-zh-en-2023-02-20 (Bilingual, Chinese + English) https://k2-fsa.github.io/sherpa/onnx/pretrained_models/zipformer-transducer-models.html#sherpa-onnx-streaming-zipformer-bilingual-zh-en-2023-02-20-bilingual-chinese-english -1 - csukuangfj/sherpa-onnx-lstm-zh-2023-02-20 (Chinese) +1 - csukuangfj/sherpa-onnx-lstm-zh-2023-02-20 (Chinese) https://k2-fsa.github.io/sherpa/onnx/pretrained_models/lstm-transducer-models.html#csukuangfj-sherpa-onnx-lstm-zh-2023-02-20-chinese -2 - csukuangfj/sherpa-onnx-lstm-en-2023-02-17 (English) +2 - csukuangfj/sherpa-onnx-lstm-en-2023-02-17 (English) https://k2-fsa.github.io/sherpa/onnx/pretrained_models/lstm-transducer-models.html#csukuangfj-sherpa-onnx-lstm-en-2023-02-17-english -3 - pkufool/icefall-asr-zipformer-streaming-wenetspeech-20230615 +3,4 - pkufool/icefall-asr-zipformer-streaming-wenetspeech-20230615 https://huggingface.co/pkufool/icefall-asr-zipformer-streaming-wenetspeech-20230615 + 3 - int8 encoder + 4 - float32 encoder + +5 - csukuangfj/sherpa-onnx-streaming-paraformer-bilingual-zh-en + https://huggingface.co/csukuangfj/sherpa-onnx-streaming-paraformer-bilingual-zh-en + */ -fun getModelConfig(type: Int): OnlineTransducerModelConfig? { +fun getModelConfig(type: Int): OnlineModelConfig? { when (type) { 0 -> { val modelDir = "sherpa-onnx-streaming-zipformer-bilingual-zh-en-2023-02-20" - return OnlineTransducerModelConfig( - encoder = "$modelDir/encoder-epoch-99-avg-1.onnx", - decoder = "$modelDir/decoder-epoch-99-avg-1.onnx", - joiner = "$modelDir/joiner-epoch-99-avg-1.onnx", + return OnlineModelConfig( + transducer = OnlineTransducerModelConfig( + encoder = "$modelDir/encoder-epoch-99-avg-1.onnx", + decoder = "$modelDir/decoder-epoch-99-avg-1.onnx", + joiner = "$modelDir/joiner-epoch-99-avg-1.onnx", + ), tokens = "$modelDir/tokens.txt", modelType = "zipformer", ) } 1 -> { val modelDir = "sherpa-onnx-lstm-zh-2023-02-20" - return OnlineTransducerModelConfig( - encoder = "$modelDir/encoder-epoch-11-avg-1.onnx", - decoder = "$modelDir/decoder-epoch-11-avg-1.onnx", - joiner = "$modelDir/joiner-epoch-11-avg-1.onnx", + return OnlineModelConfig( + transducer = OnlineTransducerModelConfig( + encoder = "$modelDir/encoder-epoch-11-avg-1.onnx", + decoder = "$modelDir/decoder-epoch-11-avg-1.onnx", + joiner = "$modelDir/joiner-epoch-11-avg-1.onnx", + ), tokens = "$modelDir/tokens.txt", modelType = "lstm", ) @@ -153,10 +173,12 @@ fun getModelConfig(type: Int): OnlineTransducerModelConfig? { 2 -> { val modelDir = "sherpa-onnx-lstm-en-2023-02-17" - return OnlineTransducerModelConfig( - encoder = "$modelDir/encoder-epoch-99-avg-1.onnx", - decoder = "$modelDir/decoder-epoch-99-avg-1.onnx", - joiner = "$modelDir/joiner-epoch-99-avg-1.onnx", + return OnlineModelConfig( + transducer = OnlineTransducerModelConfig( + encoder = "$modelDir/encoder-epoch-99-avg-1.onnx", + decoder = "$modelDir/decoder-epoch-99-avg-1.onnx", + joiner = "$modelDir/joiner-epoch-99-avg-1.onnx", + ), tokens = "$modelDir/tokens.txt", modelType = "lstm", ) @@ -164,10 +186,12 @@ fun getModelConfig(type: Int): OnlineTransducerModelConfig? { 3 -> { val modelDir = "icefall-asr-zipformer-streaming-wenetspeech-20230615" - return OnlineTransducerModelConfig( - encoder = "$modelDir/exp/encoder-epoch-12-avg-4-chunk-16-left-128.int8.onnx", - decoder = "$modelDir/exp/decoder-epoch-12-avg-4-chunk-16-left-128.onnx", - joiner = "$modelDir/exp/joiner-epoch-12-avg-4-chunk-16-left-128.onnx", + return OnlineModelConfig( + transducer = OnlineTransducerModelConfig( + encoder = "$modelDir/exp/encoder-epoch-12-avg-4-chunk-16-left-128.int8.onnx", + decoder = "$modelDir/exp/decoder-epoch-12-avg-4-chunk-16-left-128.onnx", + joiner = "$modelDir/exp/joiner-epoch-12-avg-4-chunk-16-left-128.onnx", + ), tokens = "$modelDir/data/lang_char/tokens.txt", modelType = "zipformer2", ) @@ -175,14 +199,28 @@ fun getModelConfig(type: Int): OnlineTransducerModelConfig? { 4 -> { val modelDir = "icefall-asr-zipformer-streaming-wenetspeech-20230615" - return OnlineTransducerModelConfig( - encoder = "$modelDir/exp/encoder-epoch-12-avg-4-chunk-16-left-128.onnx", - decoder = "$modelDir/exp/decoder-epoch-12-avg-4-chunk-16-left-128.onnx", - joiner = "$modelDir/exp/joiner-epoch-12-avg-4-chunk-16-left-128.onnx", + return OnlineModelConfig( + transducer = OnlineTransducerModelConfig( + encoder = "$modelDir/exp/encoder-epoch-12-avg-4-chunk-16-left-128.onnx", + decoder = "$modelDir/exp/decoder-epoch-12-avg-4-chunk-16-left-128.onnx", + joiner = "$modelDir/exp/joiner-epoch-12-avg-4-chunk-16-left-128.onnx", + ), tokens = "$modelDir/data/lang_char/tokens.txt", modelType = "zipformer2", ) } + + 5 -> { + val modelDir = "sherpa-onnx-streaming-paraformer-bilingual-zh-en" + return OnlineModelConfig( + paraformer = OnlineParaformerModelConfig( + encoder = "$modelDir/encoder.int8.onnx", + decoder = "$modelDir/decoder.int8.onnx", + ), + tokens = "$modelDir/tokens.txt", + modelType = "paraformer", + ) + } } return null; } @@ -200,7 +238,7 @@ by following the code, https://github.com/k2-fsa/icefall/blob/master/icefall/rnn 0 - sherpa-onnx-streaming-zipformer-bilingual-zh-en-2023-02-20 (Bilingual, Chinese + English) https://k2-fsa.github.io/sherpa/onnx/pretrained_models/zipformer-transducer-models.html#sherpa-onnx-streaming-zipformer-bilingual-zh-en-2023-02-20-bilingual-chinese-english */ -fun getOnlineLMConfig(type : Int): OnlineLMConfig { +fun getOnlineLMConfig(type: Int): OnlineLMConfig { when (type) { 0 -> { val modelDir = "sherpa-onnx-streaming-zipformer-bilingual-zh-en-2023-02-20" diff --git a/sherpa-onnx/csrc/online-recognizer-paraformer-impl.h b/sherpa-onnx/csrc/online-recognizer-paraformer-impl.h index ae209633..d5034d13 100644 --- a/sherpa-onnx/csrc/online-recognizer-paraformer-impl.h +++ b/sherpa-onnx/csrc/online-recognizer-paraformer-impl.h @@ -190,7 +190,11 @@ class OnlineRecognizerParaformerImpl : public OnlineRecognizerImpl { OnlineParaformerDecoderResult r; s->SetParaformerResult(r); - // the internal model caches are not reset + s->GetStates().clear(); + s->GetParaformerEncoderOutCache().clear(); + s->GetParaformerAlphaCache().clear(); + + // s->GetParaformerFeatCache().clear(); // Note: We only update counters. The underlying audio samples // are not discarded. diff --git a/sherpa-onnx/jni/jni.cc b/sherpa-onnx/jni/jni.cc index d05140ef..1f93da8e 100644 --- a/sherpa-onnx/jni/jni.cc +++ b/sherpa-onnx/jni/jni.cc @@ -47,7 +47,7 @@ class SherpaOnnx { } void InputFinished() const { - std::vector tail_padding(input_sample_rate_ * 0.32, 0); + std::vector tail_padding(input_sample_rate_ * 0.6, 0); stream_->AcceptWaveform(input_sample_rate_, tail_padding.data(), tail_padding.size()); stream_->InputFinished(); @@ -158,48 +158,74 @@ static OnlineRecognizerConfig GetConfig(JNIEnv *env, jobject config) { //---------- model config ---------- fid = env->GetFieldID(cls, "modelConfig", - "Lcom/k2fsa/sherpa/onnx/OnlineTransducerModelConfig;"); - jobject transducer_config = env->GetObjectField(config, fid); - jclass model_config_cls = env->GetObjectClass(transducer_config); + "Lcom/k2fsa/sherpa/onnx/OnlineModelConfig;"); + jobject model_config = env->GetObjectField(config, fid); + jclass model_config_cls = env->GetObjectClass(model_config); - fid = env->GetFieldID(model_config_cls, "encoder", "Ljava/lang/String;"); + // transducer + fid = env->GetFieldID(model_config_cls, "transducer", + "Lcom/k2fsa/sherpa/onnx/OnlineTransducerModelConfig;"); + jobject transducer_config = env->GetObjectField(model_config, fid); + jclass transducer_config_cls = env->GetObjectClass(transducer_config); + + fid = env->GetFieldID(transducer_config_cls, "encoder", "Ljava/lang/String;"); s = (jstring)env->GetObjectField(transducer_config, fid); p = env->GetStringUTFChars(s, nullptr); ans.model_config.transducer.encoder = p; env->ReleaseStringUTFChars(s, p); - fid = env->GetFieldID(model_config_cls, "decoder", "Ljava/lang/String;"); + fid = env->GetFieldID(transducer_config_cls, "decoder", "Ljava/lang/String;"); s = (jstring)env->GetObjectField(transducer_config, fid); p = env->GetStringUTFChars(s, nullptr); ans.model_config.transducer.decoder = p; env->ReleaseStringUTFChars(s, p); - fid = env->GetFieldID(model_config_cls, "joiner", "Ljava/lang/String;"); + fid = env->GetFieldID(transducer_config_cls, "joiner", "Ljava/lang/String;"); s = (jstring)env->GetObjectField(transducer_config, fid); p = env->GetStringUTFChars(s, nullptr); ans.model_config.transducer.joiner = p; env->ReleaseStringUTFChars(s, p); + // paraformer + fid = env->GetFieldID(model_config_cls, "paraformer", + "Lcom/k2fsa/sherpa/onnx/OnlineParaformerModelConfig;"); + jobject paraformer_config = env->GetObjectField(model_config, fid); + jclass paraformer_config_config_cls = env->GetObjectClass(paraformer_config); + + fid = env->GetFieldID(paraformer_config_config_cls, "encoder", + "Ljava/lang/String;"); + s = (jstring)env->GetObjectField(paraformer_config, fid); + p = env->GetStringUTFChars(s, nullptr); + ans.model_config.paraformer.encoder = p; + env->ReleaseStringUTFChars(s, p); + + fid = env->GetFieldID(paraformer_config_config_cls, "decoder", + "Ljava/lang/String;"); + s = (jstring)env->GetObjectField(paraformer_config, fid); + p = env->GetStringUTFChars(s, nullptr); + ans.model_config.paraformer.decoder = p; + env->ReleaseStringUTFChars(s, p); + fid = env->GetFieldID(model_config_cls, "tokens", "Ljava/lang/String;"); - s = (jstring)env->GetObjectField(transducer_config, fid); + s = (jstring)env->GetObjectField(model_config, fid); p = env->GetStringUTFChars(s, nullptr); ans.model_config.tokens = p; env->ReleaseStringUTFChars(s, p); fid = env->GetFieldID(model_config_cls, "numThreads", "I"); - ans.model_config.num_threads = env->GetIntField(transducer_config, fid); + ans.model_config.num_threads = env->GetIntField(model_config, fid); fid = env->GetFieldID(model_config_cls, "debug", "Z"); - ans.model_config.debug = env->GetBooleanField(transducer_config, fid); + ans.model_config.debug = env->GetBooleanField(model_config, fid); fid = env->GetFieldID(model_config_cls, "provider", "Ljava/lang/String;"); - s = (jstring)env->GetObjectField(transducer_config, fid); + s = (jstring)env->GetObjectField(model_config, fid); p = env->GetStringUTFChars(s, nullptr); ans.model_config.provider = p; env->ReleaseStringUTFChars(s, p); fid = env->GetFieldID(model_config_cls, "modelType", "Ljava/lang/String;"); - s = (jstring)env->GetObjectField(transducer_config, fid); + s = (jstring)env->GetObjectField(model_config, fid); p = env->GetStringUTFChars(s, nullptr); ans.model_config.model_type = p; env->ReleaseStringUTFChars(s, p);