Support CED models (#792)
This commit is contained in:
78
.github/workflows/export-ced-to-onnx.yaml
vendored
Normal file
78
.github/workflows/export-ced-to-onnx.yaml
vendored
Normal file
@@ -0,0 +1,78 @@
|
||||
name: export-ced-to-onnx
|
||||
|
||||
on:
|
||||
workflow_dispatch:
|
||||
|
||||
concurrency:
|
||||
group: export-ced-to-onnx-${{ github.ref }}
|
||||
cancel-in-progress: true
|
||||
|
||||
jobs:
|
||||
export-ced-to-onnx:
|
||||
if: github.repository_owner == 'k2-fsa' || github.repository_owner == 'csukuangfj'
|
||||
name: export ced
|
||||
runs-on: ${{ matrix.os }}
|
||||
strategy:
|
||||
fail-fast: false
|
||||
matrix:
|
||||
os: [ubuntu-latest]
|
||||
python-version: ["3.8"]
|
||||
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
|
||||
- name: Setup Python ${{ matrix.python-version }}
|
||||
uses: actions/setup-python@v5
|
||||
with:
|
||||
python-version: ${{ matrix.python-version }}
|
||||
|
||||
- name: Run
|
||||
shell: bash
|
||||
run: |
|
||||
cd scripts/ced
|
||||
./run.sh
|
||||
|
||||
- name: Release
|
||||
uses: svenstaro/upload-release-action@v2
|
||||
with:
|
||||
file_glob: true
|
||||
file: ./*.tar.bz2
|
||||
overwrite: true
|
||||
repo_name: k2-fsa/sherpa-onnx
|
||||
repo_token: ${{ secrets.UPLOAD_GH_SHERPA_ONNX_TOKEN }}
|
||||
tag: audio-tagging-models
|
||||
|
||||
- name: Publish to huggingface
|
||||
env:
|
||||
HF_TOKEN: ${{ secrets.HF_TOKEN }}
|
||||
uses: nick-fields/retry@v3
|
||||
with:
|
||||
max_attempts: 20
|
||||
timeout_seconds: 200
|
||||
shell: bash
|
||||
command: |
|
||||
git config --global user.email "csukuangfj@gmail.com"
|
||||
git config --global user.name "Fangjun Kuang"
|
||||
|
||||
models=(
|
||||
tiny
|
||||
mini
|
||||
small
|
||||
base
|
||||
)
|
||||
|
||||
for m in ${models[@]}; do
|
||||
rm -rf huggingface
|
||||
export GIT_LFS_SKIP_SMUDGE=1
|
||||
d=sherpa-onnx-ced-$m-audio-tagging-2024-04-19
|
||||
git clone https://huggingface.co/k2-fsa/$d huggingface
|
||||
mv -v $d/* huggingface
|
||||
cd huggingface
|
||||
git lfs track "*.onnx"
|
||||
git status
|
||||
git add .
|
||||
git status
|
||||
git commit -m "first commit"
|
||||
git push https://csukuangfj:$HF_TOKEN@huggingface.co/k2-fsa/$d main
|
||||
cd ..
|
||||
done
|
||||
@@ -1,5 +1,5 @@
|
||||
<resources>
|
||||
<string name="app_name">ASR with Next-gen Kaldi</string>
|
||||
<string name="app_name">ASR</string>
|
||||
<string name="hint">Click the Start button to play speech-to-text with Next-gen Kaldi.
|
||||
\n
|
||||
\n\n\n
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
<resources>
|
||||
<string name="app_name">ASR with Next-gen Kaldi</string>
|
||||
<string name="app_name">ASR2pass </string>
|
||||
<string name="hint">Click the Start button to play speech-to-text with Next-gen Kaldi.
|
||||
\n
|
||||
\n\n\n
|
||||
|
||||
@@ -1,16 +1,16 @@
|
||||
package com.k2fsa.sherpa.onnx
|
||||
|
||||
import android.content.res.AssetManager
|
||||
import android.util.Log
|
||||
|
||||
private val TAG = "sherpa-onnx"
|
||||
const val TAG = "sherpa-onnx"
|
||||
|
||||
data class OfflineZipformerAudioTaggingModelConfig(
|
||||
var model: String,
|
||||
var model: String = "",
|
||||
)
|
||||
|
||||
data class AudioTaggingModelConfig(
|
||||
var zipformer: OfflineZipformerAudioTaggingModelConfig,
|
||||
var zipformer: OfflineZipformerAudioTaggingModelConfig = OfflineZipformerAudioTaggingModelConfig(),
|
||||
var ced: String = "",
|
||||
var numThreads: Int = 1,
|
||||
var debug: Boolean = false,
|
||||
var provider: String = "cpu",
|
||||
@@ -103,7 +103,7 @@ class AudioTagging(
|
||||
//
|
||||
// See also
|
||||
// https://k2-fsa.github.io/sherpa/onnx/audio-tagging/
|
||||
fun getAudioTaggingConfig(type: Int, numThreads: Int=1): AudioTaggingConfig? {
|
||||
fun getAudioTaggingConfig(type: Int, numThreads: Int = 1): AudioTaggingConfig? {
|
||||
when (type) {
|
||||
0 -> {
|
||||
val modelDir = "sherpa-onnx-zipformer-small-audio-tagging-2024-04-15"
|
||||
@@ -123,7 +123,7 @@ fun getAudioTaggingConfig(type: Int, numThreads: Int=1): AudioTaggingConfig? {
|
||||
return AudioTaggingConfig(
|
||||
model = AudioTaggingModelConfig(
|
||||
zipformer = OfflineZipformerAudioTaggingModelConfig(model = "$modelDir/model.int8.onnx"),
|
||||
numThreads = 1,
|
||||
numThreads = numThreads,
|
||||
debug = true,
|
||||
),
|
||||
labels = "$modelDir/class_labels_indices.csv",
|
||||
@@ -131,6 +131,57 @@ fun getAudioTaggingConfig(type: Int, numThreads: Int=1): AudioTaggingConfig? {
|
||||
)
|
||||
}
|
||||
|
||||
2 -> {
|
||||
val modelDir = "sherpa-onnx-ced-tiny-audio-tagging-2024-04-19"
|
||||
return AudioTaggingConfig(
|
||||
model = AudioTaggingModelConfig(
|
||||
ced = "$modelDir/model.int8.onnx",
|
||||
numThreads = numThreads,
|
||||
debug = true,
|
||||
),
|
||||
labels = "$modelDir/class_labels_indices.csv",
|
||||
topK = 3,
|
||||
)
|
||||
}
|
||||
|
||||
3 -> {
|
||||
val modelDir = "sherpa-onnx-ced-mini-audio-tagging-2024-04-19"
|
||||
return AudioTaggingConfig(
|
||||
model = AudioTaggingModelConfig(
|
||||
ced = "$modelDir/model.int8.onnx",
|
||||
numThreads = numThreads,
|
||||
debug = true,
|
||||
),
|
||||
labels = "$modelDir/class_labels_indices.csv",
|
||||
topK = 3,
|
||||
)
|
||||
}
|
||||
|
||||
4 -> {
|
||||
val modelDir = "sherpa-onnx-ced-small-audio-tagging-2024-04-19"
|
||||
return AudioTaggingConfig(
|
||||
model = AudioTaggingModelConfig(
|
||||
ced = "$modelDir/model.int8.onnx",
|
||||
numThreads = numThreads,
|
||||
debug = true,
|
||||
),
|
||||
labels = "$modelDir/class_labels_indices.csv",
|
||||
topK = 3,
|
||||
)
|
||||
}
|
||||
|
||||
5 -> {
|
||||
val modelDir = "sherpa-onnx-ced-base-audio-tagging-2024-04-19"
|
||||
return AudioTaggingConfig(
|
||||
model = AudioTaggingModelConfig(
|
||||
ced = "$modelDir/model.int8.onnx",
|
||||
numThreads = numThreads,
|
||||
debug = true,
|
||||
),
|
||||
labels = "$modelDir/class_labels_indices.csv",
|
||||
topK = 3,
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
return null
|
||||
|
||||
@@ -3,24 +3,15 @@
|
||||
package com.k2fsa.sherpa.onnx.audio.tagging
|
||||
|
||||
import android.Manifest
|
||||
|
||||
import android.app.Activity
|
||||
import android.content.pm.PackageManager
|
||||
import android.media.AudioFormat
|
||||
import android.media.AudioRecord
|
||||
import androidx.compose.foundation.lazy.items
|
||||
import android.media.MediaRecorder
|
||||
import android.util.Log
|
||||
import androidx.compose.foundation.ExperimentalFoundationApi
|
||||
import androidx.compose.foundation.background
|
||||
import androidx.compose.foundation.layout.Arrangement
|
||||
import androidx.compose.foundation.layout.Box
|
||||
import androidx.compose.material3.CenterAlignedTopAppBar
|
||||
import androidx.compose.runtime.Composable
|
||||
import androidx.compose.material3.Scaffold
|
||||
import androidx.compose.material3.TopAppBarDefaults
|
||||
import androidx.compose.material3.MaterialTheme
|
||||
import androidx.compose.material3.Text
|
||||
import androidx.compose.foundation.layout.Column
|
||||
import androidx.compose.foundation.layout.PaddingValues
|
||||
import androidx.compose.foundation.layout.Row
|
||||
@@ -30,10 +21,17 @@ import androidx.compose.foundation.layout.fillMaxWidth
|
||||
import androidx.compose.foundation.layout.height
|
||||
import androidx.compose.foundation.layout.padding
|
||||
import androidx.compose.foundation.lazy.LazyColumn
|
||||
import androidx.compose.foundation.lazy.items
|
||||
import androidx.compose.material3.Button
|
||||
import androidx.compose.material3.CenterAlignedTopAppBar
|
||||
import androidx.compose.material3.ExperimentalMaterial3Api
|
||||
import androidx.compose.material3.MaterialTheme
|
||||
import androidx.compose.material3.Scaffold
|
||||
import androidx.compose.material3.Slider
|
||||
import androidx.compose.material3.Surface
|
||||
import androidx.compose.material3.Text
|
||||
import androidx.compose.material3.TopAppBarDefaults
|
||||
import androidx.compose.runtime.Composable
|
||||
import androidx.compose.runtime.getValue
|
||||
import androidx.compose.runtime.mutableStateListOf
|
||||
import androidx.compose.runtime.mutableStateOf
|
||||
@@ -41,7 +39,6 @@ import androidx.compose.runtime.remember
|
||||
import androidx.compose.runtime.setValue
|
||||
import androidx.compose.ui.Alignment
|
||||
import androidx.compose.ui.Modifier
|
||||
import androidx.compose.ui.graphics.Color
|
||||
import androidx.compose.ui.platform.LocalContext
|
||||
import androidx.compose.ui.text.font.FontWeight
|
||||
import androidx.compose.ui.text.style.TextAlign
|
||||
@@ -49,6 +46,7 @@ import androidx.compose.ui.unit.dp
|
||||
import androidx.compose.ui.unit.sp
|
||||
import androidx.core.app.ActivityCompat
|
||||
import com.k2fsa.sherpa.onnx.AudioEvent
|
||||
import com.k2fsa.sherpa.onnx.Tagger
|
||||
import kotlin.concurrent.thread
|
||||
|
||||
|
||||
|
||||
@@ -13,6 +13,7 @@ import androidx.compose.material3.Surface
|
||||
import androidx.compose.runtime.Composable
|
||||
import androidx.compose.ui.Modifier
|
||||
import androidx.core.app.ActivityCompat
|
||||
import com.k2fsa.sherpa.onnx.Tagger
|
||||
import com.k2fsa.sherpa.onnx.audio.tagging.ui.theme.SherpaOnnxAudioTaggingTheme
|
||||
|
||||
const val TAG = "sherpa-onnx"
|
||||
|
||||
@@ -1,10 +1,8 @@
|
||||
package com.k2fsa.sherpa.onnx.audio.tagging
|
||||
package com.k2fsa.sherpa.onnx
|
||||
|
||||
import android.content.res.AssetManager
|
||||
import android.util.Log
|
||||
import com.k2fsa.sherpa.onnx.AudioTagging
|
||||
import com.k2fsa.sherpa.onnx.audio.tagging.wear.os.presentation.TAG
|
||||
import com.k2fsa.sherpa.onnx.getAudioTaggingConfig
|
||||
|
||||
|
||||
object Tagger {
|
||||
private var _tagger: AudioTagging? = null
|
||||
@@ -12,6 +10,7 @@ object Tagger {
|
||||
get() {
|
||||
return _tagger!!
|
||||
}
|
||||
|
||||
fun initTagger(assetManager: AssetManager? = null, numThreads: Int = 1) {
|
||||
synchronized(this) {
|
||||
if (_tagger != null) {
|
||||
@@ -19,7 +18,7 @@ object Tagger {
|
||||
}
|
||||
|
||||
Log.i(TAG, "Initializing audio tagger")
|
||||
val config = getAudioTaggingConfig(type = 0, numThreads=numThreads)!!
|
||||
val config = getAudioTaggingConfig(type = 0, numThreads = numThreads)!!
|
||||
_tagger = AudioTagging(assetManager, config)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -33,7 +33,7 @@ import androidx.wear.compose.material.Button
|
||||
import androidx.wear.compose.material.MaterialTheme
|
||||
import androidx.wear.compose.material.Text
|
||||
import com.k2fsa.sherpa.onnx.AudioEvent
|
||||
import com.k2fsa.sherpa.onnx.audio.tagging.Tagger
|
||||
import com.k2fsa.sherpa.onnx.Tagger
|
||||
import com.k2fsa.sherpa.onnx.audio.tagging.wear.os.presentation.theme.SherpaOnnxAudioTaggingWearOsTheme
|
||||
import kotlin.concurrent.thread
|
||||
|
||||
|
||||
@@ -17,7 +17,7 @@ import androidx.activity.compose.setContent
|
||||
import androidx.compose.runtime.Composable
|
||||
import androidx.core.app.ActivityCompat
|
||||
import androidx.core.splashscreen.SplashScreen.Companion.installSplashScreen
|
||||
import com.k2fsa.sherpa.onnx.audio.tagging.Tagger
|
||||
import com.k2fsa.sherpa.onnx.Tagger
|
||||
|
||||
const val TAG = "sherpa-onnx"
|
||||
private const val REQUEST_RECORD_AUDIO_PERMISSION = 200
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
<resources>
|
||||
<string name="app_name">AudioTagging</string>
|
||||
<string name="app_name">Audio Tagging</string>
|
||||
<!--
|
||||
This string is used for square devices and overridden by hello_world in
|
||||
values-round/strings.xml for round devices.
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
<resources>
|
||||
<string name="app_name">Speaker Identification</string>
|
||||
<string name="app_name">Speaker ID</string>
|
||||
<string name="start">Start recording</string>
|
||||
<string name="stop">Stop recording</string>
|
||||
<string name="add">Add speaker</string>
|
||||
|
||||
@@ -1,3 +1,3 @@
|
||||
<resources>
|
||||
<string name="app_name">SherpaOnnxSpokenLanguageIdentification</string>
|
||||
<string name="app_name">Language ID</string>
|
||||
</resources>
|
||||
@@ -1,5 +1,5 @@
|
||||
<resources>
|
||||
<string name="app_name">Next-gen Kaldi: TTS</string>
|
||||
<string name="app_name">TTS</string>
|
||||
<string name="sid_label">Speaker ID</string>
|
||||
<string name="sid_hint">0</string>
|
||||
<string name="speed_label">Speech speed (large->fast)</string>
|
||||
|
||||
@@ -1,3 +1,3 @@
|
||||
<resources>
|
||||
<string name="app_name">Next-gen Kaldi: TTS</string>
|
||||
<string name="app_name">TTS Engine</string>
|
||||
</resources>
|
||||
@@ -1,5 +1,5 @@
|
||||
<resources>
|
||||
<string name="app_name">Next-gen Kaldi: SileroVAD</string>
|
||||
<string name="app_name">VAD</string>
|
||||
|
||||
<string name="hint">Click the Start button to play Silero VAD with Next-gen Kaldi.</string>
|
||||
<string name="start">Start</string>
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
<resources>
|
||||
<string name="app_name">ASR with Next-gen Kaldi</string>
|
||||
<string name="app_name">VAD-ASR</string>
|
||||
<string name="hint">Click the Start button to play speech-to-text with Next-gen Kaldi.
|
||||
\n
|
||||
\n\n\n
|
||||
|
||||
@@ -46,7 +46,30 @@ def get_models():
|
||||
),
|
||||
]
|
||||
|
||||
return icefall_models
|
||||
ced_models = [
|
||||
AudioTaggingModel(
|
||||
model_name="sherpa-onnx-ced-tiny-audio-tagging-2024-04-19",
|
||||
idx=2,
|
||||
short_name="ced_tiny",
|
||||
),
|
||||
AudioTaggingModel(
|
||||
model_name="sherpa-onnx-ced-mini-audio-tagging-2024-04-19",
|
||||
idx=3,
|
||||
short_name="ced_mini",
|
||||
),
|
||||
AudioTaggingModel(
|
||||
model_name="sherpa-onnx-ced-small-audio-tagging-2024-04-19",
|
||||
idx=4,
|
||||
short_name="ced_small",
|
||||
),
|
||||
AudioTaggingModel(
|
||||
model_name="sherpa-onnx-ced-base-audio-tagging-2024-04-19",
|
||||
idx=5,
|
||||
short_name="ced_base",
|
||||
),
|
||||
]
|
||||
|
||||
return icefall_models + ced_models
|
||||
|
||||
|
||||
def main():
|
||||
|
||||
69
scripts/ced/run.sh
Executable file
69
scripts/ced/run.sh
Executable file
@@ -0,0 +1,69 @@
|
||||
#!/usr/bin/env bash
|
||||
#
|
||||
# Copyright 2024 Xiaomi Corp. (authors: Fangjun Kuang)
|
||||
|
||||
set -ex
|
||||
|
||||
function install_dependencies() {
|
||||
pip install -qq torch==2.1.0+cpu torchaudio==2.1.0+cpu -f https://download.pytorch.org/whl/torch_stable.html
|
||||
pip install -qq onnx onnxruntime==1.17.1
|
||||
|
||||
pip install -r ./requirements.txt
|
||||
}
|
||||
|
||||
git clone https://github.com/RicherMans/CED
|
||||
pushd CED
|
||||
|
||||
install_dependencies
|
||||
|
||||
models=(
|
||||
tiny
|
||||
mini
|
||||
small
|
||||
base
|
||||
)
|
||||
|
||||
for m in ${models[@]}; do
|
||||
python3 ./export_onnx.py -m ced_$m
|
||||
done
|
||||
|
||||
ls -lh *.onnx
|
||||
|
||||
curl -SL -O https://github.com/k2-fsa/sherpa-onnx/releases/download/audio-tagging-models/sherpa-onnx-zipformer-small-audio-tagging-2024-04-15.tar.bz2
|
||||
|
||||
tar xvf sherpa-onnx-zipformer-small-audio-tagging-2024-04-15.tar.bz2
|
||||
rm sherpa-onnx-zipformer-small-audio-tagging-2024-04-15.tar.bz2
|
||||
src=sherpa-onnx-zipformer-small-audio-tagging-2024-04-15
|
||||
|
||||
cat >README.md <<EOF
|
||||
# Introduction
|
||||
|
||||
Models in this repo are converted from
|
||||
https://github.com/RicherMans/CED
|
||||
EOF
|
||||
|
||||
for m in ${models[@]}; do
|
||||
d=sherpa-onnx-ced-$m-audio-tagging-2024-04-19
|
||||
|
||||
mkdir -p $d
|
||||
|
||||
cp -v README.md $d
|
||||
cp -v $src/class_labels_indices.csv $d
|
||||
cp -a $src/test_wavs $d
|
||||
cp -v ced_$m.onnx $d/model.onnx
|
||||
cp -v ced_$m.int8.onnx $d/model.int8.onnx
|
||||
echo "----------$m----------"
|
||||
ls -lh $d
|
||||
echo "----------------------"
|
||||
tar cjvf $d.tar.bz2 $d
|
||||
mv $d.tar.bz2 ../../..
|
||||
mv $d ../../../
|
||||
done
|
||||
|
||||
rm -rf sherpa-onnx-zipformer-small-audio-tagging-2024-04-15
|
||||
|
||||
cd ../../..
|
||||
|
||||
ls -lh *.tar.bz2
|
||||
echo "======="
|
||||
ls -lh
|
||||
@@ -1223,6 +1223,7 @@ const SherpaOnnxAudioTagging *SherpaOnnxCreateAudioTagging(
|
||||
const SherpaOnnxAudioTaggingConfig *config) {
|
||||
sherpa_onnx::AudioTaggingConfig ac;
|
||||
ac.model.zipformer.model = SHERPA_ONNX_OR(config->model.zipformer.model, "");
|
||||
ac.model.ced = SHERPA_ONNX_OR(config->model.ced, "");
|
||||
ac.model.num_threads = SHERPA_ONNX_OR(config->model.num_threads, 1);
|
||||
ac.model.debug = config->model.debug;
|
||||
ac.model.provider = SHERPA_ONNX_OR(config->model.provider, "cpu");
|
||||
|
||||
@@ -1100,6 +1100,7 @@ SHERPA_ONNX_API typedef struct
|
||||
|
||||
SHERPA_ONNX_API typedef struct SherpaOnnxAudioTaggingModelConfig {
|
||||
SherpaOnnxOfflineZipformerAudioTaggingModelConfig zipformer;
|
||||
const char *ced;
|
||||
int32_t num_threads;
|
||||
int32_t debug; // true to print debug information of the model
|
||||
const char *provider;
|
||||
|
||||
@@ -117,6 +117,7 @@ list(APPEND sources
|
||||
audio-tagging-label-file.cc
|
||||
audio-tagging-model-config.cc
|
||||
audio-tagging.cc
|
||||
offline-ced-model.cc
|
||||
offline-zipformer-audio-tagging-model-config.cc
|
||||
offline-zipformer-audio-tagging-model.cc
|
||||
)
|
||||
|
||||
111
sherpa-onnx/csrc/audio-tagging-ced-impl.h
Normal file
111
sherpa-onnx/csrc/audio-tagging-ced-impl.h
Normal file
@@ -0,0 +1,111 @@
|
||||
// sherpa-onnx/csrc/audio-tagging-ced-impl.h
|
||||
//
|
||||
// Copyright (c) 2024 Xiaomi Corporation
|
||||
#ifndef SHERPA_ONNX_CSRC_AUDIO_TAGGING_CED_IMPL_H_
|
||||
#define SHERPA_ONNX_CSRC_AUDIO_TAGGING_CED_IMPL_H_
|
||||
|
||||
#include <assert.h>
|
||||
|
||||
#include <memory>
|
||||
#include <utility>
|
||||
#include <vector>
|
||||
|
||||
#if __ANDROID_API__ >= 9
|
||||
#include "android/asset_manager.h"
|
||||
#include "android/asset_manager_jni.h"
|
||||
#endif
|
||||
|
||||
#include "sherpa-onnx/csrc/audio-tagging-impl.h"
|
||||
#include "sherpa-onnx/csrc/audio-tagging-label-file.h"
|
||||
#include "sherpa-onnx/csrc/audio-tagging.h"
|
||||
#include "sherpa-onnx/csrc/macros.h"
|
||||
#include "sherpa-onnx/csrc/math.h"
|
||||
#include "sherpa-onnx/csrc/offline-ced-model.h"
|
||||
|
||||
namespace sherpa_onnx {
|
||||
|
||||
class AudioTaggingCEDImpl : public AudioTaggingImpl {
|
||||
public:
|
||||
explicit AudioTaggingCEDImpl(const AudioTaggingConfig &config)
|
||||
: config_(config), model_(config.model), labels_(config.labels) {
|
||||
if (model_.NumEventClasses() != labels_.NumEventClasses()) {
|
||||
SHERPA_ONNX_LOGE("number of classes: %d (model) != %d (label file)",
|
||||
model_.NumEventClasses(), labels_.NumEventClasses());
|
||||
exit(-1);
|
||||
}
|
||||
}
|
||||
|
||||
#if __ANDROID_API__ >= 9
|
||||
explicit AudioTaggingCEDImpl(AAssetManager *mgr,
|
||||
const AudioTaggingConfig &config)
|
||||
: config_(config),
|
||||
model_(mgr, config.model),
|
||||
labels_(mgr, config.labels) {
|
||||
if (model_.NumEventClasses() != labels_.NumEventClasses()) {
|
||||
SHERPA_ONNX_LOGE("number of classes: %d (model) != %d (label file)",
|
||||
model_.NumEventClasses(), labels_.NumEventClasses());
|
||||
exit(-1);
|
||||
}
|
||||
}
|
||||
#endif
|
||||
|
||||
std::unique_ptr<OfflineStream> CreateStream() const override {
|
||||
return std::make_unique<OfflineStream>(CEDTag{});
|
||||
}
|
||||
|
||||
std::vector<AudioEvent> Compute(OfflineStream *s,
|
||||
int32_t top_k = -1) const override {
|
||||
if (top_k < 0) {
|
||||
top_k = config_.top_k;
|
||||
}
|
||||
|
||||
int32_t num_event_classes = model_.NumEventClasses();
|
||||
|
||||
if (top_k > num_event_classes) {
|
||||
top_k = num_event_classes;
|
||||
}
|
||||
|
||||
auto memory_info =
|
||||
Ort::MemoryInfo::CreateCpu(OrtDeviceAllocator, OrtMemTypeDefault);
|
||||
|
||||
// WARNING(fangjun): It is fixed to 64 for CED models
|
||||
int32_t feat_dim = 64;
|
||||
std::vector<float> f = s->GetFrames();
|
||||
|
||||
int32_t num_frames = f.size() / feat_dim;
|
||||
assert(feat_dim * num_frames == f.size());
|
||||
|
||||
std::array<int64_t, 3> shape = {1, num_frames, feat_dim};
|
||||
|
||||
Ort::Value x = Ort::Value::CreateTensor(memory_info, f.data(), f.size(),
|
||||
shape.data(), shape.size());
|
||||
|
||||
Ort::Value probs = model_.Forward(std::move(x));
|
||||
|
||||
const float *p = probs.GetTensorData<float>();
|
||||
|
||||
std::vector<int32_t> top_k_indexes = TopkIndex(p, num_event_classes, top_k);
|
||||
|
||||
std::vector<AudioEvent> ans(top_k);
|
||||
|
||||
int32_t i = 0;
|
||||
|
||||
for (int32_t index : top_k_indexes) {
|
||||
ans[i].name = labels_.GetEventName(index);
|
||||
ans[i].index = index;
|
||||
ans[i].prob = p[index];
|
||||
i += 1;
|
||||
}
|
||||
|
||||
return ans;
|
||||
}
|
||||
|
||||
private:
|
||||
AudioTaggingConfig config_;
|
||||
OfflineCEDModel model_;
|
||||
AudioTaggingLabels labels_;
|
||||
};
|
||||
|
||||
} // namespace sherpa_onnx
|
||||
|
||||
#endif // SHERPA_ONNX_CSRC_AUDIO_TAGGING_CED_IMPL_H_
|
||||
@@ -11,6 +11,7 @@
|
||||
#include "android/asset_manager_jni.h"
|
||||
#endif
|
||||
|
||||
#include "sherpa-onnx/csrc/audio-tagging-ced-impl.h"
|
||||
#include "sherpa-onnx/csrc/audio-tagging-zipformer-impl.h"
|
||||
#include "sherpa-onnx/csrc/macros.h"
|
||||
|
||||
@@ -20,6 +21,8 @@ std::unique_ptr<AudioTaggingImpl> AudioTaggingImpl::Create(
|
||||
const AudioTaggingConfig &config) {
|
||||
if (!config.model.zipformer.model.empty()) {
|
||||
return std::make_unique<AudioTaggingZipformerImpl>(config);
|
||||
} else if (!config.model.ced.empty()) {
|
||||
return std::make_unique<AudioTaggingCEDImpl>(config);
|
||||
}
|
||||
|
||||
SHERPA_ONNX_LOG(
|
||||
@@ -32,6 +35,8 @@ std::unique_ptr<AudioTaggingImpl> AudioTaggingImpl::Create(
|
||||
AAssetManager *mgr, const AudioTaggingConfig &config) {
|
||||
if (!config.model.zipformer.model.empty()) {
|
||||
return std::make_unique<AudioTaggingZipformerImpl>(mgr, config);
|
||||
} else if (!config.model.ced.empty()) {
|
||||
return std::make_unique<AudioTaggingCEDImpl>(mgr, config);
|
||||
}
|
||||
|
||||
SHERPA_ONNX_LOG(
|
||||
|
||||
@@ -4,11 +4,18 @@
|
||||
|
||||
#include "sherpa-onnx/csrc/audio-tagging-model-config.h"
|
||||
|
||||
#include "sherpa-onnx/csrc/file-utils.h"
|
||||
#include "sherpa-onnx/csrc/macros.h"
|
||||
|
||||
namespace sherpa_onnx {
|
||||
|
||||
void AudioTaggingModelConfig::Register(ParseOptions *po) {
|
||||
zipformer.Register(po);
|
||||
|
||||
po->Register("ced-model", &ced,
|
||||
"Path to CED model. Only need to pass one of --zipformer-model "
|
||||
"or --ced-model");
|
||||
|
||||
po->Register("num-threads", &num_threads,
|
||||
"Number of threads to run the neural network");
|
||||
|
||||
@@ -24,6 +31,16 @@ bool AudioTaggingModelConfig::Validate() const {
|
||||
return false;
|
||||
}
|
||||
|
||||
if (!ced.empty() && !FileExists(ced)) {
|
||||
SHERPA_ONNX_LOGE("CED model file %s does not exist", ced.c_str());
|
||||
return false;
|
||||
}
|
||||
|
||||
if (zipformer.model.empty() && ced.empty()) {
|
||||
SHERPA_ONNX_LOGE("Please provide either --zipformer-model or --ced-model");
|
||||
return false;
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
@@ -32,6 +49,7 @@ std::string AudioTaggingModelConfig::ToString() const {
|
||||
|
||||
os << "AudioTaggingModelConfig(";
|
||||
os << "zipformer=" << zipformer.ToString() << ", ";
|
||||
os << "ced=\"" << ced << "\", ";
|
||||
os << "num_threads=" << num_threads << ", ";
|
||||
os << "debug=" << (debug ? "True" : "False") << ", ";
|
||||
os << "provider=\"" << provider << "\")";
|
||||
|
||||
@@ -13,6 +13,7 @@ namespace sherpa_onnx {
|
||||
|
||||
struct AudioTaggingModelConfig {
|
||||
struct OfflineZipformerAudioTaggingModelConfig zipformer;
|
||||
std::string ced;
|
||||
|
||||
int32_t num_threads = 1;
|
||||
bool debug = false;
|
||||
@@ -22,8 +23,10 @@ struct AudioTaggingModelConfig {
|
||||
|
||||
AudioTaggingModelConfig(
|
||||
const OfflineZipformerAudioTaggingModelConfig &zipformer,
|
||||
int32_t num_threads, bool debug, const std::string &provider)
|
||||
const std::string &ced, int32_t num_threads, bool debug,
|
||||
const std::string &provider)
|
||||
: zipformer(zipformer),
|
||||
ced(ced),
|
||||
num_threads(num_threads),
|
||||
debug(debug),
|
||||
provider(provider) {}
|
||||
|
||||
@@ -4,6 +4,8 @@
|
||||
#ifndef SHERPA_ONNX_CSRC_AUDIO_TAGGING_ZIPFORMER_IMPL_H_
|
||||
#define SHERPA_ONNX_CSRC_AUDIO_TAGGING_ZIPFORMER_IMPL_H_
|
||||
|
||||
#include <assert.h>
|
||||
|
||||
#include <memory>
|
||||
#include <utility>
|
||||
#include <vector>
|
||||
@@ -72,6 +74,8 @@ class AudioTaggingZipformerImpl : public AudioTaggingImpl {
|
||||
|
||||
int32_t num_frames = f.size() / feat_dim;
|
||||
|
||||
assert(feat_dim * num_frames == f.size());
|
||||
|
||||
std::array<int64_t, 3> shape = {1, num_frames, feat_dim};
|
||||
|
||||
Ort::Value x = Ort::Value::CreateTensor(memory_info, f.data(), f.size(),
|
||||
|
||||
@@ -24,7 +24,8 @@ void FeatureExtractorConfig::Register(ParseOptions *po) {
|
||||
"inside the feature extractor");
|
||||
|
||||
po->Register("feat-dim", &feature_dim,
|
||||
"Feature dimension. Must match the one expected by the model.");
|
||||
"Feature dimension. Must match the one expected by the model. "
|
||||
"Not used by whisper and CED models");
|
||||
|
||||
po->Register("low-freq", &low_freq, "Low cutoff frequency for mel bins");
|
||||
|
||||
|
||||
112
sherpa-onnx/csrc/offline-ced-model.cc
Normal file
112
sherpa-onnx/csrc/offline-ced-model.cc
Normal file
@@ -0,0 +1,112 @@
|
||||
// sherpa-onnx/csrc/offline-ced-model.cc
|
||||
//
|
||||
// Copyright (c) 2024 Xiaomi Corporation
|
||||
|
||||
#include "sherpa-onnx/csrc/offline-ced-model.h"
|
||||
|
||||
#include <string>
|
||||
#include <vector>
|
||||
|
||||
#include "sherpa-onnx/csrc/onnx-utils.h"
|
||||
#include "sherpa-onnx/csrc/session.h"
|
||||
#include "sherpa-onnx/csrc/text-utils.h"
|
||||
#include "sherpa-onnx/csrc/transpose.h"
|
||||
|
||||
namespace sherpa_onnx {
|
||||
|
||||
class OfflineCEDModel::Impl {
|
||||
public:
|
||||
explicit Impl(const AudioTaggingModelConfig &config)
|
||||
: config_(config),
|
||||
env_(ORT_LOGGING_LEVEL_ERROR),
|
||||
sess_opts_(GetSessionOptions(config)),
|
||||
allocator_{} {
|
||||
auto buf = ReadFile(config_.ced);
|
||||
Init(buf.data(), buf.size());
|
||||
}
|
||||
|
||||
#if __ANDROID_API__ >= 9
|
||||
Impl(AAssetManager *mgr, const AudioTaggingModelConfig &config)
|
||||
: config_(config),
|
||||
env_(ORT_LOGGING_LEVEL_ERROR),
|
||||
sess_opts_(GetSessionOptions(config)),
|
||||
allocator_{} {
|
||||
auto buf = ReadFile(mgr, config_.ced);
|
||||
Init(buf.data(), buf.size());
|
||||
}
|
||||
#endif
|
||||
|
||||
Ort::Value Forward(Ort::Value features) {
|
||||
features = Transpose12(allocator_, &features);
|
||||
|
||||
auto ans = sess_->Run({}, input_names_ptr_.data(), &features, 1,
|
||||
output_names_ptr_.data(), output_names_ptr_.size());
|
||||
return std::move(ans[0]);
|
||||
}
|
||||
|
||||
int32_t NumEventClasses() const { return num_event_classes_; }
|
||||
|
||||
OrtAllocator *Allocator() const { return allocator_; }
|
||||
|
||||
private:
|
||||
void Init(void *model_data, size_t model_data_length) {
|
||||
sess_ = std::make_unique<Ort::Session>(env_, model_data, model_data_length,
|
||||
sess_opts_);
|
||||
|
||||
GetInputNames(sess_.get(), &input_names_, &input_names_ptr_);
|
||||
|
||||
GetOutputNames(sess_.get(), &output_names_, &output_names_ptr_);
|
||||
|
||||
// get meta data
|
||||
Ort::ModelMetadata meta_data = sess_->GetModelMetadata();
|
||||
if (config_.debug) {
|
||||
std::ostringstream os;
|
||||
PrintModelMetadata(os, meta_data);
|
||||
SHERPA_ONNX_LOGE("%s\n", os.str().c_str());
|
||||
}
|
||||
|
||||
// get num_event_classes from the output[0].shape,
|
||||
// which is (N, num_event_classes)
|
||||
num_event_classes_ =
|
||||
sess_->GetOutputTypeInfo(0).GetTensorTypeAndShapeInfo().GetShape()[1];
|
||||
}
|
||||
|
||||
private:
|
||||
AudioTaggingModelConfig config_;
|
||||
Ort::Env env_;
|
||||
Ort::SessionOptions sess_opts_;
|
||||
Ort::AllocatorWithDefaultOptions allocator_;
|
||||
|
||||
std::unique_ptr<Ort::Session> sess_;
|
||||
|
||||
std::vector<std::string> input_names_;
|
||||
std::vector<const char *> input_names_ptr_;
|
||||
|
||||
std::vector<std::string> output_names_;
|
||||
std::vector<const char *> output_names_ptr_;
|
||||
|
||||
int32_t num_event_classes_ = 0;
|
||||
};
|
||||
|
||||
OfflineCEDModel::OfflineCEDModel(const AudioTaggingModelConfig &config)
|
||||
: impl_(std::make_unique<Impl>(config)) {}
|
||||
|
||||
#if __ANDROID_API__ >= 9
|
||||
OfflineCEDModel::OfflineCEDModel(AAssetManager *mgr,
|
||||
const AudioTaggingModelConfig &config)
|
||||
: impl_(std::make_unique<Impl>(mgr, config)) {}
|
||||
#endif
|
||||
|
||||
OfflineCEDModel::~OfflineCEDModel() = default;
|
||||
|
||||
Ort::Value OfflineCEDModel::Forward(Ort::Value features) const {
|
||||
return impl_->Forward(std::move(features));
|
||||
}
|
||||
|
||||
int32_t OfflineCEDModel::NumEventClasses() const {
|
||||
return impl_->NumEventClasses();
|
||||
}
|
||||
|
||||
OrtAllocator *OfflineCEDModel::Allocator() const { return impl_->Allocator(); }
|
||||
|
||||
} // namespace sherpa_onnx
|
||||
56
sherpa-onnx/csrc/offline-ced-model.h
Normal file
56
sherpa-onnx/csrc/offline-ced-model.h
Normal file
@@ -0,0 +1,56 @@
|
||||
// sherpa-onnx/csrc/offline-ced-model.h
|
||||
//
|
||||
// Copyright (c) 2024 Xiaomi Corporation
|
||||
#ifndef SHERPA_ONNX_CSRC_OFFLINE_CED_MODEL_H_
|
||||
#define SHERPA_ONNX_CSRC_OFFLINE_CED_MODEL_H_
|
||||
#include <memory>
|
||||
#include <utility>
|
||||
|
||||
#if __ANDROID_API__ >= 9
|
||||
#include "android/asset_manager.h"
|
||||
#include "android/asset_manager_jni.h"
|
||||
#endif
|
||||
|
||||
#include "onnxruntime_cxx_api.h" // NOLINT
|
||||
#include "sherpa-onnx/csrc/audio-tagging-model-config.h"
|
||||
|
||||
namespace sherpa_onnx {
|
||||
|
||||
/** This class implements the CED model from
|
||||
* https://github.com/RicherMans/CED/blob/main/export_onnx.py
|
||||
*/
|
||||
class OfflineCEDModel {
|
||||
public:
|
||||
explicit OfflineCEDModel(const AudioTaggingModelConfig &config);
|
||||
|
||||
#if __ANDROID_API__ >= 9
|
||||
OfflineCEDModel(AAssetManager *mgr, const AudioTaggingModelConfig &config);
|
||||
#endif
|
||||
|
||||
~OfflineCEDModel();
|
||||
|
||||
/** Run the forward method of the model.
|
||||
*
|
||||
* @param features A tensor of shape (N, T, C).
|
||||
*
|
||||
* @return Return a tensor
|
||||
* - probs: A 2-D tensor of shape (N, num_event_classes).
|
||||
*/
|
||||
Ort::Value Forward(Ort::Value features) const;
|
||||
|
||||
/** Return the number of event classes of the model
|
||||
*/
|
||||
int32_t NumEventClasses() const;
|
||||
|
||||
/** Return an allocator for allocating memory
|
||||
*/
|
||||
OrtAllocator *Allocator() const;
|
||||
|
||||
private:
|
||||
class Impl;
|
||||
std::unique_ptr<Impl> impl_;
|
||||
};
|
||||
|
||||
} // namespace sherpa_onnx
|
||||
|
||||
#endif // SHERPA_ONNX_CSRC_OFFLINE_CED_MODEL_H_
|
||||
@@ -92,15 +92,32 @@ class OfflineStream::Impl {
|
||||
fbank_ = std::make_unique<knf::OnlineFbank>(opts_);
|
||||
}
|
||||
|
||||
Impl(WhisperTag /*tag*/, ContextGraphPtr context_graph)
|
||||
: context_graph_(context_graph) {
|
||||
explicit Impl(WhisperTag /*tag*/) {
|
||||
config_.normalize_samples = true;
|
||||
opts_.frame_opts.samp_freq = 16000;
|
||||
opts_.mel_opts.num_bins = 80;
|
||||
opts_.mel_opts.num_bins = 80; // not used
|
||||
whisper_fbank_ =
|
||||
std::make_unique<knf::OnlineWhisperFbank>(opts_.frame_opts);
|
||||
}
|
||||
|
||||
explicit Impl(CEDTag /*tag*/) {
|
||||
// see
|
||||
// https://github.com/RicherMans/CED/blob/main/onnx_inference_with_kaldi.py
|
||||
|
||||
opts_.frame_opts.frame_length_ms = 32;
|
||||
opts_.frame_opts.dither = 0;
|
||||
opts_.frame_opts.preemph_coeff = 0;
|
||||
opts_.frame_opts.remove_dc_offset = false;
|
||||
opts_.frame_opts.window_type = "hann";
|
||||
opts_.frame_opts.snip_edges = false;
|
||||
|
||||
opts_.frame_opts.samp_freq = 16000; // fixed to 16000
|
||||
opts_.mel_opts.num_bins = 64;
|
||||
opts_.mel_opts.high_freq = 8000;
|
||||
|
||||
fbank_ = std::make_unique<knf::OnlineFbank>(opts_);
|
||||
}
|
||||
|
||||
void AcceptWaveform(int32_t sampling_rate, const float *waveform, int32_t n) {
|
||||
if (config_.normalize_samples) {
|
||||
AcceptWaveformImpl(sampling_rate, waveform, n);
|
||||
@@ -233,9 +250,10 @@ OfflineStream::OfflineStream(
|
||||
ContextGraphPtr context_graph /*= nullptr*/)
|
||||
: impl_(std::make_unique<Impl>(config, context_graph)) {}
|
||||
|
||||
OfflineStream::OfflineStream(WhisperTag tag,
|
||||
ContextGraphPtr context_graph /*= {}*/)
|
||||
: impl_(std::make_unique<Impl>(tag, context_graph)) {}
|
||||
OfflineStream::OfflineStream(WhisperTag tag)
|
||||
: impl_(std::make_unique<Impl>(tag)) {}
|
||||
|
||||
OfflineStream::OfflineStream(CEDTag tag) : impl_(std::make_unique<Impl>(tag)) {}
|
||||
|
||||
OfflineStream::~OfflineStream() = default;
|
||||
|
||||
|
||||
@@ -67,13 +67,15 @@ struct OfflineFeatureExtractorConfig {
|
||||
};
|
||||
|
||||
struct WhisperTag {};
|
||||
struct CEDTag {};
|
||||
|
||||
class OfflineStream {
|
||||
public:
|
||||
explicit OfflineStream(const OfflineFeatureExtractorConfig &config = {},
|
||||
ContextGraphPtr context_graph = {});
|
||||
|
||||
explicit OfflineStream(WhisperTag tag, ContextGraphPtr context_graph = {});
|
||||
explicit OfflineStream(WhisperTag tag);
|
||||
explicit OfflineStream(CEDTag tag);
|
||||
~OfflineStream();
|
||||
|
||||
/**
|
||||
|
||||
@@ -31,6 +31,12 @@ static AudioTaggingConfig GetAudioTaggingConfig(JNIEnv *env, jobject config) {
|
||||
ans.model.zipformer.model = p;
|
||||
env->ReleaseStringUTFChars(s, p);
|
||||
|
||||
fid = env->GetFieldID(model_cls, "ced", "Ljava/lang/String;");
|
||||
s = (jstring)env->GetObjectField(model, fid);
|
||||
p = env->GetStringUTFChars(s, nullptr);
|
||||
ans.model.ced = p;
|
||||
env->ReleaseStringUTFChars(s, p);
|
||||
|
||||
fid = env->GetFieldID(model_cls, "numThreads", "I");
|
||||
ans.model.num_threads = env->GetIntField(model, fid);
|
||||
|
||||
|
||||
@@ -27,10 +27,11 @@ static void PybindAudioTaggingModelConfig(py::module *m) {
|
||||
|
||||
py::class_<PyClass>(*m, "AudioTaggingModelConfig")
|
||||
.def(py::init<>())
|
||||
.def(py::init<const OfflineZipformerAudioTaggingModelConfig &, int32_t,
|
||||
bool, const std::string &>(),
|
||||
py::arg("zipformer"), py::arg("num_threads") = 1,
|
||||
py::arg("debug") = false, py::arg("provider") = "cpu")
|
||||
.def(py::init<const OfflineZipformerAudioTaggingModelConfig &,
|
||||
const std::string &, int32_t, bool, const std::string &>(),
|
||||
py::arg("zipformer"), py::arg("ced") = "",
|
||||
py::arg("num_threads") = 1, py::arg("debug") = false,
|
||||
py::arg("provider") = "cpu")
|
||||
.def_readwrite("zipformer", &PyClass::zipformer)
|
||||
.def_readwrite("num_threads", &PyClass::num_threads)
|
||||
.def_readwrite("debug", &PyClass::debug)
|
||||
|
||||
Reference in New Issue
Block a user