Fix model_type for jni, c# and iOS. (#216)

This commit is contained in:
Fangjun Kuang
2023-07-14 22:24:38 +08:00
committed by GitHub
parent 5a6b55c5a7
commit de2673680e
7 changed files with 38 additions and 8 deletions

View File

@@ -21,6 +21,8 @@ data class OnlineTransducerModelConfig(
var tokens: String, var tokens: String,
var numThreads: Int = 1, var numThreads: Int = 1,
var debug: Boolean = false, var debug: Boolean = false,
var provider: String = "cpu",
var modelType: String = "",
) )
data class OnlineLMConfig( data class OnlineLMConfig(
@@ -135,6 +137,7 @@ fun getModelConfig(type: Int): OnlineTransducerModelConfig? {
decoder = "$modelDir/decoder-epoch-99-avg-1.onnx", decoder = "$modelDir/decoder-epoch-99-avg-1.onnx",
joiner = "$modelDir/joiner-epoch-99-avg-1.onnx", joiner = "$modelDir/joiner-epoch-99-avg-1.onnx",
tokens = "$modelDir/tokens.txt", tokens = "$modelDir/tokens.txt",
modelType = "zipformer",
) )
} }
1 -> { 1 -> {
@@ -144,6 +147,7 @@ fun getModelConfig(type: Int): OnlineTransducerModelConfig? {
decoder = "$modelDir/decoder-epoch-11-avg-1.onnx", decoder = "$modelDir/decoder-epoch-11-avg-1.onnx",
joiner = "$modelDir/joiner-epoch-11-avg-1.onnx", joiner = "$modelDir/joiner-epoch-11-avg-1.onnx",
tokens = "$modelDir/tokens.txt", tokens = "$modelDir/tokens.txt",
modelType = "lstm",
) )
} }
@@ -154,6 +158,7 @@ fun getModelConfig(type: Int): OnlineTransducerModelConfig? {
decoder = "$modelDir/decoder-epoch-99-avg-1.onnx", decoder = "$modelDir/decoder-epoch-99-avg-1.onnx",
joiner = "$modelDir/joiner-epoch-99-avg-1.onnx", joiner = "$modelDir/joiner-epoch-99-avg-1.onnx",
tokens = "$modelDir/tokens.txt", tokens = "$modelDir/tokens.txt",
modelType = "lstm",
) )
} }
@@ -164,6 +169,7 @@ fun getModelConfig(type: Int): OnlineTransducerModelConfig? {
decoder = "$modelDir/exp/decoder-epoch-12-avg-4-chunk-16-left-128.onnx", decoder = "$modelDir/exp/decoder-epoch-12-avg-4-chunk-16-left-128.onnx",
joiner = "$modelDir/exp/joiner-epoch-12-avg-4-chunk-16-left-128.onnx", joiner = "$modelDir/exp/joiner-epoch-12-avg-4-chunk-16-left-128.onnx",
tokens = "$modelDir/data/lang_char/tokens.txt", tokens = "$modelDir/data/lang_char/tokens.txt",
modelType = "zipformer2",
) )
} }
@@ -174,6 +180,7 @@ fun getModelConfig(type: Int): OnlineTransducerModelConfig? {
decoder = "$modelDir/exp/decoder-epoch-12-avg-4-chunk-16-left-128.onnx", decoder = "$modelDir/exp/decoder-epoch-12-avg-4-chunk-16-left-128.onnx",
joiner = "$modelDir/exp/joiner-epoch-12-avg-4-chunk-16-left-128.onnx", joiner = "$modelDir/exp/joiner-epoch-12-avg-4-chunk-16-left-128.onnx",
tokens = "$modelDir/data/lang_char/tokens.txt", tokens = "$modelDir/data/lang_char/tokens.txt",
modelType = "zipformer2",
) )
} }
} }

View File

@@ -26,7 +26,8 @@ func getBilingualStreamZhEnZipformer20230220() -> SherpaOnnxOnlineTransducerMode
decoder: decoder, decoder: decoder,
joiner: joiner, joiner: joiner,
tokens: tokens, tokens: tokens,
numThreads: 2 numThreads: 2,
modelType: "zipformer"
) )
} }
@@ -41,7 +42,8 @@ func getZhZipformer20230615() -> SherpaOnnxOnlineTransducerModelConfig {
decoder: decoder, decoder: decoder,
joiner: joiner, joiner: joiner,
tokens: tokens, tokens: tokens,
numThreads: 2 numThreads: 2,
modelType: "zipformer2"
) )
} }
@@ -56,7 +58,8 @@ func getZhZipformer20230615Int8() -> SherpaOnnxOnlineTransducerModelConfig {
decoder: decoder, decoder: decoder,
joiner: joiner, joiner: joiner,
tokens: tokens, tokens: tokens,
numThreads: 2 numThreads: 2,
modelType: "zipformer2"
) )
} }
@@ -71,7 +74,8 @@ func getEnZipformer20230626() -> SherpaOnnxOnlineTransducerModelConfig {
decoder: decoder, decoder: decoder,
joiner: joiner, joiner: joiner,
tokens: tokens, tokens: tokens,
numThreads: 2 numThreads: 2,
modelType: "zipformer2"
) )
} }

