Fix modified beam search for iOS and android (#76)

* Use Int type for sampling rate

* Fix swift

* Fix iOS
This commit is contained in:
Fangjun Kuang
2023-03-03 15:18:31 +08:00
committed by GitHub
parent 7f72c13d9a
commit 5f31b22c12
15 changed files with 125 additions and 93 deletions

View File

@@ -1,2 +1,3 @@
Makefile Makefile
*.jar *.jar
hs_err_pid*.log

View File

@@ -4,7 +4,7 @@ import android.content.res.AssetManager
fun main() { fun main() {
var featConfig = FeatureConfig( var featConfig = FeatureConfig(
sampleRate = 16000.0f, sampleRate = 16000,
featureDim = 80, featureDim = 80,
) )
@@ -13,7 +13,7 @@ fun main() {
decoder = "./sherpa-onnx-streaming-zipformer-en-2023-02-21/decoder-epoch-99-avg-1.onnx", decoder = "./sherpa-onnx-streaming-zipformer-en-2023-02-21/decoder-epoch-99-avg-1.onnx",
joiner = "./sherpa-onnx-streaming-zipformer-en-2023-02-21/joiner-epoch-99-avg-1.onnx", joiner = "./sherpa-onnx-streaming-zipformer-en-2023-02-21/joiner-epoch-99-avg-1.onnx",
tokens = "./sherpa-onnx-streaming-zipformer-en-2023-02-21/tokens.txt", tokens = "./sherpa-onnx-streaming-zipformer-en-2023-02-21/tokens.txt",
numThreads = 4, numThreads = 1,
debug = false, debug = false,
) )
@@ -24,22 +24,31 @@ fun main() {
featConfig = featConfig, featConfig = featConfig,
endpointConfig = endpointConfig, endpointConfig = endpointConfig,
enableEndpoint = true, enableEndpoint = true,
decodingMethod = "greedy_search",
maxActivePaths = 4,
) )
var model = SherpaOnnx( var model = SherpaOnnx(
assetManager = AssetManager(), assetManager = AssetManager(),
config = config, config = config,
) )
var samples = WaveReader.readWave( var samples = WaveReader.readWave(
assetManager = AssetManager(), assetManager = AssetManager(),
filename = "./sherpa-onnx-streaming-zipformer-en-2023-02-21/test_wavs/1089-134686-0001.wav", filename = "./sherpa-onnx-streaming-zipformer-en-2023-02-21/test_wavs/1089-134686-0001.wav",
) )
model.decodeSamples(samples!!) model.acceptWaveform(samples!!, sampleRate=16000)
while (model.isReady()) {
model.decode()
}
var tail_paddings = FloatArray(8000) // 0.5 seconds var tail_paddings = FloatArray(8000) // 0.5 seconds
model.decodeSamples(tail_paddings) model.acceptWaveform(tail_paddings, sampleRate=16000)
model.inputFinished() model.inputFinished()
while (model.isReady()) {
model.decode()
}
println("results: ${model.text}") println("results: ${model.text}")
} }

1
.gitignore vendored
View File

@@ -38,3 +38,4 @@ log.txt
tags tags
run-decode-file-python.sh run-decode-file-python.sh
android/SherpaOnnx/app/src/main/assets/ android/SherpaOnnx/app/src/main/assets/
*.ncnn.*

View File

