Add lm rescore to online-modified-beam-search (#133)

This commit is contained in:
PF Luo
2023-05-05 21:23:54 +08:00
committed by GitHub
parent 3b9c3db31d
commit 8c6a6768d5
26 changed files with 495 additions and 39 deletions

View File

@@ -182,9 +182,10 @@ class MainActivity : AppCompatActivity() {
val config = OnlineRecognizerConfig(
featConfig = getFeatureConfig(sampleRate = sampleRateInHz, featureDim = 80),
modelConfig = getModelConfig(type = type)!!,
lmConfig = getOnlineLMConfig(type = type),
endpointConfig = getEndpointConfig(),
enableEndpoint = true,
decodingMethod = "greedy_search",
decodingMethod = "modified_beam_search",
maxActivePaths = 4,
)

View File

@@ -23,6 +23,11 @@ data class OnlineTransducerModelConfig(
var debug: Boolean = false,
)
data class OnlineLMConfig(
var model: String = "",
var scale: Float = 0.5f,
)
data class FeatureConfig(
var sampleRate: Int = 16000,
var featureDim: Int = 80,
@@ -31,6 +36,7 @@ data class FeatureConfig(
data class OnlineRecognizerConfig(
var featConfig: FeatureConfig = FeatureConfig(),
var modelConfig: OnlineTransducerModelConfig,
var lmConfig : OnlineLMConfig,
var endpointConfig: EndpointConfig = EndpointConfig(),
var enableEndpoint: Boolean = true,
var decodingMethod: String = "greedy_search",
@@ -151,6 +157,32 @@ fun getModelConfig(type: Int): OnlineTransducerModelConfig? {
return null;
}
/*
Please see
https://k2-fsa.github.io/sherpa/onnx/pretrained_models/index.html
for a list of pre-trained models.
We only add a few here. Please change the following code
to add your own LM model. (It should be straightforward to train a new NN LM model
by following the code, https://github.com/k2-fsa/icefall/blob/master/icefall/rnn_lm/train.py)
@param type
0 - sherpa-onnx-streaming-zipformer-bilingual-zh-en-2023-02-20 (Bilingual, Chinese + English)
https://k2-fsa.github.io/sherpa/onnx/pretrained_models/zipformer-transducer-models.html#sherpa-onnx-streaming-zipformer-bilingual-zh-en-2023-02-20-bilingual-chinese-english
*/
fun getOnlineLMConfig(type : Int): OnlineLMConfig {
when (type) {
0 -> {
val modelDir = "sherpa-onnx-streaming-zipformer-bilingual-zh-en-2023-02-20"
return OnlineLMConfig(
model = "$modelDir/with-state-epoch-99-avg-1.int8.onnx",
scale = 0.5f,
)
}
}
return OnlineLMConfig();
}
fun getEndpointConfig(): EndpointConfig {
return EndpointConfig(
rule1 = EndpointRule(false, 2.4f, 0.0f),