decoder for open vocabulary keyword spotting (#505)
* various fixes to ContextGraph to support open vocabulary keywords decoder * Add keyword spotter runtime * Add binary * First version works * Minor fixes * update text2token * default values * Add jni for kws * add kws android project * Minor fixes * Remove unused interface * Minor fixes * Add workflow * handle extra info in texts * Minor fixes * Add more comments * Fix ci * fix cpp style * Add input box in android demo so that users can specify their keywords * Fix cpp style * Fix comments * Minor fixes * Minor fixes * minor fixes * Minor fixes * Minor fixes * Add CI * Fix code style * cpplint * Fix comments * Fix error
68
.github/scripts/test-kws.sh
vendored
Executable file
@@ -0,0 +1,68 @@
|
|||||||
|
#!/usr/bin/env bash
|
||||||
|
|
||||||
|
set -e
|
||||||
|
|
||||||
|
log() {
|
||||||
|
# This function is from espnet
|
||||||
|
local fname=${BASH_SOURCE[1]##*/}
|
||||||
|
echo -e "$(date '+%Y-%m-%d %H:%M:%S') (${fname}:${BASH_LINENO[0]}:${FUNCNAME[1]}) $*"
|
||||||
|
}
|
||||||
|
|
||||||
|
echo "EXE is $EXE"
|
||||||
|
echo "PATH: $PATH"
|
||||||
|
|
||||||
|
which $EXE
|
||||||
|
|
||||||
|
log "------------------------------------------------------------"
|
||||||
|
log "Run Chinese keyword spotting (Wenetspeech)"
|
||||||
|
log "------------------------------------------------------------"
|
||||||
|
|
||||||
|
repo_url=https://www.modelscope.cn/pkufool/sherpa-onnx-kws-zipformer-wenetspeech-3.3M-2024-01-01.git
|
||||||
|
log "Start testing ${repo_url}"
|
||||||
|
repo=sherpa-onnx-kws-zipformer-wenetspeech-3.3M-2024-01-01
|
||||||
|
log "Download pretrained model and test-data from $repo_url"
|
||||||
|
|
||||||
|
GIT_LFS_SKIP_SMUDGE=1 git clone $repo_url
|
||||||
|
pushd $repo
|
||||||
|
git lfs pull --include "*.onnx"
|
||||||
|
ls -lh *.onnx
|
||||||
|
popd
|
||||||
|
|
||||||
|
time $EXE \
|
||||||
|
--tokens=$repo/tokens.txt \
|
||||||
|
--encoder=$repo/encoder-epoch-12-avg-2-chunk-16-left-64.onnx \
|
||||||
|
--decoder=$repo/decoder-epoch-12-avg-2-chunk-16-left-64.onnx \
|
||||||
|
--joiner=$repo/joiner-epoch-12-avg-2-chunk-16-left-64.onnx \
|
||||||
|
--keywords-file=$repo/test_wavs/test_keywords.txt \
|
||||||
|
--max-active-paths=4 \
|
||||||
|
--num-threads=4 \
|
||||||
|
$repo/test_wavs/3.wav $repo/test_wavs/4.wav $repo/test_wavs/5.wav $repo/test_wavs/6.wav
|
||||||
|
|
||||||
|
rm -rf $repo
|
||||||
|
|
||||||
|
log "------------------------------------------------------------"
|
||||||
|
log "Run English keyword spotting (Gigaspeech)"
|
||||||
|
log "------------------------------------------------------------"
|
||||||
|
|
||||||
|
repo_url=https://www.modelscope.cn/pkufool/sherpa-onnx-kws-zipformer-gigaspeech-3.3M-2024-01-01.git
|
||||||
|
log "Start testing ${repo_url}"
|
||||||
|
repo=sherpa-onnx-kws-zipformer-gigaspeech-3.3M-2024-01-01
|
||||||
|
log "Download pretrained model and test-data from $repo_url"
|
||||||
|
|
||||||
|
GIT_LFS_SKIP_SMUDGE=1 git clone $repo_url
|
||||||
|
pushd $repo
|
||||||
|
git lfs pull --include "*.onnx"
|
||||||
|
ls -lh *.onnx
|
||||||
|
popd
|
||||||
|
|
||||||
|
time $EXE \
|
||||||
|
--tokens=$repo/tokens.txt \
|
||||||
|
--encoder=$repo/encoder-epoch-12-avg-2-chunk-16-left-64.onnx \
|
||||||
|
--decoder=$repo/decoder-epoch-12-avg-2-chunk-16-left-64.onnx \
|
||||||
|
--joiner=$repo/joiner-epoch-12-avg-2-chunk-16-left-64.onnx \
|
||||||
|
--keywords-file=$repo/test_wavs/test_keywords.txt \
|
||||||
|
--max-active-paths=4 \
|
||||||
|
--num-threads=4 \
|
||||||
|
$repo/test_wavs/0.wav $repo/test_wavs/1.wav
|
||||||
|
|
||||||
|
rm -rf $repo
|
||||||
67
.github/workflows/apk-kws.yaml
vendored
Normal file
@@ -0,0 +1,67 @@
|
|||||||
|
name: apk-kws
|
||||||
|
|
||||||
|
on:
|
||||||
|
push:
|
||||||
|
branches:
|
||||||
|
- apk-kws
|
||||||
|
tags:
|
||||||
|
- '*'
|
||||||
|
|
||||||
|
workflow_dispatch:
|
||||||
|
|
||||||
|
concurrency:
|
||||||
|
group: apk-kws-${{ github.ref }}
|
||||||
|
cancel-in-progress: true
|
||||||
|
|
||||||
|
permissions:
|
||||||
|
contents: write
|
||||||
|
|
||||||
|
jobs:
|
||||||
|
apk:
|
||||||
|
runs-on: ${{ matrix.os }}
|
||||||
|
strategy:
|
||||||
|
fail-fast: false
|
||||||
|
matrix:
|
||||||
|
os: [ubuntu-latest]
|
||||||
|
|
||||||
|
steps:
|
||||||
|
- uses: actions/checkout@v4
|
||||||
|
with:
|
||||||
|
fetch-depth: 0
|
||||||
|
|
||||||
|
- name: ccache
|
||||||
|
uses: hendrikmuhs/ccache-action@v1.2
|
||||||
|
with:
|
||||||
|
key: ${{ matrix.os }}-android
|
||||||
|
|
||||||
|
- name: Display NDK HOME
|
||||||
|
shell: bash
|
||||||
|
run: |
|
||||||
|
echo "ANDROID_NDK_LATEST_HOME: ${ANDROID_NDK_LATEST_HOME}"
|
||||||
|
ls -lh ${ANDROID_NDK_LATEST_HOME}
|
||||||
|
|
||||||
|
- name: build APK
|
||||||
|
shell: bash
|
||||||
|
run: |
|
||||||
|
export CMAKE_CXX_COMPILER_LAUNCHER=ccache
|
||||||
|
export PATH="/usr/lib/ccache:/usr/local/opt/ccache/libexec:$PATH"
|
||||||
|
cmake --version
|
||||||
|
|
||||||
|
export ANDROID_NDK=$ANDROID_NDK_LATEST_HOME
|
||||||
|
./build-kws-apk.sh
|
||||||
|
|
||||||
|
- name: Display APK
|
||||||
|
shell: bash
|
||||||
|
run: |
|
||||||
|
ls -lh ./apks/
|
||||||
|
|
||||||
|
- uses: actions/upload-artifact@v3
|
||||||
|
with:
|
||||||
|
path: ./apks/*.apk
|
||||||
|
|
||||||
|
- name: Release APK
|
||||||
|
uses: svenstaro/upload-release-action@v2
|
||||||
|
with:
|
||||||
|
file_glob: true
|
||||||
|
file: apks/*.apk
|
||||||
|
overwrite: true
|
||||||
8
.github/workflows/linux.yaml
vendored
@@ -107,6 +107,14 @@ jobs:
|
|||||||
name: release-static
|
name: release-static
|
||||||
path: build/bin/*
|
path: build/bin/*
|
||||||
|
|
||||||
|
- name: Test transducer kws
|
||||||
|
shell: bash
|
||||||
|
run: |
|
||||||
|
export PATH=$PWD/build/bin:$PATH
|
||||||
|
export EXE=sherpa-onnx-keyword-spotter
|
||||||
|
|
||||||
|
.github/scripts/test-kws.sh
|
||||||
|
|
||||||
- name: Test online CTC
|
- name: Test online CTC
|
||||||
shell: bash
|
shell: bash
|
||||||
run: |
|
run: |
|
||||||
|
|||||||
9
.github/workflows/macos.yaml
vendored
@@ -98,6 +98,14 @@ jobs:
|
|||||||
otool -L build/bin/sherpa-onnx
|
otool -L build/bin/sherpa-onnx
|
||||||
otool -l build/bin/sherpa-onnx
|
otool -l build/bin/sherpa-onnx
|
||||||
|
|
||||||
|
- name: Test transducer kws
|
||||||
|
shell: bash
|
||||||
|
run: |
|
||||||
|
export PATH=$PWD/build/bin:$PATH
|
||||||
|
export EXE=sherpa-onnx-keyword-spotter
|
||||||
|
|
||||||
|
.github/scripts/test-kws.sh
|
||||||
|
|
||||||
- name: Test online CTC
|
- name: Test online CTC
|
||||||
shell: bash
|
shell: bash
|
||||||
run: |
|
run: |
|
||||||
@@ -106,7 +114,6 @@ jobs:
|
|||||||
|
|
||||||
.github/scripts/test-online-ctc.sh
|
.github/scripts/test-online-ctc.sh
|
||||||
|
|
||||||
|
|
||||||
- name: Test offline TTS
|
- name: Test offline TTS
|
||||||
shell: bash
|
shell: bash
|
||||||
run: |
|
run: |
|
||||||
|
|||||||
2
.github/workflows/run-python-test.yaml
vendored
@@ -62,7 +62,7 @@ jobs:
|
|||||||
- name: Install Python dependencies
|
- name: Install Python dependencies
|
||||||
shell: bash
|
shell: bash
|
||||||
run: |
|
run: |
|
||||||
python3 -m pip install --upgrade pip numpy sentencepiece==0.1.96 soundfile
|
python3 -m pip install --upgrade pip numpy pypinyin sentencepiece==0.1.96 soundfile
|
||||||
|
|
||||||
- name: Install sherpa-onnx
|
- name: Install sherpa-onnx
|
||||||
shell: bash
|
shell: bash
|
||||||
|
|||||||
@@ -45,7 +45,7 @@ jobs:
|
|||||||
- name: Install Python dependencies
|
- name: Install Python dependencies
|
||||||
shell: bash
|
shell: bash
|
||||||
run: |
|
run: |
|
||||||
python3 -m pip install --upgrade pip numpy sentencepiece
|
python3 -m pip install --upgrade pip numpy pypinyin sentencepiece
|
||||||
|
|
||||||
- name: Install sherpa-onnx
|
- name: Install sherpa-onnx
|
||||||
shell: bash
|
shell: bash
|
||||||
|
|||||||
@@ -45,7 +45,7 @@ jobs:
|
|||||||
- name: Install Python dependencies
|
- name: Install Python dependencies
|
||||||
shell: bash
|
shell: bash
|
||||||
run: |
|
run: |
|
||||||
python3 -m pip install --upgrade pip numpy sentencepiece
|
python3 -m pip install --upgrade pip numpy pypinyin sentencepiece
|
||||||
|
|
||||||
- name: Install sherpa-onnx
|
- name: Install sherpa-onnx
|
||||||
shell: bash
|
shell: bash
|
||||||
|
|||||||
15
android/SherpaOnnxKws/.gitignore
vendored
Normal file
@@ -0,0 +1,15 @@
|
|||||||
|
*.iml
|
||||||
|
.gradle
|
||||||
|
/local.properties
|
||||||
|
/.idea/caches
|
||||||
|
/.idea/libraries
|
||||||
|
/.idea/modules.xml
|
||||||
|
/.idea/workspace.xml
|
||||||
|
/.idea/navEditor.xml
|
||||||
|
/.idea/assetWizardSettings.xml
|
||||||
|
.DS_Store
|
||||||
|
/build
|
||||||
|
/captures
|
||||||
|
.externalNativeBuild
|
||||||
|
.cxx
|
||||||
|
local.properties
|
||||||
1
android/SherpaOnnxKws/app/.gitignore
vendored
Normal file
@@ -0,0 +1 @@
|
|||||||
|
/build
|
||||||
44
android/SherpaOnnxKws/app/build.gradle
Normal file
@@ -0,0 +1,44 @@
|
|||||||
|
plugins {
|
||||||
|
id 'com.android.application'
|
||||||
|
id 'org.jetbrains.kotlin.android'
|
||||||
|
}
|
||||||
|
|
||||||
|
android {
|
||||||
|
namespace 'com.k2fsa.sherpa.onnx'
|
||||||
|
compileSdk 32
|
||||||
|
|
||||||
|
defaultConfig {
|
||||||
|
applicationId "com.k2fsa.sherpa.onnx"
|
||||||
|
minSdk 21
|
||||||
|
targetSdk 32
|
||||||
|
versionCode 1
|
||||||
|
versionName "1.0"
|
||||||
|
|
||||||
|
testInstrumentationRunner "androidx.test.runner.AndroidJUnitRunner"
|
||||||
|
}
|
||||||
|
|
||||||
|
buildTypes {
|
||||||
|
release {
|
||||||
|
minifyEnabled false
|
||||||
|
proguardFiles getDefaultProguardFile('proguard-android-optimize.txt'), 'proguard-rules.pro'
|
||||||
|
}
|
||||||
|
}
|
||||||
|
compileOptions {
|
||||||
|
sourceCompatibility JavaVersion.VERSION_1_8
|
||||||
|
targetCompatibility JavaVersion.VERSION_1_8
|
||||||
|
}
|
||||||
|
kotlinOptions {
|
||||||
|
jvmTarget = '1.8'
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
dependencies {
|
||||||
|
|
||||||
|
implementation 'androidx.core:core-ktx:1.7.0'
|
||||||
|
implementation 'androidx.appcompat:appcompat:1.5.1'
|
||||||
|
implementation 'com.google.android.material:material:1.7.0'
|
||||||
|
implementation 'androidx.constraintlayout:constraintlayout:2.1.4'
|
||||||
|
testImplementation 'junit:junit:4.13.2'
|
||||||
|
androidTestImplementation 'androidx.test.ext:junit:1.1.4'
|
||||||
|
androidTestImplementation 'androidx.test.espresso:espresso-core:3.5.0'
|
||||||
|
}
|
||||||
21
android/SherpaOnnxKws/app/proguard-rules.pro
vendored
Normal file
@@ -0,0 +1,21 @@
|
|||||||
|
# Add project specific ProGuard rules here.
|
||||||
|
# You can control the set of applied configuration files using the
|
||||||
|
# proguardFiles setting in build.gradle.
|
||||||
|
#
|
||||||
|
# For more details, see
|
||||||
|
# http://developer.android.com/guide/developing/tools/proguard.html
|
||||||
|
|
||||||
|
# If your project uses WebView with JS, uncomment the following
|
||||||
|
# and specify the fully qualified class name to the JavaScript interface
|
||||||
|
# class:
|
||||||
|
#-keepclassmembers class fqcn.of.javascript.interface.for.webview {
|
||||||
|
# public *;
|
||||||
|
#}
|
||||||
|
|
||||||
|
# Uncomment this to preserve the line number information for
|
||||||
|
# debugging stack traces.
|
||||||
|
#-keepattributes SourceFile,LineNumberTable
|
||||||
|
|
||||||
|
# If you keep the line number information, uncomment this to
|
||||||
|
# hide the original source file name.
|
||||||
|
#-renamesourcefileattribute SourceFile
|
||||||
@@ -0,0 +1,24 @@
|
|||||||
|
package com.k2fsa.sherpa.onnx
|
||||||
|
|
||||||
|
import androidx.test.platform.app.InstrumentationRegistry
|
||||||
|
import androidx.test.ext.junit.runners.AndroidJUnit4
|
||||||
|
|
||||||
|
import org.junit.Test
|
||||||
|
import org.junit.runner.RunWith
|
||||||
|
|
||||||
|
import org.junit.Assert.*
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Instrumented test, which will execute on an Android device.
|
||||||
|
*
|
||||||
|
* See [testing documentation](http://d.android.com/tools/testing).
|
||||||
|
*/
|
||||||
|
@RunWith(AndroidJUnit4::class)
|
||||||
|
class ExampleInstrumentedTest {
|
||||||
|
@Test
|
||||||
|
fun useAppContext() {
|
||||||
|
// Context of the app under test.
|
||||||
|
val appContext = InstrumentationRegistry.getInstrumentation().targetContext
|
||||||
|
assertEquals("com.k2fsa.sherpa.onnx", appContext.packageName)
|
||||||
|
}
|
||||||
|
}
|
||||||
32
android/SherpaOnnxKws/app/src/main/AndroidManifest.xml
Normal file
@@ -0,0 +1,32 @@
|
|||||||
|
<?xml version="1.0" encoding="utf-8"?>
|
||||||
|
<manifest xmlns:android="http://schemas.android.com/apk/res/android"
|
||||||
|
xmlns:tools="http://schemas.android.com/tools">
|
||||||
|
|
||||||
|
<uses-permission android:name="android.permission.RECORD_AUDIO" />
|
||||||
|
|
||||||
|
<application
|
||||||
|
android:allowBackup="true"
|
||||||
|
android:dataExtractionRules="@xml/data_extraction_rules"
|
||||||
|
android:fullBackupContent="@xml/backup_rules"
|
||||||
|
android:icon="@mipmap/ic_launcher"
|
||||||
|
android:label="@string/app_name"
|
||||||
|
android:roundIcon="@mipmap/ic_launcher_round"
|
||||||
|
android:supportsRtl="true"
|
||||||
|
android:theme="@style/Theme.SherpaOnnx"
|
||||||
|
tools:targetApi="31">
|
||||||
|
<activity
|
||||||
|
android:name=".MainActivity"
|
||||||
|
android:exported="true">
|
||||||
|
<intent-filter>
|
||||||
|
<action android:name="android.intent.action.MAIN" />
|
||||||
|
|
||||||
|
<category android:name="android.intent.category.LAUNCHER" />
|
||||||
|
</intent-filter>
|
||||||
|
|
||||||
|
<meta-data
|
||||||
|
android:name="android.app.lib_name"
|
||||||
|
android:value="" />
|
||||||
|
</activity>
|
||||||
|
</application>
|
||||||
|
|
||||||
|
</manifest>
|
||||||
0
android/SherpaOnnxKws/app/src/main/assets/.gitkeep
Normal file
@@ -0,0 +1,207 @@
|
|||||||
|
package com.k2fsa.sherpa.onnx
|
||||||
|
|
||||||
|
import android.Manifest
|
||||||
|
import android.content.pm.PackageManager
|
||||||
|
import android.media.AudioFormat
|
||||||
|
import android.media.AudioRecord
|
||||||
|
import android.media.MediaRecorder
|
||||||
|
import android.os.Bundle
|
||||||
|
import android.text.method.ScrollingMovementMethod
|
||||||
|
import android.util.Log
|
||||||
|
import android.widget.Button
|
||||||
|
import android.widget.EditText
|
||||||
|
import android.widget.TextView
|
||||||
|
import android.widget.Toast
|
||||||
|
import androidx.appcompat.app.AppCompatActivity
|
||||||
|
import androidx.core.app.ActivityCompat
|
||||||
|
import com.k2fsa.sherpa.onnx.*
|
||||||
|
import kotlin.concurrent.thread
|
||||||
|
|
||||||
|
private const val TAG = "sherpa-onnx"
|
||||||
|
private const val REQUEST_RECORD_AUDIO_PERMISSION = 200
|
||||||
|
|
||||||
|
class MainActivity : AppCompatActivity() {
|
||||||
|
private val permissions: Array<String> = arrayOf(Manifest.permission.RECORD_AUDIO)
|
||||||
|
|
||||||
|
private lateinit var model: SherpaOnnxKws
|
||||||
|
private var audioRecord: AudioRecord? = null
|
||||||
|
private lateinit var recordButton: Button
|
||||||
|
private lateinit var textView: TextView
|
||||||
|
private lateinit var inputText: EditText
|
||||||
|
private var recordingThread: Thread? = null
|
||||||
|
|
||||||
|
private val audioSource = MediaRecorder.AudioSource.MIC
|
||||||
|
private val sampleRateInHz = 16000
|
||||||
|
private val channelConfig = AudioFormat.CHANNEL_IN_MONO
|
||||||
|
|
||||||
|
// Note: We don't use AudioFormat.ENCODING_PCM_FLOAT
|
||||||
|
// since the AudioRecord.read(float[]) needs API level >= 23
|
||||||
|
// but we are targeting API level >= 21
|
||||||
|
private val audioFormat = AudioFormat.ENCODING_PCM_16BIT
|
||||||
|
private var idx: Int = 0
|
||||||
|
private var lastText: String = ""
|
||||||
|
|
||||||
|
@Volatile
|
||||||
|
private var isRecording: Boolean = false
|
||||||
|
|
||||||
|
override fun onRequestPermissionsResult(
|
||||||
|
requestCode: Int, permissions: Array<String>, grantResults: IntArray
|
||||||
|
) {
|
||||||
|
super.onRequestPermissionsResult(requestCode, permissions, grantResults)
|
||||||
|
val permissionToRecordAccepted = if (requestCode == REQUEST_RECORD_AUDIO_PERMISSION) {
|
||||||
|
grantResults[0] == PackageManager.PERMISSION_GRANTED
|
||||||
|
} else {
|
||||||
|
false
|
||||||
|
}
|
||||||
|
|
||||||
|
if (!permissionToRecordAccepted) {
|
||||||
|
Log.e(TAG, "Audio record is disallowed")
|
||||||
|
finish()
|
||||||
|
}
|
||||||
|
|
||||||
|
Log.i(TAG, "Audio record is permitted")
|
||||||
|
}
|
||||||
|
|
||||||
|
override fun onCreate(savedInstanceState: Bundle?) {
|
||||||
|
super.onCreate(savedInstanceState)
|
||||||
|
setContentView(R.layout.activity_main)
|
||||||
|
|
||||||
|
ActivityCompat.requestPermissions(this, permissions, REQUEST_RECORD_AUDIO_PERMISSION)
|
||||||
|
|
||||||
|
Log.i(TAG, "Start to initialize model")
|
||||||
|
initModel()
|
||||||
|
Log.i(TAG, "Finished initializing model")
|
||||||
|
|
||||||
|
recordButton = findViewById(R.id.record_button)
|
||||||
|
recordButton.setOnClickListener { onclick() }
|
||||||
|
|
||||||
|
textView = findViewById(R.id.my_text)
|
||||||
|
textView.movementMethod = ScrollingMovementMethod()
|
||||||
|
|
||||||
|
inputText = findViewById(R.id.input_text)
|
||||||
|
}
|
||||||
|
|
||||||
|
private fun onclick() {
|
||||||
|
if (!isRecording) {
|
||||||
|
var keywords = inputText.text.toString()
|
||||||
|
|
||||||
|
Log.i(TAG, keywords)
|
||||||
|
keywords = keywords.replace("\n", "/")
|
||||||
|
// If keywords is an empty string, it just resets the decoding stream
|
||||||
|
// always returns true in this case.
|
||||||
|
// If keywords is not empty, it will create a new decoding stream with
|
||||||
|
// the given keywords appended to the default keywords.
|
||||||
|
// Return false if errors occured when adding keywords, true otherwise.
|
||||||
|
val status = model.reset(keywords)
|
||||||
|
if (!status) {
|
||||||
|
Log.i(TAG, "Failed to reset with keywords.")
|
||||||
|
Toast.makeText(this, "Failed to set keywords.", Toast.LENGTH_LONG).show();
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
val ret = initMicrophone()
|
||||||
|
if (!ret) {
|
||||||
|
Log.e(TAG, "Failed to initialize microphone")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
Log.i(TAG, "state: ${audioRecord?.state}")
|
||||||
|
audioRecord!!.startRecording()
|
||||||
|
recordButton.setText(R.string.stop)
|
||||||
|
isRecording = true
|
||||||
|
textView.text = ""
|
||||||
|
lastText = ""
|
||||||
|
idx = 0
|
||||||
|
|
||||||
|
recordingThread = thread(true) {
|
||||||
|
processSamples()
|
||||||
|
}
|
||||||
|
Log.i(TAG, "Started recording")
|
||||||
|
} else {
|
||||||
|
isRecording = false
|
||||||
|
audioRecord!!.stop()
|
||||||
|
audioRecord!!.release()
|
||||||
|
audioRecord = null
|
||||||
|
recordButton.setText(R.string.start)
|
||||||
|
Log.i(TAG, "Stopped recording")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
private fun processSamples() {
|
||||||
|
Log.i(TAG, "processing samples")
|
||||||
|
|
||||||
|
val interval = 0.1 // i.e., 100 ms
|
||||||
|
val bufferSize = (interval * sampleRateInHz).toInt() // in samples
|
||||||
|
val buffer = ShortArray(bufferSize)
|
||||||
|
|
||||||
|
while (isRecording) {
|
||||||
|
val ret = audioRecord?.read(buffer, 0, buffer.size)
|
||||||
|
if (ret != null && ret > 0) {
|
||||||
|
val samples = FloatArray(ret) { buffer[it] / 32768.0f }
|
||||||
|
model.acceptWaveform(samples, sampleRate=sampleRateInHz)
|
||||||
|
while (model.isReady()) {
|
||||||
|
model.decode()
|
||||||
|
}
|
||||||
|
|
||||||
|
val text = model.keyword
|
||||||
|
|
||||||
|
var textToDisplay = lastText;
|
||||||
|
|
||||||
|
if(text.isNotBlank()) {
|
||||||
|
if (lastText.isBlank()) {
|
||||||
|
textToDisplay = "${idx}: ${text}"
|
||||||
|
} else {
|
||||||
|
textToDisplay = "${idx}: ${text}\n${lastText}"
|
||||||
|
}
|
||||||
|
lastText = "${idx}: ${text}\n${lastText}"
|
||||||
|
idx += 1
|
||||||
|
}
|
||||||
|
|
||||||
|
runOnUiThread {
|
||||||
|
textView.text = textToDisplay
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
private fun initMicrophone(): Boolean {
|
||||||
|
if (ActivityCompat.checkSelfPermission(
|
||||||
|
this, Manifest.permission.RECORD_AUDIO
|
||||||
|
) != PackageManager.PERMISSION_GRANTED
|
||||||
|
) {
|
||||||
|
ActivityCompat.requestPermissions(this, permissions, REQUEST_RECORD_AUDIO_PERMISSION)
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
val numBytes = AudioRecord.getMinBufferSize(sampleRateInHz, channelConfig, audioFormat)
|
||||||
|
Log.i(
|
||||||
|
TAG, "buffer size in milliseconds: ${numBytes * 1000.0f / sampleRateInHz}"
|
||||||
|
)
|
||||||
|
|
||||||
|
audioRecord = AudioRecord(
|
||||||
|
audioSource,
|
||||||
|
sampleRateInHz,
|
||||||
|
channelConfig,
|
||||||
|
audioFormat,
|
||||||
|
numBytes * 2 // a sample has two bytes as we are using 16-bit PCM
|
||||||
|
)
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
|
private fun initModel() {
|
||||||
|
// Please change getModelConfig() to add new models
|
||||||
|
// See https://k2-fsa.github.io/sherpa/onnx/kws/pretrained_models/index.html
|
||||||
|
// for a list of available models
|
||||||
|
val type = 0
|
||||||
|
Log.i(TAG, "Select model type ${type}")
|
||||||
|
val config = KeywordSpotterConfig(
|
||||||
|
featConfig = getFeatureConfig(sampleRate = sampleRateInHz, featureDim = 80),
|
||||||
|
modelConfig = getModelConfig(type = type)!!,
|
||||||
|
keywordsFile = getKeywordsFile(type = type)!!,
|
||||||
|
)
|
||||||
|
|
||||||
|
model = SherpaOnnxKws(
|
||||||
|
assetManager = application.assets,
|
||||||
|
config = config,
|
||||||
|
)
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -0,0 +1,162 @@
|
|||||||
|
// Copyright (c) 2024 Xiaomi Corporation
|
||||||
|
package com.k2fsa.sherpa.onnx
|
||||||
|
|
||||||
|
import android.content.res.AssetManager
|
||||||
|
|
||||||
|
data class OnlineTransducerModelConfig(
|
||||||
|
var encoder: String = "",
|
||||||
|
var decoder: String = "",
|
||||||
|
var joiner: String = "",
|
||||||
|
)
|
||||||
|
|
||||||
|
data class OnlineModelConfig(
|
||||||
|
var transducer: OnlineTransducerModelConfig = OnlineTransducerModelConfig(),
|
||||||
|
var tokens: String,
|
||||||
|
var numThreads: Int = 1,
|
||||||
|
var debug: Boolean = false,
|
||||||
|
var provider: String = "cpu",
|
||||||
|
var modelType: String = "",
|
||||||
|
)
|
||||||
|
|
||||||
|
data class FeatureConfig(
|
||||||
|
var sampleRate: Int = 16000,
|
||||||
|
var featureDim: Int = 80,
|
||||||
|
)
|
||||||
|
|
||||||
|
data class KeywordSpotterConfig(
|
||||||
|
var featConfig: FeatureConfig = FeatureConfig(),
|
||||||
|
var modelConfig: OnlineModelConfig,
|
||||||
|
var maxActivePaths: Int = 4,
|
||||||
|
var keywordsFile: String = "keywords.txt",
|
||||||
|
var keywordsScore: Float = 1.5f,
|
||||||
|
var keywordsThreshold: Float = 0.25f,
|
||||||
|
var numTrailingBlanks: Int = 2,
|
||||||
|
)
|
||||||
|
|
||||||
|
class SherpaOnnxKws(
|
||||||
|
assetManager: AssetManager? = null,
|
||||||
|
var config: KeywordSpotterConfig,
|
||||||
|
) {
|
||||||
|
private val ptr: Long
|
||||||
|
|
||||||
|
init {
|
||||||
|
if (assetManager != null) {
|
||||||
|
ptr = new(assetManager, config)
|
||||||
|
} else {
|
||||||
|
ptr = newFromFile(config)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
protected fun finalize() {
|
||||||
|
delete(ptr)
|
||||||
|
}
|
||||||
|
|
||||||
|
fun acceptWaveform(samples: FloatArray, sampleRate: Int) =
|
||||||
|
acceptWaveform(ptr, samples, sampleRate)
|
||||||
|
|
||||||
|
fun inputFinished() = inputFinished(ptr)
|
||||||
|
fun decode() = decode(ptr)
|
||||||
|
fun isReady(): Boolean = isReady(ptr)
|
||||||
|
fun reset(keywords: String): Boolean = reset(ptr, keywords)
|
||||||
|
|
||||||
|
val keyword: String
|
||||||
|
get() = getKeyword(ptr)
|
||||||
|
|
||||||
|
private external fun delete(ptr: Long)
|
||||||
|
|
||||||
|
private external fun new(
|
||||||
|
assetManager: AssetManager,
|
||||||
|
config: KeywordSpotterConfig,
|
||||||
|
): Long
|
||||||
|
|
||||||
|
private external fun newFromFile(
|
||||||
|
config: KeywordSpotterConfig,
|
||||||
|
): Long
|
||||||
|
|
||||||
|
private external fun acceptWaveform(ptr: Long, samples: FloatArray, sampleRate: Int)
|
||||||
|
private external fun inputFinished(ptr: Long)
|
||||||
|
private external fun getKeyword(ptr: Long): String
|
||||||
|
private external fun reset(ptr: Long, keywords: String): Boolean
|
||||||
|
private external fun decode(ptr: Long)
|
||||||
|
private external fun isReady(ptr: Long): Boolean
|
||||||
|
|
||||||
|
companion object {
|
||||||
|
init {
|
||||||
|
System.loadLibrary("sherpa-onnx-jni")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fun getFeatureConfig(sampleRate: Int, featureDim: Int): FeatureConfig {
|
||||||
|
return FeatureConfig(sampleRate = sampleRate, featureDim = featureDim)
|
||||||
|
}
|
||||||
|
|
||||||
|
/*
|
||||||
|
Please see
|
||||||
|
https://k2-fsa.github.io/sherpa/onnx/kws/pretrained_models/index.html
|
||||||
|
for a list of pre-trained models.
|
||||||
|
|
||||||
|
We only add a few here. Please change the following code
|
||||||
|
to add your own. (It should be straightforward to add a new model
|
||||||
|
by following the code)
|
||||||
|
|
||||||
|
@param type
|
||||||
|
0 - sherpa-onnx-kws-zipformer-wenetspeech-3.3M-2024-01-01 (Chinese)
|
||||||
|
https://www.modelscope.cn/models/pkufool/sherpa-onnx-kws-zipformer-wenetspeech-3.3M-2024-01-01/summary
|
||||||
|
|
||||||
|
1 - sherpa-onnx-kws-zipformer-gigaspeech-3.3M-2024-01-01 (English)
|
||||||
|
https://www.modelscope.cn/models/pkufool/sherpa-onnx-kws-zipformer-gigaspeech-3.3M-2024-01-01/summary
|
||||||
|
|
||||||
|
*/
|
||||||
|
fun getModelConfig(type: Int): OnlineModelConfig? {
|
||||||
|
when (type) {
|
||||||
|
0 -> {
|
||||||
|
val modelDir = "sherpa-onnx-kws-zipformer-wenetspeech-3.3M-2024-01-01"
|
||||||
|
return OnlineModelConfig(
|
||||||
|
transducer = OnlineTransducerModelConfig(
|
||||||
|
encoder = "$modelDir/encoder-epoch-12-avg-2-chunk-16-left-64.onnx",
|
||||||
|
decoder = "$modelDir/decoder-epoch-12-avg-2-chunk-16-left-64.onnx",
|
||||||
|
joiner = "$modelDir/joiner-epoch-12-avg-2-chunk-16-left-64.onnx",
|
||||||
|
),
|
||||||
|
tokens = "$modelDir/tokens.txt",
|
||||||
|
modelType = "zipformer2",
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
1 -> {
|
||||||
|
val modelDir = "sherpa-onnx-kws-zipformer-gigaspeech-3.3M-2024-01-01"
|
||||||
|
return OnlineModelConfig(
|
||||||
|
transducer = OnlineTransducerModelConfig(
|
||||||
|
encoder = "$modelDir/encoder-epoch-12-avg-2-chunk-16-left-64.onnx",
|
||||||
|
decoder = "$modelDir/decoder-epoch-12-avg-2-chunk-16-left-64.onnx",
|
||||||
|
joiner = "$modelDir/joiner-epoch-12-avg-2-chunk-16-left-64.onnx",
|
||||||
|
),
|
||||||
|
tokens = "$modelDir/tokens.txt",
|
||||||
|
modelType = "zipformer2",
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
||||||
|
return null;
|
||||||
|
}
|
||||||
|
|
||||||
|
/*
|
||||||
|
* Get the default keywords for each model.
|
||||||
|
* Caution: The types and modelDir should be the same as those in getModelConfig
|
||||||
|
* function above.
|
||||||
|
*/
|
||||||
|
fun getKeywordsFile(type: Int) : String {
|
||||||
|
when (type) {
|
||||||
|
0 -> {
|
||||||
|
val modelDir = "sherpa-onnx-kws-zipformer-wenetspeech-3.3M-2024-01-01"
|
||||||
|
return "$modelDir/keywords.txt"
|
||||||
|
}
|
||||||
|
|
||||||
|
1 -> {
|
||||||
|
val modelDir = "sherpa-onnx-kws-zipformer-gigaspeech-3.3M-2024-01-01"
|
||||||
|
return "$modelDir/keywords.txt"
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
||||||
|
return "";
|
||||||
|
}
|
||||||
@@ -0,0 +1,29 @@
|
|||||||
|
// Copyright (c) 2023 Xiaomi Corporation
|
||||||
|
package com.k2fsa.sherpa.onnx
|
||||||
|
|
||||||
|
import android.content.res.AssetManager
|
||||||
|
|
||||||
|
class WaveReader {
|
||||||
|
companion object {
|
||||||
|
// Read a mono wave file asset
|
||||||
|
// The returned array has two entries:
|
||||||
|
// - the first entry contains an 1-D float array
|
||||||
|
// - the second entry is the sample rate
|
||||||
|
external fun readWaveFromAsset(
|
||||||
|
assetManager: AssetManager,
|
||||||
|
filename: String,
|
||||||
|
): Array<Any>
|
||||||
|
|
||||||
|
// Read a mono wave file from disk
|
||||||
|
// The returned array has two entries:
|
||||||
|
// - the first entry contains an 1-D float array
|
||||||
|
// - the second entry is the sample rate
|
||||||
|
external fun readWaveFromFile(
|
||||||
|
filename: String,
|
||||||
|
): Array<Any>
|
||||||
|
|
||||||
|
init {
|
||||||
|
System.loadLibrary("sherpa-onnx-jni")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
4
android/SherpaOnnxKws/app/src/main/jniLibs/.gitignore
vendored
Normal file
@@ -0,0 +1,4 @@
|
|||||||
|
*.so
|
||||||
|
*.txt
|
||||||
|
*.onnx
|
||||||
|
*.wav
|
||||||
@@ -0,0 +1,30 @@
|
|||||||
|
<vector xmlns:android="http://schemas.android.com/apk/res/android"
|
||||||
|
xmlns:aapt="http://schemas.android.com/aapt"
|
||||||
|
android:width="108dp"
|
||||||
|
android:height="108dp"
|
||||||
|
android:viewportWidth="108"
|
||||||
|
android:viewportHeight="108">
|
||||||
|
<path android:pathData="M31,63.928c0,0 6.4,-11 12.1,-13.1c7.2,-2.6 26,-1.4 26,-1.4l38.1,38.1L107,108.928l-32,-1L31,63.928z">
|
||||||
|
<aapt:attr name="android:fillColor">
|
||||||
|
<gradient
|
||||||
|
android:endX="85.84757"
|
||||||
|
android:endY="92.4963"
|
||||||
|
android:startX="42.9492"
|
||||||
|
android:startY="49.59793"
|
||||||
|
android:type="linear">
|
||||||
|
<item
|
||||||
|
android:color="#44000000"
|
||||||
|
android:offset="0.0" />
|
||||||
|
<item
|
||||||
|
android:color="#00000000"
|
||||||
|
android:offset="1.0" />
|
||||||
|
</gradient>
|
||||||
|
</aapt:attr>
|
||||||
|
</path>
|
||||||
|
<path
|
||||||
|
android:fillColor="#FFFFFF"
|
||||||
|
android:fillType="nonZero"
|
||||||
|
android:pathData="M65.3,45.828l3.8,-6.6c0.2,-0.4 0.1,-0.9 -0.3,-1.1c-0.4,-0.2 -0.9,-0.1 -1.1,0.3l-3.9,6.7c-6.3,-2.8 -13.4,-2.8 -19.7,0l-3.9,-6.7c-0.2,-0.4 -0.7,-0.5 -1.1,-0.3C38.8,38.328 38.7,38.828 38.9,39.228l3.8,6.6C36.2,49.428 31.7,56.028 31,63.928h46C76.3,56.028 71.8,49.428 65.3,45.828zM43.4,57.328c-0.8,0 -1.5,-0.5 -1.8,-1.2c-0.3,-0.7 -0.1,-1.5 0.4,-2.1c0.5,-0.5 1.4,-0.7 2.1,-0.4c0.7,0.3 1.2,1 1.2,1.8C45.3,56.528 44.5,57.328 43.4,57.328L43.4,57.328zM64.6,57.328c-0.8,0 -1.5,-0.5 -1.8,-1.2s-0.1,-1.5 0.4,-2.1c0.5,-0.5 1.4,-0.7 2.1,-0.4c0.7,0.3 1.2,1 1.2,1.8C66.5,56.528 65.6,57.328 64.6,57.328L64.6,57.328z"
|
||||||
|
android:strokeWidth="1"
|
||||||
|
android:strokeColor="#00000000" />
|
||||||
|
</vector>
|
||||||
@@ -0,0 +1,170 @@
|
|||||||
|
<?xml version="1.0" encoding="utf-8"?>
|
||||||
|
<vector xmlns:android="http://schemas.android.com/apk/res/android"
|
||||||
|
android:width="108dp"
|
||||||
|
android:height="108dp"
|
||||||
|
android:viewportWidth="108"
|
||||||
|
android:viewportHeight="108">
|
||||||
|
<path
|
||||||
|
android:fillColor="#3DDC84"
|
||||||
|
android:pathData="M0,0h108v108h-108z" />
|
||||||
|
<path
|
||||||
|
android:fillColor="#00000000"
|
||||||
|
android:pathData="M9,0L9,108"
|
||||||
|
android:strokeWidth="0.8"
|
||||||
|
android:strokeColor="#33FFFFFF" />
|
||||||
|
<path
|
||||||
|
android:fillColor="#00000000"
|
||||||
|
android:pathData="M19,0L19,108"
|
||||||
|
android:strokeWidth="0.8"
|
||||||
|
android:strokeColor="#33FFFFFF" />
|
||||||
|
<path
|
||||||
|
android:fillColor="#00000000"
|
||||||
|
android:pathData="M29,0L29,108"
|
||||||
|
android:strokeWidth="0.8"
|
||||||
|
android:strokeColor="#33FFFFFF" />
|
||||||
|
<path
|
||||||
|
android:fillColor="#00000000"
|
||||||
|
android:pathData="M39,0L39,108"
|
||||||
|
android:strokeWidth="0.8"
|
||||||
|
android:strokeColor="#33FFFFFF" />
|
||||||
|
<path
|
||||||
|
android:fillColor="#00000000"
|
||||||
|
android:pathData="M49,0L49,108"
|
||||||
|
android:strokeWidth="0.8"
|
||||||
|
android:strokeColor="#33FFFFFF" />
|
||||||
|
<path
|
||||||
|
android:fillColor="#00000000"
|
||||||
|
android:pathData="M59,0L59,108"
|
||||||
|
android:strokeWidth="0.8"
|
||||||
|
android:strokeColor="#33FFFFFF" />
|
||||||
|
<path
|
||||||
|
android:fillColor="#00000000"
|
||||||
|
android:pathData="M69,0L69,108"
|
||||||
|
android:strokeWidth="0.8"
|
||||||
|
android:strokeColor="#33FFFFFF" />
|
||||||
|
<path
|
||||||
|
android:fillColor="#00000000"
|
||||||
|
android:pathData="M79,0L79,108"
|
||||||
|
android:strokeWidth="0.8"
|
||||||
|
android:strokeColor="#33FFFFFF" />
|
||||||
|
<path
|
||||||
|
android:fillColor="#00000000"
|
||||||
|
android:pathData="M89,0L89,108"
|
||||||
|
android:strokeWidth="0.8"
|
||||||
|
android:strokeColor="#33FFFFFF" />
|
||||||
|
<path
|
||||||
|
android:fillColor="#00000000"
|
||||||
|
android:pathData="M99,0L99,108"
|
||||||
|
android:strokeWidth="0.8"
|
||||||
|
android:strokeColor="#33FFFFFF" />
|
||||||
|
<path
|
||||||
|
android:fillColor="#00000000"
|
||||||
|
android:pathData="M0,9L108,9"
|
||||||
|
android:strokeWidth="0.8"
|
||||||
|
android:strokeColor="#33FFFFFF" />
|
||||||
|
<path
|
||||||
|
android:fillColor="#00000000"
|
||||||
|
android:pathData="M0,19L108,19"
|
||||||
|
android:strokeWidth="0.8"
|
||||||
|
android:strokeColor="#33FFFFFF" />
|
||||||
|
<path
|
||||||
|
android:fillColor="#00000000"
|
||||||
|
android:pathData="M0,29L108,29"
|
||||||
|
android:strokeWidth="0.8"
|
||||||
|
android:strokeColor="#33FFFFFF" />
|
||||||
|
<path
|
||||||
|
android:fillColor="#00000000"
|
||||||
|
android:pathData="M0,39L108,39"
|
||||||
|
android:strokeWidth="0.8"
|
||||||
|
android:strokeColor="#33FFFFFF" />
|
||||||
|
<path
|
||||||
|
android:fillColor="#00000000"
|
||||||
|
android:pathData="M0,49L108,49"
|
||||||
|
android:strokeWidth="0.8"
|
||||||
|
android:strokeColor="#33FFFFFF" />
|
||||||
|
<path
|
||||||
|
android:fillColor="#00000000"
|
||||||
|
android:pathData="M0,59L108,59"
|
||||||
|
android:strokeWidth="0.8"
|
||||||
|
android:strokeColor="#33FFFFFF" />
|
||||||
|
<path
|
||||||
|
android:fillColor="#00000000"
|
||||||
|
android:pathData="M0,69L108,69"
|
||||||
|
android:strokeWidth="0.8"
|
||||||
|
android:strokeColor="#33FFFFFF" />
|
||||||
|
<path
|
||||||
|
android:fillColor="#00000000"
|
||||||
|
android:pathData="M0,79L108,79"
|
||||||
|
android:strokeWidth="0.8"
|
||||||
|
android:strokeColor="#33FFFFFF" />
|
||||||
|
<path
|
||||||
|
android:fillColor="#00000000"
|
||||||
|
android:pathData="M0,89L108,89"
|
||||||
|
android:strokeWidth="0.8"
|
||||||
|
android:strokeColor="#33FFFFFF" />
|
||||||
|
<path
|
||||||
|
android:fillColor="#00000000"
|
||||||
|
android:pathData="M0,99L108,99"
|
||||||
|
android:strokeWidth="0.8"
|
||||||
|
android:strokeColor="#33FFFFFF" />
|
||||||
|
<path
|
||||||
|
android:fillColor="#00000000"
|
||||||
|
android:pathData="M19,29L89,29"
|
||||||
|
android:strokeWidth="0.8"
|
||||||
|
android:strokeColor="#33FFFFFF" />
|
||||||
|
<path
|
||||||
|
android:fillColor="#00000000"
|
||||||
|
android:pathData="M19,39L89,39"
|
||||||
|
android:strokeWidth="0.8"
|
||||||
|
android:strokeColor="#33FFFFFF" />
|
||||||
|
<path
|
||||||
|
android:fillColor="#00000000"
|
||||||
|
android:pathData="M19,49L89,49"
|
||||||
|
android:strokeWidth="0.8"
|
||||||
|
android:strokeColor="#33FFFFFF" />
|
||||||
|
<path
|
||||||
|
android:fillColor="#00000000"
|
||||||
|
android:pathData="M19,59L89,59"
|
||||||
|
android:strokeWidth="0.8"
|
||||||
|
android:strokeColor="#33FFFFFF" />
|
||||||
|
<path
|
||||||
|
android:fillColor="#00000000"
|
||||||
|
android:pathData="M19,69L89,69"
|
||||||
|
android:strokeWidth="0.8"
|
||||||
|
android:strokeColor="#33FFFFFF" />
|
||||||
|
<path
|
||||||
|
android:fillColor="#00000000"
|
||||||
|
android:pathData="M19,79L89,79"
|
||||||
|
android:strokeWidth="0.8"
|
||||||
|
android:strokeColor="#33FFFFFF" />
|
||||||
|
<path
|
||||||
|
android:fillColor="#00000000"
|
||||||
|
android:pathData="M29,19L29,89"
|
||||||
|
android:strokeWidth="0.8"
|
||||||
|
android:strokeColor="#33FFFFFF" />
|
||||||
|
<path
|
||||||
|
android:fillColor="#00000000"
|
||||||
|
android:pathData="M39,19L39,89"
|
||||||
|
android:strokeWidth="0.8"
|
||||||
|
android:strokeColor="#33FFFFFF" />
|
||||||
|
<path
|
||||||
|
android:fillColor="#00000000"
|
||||||
|
android:pathData="M49,19L49,89"
|
||||||
|
android:strokeWidth="0.8"
|
||||||
|
android:strokeColor="#33FFFFFF" />
|
||||||
|
<path
|
||||||
|
android:fillColor="#00000000"
|
||||||
|
android:pathData="M59,19L59,89"
|
||||||
|
android:strokeWidth="0.8"
|
||||||
|
android:strokeColor="#33FFFFFF" />
|
||||||
|
<path
|
||||||
|
android:fillColor="#00000000"
|
||||||
|
android:pathData="M69,19L69,89"
|
||||||
|
android:strokeWidth="0.8"
|
||||||
|
android:strokeColor="#33FFFFFF" />
|
||||||
|
<path
|
||||||
|
android:fillColor="#00000000"
|
||||||
|
android:pathData="M79,19L79,89"
|
||||||
|
android:strokeWidth="0.8"
|
||||||
|
android:strokeColor="#33FFFFFF" />
|
||||||
|
</vector>
|
||||||
@@ -0,0 +1,46 @@
|
|||||||
|
<?xml version="1.0" encoding="utf-8"?>
|
||||||
|
<androidx.constraintlayout.widget.ConstraintLayout xmlns:android="http://schemas.android.com/apk/res/android"
|
||||||
|
xmlns:app="http://schemas.android.com/apk/res-auto"
|
||||||
|
xmlns:tools="http://schemas.android.com/tools"
|
||||||
|
android:layout_width="match_parent"
|
||||||
|
android:layout_height="match_parent"
|
||||||
|
tools:context=".MainActivity">
|
||||||
|
|
||||||
|
<LinearLayout
|
||||||
|
android:layout_width="match_parent"
|
||||||
|
android:layout_height="match_parent"
|
||||||
|
android:gravity="center"
|
||||||
|
android:orientation="vertical">
|
||||||
|
|
||||||
|
<EditText
|
||||||
|
android:id="@+id/input_text"
|
||||||
|
android:layout_width="match_parent"
|
||||||
|
android:layout_height="320dp"
|
||||||
|
android:layout_weight="2.5"
|
||||||
|
android:hint="@string/keyword_hint"
|
||||||
|
android:scrollbars="vertical"
|
||||||
|
android:text=""
|
||||||
|
android:textSize="15dp" />
|
||||||
|
|
||||||
|
<TextView
|
||||||
|
android:id="@+id/my_text"
|
||||||
|
android:layout_width="match_parent"
|
||||||
|
android:layout_height="443dp"
|
||||||
|
android:layout_weight="2.5"
|
||||||
|
android:padding="24dp"
|
||||||
|
android:scrollbars="vertical"
|
||||||
|
android:singleLine="false"
|
||||||
|
android:text="@string/hint"
|
||||||
|
android:textSize="15dp" />
|
||||||
|
|
||||||
|
<Button
|
||||||
|
android:id="@+id/record_button"
|
||||||
|
android:layout_width="wrap_content"
|
||||||
|
android:layout_height="wrap_content"
|
||||||
|
android:layout_weight="0.5"
|
||||||
|
android:text="@string/start" />
|
||||||
|
|
||||||
|
</LinearLayout>
|
||||||
|
|
||||||
|
|
||||||
|
</androidx.constraintlayout.widget.ConstraintLayout>
|
||||||
@@ -0,0 +1,5 @@
|
|||||||
|
<?xml version="1.0" encoding="utf-8"?>
|
||||||
|
<adaptive-icon xmlns:android="http://schemas.android.com/apk/res/android">
|
||||||
|
<background android:drawable="@drawable/ic_launcher_background" />
|
||||||
|
<foreground android:drawable="@drawable/ic_launcher_foreground" />
|
||||||
|
</adaptive-icon>
|
||||||
@@ -0,0 +1,5 @@
|
|||||||
|
<?xml version="1.0" encoding="utf-8"?>
|
||||||
|
<adaptive-icon xmlns:android="http://schemas.android.com/apk/res/android">
|
||||||
|
<background android:drawable="@drawable/ic_launcher_background" />
|
||||||
|
<foreground android:drawable="@drawable/ic_launcher_foreground" />
|
||||||
|
</adaptive-icon>
|
||||||
|
After Width: | Height: | Size: 1.4 KiB |
|
After Width: | Height: | Size: 2.8 KiB |
|
After Width: | Height: | Size: 982 B |
|
After Width: | Height: | Size: 1.7 KiB |
|
After Width: | Height: | Size: 1.9 KiB |
|
After Width: | Height: | Size: 3.8 KiB |
|
After Width: | Height: | Size: 2.8 KiB |
|
After Width: | Height: | Size: 5.8 KiB |
|
After Width: | Height: | Size: 3.8 KiB |
|
After Width: | Height: | Size: 7.6 KiB |
@@ -0,0 +1,16 @@
|
|||||||
|
<resources xmlns:tools="http://schemas.android.com/tools">
|
||||||
|
<!-- Base application theme. -->
|
||||||
|
<style name="Theme.SherpaOnnx" parent="Theme.MaterialComponents.DayNight.DarkActionBar">
|
||||||
|
<!-- Primary brand color. -->
|
||||||
|
<item name="colorPrimary">@color/purple_200</item>
|
||||||
|
<item name="colorPrimaryVariant">@color/purple_700</item>
|
||||||
|
<item name="colorOnPrimary">@color/black</item>
|
||||||
|
<!-- Secondary brand color. -->
|
||||||
|
<item name="colorSecondary">@color/teal_200</item>
|
||||||
|
<item name="colorSecondaryVariant">@color/teal_200</item>
|
||||||
|
<item name="colorOnSecondary">@color/black</item>
|
||||||
|
<!-- Status bar color. -->
|
||||||
|
<item name="android:statusBarColor">?attr/colorPrimaryVariant</item>
|
||||||
|
<!-- Customize your theme here. -->
|
||||||
|
</style>
|
||||||
|
</resources>
|
||||||
10
android/SherpaOnnxKws/app/src/main/res/values/colors.xml
Normal file
@@ -0,0 +1,10 @@
|
|||||||
|
<?xml version="1.0" encoding="utf-8"?>
|
||||||
|
<resources>
|
||||||
|
<color name="purple_200">#FFBB86FC</color>
|
||||||
|
<color name="purple_500">#FF6200EE</color>
|
||||||
|
<color name="purple_700">#FF3700B3</color>
|
||||||
|
<color name="teal_200">#FF03DAC5</color>
|
||||||
|
<color name="teal_700">#FF018786</color>
|
||||||
|
<color name="black">#FF000000</color>
|
||||||
|
<color name="white">#FFFFFFFF</color>
|
||||||
|
</resources>
|
||||||
12
android/SherpaOnnxKws/app/src/main/res/values/strings.xml
Normal file
@@ -0,0 +1,12 @@
|
|||||||
|
<resources>
|
||||||
|
<string name="app_name">KWS with Next-gen Kaldi</string>
|
||||||
|
<string name="hint">Click the Start button to play keyword spotting with Next-gen Kaldi.
|
||||||
|
\n
|
||||||
|
\n\n\n
|
||||||
|
The source code and pre-trained models are publicly available.
|
||||||
|
Please see https://github.com/k2-fsa/sherpa-onnx for details.
|
||||||
|
</string>
|
||||||
|
<string name="keyword_hint">Input your keywords here, one keyword perline.</string>
|
||||||
|
<string name="start">Start</string>
|
||||||
|
<string name="stop">Stop</string>
|
||||||
|
</resources>
|
||||||
16
android/SherpaOnnxKws/app/src/main/res/values/themes.xml
Normal file
@@ -0,0 +1,16 @@
|
|||||||
|
<resources xmlns:tools="http://schemas.android.com/tools">
|
||||||
|
<!-- Base application theme. -->
|
||||||
|
<style name="Theme.SherpaOnnx" parent="Theme.MaterialComponents.DayNight.DarkActionBar">
|
||||||
|
<!-- Primary brand color. -->
|
||||||
|
<item name="colorPrimary">@color/purple_500</item>
|
||||||
|
<item name="colorPrimaryVariant">@color/purple_700</item>
|
||||||
|
<item name="colorOnPrimary">@color/white</item>
|
||||||
|
<!-- Secondary brand color. -->
|
||||||
|
<item name="colorSecondary">@color/teal_200</item>
|
||||||
|
<item name="colorSecondaryVariant">@color/teal_700</item>
|
||||||
|
<item name="colorOnSecondary">@color/black</item>
|
||||||
|
<!-- Status bar color. -->
|
||||||
|
<item name="android:statusBarColor">?attr/colorPrimaryVariant</item>
|
||||||
|
<!-- Customize your theme here. -->
|
||||||
|
</style>
|
||||||
|
</resources>
|
||||||
13
android/SherpaOnnxKws/app/src/main/res/xml/backup_rules.xml
Normal file
@@ -0,0 +1,13 @@
|
|||||||
|
<?xml version="1.0" encoding="utf-8"?><!--
|
||||||
|
Sample backup rules file; uncomment and customize as necessary.
|
||||||
|
See https://developer.android.com/guide/topics/data/autobackup
|
||||||
|
for details.
|
||||||
|
Note: This file is ignored for devices older that API 31
|
||||||
|
See https://developer.android.com/about/versions/12/backup-restore
|
||||||
|
-->
|
||||||
|
<full-backup-content>
|
||||||
|
<!--
|
||||||
|
<include domain="sharedpref" path="."/>
|
||||||
|
<exclude domain="sharedpref" path="device.xml"/>
|
||||||
|
-->
|
||||||
|
</full-backup-content>
|
||||||
@@ -0,0 +1,19 @@
|
|||||||
|
<?xml version="1.0" encoding="utf-8"?><!--
|
||||||
|
Sample data extraction rules file; uncomment and customize as necessary.
|
||||||
|
See https://developer.android.com/about/versions/12/backup-restore#xml-changes
|
||||||
|
for details.
|
||||||
|
-->
|
||||||
|
<data-extraction-rules>
|
||||||
|
<cloud-backup>
|
||||||
|
<!-- TODO: Use <include> and <exclude> to control what is backed up.
|
||||||
|
<include .../>
|
||||||
|
<exclude .../>
|
||||||
|
-->
|
||||||
|
</cloud-backup>
|
||||||
|
<!--
|
||||||
|
<device-transfer>
|
||||||
|
<include .../>
|
||||||
|
<exclude .../>
|
||||||
|
</device-transfer>
|
||||||
|
-->
|
||||||
|
</data-extraction-rules>
|
||||||
@@ -0,0 +1,17 @@
|
|||||||
|
package com.k2fsa.sherpa.onnx
|
||||||
|
|
||||||
|
import org.junit.Test
|
||||||
|
|
||||||
|
import org.junit.Assert.*
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Example local unit test, which will execute on the development machine (host).
|
||||||
|
*
|
||||||
|
* See [testing documentation](http://d.android.com/tools/testing).
|
||||||
|
*/
|
||||||
|
class ExampleUnitTest {
|
||||||
|
@Test
|
||||||
|
fun addition_isCorrect() {
|
||||||
|
assertEquals(4, 2 + 2)
|
||||||
|
}
|
||||||
|
}
|
||||||
6
android/SherpaOnnxKws/build.gradle
Normal file
@@ -0,0 +1,6 @@
|
|||||||
|
// Top-level build file where you can add configuration options common to all sub-projects/modules.
|
||||||
|
plugins {
|
||||||
|
id 'com.android.application' version '7.3.1' apply false
|
||||||
|
id 'com.android.library' version '7.3.1' apply false
|
||||||
|
id 'org.jetbrains.kotlin.android' version '1.7.20' apply false
|
||||||
|
}
|
||||||
23
android/SherpaOnnxKws/gradle.properties
Normal file
@@ -0,0 +1,23 @@
|
|||||||
|
# Project-wide Gradle settings.
|
||||||
|
# IDE (e.g. Android Studio) users:
|
||||||
|
# Gradle settings configured through the IDE *will override*
|
||||||
|
# any settings specified in this file.
|
||||||
|
# For more details on how to configure your build environment visit
|
||||||
|
# http://www.gradle.org/docs/current/userguide/build_environment.html
|
||||||
|
# Specifies the JVM arguments used for the daemon process.
|
||||||
|
# The setting is particularly useful for tweaking memory settings.
|
||||||
|
org.gradle.jvmargs=-Xmx2048m -Dfile.encoding=UTF-8
|
||||||
|
# When configured, Gradle will run in incubating parallel mode.
|
||||||
|
# This option should only be used with decoupled projects. More details, visit
|
||||||
|
# http://www.gradle.org/docs/current/userguide/multi_project_builds.html#sec:decoupled_projects
|
||||||
|
# org.gradle.parallel=true
|
||||||
|
# AndroidX package structure to make it clearer which packages are bundled with the
|
||||||
|
# Android operating system, and which are packaged with your app's APK
|
||||||
|
# https://developer.android.com/topic/libraries/support-library/androidx-rn
|
||||||
|
android.useAndroidX=true
|
||||||
|
# Kotlin code style for this project: "official" or "obsolete":
|
||||||
|
kotlin.code.style=official
|
||||||
|
# Enables namespacing of each library's R class so that its R class includes only the
|
||||||
|
# resources declared in the library itself and none from the library's dependencies,
|
||||||
|
# thereby reducing the size of the R class for that library
|
||||||
|
android.nonTransitiveRClass=true
|
||||||
BIN
android/SherpaOnnxKws/gradle/wrapper/gradle-wrapper.jar
vendored
Normal file
6
android/SherpaOnnxKws/gradle/wrapper/gradle-wrapper.properties
vendored
Normal file
@@ -0,0 +1,6 @@
|
|||||||
|
#Thu Feb 23 11:09:06 CST 2023
|
||||||
|
distributionBase=GRADLE_USER_HOME
|
||||||
|
distributionUrl=https\://services.gradle.org/distributions/gradle-7.4-bin.zip
|
||||||
|
distributionPath=wrapper/dists
|
||||||
|
zipStorePath=wrapper/dists
|
||||||
|
zipStoreBase=GRADLE_USER_HOME
|
||||||
185
android/SherpaOnnxKws/gradlew
vendored
Executable file
@@ -0,0 +1,185 @@
|
|||||||
|
#!/usr/bin/env sh
|
||||||
|
|
||||||
|
#
|
||||||
|
# Copyright 2015 the original author or authors.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# https://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
#
|
||||||
|
|
||||||
|
##############################################################################
|
||||||
|
##
|
||||||
|
## Gradle start up script for UN*X
|
||||||
|
##
|
||||||
|
##############################################################################
|
||||||
|
|
||||||
|
# Attempt to set APP_HOME
|
||||||
|
# Resolve links: $0 may be a link
|
||||||
|
PRG="$0"
|
||||||
|
# Need this for relative symlinks.
|
||||||
|
while [ -h "$PRG" ] ; do
|
||||||
|
ls=`ls -ld "$PRG"`
|
||||||
|
link=`expr "$ls" : '.*-> \(.*\)$'`
|
||||||
|
if expr "$link" : '/.*' > /dev/null; then
|
||||||
|
PRG="$link"
|
||||||
|
else
|
||||||
|
PRG=`dirname "$PRG"`"/$link"
|
||||||
|
fi
|
||||||
|
done
|
||||||
|
SAVED="`pwd`"
|
||||||
|
cd "`dirname \"$PRG\"`/" >/dev/null
|
||||||
|
APP_HOME="`pwd -P`"
|
||||||
|
cd "$SAVED" >/dev/null
|
||||||
|
|
||||||
|
APP_NAME="Gradle"
|
||||||
|
APP_BASE_NAME=`basename "$0"`
|
||||||
|
|
||||||
|
# Add default JVM options here. You can also use JAVA_OPTS and GRADLE_OPTS to pass JVM options to this script.
|
||||||
|
DEFAULT_JVM_OPTS='"-Xmx64m" "-Xms64m"'
|
||||||
|
|
||||||
|
# Use the maximum available, or set MAX_FD != -1 to use that value.
|
||||||
|
MAX_FD="maximum"
|
||||||
|
|
||||||
|
warn () {
|
||||||
|
echo "$*"
|
||||||
|
}
|
||||||
|
|
||||||
|
die () {
|
||||||
|
echo
|
||||||
|
echo "$*"
|
||||||
|
echo
|
||||||
|
exit 1
|
||||||
|
}
|
||||||
|
|
||||||
|
# OS specific support (must be 'true' or 'false').
|
||||||
|
cygwin=false
|
||||||
|
msys=false
|
||||||
|
darwin=false
|
||||||
|
nonstop=false
|
||||||
|
case "`uname`" in
|
||||||
|
CYGWIN* )
|
||||||
|
cygwin=true
|
||||||
|
;;
|
||||||
|
Darwin* )
|
||||||
|
darwin=true
|
||||||
|
;;
|
||||||
|
MINGW* )
|
||||||
|
msys=true
|
||||||
|
;;
|
||||||
|
NONSTOP* )
|
||||||
|
nonstop=true
|
||||||
|
;;
|
||||||
|
esac
|
||||||
|
|
||||||
|
CLASSPATH=$APP_HOME/gradle/wrapper/gradle-wrapper.jar
|
||||||
|
|
||||||
|
|
||||||
|
# Determine the Java command to use to start the JVM.
|
||||||
|
if [ -n "$JAVA_HOME" ] ; then
|
||||||
|
if [ -x "$JAVA_HOME/jre/sh/java" ] ; then
|
||||||
|
# IBM's JDK on AIX uses strange locations for the executables
|
||||||
|
JAVACMD="$JAVA_HOME/jre/sh/java"
|
||||||
|
else
|
||||||
|
JAVACMD="$JAVA_HOME/bin/java"
|
||||||
|
fi
|
||||||
|
if [ ! -x "$JAVACMD" ] ; then
|
||||||
|
die "ERROR: JAVA_HOME is set to an invalid directory: $JAVA_HOME
|
||||||
|
|
||||||
|
Please set the JAVA_HOME variable in your environment to match the
|
||||||
|
location of your Java installation."
|
||||||
|
fi
|
||||||
|
else
|
||||||
|
JAVACMD="java"
|
||||||
|
which java >/dev/null 2>&1 || die "ERROR: JAVA_HOME is not set and no 'java' command could be found in your PATH.
|
||||||
|
|
||||||
|
Please set the JAVA_HOME variable in your environment to match the
|
||||||
|
location of your Java installation."
|
||||||
|
fi
|
||||||
|
|
||||||
|
# Increase the maximum file descriptors if we can.
|
||||||
|
if [ "$cygwin" = "false" -a "$darwin" = "false" -a "$nonstop" = "false" ] ; then
|
||||||
|
MAX_FD_LIMIT=`ulimit -H -n`
|
||||||
|
if [ $? -eq 0 ] ; then
|
||||||
|
if [ "$MAX_FD" = "maximum" -o "$MAX_FD" = "max" ] ; then
|
||||||
|
MAX_FD="$MAX_FD_LIMIT"
|
||||||
|
fi
|
||||||
|
ulimit -n $MAX_FD
|
||||||
|
if [ $? -ne 0 ] ; then
|
||||||
|
warn "Could not set maximum file descriptor limit: $MAX_FD"
|
||||||
|
fi
|
||||||
|
else
|
||||||
|
warn "Could not query maximum file descriptor limit: $MAX_FD_LIMIT"
|
||||||
|
fi
|
||||||
|
fi
|
||||||
|
|
||||||
|
# For Darwin, add options to specify how the application appears in the dock
|
||||||
|
if $darwin; then
|
||||||
|
GRADLE_OPTS="$GRADLE_OPTS \"-Xdock:name=$APP_NAME\" \"-Xdock:icon=$APP_HOME/media/gradle.icns\""
|
||||||
|
fi
|
||||||
|
|
||||||
|
# For Cygwin or MSYS, switch paths to Windows format before running java
|
||||||
|
if [ "$cygwin" = "true" -o "$msys" = "true" ] ; then
|
||||||
|
APP_HOME=`cygpath --path --mixed "$APP_HOME"`
|
||||||
|
CLASSPATH=`cygpath --path --mixed "$CLASSPATH"`
|
||||||
|
|
||||||
|
JAVACMD=`cygpath --unix "$JAVACMD"`
|
||||||
|
|
||||||
|
# We build the pattern for arguments to be converted via cygpath
|
||||||
|
ROOTDIRSRAW=`find -L / -maxdepth 1 -mindepth 1 -type d 2>/dev/null`
|
||||||
|
SEP=""
|
||||||
|
for dir in $ROOTDIRSRAW ; do
|
||||||
|
ROOTDIRS="$ROOTDIRS$SEP$dir"
|
||||||
|
SEP="|"
|
||||||
|
done
|
||||||
|
OURCYGPATTERN="(^($ROOTDIRS))"
|
||||||
|
# Add a user-defined pattern to the cygpath arguments
|
||||||
|
if [ "$GRADLE_CYGPATTERN" != "" ] ; then
|
||||||
|
OURCYGPATTERN="$OURCYGPATTERN|($GRADLE_CYGPATTERN)"
|
||||||
|
fi
|
||||||
|
# Now convert the arguments - kludge to limit ourselves to /bin/sh
|
||||||
|
i=0
|
||||||
|
for arg in "$@" ; do
|
||||||
|
CHECK=`echo "$arg"|egrep -c "$OURCYGPATTERN" -`
|
||||||
|
CHECK2=`echo "$arg"|egrep -c "^-"` ### Determine if an option
|
||||||
|
|
||||||
|
if [ $CHECK -ne 0 ] && [ $CHECK2 -eq 0 ] ; then ### Added a condition
|
||||||
|
eval `echo args$i`=`cygpath --path --ignore --mixed "$arg"`
|
||||||
|
else
|
||||||
|
eval `echo args$i`="\"$arg\""
|
||||||
|
fi
|
||||||
|
i=`expr $i + 1`
|
||||||
|
done
|
||||||
|
case $i in
|
||||||
|
0) set -- ;;
|
||||||
|
1) set -- "$args0" ;;
|
||||||
|
2) set -- "$args0" "$args1" ;;
|
||||||
|
3) set -- "$args0" "$args1" "$args2" ;;
|
||||||
|
4) set -- "$args0" "$args1" "$args2" "$args3" ;;
|
||||||
|
5) set -- "$args0" "$args1" "$args2" "$args3" "$args4" ;;
|
||||||
|
6) set -- "$args0" "$args1" "$args2" "$args3" "$args4" "$args5" ;;
|
||||||
|
7) set -- "$args0" "$args1" "$args2" "$args3" "$args4" "$args5" "$args6" ;;
|
||||||
|
8) set -- "$args0" "$args1" "$args2" "$args3" "$args4" "$args5" "$args6" "$args7" ;;
|
||||||
|
9) set -- "$args0" "$args1" "$args2" "$args3" "$args4" "$args5" "$args6" "$args7" "$args8" ;;
|
||||||
|
esac
|
||||||
|
fi
|
||||||
|
|
||||||
|
# Escape application args
|
||||||
|
save () {
|
||||||
|
for i do printf %s\\n "$i" | sed "s/'/'\\\\''/g;1s/^/'/;\$s/\$/' \\\\/" ; done
|
||||||
|
echo " "
|
||||||
|
}
|
||||||
|
APP_ARGS=`save "$@"`
|
||||||
|
|
||||||
|
# Collect all arguments for the java command, following the shell quoting and substitution rules
|
||||||
|
eval set -- $DEFAULT_JVM_OPTS $JAVA_OPTS $GRADLE_OPTS "\"-Dorg.gradle.appname=$APP_BASE_NAME\"" -classpath "\"$CLASSPATH\"" org.gradle.wrapper.GradleWrapperMain "$APP_ARGS"
|
||||||
|
|
||||||
|
exec "$JAVACMD" "$@"
|
||||||
89
android/SherpaOnnxKws/gradlew.bat
vendored
Normal file
@@ -0,0 +1,89 @@
|
|||||||
|
@rem
|
||||||
|
@rem Copyright 2015 the original author or authors.
|
||||||
|
@rem
|
||||||
|
@rem Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
@rem you may not use this file except in compliance with the License.
|
||||||
|
@rem You may obtain a copy of the License at
|
||||||
|
@rem
|
||||||
|
@rem https://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
@rem
|
||||||
|
@rem Unless required by applicable law or agreed to in writing, software
|
||||||
|
@rem distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
@rem WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
@rem See the License for the specific language governing permissions and
|
||||||
|
@rem limitations under the License.
|
||||||
|
@rem
|
||||||
|
|
||||||
|
@if "%DEBUG%" == "" @echo off
|
||||||
|
@rem ##########################################################################
|
||||||
|
@rem
|
||||||
|
@rem Gradle startup script for Windows
|
||||||
|
@rem
|
||||||
|
@rem ##########################################################################
|
||||||
|
|
||||||
|
@rem Set local scope for the variables with windows NT shell
|
||||||
|
if "%OS%"=="Windows_NT" setlocal
|
||||||
|
|
||||||
|
set DIRNAME=%~dp0
|
||||||
|
if "%DIRNAME%" == "" set DIRNAME=.
|
||||||
|
set APP_BASE_NAME=%~n0
|
||||||
|
set APP_HOME=%DIRNAME%
|
||||||
|
|
||||||
|
@rem Resolve any "." and ".." in APP_HOME to make it shorter.
|
||||||
|
for %%i in ("%APP_HOME%") do set APP_HOME=%%~fi
|
||||||
|
|
||||||
|
@rem Add default JVM options here. You can also use JAVA_OPTS and GRADLE_OPTS to pass JVM options to this script.
|
||||||
|
set DEFAULT_JVM_OPTS="-Xmx64m" "-Xms64m"
|
||||||
|
|
||||||
|
@rem Find java.exe
|
||||||
|
if defined JAVA_HOME goto findJavaFromJavaHome
|
||||||
|
|
||||||
|
set JAVA_EXE=java.exe
|
||||||
|
%JAVA_EXE% -version >NUL 2>&1
|
||||||
|
if "%ERRORLEVEL%" == "0" goto execute
|
||||||
|
|
||||||
|
echo.
|
||||||
|
echo ERROR: JAVA_HOME is not set and no 'java' command could be found in your PATH.
|
||||||
|
echo.
|
||||||
|
echo Please set the JAVA_HOME variable in your environment to match the
|
||||||
|
echo location of your Java installation.
|
||||||
|
|
||||||
|
goto fail
|
||||||
|
|
||||||
|
:findJavaFromJavaHome
|
||||||
|
set JAVA_HOME=%JAVA_HOME:"=%
|
||||||
|
set JAVA_EXE=%JAVA_HOME%/bin/java.exe
|
||||||
|
|
||||||
|
if exist "%JAVA_EXE%" goto execute
|
||||||
|
|
||||||
|
echo.
|
||||||
|
echo ERROR: JAVA_HOME is set to an invalid directory: %JAVA_HOME%
|
||||||
|
echo.
|
||||||
|
echo Please set the JAVA_HOME variable in your environment to match the
|
||||||
|
echo location of your Java installation.
|
||||||
|
|
||||||
|
goto fail
|
||||||
|
|
||||||
|
:execute
|
||||||
|
@rem Setup the command line
|
||||||
|
|
||||||
|
set CLASSPATH=%APP_HOME%\gradle\wrapper\gradle-wrapper.jar
|
||||||
|
|
||||||
|
|
||||||
|
@rem Execute Gradle
|
||||||
|
"%JAVA_EXE%" %DEFAULT_JVM_OPTS% %JAVA_OPTS% %GRADLE_OPTS% "-Dorg.gradle.appname=%APP_BASE_NAME%" -classpath "%CLASSPATH%" org.gradle.wrapper.GradleWrapperMain %*
|
||||||
|
|
||||||
|
:end
|
||||||
|
@rem End local scope for the variables with windows NT shell
|
||||||
|
if "%ERRORLEVEL%"=="0" goto mainEnd
|
||||||
|
|
||||||
|
:fail
|
||||||
|
rem Set variable GRADLE_EXIT_CONSOLE if you need the _script_ return code instead of
|
||||||
|
rem the _cmd.exe /c_ return code!
|
||||||
|
if not "" == "%GRADLE_EXIT_CONSOLE%" exit 1
|
||||||
|
exit /b 1
|
||||||
|
|
||||||
|
:mainEnd
|
||||||
|
if "%OS%"=="Windows_NT" endlocal
|
||||||
|
|
||||||
|
:omega
|
||||||
16
android/SherpaOnnxKws/settings.gradle
Normal file
@@ -0,0 +1,16 @@
|
|||||||
|
pluginManagement {
|
||||||
|
repositories {
|
||||||
|
gradlePluginPortal()
|
||||||
|
google()
|
||||||
|
mavenCentral()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
dependencyResolutionManagement {
|
||||||
|
repositoriesMode.set(RepositoriesMode.FAIL_ON_PROJECT_REPOS)
|
||||||
|
repositories {
|
||||||
|
google()
|
||||||
|
mavenCentral()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
rootProject.name = "SherpaOnnxKws"
|
||||||
|
include ':app'
|
||||||
139
build-kws-apk.sh
Executable file
@@ -0,0 +1,139 @@
|
|||||||
|
#!/usr/bin/env bash
|
||||||
|
|
||||||
|
# Please set the environment variable ANDROID_NDK
|
||||||
|
# before running this script
|
||||||
|
|
||||||
|
# Inside the $ANDROID_NDK directory, you can find a binary ndk-build
|
||||||
|
# and some other files like the file "build/cmake/android.toolchain.cmake"
|
||||||
|
|
||||||
|
set -e
|
||||||
|
|
||||||
|
log() {
|
||||||
|
# This function is from espnet
|
||||||
|
local fname=${BASH_SOURCE[1]##*/}
|
||||||
|
echo -e "$(date '+%Y-%m-%d %H:%M:%S') (${fname}:${BASH_LINENO[0]}:${FUNCNAME[1]}) $*"
|
||||||
|
}
|
||||||
|
|
||||||
|
SHERPA_ONNX_VERSION=$(grep "SHERPA_ONNX_VERSION" ./CMakeLists.txt | cut -d " " -f 2 | cut -d '"' -f 2)
|
||||||
|
|
||||||
|
log "Building keyword spotting APK for sherpa-onnx v${SHERPA_ONNX_VERSION}"
|
||||||
|
|
||||||
|
log "====================arm64-v8a================="
|
||||||
|
./build-android-arm64-v8a.sh
|
||||||
|
log "====================armv7-eabi================"
|
||||||
|
./build-android-armv7-eabi.sh
|
||||||
|
log "====================x86-64===================="
|
||||||
|
./build-android-x86-64.sh
|
||||||
|
log "====================x86===================="
|
||||||
|
./build-android-x86.sh
|
||||||
|
|
||||||
|
mkdir -p apks
|
||||||
|
|
||||||
|
# Download the model
|
||||||
|
repo=sherpa-onnx-kws-zipformer-wenetspeech-3.3M-2024-01-01
|
||||||
|
|
||||||
|
if [ ! -d ./android/SherpaOnnxKws/app/src/main/assets/$repo ]; then
|
||||||
|
repo_url=https://www.modelscope.cn/pkufool/sherpa-onnx-kws-zipformer-wenetspeech-3.3M-2024-01-01.git
|
||||||
|
log "Start testing ${repo_url}"
|
||||||
|
log "Download pretrained model and test-data from $repo_url"
|
||||||
|
GIT_LFS_SKIP_SMUDGE=1 git clone $repo_url
|
||||||
|
pushd $repo
|
||||||
|
git lfs pull --include "*.onnx"
|
||||||
|
|
||||||
|
# remove .git to save spaces
|
||||||
|
rm -rf .git
|
||||||
|
rm *.int8.onnx
|
||||||
|
rm README.md configuration.json .gitattributes
|
||||||
|
rm -rfv test_wavs
|
||||||
|
ls -lh
|
||||||
|
popd
|
||||||
|
|
||||||
|
mv -v $repo ./android/SherpaOnnxKws/app/src/main/assets/
|
||||||
|
fi
|
||||||
|
|
||||||
|
tree ./android/SherpaOnnxKws/app/src/main/assets/
|
||||||
|
|
||||||
|
for arch in arm64-v8a armeabi-v7a x86_64 x86; do
|
||||||
|
log "------------------------------------------------------------"
|
||||||
|
log "build apk for $arch"
|
||||||
|
log "------------------------------------------------------------"
|
||||||
|
src_arch=$arch
|
||||||
|
if [ $arch == "armeabi-v7a" ]; then
|
||||||
|
src_arch=armv7-eabi
|
||||||
|
elif [ $arch == "x86_64" ]; then
|
||||||
|
src_arch=x86-64
|
||||||
|
fi
|
||||||
|
|
||||||
|
ls -lh ./build-android-$src_arch/install/lib/*.so
|
||||||
|
|
||||||
|
cp -v ./build-android-$src_arch/install/lib/*.so ./android/SherpaOnnxKws/app/src/main/jniLibs/$arch/
|
||||||
|
|
||||||
|
pushd ./android/SherpaOnnxKws
|
||||||
|
./gradlew build
|
||||||
|
popd
|
||||||
|
|
||||||
|
mv android/SherpaOnnxKws/app/build/outputs/apk/debug/app-debug.apk ./apks/sherpa-onnx-kws-wenetspeech-zh-${SHERPA_ONNX_VERSION}-$arch.apk
|
||||||
|
ls -lh apks
|
||||||
|
rm -v ./android/SherpaOnnxKws/app/src/main/jniLibs/$arch/*.so
|
||||||
|
done
|
||||||
|
|
||||||
|
git checkout .
|
||||||
|
|
||||||
|
rm -rf ./android/SherpaOnnxKws/app/src/main/assets/$repo
|
||||||
|
|
||||||
|
# English model
|
||||||
|
repo=sherpa-onnx-kws-zipformer-gigaspeech-3.3M-2024-01-01
|
||||||
|
|
||||||
|
if [ ! -d ./android/SherpaOnnxKws/app/src/main/assets/$repo ]; then
|
||||||
|
repo_url=https://www.modelscope.cn/pkufool/sherpa-onnx-kws-zipformer-gigaspeech-3.3M-2024-01-01.git
|
||||||
|
log "Start testing ${repo_url}"
|
||||||
|
log "Download pretrained model and test-data from $repo_url"
|
||||||
|
GIT_LFS_SKIP_SMUDGE=1 git clone $repo_url
|
||||||
|
pushd $repo
|
||||||
|
git lfs pull --include "*.onnx"
|
||||||
|
|
||||||
|
# remove .git to save spaces
|
||||||
|
rm -rf .git
|
||||||
|
rm *.int8.onnx
|
||||||
|
rm README.md configuration.json .gitattributes
|
||||||
|
rm -rfv test_wavs
|
||||||
|
ls -lh
|
||||||
|
popd
|
||||||
|
|
||||||
|
mv -v $repo ./android/SherpaOnnxKws/app/src/main/assets/
|
||||||
|
fi
|
||||||
|
|
||||||
|
tree ./android/SherpaOnnxKws/app/src/main/assets/
|
||||||
|
|
||||||
|
pushd android/SherpaOnnxKws/app/src/main/java/com/k2fsa/sherpa/onnx
|
||||||
|
sed -i.bak s/"type = 0"/"type = 1"/ ./MainActivity.kt
|
||||||
|
git diff
|
||||||
|
popd
|
||||||
|
|
||||||
|
for arch in arm64-v8a armeabi-v7a x86_64 x86; do
|
||||||
|
log "------------------------------------------------------------"
|
||||||
|
log "build apk for $arch"
|
||||||
|
log "------------------------------------------------------------"
|
||||||
|
src_arch=$arch
|
||||||
|
if [ $arch == "armeabi-v7a" ]; then
|
||||||
|
src_arch=armv7-eabi
|
||||||
|
elif [ $arch == "x86_64" ]; then
|
||||||
|
src_arch=x86-64
|
||||||
|
fi
|
||||||
|
|
||||||
|
ls -lh ./build-android-$src_arch/install/lib/*.so
|
||||||
|
|
||||||
|
cp -v ./build-android-$src_arch/install/lib/*.so ./android/SherpaOnnxKws/app/src/main/jniLibs/$arch/
|
||||||
|
|
||||||
|
pushd ./android/SherpaOnnxKws
|
||||||
|
./gradlew build
|
||||||
|
popd
|
||||||
|
|
||||||
|
mv android/SherpaOnnxKws/app/build/outputs/apk/debug/app-debug.apk ./apks/sherpa-onnx-kws-gigaspeech-en-${SHERPA_ONNX_VERSION}-$arch.apk
|
||||||
|
ls -lh apks
|
||||||
|
rm -v ./android/SherpaOnnxKws/app/src/main/jniLibs/$arch/*.so
|
||||||
|
done
|
||||||
|
|
||||||
|
git checkout .
|
||||||
|
|
||||||
|
rm -rf ./android/SherpaOnnxKws/app/src/main/assets/$repo
|
||||||
@@ -151,6 +151,7 @@ class BuildExtension(build_ext):
|
|||||||
# Remember to also change setup.py
|
# Remember to also change setup.py
|
||||||
|
|
||||||
binaries = ["sherpa-onnx"]
|
binaries = ["sherpa-onnx"]
|
||||||
|
binaries += ["sherpa-onnx-keyword-spotter"]
|
||||||
binaries += ["sherpa-onnx-offline"]
|
binaries += ["sherpa-onnx-offline"]
|
||||||
binaries += ["sherpa-onnx-microphone"]
|
binaries += ["sherpa-onnx-microphone"]
|
||||||
binaries += ["sherpa-onnx-microphone-offline"]
|
binaries += ["sherpa-onnx-microphone-offline"]
|
||||||
|
|||||||
@@ -36,13 +36,44 @@ import argparse
|
|||||||
|
|
||||||
from sherpa_onnx import text2token
|
from sherpa_onnx import text2token
|
||||||
|
|
||||||
|
|
||||||
def get_args():
|
def get_args():
|
||||||
parser = argparse.ArgumentParser()
|
parser = argparse.ArgumentParser()
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--text",
|
"--text",
|
||||||
type=str,
|
type=str,
|
||||||
required=True,
|
required=True,
|
||||||
help="Path to the input texts",
|
help="""Path to the input texts.
|
||||||
|
|
||||||
|
Each line in the texts contains the original phrase, it might also contain some
|
||||||
|
extra items, for example, the boosting score (startting with :), the triggering
|
||||||
|
threshold (startting with #, only used in keyword spotting task) and the original
|
||||||
|
phrase (startting with @). Note: extra items will be kept in the output.
|
||||||
|
|
||||||
|
example input 1 (tokens_type = ppinyin):
|
||||||
|
|
||||||
|
小爱同学 :2.0 #0.6 @小爱同学
|
||||||
|
你好问问 :3.5 @你好问问
|
||||||
|
小艺小艺 #0.6 @小艺小艺
|
||||||
|
|
||||||
|
example output 1:
|
||||||
|
|
||||||
|
x iǎo ài t óng x ué :2.0 #0.6 @小爱同学
|
||||||
|
n ǐ h ǎo w èn w èn :3.5 @你好问问
|
||||||
|
x iǎo y ì x iǎo y ì #0.6 @小艺小艺
|
||||||
|
|
||||||
|
example input 2 (tokens_type = bpe):
|
||||||
|
|
||||||
|
HELLO WORLD :1.5 #0.4
|
||||||
|
HI GOOGLE :2.0 #0.8
|
||||||
|
HEY SIRI #0.35
|
||||||
|
|
||||||
|
example output 2:
|
||||||
|
|
||||||
|
▁HE LL O ▁WORLD :1.5 #0.4
|
||||||
|
▁HI ▁GO O G LE :2.0 #0.8
|
||||||
|
▁HE Y ▁S I RI #0.35
|
||||||
|
""",
|
||||||
)
|
)
|
||||||
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
@@ -56,7 +87,11 @@ def get_args():
|
|||||||
"--tokens-type",
|
"--tokens-type",
|
||||||
type=str,
|
type=str,
|
||||||
required=True,
|
required=True,
|
||||||
help="The type of modeling units, should be cjkchar, bpe or cjkchar+bpe",
|
choices=["cjkchar", "bpe", "cjkchar+bpe", "fpinyin", "ppinyin"],
|
||||||
|
help="""The type of modeling units, should be cjkchar, bpe, cjkchar+bpe, fpinyin or ppinyin.
|
||||||
|
fpinyin means full pinyin, each cjkchar has a pinyin(with tone).
|
||||||
|
ppinyin means partial pinyin, it splits pinyin into initial and final,
|
||||||
|
""",
|
||||||
)
|
)
|
||||||
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
@@ -79,9 +114,21 @@ def main():
|
|||||||
args = get_args()
|
args = get_args()
|
||||||
|
|
||||||
texts = []
|
texts = []
|
||||||
|
# extra information like boosting score (start with :), triggering threshold (start with #)
|
||||||
|
# original keyword (start with @)
|
||||||
|
extra_info = []
|
||||||
with open(args.text, "r", encoding="utf8") as f:
|
with open(args.text, "r", encoding="utf8") as f:
|
||||||
for line in f:
|
for line in f:
|
||||||
texts.append(line.strip())
|
extra = []
|
||||||
|
text = []
|
||||||
|
toks = line.strip().split()
|
||||||
|
for tok in toks:
|
||||||
|
if tok[0] == ":" or tok[0] == "#" or tok[0] == "@":
|
||||||
|
extra.append(tok)
|
||||||
|
else:
|
||||||
|
text.append(tok)
|
||||||
|
texts.append(" ".join(text))
|
||||||
|
extra_info.append(extra)
|
||||||
encoded_texts = text2token(
|
encoded_texts = text2token(
|
||||||
texts,
|
texts,
|
||||||
tokens=args.tokens,
|
tokens=args.tokens,
|
||||||
@@ -89,7 +136,8 @@ def main():
|
|||||||
bpe_model=args.bpe_model,
|
bpe_model=args.bpe_model,
|
||||||
)
|
)
|
||||||
with open(args.output, "w", encoding="utf8") as f:
|
with open(args.output, "w", encoding="utf8") as f:
|
||||||
for txt in encoded_texts:
|
for i, txt in enumerate(encoded_texts):
|
||||||
|
txt += extra_info[i]
|
||||||
f.write(" ".join(txt) + "\n")
|
f.write(" ".join(txt) + "\n")
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
1
setup.py
@@ -51,6 +51,7 @@ def get_binaries_to_install():
|
|||||||
|
|
||||||
# Remember to also change cmake/cmake_extension.py
|
# Remember to also change cmake/cmake_extension.py
|
||||||
binaries = ["sherpa-onnx"]
|
binaries = ["sherpa-onnx"]
|
||||||
|
binaries += ["sherpa-onnx-keyword-spotter"]
|
||||||
binaries += ["sherpa-onnx-offline"]
|
binaries += ["sherpa-onnx-offline"]
|
||||||
binaries += ["sherpa-onnx-microphone"]
|
binaries += ["sherpa-onnx-microphone"]
|
||||||
binaries += ["sherpa-onnx-microphone-offline"]
|
binaries += ["sherpa-onnx-microphone-offline"]
|
||||||
|
|||||||
@@ -19,6 +19,8 @@ set(sources
|
|||||||
features.cc
|
features.cc
|
||||||
file-utils.cc
|
file-utils.cc
|
||||||
hypothesis.cc
|
hypothesis.cc
|
||||||
|
keyword-spotter-impl.cc
|
||||||
|
keyword-spotter.cc
|
||||||
offline-ctc-fst-decoder-config.cc
|
offline-ctc-fst-decoder-config.cc
|
||||||
offline-ctc-fst-decoder.cc
|
offline-ctc-fst-decoder.cc
|
||||||
offline-ctc-greedy-search-decoder.cc
|
offline-ctc-greedy-search-decoder.cc
|
||||||
@@ -87,6 +89,7 @@ set(sources
|
|||||||
stack.cc
|
stack.cc
|
||||||
symbol-table.cc
|
symbol-table.cc
|
||||||
text-utils.cc
|
text-utils.cc
|
||||||
|
transducer-keyword-decoder.cc
|
||||||
transpose.cc
|
transpose.cc
|
||||||
unbind.cc
|
unbind.cc
|
||||||
utils.cc
|
utils.cc
|
||||||
@@ -173,12 +176,14 @@ if(NOT BUILD_SHARED_LIBS AND CMAKE_SYSTEM_NAME STREQUAL Linux)
|
|||||||
endif()
|
endif()
|
||||||
|
|
||||||
add_executable(sherpa-onnx sherpa-onnx.cc)
|
add_executable(sherpa-onnx sherpa-onnx.cc)
|
||||||
|
add_executable(sherpa-onnx-keyword-spotter sherpa-onnx-keyword-spotter.cc)
|
||||||
add_executable(sherpa-onnx-offline sherpa-onnx-offline.cc)
|
add_executable(sherpa-onnx-offline sherpa-onnx-offline.cc)
|
||||||
add_executable(sherpa-onnx-offline-parallel sherpa-onnx-offline-parallel.cc)
|
add_executable(sherpa-onnx-offline-parallel sherpa-onnx-offline-parallel.cc)
|
||||||
add_executable(sherpa-onnx-offline-tts sherpa-onnx-offline-tts.cc)
|
add_executable(sherpa-onnx-offline-tts sherpa-onnx-offline-tts.cc)
|
||||||
|
|
||||||
set(main_exes
|
set(main_exes
|
||||||
sherpa-onnx
|
sherpa-onnx
|
||||||
|
sherpa-onnx-keyword-spotter
|
||||||
sherpa-onnx-offline
|
sherpa-onnx-offline
|
||||||
sherpa-onnx-offline-parallel
|
sherpa-onnx-offline-parallel
|
||||||
sherpa-onnx-offline-tts
|
sherpa-onnx-offline-tts
|
||||||
|
|||||||
@@ -5,6 +5,7 @@
|
|||||||
#include "sherpa-onnx/csrc/context-graph.h"
|
#include "sherpa-onnx/csrc/context-graph.h"
|
||||||
|
|
||||||
#include <chrono> // NOLINT
|
#include <chrono> // NOLINT
|
||||||
|
#include <cmath>
|
||||||
#include <map>
|
#include <map>
|
||||||
#include <random>
|
#include <random>
|
||||||
#include <string>
|
#include <string>
|
||||||
@@ -15,27 +16,25 @@
|
|||||||
|
|
||||||
namespace sherpa_onnx {
|
namespace sherpa_onnx {
|
||||||
|
|
||||||
TEST(ContextGraph, TestBasic) {
|
static void TestHelper(const std::map<std::string, float> &queries, float score,
|
||||||
|
bool strict_mode) {
|
||||||
std::vector<std::string> contexts_str(
|
std::vector<std::string> contexts_str(
|
||||||
{"S", "HE", "SHE", "SHELL", "HIS", "HERS", "HELLO", "THIS", "THEM"});
|
{"S", "HE", "SHE", "SHELL", "HIS", "HERS", "HELLO", "THIS", "THEM"});
|
||||||
std::vector<std::vector<int32_t>> contexts;
|
std::vector<std::vector<int32_t>> contexts;
|
||||||
|
std::vector<float> scores;
|
||||||
for (int32_t i = 0; i < contexts_str.size(); ++i) {
|
for (int32_t i = 0; i < contexts_str.size(); ++i) {
|
||||||
contexts.emplace_back(contexts_str[i].begin(), contexts_str[i].end());
|
contexts.emplace_back(contexts_str[i].begin(), contexts_str[i].end());
|
||||||
|
scores.push_back(std::round(score / contexts_str[i].size() * 100) / 100);
|
||||||
}
|
}
|
||||||
auto context_graph = ContextGraph(contexts, 1);
|
auto context_graph = ContextGraph(contexts, 1, scores);
|
||||||
|
|
||||||
auto queries = std::map<std::string, float>{
|
|
||||||
{"HEHERSHE", 14}, {"HERSHE", 12}, {"HISHE", 9},
|
|
||||||
{"SHED", 6}, {"SHELF", 6}, {"HELL", 2},
|
|
||||||
{"HELLO", 7}, {"DHRHISQ", 4}, {"THEN", 2}};
|
|
||||||
|
|
||||||
for (const auto &iter : queries) {
|
for (const auto &iter : queries) {
|
||||||
float total_scores = 0;
|
float total_scores = 0;
|
||||||
auto state = context_graph.Root();
|
auto state = context_graph.Root();
|
||||||
for (auto q : iter.first) {
|
for (auto q : iter.first) {
|
||||||
auto res = context_graph.ForwardOneStep(state, q);
|
auto res = context_graph.ForwardOneStep(state, q, strict_mode);
|
||||||
total_scores += res.first;
|
total_scores += std::get<0>(res);
|
||||||
state = res.second;
|
state = std::get<1>(res);
|
||||||
}
|
}
|
||||||
auto res = context_graph.Finalize(state);
|
auto res = context_graph.Finalize(state);
|
||||||
EXPECT_EQ(res.second->token, -1);
|
EXPECT_EQ(res.second->token, -1);
|
||||||
@@ -44,6 +43,37 @@ TEST(ContextGraph, TestBasic) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
TEST(ContextGraph, TestBasic) {
|
||||||
|
auto queries = std::map<std::string, float>{
|
||||||
|
{"HEHERSHE", 14}, {"HERSHE", 12}, {"HISHE", 9},
|
||||||
|
{"SHED", 6}, {"SHELF", 6}, {"HELL", 2},
|
||||||
|
{"HELLO", 7}, {"DHRHISQ", 4}, {"THEN", 2}};
|
||||||
|
TestHelper(queries, 0, true);
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST(ContextGraph, TestBasicNonStrict) {
|
||||||
|
auto queries = std::map<std::string, float>{
|
||||||
|
{"HEHERSHE", 7}, {"HERSHE", 5}, {"HISHE", 5}, {"SHED", 3}, {"SHELF", 3},
|
||||||
|
{"HELL", 2}, {"HELLO", 2}, {"DHRHISQ", 3}, {"THEN", 2}};
|
||||||
|
TestHelper(queries, 0, false);
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST(ContextGraph, TestCustomize) {
|
||||||
|
auto queries = std::map<std::string, float>{
|
||||||
|
{"HEHERSHE", 35.84}, {"HERSHE", 30.84}, {"HISHE", 24.18},
|
||||||
|
{"SHED", 18.34}, {"SHELF", 18.34}, {"HELL", 5},
|
||||||
|
{"HELLO", 13}, {"DHRHISQ", 10.84}, {"THEN", 5}};
|
||||||
|
TestHelper(queries, 5, true);
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST(ContextGraph, TestCustomizeNonStrict) {
|
||||||
|
auto queries = std::map<std::string, float>{
|
||||||
|
{"HEHERSHE", 20}, {"HERSHE", 15}, {"HISHE", 10.84},
|
||||||
|
{"SHED", 10}, {"SHELF", 10}, {"HELL", 5},
|
||||||
|
{"HELLO", 5}, {"DHRHISQ", 5.84}, {"THEN", 5}};
|
||||||
|
TestHelper(queries, 5, false);
|
||||||
|
}
|
||||||
|
|
||||||
TEST(ContextGraph, Benchmark) {
|
TEST(ContextGraph, Benchmark) {
|
||||||
std::random_device rd;
|
std::random_device rd;
|
||||||
std::mt19937 mt(rd());
|
std::mt19937 mt(rd());
|
||||||
|
|||||||
@@ -4,22 +4,59 @@
|
|||||||
|
|
||||||
#include "sherpa-onnx/csrc/context-graph.h"
|
#include "sherpa-onnx/csrc/context-graph.h"
|
||||||
|
|
||||||
|
#include <algorithm>
|
||||||
#include <cassert>
|
#include <cassert>
|
||||||
#include <queue>
|
#include <queue>
|
||||||
|
#include <string>
|
||||||
|
#include <tuple>
|
||||||
#include <utility>
|
#include <utility>
|
||||||
|
|
||||||
|
#include "sherpa-onnx/csrc/macros.h"
|
||||||
|
|
||||||
namespace sherpa_onnx {
|
namespace sherpa_onnx {
|
||||||
void ContextGraph::Build(
|
void ContextGraph::Build(const std::vector<std::vector<int32_t>> &token_ids,
|
||||||
const std::vector<std::vector<int32_t>> &token_ids) const {
|
const std::vector<float> &scores,
|
||||||
|
const std::vector<std::string> &phrases,
|
||||||
|
const std::vector<float> &ac_thresholds) const {
|
||||||
|
if (!scores.empty()) {
|
||||||
|
SHERPA_ONNX_CHECK_EQ(token_ids.size(), scores.size());
|
||||||
|
}
|
||||||
|
if (!phrases.empty()) {
|
||||||
|
SHERPA_ONNX_CHECK_EQ(token_ids.size(), phrases.size());
|
||||||
|
}
|
||||||
|
if (!ac_thresholds.empty()) {
|
||||||
|
SHERPA_ONNX_CHECK_EQ(token_ids.size(), ac_thresholds.size());
|
||||||
|
}
|
||||||
for (int32_t i = 0; i < token_ids.size(); ++i) {
|
for (int32_t i = 0; i < token_ids.size(); ++i) {
|
||||||
auto node = root_.get();
|
auto node = root_.get();
|
||||||
|
float score = scores.empty() ? 0.0f : scores[i];
|
||||||
|
score = score == 0.0f ? context_score_ : score;
|
||||||
|
float ac_threshold = ac_thresholds.empty() ? 0.0f : ac_thresholds[i];
|
||||||
|
ac_threshold = ac_threshold == 0.0f ? ac_threshold_ : ac_threshold;
|
||||||
|
std::string phrase = phrases.empty() ? std::string() : phrases[i];
|
||||||
|
|
||||||
for (int32_t j = 0; j < token_ids[i].size(); ++j) {
|
for (int32_t j = 0; j < token_ids[i].size(); ++j) {
|
||||||
int32_t token = token_ids[i][j];
|
int32_t token = token_ids[i][j];
|
||||||
if (0 == node->next.count(token)) {
|
if (0 == node->next.count(token)) {
|
||||||
bool is_end = j == token_ids[i].size() - 1;
|
bool is_end = j == token_ids[i].size() - 1;
|
||||||
node->next[token] = std::make_unique<ContextState>(
|
node->next[token] = std::make_unique<ContextState>(
|
||||||
token, context_score_, node->node_score + context_score_,
|
token, score, node->node_score + score,
|
||||||
is_end ? node->node_score + context_score_ : 0, is_end);
|
is_end ? node->node_score + score : 0, j + 1,
|
||||||
|
is_end ? ac_threshold : 0.0f, is_end,
|
||||||
|
is_end ? phrase : std::string());
|
||||||
|
} else {
|
||||||
|
float token_score = std::max(score, node->next[token]->token_score);
|
||||||
|
node->next[token]->token_score = token_score;
|
||||||
|
float node_score = node->node_score + token_score;
|
||||||
|
node->next[token]->node_score = node_score;
|
||||||
|
bool is_end =
|
||||||
|
(j == token_ids[i].size() - 1) || node->next[token]->is_end;
|
||||||
|
node->next[token]->output_score = is_end ? node_score : 0.0f;
|
||||||
|
node->next[token]->is_end = is_end;
|
||||||
|
if (j == token_ids[i].size() - 1) {
|
||||||
|
node->next[token]->phrase = phrase;
|
||||||
|
node->next[token]->ac_threshold = ac_threshold;
|
||||||
|
}
|
||||||
}
|
}
|
||||||
node = node->next[token].get();
|
node = node->next[token].get();
|
||||||
}
|
}
|
||||||
@@ -27,8 +64,9 @@ void ContextGraph::Build(
|
|||||||
FillFailOutput();
|
FillFailOutput();
|
||||||
}
|
}
|
||||||
|
|
||||||
std::pair<float, const ContextState *> ContextGraph::ForwardOneStep(
|
std::tuple<float, const ContextState *, const ContextState *>
|
||||||
const ContextState *state, int32_t token) const {
|
ContextGraph::ForwardOneStep(const ContextState *state, int32_t token,
|
||||||
|
bool strict_mode /*= true*/) const {
|
||||||
const ContextState *node;
|
const ContextState *node;
|
||||||
float score;
|
float score;
|
||||||
if (1 == state->next.count(token)) {
|
if (1 == state->next.count(token)) {
|
||||||
@@ -45,8 +83,22 @@ std::pair<float, const ContextState *> ContextGraph::ForwardOneStep(
|
|||||||
}
|
}
|
||||||
score = node->node_score - state->node_score;
|
score = node->node_score - state->node_score;
|
||||||
}
|
}
|
||||||
|
|
||||||
SHERPA_ONNX_CHECK(nullptr != node);
|
SHERPA_ONNX_CHECK(nullptr != node);
|
||||||
return std::make_pair(score + node->output_score, node);
|
|
||||||
|
const ContextState *matched_node =
|
||||||
|
node->is_end ? node : (node->output != nullptr ? node->output : nullptr);
|
||||||
|
|
||||||
|
if (!strict_mode && node->output_score != 0) {
|
||||||
|
SHERPA_ONNX_CHECK(nullptr != matched_node);
|
||||||
|
float output_score =
|
||||||
|
node->is_end ? node->node_score
|
||||||
|
: (node->output != nullptr ? node->output->node_score
|
||||||
|
: node->node_score);
|
||||||
|
return std::make_tuple(score + output_score - node->node_score, root_.get(),
|
||||||
|
matched_node);
|
||||||
|
}
|
||||||
|
return std::make_tuple(score + node->output_score, node, matched_node);
|
||||||
}
|
}
|
||||||
|
|
||||||
std::pair<float, const ContextState *> ContextGraph::Finalize(
|
std::pair<float, const ContextState *> ContextGraph::Finalize(
|
||||||
@@ -55,6 +107,22 @@ std::pair<float, const ContextState *> ContextGraph::Finalize(
|
|||||||
return std::make_pair(score, root_.get());
|
return std::make_pair(score, root_.get());
|
||||||
}
|
}
|
||||||
|
|
||||||
|
std::pair<bool, const ContextState *> ContextGraph::IsMatched(
|
||||||
|
const ContextState *state) const {
|
||||||
|
bool status = false;
|
||||||
|
const ContextState *node = nullptr;
|
||||||
|
if (state->is_end) {
|
||||||
|
status = true;
|
||||||
|
node = state;
|
||||||
|
} else {
|
||||||
|
if (state->output != nullptr) {
|
||||||
|
status = true;
|
||||||
|
node = state->output;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return std::make_pair(status, node);
|
||||||
|
}
|
||||||
|
|
||||||
void ContextGraph::FillFailOutput() const {
|
void ContextGraph::FillFailOutput() const {
|
||||||
std::queue<const ContextState *> node_queue;
|
std::queue<const ContextState *> node_queue;
|
||||||
for (auto &kv : root_->next) {
|
for (auto &kv : root_->next) {
|
||||||
|
|||||||
@@ -6,6 +6,8 @@
|
|||||||
#define SHERPA_ONNX_CSRC_CONTEXT_GRAPH_H_
|
#define SHERPA_ONNX_CSRC_CONTEXT_GRAPH_H_
|
||||||
|
|
||||||
#include <memory>
|
#include <memory>
|
||||||
|
#include <string>
|
||||||
|
#include <tuple>
|
||||||
#include <unordered_map>
|
#include <unordered_map>
|
||||||
#include <utility>
|
#include <utility>
|
||||||
#include <vector>
|
#include <vector>
|
||||||
@@ -22,34 +24,55 @@ struct ContextState {
|
|||||||
float token_score;
|
float token_score;
|
||||||
float node_score;
|
float node_score;
|
||||||
float output_score;
|
float output_score;
|
||||||
|
int32_t level;
|
||||||
|
float ac_threshold;
|
||||||
bool is_end;
|
bool is_end;
|
||||||
|
std::string phrase;
|
||||||
std::unordered_map<int32_t, std::unique_ptr<ContextState>> next;
|
std::unordered_map<int32_t, std::unique_ptr<ContextState>> next;
|
||||||
const ContextState *fail = nullptr;
|
const ContextState *fail = nullptr;
|
||||||
const ContextState *output = nullptr;
|
const ContextState *output = nullptr;
|
||||||
|
|
||||||
ContextState() = default;
|
ContextState() = default;
|
||||||
ContextState(int32_t token, float token_score, float node_score,
|
ContextState(int32_t token, float token_score, float node_score,
|
||||||
float output_score, bool is_end)
|
float output_score, int32_t level = 0, float ac_threshold = 0.0f,
|
||||||
|
bool is_end = false, const std::string &phrase = {})
|
||||||
: token(token),
|
: token(token),
|
||||||
token_score(token_score),
|
token_score(token_score),
|
||||||
node_score(node_score),
|
node_score(node_score),
|
||||||
output_score(output_score),
|
output_score(output_score),
|
||||||
is_end(is_end) {}
|
level(level),
|
||||||
|
ac_threshold(ac_threshold),
|
||||||
|
is_end(is_end),
|
||||||
|
phrase(phrase) {}
|
||||||
};
|
};
|
||||||
|
|
||||||
class ContextGraph {
|
class ContextGraph {
|
||||||
public:
|
public:
|
||||||
ContextGraph() = default;
|
ContextGraph() = default;
|
||||||
ContextGraph(const std::vector<std::vector<int32_t>> &token_ids,
|
ContextGraph(const std::vector<std::vector<int32_t>> &token_ids,
|
||||||
float context_score)
|
float context_score, float ac_threshold,
|
||||||
: context_score_(context_score) {
|
const std::vector<float> &scores = {},
|
||||||
root_ = std::make_unique<ContextState>(-1, 0, 0, 0, false);
|
const std::vector<std::string> &phrases = {},
|
||||||
|
const std::vector<float> &ac_thresholds = {})
|
||||||
|
: context_score_(context_score), ac_threshold_(ac_threshold) {
|
||||||
|
root_ = std::make_unique<ContextState>(-1, 0, 0, 0);
|
||||||
root_->fail = root_.get();
|
root_->fail = root_.get();
|
||||||
Build(token_ids);
|
Build(token_ids, scores, phrases, ac_thresholds);
|
||||||
}
|
}
|
||||||
|
|
||||||
std::pair<float, const ContextState *> ForwardOneStep(
|
ContextGraph(const std::vector<std::vector<int32_t>> &token_ids,
|
||||||
const ContextState *state, int32_t token_id) const;
|
float context_score, const std::vector<float> &scores = {},
|
||||||
|
const std::vector<std::string> &phrases = {})
|
||||||
|
: ContextGraph(token_ids, context_score, 0.0f, scores, phrases,
|
||||||
|
std::vector<float>()) {}
|
||||||
|
|
||||||
|
std::tuple<float, const ContextState *, const ContextState *> ForwardOneStep(
|
||||||
|
const ContextState *state, int32_t token_id,
|
||||||
|
bool strict_mode = true) const;
|
||||||
|
|
||||||
|
std::pair<bool, const ContextState *> IsMatched(
|
||||||
|
const ContextState *state) const;
|
||||||
|
|
||||||
std::pair<float, const ContextState *> Finalize(
|
std::pair<float, const ContextState *> Finalize(
|
||||||
const ContextState *state) const;
|
const ContextState *state) const;
|
||||||
|
|
||||||
@@ -57,8 +80,12 @@ class ContextGraph {
|
|||||||
|
|
||||||
private:
|
private:
|
||||||
float context_score_;
|
float context_score_;
|
||||||
|
float ac_threshold_;
|
||||||
std::unique_ptr<ContextState> root_;
|
std::unique_ptr<ContextState> root_;
|
||||||
void Build(const std::vector<std::vector<int32_t>> &token_ids) const;
|
void Build(const std::vector<std::vector<int32_t>> &token_ids,
|
||||||
|
const std::vector<float> &scores,
|
||||||
|
const std::vector<std::string> &phrases,
|
||||||
|
const std::vector<float> &ac_thresholds) const;
|
||||||
void FillFailOutput() const;
|
void FillFailOutput() const;
|
||||||
};
|
};
|
||||||
|
|
||||||
|
|||||||
@@ -28,6 +28,10 @@ struct Hypothesis {
|
|||||||
// on which ys[i] is decoded.
|
// on which ys[i] is decoded.
|
||||||
std::vector<int32_t> timestamps;
|
std::vector<int32_t> timestamps;
|
||||||
|
|
||||||
|
// The acoustic probability for each token in ys.
|
||||||
|
// Only used for keyword spotting task.
|
||||||
|
std::vector<float> ys_probs;
|
||||||
|
|
||||||
// The total score of ys in log space.
|
// The total score of ys in log space.
|
||||||
// It contains only acoustic scores
|
// It contains only acoustic scores
|
||||||
double log_prob = 0;
|
double log_prob = 0;
|
||||||
|
|||||||
33
sherpa-onnx/csrc/keyword-spotter-impl.cc
Normal file
@@ -0,0 +1,33 @@
|
|||||||
|
// sherpa-onnx/csrc/keyword-spotter-impl.cc
|
||||||
|
//
|
||||||
|
// Copyright (c) 2023-2024 Xiaomi Corporation
|
||||||
|
|
||||||
|
#include "sherpa-onnx/csrc/keyword-spotter-impl.h"
|
||||||
|
|
||||||
|
#include "sherpa-onnx/csrc/keyword-spotter-transducer-impl.h"
|
||||||
|
|
||||||
|
namespace sherpa_onnx {
|
||||||
|
|
||||||
|
std::unique_ptr<KeywordSpotterImpl> KeywordSpotterImpl::Create(
|
||||||
|
const KeywordSpotterConfig &config) {
|
||||||
|
if (!config.model_config.transducer.encoder.empty()) {
|
||||||
|
return std::make_unique<KeywordSpotterTransducerImpl>(config);
|
||||||
|
}
|
||||||
|
|
||||||
|
SHERPA_ONNX_LOGE("Please specify a model");
|
||||||
|
exit(-1);
|
||||||
|
}
|
||||||
|
|
||||||
|
#if __ANDROID_API__ >= 9
|
||||||
|
std::unique_ptr<KeywordSpotterImpl> KeywordSpotterImpl::Create(
|
||||||
|
AAssetManager *mgr, const KeywordSpotterConfig &config) {
|
||||||
|
if (!config.model_config.transducer.encoder.empty()) {
|
||||||
|
return std::make_unique<KeywordSpotterTransducerImpl>(mgr, config);
|
||||||
|
}
|
||||||
|
|
||||||
|
SHERPA_ONNX_LOGE("Please specify a model");
|
||||||
|
exit(-1);
|
||||||
|
}
|
||||||
|
#endif
|
||||||
|
|
||||||
|
} // namespace sherpa_onnx
|
||||||
48
sherpa-onnx/csrc/keyword-spotter-impl.h
Normal file
@@ -0,0 +1,48 @@
|
|||||||
|
// sherpa-onnx/csrc/keyword-spotter-impl.h
|
||||||
|
//
|
||||||
|
// Copyright (c) 2023-2024 Xiaomi Corporation
|
||||||
|
|
||||||
|
#ifndef SHERPA_ONNX_CSRC_KEYWORD_SPOTTER_IMPL_H_
|
||||||
|
#define SHERPA_ONNX_CSRC_KEYWORD_SPOTTER_IMPL_H_
|
||||||
|
|
||||||
|
#include <memory>
|
||||||
|
#include <string>
|
||||||
|
#include <vector>
|
||||||
|
|
||||||
|
#if __ANDROID_API__ >= 9
|
||||||
|
#include "android/asset_manager.h"
|
||||||
|
#include "android/asset_manager_jni.h"
|
||||||
|
#endif
|
||||||
|
|
||||||
|
#include "sherpa-onnx/csrc/keyword-spotter.h"
|
||||||
|
#include "sherpa-onnx/csrc/online-stream.h"
|
||||||
|
|
||||||
|
namespace sherpa_onnx {
|
||||||
|
|
||||||
|
class KeywordSpotterImpl {
|
||||||
|
public:
|
||||||
|
static std::unique_ptr<KeywordSpotterImpl> Create(
|
||||||
|
const KeywordSpotterConfig &config);
|
||||||
|
|
||||||
|
#if __ANDROID_API__ >= 9
|
||||||
|
static std::unique_ptr<KeywordSpotterImpl> Create(
|
||||||
|
AAssetManager *mgr, const KeywordSpotterConfig &config);
|
||||||
|
#endif
|
||||||
|
|
||||||
|
virtual ~KeywordSpotterImpl() = default;
|
||||||
|
|
||||||
|
virtual std::unique_ptr<OnlineStream> CreateStream() const = 0;
|
||||||
|
|
||||||
|
virtual std::unique_ptr<OnlineStream> CreateStream(
|
||||||
|
const std::string &keywords) const = 0;
|
||||||
|
|
||||||
|
virtual bool IsReady(OnlineStream *s) const = 0;
|
||||||
|
|
||||||
|
virtual void DecodeStreams(OnlineStream **ss, int32_t n) const = 0;
|
||||||
|
|
||||||
|
virtual KeywordResult GetResult(OnlineStream *s) const = 0;
|
||||||
|
};
|
||||||
|
|
||||||
|
} // namespace sherpa_onnx
|
||||||
|
|
||||||
|
#endif // SHERPA_ONNX_CSRC_KEYWORD_SPOTTER_IMPL_H_
|
||||||
323
sherpa-onnx/csrc/keyword-spotter-transducer-impl.h
Normal file
@@ -0,0 +1,323 @@
|
|||||||
|
// sherpa-onnx/csrc/keyword-spotter-transducer-impl.h
|
||||||
|
//
|
||||||
|
// Copyright (c) 2023-2024 Xiaomi Corporation
|
||||||
|
|
||||||
|
#ifndef SHERPA_ONNX_CSRC_KEYWORD_SPOTTER_TRANSDUCER_IMPL_H_
|
||||||
|
#define SHERPA_ONNX_CSRC_KEYWORD_SPOTTER_TRANSDUCER_IMPL_H_
|
||||||
|
|
||||||
|
#include <algorithm>
|
||||||
|
#include <memory>
|
||||||
|
#include <regex> // NOLINT
|
||||||
|
#include <string>
|
||||||
|
#include <utility>
|
||||||
|
#include <vector>
|
||||||
|
|
||||||
|
#if __ANDROID_API__ >= 9
|
||||||
|
#include <strstream>
|
||||||
|
|
||||||
|
#include "android/asset_manager.h"
|
||||||
|
#include "android/asset_manager_jni.h"
|
||||||
|
#endif
|
||||||
|
|
||||||
|
#include "sherpa-onnx/csrc/file-utils.h"
|
||||||
|
#include "sherpa-onnx/csrc/keyword-spotter-impl.h"
|
||||||
|
#include "sherpa-onnx/csrc/keyword-spotter.h"
|
||||||
|
#include "sherpa-onnx/csrc/macros.h"
|
||||||
|
#include "sherpa-onnx/csrc/online-transducer-model.h"
|
||||||
|
#include "sherpa-onnx/csrc/symbol-table.h"
|
||||||
|
#include "sherpa-onnx/csrc/transducer-keyword-decoder.h"
|
||||||
|
#include "sherpa-onnx/csrc/utils.h"
|
||||||
|
|
||||||
|
namespace sherpa_onnx {
|
||||||
|
|
||||||
|
static KeywordResult Convert(const TransducerKeywordResult &src,
|
||||||
|
const SymbolTable &sym_table, float frame_shift_ms,
|
||||||
|
int32_t subsampling_factor,
|
||||||
|
int32_t frames_since_start) {
|
||||||
|
KeywordResult r;
|
||||||
|
r.tokens.reserve(src.tokens.size());
|
||||||
|
r.timestamps.reserve(src.tokens.size());
|
||||||
|
r.keyword = src.keyword;
|
||||||
|
bool from_tokens = src.keyword.empty();
|
||||||
|
|
||||||
|
for (auto i : src.tokens) {
|
||||||
|
auto sym = sym_table[i];
|
||||||
|
if (from_tokens) {
|
||||||
|
r.keyword.append(sym);
|
||||||
|
}
|
||||||
|
r.tokens.push_back(std::move(sym));
|
||||||
|
}
|
||||||
|
if (from_tokens && r.keyword.size()) {
|
||||||
|
r.keyword = r.keyword.substr(1);
|
||||||
|
}
|
||||||
|
|
||||||
|
float frame_shift_s = frame_shift_ms / 1000. * subsampling_factor;
|
||||||
|
for (auto t : src.timestamps) {
|
||||||
|
float time = frame_shift_s * t;
|
||||||
|
r.timestamps.push_back(time);
|
||||||
|
}
|
||||||
|
|
||||||
|
r.start_time = frames_since_start * frame_shift_ms / 1000.;
|
||||||
|
|
||||||
|
return r;
|
||||||
|
}
|
||||||
|
|
||||||
|
class KeywordSpotterTransducerImpl : public KeywordSpotterImpl {
|
||||||
|
public:
|
||||||
|
explicit KeywordSpotterTransducerImpl(const KeywordSpotterConfig &config)
|
||||||
|
: config_(config),
|
||||||
|
model_(OnlineTransducerModel::Create(config.model_config)),
|
||||||
|
sym_(config.model_config.tokens) {
|
||||||
|
if (sym_.contains("<unk>")) {
|
||||||
|
unk_id_ = sym_["<unk>"];
|
||||||
|
}
|
||||||
|
|
||||||
|
InitKeywords();
|
||||||
|
|
||||||
|
decoder_ = std::make_unique<TransducerKeywordDecoder>(
|
||||||
|
model_.get(), config_.max_active_paths, config_.num_trailing_blanks,
|
||||||
|
unk_id_);
|
||||||
|
}
|
||||||
|
|
||||||
|
#if __ANDROID_API__ >= 9
|
||||||
|
KeywordSpotterTransducerImpl(AAssetManager *mgr,
|
||||||
|
const KeywordSpotterConfig &config)
|
||||||
|
: config_(config),
|
||||||
|
model_(OnlineTransducerModel::Create(mgr, config.model_config)),
|
||||||
|
sym_(mgr, config.model_config.tokens) {
|
||||||
|
if (sym_.contains("<unk>")) {
|
||||||
|
unk_id_ = sym_["<unk>"];
|
||||||
|
}
|
||||||
|
|
||||||
|
InitKeywords(mgr);
|
||||||
|
|
||||||
|
decoder_ = std::make_unique<TransducerKeywordDecoder>(
|
||||||
|
model_.get(), config_.max_active_paths, config_.num_trailing_blanks,
|
||||||
|
unk_id_);
|
||||||
|
}
|
||||||
|
#endif
|
||||||
|
|
||||||
|
std::unique_ptr<OnlineStream> CreateStream() const override {
|
||||||
|
auto stream =
|
||||||
|
std::make_unique<OnlineStream>(config_.feat_config, keywords_graph_);
|
||||||
|
InitOnlineStream(stream.get());
|
||||||
|
return stream;
|
||||||
|
}
|
||||||
|
|
||||||
|
std::unique_ptr<OnlineStream> CreateStream(
|
||||||
|
const std::string &keywords) const override {
|
||||||
|
auto kws = std::regex_replace(keywords, std::regex("/"), "\n");
|
||||||
|
std::istringstream is(kws);
|
||||||
|
|
||||||
|
std::vector<std::vector<int32_t>> current_ids;
|
||||||
|
std::vector<std::string> current_kws;
|
||||||
|
std::vector<float> current_scores;
|
||||||
|
std::vector<float> current_thresholds;
|
||||||
|
|
||||||
|
if (!EncodeKeywords(is, sym_, ¤t_ids, ¤t_kws, ¤t_scores,
|
||||||
|
¤t_thresholds)) {
|
||||||
|
SHERPA_ONNX_LOGE("Encode keywords %s failed.", keywords.c_str());
|
||||||
|
return nullptr;
|
||||||
|
}
|
||||||
|
|
||||||
|
int32_t num_kws = current_ids.size();
|
||||||
|
int32_t num_default_kws = keywords_id_.size();
|
||||||
|
|
||||||
|
current_ids.insert(current_ids.end(), keywords_id_.begin(),
|
||||||
|
keywords_id_.end());
|
||||||
|
|
||||||
|
if (!current_kws.empty() && !keywords_.empty()) {
|
||||||
|
current_kws.insert(current_kws.end(), keywords_.begin(), keywords_.end());
|
||||||
|
} else if (!current_kws.empty() && keywords_.empty()) {
|
||||||
|
current_kws.insert(current_kws.end(), num_default_kws, std::string());
|
||||||
|
} else if (current_kws.empty() && !keywords_.empty()) {
|
||||||
|
current_kws.insert(current_kws.end(), num_kws, std::string());
|
||||||
|
current_kws.insert(current_kws.end(), keywords_.begin(), keywords_.end());
|
||||||
|
} else {
|
||||||
|
// Do nothing.
|
||||||
|
}
|
||||||
|
|
||||||
|
if (!current_scores.empty() && !boost_scores_.empty()) {
|
||||||
|
current_scores.insert(current_scores.end(), boost_scores_.begin(),
|
||||||
|
boost_scores_.end());
|
||||||
|
} else if (!current_scores.empty() && boost_scores_.empty()) {
|
||||||
|
current_scores.insert(current_scores.end(), num_default_kws,
|
||||||
|
config_.keywords_score);
|
||||||
|
} else if (current_scores.empty() && !boost_scores_.empty()) {
|
||||||
|
current_scores.insert(current_scores.end(), num_kws,
|
||||||
|
config_.keywords_score);
|
||||||
|
current_scores.insert(current_scores.end(), boost_scores_.begin(),
|
||||||
|
boost_scores_.end());
|
||||||
|
} else {
|
||||||
|
// Do nothing.
|
||||||
|
}
|
||||||
|
|
||||||
|
if (!current_thresholds.empty() && !thresholds_.empty()) {
|
||||||
|
current_thresholds.insert(current_thresholds.end(), thresholds_.begin(),
|
||||||
|
thresholds_.end());
|
||||||
|
} else if (!current_thresholds.empty() && thresholds_.empty()) {
|
||||||
|
current_thresholds.insert(current_thresholds.end(), num_default_kws,
|
||||||
|
config_.keywords_threshold);
|
||||||
|
} else if (current_thresholds.empty() && !thresholds_.empty()) {
|
||||||
|
current_thresholds.insert(current_thresholds.end(), num_kws,
|
||||||
|
config_.keywords_threshold);
|
||||||
|
current_thresholds.insert(current_thresholds.end(), thresholds_.begin(),
|
||||||
|
thresholds_.end());
|
||||||
|
} else {
|
||||||
|
// Do nothing.
|
||||||
|
}
|
||||||
|
|
||||||
|
auto keywords_graph = std::make_shared<ContextGraph>(
|
||||||
|
current_ids, config_.keywords_score, config_.keywords_threshold,
|
||||||
|
current_scores, current_kws, current_thresholds);
|
||||||
|
|
||||||
|
auto stream =
|
||||||
|
std::make_unique<OnlineStream>(config_.feat_config, keywords_graph);
|
||||||
|
InitOnlineStream(stream.get());
|
||||||
|
return stream;
|
||||||
|
}
|
||||||
|
|
||||||
|
bool IsReady(OnlineStream *s) const override {
|
||||||
|
return s->GetNumProcessedFrames() + model_->ChunkSize() <
|
||||||
|
s->NumFramesReady();
|
||||||
|
}
|
||||||
|
|
||||||
|
void DecodeStreams(OnlineStream **ss, int32_t n) const override {
|
||||||
|
int32_t chunk_size = model_->ChunkSize();
|
||||||
|
int32_t chunk_shift = model_->ChunkShift();
|
||||||
|
|
||||||
|
int32_t feature_dim = ss[0]->FeatureDim();
|
||||||
|
|
||||||
|
std::vector<TransducerKeywordResult> results(n);
|
||||||
|
std::vector<float> features_vec(n * chunk_size * feature_dim);
|
||||||
|
std::vector<std::vector<Ort::Value>> states_vec(n);
|
||||||
|
std::vector<int64_t> all_processed_frames(n);
|
||||||
|
|
||||||
|
for (int32_t i = 0; i != n; ++i) {
|
||||||
|
SHERPA_ONNX_CHECK(ss[i]->GetContextGraph() != nullptr);
|
||||||
|
|
||||||
|
const auto num_processed_frames = ss[i]->GetNumProcessedFrames();
|
||||||
|
std::vector<float> features =
|
||||||
|
ss[i]->GetFrames(num_processed_frames, chunk_size);
|
||||||
|
|
||||||
|
// Question: should num_processed_frames include chunk_shift?
|
||||||
|
ss[i]->GetNumProcessedFrames() += chunk_shift;
|
||||||
|
|
||||||
|
std::copy(features.begin(), features.end(),
|
||||||
|
features_vec.data() + i * chunk_size * feature_dim);
|
||||||
|
|
||||||
|
results[i] = std::move(ss[i]->GetKeywordResult());
|
||||||
|
states_vec[i] = std::move(ss[i]->GetStates());
|
||||||
|
all_processed_frames[i] = num_processed_frames;
|
||||||
|
}
|
||||||
|
|
||||||
|
auto memory_info =
|
||||||
|
Ort::MemoryInfo::CreateCpu(OrtDeviceAllocator, OrtMemTypeDefault);
|
||||||
|
|
||||||
|
std::array<int64_t, 3> x_shape{n, chunk_size, feature_dim};
|
||||||
|
|
||||||
|
Ort::Value x = Ort::Value::CreateTensor(memory_info, features_vec.data(),
|
||||||
|
features_vec.size(), x_shape.data(),
|
||||||
|
x_shape.size());
|
||||||
|
|
||||||
|
std::array<int64_t, 1> processed_frames_shape{
|
||||||
|
static_cast<int64_t>(all_processed_frames.size())};
|
||||||
|
|
||||||
|
Ort::Value processed_frames = Ort::Value::CreateTensor(
|
||||||
|
memory_info, all_processed_frames.data(), all_processed_frames.size(),
|
||||||
|
processed_frames_shape.data(), processed_frames_shape.size());
|
||||||
|
|
||||||
|
auto states = model_->StackStates(states_vec);
|
||||||
|
|
||||||
|
auto pair = model_->RunEncoder(std::move(x), std::move(states),
|
||||||
|
std::move(processed_frames));
|
||||||
|
|
||||||
|
decoder_->Decode(std::move(pair.first), ss, &results);
|
||||||
|
|
||||||
|
std::vector<std::vector<Ort::Value>> next_states =
|
||||||
|
model_->UnStackStates(pair.second);
|
||||||
|
|
||||||
|
for (int32_t i = 0; i != n; ++i) {
|
||||||
|
ss[i]->SetKeywordResult(results[i]);
|
||||||
|
ss[i]->SetStates(std::move(next_states[i]));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
KeywordResult GetResult(OnlineStream *s) const override {
|
||||||
|
TransducerKeywordResult decoder_result = s->GetKeywordResult(true);
|
||||||
|
|
||||||
|
// TODO(fangjun): Remember to change these constants if needed
|
||||||
|
int32_t frame_shift_ms = 10;
|
||||||
|
int32_t subsampling_factor = 4;
|
||||||
|
return Convert(decoder_result, sym_, frame_shift_ms, subsampling_factor,
|
||||||
|
s->GetNumFramesSinceStart());
|
||||||
|
}
|
||||||
|
|
||||||
|
private:
|
||||||
|
void InitKeywords(std::istream &is) {
|
||||||
|
if (!EncodeKeywords(is, sym_, &keywords_id_, &keywords_, &boost_scores_,
|
||||||
|
&thresholds_)) {
|
||||||
|
SHERPA_ONNX_LOGE("Encode keywords failed.");
|
||||||
|
exit(-1);
|
||||||
|
}
|
||||||
|
keywords_graph_ = std::make_shared<ContextGraph>(
|
||||||
|
keywords_id_, config_.keywords_score, config_.keywords_threshold,
|
||||||
|
boost_scores_, keywords_, thresholds_);
|
||||||
|
}
|
||||||
|
|
||||||
|
void InitKeywords() {
|
||||||
|
// each line in keywords_file contains space-separated words
|
||||||
|
|
||||||
|
std::ifstream is(config_.keywords_file);
|
||||||
|
if (!is) {
|
||||||
|
SHERPA_ONNX_LOGE("Open keywords file failed: %s",
|
||||||
|
config_.keywords_file.c_str());
|
||||||
|
exit(-1);
|
||||||
|
}
|
||||||
|
InitKeywords(is);
|
||||||
|
}
|
||||||
|
|
||||||
|
#if __ANDROID_API__ >= 9
|
||||||
|
void InitKeywords(AAssetManager *mgr) {
|
||||||
|
// each line in keywords_file contains space-separated words
|
||||||
|
|
||||||
|
auto buf = ReadFile(mgr, config_.keywords_file);
|
||||||
|
|
||||||
|
std::istrstream is(buf.data(), buf.size());
|
||||||
|
|
||||||
|
if (!is) {
|
||||||
|
SHERPA_ONNX_LOGE("Open keywords file failed: %s",
|
||||||
|
config_.keywords_file.c_str());
|
||||||
|
exit(-1);
|
||||||
|
}
|
||||||
|
InitKeywords(is);
|
||||||
|
}
|
||||||
|
#endif
|
||||||
|
|
||||||
|
void InitOnlineStream(OnlineStream *stream) const {
|
||||||
|
auto r = decoder_->GetEmptyResult();
|
||||||
|
SHERPA_ONNX_CHECK_EQ(r.hyps.size(), 1);
|
||||||
|
|
||||||
|
SHERPA_ONNX_CHECK(stream->GetContextGraph() != nullptr);
|
||||||
|
r.hyps.begin()->second.context_state = stream->GetContextGraph()->Root();
|
||||||
|
|
||||||
|
stream->SetKeywordResult(r);
|
||||||
|
stream->SetStates(model_->GetEncoderInitStates());
|
||||||
|
}
|
||||||
|
|
||||||
|
private:
|
||||||
|
KeywordSpotterConfig config_;
|
||||||
|
std::vector<std::vector<int32_t>> keywords_id_;
|
||||||
|
std::vector<float> boost_scores_;
|
||||||
|
std::vector<float> thresholds_;
|
||||||
|
std::vector<std::string> keywords_;
|
||||||
|
ContextGraphPtr keywords_graph_;
|
||||||
|
std::unique_ptr<OnlineTransducerModel> model_;
|
||||||
|
std::unique_ptr<TransducerKeywordDecoder> decoder_;
|
||||||
|
SymbolTable sym_;
|
||||||
|
int32_t unk_id_ = -1;
|
||||||
|
};
|
||||||
|
|
||||||
|
} // namespace sherpa_onnx
|
||||||
|
|
||||||
|
#endif // SHERPA_ONNX_CSRC_KEYWORD_SPOTTER_TRANSDUCER_IMPL_H_
|
||||||
152
sherpa-onnx/csrc/keyword-spotter.cc
Normal file
@@ -0,0 +1,152 @@
|
|||||||
|
// sherpa-onnx/csrc/keyword-spotter.cc
|
||||||
|
//
|
||||||
|
// Copyright (c) 2023-2024 Xiaomi Corporation
|
||||||
|
|
||||||
|
#include "sherpa-onnx/csrc/keyword-spotter.h"
|
||||||
|
|
||||||
|
#include <assert.h>
|
||||||
|
|
||||||
|
#include <algorithm>
|
||||||
|
#include <fstream>
|
||||||
|
#include <iomanip>
|
||||||
|
#include <memory>
|
||||||
|
#include <sstream>
|
||||||
|
#include <utility>
|
||||||
|
#include <vector>
|
||||||
|
|
||||||
|
#include "sherpa-onnx/csrc/keyword-spotter-impl.h"
|
||||||
|
|
||||||
|
namespace sherpa_onnx {
|
||||||
|
|
||||||
|
std::string KeywordResult::AsJsonString() const {
|
||||||
|
std::ostringstream os;
|
||||||
|
os << "{";
|
||||||
|
os << "\"start_time\":" << std::fixed << std::setprecision(2) << start_time
|
||||||
|
<< ", ";
|
||||||
|
|
||||||
|
os << "\"keyword\""
|
||||||
|
<< ": ";
|
||||||
|
os << "\"" << keyword << "\""
|
||||||
|
<< ", ";
|
||||||
|
|
||||||
|
os << "\""
|
||||||
|
<< "timestamps"
|
||||||
|
<< "\""
|
||||||
|
<< ": ";
|
||||||
|
os << "[";
|
||||||
|
|
||||||
|
std::string sep = "";
|
||||||
|
for (auto t : timestamps) {
|
||||||
|
os << sep << std::fixed << std::setprecision(2) << t;
|
||||||
|
sep = ", ";
|
||||||
|
}
|
||||||
|
os << "], ";
|
||||||
|
|
||||||
|
os << "\""
|
||||||
|
<< "tokens"
|
||||||
|
<< "\""
|
||||||
|
<< ":";
|
||||||
|
os << "[";
|
||||||
|
|
||||||
|
sep = "";
|
||||||
|
auto oldFlags = os.flags();
|
||||||
|
for (const auto &t : tokens) {
|
||||||
|
if (t.size() == 1 && static_cast<uint8_t>(t[0]) > 0x7f) {
|
||||||
|
const uint8_t *p = reinterpret_cast<const uint8_t *>(t.c_str());
|
||||||
|
os << sep << "\""
|
||||||
|
<< "<0x" << std::hex << std::uppercase << static_cast<uint32_t>(p[0])
|
||||||
|
<< ">"
|
||||||
|
<< "\"";
|
||||||
|
os.flags(oldFlags);
|
||||||
|
} else {
|
||||||
|
os << sep << "\"" << t << "\"";
|
||||||
|
}
|
||||||
|
sep = ", ";
|
||||||
|
}
|
||||||
|
os << "]";
|
||||||
|
os << "}";
|
||||||
|
|
||||||
|
return os.str();
|
||||||
|
}
|
||||||
|
|
||||||
|
void KeywordSpotterConfig::Register(ParseOptions *po) {
|
||||||
|
feat_config.Register(po);
|
||||||
|
model_config.Register(po);
|
||||||
|
|
||||||
|
po->Register("max-active-paths", &max_active_paths,
|
||||||
|
"beam size used in modified beam search.");
|
||||||
|
po->Register("num-trailing-blanks", &num_trailing_blanks,
|
||||||
|
"The number of trailing blanks should have after the keyword.");
|
||||||
|
po->Register("keywords-score", &keywords_score,
|
||||||
|
"The bonus score for each token in context word/phrase.");
|
||||||
|
po->Register("keywords-threshold", &keywords_threshold,
|
||||||
|
"The acoustic threshold (probability) to trigger the keywords.");
|
||||||
|
po->Register(
|
||||||
|
"keywords-file", &keywords_file,
|
||||||
|
"The file containing keywords, one word/phrase per line, and for each"
|
||||||
|
"phrase the bpe/cjkchar are separated by a space. For example: "
|
||||||
|
"▁HE LL O ▁WORLD"
|
||||||
|
"你 好 世 界");
|
||||||
|
}
|
||||||
|
|
||||||
|
bool KeywordSpotterConfig::Validate() const {
|
||||||
|
if (keywords_file.empty()) {
|
||||||
|
SHERPA_ONNX_LOGE("Please provide --keywords-file.");
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
if (!std::ifstream(keywords_file.c_str()).good()) {
|
||||||
|
SHERPA_ONNX_LOGE("Keywords file %s does not exist.", keywords_file.c_str());
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
return model_config.Validate();
|
||||||
|
}
|
||||||
|
|
||||||
|
std::string KeywordSpotterConfig::ToString() const {
|
||||||
|
std::ostringstream os;
|
||||||
|
|
||||||
|
os << "KeywordSpotterConfig(";
|
||||||
|
os << "feat_config=" << feat_config.ToString() << ", ";
|
||||||
|
os << "model_config=" << model_config.ToString() << ", ";
|
||||||
|
os << "max_active_paths=" << max_active_paths << ", ";
|
||||||
|
os << "num_trailing_blanks=" << num_trailing_blanks << ", ";
|
||||||
|
os << "keywords_score=" << keywords_score << ", ";
|
||||||
|
os << "keywords_threshold=" << keywords_threshold << ", ";
|
||||||
|
os << "keywords_file=\"" << keywords_file << "\")";
|
||||||
|
|
||||||
|
return os.str();
|
||||||
|
}
|
||||||
|
|
||||||
|
KeywordSpotter::KeywordSpotter(const KeywordSpotterConfig &config)
|
||||||
|
: impl_(KeywordSpotterImpl::Create(config)) {}
|
||||||
|
|
||||||
|
#if __ANDROID_API__ >= 9
|
||||||
|
KeywordSpotter::KeywordSpotter(AAssetManager *mgr,
|
||||||
|
const KeywordSpotterConfig &config)
|
||||||
|
: impl_(KeywordSpotterImpl::Create(mgr, config)) {}
|
||||||
|
#endif
|
||||||
|
|
||||||
|
KeywordSpotter::~KeywordSpotter() = default;
|
||||||
|
|
||||||
|
std::unique_ptr<OnlineStream> KeywordSpotter::CreateStream() const {
|
||||||
|
return impl_->CreateStream();
|
||||||
|
}
|
||||||
|
|
||||||
|
std::unique_ptr<OnlineStream> KeywordSpotter::CreateStream(
|
||||||
|
const std::string &keywords) const {
|
||||||
|
return impl_->CreateStream(keywords);
|
||||||
|
}
|
||||||
|
|
||||||
|
bool KeywordSpotter::IsReady(OnlineStream *s) const {
|
||||||
|
return impl_->IsReady(s);
|
||||||
|
}
|
||||||
|
|
||||||
|
void KeywordSpotter::DecodeStreams(OnlineStream **ss, int32_t n) const {
|
||||||
|
impl_->DecodeStreams(ss, n);
|
||||||
|
}
|
||||||
|
|
||||||
|
KeywordResult KeywordSpotter::GetResult(OnlineStream *s) const {
|
||||||
|
return impl_->GetResult(s);
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace sherpa_onnx
|
||||||
148
sherpa-onnx/csrc/keyword-spotter.h
Normal file
@@ -0,0 +1,148 @@
|
|||||||
|
// sherpa-onnx/csrc/keyword-spotter.h
|
||||||
|
//
|
||||||
|
// Copyright (c) 2023-2024 Xiaomi Corporation
|
||||||
|
|
||||||
|
#ifndef SHERPA_ONNX_CSRC_KEYWORD_SPOTTER_H_
|
||||||
|
#define SHERPA_ONNX_CSRC_KEYWORD_SPOTTER_H_
|
||||||
|
|
||||||
|
#include <memory>
|
||||||
|
#include <string>
|
||||||
|
#include <vector>
|
||||||
|
|
||||||
|
#if __ANDROID_API__ >= 9
|
||||||
|
#include "android/asset_manager.h"
|
||||||
|
#include "android/asset_manager_jni.h"
|
||||||
|
#endif
|
||||||
|
|
||||||
|
#include "sherpa-onnx/csrc/features.h"
|
||||||
|
#include "sherpa-onnx/csrc/online-model-config.h"
|
||||||
|
#include "sherpa-onnx/csrc/online-stream.h"
|
||||||
|
#include "sherpa-onnx/csrc/online-transducer-model-config.h"
|
||||||
|
#include "sherpa-onnx/csrc/parse-options.h"
|
||||||
|
|
||||||
|
namespace sherpa_onnx {
|
||||||
|
|
||||||
|
struct KeywordResult {
|
||||||
|
/// The triggered keyword.
|
||||||
|
/// For English, it consists of space separated words.
|
||||||
|
/// For Chinese, it consists of Chinese words without spaces.
|
||||||
|
/// Example 1: "hello world"
|
||||||
|
/// Example 2: "你好世界"
|
||||||
|
std::string keyword;
|
||||||
|
|
||||||
|
/// Decoded results at the token level.
|
||||||
|
/// For instance, for BPE-based models it consists of a list of BPE tokens.
|
||||||
|
std::vector<std::string> tokens;
|
||||||
|
|
||||||
|
/// timestamps.size() == tokens.size()
|
||||||
|
/// timestamps[i] records the time in seconds when tokens[i] is decoded.
|
||||||
|
std::vector<float> timestamps;
|
||||||
|
|
||||||
|
/// Starting time of this segment.
|
||||||
|
/// When an endpoint is detected, it will change
|
||||||
|
float start_time = 0;
|
||||||
|
|
||||||
|
/** Return a json string.
|
||||||
|
*
|
||||||
|
* The returned string contains:
|
||||||
|
* {
|
||||||
|
* "keyword": "The triggered keyword",
|
||||||
|
* "tokens": [x, x, x],
|
||||||
|
* "timestamps": [x, x, x],
|
||||||
|
* "start_time": x,
|
||||||
|
* }
|
||||||
|
*/
|
||||||
|
std::string AsJsonString() const;
|
||||||
|
};
|
||||||
|
|
||||||
|
struct KeywordSpotterConfig {
|
||||||
|
FeatureExtractorConfig feat_config;
|
||||||
|
OnlineModelConfig model_config;
|
||||||
|
|
||||||
|
int32_t max_active_paths = 4;
|
||||||
|
|
||||||
|
int32_t num_trailing_blanks = 1;
|
||||||
|
|
||||||
|
float keywords_score = 1.0;
|
||||||
|
|
||||||
|
float keywords_threshold = 0.25;
|
||||||
|
|
||||||
|
std::string keywords_file;
|
||||||
|
|
||||||
|
KeywordSpotterConfig() = default;
|
||||||
|
|
||||||
|
KeywordSpotterConfig(const FeatureExtractorConfig &feat_config,
|
||||||
|
const OnlineModelConfig &model_config,
|
||||||
|
int32_t max_active_paths, int32_t num_trailing_blanks,
|
||||||
|
float keywords_score, float keywords_threshold,
|
||||||
|
const std::string &keywords_file)
|
||||||
|
: feat_config(feat_config),
|
||||||
|
model_config(model_config),
|
||||||
|
max_active_paths(max_active_paths),
|
||||||
|
num_trailing_blanks(num_trailing_blanks),
|
||||||
|
keywords_score(keywords_score),
|
||||||
|
keywords_threshold(keywords_threshold),
|
||||||
|
keywords_file(keywords_file) {}
|
||||||
|
|
||||||
|
void Register(ParseOptions *po);
|
||||||
|
bool Validate() const;
|
||||||
|
|
||||||
|
std::string ToString() const;
|
||||||
|
};
|
||||||
|
|
||||||
|
class KeywordSpotterImpl;
|
||||||
|
|
||||||
|
class KeywordSpotter {
|
||||||
|
public:
|
||||||
|
explicit KeywordSpotter(const KeywordSpotterConfig &config);
|
||||||
|
|
||||||
|
#if __ANDROID_API__ >= 9
|
||||||
|
KeywordSpotter(AAssetManager *mgr, const KeywordSpotterConfig &config);
|
||||||
|
#endif
|
||||||
|
|
||||||
|
~KeywordSpotter();
|
||||||
|
|
||||||
|
/** Create a stream for decoding.
|
||||||
|
*
|
||||||
|
*/
|
||||||
|
std::unique_ptr<OnlineStream> CreateStream() const;
|
||||||
|
|
||||||
|
/** Create a stream for decoding.
|
||||||
|
*
|
||||||
|
* @param The keywords for this string, it might contain several keywords,
|
||||||
|
* the keywords are separated by "/". In each of the keywords, there
|
||||||
|
* are cjkchars or bpes, the bpe/cjkchar are separated by space (" ").
|
||||||
|
* For example, keywords I LOVE YOU and HELLO WORLD, looks like:
|
||||||
|
*
|
||||||
|
* "▁I ▁LOVE ▁YOU/▁HE LL O ▁WORLD"
|
||||||
|
*/
|
||||||
|
std::unique_ptr<OnlineStream> CreateStream(const std::string &keywords) const;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Return true if the given stream has enough frames for decoding.
|
||||||
|
* Return false otherwise
|
||||||
|
*/
|
||||||
|
bool IsReady(OnlineStream *s) const;
|
||||||
|
|
||||||
|
/** Decode a single stream. */
|
||||||
|
void DecodeStream(OnlineStream *s) const {
|
||||||
|
OnlineStream *ss[1] = {s};
|
||||||
|
DecodeStreams(ss, 1);
|
||||||
|
}
|
||||||
|
|
||||||
|
/** Decode multiple streams in parallel
|
||||||
|
*
|
||||||
|
* @param ss Pointer array containing streams to be decoded.
|
||||||
|
* @param n Number of streams in `ss`.
|
||||||
|
*/
|
||||||
|
void DecodeStreams(OnlineStream **ss, int32_t n) const;
|
||||||
|
|
||||||
|
KeywordResult GetResult(OnlineStream *s) const;
|
||||||
|
|
||||||
|
private:
|
||||||
|
std::unique_ptr<KeywordSpotterImpl> impl_;
|
||||||
|
};
|
||||||
|
|
||||||
|
} // namespace sherpa_onnx
|
||||||
|
|
||||||
|
#endif // SHERPA_ONNX_CSRC_KEYWORD_SPOTTER_H_
|
||||||
@@ -93,8 +93,8 @@ OfflineTransducerModifiedBeamSearchDecoder::Decode(
|
|||||||
Repeat(model_->Allocator(), &cur_encoder_out, hyps_row_splits);
|
Repeat(model_->Allocator(), &cur_encoder_out, hyps_row_splits);
|
||||||
// now cur_encoder_out is of shape (num_hyps, joiner_dim)
|
// now cur_encoder_out is of shape (num_hyps, joiner_dim)
|
||||||
|
|
||||||
Ort::Value logit = model_->RunJoiner(
|
Ort::Value logit =
|
||||||
std::move(cur_encoder_out), View(&decoder_out));
|
model_->RunJoiner(std::move(cur_encoder_out), View(&decoder_out));
|
||||||
|
|
||||||
float *p_logit = logit.GetTensorMutableData<float>();
|
float *p_logit = logit.GetTensorMutableData<float>();
|
||||||
LogSoftmax(p_logit, vocab_size, num_hyps);
|
LogSoftmax(p_logit, vocab_size, num_hyps);
|
||||||
@@ -134,8 +134,8 @@ OfflineTransducerModifiedBeamSearchDecoder::Decode(
|
|||||||
if (context_graphs[i] != nullptr) {
|
if (context_graphs[i] != nullptr) {
|
||||||
auto context_res =
|
auto context_res =
|
||||||
context_graphs[i]->ForwardOneStep(context_state, new_token);
|
context_graphs[i]->ForwardOneStep(context_state, new_token);
|
||||||
context_score = context_res.first;
|
context_score = std::get<0>(context_res);
|
||||||
new_hyp.context_state = context_res.second;
|
new_hyp.context_state = std::get<1>(context_res);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -51,6 +51,25 @@ class OnlineStream::Impl {
|
|||||||
|
|
||||||
OnlineTransducerDecoderResult &GetResult() { return result_; }
|
OnlineTransducerDecoderResult &GetResult() { return result_; }
|
||||||
|
|
||||||
|
void SetKeywordResult(const TransducerKeywordResult &r) {
|
||||||
|
keyword_result_ = r;
|
||||||
|
}
|
||||||
|
TransducerKeywordResult &GetKeywordResult(bool remove_duplicates) {
|
||||||
|
if (remove_duplicates) {
|
||||||
|
if (!prev_keyword_result_.timestamps.empty() &&
|
||||||
|
!keyword_result_.timestamps.empty() &&
|
||||||
|
keyword_result_.timestamps[0] <=
|
||||||
|
prev_keyword_result_.timestamps.back()) {
|
||||||
|
return empty_keyword_result_;
|
||||||
|
} else {
|
||||||
|
prev_keyword_result_ = keyword_result_;
|
||||||
|
}
|
||||||
|
return keyword_result_;
|
||||||
|
} else {
|
||||||
|
return keyword_result_;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
OnlineCtcDecoderResult &GetCtcResult() { return ctc_result_; }
|
OnlineCtcDecoderResult &GetCtcResult() { return ctc_result_; }
|
||||||
|
|
||||||
void SetCtcResult(const OnlineCtcDecoderResult &r) { ctc_result_ = r; }
|
void SetCtcResult(const OnlineCtcDecoderResult &r) { ctc_result_ = r; }
|
||||||
@@ -93,6 +112,9 @@ class OnlineStream::Impl {
|
|||||||
int32_t start_frame_index_ = 0; // never reset
|
int32_t start_frame_index_ = 0; // never reset
|
||||||
int32_t segment_ = 0;
|
int32_t segment_ = 0;
|
||||||
OnlineTransducerDecoderResult result_;
|
OnlineTransducerDecoderResult result_;
|
||||||
|
TransducerKeywordResult prev_keyword_result_;
|
||||||
|
TransducerKeywordResult keyword_result_;
|
||||||
|
TransducerKeywordResult empty_keyword_result_;
|
||||||
OnlineCtcDecoderResult ctc_result_;
|
OnlineCtcDecoderResult ctc_result_;
|
||||||
std::vector<Ort::Value> states_; // states for transducer or ctc models
|
std::vector<Ort::Value> states_; // states for transducer or ctc models
|
||||||
std::vector<float> paraformer_feat_cache_;
|
std::vector<float> paraformer_feat_cache_;
|
||||||
@@ -149,6 +171,15 @@ OnlineTransducerDecoderResult &OnlineStream::GetResult() {
|
|||||||
return impl_->GetResult();
|
return impl_->GetResult();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
void OnlineStream::SetKeywordResult(const TransducerKeywordResult &r) {
|
||||||
|
impl_->SetKeywordResult(r);
|
||||||
|
}
|
||||||
|
|
||||||
|
TransducerKeywordResult &OnlineStream::GetKeywordResult(
|
||||||
|
bool remove_duplicates /*=false*/) {
|
||||||
|
return impl_->GetKeywordResult(remove_duplicates);
|
||||||
|
}
|
||||||
|
|
||||||
OnlineCtcDecoderResult &OnlineStream::GetCtcResult() {
|
OnlineCtcDecoderResult &OnlineStream::GetCtcResult() {
|
||||||
return impl_->GetCtcResult();
|
return impl_->GetCtcResult();
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -14,9 +14,11 @@
|
|||||||
#include "sherpa-onnx/csrc/online-ctc-decoder.h"
|
#include "sherpa-onnx/csrc/online-ctc-decoder.h"
|
||||||
#include "sherpa-onnx/csrc/online-paraformer-decoder.h"
|
#include "sherpa-onnx/csrc/online-paraformer-decoder.h"
|
||||||
#include "sherpa-onnx/csrc/online-transducer-decoder.h"
|
#include "sherpa-onnx/csrc/online-transducer-decoder.h"
|
||||||
|
#include "sherpa-onnx/csrc/transducer-keyword-decoder.h"
|
||||||
|
|
||||||
namespace sherpa_onnx {
|
namespace sherpa_onnx {
|
||||||
|
|
||||||
|
class TransducerKeywordResult;
|
||||||
class OnlineStream {
|
class OnlineStream {
|
||||||
public:
|
public:
|
||||||
explicit OnlineStream(const FeatureExtractorConfig &config = {},
|
explicit OnlineStream(const FeatureExtractorConfig &config = {},
|
||||||
@@ -76,6 +78,9 @@ class OnlineStream {
|
|||||||
void SetResult(const OnlineTransducerDecoderResult &r);
|
void SetResult(const OnlineTransducerDecoderResult &r);
|
||||||
OnlineTransducerDecoderResult &GetResult();
|
OnlineTransducerDecoderResult &GetResult();
|
||||||
|
|
||||||
|
void SetKeywordResult(const TransducerKeywordResult &r);
|
||||||
|
TransducerKeywordResult &GetKeywordResult(bool remove_duplicates = false);
|
||||||
|
|
||||||
void SetCtcResult(const OnlineCtcDecoderResult &r);
|
void SetCtcResult(const OnlineCtcDecoderResult &r);
|
||||||
OnlineCtcDecoderResult &GetCtcResult();
|
OnlineCtcDecoderResult &GetCtcResult();
|
||||||
|
|
||||||
@@ -92,7 +97,7 @@ class OnlineStream {
|
|||||||
*/
|
*/
|
||||||
const ContextGraphPtr &GetContextGraph() const;
|
const ContextGraphPtr &GetContextGraph() const;
|
||||||
|
|
||||||
// for streaming parformer
|
// for streaming paraformer
|
||||||
std::vector<float> &GetParaformerFeatCache();
|
std::vector<float> &GetParaformerFeatCache();
|
||||||
std::vector<float> &GetParaformerEncoderOutCache();
|
std::vector<float> &GetParaformerEncoderOutCache();
|
||||||
std::vector<float> &GetParaformerAlphaCache();
|
std::vector<float> &GetParaformerAlphaCache();
|
||||||
|
|||||||
@@ -75,7 +75,7 @@ void OnlineTransducerModifiedBeamSearchDecoder::Decode(
|
|||||||
encoder_out.GetTensorTypeAndShapeInfo().GetShape();
|
encoder_out.GetTensorTypeAndShapeInfo().GetShape();
|
||||||
|
|
||||||
if (encoder_out_shape[0] != result->size()) {
|
if (encoder_out_shape[0] != result->size()) {
|
||||||
fprintf(stderr,
|
SHERPA_ONNX_LOGE(
|
||||||
"Size mismatch! encoder_out.size(0) %d, result.size(0): %d\n",
|
"Size mismatch! encoder_out.size(0) %d, result.size(0): %d\n",
|
||||||
static_cast<int32_t>(encoder_out_shape[0]),
|
static_cast<int32_t>(encoder_out_shape[0]),
|
||||||
static_cast<int32_t>(result->size()));
|
static_cast<int32_t>(result->size()));
|
||||||
@@ -119,8 +119,8 @@ void OnlineTransducerModifiedBeamSearchDecoder::Decode(
|
|||||||
GetEncoderOutFrame(model_->Allocator(), &encoder_out, t);
|
GetEncoderOutFrame(model_->Allocator(), &encoder_out, t);
|
||||||
cur_encoder_out =
|
cur_encoder_out =
|
||||||
Repeat(model_->Allocator(), &cur_encoder_out, hyps_row_splits);
|
Repeat(model_->Allocator(), &cur_encoder_out, hyps_row_splits);
|
||||||
Ort::Value logit = model_->RunJoiner(
|
Ort::Value logit =
|
||||||
std::move(cur_encoder_out), View(&decoder_out));
|
model_->RunJoiner(std::move(cur_encoder_out), View(&decoder_out));
|
||||||
|
|
||||||
float *p_logit = logit.GetTensorMutableData<float>();
|
float *p_logit = logit.GetTensorMutableData<float>();
|
||||||
LogSoftmax(p_logit, vocab_size, num_hyps);
|
LogSoftmax(p_logit, vocab_size, num_hyps);
|
||||||
@@ -164,8 +164,8 @@ void OnlineTransducerModifiedBeamSearchDecoder::Decode(
|
|||||||
if (ss != nullptr && ss[b]->GetContextGraph() != nullptr) {
|
if (ss != nullptr && ss[b]->GetContextGraph() != nullptr) {
|
||||||
auto context_res = ss[b]->GetContextGraph()->ForwardOneStep(
|
auto context_res = ss[b]->GetContextGraph()->ForwardOneStep(
|
||||||
context_state, new_token);
|
context_state, new_token);
|
||||||
context_score = context_res.first;
|
context_score = std::get<0>(context_res);
|
||||||
new_hyp.context_state = context_res.second;
|
new_hyp.context_state = std::get<1>(context_res);
|
||||||
}
|
}
|
||||||
if (lm_) {
|
if (lm_) {
|
||||||
lm_->ComputeLMScore(lm_scale_, &new_hyp);
|
lm_->ComputeLMScore(lm_scale_, &new_hyp);
|
||||||
|
|||||||
122
sherpa-onnx/csrc/sherpa-onnx-keyword-spotter.cc
Normal file
@@ -0,0 +1,122 @@
|
|||||||
|
// sherpa-onnx/csrc/sherpa-onnx-keyword-spotter.cc
|
||||||
|
//
|
||||||
|
// Copyright (c) 2023-2024 Xiaomi Corporation
|
||||||
|
|
||||||
|
#include <stdio.h>
|
||||||
|
|
||||||
|
#include <iomanip>
|
||||||
|
#include <iostream>
|
||||||
|
#include <string>
|
||||||
|
#include <vector>
|
||||||
|
|
||||||
|
#include "sherpa-onnx/csrc/keyword-spotter.h"
|
||||||
|
#include "sherpa-onnx/csrc/online-stream.h"
|
||||||
|
#include "sherpa-onnx/csrc/parse-options.h"
|
||||||
|
#include "sherpa-onnx/csrc/symbol-table.h"
|
||||||
|
#include "sherpa-onnx/csrc/wave-reader.h"
|
||||||
|
|
||||||
|
typedef struct {
|
||||||
|
std::unique_ptr<sherpa_onnx::OnlineStream> online_stream;
|
||||||
|
std::string filename;
|
||||||
|
} Stream;
|
||||||
|
|
||||||
|
int main(int32_t argc, char *argv[]) {
|
||||||
|
const char *kUsageMessage = R"usage(
|
||||||
|
Usage:
|
||||||
|
|
||||||
|
(1) Streaming transducer
|
||||||
|
|
||||||
|
./bin/sherpa-onnx-keyword-spotter \
|
||||||
|
--tokens=/path/to/tokens.txt \
|
||||||
|
--encoder=/path/to/encoder.onnx \
|
||||||
|
--decoder=/path/to/decoder.onnx \
|
||||||
|
--joiner=/path/to/joiner.onnx \
|
||||||
|
--provider=cpu \
|
||||||
|
--num-threads=2 \
|
||||||
|
--keywords-file=keywords.txt \
|
||||||
|
/path/to/foo.wav [bar.wav foobar.wav ...]
|
||||||
|
|
||||||
|
Note: It supports decoding multiple files in batches
|
||||||
|
|
||||||
|
Default value for num_threads is 2.
|
||||||
|
Valid values for provider: cpu (default), cuda, coreml.
|
||||||
|
foo.wav should be of single channel, 16-bit PCM encoded wave file; its
|
||||||
|
sampling rate can be arbitrary and does not need to be 16kHz.
|
||||||
|
|
||||||
|
Please refer to
|
||||||
|
https://k2-fsa.github.io/sherpa/onnx/pretrained_models/index.html
|
||||||
|
for a list of pre-trained models to download.
|
||||||
|
)usage";
|
||||||
|
|
||||||
|
sherpa_onnx::ParseOptions po(kUsageMessage);
|
||||||
|
sherpa_onnx::KeywordSpotterConfig config;
|
||||||
|
|
||||||
|
config.Register(&po);
|
||||||
|
|
||||||
|
po.Read(argc, argv);
|
||||||
|
if (po.NumArgs() < 1) {
|
||||||
|
po.PrintUsage();
|
||||||
|
exit(EXIT_FAILURE);
|
||||||
|
}
|
||||||
|
|
||||||
|
fprintf(stderr, "%s\n", config.ToString().c_str());
|
||||||
|
|
||||||
|
if (!config.Validate()) {
|
||||||
|
fprintf(stderr, "Errors in config!\n");
|
||||||
|
return -1;
|
||||||
|
}
|
||||||
|
|
||||||
|
sherpa_onnx::KeywordSpotter keyword_spotter(config);
|
||||||
|
|
||||||
|
std::vector<Stream> ss;
|
||||||
|
|
||||||
|
for (int32_t i = 1; i <= po.NumArgs(); ++i) {
|
||||||
|
const std::string wav_filename = po.GetArg(i);
|
||||||
|
int32_t sampling_rate = -1;
|
||||||
|
|
||||||
|
bool is_ok = false;
|
||||||
|
const std::vector<float> samples =
|
||||||
|
sherpa_onnx::ReadWave(wav_filename, &sampling_rate, &is_ok);
|
||||||
|
|
||||||
|
if (!is_ok) {
|
||||||
|
fprintf(stderr, "Failed to read %s\n", wav_filename.c_str());
|
||||||
|
return -1;
|
||||||
|
}
|
||||||
|
|
||||||
|
auto s = keyword_spotter.CreateStream();
|
||||||
|
s->AcceptWaveform(sampling_rate, samples.data(), samples.size());
|
||||||
|
|
||||||
|
std::vector<float> tail_paddings(static_cast<int>(0.8 * sampling_rate));
|
||||||
|
// Note: We can call AcceptWaveform() multiple times.
|
||||||
|
s->AcceptWaveform(sampling_rate, tail_paddings.data(),
|
||||||
|
tail_paddings.size());
|
||||||
|
|
||||||
|
// Call InputFinished() to indicate that no audio samples are available
|
||||||
|
s->InputFinished();
|
||||||
|
ss.push_back({std::move(s), wav_filename});
|
||||||
|
}
|
||||||
|
|
||||||
|
std::vector<sherpa_onnx::OnlineStream *> ready_streams;
|
||||||
|
for (;;) {
|
||||||
|
ready_streams.clear();
|
||||||
|
for (auto &s : ss) {
|
||||||
|
const auto p_ss = s.online_stream.get();
|
||||||
|
if (keyword_spotter.IsReady(p_ss)) {
|
||||||
|
ready_streams.push_back(p_ss);
|
||||||
|
}
|
||||||
|
std::ostringstream os;
|
||||||
|
const auto r = keyword_spotter.GetResult(p_ss);
|
||||||
|
if (!r.keyword.empty()) {
|
||||||
|
os << s.filename << "\n";
|
||||||
|
os << r.AsJsonString() << "\n\n";
|
||||||
|
fprintf(stderr, "%s", os.str().c_str());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if (ready_streams.empty()) {
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
keyword_spotter.DecodeStreams(ready_streams.data(), ready_streams.size());
|
||||||
|
}
|
||||||
|
return 0;
|
||||||
|
}
|
||||||
184
sherpa-onnx/csrc/transducer-keyword-decoder.cc
Normal file
@@ -0,0 +1,184 @@
|
|||||||
|
// sherpa-onnx/csrc/transducer-keywords-decoder.cc
|
||||||
|
//
|
||||||
|
// Copyright (c) 2023-2024 Xiaomi Corporation
|
||||||
|
|
||||||
|
#include <algorithm>
|
||||||
|
#include <cmath>
|
||||||
|
#include <utility>
|
||||||
|
#include <vector>
|
||||||
|
|
||||||
|
#include "sherpa-onnx/csrc/log.h"
|
||||||
|
#include "sherpa-onnx/csrc/onnx-utils.h"
|
||||||
|
#include "sherpa-onnx/csrc/transducer-keyword-decoder.h"
|
||||||
|
|
||||||
|
namespace sherpa_onnx {
|
||||||
|
|
||||||
|
TransducerKeywordResult TransducerKeywordDecoder::GetEmptyResult() const {
|
||||||
|
int32_t context_size = model_->ContextSize();
|
||||||
|
int32_t blank_id = 0; // always 0
|
||||||
|
TransducerKeywordResult r;
|
||||||
|
std::vector<int64_t> blanks(context_size, -1);
|
||||||
|
blanks.back() = blank_id;
|
||||||
|
|
||||||
|
Hypotheses blank_hyp({{blanks, 0}});
|
||||||
|
r.hyps = std::move(blank_hyp);
|
||||||
|
return r;
|
||||||
|
}
|
||||||
|
|
||||||
|
void TransducerKeywordDecoder::Decode(
|
||||||
|
Ort::Value encoder_out, OnlineStream **ss,
|
||||||
|
std::vector<TransducerKeywordResult> *result) {
|
||||||
|
std::vector<int64_t> encoder_out_shape =
|
||||||
|
encoder_out.GetTensorTypeAndShapeInfo().GetShape();
|
||||||
|
|
||||||
|
if (encoder_out_shape[0] != result->size()) {
|
||||||
|
SHERPA_ONNX_LOGE(
|
||||||
|
"Size mismatch! encoder_out.size(0) %d, result.size(0): %d\n",
|
||||||
|
static_cast<int32_t>(encoder_out_shape[0]),
|
||||||
|
static_cast<int32_t>(result->size()));
|
||||||
|
exit(-1);
|
||||||
|
}
|
||||||
|
|
||||||
|
int32_t batch_size = static_cast<int32_t>(encoder_out_shape[0]);
|
||||||
|
|
||||||
|
int32_t num_frames = static_cast<int32_t>(encoder_out_shape[1]);
|
||||||
|
int32_t vocab_size = model_->VocabSize();
|
||||||
|
int32_t context_size = model_->ContextSize();
|
||||||
|
std::vector<int64_t> blanks(context_size, -1);
|
||||||
|
blanks.back() = 0; // blank_id is hardcoded to 0
|
||||||
|
|
||||||
|
std::vector<Hypotheses> cur;
|
||||||
|
for (auto &r : *result) {
|
||||||
|
cur.push_back(std::move(r.hyps));
|
||||||
|
}
|
||||||
|
std::vector<Hypothesis> prev;
|
||||||
|
|
||||||
|
for (int32_t t = 0; t != num_frames; ++t) {
|
||||||
|
// Due to merging paths with identical token sequences,
|
||||||
|
// not all utterances have "num_active_paths" paths.
|
||||||
|
auto hyps_row_splits = GetHypsRowSplits(cur);
|
||||||
|
int32_t num_hyps =
|
||||||
|
hyps_row_splits.back(); // total num hyps for all utterance
|
||||||
|
prev.clear();
|
||||||
|
for (auto &hyps : cur) {
|
||||||
|
for (auto &h : hyps) {
|
||||||
|
prev.push_back(std::move(h.second));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
cur.clear();
|
||||||
|
cur.reserve(batch_size);
|
||||||
|
|
||||||
|
Ort::Value decoder_input = model_->BuildDecoderInput(prev);
|
||||||
|
Ort::Value decoder_out = model_->RunDecoder(std::move(decoder_input));
|
||||||
|
|
||||||
|
Ort::Value cur_encoder_out =
|
||||||
|
GetEncoderOutFrame(model_->Allocator(), &encoder_out, t);
|
||||||
|
cur_encoder_out =
|
||||||
|
Repeat(model_->Allocator(), &cur_encoder_out, hyps_row_splits);
|
||||||
|
Ort::Value logit =
|
||||||
|
model_->RunJoiner(std::move(cur_encoder_out), View(&decoder_out));
|
||||||
|
|
||||||
|
float *p_logit = logit.GetTensorMutableData<float>();
|
||||||
|
LogSoftmax(p_logit, vocab_size, num_hyps);
|
||||||
|
|
||||||
|
// The acoustic logprobs for current frame
|
||||||
|
std::vector<float> logprobs(vocab_size * num_hyps);
|
||||||
|
std::memcpy(logprobs.data(), p_logit,
|
||||||
|
sizeof(float) * vocab_size * num_hyps);
|
||||||
|
|
||||||
|
// now p_logit contains log_softmax output, we rename it to p_logprob
|
||||||
|
// to match what it actually contains
|
||||||
|
float *p_logprob = p_logit;
|
||||||
|
|
||||||
|
// add log_prob of each hypothesis to p_logprob before taking top_k
|
||||||
|
for (int32_t i = 0; i != num_hyps; ++i) {
|
||||||
|
float log_prob = prev[i].log_prob;
|
||||||
|
for (int32_t k = 0; k != vocab_size; ++k, ++p_logprob) {
|
||||||
|
*p_logprob += log_prob;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
p_logprob = p_logit; // we changed p_logprob in the above for loop
|
||||||
|
|
||||||
|
for (int32_t b = 0; b != batch_size; ++b) {
|
||||||
|
int32_t frame_offset = (*result)[b].frame_offset;
|
||||||
|
int32_t start = hyps_row_splits[b];
|
||||||
|
int32_t end = hyps_row_splits[b + 1];
|
||||||
|
auto topk =
|
||||||
|
TopkIndex(p_logprob, vocab_size * (end - start), max_active_paths_);
|
||||||
|
|
||||||
|
Hypotheses hyps;
|
||||||
|
for (auto k : topk) {
|
||||||
|
int32_t hyp_index = k / vocab_size + start;
|
||||||
|
int32_t new_token = k % vocab_size;
|
||||||
|
|
||||||
|
Hypothesis new_hyp = prev[hyp_index];
|
||||||
|
float context_score = 0;
|
||||||
|
auto context_state = new_hyp.context_state;
|
||||||
|
|
||||||
|
// blank is hardcoded to 0
|
||||||
|
// also, it treats unk as blank
|
||||||
|
if (new_token != 0 && new_token != unk_id_) {
|
||||||
|
new_hyp.ys.push_back(new_token);
|
||||||
|
new_hyp.timestamps.push_back(t + frame_offset);
|
||||||
|
new_hyp.ys_probs.push_back(
|
||||||
|
exp(logprobs[hyp_index * vocab_size + new_token]));
|
||||||
|
|
||||||
|
new_hyp.num_trailing_blanks = 0;
|
||||||
|
auto context_res = ss[b]->GetContextGraph()->ForwardOneStep(
|
||||||
|
context_state, new_token);
|
||||||
|
context_score = std::get<0>(context_res);
|
||||||
|
new_hyp.context_state = std::get<1>(context_res);
|
||||||
|
// Start matching from the start state, forget the decoder history.
|
||||||
|
if (new_hyp.context_state->token == -1) {
|
||||||
|
new_hyp.ys = blanks;
|
||||||
|
new_hyp.timestamps.clear();
|
||||||
|
new_hyp.ys_probs.clear();
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
++new_hyp.num_trailing_blanks;
|
||||||
|
}
|
||||||
|
new_hyp.log_prob = p_logprob[k] + context_score;
|
||||||
|
hyps.Add(std::move(new_hyp));
|
||||||
|
} // for (auto k : topk)
|
||||||
|
|
||||||
|
auto best_hyp = hyps.GetMostProbable(false);
|
||||||
|
|
||||||
|
auto status = ss[b]->GetContextGraph()->IsMatched(best_hyp.context_state);
|
||||||
|
bool matched = std::get<0>(status);
|
||||||
|
const ContextState *matched_state = std::get<1>(status);
|
||||||
|
|
||||||
|
if (matched) {
|
||||||
|
float ys_prob = 0.0;
|
||||||
|
int32_t length = best_hyp.ys_probs.size();
|
||||||
|
for (int32_t i = 1; i <= matched_state->level; ++i) {
|
||||||
|
ys_prob += best_hyp.ys_probs[i];
|
||||||
|
}
|
||||||
|
ys_prob /= matched_state->level;
|
||||||
|
if (best_hyp.num_trailing_blanks > num_trailing_blanks_ &&
|
||||||
|
ys_prob >= matched_state->ac_threshold) {
|
||||||
|
auto &r = (*result)[b];
|
||||||
|
r.tokens = {best_hyp.ys.end() - matched_state->level,
|
||||||
|
best_hyp.ys.end()};
|
||||||
|
r.timestamps = {best_hyp.timestamps.end() - matched_state->level,
|
||||||
|
best_hyp.timestamps.end()};
|
||||||
|
r.keyword = matched_state->phrase;
|
||||||
|
|
||||||
|
hyps = Hypotheses({{blanks, 0, ss[b]->GetContextGraph()->Root()}});
|
||||||
|
}
|
||||||
|
}
|
||||||
|
cur.push_back(std::move(hyps));
|
||||||
|
p_logprob += (end - start) * vocab_size;
|
||||||
|
} // for (int32_t b = 0; b != batch_size; ++b)
|
||||||
|
}
|
||||||
|
|
||||||
|
for (int32_t b = 0; b != batch_size; ++b) {
|
||||||
|
auto &hyps = cur[b];
|
||||||
|
auto best_hyp = hyps.GetMostProbable(false);
|
||||||
|
auto &r = (*result)[b];
|
||||||
|
r.hyps = std::move(hyps);
|
||||||
|
r.num_trailing_blanks = best_hyp.num_trailing_blanks;
|
||||||
|
r.frame_offset += num_frames;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace sherpa_onnx
|
||||||
62
sherpa-onnx/csrc/transducer-keyword-decoder.h
Normal file
@@ -0,0 +1,62 @@
|
|||||||
|
// sherpa-onnx/csrc/transducer-keywords-decoder.h
|
||||||
|
//
|
||||||
|
// Copyright (c) 2023-2024 Xiaomi Corporation
|
||||||
|
|
||||||
|
#ifndef SHERPA_ONNX_CSRC_TRANSDUCER_KEYWORD_DECODER_H_
|
||||||
|
#define SHERPA_ONNX_CSRC_TRANSDUCER_KEYWORD_DECODER_H_
|
||||||
|
|
||||||
|
#include <string>
|
||||||
|
#include <utility>
|
||||||
|
#include <vector>
|
||||||
|
|
||||||
|
#include "sherpa-onnx/csrc/online-stream.h"
|
||||||
|
#include "sherpa-onnx/csrc/online-transducer-model.h"
|
||||||
|
|
||||||
|
namespace sherpa_onnx {
|
||||||
|
|
||||||
|
struct TransducerKeywordResult {
|
||||||
|
/// Number of frames after subsampling we have decoded so far
|
||||||
|
int32_t frame_offset = 0;
|
||||||
|
|
||||||
|
/// The decoded token IDs for keywords
|
||||||
|
std::vector<int64_t> tokens;
|
||||||
|
|
||||||
|
/// The triggered keyword
|
||||||
|
std::string keyword;
|
||||||
|
|
||||||
|
/// number of trailing blank frames decoded so far
|
||||||
|
int32_t num_trailing_blanks = 0;
|
||||||
|
|
||||||
|
/// timestamps[i] contains the output frame index where tokens[i] is decoded.
|
||||||
|
std::vector<int32_t> timestamps;
|
||||||
|
|
||||||
|
// used only in modified beam_search
|
||||||
|
Hypotheses hyps;
|
||||||
|
};
|
||||||
|
|
||||||
|
class TransducerKeywordDecoder {
|
||||||
|
public:
|
||||||
|
TransducerKeywordDecoder(OnlineTransducerModel *model,
|
||||||
|
int32_t max_active_paths,
|
||||||
|
int32_t num_trailing_blanks, int32_t unk_id)
|
||||||
|
: model_(model),
|
||||||
|
max_active_paths_(max_active_paths),
|
||||||
|
num_trailing_blanks_(num_trailing_blanks),
|
||||||
|
unk_id_(unk_id) {}
|
||||||
|
|
||||||
|
TransducerKeywordResult GetEmptyResult() const;
|
||||||
|
|
||||||
|
void Decode(Ort::Value encoder_out, OnlineStream **ss,
|
||||||
|
std::vector<TransducerKeywordResult> *result);
|
||||||
|
|
||||||
|
private:
|
||||||
|
OnlineTransducerModel *model_; // Not owned
|
||||||
|
|
||||||
|
int32_t max_active_paths_;
|
||||||
|
int32_t num_trailing_blanks_;
|
||||||
|
int32_t unk_id_;
|
||||||
|
};
|
||||||
|
|
||||||
|
} // namespace sherpa_onnx
|
||||||
|
|
||||||
|
#endif // SHERPA_ONNX_CSRC_TRANSDUCER_KEYWORD_DECODER_H_
|
||||||
@@ -15,16 +15,31 @@
|
|||||||
|
|
||||||
namespace sherpa_onnx {
|
namespace sherpa_onnx {
|
||||||
|
|
||||||
bool EncodeHotwords(std::istream &is, const SymbolTable &symbol_table,
|
static bool EncodeBase(std::istream &is, const SymbolTable &symbol_table,
|
||||||
std::vector<std::vector<int32_t>> *hotwords) {
|
std::vector<std::vector<int32_t>> *ids,
|
||||||
hotwords->clear();
|
std::vector<std::string> *phrases,
|
||||||
std::vector<int32_t> tmp;
|
std::vector<float> *scores,
|
||||||
|
std::vector<float> *thresholds) {
|
||||||
|
SHERPA_ONNX_CHECK(ids != nullptr);
|
||||||
|
ids->clear();
|
||||||
|
|
||||||
|
std::vector<int32_t> tmp_ids;
|
||||||
|
std::vector<float> tmp_scores;
|
||||||
|
std::vector<float> tmp_thresholds;
|
||||||
|
std::vector<std::string> tmp_phrases;
|
||||||
|
|
||||||
std::string line;
|
std::string line;
|
||||||
std::string word;
|
std::string word;
|
||||||
|
bool has_scores = false;
|
||||||
|
bool has_thresholds = false;
|
||||||
|
bool has_phrases = false;
|
||||||
|
|
||||||
while (std::getline(is, line)) {
|
while (std::getline(is, line)) {
|
||||||
|
float score = 0;
|
||||||
|
float threshold = 0;
|
||||||
|
std::string phrase = "";
|
||||||
|
|
||||||
std::istringstream iss(line);
|
std::istringstream iss(line);
|
||||||
std::vector<std::string> syms;
|
|
||||||
while (iss >> word) {
|
while (iss >> word) {
|
||||||
if (word.size() >= 3) {
|
if (word.size() >= 3) {
|
||||||
// For BPE-based models, we replace ▁ with a space
|
// For BPE-based models, we replace ▁ with a space
|
||||||
@@ -35,20 +50,72 @@ bool EncodeHotwords(std::istream &is, const SymbolTable &symbol_table,
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
if (symbol_table.contains(word)) {
|
if (symbol_table.contains(word)) {
|
||||||
int32_t number = symbol_table[word];
|
int32_t id = symbol_table[word];
|
||||||
tmp.push_back(number);
|
tmp_ids.push_back(id);
|
||||||
} else {
|
} else {
|
||||||
|
switch (word[0]) {
|
||||||
|
case ':': // boosting score for current keyword
|
||||||
|
score = std::stof(word.substr(1));
|
||||||
|
has_scores = true;
|
||||||
|
break;
|
||||||
|
case '#': // triggering threshold (probability) for current keyword
|
||||||
|
threshold = std::stof(word.substr(1));
|
||||||
|
has_thresholds = true;
|
||||||
|
break;
|
||||||
|
case '@': // the original keyword string
|
||||||
|
phrase = word.substr(1);
|
||||||
|
has_phrases = true;
|
||||||
|
break;
|
||||||
|
default:
|
||||||
SHERPA_ONNX_LOGE(
|
SHERPA_ONNX_LOGE(
|
||||||
"Cannot find ID for hotword %s at line: %s. (Hint: words on "
|
"Cannot find ID for token %s at line: %s. (Hint: words on "
|
||||||
"the "
|
"the same line are separated by spaces)",
|
||||||
"same line are separated by spaces)",
|
|
||||||
word.c_str(), line.c_str());
|
word.c_str(), line.c_str());
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
hotwords->push_back(std::move(tmp));
|
}
|
||||||
|
ids->push_back(std::move(tmp_ids));
|
||||||
|
tmp_scores.push_back(score);
|
||||||
|
tmp_phrases.push_back(phrase);
|
||||||
|
tmp_thresholds.push_back(threshold);
|
||||||
|
}
|
||||||
|
if (scores != nullptr) {
|
||||||
|
if (has_scores) {
|
||||||
|
scores->swap(tmp_scores);
|
||||||
|
} else {
|
||||||
|
scores->clear();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if (phrases != nullptr) {
|
||||||
|
if (has_phrases) {
|
||||||
|
*phrases = std::move(tmp_phrases);
|
||||||
|
} else {
|
||||||
|
phrases->clear();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if (thresholds != nullptr) {
|
||||||
|
if (has_thresholds) {
|
||||||
|
thresholds->swap(tmp_thresholds);
|
||||||
|
} else {
|
||||||
|
thresholds->clear();
|
||||||
|
}
|
||||||
}
|
}
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
bool EncodeHotwords(std::istream &is, const SymbolTable &symbol_table,
|
||||||
|
std::vector<std::vector<int32_t>> *hotwords) {
|
||||||
|
return EncodeBase(is, symbol_table, hotwords, nullptr, nullptr, nullptr);
|
||||||
|
}
|
||||||
|
|
||||||
|
bool EncodeKeywords(std::istream &is, const SymbolTable &symbol_table,
|
||||||
|
std::vector<std::vector<int32_t>> *keywords_id,
|
||||||
|
std::vector<std::string> *keywords,
|
||||||
|
std::vector<float> *boost_scores,
|
||||||
|
std::vector<float> *threshold) {
|
||||||
|
return EncodeBase(is, symbol_table, keywords_id, keywords, boost_scores,
|
||||||
|
threshold);
|
||||||
|
}
|
||||||
|
|
||||||
} // namespace sherpa_onnx
|
} // namespace sherpa_onnx
|
||||||
|
|||||||
@@ -26,7 +26,32 @@ namespace sherpa_onnx {
|
|||||||
* otherwise returns false.
|
* otherwise returns false.
|
||||||
*/
|
*/
|
||||||
bool EncodeHotwords(std::istream &is, const SymbolTable &symbol_table,
|
bool EncodeHotwords(std::istream &is, const SymbolTable &symbol_table,
|
||||||
std::vector<std::vector<int32_t>> *hotwords);
|
std::vector<std::vector<int32_t>> *hotwords_id);
|
||||||
|
|
||||||
|
/* Encode the keywords in an input stream to be tokens ids.
|
||||||
|
*
|
||||||
|
* @param is The input stream, it contains several lines, one hotword for each
|
||||||
|
* line. For each hotword, the tokens (cjkchar or bpe) are separated
|
||||||
|
* by spaces, it might contain boosting score (starting with :),
|
||||||
|
* triggering threshold (starting with #) and keyword string (starting
|
||||||
|
* with @) too.
|
||||||
|
* @param symbol_table The tokens table mapping symbols to ids. All the symbols
|
||||||
|
* in the stream should be in the symbol_table, if not this
|
||||||
|
* function returns fasle.
|
||||||
|
*
|
||||||
|
* @param keywords_id The encoded ids to be written to.
|
||||||
|
* @param keywords The original keyword string to be written to.
|
||||||
|
* @param boost_scores The boosting score for each keyword to be written to.
|
||||||
|
* @param threshold The triggering threshold for each keyword to be written to.
|
||||||
|
*
|
||||||
|
* @return If all the symbols from ``is`` are in the symbol_table, returns true
|
||||||
|
* otherwise returns false.
|
||||||
|
*/
|
||||||
|
bool EncodeKeywords(std::istream &is, const SymbolTable &symbol_table,
|
||||||
|
std::vector<std::vector<int32_t>> *keywords_id,
|
||||||
|
std::vector<std::string> *keywords,
|
||||||
|
std::vector<float> *boost_scores,
|
||||||
|
std::vector<float> *threshold);
|
||||||
|
|
||||||
} // namespace sherpa_onnx
|
} // namespace sherpa_onnx
|
||||||
|
|
||||||
|
|||||||
@@ -21,6 +21,7 @@
|
|||||||
#include "android/asset_manager_jni.h"
|
#include "android/asset_manager_jni.h"
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
|
#include "sherpa-onnx/csrc/keyword-spotter.h"
|
||||||
#include "sherpa-onnx/csrc/macros.h"
|
#include "sherpa-onnx/csrc/macros.h"
|
||||||
#include "sherpa-onnx/csrc/offline-recognizer.h"
|
#include "sherpa-onnx/csrc/offline-recognizer.h"
|
||||||
#include "sherpa-onnx/csrc/offline-tts.h"
|
#include "sherpa-onnx/csrc/offline-tts.h"
|
||||||
@@ -140,6 +141,73 @@ class SherpaOnnxVad {
|
|||||||
VoiceActivityDetector vad_;
|
VoiceActivityDetector vad_;
|
||||||
};
|
};
|
||||||
|
|
||||||
|
class SherpaOnnxKws {
|
||||||
|
public:
|
||||||
|
#if __ANDROID_API__ >= 9
|
||||||
|
SherpaOnnxKws(AAssetManager *mgr, const KeywordSpotterConfig &config)
|
||||||
|
: keyword_spotter_(mgr, config),
|
||||||
|
stream_(keyword_spotter_.CreateStream()) {}
|
||||||
|
#endif
|
||||||
|
|
||||||
|
explicit SherpaOnnxKws(const KeywordSpotterConfig &config)
|
||||||
|
: keyword_spotter_(config), stream_(keyword_spotter_.CreateStream()) {}
|
||||||
|
|
||||||
|
void AcceptWaveform(int32_t sample_rate, const float *samples, int32_t n) {
|
||||||
|
if (input_sample_rate_ == -1) {
|
||||||
|
input_sample_rate_ = sample_rate;
|
||||||
|
}
|
||||||
|
|
||||||
|
stream_->AcceptWaveform(sample_rate, samples, n);
|
||||||
|
}
|
||||||
|
|
||||||
|
void InputFinished() const {
|
||||||
|
std::vector<float> tail_padding(input_sample_rate_ * 0.6, 0);
|
||||||
|
stream_->AcceptWaveform(input_sample_rate_, tail_padding.data(),
|
||||||
|
tail_padding.size());
|
||||||
|
stream_->InputFinished();
|
||||||
|
}
|
||||||
|
|
||||||
|
// If keywords is an empty string, it just recreates the decoding stream
|
||||||
|
// always returns true in this case.
|
||||||
|
// If keywords is not empty, it will create a new decoding stream with
|
||||||
|
// the given keywords appended to the default keywords.
|
||||||
|
// Return false if errors occurred when adding keywords, true otherwise.
|
||||||
|
bool Reset(const std::string &keywords = {}) {
|
||||||
|
if (keywords.empty()) {
|
||||||
|
stream_ = keyword_spotter_.CreateStream();
|
||||||
|
return true;
|
||||||
|
} else {
|
||||||
|
auto stream = keyword_spotter_.CreateStream(keywords);
|
||||||
|
// Set new keywords failed, the stream_ will not be updated.
|
||||||
|
if (stream == nullptr) {
|
||||||
|
return false;
|
||||||
|
} else {
|
||||||
|
stream_ = std::move(stream);
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
std::string GetKeyword() const {
|
||||||
|
auto result = keyword_spotter_.GetResult(stream_.get());
|
||||||
|
return result.keyword;
|
||||||
|
}
|
||||||
|
|
||||||
|
std::vector<std::string> GetTokens() const {
|
||||||
|
auto result = keyword_spotter_.GetResult(stream_.get());
|
||||||
|
return result.tokens;
|
||||||
|
}
|
||||||
|
|
||||||
|
bool IsReady() const { return keyword_spotter_.IsReady(stream_.get()); }
|
||||||
|
|
||||||
|
void Decode() const { keyword_spotter_.DecodeStream(stream_.get()); }
|
||||||
|
|
||||||
|
private:
|
||||||
|
KeywordSpotter keyword_spotter_;
|
||||||
|
std::unique_ptr<OnlineStream> stream_;
|
||||||
|
int32_t input_sample_rate_ = -1;
|
||||||
|
};
|
||||||
|
|
||||||
static OnlineRecognizerConfig GetConfig(JNIEnv *env, jobject config) {
|
static OnlineRecognizerConfig GetConfig(JNIEnv *env, jobject config) {
|
||||||
OnlineRecognizerConfig ans;
|
OnlineRecognizerConfig ans;
|
||||||
|
|
||||||
@@ -457,6 +525,103 @@ static OfflineRecognizerConfig GetOfflineConfig(JNIEnv *env, jobject config) {
|
|||||||
return ans;
|
return ans;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
static KeywordSpotterConfig GetKwsConfig(JNIEnv *env, jobject config) {
|
||||||
|
KeywordSpotterConfig ans;
|
||||||
|
|
||||||
|
jclass cls = env->GetObjectClass(config);
|
||||||
|
jfieldID fid;
|
||||||
|
|
||||||
|
// https://docs.oracle.com/javase/7/docs/technotes/guides/jni/spec/types.html
|
||||||
|
// https://courses.cs.washington.edu/courses/cse341/99wi/java/tutorial/native1.1/implementing/field.html
|
||||||
|
|
||||||
|
//---------- decoding ----------
|
||||||
|
fid = env->GetFieldID(cls, "maxActivePaths", "I");
|
||||||
|
ans.max_active_paths = env->GetIntField(config, fid);
|
||||||
|
|
||||||
|
fid = env->GetFieldID(cls, "keywordsFile", "Ljava/lang/String;");
|
||||||
|
jstring s = (jstring)env->GetObjectField(config, fid);
|
||||||
|
const char *p = env->GetStringUTFChars(s, nullptr);
|
||||||
|
ans.keywords_file = p;
|
||||||
|
env->ReleaseStringUTFChars(s, p);
|
||||||
|
|
||||||
|
fid = env->GetFieldID(cls, "keywordsScore", "F");
|
||||||
|
ans.keywords_score = env->GetFloatField(config, fid);
|
||||||
|
|
||||||
|
fid = env->GetFieldID(cls, "keywordsThreshold", "F");
|
||||||
|
ans.keywords_threshold = env->GetFloatField(config, fid);
|
||||||
|
|
||||||
|
fid = env->GetFieldID(cls, "numTrailingBlanks", "I");
|
||||||
|
ans.num_trailing_blanks = env->GetIntField(config, fid);
|
||||||
|
|
||||||
|
//---------- feat config ----------
|
||||||
|
fid = env->GetFieldID(cls, "featConfig",
|
||||||
|
"Lcom/k2fsa/sherpa/onnx/FeatureConfig;");
|
||||||
|
jobject feat_config = env->GetObjectField(config, fid);
|
||||||
|
jclass feat_config_cls = env->GetObjectClass(feat_config);
|
||||||
|
|
||||||
|
fid = env->GetFieldID(feat_config_cls, "sampleRate", "I");
|
||||||
|
ans.feat_config.sampling_rate = env->GetIntField(feat_config, fid);
|
||||||
|
|
||||||
|
fid = env->GetFieldID(feat_config_cls, "featureDim", "I");
|
||||||
|
ans.feat_config.feature_dim = env->GetIntField(feat_config, fid);
|
||||||
|
|
||||||
|
//---------- model config ----------
|
||||||
|
fid = env->GetFieldID(cls, "modelConfig",
|
||||||
|
"Lcom/k2fsa/sherpa/onnx/OnlineModelConfig;");
|
||||||
|
jobject model_config = env->GetObjectField(config, fid);
|
||||||
|
jclass model_config_cls = env->GetObjectClass(model_config);
|
||||||
|
|
||||||
|
// transducer
|
||||||
|
fid = env->GetFieldID(model_config_cls, "transducer",
|
||||||
|
"Lcom/k2fsa/sherpa/onnx/OnlineTransducerModelConfig;");
|
||||||
|
jobject transducer_config = env->GetObjectField(model_config, fid);
|
||||||
|
jclass transducer_config_cls = env->GetObjectClass(transducer_config);
|
||||||
|
|
||||||
|
fid = env->GetFieldID(transducer_config_cls, "encoder", "Ljava/lang/String;");
|
||||||
|
s = (jstring)env->GetObjectField(transducer_config, fid);
|
||||||
|
p = env->GetStringUTFChars(s, nullptr);
|
||||||
|
ans.model_config.transducer.encoder = p;
|
||||||
|
env->ReleaseStringUTFChars(s, p);
|
||||||
|
|
||||||
|
fid = env->GetFieldID(transducer_config_cls, "decoder", "Ljava/lang/String;");
|
||||||
|
s = (jstring)env->GetObjectField(transducer_config, fid);
|
||||||
|
p = env->GetStringUTFChars(s, nullptr);
|
||||||
|
ans.model_config.transducer.decoder = p;
|
||||||
|
env->ReleaseStringUTFChars(s, p);
|
||||||
|
|
||||||
|
fid = env->GetFieldID(transducer_config_cls, "joiner", "Ljava/lang/String;");
|
||||||
|
s = (jstring)env->GetObjectField(transducer_config, fid);
|
||||||
|
p = env->GetStringUTFChars(s, nullptr);
|
||||||
|
ans.model_config.transducer.joiner = p;
|
||||||
|
env->ReleaseStringUTFChars(s, p);
|
||||||
|
|
||||||
|
fid = env->GetFieldID(model_config_cls, "tokens", "Ljava/lang/String;");
|
||||||
|
s = (jstring)env->GetObjectField(model_config, fid);
|
||||||
|
p = env->GetStringUTFChars(s, nullptr);
|
||||||
|
ans.model_config.tokens = p;
|
||||||
|
env->ReleaseStringUTFChars(s, p);
|
||||||
|
|
||||||
|
fid = env->GetFieldID(model_config_cls, "numThreads", "I");
|
||||||
|
ans.model_config.num_threads = env->GetIntField(model_config, fid);
|
||||||
|
|
||||||
|
fid = env->GetFieldID(model_config_cls, "debug", "Z");
|
||||||
|
ans.model_config.debug = env->GetBooleanField(model_config, fid);
|
||||||
|
|
||||||
|
fid = env->GetFieldID(model_config_cls, "provider", "Ljava/lang/String;");
|
||||||
|
s = (jstring)env->GetObjectField(model_config, fid);
|
||||||
|
p = env->GetStringUTFChars(s, nullptr);
|
||||||
|
ans.model_config.provider = p;
|
||||||
|
env->ReleaseStringUTFChars(s, p);
|
||||||
|
|
||||||
|
fid = env->GetFieldID(model_config_cls, "modelType", "Ljava/lang/String;");
|
||||||
|
s = (jstring)env->GetObjectField(model_config, fid);
|
||||||
|
p = env->GetStringUTFChars(s, nullptr);
|
||||||
|
ans.model_config.model_type = p;
|
||||||
|
env->ReleaseStringUTFChars(s, p);
|
||||||
|
|
||||||
|
return ans;
|
||||||
|
}
|
||||||
|
|
||||||
static VadModelConfig GetVadModelConfig(JNIEnv *env, jobject config) {
|
static VadModelConfig GetVadModelConfig(JNIEnv *env, jobject config) {
|
||||||
VadModelConfig ans;
|
VadModelConfig ans;
|
||||||
|
|
||||||
@@ -1013,7 +1178,124 @@ JNIEXPORT jobjectArray JNICALL Java_com_k2fsa_sherpa_onnx_SherpaOnnx_getTokens(
|
|||||||
jclass stringClass = env->FindClass("java/lang/String");
|
jclass stringClass = env->FindClass("java/lang/String");
|
||||||
|
|
||||||
// convert C++ list into jni string array
|
// convert C++ list into jni string array
|
||||||
jobjectArray result = env->NewObjectArray(size, stringClass, NULL);
|
jobjectArray result = env->NewObjectArray(size, stringClass, nullptr);
|
||||||
|
for (int32_t i = 0; i < size; i++) {
|
||||||
|
// Convert the C++ string to a C string
|
||||||
|
const char *cstr = tokens[i].c_str();
|
||||||
|
|
||||||
|
// Convert the C string to a jstring
|
||||||
|
jstring jstr = env->NewStringUTF(cstr);
|
||||||
|
|
||||||
|
// Set the array element
|
||||||
|
env->SetObjectArrayElement(result, i, jstr);
|
||||||
|
}
|
||||||
|
|
||||||
|
return result;
|
||||||
|
}
|
||||||
|
|
||||||
|
SHERPA_ONNX_EXTERN_C
|
||||||
|
JNIEXPORT jlong JNICALL Java_com_k2fsa_sherpa_onnx_SherpaOnnxKws_new(
|
||||||
|
JNIEnv *env, jobject /*obj*/, jobject asset_manager, jobject _config) {
|
||||||
|
#if __ANDROID_API__ >= 9
|
||||||
|
AAssetManager *mgr = AAssetManager_fromJava(env, asset_manager);
|
||||||
|
if (!mgr) {
|
||||||
|
SHERPA_ONNX_LOGE("Failed to get asset manager: %p", mgr);
|
||||||
|
}
|
||||||
|
#endif
|
||||||
|
auto config = sherpa_onnx::GetKwsConfig(env, _config);
|
||||||
|
SHERPA_ONNX_LOGE("config:\n%s", config.ToString().c_str());
|
||||||
|
auto model = new sherpa_onnx::SherpaOnnxKws(
|
||||||
|
#if __ANDROID_API__ >= 9
|
||||||
|
mgr,
|
||||||
|
#endif
|
||||||
|
config);
|
||||||
|
|
||||||
|
return (jlong)model;
|
||||||
|
}
|
||||||
|
|
||||||
|
SHERPA_ONNX_EXTERN_C
|
||||||
|
JNIEXPORT jlong JNICALL Java_com_k2fsa_sherpa_onnx_SherpaOnnxKws_newFromFile(
|
||||||
|
JNIEnv *env, jobject /*obj*/, jobject _config) {
|
||||||
|
auto config = sherpa_onnx::GetKwsConfig(env, _config);
|
||||||
|
SHERPA_ONNX_LOGE("config:\n%s", config.ToString().c_str());
|
||||||
|
auto model = new sherpa_onnx::SherpaOnnxKws(config);
|
||||||
|
|
||||||
|
return (jlong)model;
|
||||||
|
}
|
||||||
|
|
||||||
|
SHERPA_ONNX_EXTERN_C
|
||||||
|
JNIEXPORT void JNICALL Java_com_k2fsa_sherpa_onnx_SherpaOnnxKws_delete(
|
||||||
|
JNIEnv *env, jobject /*obj*/, jlong ptr) {
|
||||||
|
delete reinterpret_cast<sherpa_onnx::SherpaOnnxKws *>(ptr);
|
||||||
|
}
|
||||||
|
|
||||||
|
SHERPA_ONNX_EXTERN_C
|
||||||
|
JNIEXPORT bool JNICALL Java_com_k2fsa_sherpa_onnx_SherpaOnnxKws_isReady(
|
||||||
|
JNIEnv *env, jobject /*obj*/, jlong ptr) {
|
||||||
|
auto model = reinterpret_cast<sherpa_onnx::SherpaOnnxKws *>(ptr);
|
||||||
|
return model->IsReady();
|
||||||
|
}
|
||||||
|
|
||||||
|
SHERPA_ONNX_EXTERN_C
|
||||||
|
JNIEXPORT void JNICALL Java_com_k2fsa_sherpa_onnx_SherpaOnnxKws_decode(
|
||||||
|
JNIEnv *env, jobject /*obj*/, jlong ptr) {
|
||||||
|
auto model = reinterpret_cast<sherpa_onnx::SherpaOnnxKws *>(ptr);
|
||||||
|
model->Decode();
|
||||||
|
}
|
||||||
|
|
||||||
|
SHERPA_ONNX_EXTERN_C
|
||||||
|
JNIEXPORT void JNICALL Java_com_k2fsa_sherpa_onnx_SherpaOnnxKws_acceptWaveform(
|
||||||
|
JNIEnv *env, jobject /*obj*/, jlong ptr, jfloatArray samples,
|
||||||
|
jint sample_rate) {
|
||||||
|
auto model = reinterpret_cast<sherpa_onnx::SherpaOnnxKws *>(ptr);
|
||||||
|
|
||||||
|
jfloat *p = env->GetFloatArrayElements(samples, nullptr);
|
||||||
|
jsize n = env->GetArrayLength(samples);
|
||||||
|
|
||||||
|
model->AcceptWaveform(sample_rate, p, n);
|
||||||
|
|
||||||
|
env->ReleaseFloatArrayElements(samples, p, JNI_ABORT);
|
||||||
|
}
|
||||||
|
|
||||||
|
SHERPA_ONNX_EXTERN_C
|
||||||
|
JNIEXPORT void JNICALL Java_com_k2fsa_sherpa_onnx_SherpaOnnxKws_inputFinished(
|
||||||
|
JNIEnv *env, jobject /*obj*/, jlong ptr) {
|
||||||
|
reinterpret_cast<sherpa_onnx::SherpaOnnxKws *>(ptr)->InputFinished();
|
||||||
|
}
|
||||||
|
|
||||||
|
SHERPA_ONNX_EXTERN_C
|
||||||
|
JNIEXPORT jstring JNICALL Java_com_k2fsa_sherpa_onnx_SherpaOnnxKws_getKeyword(
|
||||||
|
JNIEnv *env, jobject /*obj*/, jlong ptr) {
|
||||||
|
// see
|
||||||
|
// https://stackoverflow.com/questions/11621449/send-c-string-to-java-via-jni
|
||||||
|
auto text = reinterpret_cast<sherpa_onnx::SherpaOnnxKws *>(ptr)->GetKeyword();
|
||||||
|
return env->NewStringUTF(text.c_str());
|
||||||
|
}
|
||||||
|
|
||||||
|
SHERPA_ONNX_EXTERN_C
|
||||||
|
JNIEXPORT bool JNICALL Java_com_k2fsa_sherpa_onnx_SherpaOnnxKws_reset(
|
||||||
|
JNIEnv *env, jobject /*obj*/, jlong ptr, jstring keywords) {
|
||||||
|
const char *p_keywords = env->GetStringUTFChars(keywords, nullptr);
|
||||||
|
|
||||||
|
std::string keywords_str = p_keywords;
|
||||||
|
|
||||||
|
bool status =
|
||||||
|
reinterpret_cast<sherpa_onnx::SherpaOnnxKws *>(ptr)->Reset(keywords_str);
|
||||||
|
env->ReleaseStringUTFChars(keywords, p_keywords);
|
||||||
|
return status;
|
||||||
|
}
|
||||||
|
|
||||||
|
SHERPA_ONNX_EXTERN_C
|
||||||
|
JNIEXPORT jobjectArray JNICALL
|
||||||
|
Java_com_k2fsa_sherpa_onnx_SherpaOnnxKws_getTokens(JNIEnv *env, jobject /*obj*/,
|
||||||
|
jlong ptr) {
|
||||||
|
auto tokens =
|
||||||
|
reinterpret_cast<sherpa_onnx::SherpaOnnxKws *>(ptr)->GetTokens();
|
||||||
|
int32_t size = tokens.size();
|
||||||
|
jclass stringClass = env->FindClass("java/lang/String");
|
||||||
|
|
||||||
|
// convert C++ list into jni string array
|
||||||
|
jobjectArray result = env->NewObjectArray(size, stringClass, nullptr);
|
||||||
for (int32_t i = 0; i < size; i++) {
|
for (int32_t i = 0; i < size; i++) {
|
||||||
// Convert the C++ string to a C string
|
// Convert the C++ string to a C string
|
||||||
const char *cstr = tokens[i].c_str();
|
const char *cstr = tokens[i].c_str();
|
||||||
|
|||||||
@@ -28,9 +28,14 @@ def cli():
|
|||||||
)
|
)
|
||||||
@click.option(
|
@click.option(
|
||||||
"--tokens-type",
|
"--tokens-type",
|
||||||
type=str,
|
type=click.Choice(
|
||||||
|
["cjkchar", "bpe", "cjkchar+bpe", "fpinyin", "ppinyin"], case_sensitive=True
|
||||||
|
),
|
||||||
required=True,
|
required=True,
|
||||||
help="The type of modeling units, should be cjkchar, bpe or cjkchar+bpe",
|
help="""The type of modeling units, should be cjkchar, bpe, cjkchar+bpe, fpinyin or ppinyin.
|
||||||
|
fpinyin means full pinyin, each cjkchar has a pinyin(with tone).
|
||||||
|
ppinyin means partial pinyin, it splits pinyin into initial and final,
|
||||||
|
""",
|
||||||
)
|
)
|
||||||
@click.option(
|
@click.option(
|
||||||
"--bpe-model",
|
"--bpe-model",
|
||||||
@@ -42,14 +47,56 @@ def encode_text(
|
|||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
Encode the texts given by the INPUT to tokens and write the results to the OUTPUT.
|
Encode the texts given by the INPUT to tokens and write the results to the OUTPUT.
|
||||||
|
Each line in the texts contains the original phrase, it might also contain some
|
||||||
|
extra items, for example, the boosting score (startting with :), the triggering
|
||||||
|
threshold (startting with #, only used in keyword spotting task) and the original
|
||||||
|
phrase (startting with @). Note: the extra items will be kept same in the output.
|
||||||
|
|
||||||
|
example input 1 (tokens_type = ppinyin):
|
||||||
|
|
||||||
|
小爱同学 :2.0 #0.6 @小爱同学
|
||||||
|
你好问问 :3.5 @你好问问
|
||||||
|
小艺小艺 #0.6 @小艺小艺
|
||||||
|
|
||||||
|
example output 1:
|
||||||
|
|
||||||
|
x iǎo ài t óng x ué :2.0 #0.6 @小爱同学
|
||||||
|
n ǐ h ǎo w èn w èn :3.5 @你好问问
|
||||||
|
x iǎo y ì x iǎo y ì #0.6 @小艺小艺
|
||||||
|
|
||||||
|
example input 2 (tokens_type = bpe):
|
||||||
|
|
||||||
|
HELLO WORLD :1.5 #0.4
|
||||||
|
HI GOOGLE :2.0 #0.8
|
||||||
|
HEY SIRI #0.35
|
||||||
|
|
||||||
|
example output 2:
|
||||||
|
|
||||||
|
▁HE LL O ▁WORLD :1.5 #0.4
|
||||||
|
▁HI ▁GO O G LE :2.0 #0.8
|
||||||
|
▁HE Y ▁S I RI #0.35
|
||||||
"""
|
"""
|
||||||
texts = []
|
texts = []
|
||||||
|
# extra information like boosting score (start with :), triggering threshold (start with #)
|
||||||
|
# original keyword (start with @)
|
||||||
|
extra_info = []
|
||||||
with open(input, "r", encoding="utf8") as f:
|
with open(input, "r", encoding="utf8") as f:
|
||||||
for line in f:
|
for line in f:
|
||||||
texts.append(line.strip())
|
extra = []
|
||||||
|
text = []
|
||||||
|
toks = line.strip().split()
|
||||||
|
for tok in toks:
|
||||||
|
if tok[0] == ":" or tok[0] == "#" or tok[0] == "@":
|
||||||
|
extra.append(tok)
|
||||||
|
else:
|
||||||
|
text.append(tok)
|
||||||
|
texts.append(" ".join(text))
|
||||||
|
extra_info.append(extra)
|
||||||
|
|
||||||
encoded_texts = text2token(
|
encoded_texts = text2token(
|
||||||
texts, tokens=tokens, tokens_type=tokens_type, bpe_model=bpe_model
|
texts, tokens=tokens, tokens_type=tokens_type, bpe_model=bpe_model
|
||||||
)
|
)
|
||||||
with open(output, "w", encoding="utf8") as f:
|
with open(output, "w", encoding="utf8") as f:
|
||||||
for txt in encoded_texts:
|
for i, txt in enumerate(encoded_texts):
|
||||||
|
txt += extra_info[i]
|
||||||
f.write(" ".join(txt) + "\n")
|
f.write(" ".join(txt) + "\n")
|
||||||
|
|||||||
@@ -6,6 +6,9 @@ from typing import List, Optional, Union
|
|||||||
|
|
||||||
import sentencepiece as spm
|
import sentencepiece as spm
|
||||||
|
|
||||||
|
from pypinyin import pinyin
|
||||||
|
from pypinyin.contrib.tone_convert import to_initials, to_finals_tone
|
||||||
|
|
||||||
|
|
||||||
def text2token(
|
def text2token(
|
||||||
texts: List[str],
|
texts: List[str],
|
||||||
@@ -23,7 +26,9 @@ def text2token(
|
|||||||
tokens:
|
tokens:
|
||||||
The path of the tokens.txt.
|
The path of the tokens.txt.
|
||||||
tokens_type:
|
tokens_type:
|
||||||
The valid values are cjkchar, bpe, cjkchar+bpe.
|
The valid values are cjkchar, bpe, cjkchar+bpe, fpinyin, ppinyin.
|
||||||
|
fpinyin means full pinyin, each cjkchar has a pinyin(with tone).
|
||||||
|
ppinyin means partial pinyin, it splits pinyin into initial and final,
|
||||||
bpe_model:
|
bpe_model:
|
||||||
The path of the bpe model. Only required when tokens_type is bpe or
|
The path of the bpe model. Only required when tokens_type is bpe or
|
||||||
cjkchar+bpe.
|
cjkchar+bpe.
|
||||||
@@ -53,6 +58,24 @@ def text2token(
|
|||||||
texts_list = [list("".join(text.split())) for text in texts]
|
texts_list = [list("".join(text.split())) for text in texts]
|
||||||
elif tokens_type == "bpe":
|
elif tokens_type == "bpe":
|
||||||
texts_list = sp.encode(texts, out_type=str)
|
texts_list = sp.encode(texts, out_type=str)
|
||||||
|
elif "pinyin" in tokens_type:
|
||||||
|
for txt in texts:
|
||||||
|
py = [x[0] for x in pinyin(txt)]
|
||||||
|
if "ppinyin" == tokens_type:
|
||||||
|
res = []
|
||||||
|
for x in py:
|
||||||
|
initial = to_initials(x, strict=False)
|
||||||
|
final = to_finals_tone(x, strict=False)
|
||||||
|
if initial == "" and final == "":
|
||||||
|
res.append(x)
|
||||||
|
else:
|
||||||
|
if initial != "":
|
||||||
|
res.append(initial)
|
||||||
|
if final != "":
|
||||||
|
res.append(final)
|
||||||
|
texts_list.append(res)
|
||||||
|
else:
|
||||||
|
texts_list.append(py)
|
||||||
else:
|
else:
|
||||||
assert (
|
assert (
|
||||||
tokens_type == "cjkchar+bpe"
|
tokens_type == "cjkchar+bpe"
|
||||||
|
|||||||