Support paraformer on Android (#264)

This commit is contained in:
Fangjun Kuang
2023-08-14 12:26:15 +08:00
committed by GitHub
parent 6038e2aa62
commit 35526e26e1
4 changed files with 113 additions and 47 deletions

View File

@@ -177,7 +177,7 @@ class MainActivity : AppCompatActivity() {
// Please change getModelConfig() to add new models // Please change getModelConfig() to add new models
// See https://k2-fsa.github.io/sherpa/onnx/pretrained_models/index.html // See https://k2-fsa.github.io/sherpa/onnx/pretrained_models/index.html
// for a list of available models // for a list of available models
val type = 3 val type = 5
println("Select model type ${type}") println("Select model type ${type}")
val config = OnlineRecognizerConfig( val config = OnlineRecognizerConfig(
featConfig = getFeatureConfig(sampleRate = sampleRateInHz, featureDim = 80), featConfig = getFeatureConfig(sampleRate = sampleRateInHz, featureDim = 80),
@@ -185,8 +185,6 @@ class MainActivity : AppCompatActivity() {
lmConfig = getOnlineLMConfig(type = type), lmConfig = getOnlineLMConfig(type = type),
endpointConfig = getEndpointConfig(), endpointConfig = getEndpointConfig(),
enableEndpoint = true, enableEndpoint = true,
decodingMethod = "modified_beam_search",
maxActivePaths = 4,
) )
model = SherpaOnnx( model = SherpaOnnx(

View File

@@ -15,9 +15,19 @@ data class EndpointConfig(
) )
data class OnlineTransducerModelConfig( data class OnlineTransducerModelConfig(
var encoder: String, var encoder: String = "",
var decoder: String, var decoder: String = "",
var joiner: String, var joiner: String = "",
)
data class OnlineParaformerModelConfig(
var encoder: String = "",
var decoder: String = "",
)
data class OnlineModelConfig(
var transducer: OnlineTransducerModelConfig = OnlineTransducerModelConfig(),
var paraformer: OnlineParaformerModelConfig = OnlineParaformerModelConfig(),
var tokens: String, var tokens: String,
var numThreads: Int = 1, var numThreads: Int = 1,
var debug: Boolean = false, var debug: Boolean = false,
@@ -37,8 +47,8 @@ data class FeatureConfig(
data class OnlineRecognizerConfig( data class OnlineRecognizerConfig(
var featConfig: FeatureConfig = FeatureConfig(), var featConfig: FeatureConfig = FeatureConfig(),
var modelConfig: OnlineTransducerModelConfig, var modelConfig: OnlineModelConfig,
var lmConfig : OnlineLMConfig, var lmConfig: OnlineLMConfig,
var endpointConfig: EndpointConfig = EndpointConfig(), var endpointConfig: EndpointConfig = EndpointConfig(),
var enableEndpoint: Boolean = true, var enableEndpoint: Boolean = true,
var decodingMethod: String = "greedy_search", var decodingMethod: String = "greedy_search",
@@ -115,37 +125,47 @@ to add your own. (It should be straightforward to add a new model
by following the code) by following the code)
@param type @param type
0 - sherpa-onnx-streaming-zipformer-bilingual-zh-en-2023-02-20 (Bilingual, Chinese + English) 0 - sherpa-onnx-streaming-zipformer-bilingual-zh-en-2023-02-20 (Bilingual, Chinese + English)
https://k2-fsa.github.io/sherpa/onnx/pretrained_models/zipformer-transducer-models.html#sherpa-onnx-streaming-zipformer-bilingual-zh-en-2023-02-20-bilingual-chinese-english https://k2-fsa.github.io/sherpa/onnx/pretrained_models/zipformer-transducer-models.html#sherpa-onnx-streaming-zipformer-bilingual-zh-en-2023-02-20-bilingual-chinese-english
1 - csukuangfj/sherpa-onnx-lstm-zh-2023-02-20 (Chinese) 1 - csukuangfj/sherpa-onnx-lstm-zh-2023-02-20 (Chinese)
https://k2-fsa.github.io/sherpa/onnx/pretrained_models/lstm-transducer-models.html#csukuangfj-sherpa-onnx-lstm-zh-2023-02-20-chinese https://k2-fsa.github.io/sherpa/onnx/pretrained_models/lstm-transducer-models.html#csukuangfj-sherpa-onnx-lstm-zh-2023-02-20-chinese
2 - csukuangfj/sherpa-onnx-lstm-en-2023-02-17 (English) 2 - csukuangfj/sherpa-onnx-lstm-en-2023-02-17 (English)
https://k2-fsa.github.io/sherpa/onnx/pretrained_models/lstm-transducer-models.html#csukuangfj-sherpa-onnx-lstm-en-2023-02-17-english https://k2-fsa.github.io/sherpa/onnx/pretrained_models/lstm-transducer-models.html#csukuangfj-sherpa-onnx-lstm-en-2023-02-17-english
3 - pkufool/icefall-asr-zipformer-streaming-wenetspeech-20230615 3,4 - pkufool/icefall-asr-zipformer-streaming-wenetspeech-20230615
https://huggingface.co/pkufool/icefall-asr-zipformer-streaming-wenetspeech-20230615 https://huggingface.co/pkufool/icefall-asr-zipformer-streaming-wenetspeech-20230615
3 - int8 encoder
4 - float32 encoder
5 - csukuangfj/sherpa-onnx-streaming-paraformer-bilingual-zh-en
https://huggingface.co/csukuangfj/sherpa-onnx-streaming-paraformer-bilingual-zh-en
*/ */
fun getModelConfig(type: Int): OnlineTransducerModelConfig? { fun getModelConfig(type: Int): OnlineModelConfig? {
when (type) { when (type) {
0 -> { 0 -> {
val modelDir = "sherpa-onnx-streaming-zipformer-bilingual-zh-en-2023-02-20" val modelDir = "sherpa-onnx-streaming-zipformer-bilingual-zh-en-2023-02-20"
return OnlineTransducerModelConfig( return OnlineModelConfig(
encoder = "$modelDir/encoder-epoch-99-avg-1.onnx", transducer = OnlineTransducerModelConfig(
decoder = "$modelDir/decoder-epoch-99-avg-1.onnx", encoder = "$modelDir/encoder-epoch-99-avg-1.onnx",
joiner = "$modelDir/joiner-epoch-99-avg-1.onnx", decoder = "$modelDir/decoder-epoch-99-avg-1.onnx",
joiner = "$modelDir/joiner-epoch-99-avg-1.onnx",
),
tokens = "$modelDir/tokens.txt", tokens = "$modelDir/tokens.txt",
modelType = "zipformer", modelType = "zipformer",
) )
} }
1 -> { 1 -> {
val modelDir = "sherpa-onnx-lstm-zh-2023-02-20" val modelDir = "sherpa-onnx-lstm-zh-2023-02-20"
return OnlineTransducerModelConfig( return OnlineModelConfig(
encoder = "$modelDir/encoder-epoch-11-avg-1.onnx", transducer = OnlineTransducerModelConfig(
decoder = "$modelDir/decoder-epoch-11-avg-1.onnx", encoder = "$modelDir/encoder-epoch-11-avg-1.onnx",
joiner = "$modelDir/joiner-epoch-11-avg-1.onnx", decoder = "$modelDir/decoder-epoch-11-avg-1.onnx",
joiner = "$modelDir/joiner-epoch-11-avg-1.onnx",
),
tokens = "$modelDir/tokens.txt", tokens = "$modelDir/tokens.txt",
modelType = "lstm", modelType = "lstm",
) )
@@ -153,10 +173,12 @@ fun getModelConfig(type: Int): OnlineTransducerModelConfig? {
2 -> { 2 -> {
val modelDir = "sherpa-onnx-lstm-en-2023-02-17" val modelDir = "sherpa-onnx-lstm-en-2023-02-17"
return OnlineTransducerModelConfig( return OnlineModelConfig(
encoder = "$modelDir/encoder-epoch-99-avg-1.onnx", transducer = OnlineTransducerModelConfig(
decoder = "$modelDir/decoder-epoch-99-avg-1.onnx", encoder = "$modelDir/encoder-epoch-99-avg-1.onnx",
joiner = "$modelDir/joiner-epoch-99-avg-1.onnx", decoder = "$modelDir/decoder-epoch-99-avg-1.onnx",
joiner = "$modelDir/joiner-epoch-99-avg-1.onnx",
),
tokens = "$modelDir/tokens.txt", tokens = "$modelDir/tokens.txt",
modelType = "lstm", modelType = "lstm",
) )
@@ -164,10 +186,12 @@ fun getModelConfig(type: Int): OnlineTransducerModelConfig? {
3 -> { 3 -> {
val modelDir = "icefall-asr-zipformer-streaming-wenetspeech-20230615" val modelDir = "icefall-asr-zipformer-streaming-wenetspeech-20230615"
return OnlineTransducerModelConfig( return OnlineModelConfig(
encoder = "$modelDir/exp/encoder-epoch-12-avg-4-chunk-16-left-128.int8.onnx", transducer = OnlineTransducerModelConfig(
decoder = "$modelDir/exp/decoder-epoch-12-avg-4-chunk-16-left-128.onnx", encoder = "$modelDir/exp/encoder-epoch-12-avg-4-chunk-16-left-128.int8.onnx",
joiner = "$modelDir/exp/joiner-epoch-12-avg-4-chunk-16-left-128.onnx", decoder = "$modelDir/exp/decoder-epoch-12-avg-4-chunk-16-left-128.onnx",
joiner = "$modelDir/exp/joiner-epoch-12-avg-4-chunk-16-left-128.onnx",
),
tokens = "$modelDir/data/lang_char/tokens.txt", tokens = "$modelDir/data/lang_char/tokens.txt",
modelType = "zipformer2", modelType = "zipformer2",
) )
@@ -175,14 +199,28 @@ fun getModelConfig(type: Int): OnlineTransducerModelConfig? {
4 -> { 4 -> {
val modelDir = "icefall-asr-zipformer-streaming-wenetspeech-20230615" val modelDir = "icefall-asr-zipformer-streaming-wenetspeech-20230615"
return OnlineTransducerModelConfig( return OnlineModelConfig(
encoder = "$modelDir/exp/encoder-epoch-12-avg-4-chunk-16-left-128.onnx", transducer = OnlineTransducerModelConfig(
decoder = "$modelDir/exp/decoder-epoch-12-avg-4-chunk-16-left-128.onnx", encoder = "$modelDir/exp/encoder-epoch-12-avg-4-chunk-16-left-128.onnx",
joiner = "$modelDir/exp/joiner-epoch-12-avg-4-chunk-16-left-128.onnx", decoder = "$modelDir/exp/decoder-epoch-12-avg-4-chunk-16-left-128.onnx",
joiner = "$modelDir/exp/joiner-epoch-12-avg-4-chunk-16-left-128.onnx",
),
tokens = "$modelDir/data/lang_char/tokens.txt", tokens = "$modelDir/data/lang_char/tokens.txt",
modelType = "zipformer2", modelType = "zipformer2",
) )
} }
5 -> {
val modelDir = "sherpa-onnx-streaming-paraformer-bilingual-zh-en"
return OnlineModelConfig(
paraformer = OnlineParaformerModelConfig(
encoder = "$modelDir/encoder.int8.onnx",
decoder = "$modelDir/decoder.int8.onnx",
),
tokens = "$modelDir/tokens.txt",
modelType = "paraformer",
)
}
} }
return null; return null;
} }
@@ -200,7 +238,7 @@ by following the code, https://github.com/k2-fsa/icefall/blob/master/icefall/rnn
0 - sherpa-onnx-streaming-zipformer-bilingual-zh-en-2023-02-20 (Bilingual, Chinese + English) 0 - sherpa-onnx-streaming-zipformer-bilingual-zh-en-2023-02-20 (Bilingual, Chinese + English)
https://k2-fsa.github.io/sherpa/onnx/pretrained_models/zipformer-transducer-models.html#sherpa-onnx-streaming-zipformer-bilingual-zh-en-2023-02-20-bilingual-chinese-english https://k2-fsa.github.io/sherpa/onnx/pretrained_models/zipformer-transducer-models.html#sherpa-onnx-streaming-zipformer-bilingual-zh-en-2023-02-20-bilingual-chinese-english
*/ */
fun getOnlineLMConfig(type : Int): OnlineLMConfig { fun getOnlineLMConfig(type: Int): OnlineLMConfig {
when (type) { when (type) {
0 -> { 0 -> {
val modelDir = "sherpa-onnx-streaming-zipformer-bilingual-zh-en-2023-02-20" val modelDir = "sherpa-onnx-streaming-zipformer-bilingual-zh-en-2023-02-20"

View File

@@ -190,7 +190,11 @@ class OnlineRecognizerParaformerImpl : public OnlineRecognizerImpl {
OnlineParaformerDecoderResult r; OnlineParaformerDecoderResult r;
s->SetParaformerResult(r); s->SetParaformerResult(r);
// the internal model caches are not reset s->GetStates().clear();
s->GetParaformerEncoderOutCache().clear();
s->GetParaformerAlphaCache().clear();
// s->GetParaformerFeatCache().clear();
// Note: We only update counters. The underlying audio samples // Note: We only update counters. The underlying audio samples
// are not discarded. // are not discarded.

View File

@@ -47,7 +47,7 @@ class SherpaOnnx {
} }
void InputFinished() const { void InputFinished() const {
std::vector<float> tail_padding(input_sample_rate_ * 0.32, 0); std::vector<float> tail_padding(input_sample_rate_ * 0.6, 0);
stream_->AcceptWaveform(input_sample_rate_, tail_padding.data(), stream_->AcceptWaveform(input_sample_rate_, tail_padding.data(),
tail_padding.size()); tail_padding.size());
stream_->InputFinished(); stream_->InputFinished();
@@ -158,48 +158,74 @@ static OnlineRecognizerConfig GetConfig(JNIEnv *env, jobject config) {
//---------- model config ---------- //---------- model config ----------
fid = env->GetFieldID(cls, "modelConfig", fid = env->GetFieldID(cls, "modelConfig",
"Lcom/k2fsa/sherpa/onnx/OnlineTransducerModelConfig;"); "Lcom/k2fsa/sherpa/onnx/OnlineModelConfig;");
jobject transducer_config = env->GetObjectField(config, fid); jobject model_config = env->GetObjectField(config, fid);
jclass model_config_cls = env->GetObjectClass(transducer_config); jclass model_config_cls = env->GetObjectClass(model_config);
fid = env->GetFieldID(model_config_cls, "encoder", "Ljava/lang/String;"); // transducer
fid = env->GetFieldID(model_config_cls, "transducer",
"Lcom/k2fsa/sherpa/onnx/OnlineTransducerModelConfig;");
jobject transducer_config = env->GetObjectField(model_config, fid);
jclass transducer_config_cls = env->GetObjectClass(transducer_config);
fid = env->GetFieldID(transducer_config_cls, "encoder", "Ljava/lang/String;");
s = (jstring)env->GetObjectField(transducer_config, fid); s = (jstring)env->GetObjectField(transducer_config, fid);
p = env->GetStringUTFChars(s, nullptr); p = env->GetStringUTFChars(s, nullptr);
ans.model_config.transducer.encoder = p; ans.model_config.transducer.encoder = p;
env->ReleaseStringUTFChars(s, p); env->ReleaseStringUTFChars(s, p);
fid = env->GetFieldID(model_config_cls, "decoder", "Ljava/lang/String;"); fid = env->GetFieldID(transducer_config_cls, "decoder", "Ljava/lang/String;");
s = (jstring)env->GetObjectField(transducer_config, fid); s = (jstring)env->GetObjectField(transducer_config, fid);
p = env->GetStringUTFChars(s, nullptr); p = env->GetStringUTFChars(s, nullptr);
ans.model_config.transducer.decoder = p; ans.model_config.transducer.decoder = p;
env->ReleaseStringUTFChars(s, p); env->ReleaseStringUTFChars(s, p);
fid = env->GetFieldID(model_config_cls, "joiner", "Ljava/lang/String;"); fid = env->GetFieldID(transducer_config_cls, "joiner", "Ljava/lang/String;");
s = (jstring)env->GetObjectField(transducer_config, fid); s = (jstring)env->GetObjectField(transducer_config, fid);
p = env->GetStringUTFChars(s, nullptr); p = env->GetStringUTFChars(s, nullptr);
ans.model_config.transducer.joiner = p; ans.model_config.transducer.joiner = p;
env->ReleaseStringUTFChars(s, p); env->ReleaseStringUTFChars(s, p);
// paraformer
fid = env->GetFieldID(model_config_cls, "paraformer",
"Lcom/k2fsa/sherpa/onnx/OnlineParaformerModelConfig;");
jobject paraformer_config = env->GetObjectField(model_config, fid);
jclass paraformer_config_config_cls = env->GetObjectClass(paraformer_config);
fid = env->GetFieldID(paraformer_config_config_cls, "encoder",
"Ljava/lang/String;");
s = (jstring)env->GetObjectField(paraformer_config, fid);
p = env->GetStringUTFChars(s, nullptr);
ans.model_config.paraformer.encoder = p;
env->ReleaseStringUTFChars(s, p);
fid = env->GetFieldID(paraformer_config_config_cls, "decoder",
"Ljava/lang/String;");
s = (jstring)env->GetObjectField(paraformer_config, fid);
p = env->GetStringUTFChars(s, nullptr);
ans.model_config.paraformer.decoder = p;
env->ReleaseStringUTFChars(s, p);
fid = env->GetFieldID(model_config_cls, "tokens", "Ljava/lang/String;"); fid = env->GetFieldID(model_config_cls, "tokens", "Ljava/lang/String;");
s = (jstring)env->GetObjectField(transducer_config, fid); s = (jstring)env->GetObjectField(model_config, fid);
p = env->GetStringUTFChars(s, nullptr); p = env->GetStringUTFChars(s, nullptr);
ans.model_config.tokens = p; ans.model_config.tokens = p;
env->ReleaseStringUTFChars(s, p); env->ReleaseStringUTFChars(s, p);
fid = env->GetFieldID(model_config_cls, "numThreads", "I"); fid = env->GetFieldID(model_config_cls, "numThreads", "I");
ans.model_config.num_threads = env->GetIntField(transducer_config, fid); ans.model_config.num_threads = env->GetIntField(model_config, fid);
fid = env->GetFieldID(model_config_cls, "debug", "Z"); fid = env->GetFieldID(model_config_cls, "debug", "Z");
ans.model_config.debug = env->GetBooleanField(transducer_config, fid); ans.model_config.debug = env->GetBooleanField(model_config, fid);
fid = env->GetFieldID(model_config_cls, "provider", "Ljava/lang/String;"); fid = env->GetFieldID(model_config_cls, "provider", "Ljava/lang/String;");
s = (jstring)env->GetObjectField(transducer_config, fid); s = (jstring)env->GetObjectField(model_config, fid);
p = env->GetStringUTFChars(s, nullptr); p = env->GetStringUTFChars(s, nullptr);
ans.model_config.provider = p; ans.model_config.provider = p;
env->ReleaseStringUTFChars(s, p); env->ReleaseStringUTFChars(s, p);
fid = env->GetFieldID(model_config_cls, "modelType", "Ljava/lang/String;"); fid = env->GetFieldID(model_config_cls, "modelType", "Ljava/lang/String;");
s = (jstring)env->GetObjectField(transducer_config, fid); s = (jstring)env->GetObjectField(model_config, fid);
p = env->GetStringUTFChars(s, nullptr); p = env->GetStringUTFChars(s, nullptr);
ans.model_config.model_type = p; ans.model_config.model_type = p;
env->ReleaseStringUTFChars(s, p); env->ReleaseStringUTFChars(s, p);