diff --git a/android/SherpaOnnx/app/src/main/java/com/k2fsa/sherpa/onnx/SherpaOnnx.kt b/android/SherpaOnnx/app/src/main/java/com/k2fsa/sherpa/onnx/SherpaOnnx.kt index 0fb5fe24..4ef48f93 100644 --- a/android/SherpaOnnx/app/src/main/java/com/k2fsa/sherpa/onnx/SherpaOnnx.kt +++ b/android/SherpaOnnx/app/src/main/java/com/k2fsa/sherpa/onnx/SherpaOnnx.kt @@ -21,6 +21,8 @@ data class OnlineTransducerModelConfig( var tokens: String, var numThreads: Int = 1, var debug: Boolean = false, + var provider: String = "cpu", + var modelType: String = "", ) data class OnlineLMConfig( @@ -135,6 +137,7 @@ fun getModelConfig(type: Int): OnlineTransducerModelConfig? { decoder = "$modelDir/decoder-epoch-99-avg-1.onnx", joiner = "$modelDir/joiner-epoch-99-avg-1.onnx", tokens = "$modelDir/tokens.txt", + modelType = "zipformer", ) } 1 -> { @@ -144,6 +147,7 @@ fun getModelConfig(type: Int): OnlineTransducerModelConfig? { decoder = "$modelDir/decoder-epoch-11-avg-1.onnx", joiner = "$modelDir/joiner-epoch-11-avg-1.onnx", tokens = "$modelDir/tokens.txt", + modelType = "lstm", ) } @@ -154,6 +158,7 @@ fun getModelConfig(type: Int): OnlineTransducerModelConfig? { decoder = "$modelDir/decoder-epoch-99-avg-1.onnx", joiner = "$modelDir/joiner-epoch-99-avg-1.onnx", tokens = "$modelDir/tokens.txt", + modelType = "lstm", ) } @@ -164,6 +169,7 @@ fun getModelConfig(type: Int): OnlineTransducerModelConfig? { decoder = "$modelDir/exp/decoder-epoch-12-avg-4-chunk-16-left-128.onnx", joiner = "$modelDir/exp/joiner-epoch-12-avg-4-chunk-16-left-128.onnx", tokens = "$modelDir/data/lang_char/tokens.txt", + modelType = "zipformer2", ) } @@ -174,6 +180,7 @@ fun getModelConfig(type: Int): OnlineTransducerModelConfig? { decoder = "$modelDir/exp/decoder-epoch-12-avg-4-chunk-16-left-128.onnx", joiner = "$modelDir/exp/joiner-epoch-12-avg-4-chunk-16-left-128.onnx", tokens = "$modelDir/data/lang_char/tokens.txt", + modelType = "zipformer2", ) } } diff --git a/ios-swift/SherpaOnnx/SherpaOnnx/Model.swift b/ios-swift/SherpaOnnx/SherpaOnnx/Model.swift index 4aeb97a9..5e6c30d6 100644 --- a/ios-swift/SherpaOnnx/SherpaOnnx/Model.swift +++ b/ios-swift/SherpaOnnx/SherpaOnnx/Model.swift @@ -26,7 +26,8 @@ func getBilingualStreamZhEnZipformer20230220() -> SherpaOnnxOnlineTransducerMode decoder: decoder, joiner: joiner, tokens: tokens, - numThreads: 2 + numThreads: 2, + modelType: "zipformer" ) } @@ -41,7 +42,8 @@ func getZhZipformer20230615() -> SherpaOnnxOnlineTransducerModelConfig { decoder: decoder, joiner: joiner, tokens: tokens, - numThreads: 2 + numThreads: 2, + modelType: "zipformer2" ) } @@ -56,7 +58,8 @@ func getZhZipformer20230615Int8() -> SherpaOnnxOnlineTransducerModelConfig { decoder: decoder, joiner: joiner, tokens: tokens, - numThreads: 2 + numThreads: 2, + modelType: "zipformer2" ) } @@ -71,7 +74,8 @@ func getEnZipformer20230626() -> SherpaOnnxOnlineTransducerModelConfig { decoder: decoder, joiner: joiner, tokens: tokens, - numThreads: 2 + numThreads: 2, + modelType: "zipformer2" ) } diff --git a/ios-swiftui/SherpaOnnx/SherpaOnnx/Model.swift b/ios-swiftui/SherpaOnnx/SherpaOnnx/Model.swift index 6c5b5999..569b62c8 100644 --- a/ios-swiftui/SherpaOnnx/SherpaOnnx/Model.swift +++ b/ios-swiftui/SherpaOnnx/SherpaOnnx/Model.swift @@ -26,7 +26,8 @@ func getBilingualStreamZhEnZipformer20230220() -> SherpaOnnxOnlineTransducerMode decoder: decoder, joiner: joiner, tokens: tokens, - numThreads: 2 + numThreads: 2, + modelType: "zipformer" ) } diff --git a/scripts/dotnet/online.cs b/scripts/dotnet/online.cs index de30f459..d423ee36 100644 --- a/scripts/dotnet/online.cs +++ b/scripts/dotnet/online.cs @@ -26,6 +26,7 @@ namespace SherpaOnnx NumThreads = 1; Provider = "cpu"; Debug = 0; + ModelType = ""; } [MarshalAs(UnmanagedType.LPStr)] public string Encoder; @@ -47,6 +48,9 @@ namespace SherpaOnnx /// true to print debug information of the model public int Debug; + + [MarshalAs(UnmanagedType.LPStr)] + public string ModelType; } /// It expects 16 kHz 16-bit single channel wave format. diff --git a/sherpa-onnx/c-api/c-api.h b/sherpa-onnx/c-api/c-api.h index 698f7f38..aeb68b43 100644 --- a/sherpa-onnx/c-api/c-api.h +++ b/sherpa-onnx/c-api/c-api.h @@ -53,8 +53,8 @@ SHERPA_ONNX_API typedef struct SherpaOnnxOnlineTransducerModelConfig { const char *tokens; int32_t num_threads; const char *provider; - const char *model_type; int32_t debug; // true to print debug information of the model + const char *model_type; } SherpaOnnxOnlineTransducerModelConfig; /// It expects 16 kHz 16-bit single channel wave format. diff --git a/sherpa-onnx/jni/jni.cc b/sherpa-onnx/jni/jni.cc index ab5cec5d..95cc5294 100644 --- a/sherpa-onnx/jni/jni.cc +++ b/sherpa-onnx/jni/jni.cc @@ -187,6 +187,18 @@ static OnlineRecognizerConfig GetConfig(JNIEnv *env, jobject config) { fid = env->GetFieldID(model_config_cls, "debug", "Z"); ans.model_config.debug = env->GetBooleanField(model_config, fid); + fid = env->GetFieldID(model_config_cls, "provider", "Ljava/lang/String;"); + s = (jstring)env->GetObjectField(model_config, fid); + p = env->GetStringUTFChars(s, nullptr); + ans.model_config.provider = p; + env->ReleaseStringUTFChars(s, p); + + fid = env->GetFieldID(model_config_cls, "modelType", "Ljava/lang/String;"); + s = (jstring)env->GetObjectField(model_config, fid); + p = env->GetStringUTFChars(s, nullptr); + ans.model_config.model_type = p; + env->ReleaseStringUTFChars(s, p); + //---------- rnn lm model config ---------- fid = env->GetFieldID(cls, "lmConfig", "Lcom/k2fsa/sherpa/onnx/OnlineLMConfig;"); diff --git a/swift-api-examples/SherpaOnnx.swift b/swift-api-examples/SherpaOnnx.swift index 7310d94f..c22c938a 100644 --- a/swift-api-examples/SherpaOnnx.swift +++ b/swift-api-examples/SherpaOnnx.swift @@ -36,7 +36,8 @@ func sherpaOnnxOnlineTransducerModelConfig( tokens: String, numThreads: Int = 2, provider: String = "cpu", - debug: Int = 0 + debug: Int = 0, + modelType: String = "" ) -> SherpaOnnxOnlineTransducerModelConfig { return SherpaOnnxOnlineTransducerModelConfig( encoder: toCPointer(encoder), @@ -45,7 +46,8 @@ func sherpaOnnxOnlineTransducerModelConfig( tokens: toCPointer(tokens), num_threads: Int32(numThreads), provider: toCPointer(provider), - debug: Int32(debug) + debug: Int32(debug), + model_type: toCPointer(modelType) ) }