View File

@@ -26,7 +26,8 @@ func getBilingualStreamZhEnZipformer20230220() -> SherpaOnnxOnlineTransducerMode
decoder: decoder, decoder: decoder,
joiner: joiner, joiner: joiner,
tokens: tokens, tokens: tokens,
numThreads: 2 numThreads: 2,
modelType: "zipformer"
) )
} }

View File

@@ -26,6 +26,7 @@ namespace SherpaOnnx
NumThreads = 1; NumThreads = 1;
Provider = "cpu"; Provider = "cpu";
Debug = 0; Debug = 0;
ModelType = "";
} }
[MarshalAs(UnmanagedType.LPStr)] [MarshalAs(UnmanagedType.LPStr)]
public string Encoder; public string Encoder;
@@ -47,6 +48,9 @@ namespace SherpaOnnx
/// true to print debug information of the model /// true to print debug information of the model
public int Debug; public int Debug;
[MarshalAs(UnmanagedType.LPStr)]
public string ModelType;
} }
/// It expects 16 kHz 16-bit single channel wave format. /// It expects 16 kHz 16-bit single channel wave format.

View File

@@ -53,8 +53,8 @@ SHERPA_ONNX_API typedef struct SherpaOnnxOnlineTransducerModelConfig {
const char *tokens; const char *tokens;
int32_t num_threads; int32_t num_threads;
const char *provider; const char *provider;
const char *model_type;
int32_t debug; // true to print debug information of the model int32_t debug; // true to print debug information of the model
const char *model_type;
} SherpaOnnxOnlineTransducerModelConfig; } SherpaOnnxOnlineTransducerModelConfig;
/// It expects 16 kHz 16-bit single channel wave format. /// It expects 16 kHz 16-bit single channel wave format.

View File

@@ -187,6 +187,18 @@ static OnlineRecognizerConfig GetConfig(JNIEnv *env, jobject config) {
fid = env->GetFieldID(model_config_cls, "debug", "Z"); fid = env->GetFieldID(model_config_cls, "debug", "Z");
ans.model_config.debug = env->GetBooleanField(model_config, fid); ans.model_config.debug = env->GetBooleanField(model_config, fid);
fid = env->GetFieldID(model_config_cls, "provider", "Ljava/lang/String;");
s = (jstring)env->GetObjectField(model_config, fid);
p = env->GetStringUTFChars(s, nullptr);
ans.model_config.provider = p;
env->ReleaseStringUTFChars(s, p);
fid = env->GetFieldID(model_config_cls, "modelType", "Ljava/lang/String;");
s = (jstring)env->GetObjectField(model_config, fid);
p = env->GetStringUTFChars(s, nullptr);
ans.model_config.model_type = p;
env->ReleaseStringUTFChars(s, p);
//---------- rnn lm model config ---------- //---------- rnn lm model config ----------
fid = env->GetFieldID(cls, "lmConfig", fid = env->GetFieldID(cls, "lmConfig",
"Lcom/k2fsa/sherpa/onnx/OnlineLMConfig;"); "Lcom/k2fsa/sherpa/onnx/OnlineLMConfig;");

View File

@@ -36,7 +36,8 @@ func sherpaOnnxOnlineTransducerModelConfig(
tokens: String, tokens: String,
numThreads: Int = 2, numThreads: Int = 2,
provider: String = "cpu", provider: String = "cpu",
debug: Int = 0 debug: Int = 0,
modelType: String = ""
) -> SherpaOnnxOnlineTransducerModelConfig { ) -> SherpaOnnxOnlineTransducerModelConfig {
return SherpaOnnxOnlineTransducerModelConfig( return SherpaOnnxOnlineTransducerModelConfig(
encoder: toCPointer(encoder), encoder: toCPointer(encoder),
@@ -45,7 +46,8 @@ func sherpaOnnxOnlineTransducerModelConfig(
tokens: toCPointer(tokens), tokens: toCPointer(tokens),
num_threads: Int32(numThreads), num_threads: Int32(numThreads),
provider: toCPointer(provider), provider: toCPointer(provider),
debug: Int32(debug) debug: Int32(debug),
model_type: toCPointer(modelType)
) )
} }