Support paraformer on Android (#264)
This commit is contained in:
@@ -177,7 +177,7 @@ class MainActivity : AppCompatActivity() {
|
|||||||
// Please change getModelConfig() to add new models
|
// Please change getModelConfig() to add new models
|
||||||
// See https://k2-fsa.github.io/sherpa/onnx/pretrained_models/index.html
|
// See https://k2-fsa.github.io/sherpa/onnx/pretrained_models/index.html
|
||||||
// for a list of available models
|
// for a list of available models
|
||||||
val type = 3
|
val type = 5
|
||||||
println("Select model type ${type}")
|
println("Select model type ${type}")
|
||||||
val config = OnlineRecognizerConfig(
|
val config = OnlineRecognizerConfig(
|
||||||
featConfig = getFeatureConfig(sampleRate = sampleRateInHz, featureDim = 80),
|
featConfig = getFeatureConfig(sampleRate = sampleRateInHz, featureDim = 80),
|
||||||
@@ -185,8 +185,6 @@ class MainActivity : AppCompatActivity() {
|
|||||||
lmConfig = getOnlineLMConfig(type = type),
|
lmConfig = getOnlineLMConfig(type = type),
|
||||||
endpointConfig = getEndpointConfig(),
|
endpointConfig = getEndpointConfig(),
|
||||||
enableEndpoint = true,
|
enableEndpoint = true,
|
||||||
decodingMethod = "modified_beam_search",
|
|
||||||
maxActivePaths = 4,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
model = SherpaOnnx(
|
model = SherpaOnnx(
|
||||||
|
|||||||
@@ -15,9 +15,19 @@ data class EndpointConfig(
|
|||||||
)
|
)
|
||||||
|
|
||||||
data class OnlineTransducerModelConfig(
|
data class OnlineTransducerModelConfig(
|
||||||
var encoder: String,
|
var encoder: String = "",
|
||||||
var decoder: String,
|
var decoder: String = "",
|
||||||
var joiner: String,
|
var joiner: String = "",
|
||||||
|
)
|
||||||
|
|
||||||
|
data class OnlineParaformerModelConfig(
|
||||||
|
var encoder: String = "",
|
||||||
|
var decoder: String = "",
|
||||||
|
)
|
||||||
|
|
||||||
|
data class OnlineModelConfig(
|
||||||
|
var transducer: OnlineTransducerModelConfig = OnlineTransducerModelConfig(),
|
||||||
|
var paraformer: OnlineParaformerModelConfig = OnlineParaformerModelConfig(),
|
||||||
var tokens: String,
|
var tokens: String,
|
||||||
var numThreads: Int = 1,
|
var numThreads: Int = 1,
|
||||||
var debug: Boolean = false,
|
var debug: Boolean = false,
|
||||||
@@ -37,8 +47,8 @@ data class FeatureConfig(
|
|||||||
|
|
||||||
data class OnlineRecognizerConfig(
|
data class OnlineRecognizerConfig(
|
||||||
var featConfig: FeatureConfig = FeatureConfig(),
|
var featConfig: FeatureConfig = FeatureConfig(),
|
||||||
var modelConfig: OnlineTransducerModelConfig,
|
var modelConfig: OnlineModelConfig,
|
||||||
var lmConfig : OnlineLMConfig,
|
var lmConfig: OnlineLMConfig,
|
||||||
var endpointConfig: EndpointConfig = EndpointConfig(),
|
var endpointConfig: EndpointConfig = EndpointConfig(),
|
||||||
var enableEndpoint: Boolean = true,
|
var enableEndpoint: Boolean = true,
|
||||||
var decodingMethod: String = "greedy_search",
|
var decodingMethod: String = "greedy_search",
|
||||||
@@ -115,37 +125,47 @@ to add your own. (It should be straightforward to add a new model
|
|||||||
by following the code)
|
by following the code)
|
||||||
|
|
||||||
@param type
|
@param type
|
||||||
0 - sherpa-onnx-streaming-zipformer-bilingual-zh-en-2023-02-20 (Bilingual, Chinese + English)
|
0 - sherpa-onnx-streaming-zipformer-bilingual-zh-en-2023-02-20 (Bilingual, Chinese + English)
|
||||||
https://k2-fsa.github.io/sherpa/onnx/pretrained_models/zipformer-transducer-models.html#sherpa-onnx-streaming-zipformer-bilingual-zh-en-2023-02-20-bilingual-chinese-english
|
https://k2-fsa.github.io/sherpa/onnx/pretrained_models/zipformer-transducer-models.html#sherpa-onnx-streaming-zipformer-bilingual-zh-en-2023-02-20-bilingual-chinese-english
|
||||||
|
|
||||||
1 - csukuangfj/sherpa-onnx-lstm-zh-2023-02-20 (Chinese)
|
1 - csukuangfj/sherpa-onnx-lstm-zh-2023-02-20 (Chinese)
|
||||||
|
|
||||||
https://k2-fsa.github.io/sherpa/onnx/pretrained_models/lstm-transducer-models.html#csukuangfj-sherpa-onnx-lstm-zh-2023-02-20-chinese
|
https://k2-fsa.github.io/sherpa/onnx/pretrained_models/lstm-transducer-models.html#csukuangfj-sherpa-onnx-lstm-zh-2023-02-20-chinese
|
||||||
|
|
||||||
2 - csukuangfj/sherpa-onnx-lstm-en-2023-02-17 (English)
|
2 - csukuangfj/sherpa-onnx-lstm-en-2023-02-17 (English)
|
||||||
https://k2-fsa.github.io/sherpa/onnx/pretrained_models/lstm-transducer-models.html#csukuangfj-sherpa-onnx-lstm-en-2023-02-17-english
|
https://k2-fsa.github.io/sherpa/onnx/pretrained_models/lstm-transducer-models.html#csukuangfj-sherpa-onnx-lstm-en-2023-02-17-english
|
||||||
|
|
||||||
3 - pkufool/icefall-asr-zipformer-streaming-wenetspeech-20230615
|
3,4 - pkufool/icefall-asr-zipformer-streaming-wenetspeech-20230615
|
||||||
https://huggingface.co/pkufool/icefall-asr-zipformer-streaming-wenetspeech-20230615
|
https://huggingface.co/pkufool/icefall-asr-zipformer-streaming-wenetspeech-20230615
|
||||||
|
3 - int8 encoder
|
||||||
|
4 - float32 encoder
|
||||||
|
|
||||||
|
5 - csukuangfj/sherpa-onnx-streaming-paraformer-bilingual-zh-en
|
||||||
|
https://huggingface.co/csukuangfj/sherpa-onnx-streaming-paraformer-bilingual-zh-en
|
||||||
|
|
||||||
*/
|
*/
|
||||||
fun getModelConfig(type: Int): OnlineTransducerModelConfig? {
|
fun getModelConfig(type: Int): OnlineModelConfig? {
|
||||||
when (type) {
|
when (type) {
|
||||||
0 -> {
|
0 -> {
|
||||||
val modelDir = "sherpa-onnx-streaming-zipformer-bilingual-zh-en-2023-02-20"
|
val modelDir = "sherpa-onnx-streaming-zipformer-bilingual-zh-en-2023-02-20"
|
||||||
return OnlineTransducerModelConfig(
|
return OnlineModelConfig(
|
||||||
encoder = "$modelDir/encoder-epoch-99-avg-1.onnx",
|
transducer = OnlineTransducerModelConfig(
|
||||||
decoder = "$modelDir/decoder-epoch-99-avg-1.onnx",
|
encoder = "$modelDir/encoder-epoch-99-avg-1.onnx",
|
||||||
joiner = "$modelDir/joiner-epoch-99-avg-1.onnx",
|
decoder = "$modelDir/decoder-epoch-99-avg-1.onnx",
|
||||||
|
joiner = "$modelDir/joiner-epoch-99-avg-1.onnx",
|
||||||
|
),
|
||||||
tokens = "$modelDir/tokens.txt",
|
tokens = "$modelDir/tokens.txt",
|
||||||
modelType = "zipformer",
|
modelType = "zipformer",
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
1 -> {
|
1 -> {
|
||||||
val modelDir = "sherpa-onnx-lstm-zh-2023-02-20"
|
val modelDir = "sherpa-onnx-lstm-zh-2023-02-20"
|
||||||
return OnlineTransducerModelConfig(
|
return OnlineModelConfig(
|
||||||
encoder = "$modelDir/encoder-epoch-11-avg-1.onnx",
|
transducer = OnlineTransducerModelConfig(
|
||||||
decoder = "$modelDir/decoder-epoch-11-avg-1.onnx",
|
encoder = "$modelDir/encoder-epoch-11-avg-1.onnx",
|
||||||
joiner = "$modelDir/joiner-epoch-11-avg-1.onnx",
|
decoder = "$modelDir/decoder-epoch-11-avg-1.onnx",
|
||||||
|
joiner = "$modelDir/joiner-epoch-11-avg-1.onnx",
|
||||||
|
),
|
||||||
tokens = "$modelDir/tokens.txt",
|
tokens = "$modelDir/tokens.txt",
|
||||||
modelType = "lstm",
|
modelType = "lstm",
|
||||||
)
|
)
|
||||||
@@ -153,10 +173,12 @@ fun getModelConfig(type: Int): OnlineTransducerModelConfig? {
|
|||||||
|
|
||||||
2 -> {
|
2 -> {
|
||||||
val modelDir = "sherpa-onnx-lstm-en-2023-02-17"
|
val modelDir = "sherpa-onnx-lstm-en-2023-02-17"
|
||||||
return OnlineTransducerModelConfig(
|
return OnlineModelConfig(
|
||||||
encoder = "$modelDir/encoder-epoch-99-avg-1.onnx",
|
transducer = OnlineTransducerModelConfig(
|
||||||
decoder = "$modelDir/decoder-epoch-99-avg-1.onnx",
|
encoder = "$modelDir/encoder-epoch-99-avg-1.onnx",
|
||||||
joiner = "$modelDir/joiner-epoch-99-avg-1.onnx",
|
decoder = "$modelDir/decoder-epoch-99-avg-1.onnx",
|
||||||
|
joiner = "$modelDir/joiner-epoch-99-avg-1.onnx",
|
||||||
|
),
|
||||||
tokens = "$modelDir/tokens.txt",
|
tokens = "$modelDir/tokens.txt",
|
||||||
modelType = "lstm",
|
modelType = "lstm",
|
||||||
)
|
)
|
||||||
@@ -164,10 +186,12 @@ fun getModelConfig(type: Int): OnlineTransducerModelConfig? {
|
|||||||
|
|
||||||
3 -> {
|
3 -> {
|
||||||
val modelDir = "icefall-asr-zipformer-streaming-wenetspeech-20230615"
|
val modelDir = "icefall-asr-zipformer-streaming-wenetspeech-20230615"
|
||||||
return OnlineTransducerModelConfig(
|
return OnlineModelConfig(
|
||||||
encoder = "$modelDir/exp/encoder-epoch-12-avg-4-chunk-16-left-128.int8.onnx",
|
transducer = OnlineTransducerModelConfig(
|
||||||
decoder = "$modelDir/exp/decoder-epoch-12-avg-4-chunk-16-left-128.onnx",
|
encoder = "$modelDir/exp/encoder-epoch-12-avg-4-chunk-16-left-128.int8.onnx",
|
||||||
joiner = "$modelDir/exp/joiner-epoch-12-avg-4-chunk-16-left-128.onnx",
|
decoder = "$modelDir/exp/decoder-epoch-12-avg-4-chunk-16-left-128.onnx",
|
||||||
|
joiner = "$modelDir/exp/joiner-epoch-12-avg-4-chunk-16-left-128.onnx",
|
||||||
|
),
|
||||||
tokens = "$modelDir/data/lang_char/tokens.txt",
|
tokens = "$modelDir/data/lang_char/tokens.txt",
|
||||||
modelType = "zipformer2",
|
modelType = "zipformer2",
|
||||||
)
|
)
|
||||||
@@ -175,14 +199,28 @@ fun getModelConfig(type: Int): OnlineTransducerModelConfig? {
|
|||||||
|
|
||||||
4 -> {
|
4 -> {
|
||||||
val modelDir = "icefall-asr-zipformer-streaming-wenetspeech-20230615"
|
val modelDir = "icefall-asr-zipformer-streaming-wenetspeech-20230615"
|
||||||
return OnlineTransducerModelConfig(
|
return OnlineModelConfig(
|
||||||
encoder = "$modelDir/exp/encoder-epoch-12-avg-4-chunk-16-left-128.onnx",
|
transducer = OnlineTransducerModelConfig(
|
||||||
decoder = "$modelDir/exp/decoder-epoch-12-avg-4-chunk-16-left-128.onnx",
|
encoder = "$modelDir/exp/encoder-epoch-12-avg-4-chunk-16-left-128.onnx",
|
||||||
joiner = "$modelDir/exp/joiner-epoch-12-avg-4-chunk-16-left-128.onnx",
|
decoder = "$modelDir/exp/decoder-epoch-12-avg-4-chunk-16-left-128.onnx",
|
||||||
|
joiner = "$modelDir/exp/joiner-epoch-12-avg-4-chunk-16-left-128.onnx",
|
||||||
|
),
|
||||||
tokens = "$modelDir/data/lang_char/tokens.txt",
|
tokens = "$modelDir/data/lang_char/tokens.txt",
|
||||||
modelType = "zipformer2",
|
modelType = "zipformer2",
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
5 -> {
|
||||||
|
val modelDir = "sherpa-onnx-streaming-paraformer-bilingual-zh-en"
|
||||||
|
return OnlineModelConfig(
|
||||||
|
paraformer = OnlineParaformerModelConfig(
|
||||||
|
encoder = "$modelDir/encoder.int8.onnx",
|
||||||
|
decoder = "$modelDir/decoder.int8.onnx",
|
||||||
|
),
|
||||||
|
tokens = "$modelDir/tokens.txt",
|
||||||
|
modelType = "paraformer",
|
||||||
|
)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
return null;
|
return null;
|
||||||
}
|
}
|
||||||
@@ -200,7 +238,7 @@ by following the code, https://github.com/k2-fsa/icefall/blob/master/icefall/rnn
|
|||||||
0 - sherpa-onnx-streaming-zipformer-bilingual-zh-en-2023-02-20 (Bilingual, Chinese + English)
|
0 - sherpa-onnx-streaming-zipformer-bilingual-zh-en-2023-02-20 (Bilingual, Chinese + English)
|
||||||
https://k2-fsa.github.io/sherpa/onnx/pretrained_models/zipformer-transducer-models.html#sherpa-onnx-streaming-zipformer-bilingual-zh-en-2023-02-20-bilingual-chinese-english
|
https://k2-fsa.github.io/sherpa/onnx/pretrained_models/zipformer-transducer-models.html#sherpa-onnx-streaming-zipformer-bilingual-zh-en-2023-02-20-bilingual-chinese-english
|
||||||
*/
|
*/
|
||||||
fun getOnlineLMConfig(type : Int): OnlineLMConfig {
|
fun getOnlineLMConfig(type: Int): OnlineLMConfig {
|
||||||
when (type) {
|
when (type) {
|
||||||
0 -> {
|
0 -> {
|
||||||
val modelDir = "sherpa-onnx-streaming-zipformer-bilingual-zh-en-2023-02-20"
|
val modelDir = "sherpa-onnx-streaming-zipformer-bilingual-zh-en-2023-02-20"
|
||||||
|
|||||||
@@ -190,7 +190,11 @@ class OnlineRecognizerParaformerImpl : public OnlineRecognizerImpl {
|
|||||||
OnlineParaformerDecoderResult r;
|
OnlineParaformerDecoderResult r;
|
||||||
s->SetParaformerResult(r);
|
s->SetParaformerResult(r);
|
||||||
|
|
||||||
// the internal model caches are not reset
|
s->GetStates().clear();
|
||||||
|
s->GetParaformerEncoderOutCache().clear();
|
||||||
|
s->GetParaformerAlphaCache().clear();
|
||||||
|
|
||||||
|
// s->GetParaformerFeatCache().clear();
|
||||||
|
|
||||||
// Note: We only update counters. The underlying audio samples
|
// Note: We only update counters. The underlying audio samples
|
||||||
// are not discarded.
|
// are not discarded.
|
||||||
|
|||||||
@@ -47,7 +47,7 @@ class SherpaOnnx {
|
|||||||
}
|
}
|
||||||
|
|
||||||
void InputFinished() const {
|
void InputFinished() const {
|
||||||
std::vector<float> tail_padding(input_sample_rate_ * 0.32, 0);
|
std::vector<float> tail_padding(input_sample_rate_ * 0.6, 0);
|
||||||
stream_->AcceptWaveform(input_sample_rate_, tail_padding.data(),
|
stream_->AcceptWaveform(input_sample_rate_, tail_padding.data(),
|
||||||
tail_padding.size());
|
tail_padding.size());
|
||||||
stream_->InputFinished();
|
stream_->InputFinished();
|
||||||
@@ -158,48 +158,74 @@ static OnlineRecognizerConfig GetConfig(JNIEnv *env, jobject config) {
|
|||||||
|
|
||||||
//---------- model config ----------
|
//---------- model config ----------
|
||||||
fid = env->GetFieldID(cls, "modelConfig",
|
fid = env->GetFieldID(cls, "modelConfig",
|
||||||
"Lcom/k2fsa/sherpa/onnx/OnlineTransducerModelConfig;");
|
"Lcom/k2fsa/sherpa/onnx/OnlineModelConfig;");
|
||||||
jobject transducer_config = env->GetObjectField(config, fid);
|
jobject model_config = env->GetObjectField(config, fid);
|
||||||
jclass model_config_cls = env->GetObjectClass(transducer_config);
|
jclass model_config_cls = env->GetObjectClass(model_config);
|
||||||
|
|
||||||
fid = env->GetFieldID(model_config_cls, "encoder", "Ljava/lang/String;");
|
// transducer
|
||||||
|
fid = env->GetFieldID(model_config_cls, "transducer",
|
||||||
|
"Lcom/k2fsa/sherpa/onnx/OnlineTransducerModelConfig;");
|
||||||
|
jobject transducer_config = env->GetObjectField(model_config, fid);
|
||||||
|
jclass transducer_config_cls = env->GetObjectClass(transducer_config);
|
||||||
|
|
||||||
|
fid = env->GetFieldID(transducer_config_cls, "encoder", "Ljava/lang/String;");
|
||||||
s = (jstring)env->GetObjectField(transducer_config, fid);
|
s = (jstring)env->GetObjectField(transducer_config, fid);
|
||||||
p = env->GetStringUTFChars(s, nullptr);
|
p = env->GetStringUTFChars(s, nullptr);
|
||||||
ans.model_config.transducer.encoder = p;
|
ans.model_config.transducer.encoder = p;
|
||||||
env->ReleaseStringUTFChars(s, p);
|
env->ReleaseStringUTFChars(s, p);
|
||||||
|
|
||||||
fid = env->GetFieldID(model_config_cls, "decoder", "Ljava/lang/String;");
|
fid = env->GetFieldID(transducer_config_cls, "decoder", "Ljava/lang/String;");
|
||||||
s = (jstring)env->GetObjectField(transducer_config, fid);
|
s = (jstring)env->GetObjectField(transducer_config, fid);
|
||||||
p = env->GetStringUTFChars(s, nullptr);
|
p = env->GetStringUTFChars(s, nullptr);
|
||||||
ans.model_config.transducer.decoder = p;
|
ans.model_config.transducer.decoder = p;
|
||||||
env->ReleaseStringUTFChars(s, p);
|
env->ReleaseStringUTFChars(s, p);
|
||||||
|
|
||||||
fid = env->GetFieldID(model_config_cls, "joiner", "Ljava/lang/String;");
|
fid = env->GetFieldID(transducer_config_cls, "joiner", "Ljava/lang/String;");
|
||||||
s = (jstring)env->GetObjectField(transducer_config, fid);
|
s = (jstring)env->GetObjectField(transducer_config, fid);
|
||||||
p = env->GetStringUTFChars(s, nullptr);
|
p = env->GetStringUTFChars(s, nullptr);
|
||||||
ans.model_config.transducer.joiner = p;
|
ans.model_config.transducer.joiner = p;
|
||||||
env->ReleaseStringUTFChars(s, p);
|
env->ReleaseStringUTFChars(s, p);
|
||||||
|
|
||||||
|
// paraformer
|
||||||
|
fid = env->GetFieldID(model_config_cls, "paraformer",
|
||||||
|
"Lcom/k2fsa/sherpa/onnx/OnlineParaformerModelConfig;");
|
||||||
|
jobject paraformer_config = env->GetObjectField(model_config, fid);
|
||||||
|
jclass paraformer_config_config_cls = env->GetObjectClass(paraformer_config);
|
||||||
|
|
||||||
|
fid = env->GetFieldID(paraformer_config_config_cls, "encoder",
|
||||||
|
"Ljava/lang/String;");
|
||||||
|
s = (jstring)env->GetObjectField(paraformer_config, fid);
|
||||||
|
p = env->GetStringUTFChars(s, nullptr);
|
||||||
|
ans.model_config.paraformer.encoder = p;
|
||||||
|
env->ReleaseStringUTFChars(s, p);
|
||||||
|
|
||||||
|
fid = env->GetFieldID(paraformer_config_config_cls, "decoder",
|
||||||
|
"Ljava/lang/String;");
|
||||||
|
s = (jstring)env->GetObjectField(paraformer_config, fid);
|
||||||
|
p = env->GetStringUTFChars(s, nullptr);
|
||||||
|
ans.model_config.paraformer.decoder = p;
|
||||||
|
env->ReleaseStringUTFChars(s, p);
|
||||||
|
|
||||||
fid = env->GetFieldID(model_config_cls, "tokens", "Ljava/lang/String;");
|
fid = env->GetFieldID(model_config_cls, "tokens", "Ljava/lang/String;");
|
||||||
s = (jstring)env->GetObjectField(transducer_config, fid);
|
s = (jstring)env->GetObjectField(model_config, fid);
|
||||||
p = env->GetStringUTFChars(s, nullptr);
|
p = env->GetStringUTFChars(s, nullptr);
|
||||||
ans.model_config.tokens = p;
|
ans.model_config.tokens = p;
|
||||||
env->ReleaseStringUTFChars(s, p);
|
env->ReleaseStringUTFChars(s, p);
|
||||||
|
|
||||||
fid = env->GetFieldID(model_config_cls, "numThreads", "I");
|
fid = env->GetFieldID(model_config_cls, "numThreads", "I");
|
||||||
ans.model_config.num_threads = env->GetIntField(transducer_config, fid);
|
ans.model_config.num_threads = env->GetIntField(model_config, fid);
|
||||||
|
|
||||||
fid = env->GetFieldID(model_config_cls, "debug", "Z");
|
fid = env->GetFieldID(model_config_cls, "debug", "Z");
|
||||||
ans.model_config.debug = env->GetBooleanField(transducer_config, fid);
|
ans.model_config.debug = env->GetBooleanField(model_config, fid);
|
||||||
|
|
||||||
fid = env->GetFieldID(model_config_cls, "provider", "Ljava/lang/String;");
|
fid = env->GetFieldID(model_config_cls, "provider", "Ljava/lang/String;");
|
||||||
s = (jstring)env->GetObjectField(transducer_config, fid);
|
s = (jstring)env->GetObjectField(model_config, fid);
|
||||||
p = env->GetStringUTFChars(s, nullptr);
|
p = env->GetStringUTFChars(s, nullptr);
|
||||||
ans.model_config.provider = p;
|
ans.model_config.provider = p;
|
||||||
env->ReleaseStringUTFChars(s, p);
|
env->ReleaseStringUTFChars(s, p);
|
||||||
|
|
||||||
fid = env->GetFieldID(model_config_cls, "modelType", "Ljava/lang/String;");
|
fid = env->GetFieldID(model_config_cls, "modelType", "Ljava/lang/String;");
|
||||||
s = (jstring)env->GetObjectField(transducer_config, fid);
|
s = (jstring)env->GetObjectField(model_config, fid);
|
||||||
p = env->GetStringUTFChars(s, nullptr);
|
p = env->GetStringUTFChars(s, nullptr);
|
||||||
ans.model_config.model_type = p;
|
ans.model_config.model_type = p;
|
||||||
env->ReleaseStringUTFChars(s, p);
|
env->ReleaseStringUTFChars(s, p);
|
||||||
|
|||||||
Reference in New Issue
Block a user