Support CED models (#792)
This commit is contained in:
78
.github/workflows/export-ced-to-onnx.yaml
vendored
Normal file
78
.github/workflows/export-ced-to-onnx.yaml
vendored
Normal file
@@ -0,0 +1,78 @@
|
|||||||
|
name: export-ced-to-onnx
|
||||||
|
|
||||||
|
on:
|
||||||
|
workflow_dispatch:
|
||||||
|
|
||||||
|
concurrency:
|
||||||
|
group: export-ced-to-onnx-${{ github.ref }}
|
||||||
|
cancel-in-progress: true
|
||||||
|
|
||||||
|
jobs:
|
||||||
|
export-ced-to-onnx:
|
||||||
|
if: github.repository_owner == 'k2-fsa' || github.repository_owner == 'csukuangfj'
|
||||||
|
name: export ced
|
||||||
|
runs-on: ${{ matrix.os }}
|
||||||
|
strategy:
|
||||||
|
fail-fast: false
|
||||||
|
matrix:
|
||||||
|
os: [ubuntu-latest]
|
||||||
|
python-version: ["3.8"]
|
||||||
|
|
||||||
|
steps:
|
||||||
|
- uses: actions/checkout@v4
|
||||||
|
|
||||||
|
- name: Setup Python ${{ matrix.python-version }}
|
||||||
|
uses: actions/setup-python@v5
|
||||||
|
with:
|
||||||
|
python-version: ${{ matrix.python-version }}
|
||||||
|
|
||||||
|
- name: Run
|
||||||
|
shell: bash
|
||||||
|
run: |
|
||||||
|
cd scripts/ced
|
||||||
|
./run.sh
|
||||||
|
|
||||||
|
- name: Release
|
||||||
|
uses: svenstaro/upload-release-action@v2
|
||||||
|
with:
|
||||||
|
file_glob: true
|
||||||
|
file: ./*.tar.bz2
|
||||||
|
overwrite: true
|
||||||
|
repo_name: k2-fsa/sherpa-onnx
|
||||||
|
repo_token: ${{ secrets.UPLOAD_GH_SHERPA_ONNX_TOKEN }}
|
||||||
|
tag: audio-tagging-models
|
||||||
|
|
||||||
|
- name: Publish to huggingface
|
||||||
|
env:
|
||||||
|
HF_TOKEN: ${{ secrets.HF_TOKEN }}
|
||||||
|
uses: nick-fields/retry@v3
|
||||||
|
with:
|
||||||
|
max_attempts: 20
|
||||||
|
timeout_seconds: 200
|
||||||
|
shell: bash
|
||||||
|
command: |
|
||||||
|
git config --global user.email "csukuangfj@gmail.com"
|
||||||
|
git config --global user.name "Fangjun Kuang"
|
||||||
|
|
||||||
|
models=(
|
||||||
|
tiny
|
||||||
|
mini
|
||||||
|
small
|
||||||
|
base
|
||||||
|
)
|
||||||
|
|
||||||
|
for m in ${models[@]}; do
|
||||||
|
rm -rf huggingface
|
||||||
|
export GIT_LFS_SKIP_SMUDGE=1
|
||||||
|
d=sherpa-onnx-ced-$m-audio-tagging-2024-04-19
|
||||||
|
git clone https://huggingface.co/k2-fsa/$d huggingface
|
||||||
|
mv -v $d/* huggingface
|
||||||
|
cd huggingface
|
||||||
|
git lfs track "*.onnx"
|
||||||
|
git status
|
||||||
|
git add .
|
||||||
|
git status
|
||||||
|
git commit -m "first commit"
|
||||||
|
git push https://csukuangfj:$HF_TOKEN@huggingface.co/k2-fsa/$d main
|
||||||
|
cd ..
|
||||||
|
done
|
||||||
@@ -1,5 +1,5 @@
|
|||||||
<resources>
|
<resources>
|
||||||
<string name="app_name">ASR with Next-gen Kaldi</string>
|
<string name="app_name">ASR</string>
|
||||||
<string name="hint">Click the Start button to play speech-to-text with Next-gen Kaldi.
|
<string name="hint">Click the Start button to play speech-to-text with Next-gen Kaldi.
|
||||||
\n
|
\n
|
||||||
\n\n\n
|
\n\n\n
|
||||||
|
|||||||
@@ -1,5 +1,5 @@
|
|||||||
<resources>
|
<resources>
|
||||||
<string name="app_name">ASR with Next-gen Kaldi</string>
|
<string name="app_name">ASR2pass </string>
|
||||||
<string name="hint">Click the Start button to play speech-to-text with Next-gen Kaldi.
|
<string name="hint">Click the Start button to play speech-to-text with Next-gen Kaldi.
|
||||||
\n
|
\n
|
||||||
\n\n\n
|
\n\n\n
|
||||||
|
|||||||
@@ -1,16 +1,16 @@
|
|||||||
package com.k2fsa.sherpa.onnx
|
package com.k2fsa.sherpa.onnx
|
||||||
|
|
||||||
import android.content.res.AssetManager
|
import android.content.res.AssetManager
|
||||||
import android.util.Log
|
|
||||||
|
|
||||||
private val TAG = "sherpa-onnx"
|
const val TAG = "sherpa-onnx"
|
||||||
|
|
||||||
data class OfflineZipformerAudioTaggingModelConfig(
|
data class OfflineZipformerAudioTaggingModelConfig(
|
||||||
var model: String,
|
var model: String = "",
|
||||||
)
|
)
|
||||||
|
|
||||||
data class AudioTaggingModelConfig(
|
data class AudioTaggingModelConfig(
|
||||||
var zipformer: OfflineZipformerAudioTaggingModelConfig,
|
var zipformer: OfflineZipformerAudioTaggingModelConfig = OfflineZipformerAudioTaggingModelConfig(),
|
||||||
|
var ced: String = "",
|
||||||
var numThreads: Int = 1,
|
var numThreads: Int = 1,
|
||||||
var debug: Boolean = false,
|
var debug: Boolean = false,
|
||||||
var provider: String = "cpu",
|
var provider: String = "cpu",
|
||||||
@@ -103,7 +103,7 @@ class AudioTagging(
|
|||||||
//
|
//
|
||||||
// See also
|
// See also
|
||||||
// https://k2-fsa.github.io/sherpa/onnx/audio-tagging/
|
// https://k2-fsa.github.io/sherpa/onnx/audio-tagging/
|
||||||
fun getAudioTaggingConfig(type: Int, numThreads: Int=1): AudioTaggingConfig? {
|
fun getAudioTaggingConfig(type: Int, numThreads: Int = 1): AudioTaggingConfig? {
|
||||||
when (type) {
|
when (type) {
|
||||||
0 -> {
|
0 -> {
|
||||||
val modelDir = "sherpa-onnx-zipformer-small-audio-tagging-2024-04-15"
|
val modelDir = "sherpa-onnx-zipformer-small-audio-tagging-2024-04-15"
|
||||||
@@ -123,7 +123,7 @@ fun getAudioTaggingConfig(type: Int, numThreads: Int=1): AudioTaggingConfig? {
|
|||||||
return AudioTaggingConfig(
|
return AudioTaggingConfig(
|
||||||
model = AudioTaggingModelConfig(
|
model = AudioTaggingModelConfig(
|
||||||
zipformer = OfflineZipformerAudioTaggingModelConfig(model = "$modelDir/model.int8.onnx"),
|
zipformer = OfflineZipformerAudioTaggingModelConfig(model = "$modelDir/model.int8.onnx"),
|
||||||
numThreads = 1,
|
numThreads = numThreads,
|
||||||
debug = true,
|
debug = true,
|
||||||
),
|
),
|
||||||
labels = "$modelDir/class_labels_indices.csv",
|
labels = "$modelDir/class_labels_indices.csv",
|
||||||
@@ -131,6 +131,57 @@ fun getAudioTaggingConfig(type: Int, numThreads: Int=1): AudioTaggingConfig? {
|
|||||||
)
|
)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
2 -> {
|
||||||
|
val modelDir = "sherpa-onnx-ced-tiny-audio-tagging-2024-04-19"
|
||||||
|
return AudioTaggingConfig(
|
||||||
|
model = AudioTaggingModelConfig(
|
||||||
|
ced = "$modelDir/model.int8.onnx",
|
||||||
|
numThreads = numThreads,
|
||||||
|
debug = true,
|
||||||
|
),
|
||||||
|
labels = "$modelDir/class_labels_indices.csv",
|
||||||
|
topK = 3,
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
3 -> {
|
||||||
|
val modelDir = "sherpa-onnx-ced-mini-audio-tagging-2024-04-19"
|
||||||
|
return AudioTaggingConfig(
|
||||||
|
model = AudioTaggingModelConfig(
|
||||||
|
ced = "$modelDir/model.int8.onnx",
|
||||||
|
numThreads = numThreads,
|
||||||
|
debug = true,
|
||||||
|
),
|
||||||
|
labels = "$modelDir/class_labels_indices.csv",
|
||||||
|
topK = 3,
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
4 -> {
|
||||||
|
val modelDir = "sherpa-onnx-ced-small-audio-tagging-2024-04-19"
|
||||||
|
return AudioTaggingConfig(
|
||||||
|
model = AudioTaggingModelConfig(
|
||||||
|
ced = "$modelDir/model.int8.onnx",
|
||||||
|
numThreads = numThreads,
|
||||||
|
debug = true,
|
||||||
|
),
|
||||||
|
labels = "$modelDir/class_labels_indices.csv",
|
||||||
|
topK = 3,
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
5 -> {
|
||||||
|
val modelDir = "sherpa-onnx-ced-base-audio-tagging-2024-04-19"
|
||||||
|
return AudioTaggingConfig(
|
||||||
|
model = AudioTaggingModelConfig(
|
||||||
|
ced = "$modelDir/model.int8.onnx",
|
||||||
|
numThreads = numThreads,
|
||||||
|
debug = true,
|
||||||
|
),
|
||||||
|
labels = "$modelDir/class_labels_indices.csv",
|
||||||
|
topK = 3,
|
||||||
|
)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
return null
|
return null
|
||||||
|
|||||||
@@ -3,24 +3,15 @@
|
|||||||
package com.k2fsa.sherpa.onnx.audio.tagging
|
package com.k2fsa.sherpa.onnx.audio.tagging
|
||||||
|
|
||||||
import android.Manifest
|
import android.Manifest
|
||||||
|
|
||||||
import android.app.Activity
|
import android.app.Activity
|
||||||
import android.content.pm.PackageManager
|
import android.content.pm.PackageManager
|
||||||
import android.media.AudioFormat
|
import android.media.AudioFormat
|
||||||
import android.media.AudioRecord
|
import android.media.AudioRecord
|
||||||
import androidx.compose.foundation.lazy.items
|
|
||||||
import android.media.MediaRecorder
|
import android.media.MediaRecorder
|
||||||
import android.util.Log
|
import android.util.Log
|
||||||
import androidx.compose.foundation.ExperimentalFoundationApi
|
import androidx.compose.foundation.ExperimentalFoundationApi
|
||||||
import androidx.compose.foundation.background
|
|
||||||
import androidx.compose.foundation.layout.Arrangement
|
import androidx.compose.foundation.layout.Arrangement
|
||||||
import androidx.compose.foundation.layout.Box
|
import androidx.compose.foundation.layout.Box
|
||||||
import androidx.compose.material3.CenterAlignedTopAppBar
|
|
||||||
import androidx.compose.runtime.Composable
|
|
||||||
import androidx.compose.material3.Scaffold
|
|
||||||
import androidx.compose.material3.TopAppBarDefaults
|
|
||||||
import androidx.compose.material3.MaterialTheme
|
|
||||||
import androidx.compose.material3.Text
|
|
||||||
import androidx.compose.foundation.layout.Column
|
import androidx.compose.foundation.layout.Column
|
||||||
import androidx.compose.foundation.layout.PaddingValues
|
import androidx.compose.foundation.layout.PaddingValues
|
||||||
import androidx.compose.foundation.layout.Row
|
import androidx.compose.foundation.layout.Row
|
||||||
@@ -30,10 +21,17 @@ import androidx.compose.foundation.layout.fillMaxWidth
|
|||||||
import androidx.compose.foundation.layout.height
|
import androidx.compose.foundation.layout.height
|
||||||
import androidx.compose.foundation.layout.padding
|
import androidx.compose.foundation.layout.padding
|
||||||
import androidx.compose.foundation.lazy.LazyColumn
|
import androidx.compose.foundation.lazy.LazyColumn
|
||||||
|
import androidx.compose.foundation.lazy.items
|
||||||
import androidx.compose.material3.Button
|
import androidx.compose.material3.Button
|
||||||
|
import androidx.compose.material3.CenterAlignedTopAppBar
|
||||||
import androidx.compose.material3.ExperimentalMaterial3Api
|
import androidx.compose.material3.ExperimentalMaterial3Api
|
||||||
|
import androidx.compose.material3.MaterialTheme
|
||||||
|
import androidx.compose.material3.Scaffold
|
||||||
import androidx.compose.material3.Slider
|
import androidx.compose.material3.Slider
|
||||||
import androidx.compose.material3.Surface
|
import androidx.compose.material3.Surface
|
||||||
|
import androidx.compose.material3.Text
|
||||||
|
import androidx.compose.material3.TopAppBarDefaults
|
||||||
|
import androidx.compose.runtime.Composable
|
||||||
import androidx.compose.runtime.getValue
|
import androidx.compose.runtime.getValue
|
||||||
import androidx.compose.runtime.mutableStateListOf
|
import androidx.compose.runtime.mutableStateListOf
|
||||||
import androidx.compose.runtime.mutableStateOf
|
import androidx.compose.runtime.mutableStateOf
|
||||||
@@ -41,7 +39,6 @@ import androidx.compose.runtime.remember
|
|||||||
import androidx.compose.runtime.setValue
|
import androidx.compose.runtime.setValue
|
||||||
import androidx.compose.ui.Alignment
|
import androidx.compose.ui.Alignment
|
||||||
import androidx.compose.ui.Modifier
|
import androidx.compose.ui.Modifier
|
||||||
import androidx.compose.ui.graphics.Color
|
|
||||||
import androidx.compose.ui.platform.LocalContext
|
import androidx.compose.ui.platform.LocalContext
|
||||||
import androidx.compose.ui.text.font.FontWeight
|
import androidx.compose.ui.text.font.FontWeight
|
||||||
import androidx.compose.ui.text.style.TextAlign
|
import androidx.compose.ui.text.style.TextAlign
|
||||||
@@ -49,6 +46,7 @@ import androidx.compose.ui.unit.dp
|
|||||||
import androidx.compose.ui.unit.sp
|
import androidx.compose.ui.unit.sp
|
||||||
import androidx.core.app.ActivityCompat
|
import androidx.core.app.ActivityCompat
|
||||||
import com.k2fsa.sherpa.onnx.AudioEvent
|
import com.k2fsa.sherpa.onnx.AudioEvent
|
||||||
|
import com.k2fsa.sherpa.onnx.Tagger
|
||||||
import kotlin.concurrent.thread
|
import kotlin.concurrent.thread
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -13,6 +13,7 @@ import androidx.compose.material3.Surface
|
|||||||
import androidx.compose.runtime.Composable
|
import androidx.compose.runtime.Composable
|
||||||
import androidx.compose.ui.Modifier
|
import androidx.compose.ui.Modifier
|
||||||
import androidx.core.app.ActivityCompat
|
import androidx.core.app.ActivityCompat
|
||||||
|
import com.k2fsa.sherpa.onnx.Tagger
|
||||||
import com.k2fsa.sherpa.onnx.audio.tagging.ui.theme.SherpaOnnxAudioTaggingTheme
|
import com.k2fsa.sherpa.onnx.audio.tagging.ui.theme.SherpaOnnxAudioTaggingTheme
|
||||||
|
|
||||||
const val TAG = "sherpa-onnx"
|
const val TAG = "sherpa-onnx"
|
||||||
|
|||||||
@@ -1,10 +1,8 @@
|
|||||||
package com.k2fsa.sherpa.onnx.audio.tagging
|
package com.k2fsa.sherpa.onnx
|
||||||
|
|
||||||
import android.content.res.AssetManager
|
import android.content.res.AssetManager
|
||||||
import android.util.Log
|
import android.util.Log
|
||||||
import com.k2fsa.sherpa.onnx.AudioTagging
|
|
||||||
import com.k2fsa.sherpa.onnx.audio.tagging.wear.os.presentation.TAG
|
|
||||||
import com.k2fsa.sherpa.onnx.getAudioTaggingConfig
|
|
||||||
|
|
||||||
object Tagger {
|
object Tagger {
|
||||||
private var _tagger: AudioTagging? = null
|
private var _tagger: AudioTagging? = null
|
||||||
@@ -12,6 +10,7 @@ object Tagger {
|
|||||||
get() {
|
get() {
|
||||||
return _tagger!!
|
return _tagger!!
|
||||||
}
|
}
|
||||||
|
|
||||||
fun initTagger(assetManager: AssetManager? = null, numThreads: Int = 1) {
|
fun initTagger(assetManager: AssetManager? = null, numThreads: Int = 1) {
|
||||||
synchronized(this) {
|
synchronized(this) {
|
||||||
if (_tagger != null) {
|
if (_tagger != null) {
|
||||||
@@ -19,7 +18,7 @@ object Tagger {
|
|||||||
}
|
}
|
||||||
|
|
||||||
Log.i(TAG, "Initializing audio tagger")
|
Log.i(TAG, "Initializing audio tagger")
|
||||||
val config = getAudioTaggingConfig(type = 0, numThreads=numThreads)!!
|
val config = getAudioTaggingConfig(type = 0, numThreads = numThreads)!!
|
||||||
_tagger = AudioTagging(assetManager, config)
|
_tagger = AudioTagging(assetManager, config)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -33,7 +33,7 @@ import androidx.wear.compose.material.Button
|
|||||||
import androidx.wear.compose.material.MaterialTheme
|
import androidx.wear.compose.material.MaterialTheme
|
||||||
import androidx.wear.compose.material.Text
|
import androidx.wear.compose.material.Text
|
||||||
import com.k2fsa.sherpa.onnx.AudioEvent
|
import com.k2fsa.sherpa.onnx.AudioEvent
|
||||||
import com.k2fsa.sherpa.onnx.audio.tagging.Tagger
|
import com.k2fsa.sherpa.onnx.Tagger
|
||||||
import com.k2fsa.sherpa.onnx.audio.tagging.wear.os.presentation.theme.SherpaOnnxAudioTaggingWearOsTheme
|
import com.k2fsa.sherpa.onnx.audio.tagging.wear.os.presentation.theme.SherpaOnnxAudioTaggingWearOsTheme
|
||||||
import kotlin.concurrent.thread
|
import kotlin.concurrent.thread
|
||||||
|
|
||||||
|
|||||||
@@ -17,7 +17,7 @@ import androidx.activity.compose.setContent
|
|||||||
import androidx.compose.runtime.Composable
|
import androidx.compose.runtime.Composable
|
||||||
import androidx.core.app.ActivityCompat
|
import androidx.core.app.ActivityCompat
|
||||||
import androidx.core.splashscreen.SplashScreen.Companion.installSplashScreen
|
import androidx.core.splashscreen.SplashScreen.Companion.installSplashScreen
|
||||||
import com.k2fsa.sherpa.onnx.audio.tagging.Tagger
|
import com.k2fsa.sherpa.onnx.Tagger
|
||||||
|
|
||||||
const val TAG = "sherpa-onnx"
|
const val TAG = "sherpa-onnx"
|
||||||
private const val REQUEST_RECORD_AUDIO_PERMISSION = 200
|
private const val REQUEST_RECORD_AUDIO_PERMISSION = 200
|
||||||
|
|||||||
@@ -1,5 +1,5 @@
|
|||||||
<resources>
|
<resources>
|
||||||
<string name="app_name">AudioTagging</string>
|
<string name="app_name">Audio Tagging</string>
|
||||||
<!--
|
<!--
|
||||||
This string is used for square devices and overridden by hello_world in
|
This string is used for square devices and overridden by hello_world in
|
||||||
values-round/strings.xml for round devices.
|
values-round/strings.xml for round devices.
|
||||||
|
|||||||
@@ -1,5 +1,5 @@
|
|||||||
<resources>
|
<resources>
|
||||||
<string name="app_name">Speaker Identification</string>
|
<string name="app_name">Speaker ID</string>
|
||||||
<string name="start">Start recording</string>
|
<string name="start">Start recording</string>
|
||||||
<string name="stop">Stop recording</string>
|
<string name="stop">Stop recording</string>
|
||||||
<string name="add">Add speaker</string>
|
<string name="add">Add speaker</string>
|
||||||
|
|||||||
@@ -1,3 +1,3 @@
|
|||||||
<resources>
|
<resources>
|
||||||
<string name="app_name">SherpaOnnxSpokenLanguageIdentification</string>
|
<string name="app_name">Language ID</string>
|
||||||
</resources>
|
</resources>
|
||||||
@@ -1,5 +1,5 @@
|
|||||||
<resources>
|
<resources>
|
||||||
<string name="app_name">Next-gen Kaldi: TTS</string>
|
<string name="app_name">TTS</string>
|
||||||
<string name="sid_label">Speaker ID</string>
|
<string name="sid_label">Speaker ID</string>
|
||||||
<string name="sid_hint">0</string>
|
<string name="sid_hint">0</string>
|
||||||
<string name="speed_label">Speech speed (large->fast)</string>
|
<string name="speed_label">Speech speed (large->fast)</string>
|
||||||
|
|||||||
@@ -1,3 +1,3 @@
|
|||||||
<resources>
|
<resources>
|
||||||
<string name="app_name">Next-gen Kaldi: TTS</string>
|
<string name="app_name">TTS Engine</string>
|
||||||
</resources>
|
</resources>
|
||||||
@@ -1,5 +1,5 @@
|
|||||||
<resources>
|
<resources>
|
||||||
<string name="app_name">Next-gen Kaldi: SileroVAD</string>
|
<string name="app_name">VAD</string>
|
||||||
|
|
||||||
<string name="hint">Click the Start button to play Silero VAD with Next-gen Kaldi.</string>
|
<string name="hint">Click the Start button to play Silero VAD with Next-gen Kaldi.</string>
|
||||||
<string name="start">Start</string>
|
<string name="start">Start</string>
|
||||||
|
|||||||
@@ -1,5 +1,5 @@
|
|||||||
<resources>
|
<resources>
|
||||||
<string name="app_name">ASR with Next-gen Kaldi</string>
|
<string name="app_name">VAD-ASR</string>
|
||||||
<string name="hint">Click the Start button to play speech-to-text with Next-gen Kaldi.
|
<string name="hint">Click the Start button to play speech-to-text with Next-gen Kaldi.
|
||||||
\n
|
\n
|
||||||
\n\n\n
|
\n\n\n
|
||||||
|
|||||||
@@ -46,7 +46,30 @@ def get_models():
|
|||||||
),
|
),
|
||||||
]
|
]
|
||||||
|
|
||||||
return icefall_models
|
ced_models = [
|
||||||
|
AudioTaggingModel(
|
||||||
|
model_name="sherpa-onnx-ced-tiny-audio-tagging-2024-04-19",
|
||||||
|
idx=2,
|
||||||
|
short_name="ced_tiny",
|
||||||
|
),
|
||||||
|
AudioTaggingModel(
|
||||||
|
model_name="sherpa-onnx-ced-mini-audio-tagging-2024-04-19",
|
||||||
|
idx=3,
|
||||||
|
short_name="ced_mini",
|
||||||
|
),
|
||||||
|
AudioTaggingModel(
|
||||||
|
model_name="sherpa-onnx-ced-small-audio-tagging-2024-04-19",
|
||||||
|
idx=4,
|
||||||
|
short_name="ced_small",
|
||||||
|
),
|
||||||
|
AudioTaggingModel(
|
||||||
|
model_name="sherpa-onnx-ced-base-audio-tagging-2024-04-19",
|
||||||
|
idx=5,
|
||||||
|
short_name="ced_base",
|
||||||
|
),
|
||||||
|
]
|
||||||
|
|
||||||
|
return icefall_models + ced_models
|
||||||
|
|
||||||
|
|
||||||
def main():
|
def main():
|
||||||
|
|||||||
69
scripts/ced/run.sh
Executable file
69
scripts/ced/run.sh
Executable file
@@ -0,0 +1,69 @@
|
|||||||
|
#!/usr/bin/env bash
|
||||||
|
#
|
||||||
|
# Copyright 2024 Xiaomi Corp. (authors: Fangjun Kuang)
|
||||||
|
|
||||||
|
set -ex
|
||||||
|
|
||||||
|
function install_dependencies() {
|
||||||
|
pip install -qq torch==2.1.0+cpu torchaudio==2.1.0+cpu -f https://download.pytorch.org/whl/torch_stable.html
|
||||||
|
pip install -qq onnx onnxruntime==1.17.1
|
||||||
|
|
||||||
|
pip install -r ./requirements.txt
|
||||||
|
}
|
||||||
|
|
||||||
|
git clone https://github.com/RicherMans/CED
|
||||||
|
pushd CED
|
||||||
|
|
||||||
|
install_dependencies
|
||||||
|
|
||||||
|
models=(
|
||||||
|
tiny
|
||||||
|
mini
|
||||||
|
small
|
||||||
|
base
|
||||||
|
)
|
||||||
|
|
||||||
|
for m in ${models[@]}; do
|
||||||
|
python3 ./export_onnx.py -m ced_$m
|
||||||
|
done
|
||||||
|
|
||||||
|
ls -lh *.onnx
|
||||||
|
|
||||||
|
curl -SL -O https://github.com/k2-fsa/sherpa-onnx/releases/download/audio-tagging-models/sherpa-onnx-zipformer-small-audio-tagging-2024-04-15.tar.bz2
|
||||||
|
|
||||||
|
tar xvf sherpa-onnx-zipformer-small-audio-tagging-2024-04-15.tar.bz2
|
||||||
|
rm sherpa-onnx-zipformer-small-audio-tagging-2024-04-15.tar.bz2
|
||||||
|
src=sherpa-onnx-zipformer-small-audio-tagging-2024-04-15
|
||||||
|
|
||||||
|
cat >README.md <<EOF
|
||||||
|
# Introduction
|
||||||
|
|
||||||
|
Models in this repo are converted from
|
||||||
|
https://github.com/RicherMans/CED
|
||||||
|
EOF
|
||||||
|
|
||||||
|
for m in ${models[@]}; do
|
||||||
|
d=sherpa-onnx-ced-$m-audio-tagging-2024-04-19
|
||||||
|
|
||||||
|
mkdir -p $d
|
||||||
|
|
||||||
|
cp -v README.md $d
|
||||||
|
cp -v $src/class_labels_indices.csv $d
|
||||||
|
cp -a $src/test_wavs $d
|
||||||
|
cp -v ced_$m.onnx $d/model.onnx
|
||||||
|
cp -v ced_$m.int8.onnx $d/model.int8.onnx
|
||||||
|
echo "----------$m----------"
|
||||||
|
ls -lh $d
|
||||||
|
echo "----------------------"
|
||||||
|
tar cjvf $d.tar.bz2 $d
|
||||||
|
mv $d.tar.bz2 ../../..
|
||||||
|
mv $d ../../../
|
||||||
|
done
|
||||||
|
|
||||||
|
rm -rf sherpa-onnx-zipformer-small-audio-tagging-2024-04-15
|
||||||
|
|
||||||
|
cd ../../..
|
||||||
|
|
||||||
|
ls -lh *.tar.bz2
|
||||||
|
echo "======="
|
||||||
|
ls -lh
|
||||||
@@ -1223,6 +1223,7 @@ const SherpaOnnxAudioTagging *SherpaOnnxCreateAudioTagging(
|
|||||||
const SherpaOnnxAudioTaggingConfig *config) {
|
const SherpaOnnxAudioTaggingConfig *config) {
|
||||||
sherpa_onnx::AudioTaggingConfig ac;
|
sherpa_onnx::AudioTaggingConfig ac;
|
||||||
ac.model.zipformer.model = SHERPA_ONNX_OR(config->model.zipformer.model, "");
|
ac.model.zipformer.model = SHERPA_ONNX_OR(config->model.zipformer.model, "");
|
||||||
|
ac.model.ced = SHERPA_ONNX_OR(config->model.ced, "");
|
||||||
ac.model.num_threads = SHERPA_ONNX_OR(config->model.num_threads, 1);
|
ac.model.num_threads = SHERPA_ONNX_OR(config->model.num_threads, 1);
|
||||||
ac.model.debug = config->model.debug;
|
ac.model.debug = config->model.debug;
|
||||||
ac.model.provider = SHERPA_ONNX_OR(config->model.provider, "cpu");
|
ac.model.provider = SHERPA_ONNX_OR(config->model.provider, "cpu");
|
||||||
|
|||||||
@@ -1100,6 +1100,7 @@ SHERPA_ONNX_API typedef struct
|
|||||||
|
|
||||||
SHERPA_ONNX_API typedef struct SherpaOnnxAudioTaggingModelConfig {
|
SHERPA_ONNX_API typedef struct SherpaOnnxAudioTaggingModelConfig {
|
||||||
SherpaOnnxOfflineZipformerAudioTaggingModelConfig zipformer;
|
SherpaOnnxOfflineZipformerAudioTaggingModelConfig zipformer;
|
||||||
|
const char *ced;
|
||||||
int32_t num_threads;
|
int32_t num_threads;
|
||||||
int32_t debug; // true to print debug information of the model
|
int32_t debug; // true to print debug information of the model
|
||||||
const char *provider;
|
const char *provider;
|
||||||
|
|||||||
@@ -117,6 +117,7 @@ list(APPEND sources
|
|||||||
audio-tagging-label-file.cc
|
audio-tagging-label-file.cc
|
||||||
audio-tagging-model-config.cc
|
audio-tagging-model-config.cc
|
||||||
audio-tagging.cc
|
audio-tagging.cc
|
||||||
|
offline-ced-model.cc
|
||||||
offline-zipformer-audio-tagging-model-config.cc
|
offline-zipformer-audio-tagging-model-config.cc
|
||||||
offline-zipformer-audio-tagging-model.cc
|
offline-zipformer-audio-tagging-model.cc
|
||||||
)
|
)
|
||||||
|
|||||||
111
sherpa-onnx/csrc/audio-tagging-ced-impl.h
Normal file
111
sherpa-onnx/csrc/audio-tagging-ced-impl.h
Normal file
@@ -0,0 +1,111 @@
|
|||||||
|
// sherpa-onnx/csrc/audio-tagging-ced-impl.h
|
||||||
|
//
|
||||||
|
// Copyright (c) 2024 Xiaomi Corporation
|
||||||
|
#ifndef SHERPA_ONNX_CSRC_AUDIO_TAGGING_CED_IMPL_H_
|
||||||
|
#define SHERPA_ONNX_CSRC_AUDIO_TAGGING_CED_IMPL_H_
|
||||||
|
|
||||||
|
#include <assert.h>
|
||||||
|
|
||||||
|
#include <memory>
|
||||||
|
#include <utility>
|
||||||
|
#include <vector>
|
||||||
|
|
||||||
|
#if __ANDROID_API__ >= 9
|
||||||
|
#include "android/asset_manager.h"
|
||||||
|
#include "android/asset_manager_jni.h"
|
||||||
|
#endif
|
||||||
|
|
||||||
|
#include "sherpa-onnx/csrc/audio-tagging-impl.h"
|
||||||
|
#include "sherpa-onnx/csrc/audio-tagging-label-file.h"
|
||||||
|
#include "sherpa-onnx/csrc/audio-tagging.h"
|
||||||
|
#include "sherpa-onnx/csrc/macros.h"
|
||||||
|
#include "sherpa-onnx/csrc/math.h"
|
||||||
|
#include "sherpa-onnx/csrc/offline-ced-model.h"
|
||||||
|
|
||||||
|
namespace sherpa_onnx {
|
||||||
|
|
||||||
|
class AudioTaggingCEDImpl : public AudioTaggingImpl {
|
||||||
|
public:
|
||||||
|
explicit AudioTaggingCEDImpl(const AudioTaggingConfig &config)
|
||||||
|
: config_(config), model_(config.model), labels_(config.labels) {
|
||||||
|
if (model_.NumEventClasses() != labels_.NumEventClasses()) {
|
||||||
|
SHERPA_ONNX_LOGE("number of classes: %d (model) != %d (label file)",
|
||||||
|
model_.NumEventClasses(), labels_.NumEventClasses());
|
||||||
|
exit(-1);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#if __ANDROID_API__ >= 9
|
||||||
|
explicit AudioTaggingCEDImpl(AAssetManager *mgr,
|
||||||
|
const AudioTaggingConfig &config)
|
||||||
|
: config_(config),
|
||||||
|
model_(mgr, config.model),
|
||||||
|
labels_(mgr, config.labels) {
|
||||||
|
if (model_.NumEventClasses() != labels_.NumEventClasses()) {
|
||||||
|
SHERPA_ONNX_LOGE("number of classes: %d (model) != %d (label file)",
|
||||||
|
model_.NumEventClasses(), labels_.NumEventClasses());
|
||||||
|
exit(-1);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
#endif
|
||||||
|
|
||||||
|
std::unique_ptr<OfflineStream> CreateStream() const override {
|
||||||
|
return std::make_unique<OfflineStream>(CEDTag{});
|
||||||
|
}
|
||||||
|
|
||||||
|
std::vector<AudioEvent> Compute(OfflineStream *s,
|
||||||
|
int32_t top_k = -1) const override {
|
||||||
|
if (top_k < 0) {
|
||||||
|
top_k = config_.top_k;
|
||||||
|
}
|
||||||
|
|
||||||
|
int32_t num_event_classes = model_.NumEventClasses();
|
||||||
|
|
||||||
|
if (top_k > num_event_classes) {
|
||||||
|
top_k = num_event_classes;
|
||||||
|
}
|
||||||
|
|
||||||
|
auto memory_info =
|
||||||
|
Ort::MemoryInfo::CreateCpu(OrtDeviceAllocator, OrtMemTypeDefault);
|
||||||
|
|
||||||
|
// WARNING(fangjun): It is fixed to 64 for CED models
|
||||||
|
int32_t feat_dim = 64;
|
||||||
|
std::vector<float> f = s->GetFrames();
|
||||||
|
|
||||||
|
int32_t num_frames = f.size() / feat_dim;
|
||||||
|
assert(feat_dim * num_frames == f.size());
|
||||||
|
|
||||||
|
std::array<int64_t, 3> shape = {1, num_frames, feat_dim};
|
||||||
|
|
||||||
|
Ort::Value x = Ort::Value::CreateTensor(memory_info, f.data(), f.size(),
|
||||||
|
shape.data(), shape.size());
|
||||||
|
|
||||||
|
Ort::Value probs = model_.Forward(std::move(x));
|
||||||
|
|
||||||
|
const float *p = probs.GetTensorData<float>();
|
||||||
|
|
||||||
|
std::vector<int32_t> top_k_indexes = TopkIndex(p, num_event_classes, top_k);
|
||||||
|
|
||||||
|
std::vector<AudioEvent> ans(top_k);
|
||||||
|
|
||||||
|
int32_t i = 0;
|
||||||
|
|
||||||
|
for (int32_t index : top_k_indexes) {
|
||||||
|
ans[i].name = labels_.GetEventName(index);
|
||||||
|
ans[i].index = index;
|
||||||
|
ans[i].prob = p[index];
|
||||||
|
i += 1;
|
||||||
|
}
|
||||||
|
|
||||||
|
return ans;
|
||||||
|
}
|
||||||
|
|
||||||
|
private:
|
||||||
|
AudioTaggingConfig config_;
|
||||||
|
OfflineCEDModel model_;
|
||||||
|
AudioTaggingLabels labels_;
|
||||||
|
};
|
||||||
|
|
||||||
|
} // namespace sherpa_onnx
|
||||||
|
|
||||||
|
#endif // SHERPA_ONNX_CSRC_AUDIO_TAGGING_CED_IMPL_H_
|
||||||
@@ -11,6 +11,7 @@
|
|||||||
#include "android/asset_manager_jni.h"
|
#include "android/asset_manager_jni.h"
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
|
#include "sherpa-onnx/csrc/audio-tagging-ced-impl.h"
|
||||||
#include "sherpa-onnx/csrc/audio-tagging-zipformer-impl.h"
|
#include "sherpa-onnx/csrc/audio-tagging-zipformer-impl.h"
|
||||||
#include "sherpa-onnx/csrc/macros.h"
|
#include "sherpa-onnx/csrc/macros.h"
|
||||||
|
|
||||||
@@ -20,6 +21,8 @@ std::unique_ptr<AudioTaggingImpl> AudioTaggingImpl::Create(
|
|||||||
const AudioTaggingConfig &config) {
|
const AudioTaggingConfig &config) {
|
||||||
if (!config.model.zipformer.model.empty()) {
|
if (!config.model.zipformer.model.empty()) {
|
||||||
return std::make_unique<AudioTaggingZipformerImpl>(config);
|
return std::make_unique<AudioTaggingZipformerImpl>(config);
|
||||||
|
} else if (!config.model.ced.empty()) {
|
||||||
|
return std::make_unique<AudioTaggingCEDImpl>(config);
|
||||||
}
|
}
|
||||||
|
|
||||||
SHERPA_ONNX_LOG(
|
SHERPA_ONNX_LOG(
|
||||||
@@ -32,6 +35,8 @@ std::unique_ptr<AudioTaggingImpl> AudioTaggingImpl::Create(
|
|||||||
AAssetManager *mgr, const AudioTaggingConfig &config) {
|
AAssetManager *mgr, const AudioTaggingConfig &config) {
|
||||||
if (!config.model.zipformer.model.empty()) {
|
if (!config.model.zipformer.model.empty()) {
|
||||||
return std::make_unique<AudioTaggingZipformerImpl>(mgr, config);
|
return std::make_unique<AudioTaggingZipformerImpl>(mgr, config);
|
||||||
|
} else if (!config.model.ced.empty()) {
|
||||||
|
return std::make_unique<AudioTaggingCEDImpl>(mgr, config);
|
||||||
}
|
}
|
||||||
|
|
||||||
SHERPA_ONNX_LOG(
|
SHERPA_ONNX_LOG(
|
||||||
|
|||||||
@@ -4,11 +4,18 @@
|
|||||||
|
|
||||||
#include "sherpa-onnx/csrc/audio-tagging-model-config.h"
|
#include "sherpa-onnx/csrc/audio-tagging-model-config.h"
|
||||||
|
|
||||||
|
#include "sherpa-onnx/csrc/file-utils.h"
|
||||||
|
#include "sherpa-onnx/csrc/macros.h"
|
||||||
|
|
||||||
namespace sherpa_onnx {
|
namespace sherpa_onnx {
|
||||||
|
|
||||||
void AudioTaggingModelConfig::Register(ParseOptions *po) {
|
void AudioTaggingModelConfig::Register(ParseOptions *po) {
|
||||||
zipformer.Register(po);
|
zipformer.Register(po);
|
||||||
|
|
||||||
|
po->Register("ced-model", &ced,
|
||||||
|
"Path to CED model. Only need to pass one of --zipformer-model "
|
||||||
|
"or --ced-model");
|
||||||
|
|
||||||
po->Register("num-threads", &num_threads,
|
po->Register("num-threads", &num_threads,
|
||||||
"Number of threads to run the neural network");
|
"Number of threads to run the neural network");
|
||||||
|
|
||||||
@@ -24,6 +31,16 @@ bool AudioTaggingModelConfig::Validate() const {
|
|||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if (!ced.empty() && !FileExists(ced)) {
|
||||||
|
SHERPA_ONNX_LOGE("CED model file %s does not exist", ced.c_str());
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (zipformer.model.empty() && ced.empty()) {
|
||||||
|
SHERPA_ONNX_LOGE("Please provide either --zipformer-model or --ced-model");
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -32,6 +49,7 @@ std::string AudioTaggingModelConfig::ToString() const {
|
|||||||
|
|
||||||
os << "AudioTaggingModelConfig(";
|
os << "AudioTaggingModelConfig(";
|
||||||
os << "zipformer=" << zipformer.ToString() << ", ";
|
os << "zipformer=" << zipformer.ToString() << ", ";
|
||||||
|
os << "ced=\"" << ced << "\", ";
|
||||||
os << "num_threads=" << num_threads << ", ";
|
os << "num_threads=" << num_threads << ", ";
|
||||||
os << "debug=" << (debug ? "True" : "False") << ", ";
|
os << "debug=" << (debug ? "True" : "False") << ", ";
|
||||||
os << "provider=\"" << provider << "\")";
|
os << "provider=\"" << provider << "\")";
|
||||||
|
|||||||
@@ -13,6 +13,7 @@ namespace sherpa_onnx {
|
|||||||
|
|
||||||
struct AudioTaggingModelConfig {
|
struct AudioTaggingModelConfig {
|
||||||
struct OfflineZipformerAudioTaggingModelConfig zipformer;
|
struct OfflineZipformerAudioTaggingModelConfig zipformer;
|
||||||
|
std::string ced;
|
||||||
|
|
||||||
int32_t num_threads = 1;
|
int32_t num_threads = 1;
|
||||||
bool debug = false;
|
bool debug = false;
|
||||||
@@ -22,8 +23,10 @@ struct AudioTaggingModelConfig {
|
|||||||
|
|
||||||
AudioTaggingModelConfig(
|
AudioTaggingModelConfig(
|
||||||
const OfflineZipformerAudioTaggingModelConfig &zipformer,
|
const OfflineZipformerAudioTaggingModelConfig &zipformer,
|
||||||
int32_t num_threads, bool debug, const std::string &provider)
|
const std::string &ced, int32_t num_threads, bool debug,
|
||||||
|
const std::string &provider)
|
||||||
: zipformer(zipformer),
|
: zipformer(zipformer),
|
||||||
|
ced(ced),
|
||||||
num_threads(num_threads),
|
num_threads(num_threads),
|
||||||
debug(debug),
|
debug(debug),
|
||||||
provider(provider) {}
|
provider(provider) {}
|
||||||
|
|||||||
@@ -4,6 +4,8 @@
|
|||||||
#ifndef SHERPA_ONNX_CSRC_AUDIO_TAGGING_ZIPFORMER_IMPL_H_
|
#ifndef SHERPA_ONNX_CSRC_AUDIO_TAGGING_ZIPFORMER_IMPL_H_
|
||||||
#define SHERPA_ONNX_CSRC_AUDIO_TAGGING_ZIPFORMER_IMPL_H_
|
#define SHERPA_ONNX_CSRC_AUDIO_TAGGING_ZIPFORMER_IMPL_H_
|
||||||
|
|
||||||
|
#include <assert.h>
|
||||||
|
|
||||||
#include <memory>
|
#include <memory>
|
||||||
#include <utility>
|
#include <utility>
|
||||||
#include <vector>
|
#include <vector>
|
||||||
@@ -72,6 +74,8 @@ class AudioTaggingZipformerImpl : public AudioTaggingImpl {
|
|||||||
|
|
||||||
int32_t num_frames = f.size() / feat_dim;
|
int32_t num_frames = f.size() / feat_dim;
|
||||||
|
|
||||||
|
assert(feat_dim * num_frames == f.size());
|
||||||
|
|
||||||
std::array<int64_t, 3> shape = {1, num_frames, feat_dim};
|
std::array<int64_t, 3> shape = {1, num_frames, feat_dim};
|
||||||
|
|
||||||
Ort::Value x = Ort::Value::CreateTensor(memory_info, f.data(), f.size(),
|
Ort::Value x = Ort::Value::CreateTensor(memory_info, f.data(), f.size(),
|
||||||
|
|||||||
@@ -24,7 +24,8 @@ void FeatureExtractorConfig::Register(ParseOptions *po) {
|
|||||||
"inside the feature extractor");
|
"inside the feature extractor");
|
||||||
|
|
||||||
po->Register("feat-dim", &feature_dim,
|
po->Register("feat-dim", &feature_dim,
|
||||||
"Feature dimension. Must match the one expected by the model.");
|
"Feature dimension. Must match the one expected by the model. "
|
||||||
|
"Not used by whisper and CED models");
|
||||||
|
|
||||||
po->Register("low-freq", &low_freq, "Low cutoff frequency for mel bins");
|
po->Register("low-freq", &low_freq, "Low cutoff frequency for mel bins");
|
||||||
|
|
||||||
|
|||||||
112
sherpa-onnx/csrc/offline-ced-model.cc
Normal file
112
sherpa-onnx/csrc/offline-ced-model.cc
Normal file
@@ -0,0 +1,112 @@
|
|||||||
|
// sherpa-onnx/csrc/offline-ced-model.cc
|
||||||
|
//
|
||||||
|
// Copyright (c) 2024 Xiaomi Corporation
|
||||||
|
|
||||||
|
#include "sherpa-onnx/csrc/offline-ced-model.h"
|
||||||
|
|
||||||
|
#include <string>
|
||||||
|
#include <vector>
|
||||||
|
|
||||||
|
#include "sherpa-onnx/csrc/onnx-utils.h"
|
||||||
|
#include "sherpa-onnx/csrc/session.h"
|
||||||
|
#include "sherpa-onnx/csrc/text-utils.h"
|
||||||
|
#include "sherpa-onnx/csrc/transpose.h"
|
||||||
|
|
||||||
|
namespace sherpa_onnx {
|
||||||
|
|
||||||
|
class OfflineCEDModel::Impl {
|
||||||
|
public:
|
||||||
|
explicit Impl(const AudioTaggingModelConfig &config)
|
||||||
|
: config_(config),
|
||||||
|
env_(ORT_LOGGING_LEVEL_ERROR),
|
||||||
|
sess_opts_(GetSessionOptions(config)),
|
||||||
|
allocator_{} {
|
||||||
|
auto buf = ReadFile(config_.ced);
|
||||||
|
Init(buf.data(), buf.size());
|
||||||
|
}
|
||||||
|
|
||||||
|
#if __ANDROID_API__ >= 9
|
||||||
|
Impl(AAssetManager *mgr, const AudioTaggingModelConfig &config)
|
||||||
|
: config_(config),
|
||||||
|
env_(ORT_LOGGING_LEVEL_ERROR),
|
||||||
|
sess_opts_(GetSessionOptions(config)),
|
||||||
|
allocator_{} {
|
||||||
|
auto buf = ReadFile(mgr, config_.ced);
|
||||||
|
Init(buf.data(), buf.size());
|
||||||
|
}
|
||||||
|
#endif
|
||||||
|
|
||||||
|
Ort::Value Forward(Ort::Value features) {
|
||||||
|
features = Transpose12(allocator_, &features);
|
||||||
|
|
||||||
|
auto ans = sess_->Run({}, input_names_ptr_.data(), &features, 1,
|
||||||
|
output_names_ptr_.data(), output_names_ptr_.size());
|
||||||
|
return std::move(ans[0]);
|
||||||
|
}
|
||||||
|
|
||||||
|
int32_t NumEventClasses() const { return num_event_classes_; }
|
||||||
|
|
||||||
|
OrtAllocator *Allocator() const { return allocator_; }
|
||||||
|
|
||||||
|
private:
|
||||||
|
void Init(void *model_data, size_t model_data_length) {
|
||||||
|
sess_ = std::make_unique<Ort::Session>(env_, model_data, model_data_length,
|
||||||
|
sess_opts_);
|
||||||
|
|
||||||
|
GetInputNames(sess_.get(), &input_names_, &input_names_ptr_);
|
||||||
|
|
||||||
|
GetOutputNames(sess_.get(), &output_names_, &output_names_ptr_);
|
||||||
|
|
||||||
|
// get meta data
|
||||||
|
Ort::ModelMetadata meta_data = sess_->GetModelMetadata();
|
||||||
|
if (config_.debug) {
|
||||||
|
std::ostringstream os;
|
||||||
|
PrintModelMetadata(os, meta_data);
|
||||||
|
SHERPA_ONNX_LOGE("%s\n", os.str().c_str());
|
||||||
|
}
|
||||||
|
|
||||||
|
// get num_event_classes from the output[0].shape,
|
||||||
|
// which is (N, num_event_classes)
|
||||||
|
num_event_classes_ =
|
||||||
|
sess_->GetOutputTypeInfo(0).GetTensorTypeAndShapeInfo().GetShape()[1];
|
||||||
|
}
|
||||||
|
|
||||||
|
private:
|
||||||
|
AudioTaggingModelConfig config_;
|
||||||
|
Ort::Env env_;
|
||||||
|
Ort::SessionOptions sess_opts_;
|
||||||
|
Ort::AllocatorWithDefaultOptions allocator_;
|
||||||
|
|
||||||
|
std::unique_ptr<Ort::Session> sess_;
|
||||||
|
|
||||||
|
std::vector<std::string> input_names_;
|
||||||
|
std::vector<const char *> input_names_ptr_;
|
||||||
|
|
||||||
|
std::vector<std::string> output_names_;
|
||||||
|
std::vector<const char *> output_names_ptr_;
|
||||||
|
|
||||||
|
int32_t num_event_classes_ = 0;
|
||||||
|
};
|
||||||
|
|
||||||
|
OfflineCEDModel::OfflineCEDModel(const AudioTaggingModelConfig &config)
|
||||||
|
: impl_(std::make_unique<Impl>(config)) {}
|
||||||
|
|
||||||
|
#if __ANDROID_API__ >= 9
|
||||||
|
OfflineCEDModel::OfflineCEDModel(AAssetManager *mgr,
|
||||||
|
const AudioTaggingModelConfig &config)
|
||||||
|
: impl_(std::make_unique<Impl>(mgr, config)) {}
|
||||||
|
#endif
|
||||||
|
|
||||||
|
OfflineCEDModel::~OfflineCEDModel() = default;
|
||||||
|
|
||||||
|
Ort::Value OfflineCEDModel::Forward(Ort::Value features) const {
|
||||||
|
return impl_->Forward(std::move(features));
|
||||||
|
}
|
||||||
|
|
||||||
|
int32_t OfflineCEDModel::NumEventClasses() const {
|
||||||
|
return impl_->NumEventClasses();
|
||||||
|
}
|
||||||
|
|
||||||
|
OrtAllocator *OfflineCEDModel::Allocator() const { return impl_->Allocator(); }
|
||||||
|
|
||||||
|
} // namespace sherpa_onnx
|
||||||
56
sherpa-onnx/csrc/offline-ced-model.h
Normal file
56
sherpa-onnx/csrc/offline-ced-model.h
Normal file
@@ -0,0 +1,56 @@
|
|||||||
|
// sherpa-onnx/csrc/offline-ced-model.h
|
||||||
|
//
|
||||||
|
// Copyright (c) 2024 Xiaomi Corporation
|
||||||
|
#ifndef SHERPA_ONNX_CSRC_OFFLINE_CED_MODEL_H_
|
||||||
|
#define SHERPA_ONNX_CSRC_OFFLINE_CED_MODEL_H_
|
||||||
|
#include <memory>
|
||||||
|
#include <utility>
|
||||||
|
|
||||||
|
#if __ANDROID_API__ >= 9
|
||||||
|
#include "android/asset_manager.h"
|
||||||
|
#include "android/asset_manager_jni.h"
|
||||||
|
#endif
|
||||||
|
|
||||||
|
#include "onnxruntime_cxx_api.h" // NOLINT
|
||||||
|
#include "sherpa-onnx/csrc/audio-tagging-model-config.h"
|
||||||
|
|
||||||
|
namespace sherpa_onnx {
|
||||||
|
|
||||||
|
/** This class implements the CED model from
|
||||||
|
* https://github.com/RicherMans/CED/blob/main/export_onnx.py
|
||||||
|
*/
|
||||||
|
class OfflineCEDModel {
|
||||||
|
public:
|
||||||
|
explicit OfflineCEDModel(const AudioTaggingModelConfig &config);
|
||||||
|
|
||||||
|
#if __ANDROID_API__ >= 9
|
||||||
|
OfflineCEDModel(AAssetManager *mgr, const AudioTaggingModelConfig &config);
|
||||||
|
#endif
|
||||||
|
|
||||||
|
~OfflineCEDModel();
|
||||||
|
|
||||||
|
/** Run the forward method of the model.
|
||||||
|
*
|
||||||
|
* @param features A tensor of shape (N, T, C).
|
||||||
|
*
|
||||||
|
* @return Return a tensor
|
||||||
|
* - probs: A 2-D tensor of shape (N, num_event_classes).
|
||||||
|
*/
|
||||||
|
Ort::Value Forward(Ort::Value features) const;
|
||||||
|
|
||||||
|
/** Return the number of event classes of the model
|
||||||
|
*/
|
||||||
|
int32_t NumEventClasses() const;
|
||||||
|
|
||||||
|
/** Return an allocator for allocating memory
|
||||||
|
*/
|
||||||
|
OrtAllocator *Allocator() const;
|
||||||
|
|
||||||
|
private:
|
||||||
|
class Impl;
|
||||||
|
std::unique_ptr<Impl> impl_;
|
||||||
|
};
|
||||||
|
|
||||||
|
} // namespace sherpa_onnx
|
||||||
|
|
||||||
|
#endif // SHERPA_ONNX_CSRC_OFFLINE_CED_MODEL_H_
|
||||||
@@ -92,15 +92,32 @@ class OfflineStream::Impl {
|
|||||||
fbank_ = std::make_unique<knf::OnlineFbank>(opts_);
|
fbank_ = std::make_unique<knf::OnlineFbank>(opts_);
|
||||||
}
|
}
|
||||||
|
|
||||||
Impl(WhisperTag /*tag*/, ContextGraphPtr context_graph)
|
explicit Impl(WhisperTag /*tag*/) {
|
||||||
: context_graph_(context_graph) {
|
|
||||||
config_.normalize_samples = true;
|
config_.normalize_samples = true;
|
||||||
opts_.frame_opts.samp_freq = 16000;
|
opts_.frame_opts.samp_freq = 16000;
|
||||||
opts_.mel_opts.num_bins = 80;
|
opts_.mel_opts.num_bins = 80; // not used
|
||||||
whisper_fbank_ =
|
whisper_fbank_ =
|
||||||
std::make_unique<knf::OnlineWhisperFbank>(opts_.frame_opts);
|
std::make_unique<knf::OnlineWhisperFbank>(opts_.frame_opts);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
explicit Impl(CEDTag /*tag*/) {
|
||||||
|
// see
|
||||||
|
// https://github.com/RicherMans/CED/blob/main/onnx_inference_with_kaldi.py
|
||||||
|
|
||||||
|
opts_.frame_opts.frame_length_ms = 32;
|
||||||
|
opts_.frame_opts.dither = 0;
|
||||||
|
opts_.frame_opts.preemph_coeff = 0;
|
||||||
|
opts_.frame_opts.remove_dc_offset = false;
|
||||||
|
opts_.frame_opts.window_type = "hann";
|
||||||
|
opts_.frame_opts.snip_edges = false;
|
||||||
|
|
||||||
|
opts_.frame_opts.samp_freq = 16000; // fixed to 16000
|
||||||
|
opts_.mel_opts.num_bins = 64;
|
||||||
|
opts_.mel_opts.high_freq = 8000;
|
||||||
|
|
||||||
|
fbank_ = std::make_unique<knf::OnlineFbank>(opts_);
|
||||||
|
}
|
||||||
|
|
||||||
void AcceptWaveform(int32_t sampling_rate, const float *waveform, int32_t n) {
|
void AcceptWaveform(int32_t sampling_rate, const float *waveform, int32_t n) {
|
||||||
if (config_.normalize_samples) {
|
if (config_.normalize_samples) {
|
||||||
AcceptWaveformImpl(sampling_rate, waveform, n);
|
AcceptWaveformImpl(sampling_rate, waveform, n);
|
||||||
@@ -233,9 +250,10 @@ OfflineStream::OfflineStream(
|
|||||||
ContextGraphPtr context_graph /*= nullptr*/)
|
ContextGraphPtr context_graph /*= nullptr*/)
|
||||||
: impl_(std::make_unique<Impl>(config, context_graph)) {}
|
: impl_(std::make_unique<Impl>(config, context_graph)) {}
|
||||||
|
|
||||||
OfflineStream::OfflineStream(WhisperTag tag,
|
OfflineStream::OfflineStream(WhisperTag tag)
|
||||||
ContextGraphPtr context_graph /*= {}*/)
|
: impl_(std::make_unique<Impl>(tag)) {}
|
||||||
: impl_(std::make_unique<Impl>(tag, context_graph)) {}
|
|
||||||
|
OfflineStream::OfflineStream(CEDTag tag) : impl_(std::make_unique<Impl>(tag)) {}
|
||||||
|
|
||||||
OfflineStream::~OfflineStream() = default;
|
OfflineStream::~OfflineStream() = default;
|
||||||
|
|
||||||
|
|||||||
@@ -67,13 +67,15 @@ struct OfflineFeatureExtractorConfig {
|
|||||||
};
|
};
|
||||||
|
|
||||||
struct WhisperTag {};
|
struct WhisperTag {};
|
||||||
|
struct CEDTag {};
|
||||||
|
|
||||||
class OfflineStream {
|
class OfflineStream {
|
||||||
public:
|
public:
|
||||||
explicit OfflineStream(const OfflineFeatureExtractorConfig &config = {},
|
explicit OfflineStream(const OfflineFeatureExtractorConfig &config = {},
|
||||||
ContextGraphPtr context_graph = {});
|
ContextGraphPtr context_graph = {});
|
||||||
|
|
||||||
explicit OfflineStream(WhisperTag tag, ContextGraphPtr context_graph = {});
|
explicit OfflineStream(WhisperTag tag);
|
||||||
|
explicit OfflineStream(CEDTag tag);
|
||||||
~OfflineStream();
|
~OfflineStream();
|
||||||
|
|
||||||
/**
|
/**
|
||||||
|
|||||||
@@ -31,6 +31,12 @@ static AudioTaggingConfig GetAudioTaggingConfig(JNIEnv *env, jobject config) {
|
|||||||
ans.model.zipformer.model = p;
|
ans.model.zipformer.model = p;
|
||||||
env->ReleaseStringUTFChars(s, p);
|
env->ReleaseStringUTFChars(s, p);
|
||||||
|
|
||||||
|
fid = env->GetFieldID(model_cls, "ced", "Ljava/lang/String;");
|
||||||
|
s = (jstring)env->GetObjectField(model, fid);
|
||||||
|
p = env->GetStringUTFChars(s, nullptr);
|
||||||
|
ans.model.ced = p;
|
||||||
|
env->ReleaseStringUTFChars(s, p);
|
||||||
|
|
||||||
fid = env->GetFieldID(model_cls, "numThreads", "I");
|
fid = env->GetFieldID(model_cls, "numThreads", "I");
|
||||||
ans.model.num_threads = env->GetIntField(model, fid);
|
ans.model.num_threads = env->GetIntField(model, fid);
|
||||||
|
|
||||||
|
|||||||
@@ -27,10 +27,11 @@ static void PybindAudioTaggingModelConfig(py::module *m) {
|
|||||||
|
|
||||||
py::class_<PyClass>(*m, "AudioTaggingModelConfig")
|
py::class_<PyClass>(*m, "AudioTaggingModelConfig")
|
||||||
.def(py::init<>())
|
.def(py::init<>())
|
||||||
.def(py::init<const OfflineZipformerAudioTaggingModelConfig &, int32_t,
|
.def(py::init<const OfflineZipformerAudioTaggingModelConfig &,
|
||||||
bool, const std::string &>(),
|
const std::string &, int32_t, bool, const std::string &>(),
|
||||||
py::arg("zipformer"), py::arg("num_threads") = 1,
|
py::arg("zipformer"), py::arg("ced") = "",
|
||||||
py::arg("debug") = false, py::arg("provider") = "cpu")
|
py::arg("num_threads") = 1, py::arg("debug") = false,
|
||||||
|
py::arg("provider") = "cpu")
|
||||||
.def_readwrite("zipformer", &PyClass::zipformer)
|
.def_readwrite("zipformer", &PyClass::zipformer)
|
||||||
.def_readwrite("num_threads", &PyClass::num_threads)
|
.def_readwrite("num_threads", &PyClass::num_threads)
|
||||||
.def_readwrite("debug", &PyClass::debug)
|
.def_readwrite("debug", &PyClass::debug)
|
||||||
|
|||||||
Reference in New Issue
Block a user