Fix model_type for jni, c# and iOS. (#216)
This commit is contained in:
@@ -21,6 +21,8 @@ data class OnlineTransducerModelConfig(
|
|||||||
var tokens: String,
|
var tokens: String,
|
||||||
var numThreads: Int = 1,
|
var numThreads: Int = 1,
|
||||||
var debug: Boolean = false,
|
var debug: Boolean = false,
|
||||||
|
var provider: String = "cpu",
|
||||||
|
var modelType: String = "",
|
||||||
)
|
)
|
||||||
|
|
||||||
data class OnlineLMConfig(
|
data class OnlineLMConfig(
|
||||||
@@ -135,6 +137,7 @@ fun getModelConfig(type: Int): OnlineTransducerModelConfig? {
|
|||||||
decoder = "$modelDir/decoder-epoch-99-avg-1.onnx",
|
decoder = "$modelDir/decoder-epoch-99-avg-1.onnx",
|
||||||
joiner = "$modelDir/joiner-epoch-99-avg-1.onnx",
|
joiner = "$modelDir/joiner-epoch-99-avg-1.onnx",
|
||||||
tokens = "$modelDir/tokens.txt",
|
tokens = "$modelDir/tokens.txt",
|
||||||
|
modelType = "zipformer",
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
1 -> {
|
1 -> {
|
||||||
@@ -144,6 +147,7 @@ fun getModelConfig(type: Int): OnlineTransducerModelConfig? {
|
|||||||
decoder = "$modelDir/decoder-epoch-11-avg-1.onnx",
|
decoder = "$modelDir/decoder-epoch-11-avg-1.onnx",
|
||||||
joiner = "$modelDir/joiner-epoch-11-avg-1.onnx",
|
joiner = "$modelDir/joiner-epoch-11-avg-1.onnx",
|
||||||
tokens = "$modelDir/tokens.txt",
|
tokens = "$modelDir/tokens.txt",
|
||||||
|
modelType = "lstm",
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -154,6 +158,7 @@ fun getModelConfig(type: Int): OnlineTransducerModelConfig? {
|
|||||||
decoder = "$modelDir/decoder-epoch-99-avg-1.onnx",
|
decoder = "$modelDir/decoder-epoch-99-avg-1.onnx",
|
||||||
joiner = "$modelDir/joiner-epoch-99-avg-1.onnx",
|
joiner = "$modelDir/joiner-epoch-99-avg-1.onnx",
|
||||||
tokens = "$modelDir/tokens.txt",
|
tokens = "$modelDir/tokens.txt",
|
||||||
|
modelType = "lstm",
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -164,6 +169,7 @@ fun getModelConfig(type: Int): OnlineTransducerModelConfig? {
|
|||||||
decoder = "$modelDir/exp/decoder-epoch-12-avg-4-chunk-16-left-128.onnx",
|
decoder = "$modelDir/exp/decoder-epoch-12-avg-4-chunk-16-left-128.onnx",
|
||||||
joiner = "$modelDir/exp/joiner-epoch-12-avg-4-chunk-16-left-128.onnx",
|
joiner = "$modelDir/exp/joiner-epoch-12-avg-4-chunk-16-left-128.onnx",
|
||||||
tokens = "$modelDir/data/lang_char/tokens.txt",
|
tokens = "$modelDir/data/lang_char/tokens.txt",
|
||||||
|
modelType = "zipformer2",
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -174,6 +180,7 @@ fun getModelConfig(type: Int): OnlineTransducerModelConfig? {
|
|||||||
decoder = "$modelDir/exp/decoder-epoch-12-avg-4-chunk-16-left-128.onnx",
|
decoder = "$modelDir/exp/decoder-epoch-12-avg-4-chunk-16-left-128.onnx",
|
||||||
joiner = "$modelDir/exp/joiner-epoch-12-avg-4-chunk-16-left-128.onnx",
|
joiner = "$modelDir/exp/joiner-epoch-12-avg-4-chunk-16-left-128.onnx",
|
||||||
tokens = "$modelDir/data/lang_char/tokens.txt",
|
tokens = "$modelDir/data/lang_char/tokens.txt",
|
||||||
|
modelType = "zipformer2",
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -26,7 +26,8 @@ func getBilingualStreamZhEnZipformer20230220() -> SherpaOnnxOnlineTransducerMode
|
|||||||
decoder: decoder,
|
decoder: decoder,
|
||||||
joiner: joiner,
|
joiner: joiner,
|
||||||
tokens: tokens,
|
tokens: tokens,
|
||||||
numThreads: 2
|
numThreads: 2,
|
||||||
|
modelType: "zipformer"
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -41,7 +42,8 @@ func getZhZipformer20230615() -> SherpaOnnxOnlineTransducerModelConfig {
|
|||||||
decoder: decoder,
|
decoder: decoder,
|
||||||
joiner: joiner,
|
joiner: joiner,
|
||||||
tokens: tokens,
|
tokens: tokens,
|
||||||
numThreads: 2
|
numThreads: 2,
|
||||||
|
modelType: "zipformer2"
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -56,7 +58,8 @@ func getZhZipformer20230615Int8() -> SherpaOnnxOnlineTransducerModelConfig {
|
|||||||
decoder: decoder,
|
decoder: decoder,
|
||||||
joiner: joiner,
|
joiner: joiner,
|
||||||
tokens: tokens,
|
tokens: tokens,
|
||||||
numThreads: 2
|
numThreads: 2,
|
||||||
|
modelType: "zipformer2"
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -71,7 +74,8 @@ func getEnZipformer20230626() -> SherpaOnnxOnlineTransducerModelConfig {
|
|||||||
decoder: decoder,
|
decoder: decoder,
|
||||||
joiner: joiner,
|
joiner: joiner,
|
||||||
tokens: tokens,
|
tokens: tokens,
|
||||||
numThreads: 2
|
numThreads: 2,
|
||||||
|
modelType: "zipformer2"
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -26,7 +26,8 @@ func getBilingualStreamZhEnZipformer20230220() -> SherpaOnnxOnlineTransducerMode
|
|||||||
decoder: decoder,
|
decoder: decoder,
|
||||||
joiner: joiner,
|
joiner: joiner,
|
||||||
tokens: tokens,
|
tokens: tokens,
|
||||||
numThreads: 2
|
numThreads: 2,
|
||||||
|
modelType: "zipformer"
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -26,6 +26,7 @@ namespace SherpaOnnx
|
|||||||
NumThreads = 1;
|
NumThreads = 1;
|
||||||
Provider = "cpu";
|
Provider = "cpu";
|
||||||
Debug = 0;
|
Debug = 0;
|
||||||
|
ModelType = "";
|
||||||
}
|
}
|
||||||
[MarshalAs(UnmanagedType.LPStr)]
|
[MarshalAs(UnmanagedType.LPStr)]
|
||||||
public string Encoder;
|
public string Encoder;
|
||||||
@@ -47,6 +48,9 @@ namespace SherpaOnnx
|
|||||||
|
|
||||||
/// true to print debug information of the model
|
/// true to print debug information of the model
|
||||||
public int Debug;
|
public int Debug;
|
||||||
|
|
||||||
|
[MarshalAs(UnmanagedType.LPStr)]
|
||||||
|
public string ModelType;
|
||||||
}
|
}
|
||||||
|
|
||||||
/// It expects 16 kHz 16-bit single channel wave format.
|
/// It expects 16 kHz 16-bit single channel wave format.
|
||||||
|
|||||||
@@ -53,8 +53,8 @@ SHERPA_ONNX_API typedef struct SherpaOnnxOnlineTransducerModelConfig {
|
|||||||
const char *tokens;
|
const char *tokens;
|
||||||
int32_t num_threads;
|
int32_t num_threads;
|
||||||
const char *provider;
|
const char *provider;
|
||||||
const char *model_type;
|
|
||||||
int32_t debug; // true to print debug information of the model
|
int32_t debug; // true to print debug information of the model
|
||||||
|
const char *model_type;
|
||||||
} SherpaOnnxOnlineTransducerModelConfig;
|
} SherpaOnnxOnlineTransducerModelConfig;
|
||||||
|
|
||||||
/// It expects 16 kHz 16-bit single channel wave format.
|
/// It expects 16 kHz 16-bit single channel wave format.
|
||||||
|
|||||||
@@ -187,6 +187,18 @@ static OnlineRecognizerConfig GetConfig(JNIEnv *env, jobject config) {
|
|||||||
fid = env->GetFieldID(model_config_cls, "debug", "Z");
|
fid = env->GetFieldID(model_config_cls, "debug", "Z");
|
||||||
ans.model_config.debug = env->GetBooleanField(model_config, fid);
|
ans.model_config.debug = env->GetBooleanField(model_config, fid);
|
||||||
|
|
||||||
|
fid = env->GetFieldID(model_config_cls, "provider", "Ljava/lang/String;");
|
||||||
|
s = (jstring)env->GetObjectField(model_config, fid);
|
||||||
|
p = env->GetStringUTFChars(s, nullptr);
|
||||||
|
ans.model_config.provider = p;
|
||||||
|
env->ReleaseStringUTFChars(s, p);
|
||||||
|
|
||||||
|
fid = env->GetFieldID(model_config_cls, "modelType", "Ljava/lang/String;");
|
||||||
|
s = (jstring)env->GetObjectField(model_config, fid);
|
||||||
|
p = env->GetStringUTFChars(s, nullptr);
|
||||||
|
ans.model_config.model_type = p;
|
||||||
|
env->ReleaseStringUTFChars(s, p);
|
||||||
|
|
||||||
//---------- rnn lm model config ----------
|
//---------- rnn lm model config ----------
|
||||||
fid = env->GetFieldID(cls, "lmConfig",
|
fid = env->GetFieldID(cls, "lmConfig",
|
||||||
"Lcom/k2fsa/sherpa/onnx/OnlineLMConfig;");
|
"Lcom/k2fsa/sherpa/onnx/OnlineLMConfig;");
|
||||||
|
|||||||
@@ -36,7 +36,8 @@ func sherpaOnnxOnlineTransducerModelConfig(
|
|||||||
tokens: String,
|
tokens: String,
|
||||||
numThreads: Int = 2,
|
numThreads: Int = 2,
|
||||||
provider: String = "cpu",
|
provider: String = "cpu",
|
||||||
debug: Int = 0
|
debug: Int = 0,
|
||||||
|
modelType: String = ""
|
||||||
) -> SherpaOnnxOnlineTransducerModelConfig {
|
) -> SherpaOnnxOnlineTransducerModelConfig {
|
||||||
return SherpaOnnxOnlineTransducerModelConfig(
|
return SherpaOnnxOnlineTransducerModelConfig(
|
||||||
encoder: toCPointer(encoder),
|
encoder: toCPointer(encoder),
|
||||||
@@ -45,7 +46,8 @@ func sherpaOnnxOnlineTransducerModelConfig(
|
|||||||
tokens: toCPointer(tokens),
|
tokens: toCPointer(tokens),
|
||||||
num_threads: Int32(numThreads),
|
num_threads: Int32(numThreads),
|
||||||
provider: toCPointer(provider),
|
provider: toCPointer(provider),
|
||||||
debug: Int32(debug)
|
debug: Int32(debug),
|
||||||
|
model_type: toCPointer(modelType)
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user