Fix modified beam search for iOS and android (#76)

* Use Int type for sampling rate

* Fix swift

* Fix iOS
This commit is contained in:
Fangjun Kuang
2023-03-03 15:18:31 +08:00
committed by GitHub
parent 7f72c13d9a
commit 5f31b22c12
15 changed files with 125 additions and 93 deletions

View File

@@ -121,7 +121,10 @@ class MainActivity : AppCompatActivity() {
val ret = audioRecord?.read(buffer, 0, buffer.size)
if (ret != null && ret > 0) {
val samples = FloatArray(ret) { buffer[it] / 32768.0f }
model.decodeSamples(samples)
model.acceptWaveform(samples, sampleRate=16000)
while (model.isReady()) {
model.decode()
}
runOnUiThread {
val isEndpoint = model.isEndpoint()
val text = model.text
@@ -177,33 +180,17 @@ class MainActivity : AppCompatActivity() {
val type = 0
println("Select model type ${type}")
val config = OnlineRecognizerConfig(
featConfig = getFeatureConfig(sampleRate = 16000.0f, featureDim = 80),
featConfig = getFeatureConfig(sampleRate = 16000, featureDim = 80),
modelConfig = getModelConfig(type = type)!!,
endpointConfig = getEndpointConfig(),
enableEndpoint = true
enableEndpoint = true,
decodingMethod = "greedy_search",
maxActivePaths = 4,
)
model = SherpaOnnx(
assetManager = application.assets,
config = config,
)
/*
println("reading samples")
val samples = WaveReader.readWave(
assetManager = application.assets,
// filename = "sherpa-onnx-streaming-zipformer-bilingual-zh-en-2023-02-20/test_wavs/0.wav",
filename = "sherpa-onnx-lstm-zh-2023-02-20/test_wavs/0.wav",
// filename="sherpa-onnx-lstm-en-2023-02-17/test_wavs/1089-134686-0001.wav"
)
println("samples read done!")
model.decodeSamples(samples!!)
val tailPaddings = FloatArray(8000) // 0.5 seconds
model.decodeSamples(tailPaddings)
println("result is: ${model.text}")
model.reset()
*/
}
}

View File

@@ -24,7 +24,7 @@ data class OnlineTransducerModelConfig(
)
data class FeatureConfig(
var sampleRate: Float = 16000.0f,
var sampleRate: Int = 16000,
var featureDim: Int = 80,
)
@@ -32,7 +32,9 @@ data class OnlineRecognizerConfig(
var featConfig: FeatureConfig = FeatureConfig(),
var modelConfig: OnlineTransducerModelConfig,
var endpointConfig: EndpointConfig = EndpointConfig(),
var enableEndpoint: Boolean,
var enableEndpoint: Boolean = true,
var decodingMethod: String = "greedy_search",
var maxActivePaths: Int = 4,
)
class SherpaOnnx(
@@ -49,12 +51,14 @@ class SherpaOnnx(
}
fun decodeSamples(samples: FloatArray) =
decodeSamples(ptr, samples, sampleRate = config.featConfig.sampleRate)
fun acceptWaveform(samples: FloatArray, sampleRate: Int) =
acceptWaveform(ptr, samples, sampleRate)
fun inputFinished() = inputFinished(ptr)
fun reset() = reset(ptr)
fun decode() = decode(ptr)
fun isEndpoint(): Boolean = isEndpoint(ptr)
fun isReady(): Boolean = isReady(ptr)
val text: String
get() = getText(ptr)
@@ -66,11 +70,13 @@ class SherpaOnnx(
config: OnlineRecognizerConfig,
): Long
private external fun decodeSamples(ptr: Long, samples: FloatArray, sampleRate: Float)
private external fun acceptWaveform(ptr: Long, samples: FloatArray, sampleRate: Int)
private external fun inputFinished(ptr: Long)
private external fun getText(ptr: Long): String
private external fun reset(ptr: Long)
private external fun decode(ptr: Long)
private external fun isEndpoint(ptr: Long): Boolean
private external fun isReady(ptr: Long): Boolean
companion object {
init {
@@ -79,7 +85,7 @@ class SherpaOnnx(
}
}
fun getFeatureConfig(sampleRate: Float, featureDim: Int): FeatureConfig {
fun getFeatureConfig(sampleRate: Int, featureDim: Int): FeatureConfig {
return FeatureConfig(sampleRate=sampleRate, featureDim=featureDim)
}