Fix model_type for jni, c# and iOS. (#216)
This commit is contained in:
@@ -21,6 +21,8 @@ data class OnlineTransducerModelConfig(
|
||||
var tokens: String,
|
||||
var numThreads: Int = 1,
|
||||
var debug: Boolean = false,
|
||||
var provider: String = "cpu",
|
||||
var modelType: String = "",
|
||||
)
|
||||
|
||||
data class OnlineLMConfig(
|
||||
@@ -135,6 +137,7 @@ fun getModelConfig(type: Int): OnlineTransducerModelConfig? {
|
||||
decoder = "$modelDir/decoder-epoch-99-avg-1.onnx",
|
||||
joiner = "$modelDir/joiner-epoch-99-avg-1.onnx",
|
||||
tokens = "$modelDir/tokens.txt",
|
||||
modelType = "zipformer",
|
||||
)
|
||||
}
|
||||
1 -> {
|
||||
@@ -144,6 +147,7 @@ fun getModelConfig(type: Int): OnlineTransducerModelConfig? {
|
||||
decoder = "$modelDir/decoder-epoch-11-avg-1.onnx",
|
||||
joiner = "$modelDir/joiner-epoch-11-avg-1.onnx",
|
||||
tokens = "$modelDir/tokens.txt",
|
||||
modelType = "lstm",
|
||||
)
|
||||
}
|
||||
|
||||
@@ -154,6 +158,7 @@ fun getModelConfig(type: Int): OnlineTransducerModelConfig? {
|
||||
decoder = "$modelDir/decoder-epoch-99-avg-1.onnx",
|
||||
joiner = "$modelDir/joiner-epoch-99-avg-1.onnx",
|
||||
tokens = "$modelDir/tokens.txt",
|
||||
modelType = "lstm",
|
||||
)
|
||||
}
|
||||
|
||||
@@ -164,6 +169,7 @@ fun getModelConfig(type: Int): OnlineTransducerModelConfig? {
|
||||
decoder = "$modelDir/exp/decoder-epoch-12-avg-4-chunk-16-left-128.onnx",
|
||||
joiner = "$modelDir/exp/joiner-epoch-12-avg-4-chunk-16-left-128.onnx",
|
||||
tokens = "$modelDir/data/lang_char/tokens.txt",
|
||||
modelType = "zipformer2",
|
||||
)
|
||||
}
|
||||
|
||||
@@ -174,6 +180,7 @@ fun getModelConfig(type: Int): OnlineTransducerModelConfig? {
|
||||
decoder = "$modelDir/exp/decoder-epoch-12-avg-4-chunk-16-left-128.onnx",
|
||||
joiner = "$modelDir/exp/joiner-epoch-12-avg-4-chunk-16-left-128.onnx",
|
||||
tokens = "$modelDir/data/lang_char/tokens.txt",
|
||||
modelType = "zipformer2",
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -26,7 +26,8 @@ func getBilingualStreamZhEnZipformer20230220() -> SherpaOnnxOnlineTransducerMode
|
||||
decoder: decoder,
|
||||
joiner: joiner,
|
||||
tokens: tokens,
|
||||
numThreads: 2
|
||||
numThreads: 2,
|
||||
modelType: "zipformer"
|
||||
)
|
||||
}
|
||||
|
||||
@@ -41,7 +42,8 @@ func getZhZipformer20230615() -> SherpaOnnxOnlineTransducerModelConfig {
|
||||
decoder: decoder,
|
||||
joiner: joiner,
|
||||
tokens: tokens,
|
||||
numThreads: 2
|
||||
numThreads: 2,
|
||||
modelType: "zipformer2"
|
||||
)
|
||||
}
|
||||
|
||||
@@ -56,7 +58,8 @@ func getZhZipformer20230615Int8() -> SherpaOnnxOnlineTransducerModelConfig {
|
||||
decoder: decoder,
|
||||
joiner: joiner,
|
||||
tokens: tokens,
|
||||
numThreads: 2
|
||||
numThreads: 2,
|
||||
modelType: "zipformer2"
|
||||
)
|
||||
}
|
||||
|
||||
@@ -71,7 +74,8 @@ func getEnZipformer20230626() -> SherpaOnnxOnlineTransducerModelConfig {
|
||||
decoder: decoder,
|
||||
joiner: joiner,
|
||||
tokens: tokens,
|
||||
numThreads: 2
|
||||
numThreads: 2,
|
||||
modelType: "zipformer2"
|
||||
)
|
||||
}
|
||||
|
||||
|
||||
@@ -26,7 +26,8 @@ func getBilingualStreamZhEnZipformer20230220() -> SherpaOnnxOnlineTransducerMode
|
||||
decoder: decoder,
|
||||
joiner: joiner,
|
||||
tokens: tokens,
|
||||
numThreads: 2
|
||||
numThreads: 2,
|
||||
modelType: "zipformer"
|
||||
)
|
||||
}
|
||||
|
||||
|
||||
@@ -26,6 +26,7 @@ namespace SherpaOnnx
|
||||
NumThreads = 1;
|
||||
Provider = "cpu";
|
||||
Debug = 0;
|
||||
ModelType = "";
|
||||
}
|
||||
[MarshalAs(UnmanagedType.LPStr)]
|
||||
public string Encoder;
|
||||
@@ -47,6 +48,9 @@ namespace SherpaOnnx
|
||||
|
||||
/// true to print debug information of the model
|
||||
public int Debug;
|
||||
|
||||
[MarshalAs(UnmanagedType.LPStr)]
|
||||
public string ModelType;
|
||||
}
|
||||
|
||||
/// It expects 16 kHz 16-bit single channel wave format.
|
||||
|
||||
@@ -53,8 +53,8 @@ SHERPA_ONNX_API typedef struct SherpaOnnxOnlineTransducerModelConfig {
|
||||
const char *tokens;
|
||||
int32_t num_threads;
|
||||
const char *provider;
|
||||
const char *model_type;
|
||||
int32_t debug; // true to print debug information of the model
|
||||
const char *model_type;
|
||||
} SherpaOnnxOnlineTransducerModelConfig;
|
||||
|
||||
/// It expects 16 kHz 16-bit single channel wave format.
|
||||
|
||||
@@ -187,6 +187,18 @@ static OnlineRecognizerConfig GetConfig(JNIEnv *env, jobject config) {
|
||||
fid = env->GetFieldID(model_config_cls, "debug", "Z");
|
||||
ans.model_config.debug = env->GetBooleanField(model_config, fid);
|
||||
|
||||
fid = env->GetFieldID(model_config_cls, "provider", "Ljava/lang/String;");
|
||||
s = (jstring)env->GetObjectField(model_config, fid);
|
||||
p = env->GetStringUTFChars(s, nullptr);
|
||||
ans.model_config.provider = p;
|
||||
env->ReleaseStringUTFChars(s, p);
|
||||
|
||||
fid = env->GetFieldID(model_config_cls, "modelType", "Ljava/lang/String;");
|
||||
s = (jstring)env->GetObjectField(model_config, fid);
|
||||
p = env->GetStringUTFChars(s, nullptr);
|
||||
ans.model_config.model_type = p;
|
||||
env->ReleaseStringUTFChars(s, p);
|
||||
|
||||
//---------- rnn lm model config ----------
|
||||
fid = env->GetFieldID(cls, "lmConfig",
|
||||
"Lcom/k2fsa/sherpa/onnx/OnlineLMConfig;");
|
||||
|
||||
@@ -36,7 +36,8 @@ func sherpaOnnxOnlineTransducerModelConfig(
|
||||
tokens: String,
|
||||
numThreads: Int = 2,
|
||||
provider: String = "cpu",
|
||||
debug: Int = 0
|
||||
debug: Int = 0,
|
||||
modelType: String = ""
|
||||
) -> SherpaOnnxOnlineTransducerModelConfig {
|
||||
return SherpaOnnxOnlineTransducerModelConfig(
|
||||
encoder: toCPointer(encoder),
|
||||
@@ -45,7 +46,8 @@ func sherpaOnnxOnlineTransducerModelConfig(
|
||||
tokens: toCPointer(tokens),
|
||||
num_threads: Int32(numThreads),
|
||||
provider: toCPointer(provider),
|
||||
debug: Int32(debug)
|
||||
debug: Int32(debug),
|
||||
model_type: toCPointer(modelType)
|
||||
)
|
||||
}
|
||||
|
||||
|
||||
Reference in New Issue
Block a user