Support CED models (#792)

This commit is contained in:
Fangjun Kuang
2024-04-19 15:20:37 +08:00
committed by GitHub
parent d97a283dbb
commit c1608b3524
33 changed files with 605 additions and 46 deletions

View File

@@ -0,0 +1,78 @@
name: export-ced-to-onnx
on:
workflow_dispatch:
concurrency:
group: export-ced-to-onnx-${{ github.ref }}
cancel-in-progress: true
jobs:
export-ced-to-onnx:
if: github.repository_owner == 'k2-fsa' || github.repository_owner == 'csukuangfj'
name: export ced
runs-on: ${{ matrix.os }}
strategy:
fail-fast: false
matrix:
os: [ubuntu-latest]
python-version: ["3.8"]
steps:
- uses: actions/checkout@v4
- name: Setup Python ${{ matrix.python-version }}
uses: actions/setup-python@v5
with:
python-version: ${{ matrix.python-version }}
- name: Run
shell: bash
run: |
cd scripts/ced
./run.sh
- name: Release
uses: svenstaro/upload-release-action@v2
with:
file_glob: true
file: ./*.tar.bz2
overwrite: true
repo_name: k2-fsa/sherpa-onnx
repo_token: ${{ secrets.UPLOAD_GH_SHERPA_ONNX_TOKEN }}
tag: audio-tagging-models
- name: Publish to huggingface
env:
HF_TOKEN: ${{ secrets.HF_TOKEN }}
uses: nick-fields/retry@v3
with:
max_attempts: 20
timeout_seconds: 200
shell: bash
command: |
git config --global user.email "csukuangfj@gmail.com"
git config --global user.name "Fangjun Kuang"
models=(
tiny
mini
small
base
)
for m in ${models[@]}; do
rm -rf huggingface
export GIT_LFS_SKIP_SMUDGE=1
d=sherpa-onnx-ced-$m-audio-tagging-2024-04-19
git clone https://huggingface.co/k2-fsa/$d huggingface
mv -v $d/* huggingface
cd huggingface
git lfs track "*.onnx"
git status
git add .
git status
git commit -m "first commit"
git push https://csukuangfj:$HF_TOKEN@huggingface.co/k2-fsa/$d main
cd ..
done

View File

@@ -1,5 +1,5 @@
<resources> <resources>
<string name="app_name">ASR with Next-gen Kaldi</string> <string name="app_name">ASR</string>
<string name="hint">Click the Start button to play speech-to-text with Next-gen Kaldi. <string name="hint">Click the Start button to play speech-to-text with Next-gen Kaldi.
\n \n
\n\n\n \n\n\n

View File

@@ -1,5 +1,5 @@
<resources> <resources>
<string name="app_name">ASR with Next-gen Kaldi</string> <string name="app_name">ASR2pass </string>
<string name="hint">Click the Start button to play speech-to-text with Next-gen Kaldi. <string name="hint">Click the Start button to play speech-to-text with Next-gen Kaldi.
\n \n
\n\n\n \n\n\n

View File

@@ -1,16 +1,16 @@
package com.k2fsa.sherpa.onnx package com.k2fsa.sherpa.onnx
import android.content.res.AssetManager import android.content.res.AssetManager
import android.util.Log
private val TAG = "sherpa-onnx" const val TAG = "sherpa-onnx"
data class OfflineZipformerAudioTaggingModelConfig( data class OfflineZipformerAudioTaggingModelConfig(
var model: String, var model: String = "",
) )
data class AudioTaggingModelConfig( data class AudioTaggingModelConfig(
var zipformer: OfflineZipformerAudioTaggingModelConfig, var zipformer: OfflineZipformerAudioTaggingModelConfig = OfflineZipformerAudioTaggingModelConfig(),
var ced: String = "",
var numThreads: Int = 1, var numThreads: Int = 1,
var debug: Boolean = false, var debug: Boolean = false,
var provider: String = "cpu", var provider: String = "cpu",
@@ -103,7 +103,7 @@ class AudioTagging(
// //
// See also // See also
// https://k2-fsa.github.io/sherpa/onnx/audio-tagging/ // https://k2-fsa.github.io/sherpa/onnx/audio-tagging/
fun getAudioTaggingConfig(type: Int, numThreads: Int=1): AudioTaggingConfig? { fun getAudioTaggingConfig(type: Int, numThreads: Int = 1): AudioTaggingConfig? {
when (type) { when (type) {
0 -> { 0 -> {
val modelDir = "sherpa-onnx-zipformer-small-audio-tagging-2024-04-15" val modelDir = "sherpa-onnx-zipformer-small-audio-tagging-2024-04-15"
@@ -123,7 +123,7 @@ fun getAudioTaggingConfig(type: Int, numThreads: Int=1): AudioTaggingConfig? {
return AudioTaggingConfig( return AudioTaggingConfig(
model = AudioTaggingModelConfig( model = AudioTaggingModelConfig(
zipformer = OfflineZipformerAudioTaggingModelConfig(model = "$modelDir/model.int8.onnx"), zipformer = OfflineZipformerAudioTaggingModelConfig(model = "$modelDir/model.int8.onnx"),
numThreads = 1, numThreads = numThreads,
debug = true, debug = true,
), ),
labels = "$modelDir/class_labels_indices.csv", labels = "$modelDir/class_labels_indices.csv",
@@ -131,6 +131,57 @@ fun getAudioTaggingConfig(type: Int, numThreads: Int=1): AudioTaggingConfig? {
) )
} }
2 -> {
val modelDir = "sherpa-onnx-ced-tiny-audio-tagging-2024-04-19"
return AudioTaggingConfig(
model = AudioTaggingModelConfig(
ced = "$modelDir/model.int8.onnx",
numThreads = numThreads,
debug = true,
),
labels = "$modelDir/class_labels_indices.csv",
topK = 3,
)
}
3 -> {
val modelDir = "sherpa-onnx-ced-mini-audio-tagging-2024-04-19"
return AudioTaggingConfig(
model = AudioTaggingModelConfig(
ced = "$modelDir/model.int8.onnx",
numThreads = numThreads,
debug = true,
),
labels = "$modelDir/class_labels_indices.csv",
topK = 3,
)
}
4 -> {
val modelDir = "sherpa-onnx-ced-small-audio-tagging-2024-04-19"
return AudioTaggingConfig(
model = AudioTaggingModelConfig(
ced = "$modelDir/model.int8.onnx",
numThreads = numThreads,
debug = true,
),
labels = "$modelDir/class_labels_indices.csv",
topK = 3,
)
}
5 -> {
val modelDir = "sherpa-onnx-ced-base-audio-tagging-2024-04-19"
return AudioTaggingConfig(
model = AudioTaggingModelConfig(
ced = "$modelDir/model.int8.onnx",
numThreads = numThreads,
debug = true,
),
labels = "$modelDir/class_labels_indices.csv",
topK = 3,
)
}
} }
return null return null

View File

@@ -3,24 +3,15 @@
package com.k2fsa.sherpa.onnx.audio.tagging package com.k2fsa.sherpa.onnx.audio.tagging
import android.Manifest import android.Manifest
import android.app.Activity import android.app.Activity
import android.content.pm.PackageManager import android.content.pm.PackageManager
import android.media.AudioFormat import android.media.AudioFormat
import android.media.AudioRecord import android.media.AudioRecord
import androidx.compose.foundation.lazy.items
import android.media.MediaRecorder import android.media.MediaRecorder
import android.util.Log import android.util.Log
import androidx.compose.foundation.ExperimentalFoundationApi import androidx.compose.foundation.ExperimentalFoundationApi
import androidx.compose.foundation.background
import androidx.compose.foundation.layout.Arrangement import androidx.compose.foundation.layout.Arrangement
import androidx.compose.foundation.layout.Box import androidx.compose.foundation.layout.Box
import androidx.compose.material3.CenterAlignedTopAppBar
import androidx.compose.runtime.Composable
import androidx.compose.material3.Scaffold
import androidx.compose.material3.TopAppBarDefaults
import androidx.compose.material3.MaterialTheme
import androidx.compose.material3.Text
import androidx.compose.foundation.layout.Column import androidx.compose.foundation.layout.Column
import androidx.compose.foundation.layout.PaddingValues import androidx.compose.foundation.layout.PaddingValues
import androidx.compose.foundation.layout.Row import androidx.compose.foundation.layout.Row
@@ -30,10 +21,17 @@ import androidx.compose.foundation.layout.fillMaxWidth
import androidx.compose.foundation.layout.height import androidx.compose.foundation.layout.height
import androidx.compose.foundation.layout.padding import androidx.compose.foundation.layout.padding
import androidx.compose.foundation.lazy.LazyColumn import androidx.compose.foundation.lazy.LazyColumn
import androidx.compose.foundation.lazy.items
import androidx.compose.material3.Button import androidx.compose.material3.Button
import androidx.compose.material3.CenterAlignedTopAppBar
import androidx.compose.material3.ExperimentalMaterial3Api import androidx.compose.material3.ExperimentalMaterial3Api
import androidx.compose.material3.MaterialTheme
import androidx.compose.material3.Scaffold
import androidx.compose.material3.Slider import androidx.compose.material3.Slider
import androidx.compose.material3.Surface import androidx.compose.material3.Surface
import androidx.compose.material3.Text
import androidx.compose.material3.TopAppBarDefaults
import androidx.compose.runtime.Composable
import androidx.compose.runtime.getValue import androidx.compose.runtime.getValue
import androidx.compose.runtime.mutableStateListOf import androidx.compose.runtime.mutableStateListOf
import androidx.compose.runtime.mutableStateOf import androidx.compose.runtime.mutableStateOf
@@ -41,7 +39,6 @@ import androidx.compose.runtime.remember
import androidx.compose.runtime.setValue import androidx.compose.runtime.setValue
import androidx.compose.ui.Alignment import androidx.compose.ui.Alignment
import androidx.compose.ui.Modifier import androidx.compose.ui.Modifier
import androidx.compose.ui.graphics.Color
import androidx.compose.ui.platform.LocalContext import androidx.compose.ui.platform.LocalContext
import androidx.compose.ui.text.font.FontWeight import androidx.compose.ui.text.font.FontWeight
import androidx.compose.ui.text.style.TextAlign import androidx.compose.ui.text.style.TextAlign
@@ -49,6 +46,7 @@ import androidx.compose.ui.unit.dp
import androidx.compose.ui.unit.sp import androidx.compose.ui.unit.sp
import androidx.core.app.ActivityCompat import androidx.core.app.ActivityCompat
import com.k2fsa.sherpa.onnx.AudioEvent import com.k2fsa.sherpa.onnx.AudioEvent
import com.k2fsa.sherpa.onnx.Tagger
import kotlin.concurrent.thread import kotlin.concurrent.thread

View File

@@ -13,6 +13,7 @@ import androidx.compose.material3.Surface
import androidx.compose.runtime.Composable import androidx.compose.runtime.Composable
import androidx.compose.ui.Modifier import androidx.compose.ui.Modifier
import androidx.core.app.ActivityCompat import androidx.core.app.ActivityCompat
import com.k2fsa.sherpa.onnx.Tagger
import com.k2fsa.sherpa.onnx.audio.tagging.ui.theme.SherpaOnnxAudioTaggingTheme import com.k2fsa.sherpa.onnx.audio.tagging.ui.theme.SherpaOnnxAudioTaggingTheme
const val TAG = "sherpa-onnx" const val TAG = "sherpa-onnx"

View File

@@ -1,10 +1,8 @@
package com.k2fsa.sherpa.onnx.audio.tagging package com.k2fsa.sherpa.onnx
import android.content.res.AssetManager import android.content.res.AssetManager
import android.util.Log import android.util.Log
import com.k2fsa.sherpa.onnx.AudioTagging
import com.k2fsa.sherpa.onnx.audio.tagging.wear.os.presentation.TAG
import com.k2fsa.sherpa.onnx.getAudioTaggingConfig
object Tagger { object Tagger {
private var _tagger: AudioTagging? = null private var _tagger: AudioTagging? = null
@@ -12,6 +10,7 @@ object Tagger {
get() { get() {
return _tagger!! return _tagger!!
} }
fun initTagger(assetManager: AssetManager? = null, numThreads: Int = 1) { fun initTagger(assetManager: AssetManager? = null, numThreads: Int = 1) {
synchronized(this) { synchronized(this) {
if (_tagger != null) { if (_tagger != null) {
@@ -19,7 +18,7 @@ object Tagger {
} }
Log.i(TAG, "Initializing audio tagger") Log.i(TAG, "Initializing audio tagger")
val config = getAudioTaggingConfig(type = 0, numThreads=numThreads)!! val config = getAudioTaggingConfig(type = 0, numThreads = numThreads)!!
_tagger = AudioTagging(assetManager, config) _tagger = AudioTagging(assetManager, config)
} }
} }

View File

@@ -33,7 +33,7 @@ import androidx.wear.compose.material.Button
import androidx.wear.compose.material.MaterialTheme import androidx.wear.compose.material.MaterialTheme
import androidx.wear.compose.material.Text import androidx.wear.compose.material.Text
import com.k2fsa.sherpa.onnx.AudioEvent import com.k2fsa.sherpa.onnx.AudioEvent
import com.k2fsa.sherpa.onnx.audio.tagging.Tagger import com.k2fsa.sherpa.onnx.Tagger
import com.k2fsa.sherpa.onnx.audio.tagging.wear.os.presentation.theme.SherpaOnnxAudioTaggingWearOsTheme import com.k2fsa.sherpa.onnx.audio.tagging.wear.os.presentation.theme.SherpaOnnxAudioTaggingWearOsTheme
import kotlin.concurrent.thread import kotlin.concurrent.thread

View File

@@ -17,7 +17,7 @@ import androidx.activity.compose.setContent
import androidx.compose.runtime.Composable import androidx.compose.runtime.Composable
import androidx.core.app.ActivityCompat import androidx.core.app.ActivityCompat
import androidx.core.splashscreen.SplashScreen.Companion.installSplashScreen import androidx.core.splashscreen.SplashScreen.Companion.installSplashScreen
import com.k2fsa.sherpa.onnx.audio.tagging.Tagger import com.k2fsa.sherpa.onnx.Tagger
const val TAG = "sherpa-onnx" const val TAG = "sherpa-onnx"
private const val REQUEST_RECORD_AUDIO_PERMISSION = 200 private const val REQUEST_RECORD_AUDIO_PERMISSION = 200

View File

@@ -1,5 +1,5 @@
<resources> <resources>
<string name="app_name">AudioTagging</string> <string name="app_name">Audio Tagging</string>
<!-- <!--
This string is used for square devices and overridden by hello_world in This string is used for square devices and overridden by hello_world in
values-round/strings.xml for round devices. values-round/strings.xml for round devices.

View File

@@ -1,5 +1,5 @@
<resources> <resources>
<string name="app_name">Speaker Identification</string> <string name="app_name">Speaker ID</string>
<string name="start">Start recording</string> <string name="start">Start recording</string>
<string name="stop">Stop recording</string> <string name="stop">Stop recording</string>
<string name="add">Add speaker</string> <string name="add">Add speaker</string>

View File

@@ -1,3 +1,3 @@
<resources> <resources>
<string name="app_name">SherpaOnnxSpokenLanguageIdentification</string> <string name="app_name">Language ID</string>
</resources> </resources>

View File

@@ -1,5 +1,5 @@
<resources> <resources>
<string name="app_name">Next-gen Kaldi: TTS</string> <string name="app_name">TTS</string>
<string name="sid_label">Speaker ID</string> <string name="sid_label">Speaker ID</string>
<string name="sid_hint">0</string> <string name="sid_hint">0</string>
<string name="speed_label">Speech speed (large->fast)</string> <string name="speed_label">Speech speed (large->fast)</string>

View File

@@ -1,3 +1,3 @@
<resources> <resources>
<string name="app_name">Next-gen Kaldi: TTS</string> <string name="app_name">TTS Engine</string>
</resources> </resources>

View File

@@ -1,5 +1,5 @@
<resources> <resources>
<string name="app_name">Next-gen Kaldi: SileroVAD</string> <string name="app_name">VAD</string>
<string name="hint">Click the Start button to play Silero VAD with Next-gen Kaldi.</string> <string name="hint">Click the Start button to play Silero VAD with Next-gen Kaldi.</string>
<string name="start">Start</string> <string name="start">Start</string>

View File

@@ -1,5 +1,5 @@
<resources> <resources>
<string name="app_name">ASR with Next-gen Kaldi</string> <string name="app_name">VAD-ASR</string>
<string name="hint">Click the Start button to play speech-to-text with Next-gen Kaldi. <string name="hint">Click the Start button to play speech-to-text with Next-gen Kaldi.
\n \n
\n\n\n \n\n\n

View File

@@ -46,7 +46,30 @@ def get_models():
), ),
] ]
return icefall_models ced_models = [
AudioTaggingModel(
model_name="sherpa-onnx-ced-tiny-audio-tagging-2024-04-19",
idx=2,
short_name="ced_tiny",
),
AudioTaggingModel(
model_name="sherpa-onnx-ced-mini-audio-tagging-2024-04-19",
idx=3,
short_name="ced_mini",
),
AudioTaggingModel(
model_name="sherpa-onnx-ced-small-audio-tagging-2024-04-19",
idx=4,
short_name="ced_small",
),
AudioTaggingModel(
model_name="sherpa-onnx-ced-base-audio-tagging-2024-04-19",
idx=5,
short_name="ced_base",
),
]
return icefall_models + ced_models
def main(): def main():

69
scripts/ced/run.sh Executable file
View File

@@ -0,0 +1,69 @@
#!/usr/bin/env bash
#
# Copyright 2024 Xiaomi Corp. (authors: Fangjun Kuang)
set -ex
function install_dependencies() {
pip install -qq torch==2.1.0+cpu torchaudio==2.1.0+cpu -f https://download.pytorch.org/whl/torch_stable.html
pip install -qq onnx onnxruntime==1.17.1
pip install -r ./requirements.txt
}
git clone https://github.com/RicherMans/CED
pushd CED
install_dependencies
models=(
tiny
mini
small
base
)
for m in ${models[@]}; do
python3 ./export_onnx.py -m ced_$m
done
ls -lh *.onnx
curl -SL -O https://github.com/k2-fsa/sherpa-onnx/releases/download/audio-tagging-models/sherpa-onnx-zipformer-small-audio-tagging-2024-04-15.tar.bz2
tar xvf sherpa-onnx-zipformer-small-audio-tagging-2024-04-15.tar.bz2
rm sherpa-onnx-zipformer-small-audio-tagging-2024-04-15.tar.bz2
src=sherpa-onnx-zipformer-small-audio-tagging-2024-04-15
cat >README.md <<EOF
# Introduction
Models in this repo are converted from
https://github.com/RicherMans/CED
EOF
for m in ${models[@]}; do
d=sherpa-onnx-ced-$m-audio-tagging-2024-04-19
mkdir -p $d
cp -v README.md $d
cp -v $src/class_labels_indices.csv $d
cp -a $src/test_wavs $d
cp -v ced_$m.onnx $d/model.onnx
cp -v ced_$m.int8.onnx $d/model.int8.onnx
echo "----------$m----------"
ls -lh $d
echo "----------------------"
tar cjvf $d.tar.bz2 $d
mv $d.tar.bz2 ../../..
mv $d ../../../
done
rm -rf sherpa-onnx-zipformer-small-audio-tagging-2024-04-15
cd ../../..
ls -lh *.tar.bz2
echo "======="
ls -lh

View File

@@ -1223,6 +1223,7 @@ const SherpaOnnxAudioTagging *SherpaOnnxCreateAudioTagging(
const SherpaOnnxAudioTaggingConfig *config) { const SherpaOnnxAudioTaggingConfig *config) {
sherpa_onnx::AudioTaggingConfig ac; sherpa_onnx::AudioTaggingConfig ac;
ac.model.zipformer.model = SHERPA_ONNX_OR(config->model.zipformer.model, ""); ac.model.zipformer.model = SHERPA_ONNX_OR(config->model.zipformer.model, "");
ac.model.ced = SHERPA_ONNX_OR(config->model.ced, "");
ac.model.num_threads = SHERPA_ONNX_OR(config->model.num_threads, 1); ac.model.num_threads = SHERPA_ONNX_OR(config->model.num_threads, 1);
ac.model.debug = config->model.debug; ac.model.debug = config->model.debug;
ac.model.provider = SHERPA_ONNX_OR(config->model.provider, "cpu"); ac.model.provider = SHERPA_ONNX_OR(config->model.provider, "cpu");

View File

@@ -1100,6 +1100,7 @@ SHERPA_ONNX_API typedef struct
SHERPA_ONNX_API typedef struct SherpaOnnxAudioTaggingModelConfig { SHERPA_ONNX_API typedef struct SherpaOnnxAudioTaggingModelConfig {
SherpaOnnxOfflineZipformerAudioTaggingModelConfig zipformer; SherpaOnnxOfflineZipformerAudioTaggingModelConfig zipformer;
const char *ced;
int32_t num_threads; int32_t num_threads;
int32_t debug; // true to print debug information of the model int32_t debug; // true to print debug information of the model
const char *provider; const char *provider;

View File

@@ -117,6 +117,7 @@ list(APPEND sources
audio-tagging-label-file.cc audio-tagging-label-file.cc
audio-tagging-model-config.cc audio-tagging-model-config.cc
audio-tagging.cc audio-tagging.cc
offline-ced-model.cc
offline-zipformer-audio-tagging-model-config.cc offline-zipformer-audio-tagging-model-config.cc
offline-zipformer-audio-tagging-model.cc offline-zipformer-audio-tagging-model.cc
) )

View File

@@ -0,0 +1,111 @@
// sherpa-onnx/csrc/audio-tagging-ced-impl.h
//
// Copyright (c) 2024 Xiaomi Corporation
#ifndef SHERPA_ONNX_CSRC_AUDIO_TAGGING_CED_IMPL_H_
#define SHERPA_ONNX_CSRC_AUDIO_TAGGING_CED_IMPL_H_
#include <assert.h>
#include <memory>
#include <utility>
#include <vector>
#if __ANDROID_API__ >= 9
#include "android/asset_manager.h"
#include "android/asset_manager_jni.h"
#endif
#include "sherpa-onnx/csrc/audio-tagging-impl.h"
#include "sherpa-onnx/csrc/audio-tagging-label-file.h"
#include "sherpa-onnx/csrc/audio-tagging.h"
#include "sherpa-onnx/csrc/macros.h"
#include "sherpa-onnx/csrc/math.h"
#include "sherpa-onnx/csrc/offline-ced-model.h"
namespace sherpa_onnx {
class AudioTaggingCEDImpl : public AudioTaggingImpl {
public:
explicit AudioTaggingCEDImpl(const AudioTaggingConfig &config)
: config_(config), model_(config.model), labels_(config.labels) {
if (model_.NumEventClasses() != labels_.NumEventClasses()) {
SHERPA_ONNX_LOGE("number of classes: %d (model) != %d (label file)",
model_.NumEventClasses(), labels_.NumEventClasses());
exit(-1);
}
}
#if __ANDROID_API__ >= 9
explicit AudioTaggingCEDImpl(AAssetManager *mgr,
const AudioTaggingConfig &config)
: config_(config),
model_(mgr, config.model),
labels_(mgr, config.labels) {
if (model_.NumEventClasses() != labels_.NumEventClasses()) {
SHERPA_ONNX_LOGE("number of classes: %d (model) != %d (label file)",
model_.NumEventClasses(), labels_.NumEventClasses());
exit(-1);
}
}
#endif
std::unique_ptr<OfflineStream> CreateStream() const override {
return std::make_unique<OfflineStream>(CEDTag{});
}
std::vector<AudioEvent> Compute(OfflineStream *s,
int32_t top_k = -1) const override {
if (top_k < 0) {
top_k = config_.top_k;
}
int32_t num_event_classes = model_.NumEventClasses();
if (top_k > num_event_classes) {
top_k = num_event_classes;
}
auto memory_info =
Ort::MemoryInfo::CreateCpu(OrtDeviceAllocator, OrtMemTypeDefault);
// WARNING(fangjun): It is fixed to 64 for CED models
int32_t feat_dim = 64;
std::vector<float> f = s->GetFrames();
int32_t num_frames = f.size() / feat_dim;
assert(feat_dim * num_frames == f.size());
std::array<int64_t, 3> shape = {1, num_frames, feat_dim};
Ort::Value x = Ort::Value::CreateTensor(memory_info, f.data(), f.size(),
shape.data(), shape.size());
Ort::Value probs = model_.Forward(std::move(x));
const float *p = probs.GetTensorData<float>();
std::vector<int32_t> top_k_indexes = TopkIndex(p, num_event_classes, top_k);
std::vector<AudioEvent> ans(top_k);
int32_t i = 0;
for (int32_t index : top_k_indexes) {
ans[i].name = labels_.GetEventName(index);
ans[i].index = index;
ans[i].prob = p[index];
i += 1;
}
return ans;
}
private:
AudioTaggingConfig config_;
OfflineCEDModel model_;
AudioTaggingLabels labels_;
};
} // namespace sherpa_onnx
#endif // SHERPA_ONNX_CSRC_AUDIO_TAGGING_CED_IMPL_H_

View File

@@ -11,6 +11,7 @@
#include "android/asset_manager_jni.h" #include "android/asset_manager_jni.h"
#endif #endif
#include "sherpa-onnx/csrc/audio-tagging-ced-impl.h"
#include "sherpa-onnx/csrc/audio-tagging-zipformer-impl.h" #include "sherpa-onnx/csrc/audio-tagging-zipformer-impl.h"
#include "sherpa-onnx/csrc/macros.h" #include "sherpa-onnx/csrc/macros.h"
@@ -20,6 +21,8 @@ std::unique_ptr<AudioTaggingImpl> AudioTaggingImpl::Create(
const AudioTaggingConfig &config) { const AudioTaggingConfig &config) {
if (!config.model.zipformer.model.empty()) { if (!config.model.zipformer.model.empty()) {
return std::make_unique<AudioTaggingZipformerImpl>(config); return std::make_unique<AudioTaggingZipformerImpl>(config);
} else if (!config.model.ced.empty()) {
return std::make_unique<AudioTaggingCEDImpl>(config);
} }
SHERPA_ONNX_LOG( SHERPA_ONNX_LOG(
@@ -32,6 +35,8 @@ std::unique_ptr<AudioTaggingImpl> AudioTaggingImpl::Create(
AAssetManager *mgr, const AudioTaggingConfig &config) { AAssetManager *mgr, const AudioTaggingConfig &config) {
if (!config.model.zipformer.model.empty()) { if (!config.model.zipformer.model.empty()) {
return std::make_unique<AudioTaggingZipformerImpl>(mgr, config); return std::make_unique<AudioTaggingZipformerImpl>(mgr, config);
} else if (!config.model.ced.empty()) {
return std::make_unique<AudioTaggingCEDImpl>(mgr, config);
} }
SHERPA_ONNX_LOG( SHERPA_ONNX_LOG(

View File

@@ -4,11 +4,18 @@
#include "sherpa-onnx/csrc/audio-tagging-model-config.h" #include "sherpa-onnx/csrc/audio-tagging-model-config.h"
#include "sherpa-onnx/csrc/file-utils.h"
#include "sherpa-onnx/csrc/macros.h"
namespace sherpa_onnx { namespace sherpa_onnx {
void AudioTaggingModelConfig::Register(ParseOptions *po) { void AudioTaggingModelConfig::Register(ParseOptions *po) {
zipformer.Register(po); zipformer.Register(po);
po->Register("ced-model", &ced,
"Path to CED model. Only need to pass one of --zipformer-model "
"or --ced-model");
po->Register("num-threads", &num_threads, po->Register("num-threads", &num_threads,
"Number of threads to run the neural network"); "Number of threads to run the neural network");
@@ -24,6 +31,16 @@ bool AudioTaggingModelConfig::Validate() const {
return false; return false;
} }
if (!ced.empty() && !FileExists(ced)) {
SHERPA_ONNX_LOGE("CED model file %s does not exist", ced.c_str());
return false;
}
if (zipformer.model.empty() && ced.empty()) {
SHERPA_ONNX_LOGE("Please provide either --zipformer-model or --ced-model");
return false;
}
return true; return true;
} }
@@ -32,6 +49,7 @@ std::string AudioTaggingModelConfig::ToString() const {
os << "AudioTaggingModelConfig("; os << "AudioTaggingModelConfig(";
os << "zipformer=" << zipformer.ToString() << ", "; os << "zipformer=" << zipformer.ToString() << ", ";
os << "ced=\"" << ced << "\", ";
os << "num_threads=" << num_threads << ", "; os << "num_threads=" << num_threads << ", ";
os << "debug=" << (debug ? "True" : "False") << ", "; os << "debug=" << (debug ? "True" : "False") << ", ";
os << "provider=\"" << provider << "\")"; os << "provider=\"" << provider << "\")";

View File

@@ -13,6 +13,7 @@ namespace sherpa_onnx {
struct AudioTaggingModelConfig { struct AudioTaggingModelConfig {
struct OfflineZipformerAudioTaggingModelConfig zipformer; struct OfflineZipformerAudioTaggingModelConfig zipformer;
std::string ced;
int32_t num_threads = 1; int32_t num_threads = 1;
bool debug = false; bool debug = false;
@@ -22,8 +23,10 @@ struct AudioTaggingModelConfig {
AudioTaggingModelConfig( AudioTaggingModelConfig(
const OfflineZipformerAudioTaggingModelConfig &zipformer, const OfflineZipformerAudioTaggingModelConfig &zipformer,
int32_t num_threads, bool debug, const std::string &provider) const std::string &ced, int32_t num_threads, bool debug,
const std::string &provider)
: zipformer(zipformer), : zipformer(zipformer),
ced(ced),
num_threads(num_threads), num_threads(num_threads),
debug(debug), debug(debug),
provider(provider) {} provider(provider) {}

View File

@@ -4,6 +4,8 @@
#ifndef SHERPA_ONNX_CSRC_AUDIO_TAGGING_ZIPFORMER_IMPL_H_ #ifndef SHERPA_ONNX_CSRC_AUDIO_TAGGING_ZIPFORMER_IMPL_H_
#define SHERPA_ONNX_CSRC_AUDIO_TAGGING_ZIPFORMER_IMPL_H_ #define SHERPA_ONNX_CSRC_AUDIO_TAGGING_ZIPFORMER_IMPL_H_
#include <assert.h>
#include <memory> #include <memory>
#include <utility> #include <utility>
#include <vector> #include <vector>
@@ -72,6 +74,8 @@ class AudioTaggingZipformerImpl : public AudioTaggingImpl {
int32_t num_frames = f.size() / feat_dim; int32_t num_frames = f.size() / feat_dim;
assert(feat_dim * num_frames == f.size());
std::array<int64_t, 3> shape = {1, num_frames, feat_dim}; std::array<int64_t, 3> shape = {1, num_frames, feat_dim};
Ort::Value x = Ort::Value::CreateTensor(memory_info, f.data(), f.size(), Ort::Value x = Ort::Value::CreateTensor(memory_info, f.data(), f.size(),

View File

@@ -24,7 +24,8 @@ void FeatureExtractorConfig::Register(ParseOptions *po) {
"inside the feature extractor"); "inside the feature extractor");
po->Register("feat-dim", &feature_dim, po->Register("feat-dim", &feature_dim,
"Feature dimension. Must match the one expected by the model."); "Feature dimension. Must match the one expected by the model. "
"Not used by whisper and CED models");
po->Register("low-freq", &low_freq, "Low cutoff frequency for mel bins"); po->Register("low-freq", &low_freq, "Low cutoff frequency for mel bins");

View File

@@ -0,0 +1,112 @@
// sherpa-onnx/csrc/offline-ced-model.cc
//
// Copyright (c) 2024 Xiaomi Corporation
#include "sherpa-onnx/csrc/offline-ced-model.h"
#include <string>
#include <vector>
#include "sherpa-onnx/csrc/onnx-utils.h"
#include "sherpa-onnx/csrc/session.h"
#include "sherpa-onnx/csrc/text-utils.h"
#include "sherpa-onnx/csrc/transpose.h"
namespace sherpa_onnx {
class OfflineCEDModel::Impl {
public:
explicit Impl(const AudioTaggingModelConfig &config)
: config_(config),
env_(ORT_LOGGING_LEVEL_ERROR),
sess_opts_(GetSessionOptions(config)),
allocator_{} {
auto buf = ReadFile(config_.ced);
Init(buf.data(), buf.size());
}
#if __ANDROID_API__ >= 9
Impl(AAssetManager *mgr, const AudioTaggingModelConfig &config)
: config_(config),
env_(ORT_LOGGING_LEVEL_ERROR),
sess_opts_(GetSessionOptions(config)),
allocator_{} {
auto buf = ReadFile(mgr, config_.ced);
Init(buf.data(), buf.size());
}
#endif
Ort::Value Forward(Ort::Value features) {
features = Transpose12(allocator_, &features);
auto ans = sess_->Run({}, input_names_ptr_.data(), &features, 1,
output_names_ptr_.data(), output_names_ptr_.size());
return std::move(ans[0]);
}
int32_t NumEventClasses() const { return num_event_classes_; }
OrtAllocator *Allocator() const { return allocator_; }
private:
void Init(void *model_data, size_t model_data_length) {
sess_ = std::make_unique<Ort::Session>(env_, model_data, model_data_length,
sess_opts_);
GetInputNames(sess_.get(), &input_names_, &input_names_ptr_);
GetOutputNames(sess_.get(), &output_names_, &output_names_ptr_);
// get meta data
Ort::ModelMetadata meta_data = sess_->GetModelMetadata();
if (config_.debug) {
std::ostringstream os;
PrintModelMetadata(os, meta_data);
SHERPA_ONNX_LOGE("%s\n", os.str().c_str());
}
// get num_event_classes from the output[0].shape,
// which is (N, num_event_classes)
num_event_classes_ =
sess_->GetOutputTypeInfo(0).GetTensorTypeAndShapeInfo().GetShape()[1];
}
private:
AudioTaggingModelConfig config_;
Ort::Env env_;
Ort::SessionOptions sess_opts_;
Ort::AllocatorWithDefaultOptions allocator_;
std::unique_ptr<Ort::Session> sess_;
std::vector<std::string> input_names_;
std::vector<const char *> input_names_ptr_;
std::vector<std::string> output_names_;
std::vector<const char *> output_names_ptr_;
int32_t num_event_classes_ = 0;
};
OfflineCEDModel::OfflineCEDModel(const AudioTaggingModelConfig &config)
: impl_(std::make_unique<Impl>(config)) {}
#if __ANDROID_API__ >= 9
OfflineCEDModel::OfflineCEDModel(AAssetManager *mgr,
const AudioTaggingModelConfig &config)
: impl_(std::make_unique<Impl>(mgr, config)) {}
#endif
OfflineCEDModel::~OfflineCEDModel() = default;
Ort::Value OfflineCEDModel::Forward(Ort::Value features) const {
return impl_->Forward(std::move(features));
}
int32_t OfflineCEDModel::NumEventClasses() const {
return impl_->NumEventClasses();
}
OrtAllocator *OfflineCEDModel::Allocator() const { return impl_->Allocator(); }
} // namespace sherpa_onnx

View File

@@ -0,0 +1,56 @@
// sherpa-onnx/csrc/offline-ced-model.h
//
// Copyright (c) 2024 Xiaomi Corporation
#ifndef SHERPA_ONNX_CSRC_OFFLINE_CED_MODEL_H_
#define SHERPA_ONNX_CSRC_OFFLINE_CED_MODEL_H_
#include <memory>
#include <utility>
#if __ANDROID_API__ >= 9
#include "android/asset_manager.h"
#include "android/asset_manager_jni.h"
#endif
#include "onnxruntime_cxx_api.h" // NOLINT
#include "sherpa-onnx/csrc/audio-tagging-model-config.h"
namespace sherpa_onnx {
/** This class implements the CED model from
* https://github.com/RicherMans/CED/blob/main/export_onnx.py
*/
class OfflineCEDModel {
public:
explicit OfflineCEDModel(const AudioTaggingModelConfig &config);
#if __ANDROID_API__ >= 9
OfflineCEDModel(AAssetManager *mgr, const AudioTaggingModelConfig &config);
#endif
~OfflineCEDModel();
/** Run the forward method of the model.
*
* @param features A tensor of shape (N, T, C).
*
* @return Return a tensor
* - probs: A 2-D tensor of shape (N, num_event_classes).
*/
Ort::Value Forward(Ort::Value features) const;
/** Return the number of event classes of the model
*/
int32_t NumEventClasses() const;
/** Return an allocator for allocating memory
*/
OrtAllocator *Allocator() const;
private:
class Impl;
std::unique_ptr<Impl> impl_;
};
} // namespace sherpa_onnx
#endif // SHERPA_ONNX_CSRC_OFFLINE_CED_MODEL_H_