@@ -121,7 +121,10 @@ class MainActivity : AppCompatActivity() {
val ret = audioRecord?.read(buffer, 0, buffer.size) val ret = audioRecord?.read(buffer, 0, buffer.size)
if (ret != null && ret > 0) { if (ret != null && ret > 0) {
val samples = FloatArray(ret) { buffer[it] / 32768.0f } val samples = FloatArray(ret) { buffer[it] / 32768.0f }
model.decodeSamples(samples) model.acceptWaveform(samples, sampleRate=16000)
while (model.isReady()) {
model.decode()
}
runOnUiThread { runOnUiThread {
val isEndpoint = model.isEndpoint() val isEndpoint = model.isEndpoint()
val text = model.text val text = model.text
@@ -177,33 +180,17 @@ class MainActivity : AppCompatActivity() {
val type = 0 val type = 0
println("Select model type ${type}") println("Select model type ${type}")
val config = OnlineRecognizerConfig( val config = OnlineRecognizerConfig(
featConfig = getFeatureConfig(sampleRate = 16000.0f, featureDim = 80), featConfig = getFeatureConfig(sampleRate = 16000, featureDim = 80),
modelConfig = getModelConfig(type = type)!!, modelConfig = getModelConfig(type = type)!!,
endpointConfig = getEndpointConfig(), endpointConfig = getEndpointConfig(),
enableEndpoint = true enableEndpoint = true,
decodingMethod = "greedy_search",
maxActivePaths = 4,
) )
model = SherpaOnnx( model = SherpaOnnx(
assetManager = application.assets, assetManager = application.assets,
config = config, config = config,
) )
/*
println("reading samples")
val samples = WaveReader.readWave(
assetManager = application.assets,
// filename = "sherpa-onnx-streaming-zipformer-bilingual-zh-en-2023-02-20/test_wavs/0.wav",
filename = "sherpa-onnx-lstm-zh-2023-02-20/test_wavs/0.wav",
// filename="sherpa-onnx-lstm-en-2023-02-17/test_wavs/1089-134686-0001.wav"
)
println("samples read done!")
model.decodeSamples(samples!!)
val tailPaddings = FloatArray(8000) // 0.5 seconds
model.decodeSamples(tailPaddings)
println("result is: ${model.text}")
model.reset()
*/
} }
} }

View File

@@ -24,7 +24,7 @@ data class OnlineTransducerModelConfig(
) )
data class FeatureConfig( data class FeatureConfig(
var sampleRate: Float = 16000.0f, var sampleRate: Int = 16000,
var featureDim: Int = 80, var featureDim: Int = 80,
) )
@@ -32,7 +32,9 @@ data class OnlineRecognizerConfig(
var featConfig: FeatureConfig = FeatureConfig(), var featConfig: FeatureConfig = FeatureConfig(),
var modelConfig: OnlineTransducerModelConfig, var modelConfig: OnlineTransducerModelConfig,
var endpointConfig: EndpointConfig = EndpointConfig(), var endpointConfig: EndpointConfig = EndpointConfig(),
var enableEndpoint: Boolean, var enableEndpoint: Boolean = true,
var decodingMethod: String = "greedy_search",
var maxActivePaths: Int = 4,
) )
class SherpaOnnx( class SherpaOnnx(
@@ -49,12 +51,14 @@ class SherpaOnnx(
} }
fun decodeSamples(samples: FloatArray) = fun acceptWaveform(samples: FloatArray, sampleRate: Int) =
decodeSamples(ptr, samples, sampleRate = config.featConfig.sampleRate) acceptWaveform(ptr, samples, sampleRate)
fun inputFinished() = inputFinished(ptr) fun inputFinished() = inputFinished(ptr)
fun reset() = reset(ptr) fun reset() = reset(ptr)
fun decode() = decode(ptr)
fun isEndpoint(): Boolean = isEndpoint(ptr) fun isEndpoint(): Boolean = isEndpoint(ptr)
fun isReady(): Boolean = isReady(ptr)
val text: String val text: String
get() = getText(ptr) get() = getText(ptr)
@@ -66,11 +70,13 @@ class SherpaOnnx(
config: OnlineRecognizerConfig, config: OnlineRecognizerConfig,
): Long ): Long
private external fun decodeSamples(ptr: Long, samples: FloatArray, sampleRate: Float) private external fun acceptWaveform(ptr: Long, samples: FloatArray, sampleRate: Int)
private external fun inputFinished(ptr: Long) private external fun inputFinished(ptr: Long)
private external fun getText(ptr: Long): String private external fun getText(ptr: Long): String
private external fun reset(ptr: Long) private external fun reset(ptr: Long)
private external fun decode(ptr: Long)
private external fun isEndpoint(ptr: Long): Boolean private external fun isEndpoint(ptr: Long): Boolean
private external fun isReady(ptr: Long): Boolean
companion object { companion object {
init { init {
@@ -79,7 +85,7 @@ class SherpaOnnx(
} }
} }
fun getFeatureConfig(sampleRate: Float, featureDim: Int): FeatureConfig { fun getFeatureConfig(sampleRate: Int, featureDim: Int): FeatureConfig {
return FeatureConfig(sampleRate=sampleRate, featureDim=featureDim) return FeatureConfig(sampleRate=sampleRate, featureDim=featureDim)
} }

View File

@@ -23,10 +23,10 @@ extension AVAudioPCMBuffer {
class ViewController: UIViewController { class ViewController: UIViewController {
@IBOutlet weak var resultLabel: UILabel! @IBOutlet weak var resultLabel: UILabel!
@IBOutlet weak var recordBtn: UIButton! @IBOutlet weak var recordBtn: UIButton!
var audioEngine: AVAudioEngine? = nil var audioEngine: AVAudioEngine? = nil
var recognizer: SherpaOnnxRecognizer! = nil var recognizer: SherpaOnnxRecognizer! = nil
/// It saves the decoded results so far /// It saves the decoded results so far
var sentences: [String] = [] { var sentences: [String] = [] {
didSet { didSet {
@@ -42,7 +42,7 @@ class ViewController: UIViewController {
if sentences.isEmpty { if sentences.isEmpty {
return "0: \(lastSentence.lowercased())" return "0: \(lastSentence.lowercased())"
} }
let start = max(sentences.count - maxSentence, 0) let start = max(sentences.count - maxSentence, 0)
if lastSentence.isEmpty { if lastSentence.isEmpty {
return sentences.enumerated().map { (index, s) in "\(index): \(s.lowercased())" }[start...] return sentences.enumerated().map { (index, s) in "\(index): \(s.lowercased())" }[start...]
@@ -52,23 +52,23 @@ class ViewController: UIViewController {
.joined(separator: "\n") + "\n\(sentences.count): \(lastSentence.lowercased())" .joined(separator: "\n") + "\n\(sentences.count): \(lastSentence.lowercased())"
} }
} }
func updateLabel() { func updateLabel() {
DispatchQueue.main.async { DispatchQueue.main.async {
self.resultLabel.text = self.results self.resultLabel.text = self.results
} }
} }
override func viewDidLoad() { override func viewDidLoad() {
super.viewDidLoad() super.viewDidLoad()
// Do any additional setup after loading the view. // Do any additional setup after loading the view.
resultLabel.text = "ASR with Next-gen Kaldi\n\nSee https://github.com/k2-fsa/sherpa-onnx\n\nPress the Start button to run!" resultLabel.text = "ASR with Next-gen Kaldi\n\nSee https://github.com/k2-fsa/sherpa-onnx\n\nPress the Start button to run!"
recordBtn.setTitle("Start", for: .normal) recordBtn.setTitle("Start", for: .normal)
initRecognizer() initRecognizer()
initRecorder() initRecorder()
} }
@IBAction func onRecordBtnClick(_ sender: UIButton) { @IBAction func onRecordBtnClick(_ sender: UIButton) {
if recordBtn.currentTitle == "Start" { if recordBtn.currentTitle == "Start" {
startRecorder() startRecorder()
@@ -78,30 +78,32 @@ class ViewController: UIViewController {
recordBtn.setTitle("Start", for: .normal) recordBtn.setTitle("Start", for: .normal)
} }
} }
func initRecognizer() { func initRecognizer() {
// Please select one model that is best suitable for you. // Please select one model that is best suitable for you.
// //
// You can also modify Model.swift to add new pre-trained models from // You can also modify Model.swift to add new pre-trained models from
// https://k2-fsa.github.io/sherpa/ncnn/pretrained_models/index.html // https://k2-fsa.github.io/sherpa/ncnn/pretrained_models/index.html
let modelConfig = getBilingualStreamZhEnZipformer20230220() let modelConfig = getBilingualStreamZhEnZipformer20230220()
let featConfig = sherpaOnnxFeatureConfig( let featConfig = sherpaOnnxFeatureConfig(
sampleRate: 16000, sampleRate: 16000,
featureDim: 80) featureDim: 80)
var config = sherpaOnnxOnlineRecognizerConfig( var config = sherpaOnnxOnlineRecognizerConfig(
featConfig: featConfig, featConfig: featConfig,
modelConfig: modelConfig, modelConfig: modelConfig,
enableEndpoint: true, enableEndpoint: true,
rule1MinTrailingSilence: 2.4, rule1MinTrailingSilence: 2.4,
rule2MinTrailingSilence: 0.8, rule2MinTrailingSilence: 0.8,
rule3MinUtteranceLength: 30 rule3MinUtteranceLength: 30,
decodingMethod: "greedy_search",
maxActivePaths: 4
) )
recognizer = SherpaOnnxRecognizer(config: &config) recognizer = SherpaOnnxRecognizer(config: &config)
} }
func initRecorder() { func initRecorder() {
print("init recorder") print("init recorder")
audioEngine = AVAudioEngine() audioEngine = AVAudioEngine()
@@ -112,9 +114,9 @@ class ViewController: UIViewController {
commonFormat: .pcmFormatFloat32, commonFormat: .pcmFormatFloat32,
sampleRate: 16000, channels: 1, sampleRate: 16000, channels: 1,
interleaved: false)! interleaved: false)!
let converter = AVAudioConverter(from: inputFormat!, to: outputFormat)! let converter = AVAudioConverter(from: inputFormat!, to: outputFormat)!
inputNode!.installTap( inputNode!.installTap(
onBus: bus, onBus: bus,
bufferSize: 1024, bufferSize: 1024,
@@ -122,34 +124,34 @@ class ViewController: UIViewController {
) { ) {
(buffer: AVAudioPCMBuffer, when: AVAudioTime) in (buffer: AVAudioPCMBuffer, when: AVAudioTime) in
var newBufferAvailable = true var newBufferAvailable = true
let inputCallback: AVAudioConverterInputBlock = { let inputCallback: AVAudioConverterInputBlock = {
inNumPackets, outStatus in inNumPackets, outStatus in
if newBufferAvailable { if newBufferAvailable {
outStatus.pointee = .haveData outStatus.pointee = .haveData
newBufferAvailable = false newBufferAvailable = false
return buffer return buffer
} else { } else {
outStatus.pointee = .noDataNow outStatus.pointee = .noDataNow
return nil return nil
} }
} }
let convertedBuffer = AVAudioPCMBuffer( let convertedBuffer = AVAudioPCMBuffer(
pcmFormat: outputFormat, pcmFormat: outputFormat,
frameCapacity: frameCapacity:
AVAudioFrameCount(outputFormat.sampleRate) AVAudioFrameCount(outputFormat.sampleRate)
* buffer.frameLength * buffer.frameLength
/ AVAudioFrameCount(buffer.format.sampleRate))! / AVAudioFrameCount(buffer.format.sampleRate))!
var error: NSError? var error: NSError?
let _ = converter.convert( let _ = converter.convert(
to: convertedBuffer, to: convertedBuffer,
error: &error, withInputFrom: inputCallback) error: &error, withInputFrom: inputCallback)
// TODO(fangjun): Handle status != haveData // TODO(fangjun): Handle status != haveData
let array = convertedBuffer.array() let array = convertedBuffer.array()
if !array.isEmpty { if !array.isEmpty {
self.recognizer.acceptWaveform(samples: array) self.recognizer.acceptWaveform(samples: array)
@@ -158,13 +160,13 @@ class ViewController: UIViewController {
} }
let isEndpoint = self.recognizer.isEndpoint() let isEndpoint = self.recognizer.isEndpoint()
let text = self.recognizer.getResult().text let text = self.recognizer.getResult().text
if !text.isEmpty && self.lastSentence != text { if !text.isEmpty && self.lastSentence != text {
self.lastSentence = text self.lastSentence = text
self.updateLabel() self.updateLabel()
print(text) print(text)
} }
if isEndpoint { if isEndpoint {
if !text.isEmpty { if !text.isEmpty {
let tmp = self.lastSentence let tmp = self.lastSentence
@@ -175,13 +177,13 @@ class ViewController: UIViewController {
} }
} }
} }
} }
func startRecorder() { func startRecorder() {
lastSentence = "" lastSentence = ""
sentences = [] sentences = []
do { do {
try self.audioEngine?.start() try self.audioEngine?.start()
} catch let error as NSError { } catch let error as NSError {
@@ -189,7 +191,7 @@ class ViewController: UIViewController {
} }
print("started") print("started")
} }
func stopRecorder() { func stopRecorder() {
audioEngine?.stop() audioEngine?.stop()
print("stopped") print("stopped")

View File

@@ -76,7 +76,7 @@ SherpaOnnxOnlineStream *CreateOnlineStream(
void DestoryOnlineStream(SherpaOnnxOnlineStream *stream) { delete stream; } void DestoryOnlineStream(SherpaOnnxOnlineStream *stream) { delete stream; }
void AcceptWaveform(SherpaOnnxOnlineStream *stream, float sample_rate, void AcceptWaveform(SherpaOnnxOnlineStream *stream, int32_t sample_rate,
const float *samples, int32_t n) { const float *samples, int32_t n) {
stream->impl->AcceptWaveform(sample_rate, samples, n); stream->impl->AcceptWaveform(sample_rate, samples, n);
} }

View File

@@ -120,7 +120,7 @@ void DestoryOnlineStream(SherpaOnnxOnlineStream *stream);
/// @param samples A pointer to a 1-D array containing audio samples. /// @param samples A pointer to a 1-D array containing audio samples.
/// The range of samples has to be normalized to [-1, 1]. /// The range of samples has to be normalized to [-1, 1].
/// @param n Number of elements in the samples array. /// @param n Number of elements in the samples array.
void AcceptWaveform(SherpaOnnxOnlineStream *stream, float sample_rate, void AcceptWaveform(SherpaOnnxOnlineStream *stream, int32_t sample_rate,
const float *samples, int32_t n); const float *samples, int32_t n);
/// Return 1 if there are enough number of feature frames for decoding. /// Return 1 if there are enough number of feature frames for decoding.

View File

@@ -48,7 +48,7 @@ class FeatureExtractor::Impl {
fbank_ = std::make_unique<knf::OnlineFbank>(opts_); fbank_ = std::make_unique<knf::OnlineFbank>(opts_);
} }
void AcceptWaveform(float sampling_rate, const float *waveform, int32_t n) { void AcceptWaveform(int32_t sampling_rate, const float *waveform, int32_t n) {
std::lock_guard<std::mutex> lock(mutex_); std::lock_guard<std::mutex> lock(mutex_);
fbank_->AcceptWaveform(sampling_rate, waveform, n); fbank_->AcceptWaveform(sampling_rate, waveform, n);
} }
@@ -107,7 +107,7 @@ FeatureExtractor::FeatureExtractor(const FeatureExtractorConfig &config /*={}*/)
FeatureExtractor::~FeatureExtractor() = default; FeatureExtractor::~FeatureExtractor() = default;
void FeatureExtractor::AcceptWaveform(float sampling_rate, void FeatureExtractor::AcceptWaveform(int32_t sampling_rate,
const float *waveform, int32_t n) { const float *waveform, int32_t n) {
impl_->AcceptWaveform(sampling_rate, waveform, n); impl_->AcceptWaveform(sampling_rate, waveform, n);
} }

View File

@@ -14,7 +14,7 @@
namespace sherpa_onnx { namespace sherpa_onnx {
struct FeatureExtractorConfig { struct FeatureExtractorConfig {
float sampling_rate = 16000; int32_t sampling_rate = 16000;
int32_t feature_dim = 80; int32_t feature_dim = 80;
int32_t max_feature_vectors = -1; int32_t max_feature_vectors = -1;
@@ -34,7 +34,7 @@ class FeatureExtractor {
@param waveform Pointer to a 1-D array of size n @param waveform Pointer to a 1-D array of size n
@param n Number of entries in waveform @param n Number of entries in waveform
*/ */
void AcceptWaveform(float sampling_rate, const float *waveform, int32_t n); void AcceptWaveform(int32_t sampling_rate, const float *waveform, int32_t n);
/** /**
* InputFinished() tells the class you won't be providing any * InputFinished() tells the class you won't be providing any

View File

@@ -112,7 +112,7 @@ for a list of pre-trained models to download.
param.suggestedLatency = info->defaultLowInputLatency; param.suggestedLatency = info->defaultLowInputLatency;
param.hostApiSpecificStreamInfo = nullptr; param.hostApiSpecificStreamInfo = nullptr;
const float sample_rate = 16000; float sample_rate = 16000;
PaStream *stream; PaStream *stream;
PaError err = PaError err =

View File

@@ -61,7 +61,7 @@ for a list of pre-trained models to download.
sherpa_onnx::OnlineRecognizer recognizer(config); sherpa_onnx::OnlineRecognizer recognizer(config);
float expected_sampling_rate = config.feat_config.sampling_rate; int32_t expected_sampling_rate = config.feat_config.sampling_rate;
bool is_ok = false; bool is_ok = false;
std::vector<float> samples = std::vector<float> samples =
@@ -72,7 +72,7 @@ for a list of pre-trained models to download.
return -1; return -1;
} }
float duration = samples.size() / expected_sampling_rate; float duration = samples.size() / static_cast<float>(expected_sampling_rate);
fprintf(stderr, "wav filename: %s\n", wav_filename.c_str()); fprintf(stderr, "wav filename: %s\n", wav_filename.c_str());
fprintf(stderr, "wav duration (s): %.3f\n", duration); fprintf(stderr, "wav duration (s): %.3f\n", duration);

View File

@@ -40,19 +40,18 @@ class SherpaOnnx {
mgr, mgr,
#endif #endif
config), config),
stream_(recognizer_.CreateStream()), stream_(recognizer_.CreateStream()) {
tail_padding_(16000 * 0.32, 0) {
} }
void DecodeSamples(float sample_rate, const float *samples, int32_t n) const { void AcceptWaveform(int32_t sample_rate, const float *samples,
int32_t n) const {
stream_->AcceptWaveform(sample_rate, samples, n); stream_->AcceptWaveform(sample_rate, samples, n);
Decode();
} }
void InputFinished() const { void InputFinished() const {
stream_->AcceptWaveform(16000, tail_padding_.data(), tail_padding_.size()); std::vector<float> tail_padding(16000 * 0.32, 0);
stream_->AcceptWaveform(16000, tail_padding.data(), tail_padding.size());
stream_->InputFinished(); stream_->InputFinished();
Decode();
} }
const std::string GetText() const { const std::string GetText() const {
@@ -62,19 +61,15 @@ class SherpaOnnx {
bool IsEndpoint() const { return recognizer_.IsEndpoint(stream_.get()); } bool IsEndpoint() const { return recognizer_.IsEndpoint(stream_.get()); }
bool IsReady() const { return recognizer_.IsReady(stream_.get()); }
void Reset() const { return recognizer_.Reset(stream_.get()); } void Reset() const { return recognizer_.Reset(stream_.get()); }
private: void Decode() const { recognizer_.DecodeStream(stream_.get()); }
void Decode() const {
while (recognizer_.IsReady(stream_.get())) {
recognizer_.DecodeStream(stream_.get());
}
}
private: private:
sherpa_onnx::OnlineRecognizer recognizer_; sherpa_onnx::OnlineRecognizer recognizer_;
std::unique_ptr<sherpa_onnx::OnlineStream> stream_; std::unique_ptr<sherpa_onnx::OnlineStream> stream_;
std::vector<float> tail_padding_;
}; };
static OnlineRecognizerConfig GetConfig(JNIEnv *env, jobject config) { static OnlineRecognizerConfig GetConfig(JNIEnv *env, jobject config) {
@@ -86,14 +81,24 @@ static OnlineRecognizerConfig GetConfig(JNIEnv *env, jobject config) {
// https://docs.oracle.com/javase/7/docs/technotes/guides/jni/spec/types.html // https://docs.oracle.com/javase/7/docs/technotes/guides/jni/spec/types.html
// https://courses.cs.washington.edu/courses/cse341/99wi/java/tutorial/native1.1/implementing/field.html // https://courses.cs.washington.edu/courses/cse341/99wi/java/tutorial/native1.1/implementing/field.html
//---------- decoding ----------
fid = env->GetFieldID(cls, "decodingMethod", "Ljava/lang/String;");
jstring s = (jstring)env->GetObjectField(config, fid);
const char *p = env->GetStringUTFChars(s, nullptr);
ans.decoding_method = p;
env->ReleaseStringUTFChars(s, p);
fid = env->GetFieldID(cls, "maxActivePaths", "I");
ans.max_active_paths = env->GetIntField(config, fid);
//---------- feat config ---------- //---------- feat config ----------
fid = env->GetFieldID(cls, "featConfig", fid = env->GetFieldID(cls, "featConfig",
"Lcom/k2fsa/sherpa/onnx/FeatureConfig;"); "Lcom/k2fsa/sherpa/onnx/FeatureConfig;");
jobject feat_config = env->GetObjectField(config, fid); jobject feat_config = env->GetObjectField(config, fid);
jclass feat_config_cls = env->GetObjectClass(feat_config); jclass feat_config_cls = env->GetObjectClass(feat_config);
fid = env->GetFieldID(feat_config_cls, "sampleRate", "F"); fid = env->GetFieldID(feat_config_cls, "sampleRate", "I");
ans.feat_config.sampling_rate = env->GetFloatField(feat_config, fid); ans.feat_config.sampling_rate = env->GetIntField(feat_config, fid);
fid = env->GetFieldID(feat_config_cls, "featureDim", "I"); fid = env->GetFieldID(feat_config_cls, "featureDim", "I");
ans.feat_config.feature_dim = env->GetIntField(feat_config, fid); ans.feat_config.feature_dim = env->GetIntField(feat_config, fid);
@@ -153,8 +158,8 @@ static OnlineRecognizerConfig GetConfig(JNIEnv *env, jobject config) {
jclass model_config_cls = env->GetObjectClass(model_config); jclass model_config_cls = env->GetObjectClass(model_config);
fid = env->GetFieldID(model_config_cls, "encoder", "Ljava/lang/String;"); fid = env->GetFieldID(model_config_cls, "encoder", "Ljava/lang/String;");
jstring s = (jstring)env->GetObjectField(model_config, fid); s = (jstring)env->GetObjectField(model_config, fid);
const char *p = env->GetStringUTFChars(s, nullptr); p = env->GetStringUTFChars(s, nullptr);
ans.model_config.encoder_filename = p; ans.model_config.encoder_filename = p;
env->ReleaseStringUTFChars(s, p); env->ReleaseStringUTFChars(s, p);
@@ -198,6 +203,7 @@ JNIEXPORT jlong JNICALL Java_com_k2fsa_sherpa_onnx_SherpaOnnx_new(
#endif #endif
auto config = sherpa_onnx::GetConfig(env, _config); auto config = sherpa_onnx::GetConfig(env, _config);
SHERPA_ONNX_LOGE("config:\n%s", config.ToString().c_str());
auto model = new sherpa_onnx::SherpaOnnx( auto model = new sherpa_onnx::SherpaOnnx(
#if __ANDROID_API__ >= 9 #if __ANDROID_API__ >= 9
mgr, mgr,
@@ -220,6 +226,13 @@ JNIEXPORT void JNICALL Java_com_k2fsa_sherpa_onnx_SherpaOnnx_reset(
model->Reset(); model->Reset();
} }
SHERPA_ONNX_EXTERN_C
JNIEXPORT bool JNICALL Java_com_k2fsa_sherpa_onnx_SherpaOnnx_isReady(
JNIEnv *env, jobject /*obj*/, jlong ptr) {
auto model = reinterpret_cast<sherpa_onnx::SherpaOnnx *>(ptr);
return model->IsReady();
}
SHERPA_ONNX_EXTERN_C SHERPA_ONNX_EXTERN_C
JNIEXPORT bool JNICALL Java_com_k2fsa_sherpa_onnx_SherpaOnnx_isEndpoint( JNIEXPORT bool JNICALL Java_com_k2fsa_sherpa_onnx_SherpaOnnx_isEndpoint(
JNIEnv *env, jobject /*obj*/, jlong ptr) { JNIEnv *env, jobject /*obj*/, jlong ptr) {
@@ -228,15 +241,22 @@ JNIEXPORT bool JNICALL Java_com_k2fsa_sherpa_onnx_SherpaOnnx_isEndpoint(
} }
SHERPA_ONNX_EXTERN_C SHERPA_ONNX_EXTERN_C
JNIEXPORT void JNICALL Java_com_k2fsa_sherpa_onnx_SherpaOnnx_decodeSamples( JNIEXPORT void JNICALL Java_com_k2fsa_sherpa_onnx_SherpaOnnx_decode(
JNIEnv *env, jobject /*obj*/, jlong ptr) {
auto model = reinterpret_cast<sherpa_onnx::SherpaOnnx *>(ptr);
model->Decode();
}
SHERPA_ONNX_EXTERN_C
JNIEXPORT void JNICALL Java_com_k2fsa_sherpa_onnx_SherpaOnnx_acceptWaveform(
JNIEnv *env, jobject /*obj*/, jlong ptr, jfloatArray samples, JNIEnv *env, jobject /*obj*/, jlong ptr, jfloatArray samples,
jfloat sample_rate) { jint sample_rate) {
auto model = reinterpret_cast<sherpa_onnx::SherpaOnnx *>(ptr); auto model = reinterpret_cast<sherpa_onnx::SherpaOnnx *>(ptr);
jfloat *p = env->GetFloatArrayElements(samples, nullptr); jfloat *p = env->GetFloatArrayElements(samples, nullptr);
jsize n = env->GetArrayLength(samples); jsize n = env->GetArrayLength(samples);
model->DecodeSamples(sample_rate, p, n); model->AcceptWaveform(sample_rate, p, n);
env->ReleaseFloatArrayElements(samples, p, JNI_ABORT); env->ReleaseFloatArrayElements(samples, p, JNI_ABORT);
} }

View File

@@ -62,11 +62,15 @@ func sherpaOnnxOnlineRecognizerConfig(
enableEndpoint: Bool = false, enableEndpoint: Bool = false,
rule1MinTrailingSilence: Float = 2.4, rule1MinTrailingSilence: Float = 2.4,
rule2MinTrailingSilence: Float = 1.2, rule2MinTrailingSilence: Float = 1.2,
rule3MinUtteranceLength: Float = 30 rule3MinUtteranceLength: Float = 30,
decodingMethod: String = "greedy_search",
maxActivePaths: Int = 4
) -> SherpaOnnxOnlineRecognizerConfig{ ) -> SherpaOnnxOnlineRecognizerConfig{
return SherpaOnnxOnlineRecognizerConfig( return SherpaOnnxOnlineRecognizerConfig(
feat_config: featConfig, feat_config: featConfig,
model_config: modelConfig, model_config: modelConfig,
decoding_method: toCPointer(decodingMethod),
max_active_paths: Int32(maxActivePaths),
enable_endpoint: enableEndpoint ? 1 : 0, enable_endpoint: enableEndpoint ? 1 : 0,
rule1_min_trailing_silence: rule1MinTrailingSilence, rule1_min_trailing_silence: rule1MinTrailingSilence,
rule2_min_trailing_silence: rule2MinTrailingSilence, rule2_min_trailing_silence: rule2MinTrailingSilence,
@@ -128,12 +132,12 @@ class SherpaOnnxRecognizer {
/// Decode wave samples. /// Decode wave samples.
/// ///
/// - Parameters: /// - Parameters:
/// - samples: Audio samples normalzed to the range [-1, 1] /// - samples: Audio samples normalized to the range [-1, 1]
/// - sampleRate: Sample rate of the input audio samples. Must match /// - sampleRate: Sample rate of the input audio samples. Must match
/// the one expected by the model. It must be 16000 for /// the one expected by the model. It must be 16000 for
/// models from icefall. /// models from icefall.
func acceptWaveform(samples: [Float], sampleRate: Float = 16000) { func acceptWaveform(samples: [Float], sampleRate: Int = 16000) {
AcceptWaveform(stream, sampleRate, samples, Int32(samples.count)) AcceptWaveform(stream, Int32(sampleRate), samples, Int32(samples.count))
} }
func isReady() -> Bool { func isReady() -> Bool {

View File

@@ -32,7 +32,9 @@ func run() {
var config = sherpaOnnxOnlineRecognizerConfig( var config = sherpaOnnxOnlineRecognizerConfig(
featConfig: featConfig, featConfig: featConfig,
modelConfig: modelConfig, modelConfig: modelConfig,
enableEndpoint: false enableEndpoint: false,
decodingMethod: "modified_beam_search",
maxActivePaths: 4
) )