decoder for open vocabulary keyword spotting (#505)
* various fixes to ContextGraph to support open vocabulary keywords decoder * Add keyword spotter runtime * Add binary * First version works * Minor fixes * update text2token * default values * Add jni for kws * add kws android project * Minor fixes * Remove unused interface * Minor fixes * Add workflow * handle extra info in texts * Minor fixes * Add more comments * Fix ci * fix cpp style * Add input box in android demo so that users can specify their keywords * Fix cpp style * Fix comments * Minor fixes * Minor fixes * minor fixes * Minor fixes * Minor fixes * Add CI * Fix code style * cpplint * Fix comments * Fix error
68
.github/scripts/test-kws.sh
vendored
Executable file
@@ -0,0 +1,68 @@
|
||||
#!/usr/bin/env bash
|
||||
|
||||
set -e
|
||||
|
||||
log() {
|
||||
# This function is from espnet
|
||||
local fname=${BASH_SOURCE[1]##*/}
|
||||
echo -e "$(date '+%Y-%m-%d %H:%M:%S') (${fname}:${BASH_LINENO[0]}:${FUNCNAME[1]}) $*"
|
||||
}
|
||||
|
||||
echo "EXE is $EXE"
|
||||
echo "PATH: $PATH"
|
||||
|
||||
which $EXE
|
||||
|
||||
log "------------------------------------------------------------"
|
||||
log "Run Chinese keyword spotting (Wenetspeech)"
|
||||
log "------------------------------------------------------------"
|
||||
|
||||
repo_url=https://www.modelscope.cn/pkufool/sherpa-onnx-kws-zipformer-wenetspeech-3.3M-2024-01-01.git
|
||||
log "Start testing ${repo_url}"
|
||||
repo=sherpa-onnx-kws-zipformer-wenetspeech-3.3M-2024-01-01
|
||||
log "Download pretrained model and test-data from $repo_url"
|
||||
|
||||
GIT_LFS_SKIP_SMUDGE=1 git clone $repo_url
|
||||
pushd $repo
|
||||
git lfs pull --include "*.onnx"
|
||||
ls -lh *.onnx
|
||||
popd
|
||||
|
||||
time $EXE \
|
||||
--tokens=$repo/tokens.txt \
|
||||
--encoder=$repo/encoder-epoch-12-avg-2-chunk-16-left-64.onnx \
|
||||
--decoder=$repo/decoder-epoch-12-avg-2-chunk-16-left-64.onnx \
|
||||
--joiner=$repo/joiner-epoch-12-avg-2-chunk-16-left-64.onnx \
|
||||
--keywords-file=$repo/test_wavs/test_keywords.txt \
|
||||
--max-active-paths=4 \
|
||||
--num-threads=4 \
|
||||
$repo/test_wavs/3.wav $repo/test_wavs/4.wav $repo/test_wavs/5.wav $repo/test_wavs/6.wav
|
||||
|
||||
rm -rf $repo
|
||||
|
||||
log "------------------------------------------------------------"
|
||||
log "Run English keyword spotting (Gigaspeech)"
|
||||
log "------------------------------------------------------------"
|
||||
|
||||
repo_url=https://www.modelscope.cn/pkufool/sherpa-onnx-kws-zipformer-gigaspeech-3.3M-2024-01-01.git
|
||||
log "Start testing ${repo_url}"
|
||||
repo=sherpa-onnx-kws-zipformer-gigaspeech-3.3M-2024-01-01
|
||||
log "Download pretrained model and test-data from $repo_url"
|
||||
|
||||
GIT_LFS_SKIP_SMUDGE=1 git clone $repo_url
|
||||
pushd $repo
|
||||
git lfs pull --include "*.onnx"
|
||||
ls -lh *.onnx
|
||||
popd
|
||||
|
||||
time $EXE \
|
||||
--tokens=$repo/tokens.txt \
|
||||
--encoder=$repo/encoder-epoch-12-avg-2-chunk-16-left-64.onnx \
|
||||
--decoder=$repo/decoder-epoch-12-avg-2-chunk-16-left-64.onnx \
|
||||
--joiner=$repo/joiner-epoch-12-avg-2-chunk-16-left-64.onnx \
|
||||
--keywords-file=$repo/test_wavs/test_keywords.txt \
|
||||
--max-active-paths=4 \
|
||||
--num-threads=4 \
|
||||
$repo/test_wavs/0.wav $repo/test_wavs/1.wav
|
||||
|
||||
rm -rf $repo
|
||||
67
.github/workflows/apk-kws.yaml
vendored
Normal file
@@ -0,0 +1,67 @@
|
||||
name: apk-kws
|
||||
|
||||
on:
|
||||
push:
|
||||
branches:
|
||||
- apk-kws
|
||||
tags:
|
||||
- '*'
|
||||
|
||||
workflow_dispatch:
|
||||
|
||||
concurrency:
|
||||
group: apk-kws-${{ github.ref }}
|
||||
cancel-in-progress: true
|
||||
|
||||
permissions:
|
||||
contents: write
|
||||
|
||||
jobs:
|
||||
apk:
|
||||
runs-on: ${{ matrix.os }}
|
||||
strategy:
|
||||
fail-fast: false
|
||||
matrix:
|
||||
os: [ubuntu-latest]
|
||||
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
with:
|
||||
fetch-depth: 0
|
||||
|
||||
- name: ccache
|
||||
uses: hendrikmuhs/ccache-action@v1.2
|
||||
with:
|
||||
key: ${{ matrix.os }}-android
|
||||
|
||||
- name: Display NDK HOME
|
||||
shell: bash
|
||||
run: |
|
||||
echo "ANDROID_NDK_LATEST_HOME: ${ANDROID_NDK_LATEST_HOME}"
|
||||
ls -lh ${ANDROID_NDK_LATEST_HOME}
|
||||
|
||||
- name: build APK
|
||||
shell: bash
|
||||
run: |
|
||||
export CMAKE_CXX_COMPILER_LAUNCHER=ccache
|
||||
export PATH="/usr/lib/ccache:/usr/local/opt/ccache/libexec:$PATH"
|
||||
cmake --version
|
||||
|
||||
export ANDROID_NDK=$ANDROID_NDK_LATEST_HOME
|
||||
./build-kws-apk.sh
|
||||
|
||||
- name: Display APK
|
||||
shell: bash
|
||||
run: |
|
||||
ls -lh ./apks/
|
||||
|
||||
- uses: actions/upload-artifact@v3
|
||||
with:
|
||||
path: ./apks/*.apk
|
||||
|
||||
- name: Release APK
|
||||
uses: svenstaro/upload-release-action@v2
|
||||
with:
|
||||
file_glob: true
|
||||
file: apks/*.apk
|
||||
overwrite: true
|
||||
8
.github/workflows/linux.yaml
vendored
@@ -107,6 +107,14 @@ jobs:
|
||||
name: release-static
|
||||
path: build/bin/*
|
||||
|
||||
- name: Test transducer kws
|
||||
shell: bash
|
||||
run: |
|
||||
export PATH=$PWD/build/bin:$PATH
|
||||
export EXE=sherpa-onnx-keyword-spotter
|
||||
|
||||
.github/scripts/test-kws.sh
|
||||
|
||||
- name: Test online CTC
|
||||
shell: bash
|
||||
run: |
|
||||
|
||||
9
.github/workflows/macos.yaml
vendored
@@ -98,6 +98,14 @@ jobs:
|
||||
otool -L build/bin/sherpa-onnx
|
||||
otool -l build/bin/sherpa-onnx
|
||||
|
||||
- name: Test transducer kws
|
||||
shell: bash
|
||||
run: |
|
||||
export PATH=$PWD/build/bin:$PATH
|
||||
export EXE=sherpa-onnx-keyword-spotter
|
||||
|
||||
.github/scripts/test-kws.sh
|
||||
|
||||
- name: Test online CTC
|
||||
shell: bash
|
||||
run: |
|
||||
@@ -106,7 +114,6 @@ jobs:
|
||||
|
||||
.github/scripts/test-online-ctc.sh
|
||||
|
||||
|
||||
- name: Test offline TTS
|
||||
shell: bash
|
||||
run: |
|
||||
|
||||
2
.github/workflows/run-python-test.yaml
vendored
@@ -62,7 +62,7 @@ jobs:
|
||||
- name: Install Python dependencies
|
||||
shell: bash
|
||||
run: |
|
||||
python3 -m pip install --upgrade pip numpy sentencepiece==0.1.96 soundfile
|
||||
python3 -m pip install --upgrade pip numpy pypinyin sentencepiece==0.1.96 soundfile
|
||||
|
||||
- name: Install sherpa-onnx
|
||||
shell: bash
|
||||
|
||||
@@ -45,7 +45,7 @@ jobs:
|
||||
- name: Install Python dependencies
|
||||
shell: bash
|
||||
run: |
|
||||
python3 -m pip install --upgrade pip numpy sentencepiece
|
||||
python3 -m pip install --upgrade pip numpy pypinyin sentencepiece
|
||||
|
||||
- name: Install sherpa-onnx
|
||||
shell: bash
|
||||
|
||||
@@ -45,7 +45,7 @@ jobs:
|
||||
- name: Install Python dependencies
|
||||
shell: bash
|
||||
run: |
|
||||
python3 -m pip install --upgrade pip numpy sentencepiece
|
||||
python3 -m pip install --upgrade pip numpy pypinyin sentencepiece
|
||||
|
||||
- name: Install sherpa-onnx
|
||||
shell: bash
|
||||
|
||||
15
android/SherpaOnnxKws/.gitignore
vendored
Normal file
@@ -0,0 +1,15 @@
|
||||
*.iml
|
||||
.gradle
|
||||
/local.properties
|
||||
/.idea/caches
|
||||
/.idea/libraries
|
||||
/.idea/modules.xml
|
||||
/.idea/workspace.xml
|
||||
/.idea/navEditor.xml
|
||||
/.idea/assetWizardSettings.xml
|
||||
.DS_Store
|
||||
/build
|
||||
/captures
|
||||
.externalNativeBuild
|
||||
.cxx
|
||||
local.properties
|
||||
1
android/SherpaOnnxKws/app/.gitignore
vendored
Normal file
@@ -0,0 +1 @@
|
||||
/build
|
||||
44
android/SherpaOnnxKws/app/build.gradle
Normal file
@@ -0,0 +1,44 @@
|
||||
plugins {
|
||||
id 'com.android.application'
|
||||
id 'org.jetbrains.kotlin.android'
|
||||
}
|
||||
|
||||
android {
|
||||
namespace 'com.k2fsa.sherpa.onnx'
|
||||
compileSdk 32
|
||||
|
||||
defaultConfig {
|
||||
applicationId "com.k2fsa.sherpa.onnx"
|
||||
minSdk 21
|
||||
targetSdk 32
|
||||
versionCode 1
|
||||
versionName "1.0"
|
||||
|
||||
testInstrumentationRunner "androidx.test.runner.AndroidJUnitRunner"
|
||||
}
|
||||
|
||||
buildTypes {
|
||||
release {
|
||||
minifyEnabled false
|
||||
proguardFiles getDefaultProguardFile('proguard-android-optimize.txt'), 'proguard-rules.pro'
|
||||
}
|
||||
}
|
||||
compileOptions {
|
||||
sourceCompatibility JavaVersion.VERSION_1_8
|
||||
targetCompatibility JavaVersion.VERSION_1_8
|
||||
}
|
||||
kotlinOptions {
|
||||
jvmTarget = '1.8'
|
||||
}
|
||||
}
|
||||
|
||||
dependencies {
|
||||
|
||||
implementation 'androidx.core:core-ktx:1.7.0'
|
||||
implementation 'androidx.appcompat:appcompat:1.5.1'
|
||||
implementation 'com.google.android.material:material:1.7.0'
|
||||
implementation 'androidx.constraintlayout:constraintlayout:2.1.4'
|
||||
testImplementation 'junit:junit:4.13.2'
|
||||
androidTestImplementation 'androidx.test.ext:junit:1.1.4'
|
||||
androidTestImplementation 'androidx.test.espresso:espresso-core:3.5.0'
|
||||
}
|
||||
21
android/SherpaOnnxKws/app/proguard-rules.pro
vendored
Normal file
@@ -0,0 +1,21 @@
|
||||
# Add project specific ProGuard rules here.
|
||||
# You can control the set of applied configuration files using the
|
||||
# proguardFiles setting in build.gradle.
|
||||
#
|
||||
# For more details, see
|
||||
# http://developer.android.com/guide/developing/tools/proguard.html
|
||||
|
||||
# If your project uses WebView with JS, uncomment the following
|
||||
# and specify the fully qualified class name to the JavaScript interface
|
||||
# class:
|
||||
#-keepclassmembers class fqcn.of.javascript.interface.for.webview {
|
||||
# public *;
|
||||
#}
|
||||
|
||||
# Uncomment this to preserve the line number information for
|
||||
# debugging stack traces.
|
||||
#-keepattributes SourceFile,LineNumberTable
|
||||
|
||||
# If you keep the line number information, uncomment this to
|
||||
# hide the original source file name.
|
||||
#-renamesourcefileattribute SourceFile
|
||||
@@ -0,0 +1,24 @@
|
||||
package com.k2fsa.sherpa.onnx
|
||||
|
||||
import androidx.test.platform.app.InstrumentationRegistry
|
||||
import androidx.test.ext.junit.runners.AndroidJUnit4
|
||||
|
||||
import org.junit.Test
|
||||
import org.junit.runner.RunWith
|
||||
|
||||
import org.junit.Assert.*
|
||||
|
||||
/**
|
||||
* Instrumented test, which will execute on an Android device.
|
||||
*
|
||||
* See [testing documentation](http://d.android.com/tools/testing).
|
||||
*/
|
||||
@RunWith(AndroidJUnit4::class)
|
||||
class ExampleInstrumentedTest {
|
||||
@Test
|
||||
fun useAppContext() {
|
||||
// Context of the app under test.
|
||||
val appContext = InstrumentationRegistry.getInstrumentation().targetContext
|
||||
assertEquals("com.k2fsa.sherpa.onnx", appContext.packageName)
|
||||
}
|
||||
}
|
||||
32
android/SherpaOnnxKws/app/src/main/AndroidManifest.xml
Normal file
@@ -0,0 +1,32 @@
|
||||
<?xml version="1.0" encoding="utf-8"?>
|
||||
<manifest xmlns:android="http://schemas.android.com/apk/res/android"
|
||||
xmlns:tools="http://schemas.android.com/tools">
|
||||
|
||||
<uses-permission android:name="android.permission.RECORD_AUDIO" />
|
||||
|
||||
<application
|
||||
android:allowBackup="true"
|
||||
android:dataExtractionRules="@xml/data_extraction_rules"
|
||||
android:fullBackupContent="@xml/backup_rules"
|
||||
android:icon="@mipmap/ic_launcher"
|
||||
android:label="@string/app_name"
|
||||
android:roundIcon="@mipmap/ic_launcher_round"
|
||||
android:supportsRtl="true"
|
||||
android:theme="@style/Theme.SherpaOnnx"
|
||||
tools:targetApi="31">
|
||||
<activity
|
||||
android:name=".MainActivity"
|
||||
android:exported="true">
|
||||
<intent-filter>
|
||||
<action android:name="android.intent.action.MAIN" />
|
||||
|
||||
<category android:name="android.intent.category.LAUNCHER" />
|
||||
</intent-filter>
|
||||
|
||||
<meta-data
|
||||
android:name="android.app.lib_name"
|
||||
android:value="" />
|
||||
</activity>
|
||||
</application>
|
||||
|
||||
</manifest>
|
||||
0
android/SherpaOnnxKws/app/src/main/assets/.gitkeep
Normal file
@@ -0,0 +1,207 @@
|
||||
package com.k2fsa.sherpa.onnx
|
||||
|
||||
import android.Manifest
|
||||
import android.content.pm.PackageManager
|
||||
import android.media.AudioFormat
|
||||
import android.media.AudioRecord
|
||||
import android.media.MediaRecorder
|
||||
import android.os.Bundle
|
||||
import android.text.method.ScrollingMovementMethod
|
||||
import android.util.Log
|
||||
import android.widget.Button
|
||||
import android.widget.EditText
|
||||
import android.widget.TextView
|
||||
import android.widget.Toast
|
||||
import androidx.appcompat.app.AppCompatActivity
|
||||
import androidx.core.app.ActivityCompat
|
||||
import com.k2fsa.sherpa.onnx.*
|
||||
import kotlin.concurrent.thread
|
||||
|
||||
private const val TAG = "sherpa-onnx"
|
||||
private const val REQUEST_RECORD_AUDIO_PERMISSION = 200
|
||||
|
||||
class MainActivity : AppCompatActivity() {
|
||||
private val permissions: Array<String> = arrayOf(Manifest.permission.RECORD_AUDIO)
|
||||
|
||||
private lateinit var model: SherpaOnnxKws
|
||||
private var audioRecord: AudioRecord? = null
|
||||
private lateinit var recordButton: Button
|
||||
private lateinit var textView: TextView
|
||||
private lateinit var inputText: EditText
|
||||
private var recordingThread: Thread? = null
|
||||
|
||||
private val audioSource = MediaRecorder.AudioSource.MIC
|
||||
private val sampleRateInHz = 16000
|
||||
private val channelConfig = AudioFormat.CHANNEL_IN_MONO
|
||||
|
||||
// Note: We don't use AudioFormat.ENCODING_PCM_FLOAT
|
||||
// since the AudioRecord.read(float[]) needs API level >= 23
|
||||
// but we are targeting API level >= 21
|
||||
private val audioFormat = AudioFormat.ENCODING_PCM_16BIT
|
||||
private var idx: Int = 0
|
||||
private var lastText: String = ""
|
||||
|
||||
@Volatile
|
||||
private var isRecording: Boolean = false
|
||||
|
||||
override fun onRequestPermissionsResult(
|
||||
requestCode: Int, permissions: Array<String>, grantResults: IntArray
|
||||
) {
|
||||
super.onRequestPermissionsResult(requestCode, permissions, grantResults)
|
||||
val permissionToRecordAccepted = if (requestCode == REQUEST_RECORD_AUDIO_PERMISSION) {
|
||||
grantResults[0] == PackageManager.PERMISSION_GRANTED
|
||||
} else {
|
||||
false
|
||||
}
|
||||
|
||||
if (!permissionToRecordAccepted) {
|
||||
Log.e(TAG, "Audio record is disallowed")
|
||||
finish()
|
||||
}
|
||||
|
||||
Log.i(TAG, "Audio record is permitted")
|
||||
}
|
||||
|
||||
override fun onCreate(savedInstanceState: Bundle?) {
|
||||
super.onCreate(savedInstanceState)
|
||||
setContentView(R.layout.activity_main)
|
||||
|
||||
ActivityCompat.requestPermissions(this, permissions, REQUEST_RECORD_AUDIO_PERMISSION)
|
||||
|
||||
Log.i(TAG, "Start to initialize model")
|
||||
initModel()
|
||||
Log.i(TAG, "Finished initializing model")
|
||||
|
||||
recordButton = findViewById(R.id.record_button)
|
||||
recordButton.setOnClickListener { onclick() }
|
||||
|
||||
textView = findViewById(R.id.my_text)
|
||||
textView.movementMethod = ScrollingMovementMethod()
|
||||
|
||||
inputText = findViewById(R.id.input_text)
|
||||
}
|
||||
|
||||
private fun onclick() {
|
||||
if (!isRecording) {
|
||||
var keywords = inputText.text.toString()
|
||||
|
||||
Log.i(TAG, keywords)
|
||||
keywords = keywords.replace("\n", "/")
|
||||
// If keywords is an empty string, it just resets the decoding stream
|
||||
// always returns true in this case.
|
||||
// If keywords is not empty, it will create a new decoding stream with
|
||||
// the given keywords appended to the default keywords.
|
||||
// Return false if errors occured when adding keywords, true otherwise.
|
||||
val status = model.reset(keywords)
|
||||
if (!status) {
|
||||
Log.i(TAG, "Failed to reset with keywords.")
|
||||
Toast.makeText(this, "Failed to set keywords.", Toast.LENGTH_LONG).show();
|
||||
return
|
||||
}
|
||||
|
||||
val ret = initMicrophone()
|
||||
if (!ret) {
|
||||
Log.e(TAG, "Failed to initialize microphone")
|
||||
return
|
||||
}
|
||||
Log.i(TAG, "state: ${audioRecord?.state}")
|
||||
audioRecord!!.startRecording()
|
||||
recordButton.setText(R.string.stop)
|
||||
isRecording = true
|
||||
textView.text = ""
|
||||
lastText = ""
|
||||
idx = 0
|
||||
|
||||
recordingThread = thread(true) {
|
||||
processSamples()
|
||||
}
|
||||
Log.i(TAG, "Started recording")
|
||||
} else {
|
||||
isRecording = false
|
||||
audioRecord!!.stop()
|
||||
audioRecord!!.release()
|
||||
audioRecord = null
|
||||
recordButton.setText(R.string.start)
|
||||
Log.i(TAG, "Stopped recording")
|
||||
}
|
||||
}
|
||||
|
||||
private fun processSamples() {
|
||||
Log.i(TAG, "processing samples")
|
||||
|
||||
val interval = 0.1 // i.e., 100 ms
|
||||
val bufferSize = (interval * sampleRateInHz).toInt() // in samples
|
||||
val buffer = ShortArray(bufferSize)
|
||||
|
||||
while (isRecording) {
|
||||
val ret = audioRecord?.read(buffer, 0, buffer.size)
|
||||
if (ret != null && ret > 0) {
|
||||
val samples = FloatArray(ret) { buffer[it] / 32768.0f }
|
||||
model.acceptWaveform(samples, sampleRate=sampleRateInHz)
|
||||
while (model.isReady()) {
|
||||
model.decode()
|
||||
}
|
||||
|
||||
val text = model.keyword
|
||||
|
||||
var textToDisplay = lastText;
|
||||
|
||||
if(text.isNotBlank()) {
|
||||
if (lastText.isBlank()) {
|
||||
textToDisplay = "${idx}: ${text}"
|
||||
} else {
|
||||
textToDisplay = "${idx}: ${text}\n${lastText}"
|
||||
}
|
||||
lastText = "${idx}: ${text}\n${lastText}"
|
||||
idx += 1
|
||||
}
|
||||
|
||||
runOnUiThread {
|
||||
textView.text = textToDisplay
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
private fun initMicrophone(): Boolean {
|
||||
if (ActivityCompat.checkSelfPermission(
|
||||
this, Manifest.permission.RECORD_AUDIO
|
||||
) != PackageManager.PERMISSION_GRANTED
|
||||
) {
|
||||
ActivityCompat.requestPermissions(this, permissions, REQUEST_RECORD_AUDIO_PERMISSION)
|
||||
return false
|
||||
}
|
||||
|
||||
val numBytes = AudioRecord.getMinBufferSize(sampleRateInHz, channelConfig, audioFormat)
|
||||
Log.i(
|
||||
TAG, "buffer size in milliseconds: ${numBytes * 1000.0f / sampleRateInHz}"
|
||||
)
|
||||
|
||||
audioRecord = AudioRecord(
|
||||
audioSource,
|
||||
sampleRateInHz,
|
||||
channelConfig,
|
||||
audioFormat,
|
||||
numBytes * 2 // a sample has two bytes as we are using 16-bit PCM
|
||||
)
|
||||
return true
|
||||
}
|
||||
|
||||
private fun initModel() {
|
||||
// Please change getModelConfig() to add new models
|
||||
// See https://k2-fsa.github.io/sherpa/onnx/kws/pretrained_models/index.html
|
||||
// for a list of available models
|
||||
val type = 0
|
||||
Log.i(TAG, "Select model type ${type}")
|
||||
val config = KeywordSpotterConfig(
|
||||
featConfig = getFeatureConfig(sampleRate = sampleRateInHz, featureDim = 80),
|
||||
modelConfig = getModelConfig(type = type)!!,
|
||||
keywordsFile = getKeywordsFile(type = type)!!,
|
||||
)
|
||||
|
||||
model = SherpaOnnxKws(
|
||||
assetManager = application.assets,
|
||||
config = config,
|
||||
)
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,162 @@
|
||||
// Copyright (c) 2024 Xiaomi Corporation
|
||||
package com.k2fsa.sherpa.onnx
|
||||
|
||||
import android.content.res.AssetManager
|
||||
|
||||
data class OnlineTransducerModelConfig(
|
||||
var encoder: String = "",
|
||||
var decoder: String = "",
|
||||
var joiner: String = "",
|
||||
)
|
||||
|
||||
data class OnlineModelConfig(
|
||||
var transducer: OnlineTransducerModelConfig = OnlineTransducerModelConfig(),
|
||||
var tokens: String,
|
||||
var numThreads: Int = 1,
|
||||
var debug: Boolean = false,
|
||||
var provider: String = "cpu",
|
||||
var modelType: String = "",
|
||||
)
|
||||
|
||||
data class FeatureConfig(
|
||||
var sampleRate: Int = 16000,
|
||||
var featureDim: Int = 80,
|
||||
)
|
||||
|
||||
data class KeywordSpotterConfig(
|
||||
var featConfig: FeatureConfig = FeatureConfig(),
|
||||
var modelConfig: OnlineModelConfig,
|
||||
var maxActivePaths: Int = 4,
|
||||
var keywordsFile: String = "keywords.txt",
|
||||
var keywordsScore: Float = 1.5f,
|
||||
var keywordsThreshold: Float = 0.25f,
|
||||
var numTrailingBlanks: Int = 2,
|
||||
)
|
||||
|
||||
class SherpaOnnxKws(
|
||||
assetManager: AssetManager? = null,
|
||||
var config: KeywordSpotterConfig,
|
||||
) {
|
||||
private val ptr: Long
|
||||
|
||||
init {
|
||||
if (assetManager != null) {
|
||||
ptr = new(assetManager, config)
|
||||
} else {
|
||||
ptr = newFromFile(config)
|
||||
}
|
||||
}
|
||||
|
||||
protected fun finalize() {
|
||||
delete(ptr)
|
||||
}
|
||||
|
||||
fun acceptWaveform(samples: FloatArray, sampleRate: Int) =
|
||||
acceptWaveform(ptr, samples, sampleRate)
|
||||
|
||||
fun inputFinished() = inputFinished(ptr)
|
||||
fun decode() = decode(ptr)
|
||||
fun isReady(): Boolean = isReady(ptr)
|
||||
fun reset(keywords: String): Boolean = reset(ptr, keywords)
|
||||
|
||||
val keyword: String
|
||||
get() = getKeyword(ptr)
|
||||
|
||||
private external fun delete(ptr: Long)
|
||||
|
||||
private external fun new(
|
||||
assetManager: AssetManager,
|
||||
config: KeywordSpotterConfig,
|
||||
): Long
|
||||
|
||||
private external fun newFromFile(
|
||||
config: KeywordSpotterConfig,
|
||||
): Long
|
||||
|
||||
private external fun acceptWaveform(ptr: Long, samples: FloatArray, sampleRate: Int)
|
||||
private external fun inputFinished(ptr: Long)
|
||||
private external fun getKeyword(ptr: Long): String
|
||||
private external fun reset(ptr: Long, keywords: String): Boolean
|
||||
private external fun decode(ptr: Long)
|
||||
private external fun isReady(ptr: Long): Boolean
|
||||
|
||||
companion object {
|
||||
init {
|
||||
System.loadLibrary("sherpa-onnx-jni")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fun getFeatureConfig(sampleRate: Int, featureDim: Int): FeatureConfig {
|
||||
return FeatureConfig(sampleRate = sampleRate, featureDim = featureDim)
|
||||
}
|
||||
|
||||
/*
|
||||
Please see
|
||||
https://k2-fsa.github.io/sherpa/onnx/kws/pretrained_models/index.html
|
||||
for a list of pre-trained models.
|
||||
|
||||
We only add a few here. Please change the following code
|
||||
to add your own. (It should be straightforward to add a new model
|
||||
by following the code)
|
||||
|
||||
@param type
|
||||
0 - sherpa-onnx-kws-zipformer-wenetspeech-3.3M-2024-01-01 (Chinese)
|
||||
https://www.modelscope.cn/models/pkufool/sherpa-onnx-kws-zipformer-wenetspeech-3.3M-2024-01-01/summary
|
||||
|
||||
1 - sherpa-onnx-kws-zipformer-gigaspeech-3.3M-2024-01-01 (English)
|
||||
https://www.modelscope.cn/models/pkufool/sherpa-onnx-kws-zipformer-gigaspeech-3.3M-2024-01-01/summary
|
||||
|
||||
*/
|
||||
fun getModelConfig(type: Int): OnlineModelConfig? {
|
||||
when (type) {
|
||||
0 -> {
|
||||
val modelDir = "sherpa-onnx-kws-zipformer-wenetspeech-3.3M-2024-01-01"
|
||||
return OnlineModelConfig(
|
||||
transducer = OnlineTransducerModelConfig(
|
||||
encoder = "$modelDir/encoder-epoch-12-avg-2-chunk-16-left-64.onnx",
|
||||
decoder = "$modelDir/decoder-epoch-12-avg-2-chunk-16-left-64.onnx",
|
||||
joiner = "$modelDir/joiner-epoch-12-avg-2-chunk-16-left-64.onnx",
|
||||
),
|
||||
tokens = "$modelDir/tokens.txt",
|
||||
modelType = "zipformer2",
|
||||
)
|
||||
}
|
||||
|
||||
1 -> {
|
||||
val modelDir = "sherpa-onnx-kws-zipformer-gigaspeech-3.3M-2024-01-01"
|
||||
return OnlineModelConfig(
|
||||
transducer = OnlineTransducerModelConfig(
|
||||
encoder = "$modelDir/encoder-epoch-12-avg-2-chunk-16-left-64.onnx",
|
||||
decoder = "$modelDir/decoder-epoch-12-avg-2-chunk-16-left-64.onnx",
|
||||
joiner = "$modelDir/joiner-epoch-12-avg-2-chunk-16-left-64.onnx",
|
||||
),
|
||||
tokens = "$modelDir/tokens.txt",
|
||||
modelType = "zipformer2",
|
||||
)
|
||||
}
|
||||
|
||||
}
|
||||
return null;
|
||||
}
|
||||
|
||||
/*
|
||||
* Get the default keywords for each model.
|
||||
* Caution: The types and modelDir should be the same as those in getModelConfig
|
||||
* function above.
|
||||
*/
|
||||
fun getKeywordsFile(type: Int) : String {
|
||||
when (type) {
|
||||
0 -> {
|
||||
val modelDir = "sherpa-onnx-kws-zipformer-wenetspeech-3.3M-2024-01-01"
|
||||
return "$modelDir/keywords.txt"
|
||||
}
|
||||
|
||||
1 -> {
|
||||
val modelDir = "sherpa-onnx-kws-zipformer-gigaspeech-3.3M-2024-01-01"
|
||||
return "$modelDir/keywords.txt"
|
||||
}
|
||||
|
||||
}
|
||||
return "";
|
||||
}
|
||||
@@ -0,0 +1,29 @@
|
||||
// Copyright (c) 2023 Xiaomi Corporation
|
||||
package com.k2fsa.sherpa.onnx
|
||||
|
||||
import android.content.res.AssetManager
|
||||
|
||||
class WaveReader {
|
||||
companion object {
|
||||
// Read a mono wave file asset
|
||||
// The returned array has two entries:
|
||||
// - the first entry contains an 1-D float array
|
||||
// - the second entry is the sample rate
|
||||
external fun readWaveFromAsset(
|
||||
assetManager: AssetManager,
|
||||
filename: String,
|
||||
): Array<Any>
|
||||
|
||||
// Read a mono wave file from disk
|
||||
// The returned array has two entries:
|
||||
// - the first entry contains an 1-D float array
|
||||
// - the second entry is the sample rate
|
||||
external fun readWaveFromFile(
|
||||
filename: String,
|
||||
): Array<Any>
|
||||
|
||||
init {
|
||||
System.loadLibrary("sherpa-onnx-jni")
|
||||
}
|
||||
}
|
||||
}
|
||||
4
android/SherpaOnnxKws/app/src/main/jniLibs/.gitignore
vendored
Normal file
@@ -0,0 +1,4 @@
|
||||
*.so
|
||||
*.txt
|
||||
*.onnx
|
||||
*.wav
|
||||
@@ -0,0 +1,30 @@
|
||||
<vector xmlns:android="http://schemas.android.com/apk/res/android"
|
||||
xmlns:aapt="http://schemas.android.com/aapt"
|
||||
android:width="108dp"
|
||||
android:height="108dp"
|
||||
android:viewportWidth="108"
|
||||
android:viewportHeight="108">
|
||||
<path android:pathData="M31,63.928c0,0 6.4,-11 12.1,-13.1c7.2,-2.6 26,-1.4 26,-1.4l38.1,38.1L107,108.928l-32,-1L31,63.928z">
|
||||
<aapt:attr name="android:fillColor">
|
||||
<gradient
|
||||
android:endX="85.84757"
|
||||
android:endY="92.4963"
|
||||
android:startX="42.9492"
|
||||
android:startY="49.59793"
|
||||
android:type="linear">
|
||||
<item
|
||||
android:color="#44000000"
|
||||
android:offset="0.0" />
|
||||
<item
|
||||
android:color="#00000000"
|
||||
android:offset="1.0" />
|
||||
</gradient>
|
||||
</aapt:attr>
|
||||
</path>
|
||||
<path
|
||||
android:fillColor="#FFFFFF"
|
||||
android:fillType="nonZero"
|
||||
android:pathData="M65.3,45.828l3.8,-6.6c0.2,-0.4 0.1,-0.9 -0.3,-1.1c-0.4,-0.2 -0.9,-0.1 -1.1,0.3l-3.9,6.7c-6.3,-2.8 -13.4,-2.8 -19.7,0l-3.9,-6.7c-0.2,-0.4 -0.7,-0.5 -1.1,-0.3C38.8,38.328 38.7,38.828 38.9,39.228l3.8,6.6C36.2,49.428 31.7,56.028 31,63.928h46C76.3,56.028 71.8,49.428 65.3,45.828zM43.4,57.328c-0.8,0 -1.5,-0.5 -1.8,-1.2c-0.3,-0.7 -0.1,-1.5 0.4,-2.1c0.5,-0.5 1.4,-0.7 2.1,-0.4c0.7,0.3 1.2,1 1.2,1.8C45.3,56.528 44.5,57.328 43.4,57.328L43.4,57.328zM64.6,57.328c-0.8,0 -1.5,-0.5 -1.8,-1.2s-0.1,-1.5 0.4,-2.1c0.5,-0.5 1.4,-0.7 2.1,-0.4c0.7,0.3 1.2,1 1.2,1.8C66.5,56.528 65.6,57.328 64.6,57.328L64.6,57.328z"
|
||||
android:strokeWidth="1"
|
||||
android:strokeColor="#00000000" />
|
||||
</vector>
|
||||
@@ -0,0 +1,170 @@
|
||||
<?xml version="1.0" encoding="utf-8"?>
|
||||
<vector xmlns:android="http://schemas.android.com/apk/res/android"
|
||||
android:width="108dp"
|
||||
android:height="108dp"
|
||||
android:viewportWidth="108"
|
||||
android:viewportHeight="108">
|
||||
<path
|
||||
android:fillColor="#3DDC84"
|
||||
android:pathData="M0,0h108v108h-108z" />
|
||||
<path
|
||||
android:fillColor="#00000000"
|
||||
android:pathData="M9,0L9,108"
|
||||
android:strokeWidth="0.8"
|
||||
android:strokeColor="#33FFFFFF" />
|
||||
<path
|
||||
android:fillColor="#00000000"
|
||||
android:pathData="M19,0L19,108"
|
||||
android:strokeWidth="0.8"
|
||||
android:strokeColor="#33FFFFFF" />
|
||||
<path
|
||||
android:fillColor="#00000000"
|
||||
android:pathData="M29,0L29,108"
|
||||
android:strokeWidth="0.8"
|
||||
android:strokeColor="#33FFFFFF" />
|
||||
<path
|
||||
android:fillColor="#00000000"
|
||||
android:pathData="M39,0L39,108"
|
||||
android:strokeWidth="0.8"
|
||||
android:strokeColor="#33FFFFFF" />
|
||||
<path
|
||||
android:fillColor="#00000000"
|
||||
android:pathData="M49,0L49,108"
|
||||
android:strokeWidth="0.8"
|
||||
android:strokeColor="#33FFFFFF" />
|
||||
<path
|
||||
android:fillColor="#00000000"
|
||||
android:pathData="M59,0L59,108"
|
||||
android:strokeWidth="0.8"
|
||||
android:strokeColor="#33FFFFFF" />
|
||||
<path
|
||||
android:fillColor="#00000000"
|
||||
android:pathData="M69,0L69,108"
|
||||
android:strokeWidth="0.8"
|
||||
android:strokeColor="#33FFFFFF" />
|
||||
<path
|
||||
android:fillColor="#00000000"
|
||||
android:pathData="M79,0L79,108"
|
||||
android:strokeWidth="0.8"
|
||||
android:strokeColor="#33FFFFFF" />
|
||||
<path
|
||||
android:fillColor="#00000000"
|
||||
android:pathData="M89,0L89,108"
|
||||
android:strokeWidth="0.8"
|
||||
android:strokeColor="#33FFFFFF" />
|
||||
<path
|
||||
android:fillColor="#00000000"
|
||||
android:pathData="M99,0L99,108"
|
||||
android:strokeWidth="0.8"
|
||||
android:strokeColor="#33FFFFFF" />
|
||||
<path
|
||||
android:fillColor="#00000000"
|
||||
android:pathData="M0,9L108,9"
|
||||
android:strokeWidth="0.8"
|
||||
android:strokeColor="#33FFFFFF" />
|
||||
<path
|
||||
android:fillColor="#00000000"
|
||||
android:pathData="M0,19L108,19"
|
||||
android:strokeWidth="0.8"
|
||||
android:strokeColor="#33FFFFFF" />
|
||||
<path
|
||||
android:fillColor="#00000000"
|
||||
android:pathData="M0,29L108,29"
|
||||
android:strokeWidth="0.8"
|
||||
android:strokeColor="#33FFFFFF" />
|
||||
<path
|
||||
android:fillColor="#00000000"
|
||||
android:pathData="M0,39L108,39"
|
||||
android:strokeWidth="0.8"
|
||||
android:strokeColor="#33FFFFFF" />
|
||||
<path
|
||||
android:fillColor="#00000000"
|
||||
android:pathData="M0,49L108,49"
|
||||
android:strokeWidth="0.8"
|
||||
android:strokeColor="#33FFFFFF" />
|
||||
<path
|
||||
android:fillColor="#00000000"
|
||||
android:pathData="M0,59L108,59"
|
||||
android:strokeWidth="0.8"
|
||||
android:strokeColor="#33FFFFFF" />
|
||||
<path
|
||||
android:fillColor="#00000000"
|
||||
android:pathData="M0,69L108,69"
|
||||
android:strokeWidth="0.8"
|
||||
android:strokeColor="#33FFFFFF" />
|
||||
<path
|
||||
android:fillColor="#00000000"
|
||||
android:pathData="M0,79L108,79"
|
||||
android:strokeWidth="0.8"
|
||||
android:strokeColor="#33FFFFFF" />
|
||||
<path
|
||||
android:fillColor="#00000000"
|
||||
android:pathData="M0,89L108,89"
|
||||
android:strokeWidth="0.8"
|
||||
android:strokeColor="#33FFFFFF" />
|
||||
<path
|
||||
android:fillColor="#00000000"
|
||||
android:pathData="M0,99L108,99"
|
||||
android:strokeWidth="0.8"
|
||||
android:strokeColor="#33FFFFFF" />
|
||||
<path
|
||||
android:fillColor="#00000000"
|
||||
android:pathData="M19,29L89,29"
|
||||
android:strokeWidth="0.8"
|
||||
android:strokeColor="#33FFFFFF" />
|
||||
<path
|
||||
android:fillColor="#00000000"
|
||||
android:pathData="M19,39L89,39"
|
||||
android:strokeWidth="0.8"
|
||||
android:strokeColor="#33FFFFFF" />
|
||||
<path
|
||||
android:fillColor="#00000000"
|
||||
android:pathData="M19,49L89,49"
|
||||
android:strokeWidth="0.8"
|
||||
android:strokeColor="#33FFFFFF" />
|
||||
<path
|
||||
android:fillColor="#00000000"
|
||||
android:pathData="M19,59L89,59"
|
||||
android:strokeWidth="0.8"
|
||||
android:strokeColor="#33FFFFFF" />
|
||||
<path
|
||||
android:fillColor="#00000000"
|
||||
android:pathData="M19,69L89,69"
|
||||
android:strokeWidth="0.8"
|
||||
android:strokeColor="#33FFFFFF" />
|
||||
<path
|
||||
android:fillColor="#00000000"
|
||||
android:pathData="M19,79L89,79"
|
||||
android:strokeWidth="0.8"
|
||||
android:strokeColor="#33FFFFFF" />
|
||||
<path
|
||||
android:fillColor="#00000000"
|
||||
android:pathData="M29,19L29,89"
|
||||
android:strokeWidth="0.8"
|
||||
android:strokeColor="#33FFFFFF" />
|
||||
<path
|
||||
android:fillColor="#00000000"
|
||||
android:pathData="M39,19L39,89"
|
||||
android:strokeWidth="0.8"
|
||||
android:strokeColor="#33FFFFFF" />
|
||||
<path
|
||||
android:fillColor="#00000000"
|
||||
android:pathData="M49,19L49,89"
|
||||
android:strokeWidth="0.8"
|
||||
android:strokeColor="#33FFFFFF" />
|
||||
<path
|
||||
android:fillColor="#00000000"
|
||||
android:pathData="M59,19L59,89"
|
||||
android:strokeWidth="0.8"
|
||||
android:strokeColor="#33FFFFFF" />
|
||||
<path
|
||||
android:fillColor="#00000000"
|
||||
android:pathData="M69,19L69,89"
|
||||
android:strokeWidth="0.8"
|
||||
android:strokeColor="#33FFFFFF" />
|
||||
<path
|
||||
android:fillColor="#00000000"
|
||||
android:pathData="M79,19L79,89"
|
||||
android:strokeWidth="0.8"
|
||||
android:strokeColor="#33FFFFFF" />
|
||||
</vector>
|
||||
@@ -0,0 +1,46 @@
|
||||
<?xml version="1.0" encoding="utf-8"?>
|
||||
<androidx.constraintlayout.widget.ConstraintLayout xmlns:android="http://schemas.android.com/apk/res/android"
|
||||
xmlns:app="http://schemas.android.com/apk/res-auto"
|
||||
xmlns:tools="http://schemas.android.com/tools"
|
||||
android:layout_width="match_parent"
|
||||
android:layout_height="match_parent"
|
||||
tools:context=".MainActivity">
|
||||
|
||||
<LinearLayout
|
||||
android:layout_width="match_parent"
|
||||
android:layout_height="match_parent"
|
||||
android:gravity="center"
|
||||
android:orientation="vertical">
|
||||
|
||||
<EditText
|
||||
android:id="@+id/input_text"
|
||||
android:layout_width="match_parent"
|
||||
android:layout_height="320dp"
|
||||
android:layout_weight="2.5"
|
||||
android:hint="@string/keyword_hint"
|
||||
android:scrollbars="vertical"
|
||||
android:text=""
|
||||
android:textSize="15dp" />
|
||||
|
||||
<TextView
|
||||
android:id="@+id/my_text"
|
||||
android:layout_width="match_parent"
|
||||
android:layout_height="443dp"
|
||||
android:layout_weight="2.5"
|
||||
android:padding="24dp"
|
||||
android:scrollbars="vertical"
|
||||
android:singleLine="false"
|
||||
android:text="@string/hint"
|
||||
android:textSize="15dp" />
|
||||
|
||||
<Button
|
||||
android:id="@+id/record_button"
|
||||
android:layout_width="wrap_content"
|
||||
android:layout_height="wrap_content"
|
||||
android:layout_weight="0.5"
|
||||
android:text="@string/start" />
|
||||
|
||||
</LinearLayout>
|
||||
|
||||
|
||||
</androidx.constraintlayout.widget.ConstraintLayout>
|
||||
@@ -0,0 +1,5 @@
|
||||
<?xml version="1.0" encoding="utf-8"?>
|
||||
<adaptive-icon xmlns:android="http://schemas.android.com/apk/res/android">
|
||||
<background android:drawable="@drawable/ic_launcher_background" />
|
||||
<foreground android:drawable="@drawable/ic_launcher_foreground" />
|
||||
</adaptive-icon>
|
||||
@@ -0,0 +1,5 @@
|
||||
<?xml version="1.0" encoding="utf-8"?>
|
||||
<adaptive-icon xmlns:android="http://schemas.android.com/apk/res/android">
|
||||
<background android:drawable="@drawable/ic_launcher_background" />
|
||||
<foreground android:drawable="@drawable/ic_launcher_foreground" />
|
||||
</adaptive-icon>
|
||||
|
After Width: | Height: | Size: 1.4 KiB |
|
After Width: | Height: | Size: 2.8 KiB |
|
After Width: | Height: | Size: 982 B |
|
After Width: | Height: | Size: 1.7 KiB |
|
After Width: | Height: | Size: 1.9 KiB |
|
After Width: | Height: | Size: 3.8 KiB |
|
After Width: | Height: | Size: 2.8 KiB |
|
After Width: | Height: | Size: 5.8 KiB |
|
After Width: | Height: | Size: 3.8 KiB |
|
After Width: | Height: | Size: 7.6 KiB |
@@ -0,0 +1,16 @@
|
||||
<resources xmlns:tools="http://schemas.android.com/tools">
|
||||
<!-- Base application theme. -->
|
||||
<style name="Theme.SherpaOnnx" parent="Theme.MaterialComponents.DayNight.DarkActionBar">
|
||||
<!-- Primary brand color. -->
|
||||
<item name="colorPrimary">@color/purple_200</item>
|
||||
<item name="colorPrimaryVariant">@color/purple_700</item>
|
||||
<item name="colorOnPrimary">@color/black</item>
|
||||
<!-- Secondary brand color. -->
|
||||
<item name="colorSecondary">@color/teal_200</item>
|
||||
<item name="colorSecondaryVariant">@color/teal_200</item>
|
||||
<item name="colorOnSecondary">@color/black</item>
|
||||
<!-- Status bar color. -->
|
||||
<item name="android:statusBarColor">?attr/colorPrimaryVariant</item>
|
||||
<!-- Customize your theme here. -->
|
||||
</style>
|
||||
</resources>
|
||||
10
android/SherpaOnnxKws/app/src/main/res/values/colors.xml
Normal file
@@ -0,0 +1,10 @@
|
||||
<?xml version="1.0" encoding="utf-8"?>
|
||||
<resources>
|
||||
<color name="purple_200">#FFBB86FC</color>
|
||||
<color name="purple_500">#FF6200EE</color>
|
||||
<color name="purple_700">#FF3700B3</color>
|
||||
<color name="teal_200">#FF03DAC5</color>
|
||||
<color name="teal_700">#FF018786</color>
|
||||
<color name="black">#FF000000</color>
|
||||
<color name="white">#FFFFFFFF</color>
|
||||
</resources>
|
||||
12
android/SherpaOnnxKws/app/src/main/res/values/strings.xml
Normal file
@@ -0,0 +1,12 @@
|
||||
<resources>
|
||||
<string name="app_name">KWS with Next-gen Kaldi</string>
|
||||
<string name="hint">Click the Start button to play keyword spotting with Next-gen Kaldi.
|
||||
\n
|
||||
\n\n\n
|
||||
The source code and pre-trained models are publicly available.
|
||||
Please see https://github.com/k2-fsa/sherpa-onnx for details.
|
||||
</string>
|
||||
<string name="keyword_hint">Input your keywords here, one keyword perline.</string>
|
||||
<string name="start">Start</string>
|
||||
<string name="stop">Stop</string>
|
||||
</resources>
|
||||
16
android/SherpaOnnxKws/app/src/main/res/values/themes.xml
Normal file
@@ -0,0 +1,16 @@
|
||||
<resources xmlns:tools="http://schemas.android.com/tools">
|
||||
<!-- Base application theme. -->
|
||||
<style name="Theme.SherpaOnnx" parent="Theme.MaterialComponents.DayNight.DarkActionBar">
|
||||
<!-- Primary brand color. -->
|
||||
<item name="colorPrimary">@color/purple_500</item>
|
||||
<item name="colorPrimaryVariant">@color/purple_700</item>
|
||||
<item name="colorOnPrimary">@color/white</item>
|
||||
<!-- Secondary brand color. -->
|
||||
<item name="colorSecondary">@color/teal_200</item>
|
||||
<item name="colorSecondaryVariant">@color/teal_700</item>
|
||||
<item name="colorOnSecondary">@color/black</item>
|
||||
<!-- Status bar color. -->
|
||||
<item name="android:statusBarColor">?attr/colorPrimaryVariant</item>
|
||||
<!-- Customize your theme here. -->
|
||||
</style>
|
||||
</resources>
|
||||
13
android/SherpaOnnxKws/app/src/main/res/xml/backup_rules.xml
Normal file
@@ -0,0 +1,13 @@
|
||||
<?xml version="1.0" encoding="utf-8"?><!--
|
||||
Sample backup rules file; uncomment and customize as necessary.
|
||||
See https://developer.android.com/guide/topics/data/autobackup
|
||||
for details.
|
||||
Note: This file is ignored for devices older that API 31
|
||||
See https://developer.android.com/about/versions/12/backup-restore
|
||||
-->
|
||||
<full-backup-content>
|
||||
<!--
|
||||
<include domain="sharedpref" path="."/>
|
||||
<exclude domain="sharedpref" path="device.xml"/>
|
||||
-->
|
||||
</full-backup-content>
|
||||
@@ -0,0 +1,19 @@
|
||||
<?xml version="1.0" encoding="utf-8"?><!--
|
||||
Sample data extraction rules file; uncomment and customize as necessary.
|
||||
See https://developer.android.com/about/versions/12/backup-restore#xml-changes
|
||||
for details.
|
||||
-->
|
||||
<data-extraction-rules>
|
||||
<cloud-backup>
|
||||
<!-- TODO: Use <include> and <exclude> to control what is backed up.
|
||||
<include .../>
|
||||
<exclude .../>
|
||||
-->
|
||||
</cloud-backup>
|
||||
<!--
|
||||
<device-transfer>
|
||||
<include .../>
|
||||
<exclude .../>
|
||||
</device-transfer>
|
||||
-->
|
||||
</data-extraction-rules>
|
||||
@@ -0,0 +1,17 @@
|
||||
package com.k2fsa.sherpa.onnx
|
||||
|
||||
import org.junit.Test
|
||||
|
||||
import org.junit.Assert.*
|
||||
|
||||
/**
|
||||
* Example local unit test, which will execute on the development machine (host).
|
||||
*
|
||||
* See [testing documentation](http://d.android.com/tools/testing).
|
||||
*/
|
||||
class ExampleUnitTest {
|
||||
@Test
|
||||
fun addition_isCorrect() {
|
||||
assertEquals(4, 2 + 2)
|
||||
}
|
||||
}
|
||||
6
android/SherpaOnnxKws/build.gradle
Normal file
@@ -0,0 +1,6 @@
|
||||
// Top-level build file where you can add configuration options common to all sub-projects/modules.
|
||||
plugins {
|
||||
id 'com.android.application' version '7.3.1' apply false
|
||||
id 'com.android.library' version '7.3.1' apply false
|
||||
id 'org.jetbrains.kotlin.android' version '1.7.20' apply false
|
||||
}
|
||||
23
android/SherpaOnnxKws/gradle.properties
Normal file
@@ -0,0 +1,23 @@
|
||||
# Project-wide Gradle settings.
|
||||
# IDE (e.g. Android Studio) users:
|
||||
# Gradle settings configured through the IDE *will override*
|
||||
# any settings specified in this file.
|
||||
# For more details on how to configure your build environment visit
|
||||
# http://www.gradle.org/docs/current/userguide/build_environment.html
|
||||
# Specifies the JVM arguments used for the daemon process.
|
||||
# The setting is particularly useful for tweaking memory settings.
|
||||
org.gradle.jvmargs=-Xmx2048m -Dfile.encoding=UTF-8
|
||||
# When configured, Gradle will run in incubating parallel mode.
|
||||
# This option should only be used with decoupled projects. More details, visit
|
||||
# http://www.gradle.org/docs/current/userguide/multi_project_builds.html#sec:decoupled_projects
|
||||
# org.gradle.parallel=true
|
||||
# AndroidX package structure to make it clearer which packages are bundled with the
|
||||
# Android operating system, and which are packaged with your app's APK
|
||||
# https://developer.android.com/topic/libraries/support-library/androidx-rn
|
||||
android.useAndroidX=true
|
||||
# Kotlin code style for this project: "official" or "obsolete":
|
||||
kotlin.code.style=official
|
||||
# Enables namespacing of each library's R class so that its R class includes only the
|
||||
# resources declared in the library itself and none from the library's dependencies,
|
||||
# thereby reducing the size of the R class for that library
|
||||
android.nonTransitiveRClass=true
|
||||
BIN
android/SherpaOnnxKws/gradle/wrapper/gradle-wrapper.jar
vendored
Normal file
6
android/SherpaOnnxKws/gradle/wrapper/gradle-wrapper.properties
vendored
Normal file
@@ -0,0 +1,6 @@
|
||||
#Thu Feb 23 11:09:06 CST 2023
|
||||
distributionBase=GRADLE_USER_HOME
|
||||
distributionUrl=https\://services.gradle.org/distributions/gradle-7.4-bin.zip
|
||||
distributionPath=wrapper/dists
|
||||
zipStorePath=wrapper/dists
|
||||
zipStoreBase=GRADLE_USER_HOME
|
||||
185
android/SherpaOnnxKws/gradlew
vendored
Executable file
@@ -0,0 +1,185 @@
|
||||
#!/usr/bin/env sh
|
||||
|
||||
#
|
||||
# Copyright 2015 the original author or authors.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# https://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
|
||||
##############################################################################
|
||||
##
|
||||
## Gradle start up script for UN*X
|
||||
##
|
||||
##############################################################################
|
||||
|
||||
# Attempt to set APP_HOME
|
||||
# Resolve links: $0 may be a link
|
||||
PRG="$0"
|
||||
# Need this for relative symlinks.
|
||||
while [ -h "$PRG" ] ; do
|
||||
ls=`ls -ld "$PRG"`
|
||||
link=`expr "$ls" : '.*-> \(.*\)$'`
|
||||
if expr "$link" : '/.*' > /dev/null; then
|
||||
PRG="$link"
|
||||
else
|
||||
PRG=`dirname "$PRG"`"/$link"
|
||||
fi
|
||||
done
|
||||
SAVED="`pwd`"
|
||||
cd "`dirname \"$PRG\"`/" >/dev/null
|
||||
APP_HOME="`pwd -P`"
|
||||
cd "$SAVED" >/dev/null
|
||||
|
||||
APP_NAME="Gradle"
|
||||
APP_BASE_NAME=`basename "$0"`
|
||||
|
||||
# Add default JVM options here. You can also use JAVA_OPTS and GRADLE_OPTS to pass JVM options to this script.
|
||||
DEFAULT_JVM_OPTS='"-Xmx64m" "-Xms64m"'
|
||||
|
||||
# Use the maximum available, or set MAX_FD != -1 to use that value.
|
||||
MAX_FD="maximum"
|
||||
|
||||
warn () {
|
||||
echo "$*"
|
||||
}
|
||||
|
||||
die () {
|
||||
echo
|
||||
echo "$*"
|
||||
echo
|
||||
exit 1
|
||||
}
|
||||
|
||||
# OS specific support (must be 'true' or 'false').
|
||||
cygwin=false
|
||||
msys=false
|
||||
darwin=false
|
||||
nonstop=false
|
||||
case "`uname`" in
|
||||
CYGWIN* )
|
||||
cygwin=true
|
||||
;;
|
||||
Darwin* )
|
||||
darwin=true
|
||||
;;
|
||||
MINGW* )
|
||||
msys=true
|
||||
;;
|
||||
NONSTOP* )
|
||||
nonstop=true
|
||||
;;
|
||||
esac
|
||||
|
||||
CLASSPATH=$APP_HOME/gradle/wrapper/gradle-wrapper.jar
|
||||
|
||||
|
||||
# Determine the Java command to use to start the JVM.
|
||||
if [ -n "$JAVA_HOME" ] ; then
|
||||
if [ -x "$JAVA_HOME/jre/sh/java" ] ; then
|
||||
# IBM's JDK on AIX uses strange locations for the executables
|
||||
JAVACMD="$JAVA_HOME/jre/sh/java"
|
||||
else
|
||||
JAVACMD="$JAVA_HOME/bin/java"
|
||||
fi
|
||||
if [ ! -x "$JAVACMD" ] ; then
|
||||
die "ERROR: JAVA_HOME is set to an invalid directory: $JAVA_HOME
|
||||
|
||||
Please set the JAVA_HOME variable in your environment to match the
|
||||
location of your Java installation."
|
||||
fi
|
||||
else
|
||||
JAVACMD="java"
|
||||
which java >/dev/null 2>&1 || die "ERROR: JAVA_HOME is not set and no 'java' command could be found in your PATH.
|
||||
|
||||
Please set the JAVA_HOME variable in your environment to match the
|
||||
location of your Java installation."
|
||||
fi
|
||||
|
||||
# Increase the maximum file descriptors if we can.
|
||||
if [ "$cygwin" = "false" -a "$darwin" = "false" -a "$nonstop" = "false" ] ; then
|
||||
MAX_FD_LIMIT=`ulimit -H -n`
|
||||
if [ $? -eq 0 ] ; then
|
||||
if [ "$MAX_FD" = "maximum" -o "$MAX_FD" = "max" ] ; then
|
||||
MAX_FD="$MAX_FD_LIMIT"
|
||||
fi
|
||||
ulimit -n $MAX_FD
|
||||
if [ $? -ne 0 ] ; then
|
||||
warn "Could not set maximum file descriptor limit: $MAX_FD"
|
||||
fi
|
||||
else
|
||||
warn "Could not query maximum file descriptor limit: $MAX_FD_LIMIT"
|
||||
fi
|
||||
fi
|
||||
|
||||
# For Darwin, add options to specify how the application appears in the dock
|
||||
if $darwin; then
|
||||
GRADLE_OPTS="$GRADLE_OPTS \"-Xdock:name=$APP_NAME\" \"-Xdock:icon=$APP_HOME/media/gradle.icns\""
|
||||
fi
|
||||
|
||||
# For Cygwin or MSYS, switch paths to Windows format before running java
|
||||
if [ "$cygwin" = "true" -o "$msys" = "true" ] ; then
|
||||
APP_HOME=`cygpath --path --mixed "$APP_HOME"`
|
||||
CLASSPATH=`cygpath --path --mixed "$CLASSPATH"`
|
||||
|
||||
JAVACMD=`cygpath --unix "$JAVACMD"`
|
||||
|
||||
# We build the pattern for arguments to be converted via cygpath
|
||||
ROOTDIRSRAW=`find -L / -maxdepth 1 -mindepth 1 -type d 2>/dev/null`
|
||||
SEP=""
|
||||
for dir in $ROOTDIRSRAW ; do
|
||||
ROOTDIRS="$ROOTDIRS$SEP$dir"
|
||||
SEP="|"
|
||||
done
|
||||
OURCYGPATTERN="(^($ROOTDIRS))"
|
||||
# Add a user-defined pattern to the cygpath arguments
|
||||
if [ "$GRADLE_CYGPATTERN" != "" ] ; then
|
||||
OURCYGPATTERN="$OURCYGPATTERN|($GRADLE_CYGPATTERN)"
|
||||
fi
|
||||
# Now convert the arguments - kludge to limit ourselves to /bin/sh
|
||||
i=0
|
||||
for arg in "$@" ; do
|
||||
CHECK=`echo "$arg"|egrep -c "$OURCYGPATTERN" -`
|
||||
CHECK2=`echo "$arg"|egrep -c "^-"` ### Determine if an option
|
||||
|
||||
if [ $CHECK -ne 0 ] && [ $CHECK2 -eq 0 ] ; then ### Added a condition
|
||||
eval `echo args$i`=`cygpath --path --ignore --mixed "$arg"`
|
||||
else
|
||||
eval `echo args$i`="\"$arg\""
|
||||
fi
|
||||
i=`expr $i + 1`
|
||||
done
|
||||
case $i in
|
||||
0) set -- ;;
|
||||
1) set -- "$args0" ;;
|
||||
2) set -- "$args0" "$args1" ;;
|
||||
3) set -- "$args0" "$args1" "$args2" ;;
|
||||
4) set -- "$args0" "$args1" "$args2" "$args3" ;;
|
||||
5) set -- "$args0" "$args1" "$args2" "$args3" "$args4" ;;
|
||||
6) set -- "$args0" "$args1" "$args2" "$args3" "$args4" "$args5" ;;
|
||||
7) set -- "$args0" "$args1" "$args2" "$args3" "$args4" "$args5" "$args6" ;;
|
||||
8) set -- "$args0" "$args1" "$args2" "$args3" "$args4" "$args5" "$args6" "$args7" ;;
|
||||
9) set -- "$args0" "$args1" "$args2" "$args3" "$args4" "$args5" "$args6" "$args7" "$args8" ;;
|
||||
esac
|
||||
fi
|
||||
|
||||
# Escape application args
|
||||
save () {
|
||||
for i do printf %s\\n "$i" | sed "s/'/'\\\\''/g;1s/^/'/;\$s/\$/' \\\\/" ; done
|
||||
echo " "
|
||||
}
|
||||
APP_ARGS=`save "$@"`
|
||||
|
||||
# Collect all arguments for the java command, following the shell quoting and substitution rules
|
||||
eval set -- $DEFAULT_JVM_OPTS $JAVA_OPTS $GRADLE_OPTS "\"-Dorg.gradle.appname=$APP_BASE_NAME\"" -classpath "\"$CLASSPATH\"" org.gradle.wrapper.GradleWrapperMain "$APP_ARGS"
|
||||
|
||||
exec "$JAVACMD" "$@"
|
||||
89
android/SherpaOnnxKws/gradlew.bat
vendored
Normal file
@@ -0,0 +1,89 @@
|
||||
@rem
|
||||
@rem Copyright 2015 the original author or authors.
|
||||
@rem
|
||||
@rem Licensed under the Apache License, Version 2.0 (the "License");
|
||||
@rem you may not use this file except in compliance with the License.
|
||||
@rem You may obtain a copy of the License at
|
||||
@rem
|
||||
@rem https://www.apache.org/licenses/LICENSE-2.0
|
||||
@rem
|
||||
@rem Unless required by applicable law or agreed to in writing, software
|
||||
@rem distributed under the License is distributed on an "AS IS" BASIS,
|
||||
@rem WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
@rem See the License for the specific language governing permissions and
|
||||
@rem limitations under the License.
|
||||
@rem
|
||||
|
||||
@if "%DEBUG%" == "" @echo off
|
||||
@rem ##########################################################################
|
||||
@rem
|
||||
@rem Gradle startup script for Windows
|
||||
@rem
|
||||
@rem ##########################################################################
|
||||
|
||||
@rem Set local scope for the variables with windows NT shell
|
||||
if "%OS%"=="Windows_NT" setlocal
|
||||
|
||||
set DIRNAME=%~dp0
|
||||
if "%DIRNAME%" == "" set DIRNAME=.
|
||||
set APP_BASE_NAME=%~n0
|
||||
set APP_HOME=%DIRNAME%
|
||||
|
||||
@rem Resolve any "." and ".." in APP_HOME to make it shorter.
|
||||
for %%i in ("%APP_HOME%") do set APP_HOME=%%~fi
|
||||
|
||||
@rem Add default JVM options here. You can also use JAVA_OPTS and GRADLE_OPTS to pass JVM options to this script.
|
||||
set DEFAULT_JVM_OPTS="-Xmx64m" "-Xms64m"
|
||||
|
||||
@rem Find java.exe
|
||||
if defined JAVA_HOME goto findJavaFromJavaHome
|
||||
|
||||
set JAVA_EXE=java.exe
|
||||
%JAVA_EXE% -version >NUL 2>&1
|
||||
if "%ERRORLEVEL%" == "0" goto execute
|
||||
|
||||
echo.
|
||||
echo ERROR: JAVA_HOME is not set and no 'java' command could be found in your PATH.
|
||||
echo.
|
||||
echo Please set the JAVA_HOME variable in your environment to match the
|
||||
echo location of your Java installation.
|
||||
|
||||
goto fail
|
||||
|
||||
:findJavaFromJavaHome
|
||||
set JAVA_HOME=%JAVA_HOME:"=%
|
||||
set JAVA_EXE=%JAVA_HOME%/bin/java.exe
|
||||
|
||||
if exist "%JAVA_EXE%" goto execute
|
||||
|
||||
echo.
|
||||
echo ERROR: JAVA_HOME is set to an invalid directory: %JAVA_HOME%
|
||||
echo.
|
||||
echo Please set the JAVA_HOME variable in your environment to match the
|
||||
echo location of your Java installation.
|
||||
|
||||
goto fail
|
||||
|
||||
:execute
|
||||
@rem Setup the command line
|
||||
|
||||
set CLASSPATH=%APP_HOME%\gradle\wrapper\gradle-wrapper.jar
|
||||
|
||||
|
||||
@rem Execute Gradle
|
||||
"%JAVA_EXE%" %DEFAULT_JVM_OPTS% %JAVA_OPTS% %GRADLE_OPTS% "-Dorg.gradle.appname=%APP_BASE_NAME%" -classpath "%CLASSPATH%" org.gradle.wrapper.GradleWrapperMain %*
|
||||
|
||||
:end
|
||||
@rem End local scope for the variables with windows NT shell
|
||||
if "%ERRORLEVEL%"=="0" goto mainEnd
|
||||
|
||||
:fail
|
||||
rem Set variable GRADLE_EXIT_CONSOLE if you need the _script_ return code instead of
|
||||
rem the _cmd.exe /c_ return code!
|
||||
if not "" == "%GRADLE_EXIT_CONSOLE%" exit 1
|
||||
exit /b 1
|
||||
|
||||
:mainEnd
|
||||
if "%OS%"=="Windows_NT" endlocal
|
||||
|
||||
:omega
|
||||
16
android/SherpaOnnxKws/settings.gradle
Normal file
@@ -0,0 +1,16 @@
|
||||
pluginManagement {
|
||||
repositories {
|
||||
gradlePluginPortal()
|
||||
google()
|
||||
mavenCentral()
|
||||
}
|
||||
}
|
||||
dependencyResolutionManagement {
|
||||
repositoriesMode.set(RepositoriesMode.FAIL_ON_PROJECT_REPOS)
|
||||
repositories {
|
||||
google()
|
||||
mavenCentral()
|
||||
}
|
||||
}
|
||||
rootProject.name = "SherpaOnnxKws"
|
||||
include ':app'
|
||||
139
build-kws-apk.sh
Executable file
@@ -0,0 +1,139 @@
|
||||
#!/usr/bin/env bash
|
||||
|
||||
# Please set the environment variable ANDROID_NDK
|
||||
# before running this script
|
||||
|
||||
# Inside the $ANDROID_NDK directory, you can find a binary ndk-build
|
||||
# and some other files like the file "build/cmake/android.toolchain.cmake"
|
||||
|
||||
set -e
|
||||
|
||||
log() {
|
||||
# This function is from espnet
|
||||
local fname=${BASH_SOURCE[1]##*/}
|
||||
echo -e "$(date '+%Y-%m-%d %H:%M:%S') (${fname}:${BASH_LINENO[0]}:${FUNCNAME[1]}) $*"
|
||||
}
|
||||
|
||||
SHERPA_ONNX_VERSION=$(grep "SHERPA_ONNX_VERSION" ./CMakeLists.txt | cut -d " " -f 2 | cut -d '"' -f 2)
|
||||
|
||||
log "Building keyword spotting APK for sherpa-onnx v${SHERPA_ONNX_VERSION}"
|
||||
|
||||
log "====================arm64-v8a================="
|
||||
./build-android-arm64-v8a.sh
|
||||
log "====================armv7-eabi================"
|
||||
./build-android-armv7-eabi.sh
|
||||
log "====================x86-64===================="
|
||||
./build-android-x86-64.sh
|
||||
log "====================x86===================="
|
||||
./build-android-x86.sh
|
||||
|
||||
mkdir -p apks
|
||||
|
||||
# Download the model
|
||||
repo=sherpa-onnx-kws-zipformer-wenetspeech-3.3M-2024-01-01
|
||||
|
||||
if [ ! -d ./android/SherpaOnnxKws/app/src/main/assets/$repo ]; then
|
||||
repo_url=https://www.modelscope.cn/pkufool/sherpa-onnx-kws-zipformer-wenetspeech-3.3M-2024-01-01.git
|
||||
log "Start testing ${repo_url}"
|
||||
log "Download pretrained model and test-data from $repo_url"
|
||||
GIT_LFS_SKIP_SMUDGE=1 git clone $repo_url
|
||||
pushd $repo
|
||||
git lfs pull --include "*.onnx"
|
||||
|
||||
# remove .git to save spaces
|
||||
rm -rf .git
|
||||
rm *.int8.onnx
|
||||
rm README.md configuration.json .gitattributes
|
||||
rm -rfv test_wavs
|
||||
ls -lh
|
||||
popd
|
||||
|
||||
mv -v $repo ./android/SherpaOnnxKws/app/src/main/assets/
|
||||
fi
|
||||
|
||||
tree ./android/SherpaOnnxKws/app/src/main/assets/
|
||||
|
||||
for arch in arm64-v8a armeabi-v7a x86_64 x86; do
|
||||
log "------------------------------------------------------------"
|
||||
log "build apk for $arch"
|
||||
log "------------------------------------------------------------"
|
||||
src_arch=$arch
|
||||
if [ $arch == "armeabi-v7a" ]; then
|
||||
src_arch=armv7-eabi
|
||||
elif [ $arch == "x86_64" ]; then
|
||||
src_arch=x86-64
|
||||
fi
|
||||
|
||||
ls -lh ./build-android-$src_arch/install/lib/*.so
|
||||
|
||||
cp -v ./build-android-$src_arch/install/lib/*.so ./android/SherpaOnnxKws/app/src/main/jniLibs/$arch/
|
||||
|
||||
pushd ./android/SherpaOnnxKws
|
||||
./gradlew build
|
||||
popd
|
||||
|
||||
mv android/SherpaOnnxKws/app/build/outputs/apk/debug/app-debug.apk ./apks/sherpa-onnx-kws-wenetspeech-zh-${SHERPA_ONNX_VERSION}-$arch.apk
|
||||
ls -lh apks
|
||||
rm -v ./android/SherpaOnnxKws/app/src/main/jniLibs/$arch/*.so
|
||||
done
|
||||
|
||||
git checkout .
|
||||
|
||||
rm -rf ./android/SherpaOnnxKws/app/src/main/assets/$repo
|
||||
|
||||
# English model
|
||||
repo=sherpa-onnx-kws-zipformer-gigaspeech-3.3M-2024-01-01
|
||||
|
||||
if [ ! -d ./android/SherpaOnnxKws/app/src/main/assets/$repo ]; then
|
||||
repo_url=https://www.modelscope.cn/pkufool/sherpa-onnx-kws-zipformer-gigaspeech-3.3M-2024-01-01.git
|
||||
log "Start testing ${repo_url}"
|
||||
log "Download pretrained model and test-data from $repo_url"
|
||||
GIT_LFS_SKIP_SMUDGE=1 git clone $repo_url
|
||||
pushd $repo
|
||||
git lfs pull --include "*.onnx"
|
||||
|
||||
# remove .git to save spaces
|
||||
rm -rf .git
|
||||
rm *.int8.onnx
|
||||
rm README.md configuration.json .gitattributes
|
||||
rm -rfv test_wavs
|
||||
ls -lh
|
||||
popd
|
||||
|
||||
mv -v $repo ./android/SherpaOnnxKws/app/src/main/assets/
|
||||
fi
|
||||
|
||||
tree ./android/SherpaOnnxKws/app/src/main/assets/
|
||||
|
||||
pushd android/SherpaOnnxKws/app/src/main/java/com/k2fsa/sherpa/onnx
|
||||
sed -i.bak s/"type = 0"/"type = 1"/ ./MainActivity.kt
|
||||
git diff
|
||||
popd
|
||||
|
||||
for arch in arm64-v8a armeabi-v7a x86_64 x86; do
|
||||
log "------------------------------------------------------------"
|
||||
log "build apk for $arch"
|
||||
log "------------------------------------------------------------"
|
||||
src_arch=$arch
|
||||
if [ $arch == "armeabi-v7a" ]; then
|
||||
src_arch=armv7-eabi
|
||||
elif [ $arch == "x86_64" ]; then
|
||||
src_arch=x86-64
|
||||
fi
|
||||
|
||||
ls -lh ./build-android-$src_arch/install/lib/*.so
|
||||
|
||||
cp -v ./build-android-$src_arch/install/lib/*.so ./android/SherpaOnnxKws/app/src/main/jniLibs/$arch/
|
||||
|
||||
pushd ./android/SherpaOnnxKws
|
||||
./gradlew build
|
||||
popd
|
||||
|
||||
mv android/SherpaOnnxKws/app/build/outputs/apk/debug/app-debug.apk ./apks/sherpa-onnx-kws-gigaspeech-en-${SHERPA_ONNX_VERSION}-$arch.apk
|
||||
ls -lh apks
|
||||
rm -v ./android/SherpaOnnxKws/app/src/main/jniLibs/$arch/*.so
|
||||
done
|
||||
|
||||
git checkout .
|
||||
|
||||
rm -rf ./android/SherpaOnnxKws/app/src/main/assets/$repo
|
||||
@@ -151,6 +151,7 @@ class BuildExtension(build_ext):
|
||||
# Remember to also change setup.py
|
||||
|
||||
binaries = ["sherpa-onnx"]
|
||||
binaries += ["sherpa-onnx-keyword-spotter"]
|
||||
binaries += ["sherpa-onnx-offline"]
|
||||
binaries += ["sherpa-onnx-microphone"]
|
||||
binaries += ["sherpa-onnx-microphone-offline"]
|
||||
|
||||
@@ -36,13 +36,44 @@ import argparse
|
||||
|
||||
from sherpa_onnx import text2token
|
||||
|
||||
|
||||
def get_args():
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument(
|
||||
"--text",
|
||||
type=str,
|
||||
required=True,
|
||||
help="Path to the input texts",
|
||||
help="""Path to the input texts.
|
||||
|
||||
Each line in the texts contains the original phrase, it might also contain some
|
||||
extra items, for example, the boosting score (startting with :), the triggering
|
||||
threshold (startting with #, only used in keyword spotting task) and the original
|
||||
phrase (startting with @). Note: extra items will be kept in the output.
|
||||
|
||||
example input 1 (tokens_type = ppinyin):
|
||||
|
||||
小爱同学 :2.0 #0.6 @小爱同学
|
||||
你好问问 :3.5 @你好问问
|
||||
小艺小艺 #0.6 @小艺小艺
|
||||
|
||||
example output 1:
|
||||
|
||||
x iǎo ài t óng x ué :2.0 #0.6 @小爱同学
|
||||
n ǐ h ǎo w èn w èn :3.5 @你好问问
|
||||
x iǎo y ì x iǎo y ì #0.6 @小艺小艺
|
||||
|
||||
example input 2 (tokens_type = bpe):
|
||||
|
||||
HELLO WORLD :1.5 #0.4
|
||||
HI GOOGLE :2.0 #0.8
|
||||
HEY SIRI #0.35
|
||||
|
||||
example output 2:
|
||||
|
||||
▁HE LL O ▁WORLD :1.5 #0.4
|
||||
▁HI ▁GO O G LE :2.0 #0.8
|
||||
▁HE Y ▁S I RI #0.35
|
||||
""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
@@ -56,7 +87,11 @@ def get_args():
|
||||
"--tokens-type",
|
||||
type=str,
|
||||
required=True,
|
||||
help="The type of modeling units, should be cjkchar, bpe or cjkchar+bpe",
|
||||
choices=["cjkchar", "bpe", "cjkchar+bpe", "fpinyin", "ppinyin"],
|
||||
help="""The type of modeling units, should be cjkchar, bpe, cjkchar+bpe, fpinyin or ppinyin.
|
||||
fpinyin means full pinyin, each cjkchar has a pinyin(with tone).
|
||||
ppinyin means partial pinyin, it splits pinyin into initial and final,
|
||||
""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
@@ -79,9 +114,21 @@ def main():
|
||||
args = get_args()
|
||||
|
||||
texts = []
|
||||
# extra information like boosting score (start with :), triggering threshold (start with #)
|
||||
# original keyword (start with @)
|
||||
extra_info = []
|
||||
with open(args.text, "r", encoding="utf8") as f:
|
||||
for line in f:
|
||||
texts.append(line.strip())
|
||||
extra = []
|
||||
text = []
|
||||
toks = line.strip().split()
|
||||
for tok in toks:
|
||||
if tok[0] == ":" or tok[0] == "#" or tok[0] == "@":
|
||||
extra.append(tok)
|
||||
else:
|
||||
text.append(tok)
|
||||
texts.append(" ".join(text))
|
||||
extra_info.append(extra)
|
||||
encoded_texts = text2token(
|
||||
texts,
|
||||
tokens=args.tokens,
|
||||
@@ -89,7 +136,8 @@ def main():
|
||||
bpe_model=args.bpe_model,
|
||||
)
|
||||
with open(args.output, "w", encoding="utf8") as f:
|
||||
for txt in encoded_texts:
|
||||
for i, txt in enumerate(encoded_texts):
|
||||
txt += extra_info[i]
|
||||
f.write(" ".join(txt) + "\n")
|
||||
|
||||
|
||||
|
||||
1
setup.py
@@ -51,6 +51,7 @@ def get_binaries_to_install():
|
||||
|
||||
# Remember to also change cmake/cmake_extension.py
|
||||
binaries = ["sherpa-onnx"]
|
||||
binaries += ["sherpa-onnx-keyword-spotter"]
|
||||
binaries += ["sherpa-onnx-offline"]
|
||||
binaries += ["sherpa-onnx-microphone"]
|
||||
binaries += ["sherpa-onnx-microphone-offline"]
|
||||
|
||||
@@ -19,6 +19,8 @@ set(sources
|
||||
features.cc
|
||||
file-utils.cc
|
||||
hypothesis.cc
|
||||
keyword-spotter-impl.cc
|
||||
keyword-spotter.cc
|
||||
offline-ctc-fst-decoder-config.cc
|
||||
offline-ctc-fst-decoder.cc
|
||||
offline-ctc-greedy-search-decoder.cc
|
||||
@@ -87,6 +89,7 @@ set(sources
|
||||
stack.cc
|
||||
symbol-table.cc
|
||||
text-utils.cc
|
||||
transducer-keyword-decoder.cc
|
||||
transpose.cc
|
||||
unbind.cc
|
||||
utils.cc
|
||||
@@ -173,12 +176,14 @@ if(NOT BUILD_SHARED_LIBS AND CMAKE_SYSTEM_NAME STREQUAL Linux)
|
||||
endif()
|
||||
|
||||
add_executable(sherpa-onnx sherpa-onnx.cc)
|
||||
add_executable(sherpa-onnx-keyword-spotter sherpa-onnx-keyword-spotter.cc)
|
||||
add_executable(sherpa-onnx-offline sherpa-onnx-offline.cc)
|
||||
add_executable(sherpa-onnx-offline-parallel sherpa-onnx-offline-parallel.cc)
|
||||
add_executable(sherpa-onnx-offline-tts sherpa-onnx-offline-tts.cc)
|
||||
|
||||
set(main_exes
|
||||
sherpa-onnx
|
||||
sherpa-onnx-keyword-spotter
|
||||
sherpa-onnx-offline
|
||||
sherpa-onnx-offline-parallel
|
||||
sherpa-onnx-offline-tts
|
||||
|
||||
@@ -5,6 +5,7 @@
|
||||
#include "sherpa-onnx/csrc/context-graph.h"
|
||||
|
||||
#include <chrono> // NOLINT
|
||||
#include <cmath>
|
||||
#include <map>
|
||||
#include <random>
|
||||
#include <string>
|
||||
@@ -15,27 +16,25 @@
|
||||
|
||||
namespace sherpa_onnx {
|
||||
|
||||
TEST(ContextGraph, TestBasic) {
|
||||
static void TestHelper(const std::map<std::string, float> &queries, float score,
|
||||
bool strict_mode) {
|
||||
std::vector<std::string> contexts_str(
|
||||
{"S", "HE", "SHE", "SHELL", "HIS", "HERS", "HELLO", "THIS", "THEM"});
|
||||
std::vector<std::vector<int32_t>> contexts;
|
||||
std::vector<float> scores;
|
||||
for (int32_t i = 0; i < contexts_str.size(); ++i) {
|
||||
contexts.emplace_back(contexts_str[i].begin(), contexts_str[i].end());
|
||||
scores.push_back(std::round(score / contexts_str[i].size() * 100) / 100);
|
||||
}
|
||||
auto context_graph = ContextGraph(contexts, 1);
|
||||
|
||||
auto queries = std::map<std::string, float>{
|
||||
{"HEHERSHE", 14}, {"HERSHE", 12}, {"HISHE", 9},
|
||||
{"SHED", 6}, {"SHELF", 6}, {"HELL", 2},
|
||||
{"HELLO", 7}, {"DHRHISQ", 4}, {"THEN", 2}};
|
||||
auto context_graph = ContextGraph(contexts, 1, scores);
|
||||
|
||||
for (const auto &iter : queries) {
|
||||
float total_scores = 0;
|
||||
auto state = context_graph.Root();
|
||||
for (auto q : iter.first) {
|
||||
auto res = context_graph.ForwardOneStep(state, q);
|
||||
total_scores += res.first;
|
||||
state = res.second;
|
||||
auto res = context_graph.ForwardOneStep(state, q, strict_mode);
|
||||
total_scores += std::get<0>(res);
|
||||
state = std::get<1>(res);
|
||||
}
|
||||
auto res = context_graph.Finalize(state);
|
||||
EXPECT_EQ(res.second->token, -1);
|
||||
@@ -44,6 +43,37 @@ TEST(ContextGraph, TestBasic) {
|
||||
}
|
||||
}
|
||||
|
||||
TEST(ContextGraph, TestBasic) {
|
||||
auto queries = std::map<std::string, float>{
|
||||
{"HEHERSHE", 14}, {"HERSHE", 12}, {"HISHE", 9},
|
||||
{"SHED", 6}, {"SHELF", 6}, {"HELL", 2},
|
||||
{"HELLO", 7}, {"DHRHISQ", 4}, {"THEN", 2}};
|
||||
TestHelper(queries, 0, true);
|
||||
}
|
||||
|
||||
TEST(ContextGraph, TestBasicNonStrict) {
|
||||
auto queries = std::map<std::string, float>{
|
||||
{"HEHERSHE", 7}, {"HERSHE", 5}, {"HISHE", 5}, {"SHED", 3}, {"SHELF", 3},
|
||||
{"HELL", 2}, {"HELLO", 2}, {"DHRHISQ", 3}, {"THEN", 2}};
|
||||
TestHelper(queries, 0, false);
|
||||
}
|
||||
|
||||
TEST(ContextGraph, TestCustomize) {
|
||||
auto queries = std::map<std::string, float>{
|
||||
{"HEHERSHE", 35.84}, {"HERSHE", 30.84}, {"HISHE", 24.18},
|
||||
{"SHED", 18.34}, {"SHELF", 18.34}, {"HELL", 5},
|
||||
{"HELLO", 13}, {"DHRHISQ", 10.84}, {"THEN", 5}};
|
||||
TestHelper(queries, 5, true);
|
||||
}
|
||||
|
||||
TEST(ContextGraph, TestCustomizeNonStrict) {
|
||||
auto queries = std::map<std::string, float>{
|
||||
{"HEHERSHE", 20}, {"HERSHE", 15}, {"HISHE", 10.84},
|
||||
{"SHED", 10}, {"SHELF", 10}, {"HELL", 5},
|
||||
{"HELLO", 5}, {"DHRHISQ", 5.84}, {"THEN", 5}};
|
||||
TestHelper(queries, 5, false);
|
||||
}
|
||||
|
||||
TEST(ContextGraph, Benchmark) {
|
||||
std::random_device rd;
|
||||
std::mt19937 mt(rd());
|
||||
|
||||
@@ -4,22 +4,59 @@
|
||||
|
||||
#include "sherpa-onnx/csrc/context-graph.h"
|
||||
|
||||
#include <algorithm>
|
||||
#include <cassert>
|
||||
#include <queue>
|
||||
#include <string>
|
||||
#include <tuple>
|
||||
#include <utility>
|
||||
|
||||
#include "sherpa-onnx/csrc/macros.h"
|
||||
|
||||
namespace sherpa_onnx {
|
||||
void ContextGraph::Build(
|
||||
const std::vector<std::vector<int32_t>> &token_ids) const {
|
||||
void ContextGraph::Build(const std::vector<std::vector<int32_t>> &token_ids,
|
||||
const std::vector<float> &scores,
|
||||
const std::vector<std::string> &phrases,
|
||||
const std::vector<float> &ac_thresholds) const {
|
||||
if (!scores.empty()) {
|
||||
SHERPA_ONNX_CHECK_EQ(token_ids.size(), scores.size());
|
||||
}
|
||||
if (!phrases.empty()) {
|
||||
SHERPA_ONNX_CHECK_EQ(token_ids.size(), phrases.size());
|
||||
}
|
||||
if (!ac_thresholds.empty()) {
|
||||
SHERPA_ONNX_CHECK_EQ(token_ids.size(), ac_thresholds.size());
|
||||
}
|
||||
for (int32_t i = 0; i < token_ids.size(); ++i) {
|
||||
auto node = root_.get();
|
||||
float score = scores.empty() ? 0.0f : scores[i];
|
||||
score = score == 0.0f ? context_score_ : score;
|
||||
float ac_threshold = ac_thresholds.empty() ? 0.0f : ac_thresholds[i];
|
||||
ac_threshold = ac_threshold == 0.0f ? ac_threshold_ : ac_threshold;
|
||||
std::string phrase = phrases.empty() ? std::string() : phrases[i];
|
||||
|
||||
for (int32_t j = 0; j < token_ids[i].size(); ++j) {
|
||||
int32_t token = token_ids[i][j];
|
||||
if (0 == node->next.count(token)) {
|
||||
bool is_end = j == token_ids[i].size() - 1;
|
||||
node->next[token] = std::make_unique<ContextState>(
|
||||
token, context_score_, node->node_score + context_score_,
|
||||
is_end ? node->node_score + context_score_ : 0, is_end);
|
||||
token, score, node->node_score + score,
|
||||
is_end ? node->node_score + score : 0, j + 1,
|
||||
is_end ? ac_threshold : 0.0f, is_end,
|
||||
is_end ? phrase : std::string());
|
||||
} else {
|
||||
float token_score = std::max(score, node->next[token]->token_score);
|
||||
node->next[token]->token_score = token_score;
|
||||
float node_score = node->node_score + token_score;
|
||||
node->next[token]->node_score = node_score;
|
||||
bool is_end =
|
||||
(j == token_ids[i].size() - 1) || node->next[token]->is_end;
|
||||
node->next[token]->output_score = is_end ? node_score : 0.0f;
|
||||
node->next[token]->is_end = is_end;
|
||||
if (j == token_ids[i].size() - 1) {
|
||||
node->next[token]->phrase = phrase;
|
||||
node->next[token]->ac_threshold = ac_threshold;
|
||||
}
|
||||
}
|
||||
node = node->next[token].get();
|
||||
}
|
||||
@@ -27,8 +64,9 @@ void ContextGraph::Build(
|
||||
FillFailOutput();
|
||||
}
|
||||
|
||||
std::pair<float, const ContextState *> ContextGraph::ForwardOneStep(
|
||||
const ContextState *state, int32_t token) const {
|
||||
std::tuple<float, const ContextState *, const ContextState *>
|
||||
ContextGraph::ForwardOneStep(const ContextState *state, int32_t token,
|
||||
bool strict_mode /*= true*/) const {
|
||||
const ContextState *node;
|
||||
float score;
|
||||
if (1 == state->next.count(token)) {
|
||||
@@ -45,8 +83,22 @@ std::pair<float, const ContextState *> ContextGraph::ForwardOneStep(
|
||||
}
|
||||
score = node->node_score - state->node_score;
|
||||
}
|
||||
|
||||
SHERPA_ONNX_CHECK(nullptr != node);
|
||||
return std::make_pair(score + node->output_score, node);
|
||||
|
||||
const ContextState *matched_node =
|
||||
node->is_end ? node : (node->output != nullptr ? node->output : nullptr);
|
||||
|
||||
if (!strict_mode && node->output_score != 0) {
|
||||
SHERPA_ONNX_CHECK(nullptr != matched_node);
|
||||
float output_score =
|
||||
node->is_end ? node->node_score
|
||||
: (node->output != nullptr ? node->output->node_score
|
||||
: node->node_score);
|
||||
return std::make_tuple(score + output_score - node->node_score, root_.get(),
|
||||
matched_node);
|
||||
}
|
||||
return std::make_tuple(score + node->output_score, node, matched_node);
|
||||
}
|
||||
|
||||
std::pair<float, const ContextState *> ContextGraph::Finalize(
|
||||
@@ -55,6 +107,22 @@ std::pair<float, const ContextState *> ContextGraph::Finalize(
|
||||
return std::make_pair(score, root_.get());
|
||||
}
|
||||
|
||||
std::pair<bool, const ContextState *> ContextGraph::IsMatched(
|
||||
const ContextState *state) const {
|
||||
bool status = false;
|
||||
const ContextState *node = nullptr;
|
||||
if (state->is_end) {
|
||||
status = true;
|
||||
node = state;
|
||||
} else {
|
||||
if (state->output != nullptr) {
|
||||
status = true;
|
||||
node = state->output;
|
||||
}
|
||||
}
|
||||
return std::make_pair(status, node);
|
||||
}
|
||||
|
||||
void ContextGraph::FillFailOutput() const {
|
||||
std::queue<const ContextState *> node_queue;
|
||||
for (auto &kv : root_->next) {
|
||||
|
||||
@@ -6,6 +6,8 @@
|
||||
#define SHERPA_ONNX_CSRC_CONTEXT_GRAPH_H_
|
||||
|
||||
#include <memory>
|
||||
#include <string>
|
||||
#include <tuple>
|
||||
#include <unordered_map>
|
||||
#include <utility>
|
||||
#include <vector>
|
||||
@@ -22,34 +24,55 @@ struct ContextState {
|
||||
float token_score;
|
||||
float node_score;
|
||||
float output_score;
|
||||
int32_t level;
|
||||
float ac_threshold;
|
||||
bool is_end;
|
||||
std::string phrase;
|
||||
std::unordered_map<int32_t, std::unique_ptr<ContextState>> next;
|
||||
const ContextState *fail = nullptr;
|
||||
const ContextState *output = nullptr;
|
||||
|
||||
ContextState() = default;
|
||||
ContextState(int32_t token, float token_score, float node_score,
|
||||
float output_score, bool is_end)
|
||||
float output_score, int32_t level = 0, float ac_threshold = 0.0f,
|
||||
bool is_end = false, const std::string &phrase = {})
|
||||
: token(token),
|
||||
token_score(token_score),
|
||||
node_score(node_score),
|
||||
output_score(output_score),
|
||||
is_end(is_end) {}
|
||||
level(level),
|
||||
ac_threshold(ac_threshold),
|
||||
is_end(is_end),
|
||||
phrase(phrase) {}
|
||||
};
|
||||
|
||||
class ContextGraph {
|
||||
public:
|
||||
ContextGraph() = default;
|
||||
ContextGraph(const std::vector<std::vector<int32_t>> &token_ids,
|
||||
float context_score)
|
||||
: context_score_(context_score) {
|
||||
root_ = std::make_unique<ContextState>(-1, 0, 0, 0, false);
|
||||
float context_score, float ac_threshold,
|
||||
const std::vector<float> &scores = {},
|
||||
const std::vector<std::string> &phrases = {},
|
||||
const std::vector<float> &ac_thresholds = {})
|
||||
: context_score_(context_score), ac_threshold_(ac_threshold) {
|
||||
root_ = std::make_unique<ContextState>(-1, 0, 0, 0);
|
||||
root_->fail = root_.get();
|
||||
Build(token_ids);
|
||||
Build(token_ids, scores, phrases, ac_thresholds);
|
||||
}
|
||||
|
||||
std::pair<float, const ContextState *> ForwardOneStep(
|
||||
const ContextState *state, int32_t token_id) const;
|
||||
ContextGraph(const std::vector<std::vector<int32_t>> &token_ids,
|
||||
float context_score, const std::vector<float> &scores = {},
|
||||
const std::vector<std::string> &phrases = {})
|
||||
: ContextGraph(token_ids, context_score, 0.0f, scores, phrases,
|
||||
std::vector<float>()) {}
|
||||
|
||||
std::tuple<float, const ContextState *, const ContextState *> ForwardOneStep(
|
||||
const ContextState *state, int32_t token_id,
|
||||
bool strict_mode = true) const;
|
||||
|
||||
std::pair<bool, const ContextState *> IsMatched(
|
||||
const ContextState *state) const;
|
||||
|
||||
std::pair<float, const ContextState *> Finalize(
|
||||
const ContextState *state) const;
|
||||
|
||||
@@ -57,8 +80,12 @@ class ContextGraph {
|
||||
|
||||
private:
|
||||
float context_score_;
|
||||
float ac_threshold_;
|
||||
std::unique_ptr<ContextState> root_;
|
||||
void Build(const std::vector<std::vector<int32_t>> &token_ids) const;
|
||||
void Build(const std::vector<std::vector<int32_t>> &token_ids,
|
||||
const std::vector<float> &scores,
|
||||
const std::vector<std::string> &phrases,
|
||||
const std::vector<float> &ac_thresholds) const;
|
||||
void FillFailOutput() const;
|
||||
};
|
||||
|
||||
|
||||
@@ -28,6 +28,10 @@ struct Hypothesis {
|
||||
// on which ys[i] is decoded.
|
||||
std::vector<int32_t> timestamps;
|
||||
|
||||
// The acoustic probability for each token in ys.
|
||||
// Only used for keyword spotting task.
|
||||
std::vector<float> ys_probs;
|
||||
|
||||
// The total score of ys in log space.
|
||||
// It contains only acoustic scores
|
||||
double log_prob = 0;
|
||||
|
||||
33
sherpa-onnx/csrc/keyword-spotter-impl.cc
Normal file
@@ -0,0 +1,33 @@
|
||||
// sherpa-onnx/csrc/keyword-spotter-impl.cc
|
||||
//
|
||||
// Copyright (c) 2023-2024 Xiaomi Corporation
|
||||
|
||||
#include "sherpa-onnx/csrc/keyword-spotter-impl.h"
|
||||
|
||||
#include "sherpa-onnx/csrc/keyword-spotter-transducer-impl.h"
|
||||
|
||||
namespace sherpa_onnx {
|
||||
|
||||
std::unique_ptr<KeywordSpotterImpl> KeywordSpotterImpl::Create(
|
||||
const KeywordSpotterConfig &config) {
|
||||
if (!config.model_config.transducer.encoder.empty()) {
|
||||
return std::make_unique<KeywordSpotterTransducerImpl>(config);
|
||||
}
|
||||
|
||||
SHERPA_ONNX_LOGE("Please specify a model");
|
||||
exit(-1);
|
||||
}
|
||||
|
||||
#if __ANDROID_API__ >= 9
|
||||
std::unique_ptr<KeywordSpotterImpl> KeywordSpotterImpl::Create(
|
||||
AAssetManager *mgr, const KeywordSpotterConfig &config) {
|
||||
if (!config.model_config.transducer.encoder.empty()) {
|
||||
return std::make_unique<KeywordSpotterTransducerImpl>(mgr, config);
|
||||
}
|
||||
|
||||
SHERPA_ONNX_LOGE("Please specify a model");
|
||||
exit(-1);
|
||||
}
|
||||
#endif
|
||||
|
||||
} // namespace sherpa_onnx
|
||||
48
sherpa-onnx/csrc/keyword-spotter-impl.h
Normal file
@@ -0,0 +1,48 @@
|
||||
// sherpa-onnx/csrc/keyword-spotter-impl.h
|
||||
//
|
||||
// Copyright (c) 2023-2024 Xiaomi Corporation
|
||||
|
||||
#ifndef SHERPA_ONNX_CSRC_KEYWORD_SPOTTER_IMPL_H_
|
||||
#define SHERPA_ONNX_CSRC_KEYWORD_SPOTTER_IMPL_H_
|
||||
|
||||
#include <memory>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
|
||||
#if __ANDROID_API__ >= 9
|
||||
#include "android/asset_manager.h"
|
||||
#include "android/asset_manager_jni.h"
|
||||
#endif
|
||||
|
||||
#include "sherpa-onnx/csrc/keyword-spotter.h"
|
||||
#include "sherpa-onnx/csrc/online-stream.h"
|
||||
|
||||
namespace sherpa_onnx {
|
||||
|
||||
class KeywordSpotterImpl {
|
||||
public:
|
||||
static std::unique_ptr<KeywordSpotterImpl> Create(
|
||||
const KeywordSpotterConfig &config);
|
||||
|
||||
#if __ANDROID_API__ >= 9
|
||||
static std::unique_ptr<KeywordSpotterImpl> Create(
|
||||
AAssetManager *mgr, const KeywordSpotterConfig &config);
|
||||
#endif
|
||||
|
||||
virtual ~KeywordSpotterImpl() = default;
|
||||
|
||||
virtual std::unique_ptr<OnlineStream> CreateStream() const = 0;
|
||||
|
||||
virtual std::unique_ptr<OnlineStream> CreateStream(
|
||||
const std::string &keywords) const = 0;
|
||||
|
||||
virtual bool IsReady(OnlineStream *s) const = 0;
|
||||
|
||||
virtual void DecodeStreams(OnlineStream **ss, int32_t n) const = 0;
|
||||
|
||||
virtual KeywordResult GetResult(OnlineStream *s) const = 0;
|
||||
};
|
||||
|
||||
} // namespace sherpa_onnx
|
||||
|
||||
#endif // SHERPA_ONNX_CSRC_KEYWORD_SPOTTER_IMPL_H_
|
||||
323
sherpa-onnx/csrc/keyword-spotter-transducer-impl.h
Normal file
@@ -0,0 +1,323 @@
|
||||
// sherpa-onnx/csrc/keyword-spotter-transducer-impl.h
|
||||
//
|
||||
// Copyright (c) 2023-2024 Xiaomi Corporation
|
||||
|
||||
#ifndef SHERPA_ONNX_CSRC_KEYWORD_SPOTTER_TRANSDUCER_IMPL_H_
|
||||
#define SHERPA_ONNX_CSRC_KEYWORD_SPOTTER_TRANSDUCER_IMPL_H_
|
||||
|
||||
#include <algorithm>
|
||||
#include <memory>
|
||||
#include <regex> // NOLINT
|
||||
#include <string>
|
||||
#include <utility>
|
||||
#include <vector>
|
||||
|
||||
#if __ANDROID_API__ >= 9
|
||||
#include <strstream>
|
||||
|
||||
#include "android/asset_manager.h"
|
||||
#include "android/asset_manager_jni.h"
|
||||
#endif
|
||||
|
||||
#include "sherpa-onnx/csrc/file-utils.h"
|
||||
#include "sherpa-onnx/csrc/keyword-spotter-impl.h"
|
||||
#include "sherpa-onnx/csrc/keyword-spotter.h"
|
||||
#include "sherpa-onnx/csrc/macros.h"
|
||||
#include "sherpa-onnx/csrc/online-transducer-model.h"
|
||||
#include "sherpa-onnx/csrc/symbol-table.h"
|
||||
#include "sherpa-onnx/csrc/transducer-keyword-decoder.h"
|
||||
#include "sherpa-onnx/csrc/utils.h"
|
||||
|
||||
namespace sherpa_onnx {
|
||||
|
||||
static KeywordResult Convert(const TransducerKeywordResult &src,
|
||||
const SymbolTable &sym_table, float frame_shift_ms,
|
||||
int32_t subsampling_factor,
|
||||
int32_t frames_since_start) {
|
||||
KeywordResult r;
|
||||
r.tokens.reserve(src.tokens.size());
|
||||
r.timestamps.reserve(src.tokens.size());
|
||||
r.keyword = src.keyword;
|
||||
bool from_tokens = src.keyword.empty();
|
||||
|
||||
for (auto i : src.tokens) {
|
||||
auto sym = sym_table[i];
|
||||
if (from_tokens) {
|
||||
r.keyword.append(sym);
|
||||
}
|
||||
r.tokens.push_back(std::move(sym));
|
||||
}
|
||||
if (from_tokens && r.keyword.size()) {
|
||||
r.keyword = r.keyword.substr(1);
|
||||
}
|
||||
|
||||
float frame_shift_s = frame_shift_ms / 1000. * subsampling_factor;
|
||||
for (auto t : src.timestamps) {
|
||||
float time = frame_shift_s * t;
|
||||
r.timestamps.push_back(time);
|
||||
}
|
||||
|
||||
r.start_time = frames_since_start * frame_shift_ms / 1000.;
|
||||
|
||||
return r;
|
||||
}
|
||||
|
||||
class KeywordSpotterTransducerImpl : public KeywordSpotterImpl {
|
||||
public:
|
||||
explicit KeywordSpotterTransducerImpl(const KeywordSpotterConfig &config)
|
||||
: config_(config),
|
||||
model_(OnlineTransducerModel::Create(config.model_config)),
|
||||
sym_(config.model_config.tokens) {
|
||||
if (sym_.contains("<unk>")) {
|
||||
unk_id_ = sym_["<unk>"];
|
||||
}
|
||||
|
||||
InitKeywords();
|
||||
|
||||
decoder_ = std::make_unique<TransducerKeywordDecoder>(
|
||||
model_.get(), config_.max_active_paths, config_.num_trailing_blanks,
|
||||
unk_id_);
|
||||
}
|
||||
|
||||
#if __ANDROID_API__ >= 9
|
||||
KeywordSpotterTransducerImpl(AAssetManager *mgr,
|
||||
const KeywordSpotterConfig &config)
|
||||
: config_(config),
|
||||
model_(OnlineTransducerModel::Create(mgr, config.model_config)),
|
||||
sym_(mgr, config.model_config.tokens) {
|
||||
if (sym_.contains("<unk>")) {
|
||||
unk_id_ = sym_["<unk>"];
|
||||
}
|
||||
|
||||
InitKeywords(mgr);
|
||||
|
||||
decoder_ = std::make_unique<TransducerKeywordDecoder>(
|
||||
model_.get(), config_.max_active_paths, config_.num_trailing_blanks,
|
||||
unk_id_);
|
||||
}
|
||||
#endif
|
||||
|
||||
std::unique_ptr<OnlineStream> CreateStream() const override {
|
||||
auto stream =
|
||||
std::make_unique<OnlineStream>(config_.feat_config, keywords_graph_);
|
||||
InitOnlineStream(stream.get());
|
||||
return stream;
|
||||
}
|
||||
|
||||
std::unique_ptr<OnlineStream> CreateStream(
|
||||
const std::string &keywords) const override {
|
||||
auto kws = std::regex_replace(keywords, std::regex("/"), "\n");
|
||||
std::istringstream is(kws);
|
||||
|
||||
std::vector<std::vector<int32_t>> current_ids;
|
||||
std::vector<std::string> current_kws;
|
||||
std::vector<float> current_scores;
|
||||
std::vector<float> current_thresholds;
|
||||
|
||||
if (!EncodeKeywords(is, sym_, ¤t_ids, ¤t_kws, ¤t_scores,
|
||||
¤t_thresholds)) {
|
||||
SHERPA_ONNX_LOGE("Encode keywords %s failed.", keywords.c_str());
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
int32_t num_kws = current_ids.size();
|
||||
int32_t num_default_kws = keywords_id_.size();
|
||||
|
||||
current_ids.insert(current_ids.end(), keywords_id_.begin(),
|
||||
keywords_id_.end());
|
||||
|
||||
if (!current_kws.empty() && !keywords_.empty()) {
|
||||
current_kws.insert(current_kws.end(), keywords_.begin(), keywords_.end());
|
||||
} else if (!current_kws.empty() && keywords_.empty()) {
|
||||
current_kws.insert(current_kws.end(), num_default_kws, std::string());
|
||||
} else if (current_kws.empty() && !keywords_.empty()) {
|
||||
current_kws.insert(current_kws.end(), num_kws, std::string());
|
||||
current_kws.insert(current_kws.end(), keywords_.begin(), keywords_.end());
|
||||
} else {
|
||||
// Do nothing.
|
||||
}
|
||||
|
||||
if (!current_scores.empty() && !boost_scores_.empty()) {
|
||||
current_scores.insert(current_scores.end(), boost_scores_.begin(),
|
||||
boost_scores_.end());
|
||||
} else if (!current_scores.empty() && boost_scores_.empty()) {
|
||||
current_scores.insert(current_scores.end(), num_default_kws,
|
||||
config_.keywords_score);
|
||||
} else if (current_scores.empty() && !boost_scores_.empty()) {
|
||||
current_scores.insert(current_scores.end(), num_kws,
|
||||
config_.keywords_score);
|
||||
current_scores.insert(current_scores.end(), boost_scores_.begin(),
|
||||
boost_scores_.end());
|
||||
} else {
|
||||
// Do nothing.
|
||||
}
|
||||
|
||||
if (!current_thresholds.empty() && !thresholds_.empty()) {
|
||||
current_thresholds.insert(current_thresholds.end(), thresholds_.begin(),
|
||||
thresholds_.end());
|
||||
} else if (!current_thresholds.empty() && thresholds_.empty()) {
|
||||
current_thresholds.insert(current_thresholds.end(), num_default_kws,
|
||||
config_.keywords_threshold);
|
||||
} else if (current_thresholds.empty() && !thresholds_.empty()) {
|
||||
current_thresholds.insert(current_thresholds.end(), num_kws,
|
||||
config_.keywords_threshold);
|
||||
current_thresholds.insert(current_thresholds.end(), thresholds_.begin(),
|
||||
thresholds_.end());
|
||||
} else {
|
||||
// Do nothing.
|
||||
}
|
||||
|
||||
auto keywords_graph = std::make_shared<ContextGraph>(
|
||||
current_ids, config_.keywords_score, config_.keywords_threshold,
|
||||
current_scores, current_kws, current_thresholds);
|
||||
|
||||
auto stream =
|
||||
std::make_unique<OnlineStream>(config_.feat_config, keywords_graph);
|
||||
InitOnlineStream(stream.get());
|
||||
return stream;
|
||||
}
|
||||
|
||||
bool IsReady(OnlineStream *s) const override {
|
||||
return s->GetNumProcessedFrames() + model_->ChunkSize() <
|
||||
s->NumFramesReady();
|
||||
}
|
||||
|
||||
void DecodeStreams(OnlineStream **ss, int32_t n) const override {
|
||||
int32_t chunk_size = model_->ChunkSize();
|
||||
int32_t chunk_shift = model_->ChunkShift();
|
||||
|
||||
int32_t feature_dim = ss[0]->FeatureDim();
|
||||
|
||||
std::vector<TransducerKeywordResult> results(n);
|
||||
std::vector<float> features_vec(n * chunk_size * feature_dim);
|
||||
std::vector<std::vector<Ort::Value>> states_vec(n);
|
||||
std::vector<int64_t> all_processed_frames(n);
|
||||
|
||||
for (int32_t i = 0; i != n; ++i) {
|
||||
SHERPA_ONNX_CHECK(ss[i]->GetContextGraph() != nullptr);
|
||||
|
||||
const auto num_processed_frames = ss[i]->GetNumProcessedFrames();
|
||||
std::vector<float> features =
|
||||
ss[i]->GetFrames(num_processed_frames, chunk_size);
|
||||
|
||||
// Question: should num_processed_frames include chunk_shift?
|
||||
ss[i]->GetNumProcessedFrames() += chunk_shift;
|
||||
|
||||
std::copy(features.begin(), features.end(),
|
||||
features_vec.data() + i * chunk_size * feature_dim);
|
||||
|
||||
results[i] = std::move(ss[i]->GetKeywordResult());
|
||||
states_vec[i] = std::move(ss[i]->GetStates());
|
||||
all_processed_frames[i] = num_processed_frames;
|
||||
}
|
||||
|
||||
auto memory_info =
|
||||
Ort::MemoryInfo::CreateCpu(OrtDeviceAllocator, OrtMemTypeDefault);
|
||||
|
||||
std::array<int64_t, 3> x_shape{n, chunk_size, feature_dim};
|
||||
|
||||
Ort::Value x = Ort::Value::CreateTensor(memory_info, features_vec.data(),
|
||||
features_vec.size(), x_shape.data(),
|
||||
x_shape.size());
|
||||
|
||||
std::array<int64_t, 1> processed_frames_shape{
|
||||
static_cast<int64_t>(all_processed_frames.size())};
|
||||
|
||||
Ort::Value processed_frames = Ort::Value::CreateTensor(
|
||||
memory_info, all_processed_frames.data(), all_processed_frames.size(),
|
||||
processed_frames_shape.data(), processed_frames_shape.size());
|
||||
|
||||
auto states = model_->StackStates(states_vec);
|
||||
|
||||
auto pair = model_->RunEncoder(std::move(x), std::move(states),
|
||||
std::move(processed_frames));
|
||||
|
||||
decoder_->Decode(std::move(pair.first), ss, &results);
|
||||
|
||||
std::vector<std::vector<Ort::Value>> next_states =
|
||||
model_->UnStackStates(pair.second);
|
||||
|
||||
for (int32_t i = 0; i != n; ++i) {
|
||||
ss[i]->SetKeywordResult(results[i]);
|
||||
ss[i]->SetStates(std::move(next_states[i]));
|
||||
}
|
||||
}
|
||||
|
||||
KeywordResult GetResult(OnlineStream *s) const override {
|
||||
TransducerKeywordResult decoder_result = s->GetKeywordResult(true);
|
||||
|
||||
// TODO(fangjun): Remember to change these constants if needed
|
||||
int32_t frame_shift_ms = 10;
|
||||
int32_t subsampling_factor = 4;
|
||||
return Convert(decoder_result, sym_, frame_shift_ms, subsampling_factor,
|
||||
s->GetNumFramesSinceStart());
|
||||
}
|
||||
|
||||
private:
|
||||
void InitKeywords(std::istream &is) {
|
||||
if (!EncodeKeywords(is, sym_, &keywords_id_, &keywords_, &boost_scores_,
|
||||
&thresholds_)) {
|
||||
SHERPA_ONNX_LOGE("Encode keywords failed.");
|
||||
exit(-1);
|
||||
}
|
||||
keywords_graph_ = std::make_shared<ContextGraph>(
|
||||
keywords_id_, config_.keywords_score, config_.keywords_threshold,
|
||||
boost_scores_, keywords_, thresholds_);
|
||||
}
|
||||
|
||||
void InitKeywords() {
|
||||
// each line in keywords_file contains space-separated words
|
||||
|
||||
std::ifstream is(config_.keywords_file);
|
||||
if (!is) {
|
||||
SHERPA_ONNX_LOGE("Open keywords file failed: %s",
|
||||
config_.keywords_file.c_str());
|
||||
exit(-1);
|
||||
}
|
||||
InitKeywords(is);
|
||||
}
|
||||
|
||||
#if __ANDROID_API__ >= 9
|
||||
void InitKeywords(AAssetManager *mgr) {
|
||||
// each line in keywords_file contains space-separated words
|
||||
|
||||
auto buf = ReadFile(mgr, config_.keywords_file);
|
||||
|
||||
std::istrstream is(buf.data(), buf.size());
|
||||
|
||||
if (!is) {
|
||||
SHERPA_ONNX_LOGE("Open keywords file failed: %s",
|
||||
config_.keywords_file.c_str());
|
||||
exit(-1);
|
||||
}
|
||||
InitKeywords(is);
|
||||
}
|
||||
#endif
|
||||
|
||||
void InitOnlineStream(OnlineStream *stream) const {
|
||||
auto r = decoder_->GetEmptyResult();
|
||||
SHERPA_ONNX_CHECK_EQ(r.hyps.size(), 1);
|
||||
|
||||
SHERPA_ONNX_CHECK(stream->GetContextGraph() != nullptr);
|
||||
r.hyps.begin()->second.context_state = stream->GetContextGraph()->Root();
|
||||
|
||||
stream->SetKeywordResult(r);
|
||||
stream->SetStates(model_->GetEncoderInitStates());
|
||||
}
|
||||
|
||||
private:
|
||||
KeywordSpotterConfig config_;
|
||||
std::vector<std::vector<int32_t>> keywords_id_;
|
||||
std::vector<float> boost_scores_;
|
||||
std::vector<float> thresholds_;
|
||||
std::vector<std::string> keywords_;
|
||||
ContextGraphPtr keywords_graph_;
|
||||
std::unique_ptr<OnlineTransducerModel> model_;
|
||||
std::unique_ptr<TransducerKeywordDecoder> decoder_;
|
||||
SymbolTable sym_;
|
||||
int32_t unk_id_ = -1;
|
||||
};
|
||||
|
||||
} // namespace sherpa_onnx
|
||||
|
||||
#endif // SHERPA_ONNX_CSRC_KEYWORD_SPOTTER_TRANSDUCER_IMPL_H_
|
||||
152
sherpa-onnx/csrc/keyword-spotter.cc
Normal file
@@ -0,0 +1,152 @@
|
||||
// sherpa-onnx/csrc/keyword-spotter.cc
|
||||
//
|
||||
// Copyright (c) 2023-2024 Xiaomi Corporation
|
||||
|
||||
#include "sherpa-onnx/csrc/keyword-spotter.h"
|
||||
|
||||
#include <assert.h>
|
||||
|
||||
#include <algorithm>
|
||||
#include <fstream>
|
||||
#include <iomanip>
|
||||
#include <memory>
|
||||
#include <sstream>
|
||||
#include <utility>
|
||||
#include <vector>
|
||||
|
||||
#include "sherpa-onnx/csrc/keyword-spotter-impl.h"
|
||||
|
||||
namespace sherpa_onnx {
|
||||
|
||||
std::string KeywordResult::AsJsonString() const {
|
||||
std::ostringstream os;
|
||||
os << "{";
|
||||
os << "\"start_time\":" << std::fixed << std::setprecision(2) << start_time
|
||||
<< ", ";
|
||||
|
||||
os << "\"keyword\""
|
||||
<< ": ";
|
||||
os << "\"" << keyword << "\""
|
||||
<< ", ";
|
||||
|
||||
os << "\""
|
||||
<< "timestamps"
|
||||
<< "\""
|
||||
<< ": ";
|
||||
os << "[";
|
||||
|
||||
std::string sep = "";
|
||||
for (auto t : timestamps) {
|
||||
os << sep << std::fixed << std::setprecision(2) << t;
|
||||
sep = ", ";
|
||||
}
|
||||
os << "], ";
|
||||
|
||||
os << "\""
|
||||
<< "tokens"
|
||||
<< "\""
|
||||
<< ":";
|
||||
os << "[";
|
||||
|
||||
sep = "";
|
||||
auto oldFlags = os.flags();
|
||||
for (const auto &t : tokens) {
|
||||
if (t.size() == 1 && static_cast<uint8_t>(t[0]) > 0x7f) {
|
||||
const uint8_t *p = reinterpret_cast<const uint8_t *>(t.c_str());
|
||||
os << sep << "\""
|
||||
<< "<0x" << std::hex << std::uppercase << static_cast<uint32_t>(p[0])
|
||||
<< ">"
|
||||
<< "\"";
|
||||
os.flags(oldFlags);
|
||||
} else {
|
||||
os << sep << "\"" << t << "\"";
|
||||
}
|
||||
sep = ", ";
|
||||
}
|
||||
os << "]";
|
||||
os << "}";
|
||||
|
||||
return os.str();
|
||||
}
|
||||
|
||||
void KeywordSpotterConfig::Register(ParseOptions *po) {
|
||||
feat_config.Register(po);
|
||||
model_config.Register(po);
|
||||
|
||||
po->Register("max-active-paths", &max_active_paths,
|
||||
"beam size used in modified beam search.");
|
||||
po->Register("num-trailing-blanks", &num_trailing_blanks,
|
||||
"The number of trailing blanks should have after the keyword.");
|
||||
po->Register("keywords-score", &keywords_score,
|
||||
"The bonus score for each token in context word/phrase.");
|
||||
po->Register("keywords-threshold", &keywords_threshold,
|
||||
"The acoustic threshold (probability) to trigger the keywords.");
|
||||
po->Register(
|
||||
"keywords-file", &keywords_file,
|
||||
"The file containing keywords, one word/phrase per line, and for each"
|
||||
"phrase the bpe/cjkchar are separated by a space. For example: "
|
||||
"▁HE LL O ▁WORLD"
|
||||
"你 好 世 界");
|
||||
}
|
||||
|
||||
bool KeywordSpotterConfig::Validate() const {
|
||||
if (keywords_file.empty()) {
|
||||
SHERPA_ONNX_LOGE("Please provide --keywords-file.");
|
||||
return false;
|
||||
}
|
||||
if (!std::ifstream(keywords_file.c_str()).good()) {
|
||||
SHERPA_ONNX_LOGE("Keywords file %s does not exist.", keywords_file.c_str());
|
||||
return false;
|
||||
}
|
||||
|
||||
return model_config.Validate();
|
||||
}
|
||||
|
||||
std::string KeywordSpotterConfig::ToString() const {
|
||||
std::ostringstream os;
|
||||
|
||||
os << "KeywordSpotterConfig(";
|
||||
os << "feat_config=" << feat_config.ToString() << ", ";
|
||||
os << "model_config=" << model_config.ToString() << ", ";
|
||||
os << "max_active_paths=" << max_active_paths << ", ";
|
||||
os << "num_trailing_blanks=" << num_trailing_blanks << ", ";
|
||||
os << "keywords_score=" << keywords_score << ", ";
|
||||
os << "keywords_threshold=" << keywords_threshold << ", ";
|
||||
os << "keywords_file=\"" << keywords_file << "\")";
|
||||
|
||||
return os.str();
|
||||
}
|
||||
|
||||
KeywordSpotter::KeywordSpotter(const KeywordSpotterConfig &config)
|
||||
: impl_(KeywordSpotterImpl::Create(config)) {}
|
||||
|
||||
#if __ANDROID_API__ >= 9
|
||||
KeywordSpotter::KeywordSpotter(AAssetManager *mgr,
|
||||
const KeywordSpotterConfig &config)
|
||||
: impl_(KeywordSpotterImpl::Create(mgr, config)) {}
|
||||
#endif
|
||||
|
||||
KeywordSpotter::~KeywordSpotter() = default;
|
||||
|
||||
std::unique_ptr<OnlineStream> KeywordSpotter::CreateStream() const {
|
||||
return impl_->CreateStream();
|
||||
}
|
||||
|
||||
std::unique_ptr<OnlineStream> KeywordSpotter::CreateStream(
|
||||
const std::string &keywords) const {
|
||||
return impl_->CreateStream(keywords);
|
||||
}
|
||||
|
||||
bool KeywordSpotter::IsReady(OnlineStream *s) const {
|
||||
return impl_->IsReady(s);
|
||||
}
|
||||
|
||||
void KeywordSpotter::DecodeStreams(OnlineStream **ss, int32_t n) const {
|
||||
impl_->DecodeStreams(ss, n);
|
||||
}
|
||||
|
||||
KeywordResult KeywordSpotter::GetResult(OnlineStream *s) const {
|
||||
return impl_->GetResult(s);
|
||||
}
|
||||
|
||||
} // namespace sherpa_onnx
|
||||
148
sherpa-onnx/csrc/keyword-spotter.h
Normal file
@@ -0,0 +1,148 @@
|
||||
// sherpa-onnx/csrc/keyword-spotter.h
|
||||
//
|
||||
// Copyright (c) 2023-2024 Xiaomi Corporation
|
||||
|
||||
#ifndef SHERPA_ONNX_CSRC_KEYWORD_SPOTTER_H_
|
||||
#define SHERPA_ONNX_CSRC_KEYWORD_SPOTTER_H_
|
||||
|
||||
#include <memory>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
|
||||
#if __ANDROID_API__ >= 9
|
||||
#include "android/asset_manager.h"
|
||||
#include "android/asset_manager_jni.h"
|
||||
#endif
|
||||
|
||||
#include "sherpa-onnx/csrc/features.h"
|
||||
#include "sherpa-onnx/csrc/online-model-config.h"
|
||||
#include "sherpa-onnx/csrc/online-stream.h"
|
||||
#include "sherpa-onnx/csrc/online-transducer-model-config.h"
|
||||
#include "sherpa-onnx/csrc/parse-options.h"
|
||||
|
||||
namespace sherpa_onnx {
|
||||
|
||||
struct KeywordResult {
|
||||
/// The triggered keyword.
|
||||
/// For English, it consists of space separated words.
|
||||
/// For Chinese, it consists of Chinese words without spaces.
|
||||
/// Example 1: "hello world"
|
||||
/// Example 2: "你好世界"
|
||||
std::string keyword;
|
||||
|
||||
/// Decoded results at the token level.
|
||||
/// For instance, for BPE-based models it consists of a list of BPE tokens.
|
||||
std::vector<std::string> tokens;
|
||||
|
||||
/// timestamps.size() == tokens.size()
|
||||
/// timestamps[i] records the time in seconds when tokens[i] is decoded.
|
||||
std::vector<float> timestamps;
|
||||
|
||||
/// Starting time of this segment.
|
||||
/// When an endpoint is detected, it will change
|
||||
float start_time = 0;
|
||||
|
||||
/** Return a json string.
|
||||
*
|
||||
* The returned string contains:
|
||||
* {
|
||||
* "keyword": "The triggered keyword",
|
||||
* "tokens": [x, x, x],
|
||||
* "timestamps": [x, x, x],
|
||||
* "start_time": x,
|
||||
* }
|
||||
*/
|
||||
std::string AsJsonString() const;
|
||||
};
|
||||
|
||||
struct KeywordSpotterConfig {
|
||||
FeatureExtractorConfig feat_config;
|
||||
OnlineModelConfig model_config;
|
||||
|
||||
int32_t max_active_paths = 4;
|
||||
|
||||
int32_t num_trailing_blanks = 1;
|
||||
|
||||
float keywords_score = 1.0;
|
||||
|
||||
float keywords_threshold = 0.25;
|
||||
|
||||
std::string keywords_file;
|
||||
|
||||
KeywordSpotterConfig() = default;
|
||||
|
||||
KeywordSpotterConfig(const FeatureExtractorConfig &feat_config,
|
||||
const OnlineModelConfig &model_config,
|
||||
int32_t max_active_paths, int32_t num_trailing_blanks,
|
||||
float keywords_score, float keywords_threshold,
|
||||
const std::string &keywords_file)
|
||||
: feat_config(feat_config),
|
||||
model_config(model_config),
|
||||
max_active_paths(max_active_paths),
|
||||
num_trailing_blanks(num_trailing_blanks),
|
||||
keywords_score(keywords_score),
|
||||
keywords_threshold(keywords_threshold),
|
||||
keywords_file(keywords_file) {}
|
||||
|
||||
void Register(ParseOptions *po);
|
||||
bool Validate() const;
|
||||
|
||||
std::string ToString() const;
|
||||
};
|
||||
|
||||
class KeywordSpotterImpl;
|
||||
|
||||
class KeywordSpotter {
|
||||
public:
|
||||
explicit KeywordSpotter(const KeywordSpotterConfig &config);
|
||||
|
||||
#if __ANDROID_API__ >= 9
|
||||
KeywordSpotter(AAssetManager *mgr, const KeywordSpotterConfig &config);
|
||||
#endif
|
||||
|
||||
~KeywordSpotter();
|
||||
|
||||
/** Create a stream for decoding.
|
||||
*
|
||||
*/
|
||||
std::unique_ptr<OnlineStream> CreateStream() const;
|
||||
|
||||
/** Create a stream for decoding.
|
||||
*
|
||||
* @param The keywords for this string, it might contain several keywords,
|
||||
* the keywords are separated by "/". In each of the keywords, there
|
||||
* are cjkchars or bpes, the bpe/cjkchar are separated by space (" ").
|
||||
* For example, keywords I LOVE YOU and HELLO WORLD, looks like:
|
||||
*
|
||||
* "▁I ▁LOVE ▁YOU/▁HE LL O ▁WORLD"
|
||||
*/
|
||||
std::unique_ptr<OnlineStream> CreateStream(const std::string &keywords) const;
|
||||
|
||||
/**
|
||||
* Return true if the given stream has enough frames for decoding.
|
||||
* Return false otherwise
|
||||
*/
|
||||
bool IsReady(OnlineStream *s) const;
|
||||
|
||||
/** Decode a single stream. */
|
||||
void DecodeStream(OnlineStream *s) const {
|
||||
OnlineStream *ss[1] = {s};
|
||||
DecodeStreams(ss, 1);
|
||||
}
|
||||
|
||||
/** Decode multiple streams in parallel
|
||||
*
|
||||
* @param ss Pointer array containing streams to be decoded.
|
||||
* @param n Number of streams in `ss`.
|
||||
*/
|
||||
void DecodeStreams(OnlineStream **ss, int32_t n) const;
|
||||
|
||||
KeywordResult GetResult(OnlineStream *s) const;
|
||||
|
||||
private:
|
||||
std::unique_ptr<KeywordSpotterImpl> impl_;
|
||||
};
|
||||
|
||||
} // namespace sherpa_onnx
|
||||
|
||||
#endif // SHERPA_ONNX_CSRC_KEYWORD_SPOTTER_H_
|
||||
@@ -93,8 +93,8 @@ OfflineTransducerModifiedBeamSearchDecoder::Decode(
|
||||
Repeat(model_->Allocator(), &cur_encoder_out, hyps_row_splits);
|
||||
// now cur_encoder_out is of shape (num_hyps, joiner_dim)
|
||||
|
||||
Ort::Value logit = model_->RunJoiner(
|
||||
std::move(cur_encoder_out), View(&decoder_out));
|
||||
Ort::Value logit =
|
||||
model_->RunJoiner(std::move(cur_encoder_out), View(&decoder_out));
|
||||
|
||||
float *p_logit = logit.GetTensorMutableData<float>();
|
||||
LogSoftmax(p_logit, vocab_size, num_hyps);
|
||||
@@ -134,8 +134,8 @@ OfflineTransducerModifiedBeamSearchDecoder::Decode(
|
||||
if (context_graphs[i] != nullptr) {
|
||||
auto context_res =
|
||||
context_graphs[i]->ForwardOneStep(context_state, new_token);
|
||||
context_score = context_res.first;
|
||||
new_hyp.context_state = context_res.second;
|
||||
context_score = std::get<0>(context_res);
|
||||
new_hyp.context_state = std::get<1>(context_res);
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -51,6 +51,25 @@ class OnlineStream::Impl {
|
||||
|
||||
OnlineTransducerDecoderResult &GetResult() { return result_; }
|
||||
|
||||
void SetKeywordResult(const TransducerKeywordResult &r) {
|
||||
keyword_result_ = r;
|
||||
}
|
||||
TransducerKeywordResult &GetKeywordResult(bool remove_duplicates) {
|
||||
if (remove_duplicates) {
|
||||
if (!prev_keyword_result_.timestamps.empty() &&
|
||||
!keyword_result_.timestamps.empty() &&
|
||||
keyword_result_.timestamps[0] <=
|
||||
prev_keyword_result_.timestamps.back()) {
|
||||
return empty_keyword_result_;
|
||||
} else {
|
||||
prev_keyword_result_ = keyword_result_;
|
||||
}
|
||||
return keyword_result_;
|
||||
} else {
|
||||
return keyword_result_;
|
||||
}
|
||||
}
|
||||
|
||||
OnlineCtcDecoderResult &GetCtcResult() { return ctc_result_; }
|
||||
|
||||
void SetCtcResult(const OnlineCtcDecoderResult &r) { ctc_result_ = r; }
|
||||
@@ -93,6 +112,9 @@ class OnlineStream::Impl {
|
||||
int32_t start_frame_index_ = 0; // never reset
|
||||
int32_t segment_ = 0;
|
||||
OnlineTransducerDecoderResult result_;
|
||||
TransducerKeywordResult prev_keyword_result_;
|
||||
TransducerKeywordResult keyword_result_;
|
||||
TransducerKeywordResult empty_keyword_result_;
|
||||
OnlineCtcDecoderResult ctc_result_;
|
||||
std::vector<Ort::Value> states_; // states for transducer or ctc models
|
||||
std::vector<float> paraformer_feat_cache_;
|
||||
@@ -149,6 +171,15 @@ OnlineTransducerDecoderResult &OnlineStream::GetResult() {
|
||||
return impl_->GetResult();
|
||||
}
|
||||
|
||||
void OnlineStream::SetKeywordResult(const TransducerKeywordResult &r) {
|
||||
impl_->SetKeywordResult(r);
|
||||
}
|
||||
|
||||
TransducerKeywordResult &OnlineStream::GetKeywordResult(
|
||||
bool remove_duplicates /*=false*/) {
|
||||
return impl_->GetKeywordResult(remove_duplicates);
|
||||
}
|
||||
|
||||
OnlineCtcDecoderResult &OnlineStream::GetCtcResult() {
|
||||
return impl_->GetCtcResult();
|
||||
}
|
||||
|
||||
@@ -14,9 +14,11 @@
|
||||
#include "sherpa-onnx/csrc/online-ctc-decoder.h"
|
||||
#include "sherpa-onnx/csrc/online-paraformer-decoder.h"
|
||||
#include "sherpa-onnx/csrc/online-transducer-decoder.h"
|
||||
#include "sherpa-onnx/csrc/transducer-keyword-decoder.h"
|
||||
|
||||
namespace sherpa_onnx {
|
||||
|
||||
class TransducerKeywordResult;
|
||||
class OnlineStream {
|
||||
public:
|
||||
explicit OnlineStream(const FeatureExtractorConfig &config = {},
|
||||
@@ -76,6 +78,9 @@ class OnlineStream {
|
||||
void SetResult(const OnlineTransducerDecoderResult &r);
|
||||
OnlineTransducerDecoderResult &GetResult();
|
||||
|
||||
void SetKeywordResult(const TransducerKeywordResult &r);
|
||||
TransducerKeywordResult &GetKeywordResult(bool remove_duplicates = false);
|
||||
|
||||
void SetCtcResult(const OnlineCtcDecoderResult &r);
|
||||
OnlineCtcDecoderResult &GetCtcResult();
|
||||
|
||||
@@ -92,7 +97,7 @@ class OnlineStream {
|
||||
*/
|
||||
const ContextGraphPtr &GetContextGraph() const;
|
||||
|
||||
// for streaming parformer
|
||||
// for streaming paraformer
|
||||
std::vector<float> &GetParaformerFeatCache();
|
||||
std::vector<float> &GetParaformerEncoderOutCache();
|
||||
std::vector<float> &GetParaformerAlphaCache();
|
||||
|
||||
@@ -75,10 +75,10 @@ void OnlineTransducerModifiedBeamSearchDecoder::Decode(
|
||||
encoder_out.GetTensorTypeAndShapeInfo().GetShape();
|
||||
|
||||
if (encoder_out_shape[0] != result->size()) {
|
||||
fprintf(stderr,
|
||||
"Size mismatch! encoder_out.size(0) %d, result.size(0): %d\n",
|
||||
static_cast<int32_t>(encoder_out_shape[0]),
|
||||
static_cast<int32_t>(result->size()));
|
||||
SHERPA_ONNX_LOGE(
|
||||
"Size mismatch! encoder_out.size(0) %d, result.size(0): %d\n",
|
||||
static_cast<int32_t>(encoder_out_shape[0]),
|
||||
static_cast<int32_t>(result->size()));
|
||||
exit(-1);
|
||||
}
|
||||
|
||||
@@ -119,8 +119,8 @@ void OnlineTransducerModifiedBeamSearchDecoder::Decode(
|
||||
GetEncoderOutFrame(model_->Allocator(), &encoder_out, t);
|
||||
cur_encoder_out =
|
||||
Repeat(model_->Allocator(), &cur_encoder_out, hyps_row_splits);
|
||||
Ort::Value logit = model_->RunJoiner(
|
||||
std::move(cur_encoder_out), View(&decoder_out));
|
||||
Ort::Value logit =
|
||||
model_->RunJoiner(std::move(cur_encoder_out), View(&decoder_out));
|
||||
|
||||
float *p_logit = logit.GetTensorMutableData<float>();
|
||||
LogSoftmax(p_logit, vocab_size, num_hyps);
|
||||
@@ -164,8 +164,8 @@ void OnlineTransducerModifiedBeamSearchDecoder::Decode(
|
||||
if (ss != nullptr && ss[b]->GetContextGraph() != nullptr) {
|
||||
auto context_res = ss[b]->GetContextGraph()->ForwardOneStep(
|
||||
context_state, new_token);
|
||||
context_score = context_res.first;
|
||||
new_hyp.context_state = context_res.second;
|
||||
context_score = std::get<0>(context_res);
|
||||
new_hyp.context_state = std::get<1>(context_res);
|
||||
}
|
||||
if (lm_) {
|
||||
lm_->ComputeLMScore(lm_scale_, &new_hyp);
|
||||
|
||||
122
sherpa-onnx/csrc/sherpa-onnx-keyword-spotter.cc
Normal file
@@ -0,0 +1,122 @@
|
||||
// sherpa-onnx/csrc/sherpa-onnx-keyword-spotter.cc
|
||||
//
|
||||
// Copyright (c) 2023-2024 Xiaomi Corporation
|
||||
|
||||
#include <stdio.h>
|
||||
|
||||
#include <iomanip>
|
||||
#include <iostream>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
|
||||
#include "sherpa-onnx/csrc/keyword-spotter.h"
|
||||
#include "sherpa-onnx/csrc/online-stream.h"
|
||||
#include "sherpa-onnx/csrc/parse-options.h"
|
||||
#include "sherpa-onnx/csrc/symbol-table.h"
|
||||
#include "sherpa-onnx/csrc/wave-reader.h"
|
||||
|
||||
typedef struct {
|
||||
std::unique_ptr<sherpa_onnx::OnlineStream> online_stream;
|
||||
std::string filename;
|
||||
} Stream;
|
||||
|
||||
int main(int32_t argc, char *argv[]) {
|
||||
const char *kUsageMessage = R"usage(
|
||||
Usage:
|
||||
|
||||
(1) Streaming transducer
|
||||
|
||||
./bin/sherpa-onnx-keyword-spotter \
|
||||
--tokens=/path/to/tokens.txt \
|
||||
--encoder=/path/to/encoder.onnx \
|
||||
--decoder=/path/to/decoder.onnx \
|
||||
--joiner=/path/to/joiner.onnx \
|
||||
--provider=cpu \
|
||||
--num-threads=2 \
|
||||
--keywords-file=keywords.txt \
|
||||
/path/to/foo.wav [bar.wav foobar.wav ...]
|
||||
|
||||
Note: It supports decoding multiple files in batches
|
||||
|
||||
Default value for num_threads is 2.
|
||||
Valid values for provider: cpu (default), cuda, coreml.
|
||||
foo.wav should be of single channel, 16-bit PCM encoded wave file; its
|
||||
sampling rate can be arbitrary and does not need to be 16kHz.
|
||||
|
||||
Please refer to
|
||||
https://k2-fsa.github.io/sherpa/onnx/pretrained_models/index.html
|
||||
for a list of pre-trained models to download.
|
||||
)usage";
|
||||
|
||||
sherpa_onnx::ParseOptions po(kUsageMessage);
|
||||
sherpa_onnx::KeywordSpotterConfig config;
|
||||
|
||||
config.Register(&po);
|
||||
|
||||
po.Read(argc, argv);
|
||||
if (po.NumArgs() < 1) {
|
||||
po.PrintUsage();
|
||||
exit(EXIT_FAILURE);
|
||||
}
|
||||
|
||||
fprintf(stderr, "%s\n", config.ToString().c_str());
|
||||
|
||||
if (!config.Validate()) {
|
||||
fprintf(stderr, "Errors in config!\n");
|
||||
return -1;
|
||||
}
|
||||
|
||||
sherpa_onnx::KeywordSpotter keyword_spotter(config);
|
||||
|
||||
std::vector<Stream> ss;
|
||||
|
||||
for (int32_t i = 1; i <= po.NumArgs(); ++i) {
|
||||
const std::string wav_filename = po.GetArg(i);
|
||||
int32_t sampling_rate = -1;
|
||||
|
||||
bool is_ok = false;
|
||||
const std::vector<float> samples =
|
||||
sherpa_onnx::ReadWave(wav_filename, &sampling_rate, &is_ok);
|
||||
|
||||
if (!is_ok) {
|
||||
fprintf(stderr, "Failed to read %s\n", wav_filename.c_str());
|
||||
return -1;
|
||||
}
|
||||
|
||||
auto s = keyword_spotter.CreateStream();
|
||||
s->AcceptWaveform(sampling_rate, samples.data(), samples.size());
|
||||
|
||||
std::vector<float> tail_paddings(static_cast<int>(0.8 * sampling_rate));
|
||||
// Note: We can call AcceptWaveform() multiple times.
|
||||
s->AcceptWaveform(sampling_rate, tail_paddings.data(),
|
||||
tail_paddings.size());
|
||||
|
||||
// Call InputFinished() to indicate that no audio samples are available
|
||||
s->InputFinished();
|
||||
ss.push_back({std::move(s), wav_filename});
|
||||
}
|
||||
|
||||
std::vector<sherpa_onnx::OnlineStream *> ready_streams;
|
||||
for (;;) {
|
||||
ready_streams.clear();
|
||||
for (auto &s : ss) {
|
||||
const auto p_ss = s.online_stream.get();
|
||||
if (keyword_spotter.IsReady(p_ss)) {
|
||||
ready_streams.push_back(p_ss);
|
||||
}
|
||||
std::ostringstream os;
|
||||
const auto r = keyword_spotter.GetResult(p_ss);
|
||||
if (!r.keyword.empty()) {
|
||||
os << s.filename << "\n";
|
||||
os << r.AsJsonString() << "\n\n";
|
||||
fprintf(stderr, "%s", os.str().c_str());
|
||||
}
|
||||
}
|
||||
|
||||
if (ready_streams.empty()) {
|
||||
break;
|
||||
}
|
||||
keyword_spotter.DecodeStreams(ready_streams.data(), ready_streams.size());
|
||||
}
|
||||
return 0;
|
||||
}
|
||||
184
sherpa-onnx/csrc/transducer-keyword-decoder.cc
Normal file
@@ -0,0 +1,184 @@
|
||||
// sherpa-onnx/csrc/transducer-keywords-decoder.cc
|
||||
//
|
||||
// Copyright (c) 2023-2024 Xiaomi Corporation
|
||||
|
||||
#include <algorithm>
|
||||
#include <cmath>
|
||||
#include <utility>
|
||||
#include <vector>
|
||||
|
||||
#include "sherpa-onnx/csrc/log.h"
|
||||
#include "sherpa-onnx/csrc/onnx-utils.h"
|
||||
#include "sherpa-onnx/csrc/transducer-keyword-decoder.h"
|
||||
|
||||
namespace sherpa_onnx {
|
||||
|
||||
TransducerKeywordResult TransducerKeywordDecoder::GetEmptyResult() const {
|
||||
int32_t context_size = model_->ContextSize();
|
||||
int32_t blank_id = 0; // always 0
|
||||
TransducerKeywordResult r;
|
||||
std::vector<int64_t> blanks(context_size, -1);
|
||||
blanks.back() = blank_id;
|
||||
|
||||
Hypotheses blank_hyp({{blanks, 0}});
|
||||
r.hyps = std::move(blank_hyp);
|
||||
return r;
|
||||
}
|
||||
|
||||
void TransducerKeywordDecoder::Decode(
|
||||
Ort::Value encoder_out, OnlineStream **ss,
|
||||
std::vector<TransducerKeywordResult> *result) {
|
||||
std::vector<int64_t> encoder_out_shape =
|
||||
encoder_out.GetTensorTypeAndShapeInfo().GetShape();
|
||||
|
||||
if (encoder_out_shape[0] != result->size()) {
|
||||
SHERPA_ONNX_LOGE(
|
||||
"Size mismatch! encoder_out.size(0) %d, result.size(0): %d\n",
|
||||
static_cast<int32_t>(encoder_out_shape[0]),
|
||||
static_cast<int32_t>(result->size()));
|
||||
exit(-1);
|
||||
}
|
||||
|
||||
int32_t batch_size = static_cast<int32_t>(encoder_out_shape[0]);
|
||||
|
||||
int32_t num_frames = static_cast<int32_t>(encoder_out_shape[1]);
|
||||
int32_t vocab_size = model_->VocabSize();
|
||||
int32_t context_size = model_->ContextSize();
|
||||
std::vector<int64_t> blanks(context_size, -1);
|
||||
blanks.back() = 0; // blank_id is hardcoded to 0
|
||||
|
||||
std::vector<Hypotheses> cur;
|
||||
for (auto &r : *result) {
|
||||
cur.push_back(std::move(r.hyps));
|
||||
}
|
||||
std::vector<Hypothesis> prev;
|
||||
|
||||
for (int32_t t = 0; t != num_frames; ++t) {
|
||||
// Due to merging paths with identical token sequences,
|
||||
// not all utterances have "num_active_paths" paths.
|
||||
auto hyps_row_splits = GetHypsRowSplits(cur);
|
||||
int32_t num_hyps =
|
||||
hyps_row_splits.back(); // total num hyps for all utterance
|
||||
prev.clear();
|
||||
for (auto &hyps : cur) {
|
||||
for (auto &h : hyps) {
|
||||
prev.push_back(std::move(h.second));
|
||||
}
|
||||
}
|
||||
cur.clear();
|
||||
cur.reserve(batch_size);
|
||||
|
||||
Ort::Value decoder_input = model_->BuildDecoderInput(prev);
|
||||
Ort::Value decoder_out = model_->RunDecoder(std::move(decoder_input));
|
||||
|
||||
Ort::Value cur_encoder_out =
|
||||
GetEncoderOutFrame(model_->Allocator(), &encoder_out, t);
|
||||
cur_encoder_out =
|
||||
Repeat(model_->Allocator(), &cur_encoder_out, hyps_row_splits);
|
||||
Ort::Value logit =
|
||||
model_->RunJoiner(std::move(cur_encoder_out), View(&decoder_out));
|
||||
|
||||
float *p_logit = logit.GetTensorMutableData<float>();
|
||||
LogSoftmax(p_logit, vocab_size, num_hyps);
|
||||
|
||||
// The acoustic logprobs for current frame
|
||||
std::vector<float> logprobs(vocab_size * num_hyps);
|
||||
std::memcpy(logprobs.data(), p_logit,
|
||||
sizeof(float) * vocab_size * num_hyps);
|
||||
|
||||
// now p_logit contains log_softmax output, we rename it to p_logprob
|
||||
// to match what it actually contains
|
||||
float *p_logprob = p_logit;
|
||||
|
||||
// add log_prob of each hypothesis to p_logprob before taking top_k
|
||||
for (int32_t i = 0; i != num_hyps; ++i) {
|
||||
float log_prob = prev[i].log_prob;
|
||||
for (int32_t k = 0; k != vocab_size; ++k, ++p_logprob) {
|
||||
*p_logprob += log_prob;
|
||||
}
|
||||
}
|
||||
p_logprob = p_logit; // we changed p_logprob in the above for loop
|
||||
|
||||
for (int32_t b = 0; b != batch_size; ++b) {
|
||||
int32_t frame_offset = (*result)[b].frame_offset;
|
||||
int32_t start = hyps_row_splits[b];
|
||||
int32_t end = hyps_row_splits[b + 1];
|
||||
auto topk =
|
||||
TopkIndex(p_logprob, vocab_size * (end - start), max_active_paths_);
|
||||
|
||||
Hypotheses hyps;
|
||||
for (auto k : topk) {
|
||||
int32_t hyp_index = k / vocab_size + start;
|
||||
int32_t new_token = k % vocab_size;
|
||||
|
||||
Hypothesis new_hyp = prev[hyp_index];
|
||||
float context_score = 0;
|
||||
auto context_state = new_hyp.context_state;
|
||||
|
||||
// blank is hardcoded to 0
|
||||
// also, it treats unk as blank
|
||||
if (new_token != 0 && new_token != unk_id_) {
|
||||
new_hyp.ys.push_back(new_token);
|
||||
new_hyp.timestamps.push_back(t + frame_offset);
|
||||
new_hyp.ys_probs.push_back(
|
||||
exp(logprobs[hyp_index * vocab_size + new_token]));
|
||||
|
||||
new_hyp.num_trailing_blanks = 0;
|
||||
auto context_res = ss[b]->GetContextGraph()->ForwardOneStep(
|
||||
context_state, new_token);
|
||||
context_score = std::get<0>(context_res);
|
||||
new_hyp.context_state = std::get<1>(context_res);
|
||||
// Start matching from the start state, forget the decoder history.
|
||||
if (new_hyp.context_state->token == -1) {
|
||||
new_hyp.ys = blanks;
|
||||
new_hyp.timestamps.clear();
|
||||
new_hyp.ys_probs.clear();
|
||||
}
|
||||
} else {
|
||||
++new_hyp.num_trailing_blanks;
|
||||
}
|
||||
new_hyp.log_prob = p_logprob[k] + context_score;
|
||||
hyps.Add(std::move(new_hyp));
|
||||
} // for (auto k : topk)
|
||||
|
||||
auto best_hyp = hyps.GetMostProbable(false);
|
||||
|
||||
auto status = ss[b]->GetContextGraph()->IsMatched(best_hyp.context_state);
|
||||
bool matched = std::get<0>(status);
|
||||
const ContextState *matched_state = std::get<1>(status);
|
||||
|
||||
if (matched) {
|
||||
float ys_prob = 0.0;
|
||||
int32_t length = best_hyp.ys_probs.size();
|
||||
for (int32_t i = 1; i <= matched_state->level; ++i) {
|
||||
ys_prob += best_hyp.ys_probs[i];
|
||||
}
|
||||
ys_prob /= matched_state->level;
|
||||
if (best_hyp.num_trailing_blanks > num_trailing_blanks_ &&
|
||||
ys_prob >= matched_state->ac_threshold) {
|
||||
auto &r = (*result)[b];
|
||||
r.tokens = {best_hyp.ys.end() - matched_state->level,
|
||||
best_hyp.ys.end()};
|
||||
r.timestamps = {best_hyp.timestamps.end() - matched_state->level,
|
||||
best_hyp.timestamps.end()};
|
||||
r.keyword = matched_state->phrase;
|
||||
|
||||
hyps = Hypotheses({{blanks, 0, ss[b]->GetContextGraph()->Root()}});
|
||||
}
|
||||
}
|
||||
cur.push_back(std::move(hyps));
|
||||
p_logprob += (end - start) * vocab_size;
|
||||
} // for (int32_t b = 0; b != batch_size; ++b)
|
||||
}
|
||||
|
||||
for (int32_t b = 0; b != batch_size; ++b) {
|
||||
auto &hyps = cur[b];
|
||||
auto best_hyp = hyps.GetMostProbable(false);
|
||||
auto &r = (*result)[b];
|
||||
r.hyps = std::move(hyps);
|
||||
r.num_trailing_blanks = best_hyp.num_trailing_blanks;
|
||||
r.frame_offset += num_frames;
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace sherpa_onnx
|
||||
62
sherpa-onnx/csrc/transducer-keyword-decoder.h
Normal file
@@ -0,0 +1,62 @@
|
||||
// sherpa-onnx/csrc/transducer-keywords-decoder.h
|
||||
//
|
||||
// Copyright (c) 2023-2024 Xiaomi Corporation
|
||||
|
||||
#ifndef SHERPA_ONNX_CSRC_TRANSDUCER_KEYWORD_DECODER_H_
|
||||
#define SHERPA_ONNX_CSRC_TRANSDUCER_KEYWORD_DECODER_H_
|
||||
|
||||
#include <string>
|
||||
#include <utility>
|
||||
#include <vector>
|
||||
|
||||
#include "sherpa-onnx/csrc/online-stream.h"
|
||||
#include "sherpa-onnx/csrc/online-transducer-model.h"
|
||||
|
||||
namespace sherpa_onnx {
|
||||
|
||||
struct TransducerKeywordResult {
|
||||
/// Number of frames after subsampling we have decoded so far
|
||||
int32_t frame_offset = 0;
|
||||
|
||||
/// The decoded token IDs for keywords
|
||||
std::vector<int64_t> tokens;
|
||||
|
||||
/// The triggered keyword
|
||||
std::string keyword;
|
||||
|
||||
/// number of trailing blank frames decoded so far
|
||||
int32_t num_trailing_blanks = 0;
|
||||
|
||||
/// timestamps[i] contains the output frame index where tokens[i] is decoded.
|
||||
std::vector<int32_t> timestamps;
|
||||
|
||||
// used only in modified beam_search
|
||||
Hypotheses hyps;
|
||||
};
|
||||
|
||||
class TransducerKeywordDecoder {
|
||||
public:
|
||||
TransducerKeywordDecoder(OnlineTransducerModel *model,
|
||||
int32_t max_active_paths,
|
||||
int32_t num_trailing_blanks, int32_t unk_id)
|
||||
: model_(model),
|
||||
max_active_paths_(max_active_paths),
|
||||
num_trailing_blanks_(num_trailing_blanks),
|
||||
unk_id_(unk_id) {}
|
||||
|
||||
TransducerKeywordResult GetEmptyResult() const;
|
||||
|
||||
void Decode(Ort::Value encoder_out, OnlineStream **ss,
|
||||
std::vector<TransducerKeywordResult> *result);
|
||||
|
||||
private:
|
||||
OnlineTransducerModel *model_; // Not owned
|
||||
|
||||
int32_t max_active_paths_;
|
||||
int32_t num_trailing_blanks_;
|
||||
int32_t unk_id_;
|
||||
};
|
||||
|
||||
} // namespace sherpa_onnx
|
||||
|
||||
#endif // SHERPA_ONNX_CSRC_TRANSDUCER_KEYWORD_DECODER_H_
|
||||
@@ -15,16 +15,31 @@
|
||||
|
||||
namespace sherpa_onnx {
|
||||
|
||||
bool EncodeHotwords(std::istream &is, const SymbolTable &symbol_table,
|
||||
std::vector<std::vector<int32_t>> *hotwords) {
|
||||
hotwords->clear();
|
||||
std::vector<int32_t> tmp;
|
||||
static bool EncodeBase(std::istream &is, const SymbolTable &symbol_table,
|
||||
std::vector<std::vector<int32_t>> *ids,
|
||||
std::vector<std::string> *phrases,
|
||||
std::vector<float> *scores,
|
||||
std::vector<float> *thresholds) {
|
||||
SHERPA_ONNX_CHECK(ids != nullptr);
|
||||
ids->clear();
|
||||
|
||||
std::vector<int32_t> tmp_ids;
|
||||
std::vector<float> tmp_scores;
|
||||
std::vector<float> tmp_thresholds;
|
||||
std::vector<std::string> tmp_phrases;
|
||||
|
||||
std::string line;
|
||||
std::string word;
|
||||
bool has_scores = false;
|
||||
bool has_thresholds = false;
|
||||
bool has_phrases = false;
|
||||
|
||||
while (std::getline(is, line)) {
|
||||
float score = 0;
|
||||
float threshold = 0;
|
||||
std::string phrase = "";
|
||||
|
||||
std::istringstream iss(line);
|
||||
std::vector<std::string> syms;
|
||||
while (iss >> word) {
|
||||
if (word.size() >= 3) {
|
||||
// For BPE-based models, we replace ▁ with a space
|
||||
@@ -35,20 +50,72 @@ bool EncodeHotwords(std::istream &is, const SymbolTable &symbol_table,
|
||||
}
|
||||
}
|
||||
if (symbol_table.contains(word)) {
|
||||
int32_t number = symbol_table[word];
|
||||
tmp.push_back(number);
|
||||
int32_t id = symbol_table[word];
|
||||
tmp_ids.push_back(id);
|
||||
} else {
|
||||
SHERPA_ONNX_LOGE(
|
||||
"Cannot find ID for hotword %s at line: %s. (Hint: words on "
|
||||
"the "
|
||||
"same line are separated by spaces)",
|
||||
word.c_str(), line.c_str());
|
||||
return false;
|
||||
switch (word[0]) {
|
||||
case ':': // boosting score for current keyword
|
||||
score = std::stof(word.substr(1));
|
||||
has_scores = true;
|
||||
break;
|
||||
case '#': // triggering threshold (probability) for current keyword
|
||||
threshold = std::stof(word.substr(1));
|
||||
has_thresholds = true;
|
||||
break;
|
||||
case '@': // the original keyword string
|
||||
phrase = word.substr(1);
|
||||
has_phrases = true;
|
||||
break;
|
||||
default:
|
||||
SHERPA_ONNX_LOGE(
|
||||
"Cannot find ID for token %s at line: %s. (Hint: words on "
|
||||
"the same line are separated by spaces)",
|
||||
word.c_str(), line.c_str());
|
||||
return false;
|
||||
}
|
||||
}
|
||||
}
|
||||
hotwords->push_back(std::move(tmp));
|
||||
ids->push_back(std::move(tmp_ids));
|
||||
tmp_scores.push_back(score);
|
||||
tmp_phrases.push_back(phrase);
|
||||
tmp_thresholds.push_back(threshold);
|
||||
}
|
||||
if (scores != nullptr) {
|
||||
if (has_scores) {
|
||||
scores->swap(tmp_scores);
|
||||
} else {
|
||||
scores->clear();
|
||||
}
|
||||
}
|
||||
if (phrases != nullptr) {
|
||||
if (has_phrases) {
|
||||
*phrases = std::move(tmp_phrases);
|
||||
} else {
|
||||
phrases->clear();
|
||||
}
|
||||
}
|
||||
if (thresholds != nullptr) {
|
||||
if (has_thresholds) {
|
||||
thresholds->swap(tmp_thresholds);
|
||||
} else {
|
||||
thresholds->clear();
|
||||
}
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
bool EncodeHotwords(std::istream &is, const SymbolTable &symbol_table,
|
||||
std::vector<std::vector<int32_t>> *hotwords) {
|
||||
return EncodeBase(is, symbol_table, hotwords, nullptr, nullptr, nullptr);
|
||||
}
|
||||
|
||||
bool EncodeKeywords(std::istream &is, const SymbolTable &symbol_table,
|
||||
std::vector<std::vector<int32_t>> *keywords_id,
|
||||
std::vector<std::string> *keywords,
|
||||
std::vector<float> *boost_scores,
|
||||
std::vector<float> *threshold) {
|
||||
return EncodeBase(is, symbol_table, keywords_id, keywords, boost_scores,
|
||||
threshold);
|
||||
}
|
||||
|
||||
} // namespace sherpa_onnx
|
||||
|
||||
@@ -26,7 +26,32 @@ namespace sherpa_onnx {
|
||||
* otherwise returns false.
|
||||
*/
|
||||
bool EncodeHotwords(std::istream &is, const SymbolTable &symbol_table,
|
||||
std::vector<std::vector<int32_t>> *hotwords);
|
||||
std::vector<std::vector<int32_t>> *hotwords_id);
|
||||
|
||||
/* Encode the keywords in an input stream to be tokens ids.
|
||||
*
|
||||
* @param is The input stream, it contains several lines, one hotword for each
|
||||
* line. For each hotword, the tokens (cjkchar or bpe) are separated
|
||||
* by spaces, it might contain boosting score (starting with :),
|
||||
* triggering threshold (starting with #) and keyword string (starting
|
||||
* with @) too.
|
||||
* @param symbol_table The tokens table mapping symbols to ids. All the symbols
|
||||
* in the stream should be in the symbol_table, if not this
|
||||
* function returns fasle.
|
||||
*
|
||||
* @param keywords_id The encoded ids to be written to.
|
||||
* @param keywords The original keyword string to be written to.
|
||||
* @param boost_scores The boosting score for each keyword to be written to.
|
||||
* @param threshold The triggering threshold for each keyword to be written to.
|
||||
*
|
||||
* @return If all the symbols from ``is`` are in the symbol_table, returns true
|
||||
* otherwise returns false.
|
||||
*/
|
||||
bool EncodeKeywords(std::istream &is, const SymbolTable &symbol_table,
|
||||
std::vector<std::vector<int32_t>> *keywords_id,
|
||||
std::vector<std::string> *keywords,
|
||||
std::vector<float> *boost_scores,
|
||||
std::vector<float> *threshold);
|
||||
|
||||
} // namespace sherpa_onnx
|
||||
|
||||
|
||||
@@ -21,6 +21,7 @@
|
||||
#include "android/asset_manager_jni.h"
|
||||
#endif
|
||||
|
||||
#include "sherpa-onnx/csrc/keyword-spotter.h"
|
||||
#include "sherpa-onnx/csrc/macros.h"
|
||||
#include "sherpa-onnx/csrc/offline-recognizer.h"
|
||||
#include "sherpa-onnx/csrc/offline-tts.h"
|
||||
@@ -140,6 +141,73 @@ class SherpaOnnxVad {
|
||||
VoiceActivityDetector vad_;
|
||||
};
|
||||
|
||||
class SherpaOnnxKws {
|
||||
public:
|
||||
#if __ANDROID_API__ >= 9
|
||||
SherpaOnnxKws(AAssetManager *mgr, const KeywordSpotterConfig &config)
|
||||
: keyword_spotter_(mgr, config),
|
||||
stream_(keyword_spotter_.CreateStream()) {}
|
||||
#endif
|
||||
|
||||
explicit SherpaOnnxKws(const KeywordSpotterConfig &config)
|
||||
: keyword_spotter_(config), stream_(keyword_spotter_.CreateStream()) {}
|
||||
|
||||
void AcceptWaveform(int32_t sample_rate, const float *samples, int32_t n) {
|
||||
if (input_sample_rate_ == -1) {
|
||||
input_sample_rate_ = sample_rate;
|
||||
}
|
||||
|
||||
stream_->AcceptWaveform(sample_rate, samples, n);
|
||||
}
|
||||
|
||||
void InputFinished() const {
|
||||
std::vector<float> tail_padding(input_sample_rate_ * 0.6, 0);
|
||||
stream_->AcceptWaveform(input_sample_rate_, tail_padding.data(),
|
||||
tail_padding.size());
|
||||
stream_->InputFinished();
|
||||
}
|
||||
|
||||
// If keywords is an empty string, it just recreates the decoding stream
|
||||
// always returns true in this case.
|
||||
// If keywords is not empty, it will create a new decoding stream with
|
||||
// the given keywords appended to the default keywords.
|
||||
// Return false if errors occurred when adding keywords, true otherwise.
|
||||
bool Reset(const std::string &keywords = {}) {
|
||||
if (keywords.empty()) {
|
||||
stream_ = keyword_spotter_.CreateStream();
|
||||
return true;
|
||||
} else {
|
||||
auto stream = keyword_spotter_.CreateStream(keywords);
|
||||
// Set new keywords failed, the stream_ will not be updated.
|
||||
if (stream == nullptr) {
|
||||
return false;
|
||||
} else {
|
||||
stream_ = std::move(stream);
|
||||
return true;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
std::string GetKeyword() const {
|
||||
auto result = keyword_spotter_.GetResult(stream_.get());
|
||||
return result.keyword;
|
||||
}
|
||||
|
||||
std::vector<std::string> GetTokens() const {
|
||||
auto result = keyword_spotter_.GetResult(stream_.get());
|
||||
return result.tokens;
|
||||
}
|
||||
|
||||
bool IsReady() const { return keyword_spotter_.IsReady(stream_.get()); }
|
||||
|
||||
void Decode() const { keyword_spotter_.DecodeStream(stream_.get()); }
|
||||
|
||||
private:
|
||||
KeywordSpotter keyword_spotter_;
|
||||
std::unique_ptr<OnlineStream> stream_;
|
||||
int32_t input_sample_rate_ = -1;
|
||||
};
|
||||
|
||||
static OnlineRecognizerConfig GetConfig(JNIEnv *env, jobject config) {
|
||||
OnlineRecognizerConfig ans;
|
||||
|
||||
@@ -457,6 +525,103 @@ static OfflineRecognizerConfig GetOfflineConfig(JNIEnv *env, jobject config) {
|
||||
return ans;
|
||||
}
|
||||
|
||||
static KeywordSpotterConfig GetKwsConfig(JNIEnv *env, jobject config) {
|
||||
KeywordSpotterConfig ans;
|
||||
|
||||
jclass cls = env->GetObjectClass(config);
|
||||
jfieldID fid;
|
||||
|
||||
// https://docs.oracle.com/javase/7/docs/technotes/guides/jni/spec/types.html
|
||||
// https://courses.cs.washington.edu/courses/cse341/99wi/java/tutorial/native1.1/implementing/field.html
|
||||
|
||||
//---------- decoding ----------
|
||||
fid = env->GetFieldID(cls, "maxActivePaths", "I");
|
||||
ans.max_active_paths = env->GetIntField(config, fid);
|
||||
|
||||
fid = env->GetFieldID(cls, "keywordsFile", "Ljava/lang/String;");
|
||||
jstring s = (jstring)env->GetObjectField(config, fid);
|
||||
const char *p = env->GetStringUTFChars(s, nullptr);
|
||||
ans.keywords_file = p;
|
||||
env->ReleaseStringUTFChars(s, p);
|
||||
|
||||
fid = env->GetFieldID(cls, "keywordsScore", "F");
|
||||
ans.keywords_score = env->GetFloatField(config, fid);
|
||||
|
||||
fid = env->GetFieldID(cls, "keywordsThreshold", "F");
|
||||
ans.keywords_threshold = env->GetFloatField(config, fid);
|
||||
|
||||
fid = env->GetFieldID(cls, "numTrailingBlanks", "I");
|
||||
ans.num_trailing_blanks = env->GetIntField(config, fid);
|
||||
|
||||
//---------- feat config ----------
|
||||
fid = env->GetFieldID(cls, "featConfig",
|
||||
"Lcom/k2fsa/sherpa/onnx/FeatureConfig;");
|
||||
jobject feat_config = env->GetObjectField(config, fid);
|
||||
jclass feat_config_cls = env->GetObjectClass(feat_config);
|
||||
|
||||
fid = env->GetFieldID(feat_config_cls, "sampleRate", "I");
|
||||
ans.feat_config.sampling_rate = env->GetIntField(feat_config, fid);
|
||||
|
||||
fid = env->GetFieldID(feat_config_cls, "featureDim", "I");
|
||||
ans.feat_config.feature_dim = env->GetIntField(feat_config, fid);
|
||||
|
||||
//---------- model config ----------
|
||||
fid = env->GetFieldID(cls, "modelConfig",
|
||||
"Lcom/k2fsa/sherpa/onnx/OnlineModelConfig;");
|
||||
jobject model_config = env->GetObjectField(config, fid);
|
||||
jclass model_config_cls = env->GetObjectClass(model_config);
|
||||
|
||||
// transducer
|
||||
fid = env->GetFieldID(model_config_cls, "transducer",
|
||||
"Lcom/k2fsa/sherpa/onnx/OnlineTransducerModelConfig;");
|
||||
jobject transducer_config = env->GetObjectField(model_config, fid);
|
||||
jclass transducer_config_cls = env->GetObjectClass(transducer_config);
|
||||
|
||||
fid = env->GetFieldID(transducer_config_cls, "encoder", "Ljava/lang/String;");
|
||||
s = (jstring)env->GetObjectField(transducer_config, fid);
|
||||
p = env->GetStringUTFChars(s, nullptr);
|
||||
ans.model_config.transducer.encoder = p;
|
||||
env->ReleaseStringUTFChars(s, p);
|
||||
|
||||
fid = env->GetFieldID(transducer_config_cls, "decoder", "Ljava/lang/String;");
|
||||
s = (jstring)env->GetObjectField(transducer_config, fid);
|
||||
p = env->GetStringUTFChars(s, nullptr);
|
||||
ans.model_config.transducer.decoder = p;
|
||||
env->ReleaseStringUTFChars(s, p);
|
||||
|
||||
fid = env->GetFieldID(transducer_config_cls, "joiner", "Ljava/lang/String;");
|
||||
s = (jstring)env->GetObjectField(transducer_config, fid);
|
||||
p = env->GetStringUTFChars(s, nullptr);
|
||||
ans.model_config.transducer.joiner = p;
|
||||
env->ReleaseStringUTFChars(s, p);
|
||||
|
||||
fid = env->GetFieldID(model_config_cls, "tokens", "Ljava/lang/String;");
|
||||
s = (jstring)env->GetObjectField(model_config, fid);
|
||||
p = env->GetStringUTFChars(s, nullptr);
|
||||
ans.model_config.tokens = p;
|
||||
env->ReleaseStringUTFChars(s, p);
|
||||
|
||||
fid = env->GetFieldID(model_config_cls, "numThreads", "I");
|
||||
ans.model_config.num_threads = env->GetIntField(model_config, fid);
|
||||
|
||||
fid = env->GetFieldID(model_config_cls, "debug", "Z");
|
||||
ans.model_config.debug = env->GetBooleanField(model_config, fid);
|
||||
|
||||
fid = env->GetFieldID(model_config_cls, "provider", "Ljava/lang/String;");
|
||||
s = (jstring)env->GetObjectField(model_config, fid);
|
||||
p = env->GetStringUTFChars(s, nullptr);
|
||||
ans.model_config.provider = p;
|
||||
env->ReleaseStringUTFChars(s, p);
|
||||
|
||||
fid = env->GetFieldID(model_config_cls, "modelType", "Ljava/lang/String;");
|
||||
s = (jstring)env->GetObjectField(model_config, fid);
|
||||
p = env->GetStringUTFChars(s, nullptr);
|
||||
ans.model_config.model_type = p;
|
||||
env->ReleaseStringUTFChars(s, p);
|
||||
|
||||
return ans;
|
||||
}
|
||||
|
||||
static VadModelConfig GetVadModelConfig(JNIEnv *env, jobject config) {
|
||||
VadModelConfig ans;
|
||||
|
||||
@@ -1013,7 +1178,124 @@ JNIEXPORT jobjectArray JNICALL Java_com_k2fsa_sherpa_onnx_SherpaOnnx_getTokens(
|
||||
jclass stringClass = env->FindClass("java/lang/String");
|
||||
|
||||
// convert C++ list into jni string array
|
||||
jobjectArray result = env->NewObjectArray(size, stringClass, NULL);
|
||||
jobjectArray result = env->NewObjectArray(size, stringClass, nullptr);
|
||||
for (int32_t i = 0; i < size; i++) {
|
||||
// Convert the C++ string to a C string
|
||||
const char *cstr = tokens[i].c_str();
|
||||
|
||||
// Convert the C string to a jstring
|
||||
jstring jstr = env->NewStringUTF(cstr);
|
||||
|
||||
// Set the array element
|
||||
env->SetObjectArrayElement(result, i, jstr);
|
||||
}
|
||||
|
||||
return result;
|
||||
}
|
||||
|
||||
SHERPA_ONNX_EXTERN_C
|
||||
JNIEXPORT jlong JNICALL Java_com_k2fsa_sherpa_onnx_SherpaOnnxKws_new(
|
||||
JNIEnv *env, jobject /*obj*/, jobject asset_manager, jobject _config) {
|
||||
#if __ANDROID_API__ >= 9
|
||||
AAssetManager *mgr = AAssetManager_fromJava(env, asset_manager);
|
||||
if (!mgr) {
|
||||
SHERPA_ONNX_LOGE("Failed to get asset manager: %p", mgr);
|
||||
}
|
||||
#endif
|
||||
auto config = sherpa_onnx::GetKwsConfig(env, _config);
|
||||
SHERPA_ONNX_LOGE("config:\n%s", config.ToString().c_str());
|
||||
auto model = new sherpa_onnx::SherpaOnnxKws(
|
||||
#if __ANDROID_API__ >= 9
|
||||
mgr,
|
||||
#endif
|
||||
config);
|
||||
|
||||
return (jlong)model;
|
||||
}
|
||||
|
||||
SHERPA_ONNX_EXTERN_C
|
||||
JNIEXPORT jlong JNICALL Java_com_k2fsa_sherpa_onnx_SherpaOnnxKws_newFromFile(
|
||||
JNIEnv *env, jobject /*obj*/, jobject _config) {
|
||||
auto config = sherpa_onnx::GetKwsConfig(env, _config);
|
||||
SHERPA_ONNX_LOGE("config:\n%s", config.ToString().c_str());
|
||||
auto model = new sherpa_onnx::SherpaOnnxKws(config);
|
||||
|
||||
return (jlong)model;
|
||||
}
|
||||
|
||||
SHERPA_ONNX_EXTERN_C
|
||||
JNIEXPORT void JNICALL Java_com_k2fsa_sherpa_onnx_SherpaOnnxKws_delete(
|
||||
JNIEnv *env, jobject /*obj*/, jlong ptr) {
|
||||
delete reinterpret_cast<sherpa_onnx::SherpaOnnxKws *>(ptr);
|
||||
}
|
||||
|
||||
SHERPA_ONNX_EXTERN_C
|
||||
JNIEXPORT bool JNICALL Java_com_k2fsa_sherpa_onnx_SherpaOnnxKws_isReady(
|
||||
JNIEnv *env, jobject /*obj*/, jlong ptr) {
|
||||
auto model = reinterpret_cast<sherpa_onnx::SherpaOnnxKws *>(ptr);
|
||||
return model->IsReady();
|
||||
}
|
||||
|
||||
SHERPA_ONNX_EXTERN_C
|
||||
JNIEXPORT void JNICALL Java_com_k2fsa_sherpa_onnx_SherpaOnnxKws_decode(
|
||||
JNIEnv *env, jobject /*obj*/, jlong ptr) {
|
||||
auto model = reinterpret_cast<sherpa_onnx::SherpaOnnxKws *>(ptr);
|
||||
model->Decode();
|
||||
}
|
||||
|
||||
SHERPA_ONNX_EXTERN_C
|
||||
JNIEXPORT void JNICALL Java_com_k2fsa_sherpa_onnx_SherpaOnnxKws_acceptWaveform(
|
||||
JNIEnv *env, jobject /*obj*/, jlong ptr, jfloatArray samples,
|
||||
jint sample_rate) {
|
||||
auto model = reinterpret_cast<sherpa_onnx::SherpaOnnxKws *>(ptr);
|
||||
|
||||
jfloat *p = env->GetFloatArrayElements(samples, nullptr);
|
||||
jsize n = env->GetArrayLength(samples);
|
||||
|
||||
model->AcceptWaveform(sample_rate, p, n);
|
||||
|
||||
env->ReleaseFloatArrayElements(samples, p, JNI_ABORT);
|
||||
}
|
||||
|
||||
SHERPA_ONNX_EXTERN_C
|
||||
JNIEXPORT void JNICALL Java_com_k2fsa_sherpa_onnx_SherpaOnnxKws_inputFinished(
|
||||
JNIEnv *env, jobject /*obj*/, jlong ptr) {
|
||||
reinterpret_cast<sherpa_onnx::SherpaOnnxKws *>(ptr)->InputFinished();
|
||||
}
|
||||
|
||||
SHERPA_ONNX_EXTERN_C
|
||||
JNIEXPORT jstring JNICALL Java_com_k2fsa_sherpa_onnx_SherpaOnnxKws_getKeyword(
|
||||
JNIEnv *env, jobject /*obj*/, jlong ptr) {
|
||||
// see
|
||||
// https://stackoverflow.com/questions/11621449/send-c-string-to-java-via-jni
|
||||
auto text = reinterpret_cast<sherpa_onnx::SherpaOnnxKws *>(ptr)->GetKeyword();
|
||||
return env->NewStringUTF(text.c_str());
|
||||
}
|
||||
|
||||
SHERPA_ONNX_EXTERN_C
|
||||
JNIEXPORT bool JNICALL Java_com_k2fsa_sherpa_onnx_SherpaOnnxKws_reset(
|
||||
JNIEnv *env, jobject /*obj*/, jlong ptr, jstring keywords) {
|
||||
const char *p_keywords = env->GetStringUTFChars(keywords, nullptr);
|
||||
|
||||
std::string keywords_str = p_keywords;
|
||||
|
||||
bool status =
|
||||
reinterpret_cast<sherpa_onnx::SherpaOnnxKws *>(ptr)->Reset(keywords_str);
|
||||
env->ReleaseStringUTFChars(keywords, p_keywords);
|
||||
return status;
|
||||
}
|
||||
|
||||
SHERPA_ONNX_EXTERN_C
|
||||
JNIEXPORT jobjectArray JNICALL
|
||||
Java_com_k2fsa_sherpa_onnx_SherpaOnnxKws_getTokens(JNIEnv *env, jobject /*obj*/,
|
||||
jlong ptr) {
|
||||
auto tokens =
|
||||
reinterpret_cast<sherpa_onnx::SherpaOnnxKws *>(ptr)->GetTokens();
|
||||
int32_t size = tokens.size();
|
||||
jclass stringClass = env->FindClass("java/lang/String");
|
||||
|
||||
// convert C++ list into jni string array
|
||||
jobjectArray result = env->NewObjectArray(size, stringClass, nullptr);
|
||||
for (int32_t i = 0; i < size; i++) {
|
||||
// Convert the C++ string to a C string
|
||||
const char *cstr = tokens[i].c_str();
|
||||
|
||||
@@ -28,9 +28,14 @@ def cli():
|
||||
)
|
||||
@click.option(
|
||||
"--tokens-type",
|
||||
type=str,
|
||||
type=click.Choice(
|
||||
["cjkchar", "bpe", "cjkchar+bpe", "fpinyin", "ppinyin"], case_sensitive=True
|
||||
),
|
||||
required=True,
|
||||
help="The type of modeling units, should be cjkchar, bpe or cjkchar+bpe",
|
||||
help="""The type of modeling units, should be cjkchar, bpe, cjkchar+bpe, fpinyin or ppinyin.
|
||||
fpinyin means full pinyin, each cjkchar has a pinyin(with tone).
|
||||
ppinyin means partial pinyin, it splits pinyin into initial and final,
|
||||
""",
|
||||
)
|
||||
@click.option(
|
||||
"--bpe-model",
|
||||
@@ -42,14 +47,56 @@ def encode_text(
|
||||
):
|
||||
"""
|
||||
Encode the texts given by the INPUT to tokens and write the results to the OUTPUT.
|
||||
Each line in the texts contains the original phrase, it might also contain some
|
||||
extra items, for example, the boosting score (startting with :), the triggering
|
||||
threshold (startting with #, only used in keyword spotting task) and the original
|
||||
phrase (startting with @). Note: the extra items will be kept same in the output.
|
||||
|
||||
example input 1 (tokens_type = ppinyin):
|
||||
|
||||
小爱同学 :2.0 #0.6 @小爱同学
|
||||
你好问问 :3.5 @你好问问
|
||||
小艺小艺 #0.6 @小艺小艺
|
||||
|
||||
example output 1:
|
||||
|
||||
x iǎo ài t óng x ué :2.0 #0.6 @小爱同学
|
||||
n ǐ h ǎo w èn w èn :3.5 @你好问问
|
||||
x iǎo y ì x iǎo y ì #0.6 @小艺小艺
|
||||
|
||||
example input 2 (tokens_type = bpe):
|
||||
|
||||
HELLO WORLD :1.5 #0.4
|
||||
HI GOOGLE :2.0 #0.8
|
||||
HEY SIRI #0.35
|
||||
|
||||
example output 2:
|
||||
|
||||
▁HE LL O ▁WORLD :1.5 #0.4
|
||||
▁HI ▁GO O G LE :2.0 #0.8
|
||||
▁HE Y ▁S I RI #0.35
|
||||
"""
|
||||
texts = []
|
||||
# extra information like boosting score (start with :), triggering threshold (start with #)
|
||||
# original keyword (start with @)
|
||||
extra_info = []
|
||||
with open(input, "r", encoding="utf8") as f:
|
||||
for line in f:
|
||||
texts.append(line.strip())
|
||||
extra = []
|
||||
text = []
|
||||
toks = line.strip().split()
|
||||
for tok in toks:
|
||||
if tok[0] == ":" or tok[0] == "#" or tok[0] == "@":
|
||||
extra.append(tok)
|
||||
else:
|
||||
text.append(tok)
|
||||
texts.append(" ".join(text))
|
||||
extra_info.append(extra)
|
||||
|
||||
encoded_texts = text2token(
|
||||
texts, tokens=tokens, tokens_type=tokens_type, bpe_model=bpe_model
|
||||
)
|
||||
with open(output, "w", encoding="utf8") as f:
|
||||
for txt in encoded_texts:
|
||||
for i, txt in enumerate(encoded_texts):
|
||||
txt += extra_info[i]
|
||||
f.write(" ".join(txt) + "\n")
|
||||
|
||||
@@ -6,6 +6,9 @@ from typing import List, Optional, Union
|
||||
|
||||
import sentencepiece as spm
|
||||
|
||||
from pypinyin import pinyin
|
||||
from pypinyin.contrib.tone_convert import to_initials, to_finals_tone
|
||||
|
||||
|
||||
def text2token(
|
||||
texts: List[str],
|
||||
@@ -23,7 +26,9 @@ def text2token(
|
||||
tokens:
|
||||
The path of the tokens.txt.
|
||||
tokens_type:
|
||||
The valid values are cjkchar, bpe, cjkchar+bpe.
|
||||
The valid values are cjkchar, bpe, cjkchar+bpe, fpinyin, ppinyin.
|
||||
fpinyin means full pinyin, each cjkchar has a pinyin(with tone).
|
||||
ppinyin means partial pinyin, it splits pinyin into initial and final,
|
||||
bpe_model:
|
||||
The path of the bpe model. Only required when tokens_type is bpe or
|
||||
cjkchar+bpe.
|
||||
@@ -53,6 +58,24 @@ def text2token(
|
||||
texts_list = [list("".join(text.split())) for text in texts]
|
||||
elif tokens_type == "bpe":
|
||||
texts_list = sp.encode(texts, out_type=str)
|
||||
elif "pinyin" in tokens_type:
|
||||
for txt in texts:
|
||||
py = [x[0] for x in pinyin(txt)]
|
||||
if "ppinyin" == tokens_type:
|
||||
res = []
|
||||
for x in py:
|
||||
initial = to_initials(x, strict=False)
|
||||
final = to_finals_tone(x, strict=False)
|
||||
if initial == "" and final == "":
|
||||
res.append(x)
|
||||
else:
|
||||
if initial != "":
|
||||
res.append(initial)
|
||||
if final != "":
|
||||
res.append(final)
|
||||
texts_list.append(res)
|
||||
else:
|
||||
texts_list.append(py)
|
||||
else:
|
||||
assert (
|
||||
tokens_type == "cjkchar+bpe"
|
||||
|
||||