View File

@@ -92,15 +92,32 @@ class OfflineStream::Impl {
fbank_ = std::make_unique<knf::OnlineFbank>(opts_); fbank_ = std::make_unique<knf::OnlineFbank>(opts_);
} }
Impl(WhisperTag /*tag*/, ContextGraphPtr context_graph) explicit Impl(WhisperTag /*tag*/) {
: context_graph_(context_graph) {
config_.normalize_samples = true; config_.normalize_samples = true;
opts_.frame_opts.samp_freq = 16000; opts_.frame_opts.samp_freq = 16000;
opts_.mel_opts.num_bins = 80; opts_.mel_opts.num_bins = 80; // not used
whisper_fbank_ = whisper_fbank_ =
std::make_unique<knf::OnlineWhisperFbank>(opts_.frame_opts); std::make_unique<knf::OnlineWhisperFbank>(opts_.frame_opts);
} }
explicit Impl(CEDTag /*tag*/) {
// see
// https://github.com/RicherMans/CED/blob/main/onnx_inference_with_kaldi.py
opts_.frame_opts.frame_length_ms = 32;
opts_.frame_opts.dither = 0;
opts_.frame_opts.preemph_coeff = 0;
opts_.frame_opts.remove_dc_offset = false;
opts_.frame_opts.window_type = "hann";
opts_.frame_opts.snip_edges = false;
opts_.frame_opts.samp_freq = 16000; // fixed to 16000
opts_.mel_opts.num_bins = 64;
opts_.mel_opts.high_freq = 8000;
fbank_ = std::make_unique<knf::OnlineFbank>(opts_);
}
void AcceptWaveform(int32_t sampling_rate, const float *waveform, int32_t n) { void AcceptWaveform(int32_t sampling_rate, const float *waveform, int32_t n) {
if (config_.normalize_samples) { if (config_.normalize_samples) {
AcceptWaveformImpl(sampling_rate, waveform, n); AcceptWaveformImpl(sampling_rate, waveform, n);
@@ -233,9 +250,10 @@ OfflineStream::OfflineStream(
ContextGraphPtr context_graph /*= nullptr*/) ContextGraphPtr context_graph /*= nullptr*/)
: impl_(std::make_unique<Impl>(config, context_graph)) {} : impl_(std::make_unique<Impl>(config, context_graph)) {}
OfflineStream::OfflineStream(WhisperTag tag, OfflineStream::OfflineStream(WhisperTag tag)
ContextGraphPtr context_graph /*= {}*/) : impl_(std::make_unique<Impl>(tag)) {}
: impl_(std::make_unique<Impl>(tag, context_graph)) {}
OfflineStream::OfflineStream(CEDTag tag) : impl_(std::make_unique<Impl>(tag)) {}
OfflineStream::~OfflineStream() = default; OfflineStream::~OfflineStream() = default;

View File

@@ -67,13 +67,15 @@ struct OfflineFeatureExtractorConfig {
}; };
struct WhisperTag {}; struct WhisperTag {};
struct CEDTag {};
class OfflineStream { class OfflineStream {
public: public:
explicit OfflineStream(const OfflineFeatureExtractorConfig &config = {}, explicit OfflineStream(const OfflineFeatureExtractorConfig &config = {},
ContextGraphPtr context_graph = {}); ContextGraphPtr context_graph = {});
explicit OfflineStream(WhisperTag tag, ContextGraphPtr context_graph = {}); explicit OfflineStream(WhisperTag tag);
explicit OfflineStream(CEDTag tag);
~OfflineStream(); ~OfflineStream();
/** /**

View File

@@ -31,6 +31,12 @@ static AudioTaggingConfig GetAudioTaggingConfig(JNIEnv *env, jobject config) {
ans.model.zipformer.model = p; ans.model.zipformer.model = p;
env->ReleaseStringUTFChars(s, p); env->ReleaseStringUTFChars(s, p);
fid = env->GetFieldID(model_cls, "ced", "Ljava/lang/String;");
s = (jstring)env->GetObjectField(model, fid);
p = env->GetStringUTFChars(s, nullptr);
ans.model.ced = p;
env->ReleaseStringUTFChars(s, p);
fid = env->GetFieldID(model_cls, "numThreads", "I"); fid = env->GetFieldID(model_cls, "numThreads", "I");
ans.model.num_threads = env->GetIntField(model, fid); ans.model.num_threads = env->GetIntField(model, fid);

View File

@@ -27,10 +27,11 @@ static void PybindAudioTaggingModelConfig(py::module *m) {
py::class_<PyClass>(*m, "AudioTaggingModelConfig") py::class_<PyClass>(*m, "AudioTaggingModelConfig")
.def(py::init<>()) .def(py::init<>())
.def(py::init<const OfflineZipformerAudioTaggingModelConfig &, int32_t, .def(py::init<const OfflineZipformerAudioTaggingModelConfig &,
bool, const std::string &>(), const std::string &, int32_t, bool, const std::string &>(),
py::arg("zipformer"), py::arg("num_threads") = 1, py::arg("zipformer"), py::arg("ced") = "",
py::arg("debug") = false, py::arg("provider") = "cpu") py::arg("num_threads") = 1, py::arg("debug") = false,
py::arg("provider") = "cpu")
.def_readwrite("zipformer", &PyClass::zipformer) .def_readwrite("zipformer", &PyClass::zipformer)
.def_readwrite("num_threads", &PyClass::num_threads) .def_readwrite("num_threads", &PyClass::num_threads)
.def_readwrite("debug", &PyClass::debug) .def_readwrite("debug", &PyClass::debug)