Support CED models (#792)
This commit is contained in:
@@ -1,16 +1,16 @@
|
||||
package com.k2fsa.sherpa.onnx
|
||||
|
||||
import android.content.res.AssetManager
|
||||
import android.util.Log
|
||||
|
||||
private val TAG = "sherpa-onnx"
|
||||
const val TAG = "sherpa-onnx"
|
||||
|
||||
data class OfflineZipformerAudioTaggingModelConfig(
|
||||
var model: String,
|
||||
var model: String = "",
|
||||
)
|
||||
|
||||
data class AudioTaggingModelConfig(
|
||||
var zipformer: OfflineZipformerAudioTaggingModelConfig,
|
||||
var zipformer: OfflineZipformerAudioTaggingModelConfig = OfflineZipformerAudioTaggingModelConfig(),
|
||||
var ced: String = "",
|
||||
var numThreads: Int = 1,
|
||||
var debug: Boolean = false,
|
||||
var provider: String = "cpu",
|
||||
@@ -103,7 +103,7 @@ class AudioTagging(
|
||||
//
|
||||
// See also
|
||||
// https://k2-fsa.github.io/sherpa/onnx/audio-tagging/
|
||||
fun getAudioTaggingConfig(type: Int, numThreads: Int=1): AudioTaggingConfig? {
|
||||
fun getAudioTaggingConfig(type: Int, numThreads: Int = 1): AudioTaggingConfig? {
|
||||
when (type) {
|
||||
0 -> {
|
||||
val modelDir = "sherpa-onnx-zipformer-small-audio-tagging-2024-04-15"
|
||||
@@ -123,7 +123,7 @@ fun getAudioTaggingConfig(type: Int, numThreads: Int=1): AudioTaggingConfig? {
|
||||
return AudioTaggingConfig(
|
||||
model = AudioTaggingModelConfig(
|
||||
zipformer = OfflineZipformerAudioTaggingModelConfig(model = "$modelDir/model.int8.onnx"),
|
||||
numThreads = 1,
|
||||
numThreads = numThreads,
|
||||
debug = true,
|
||||
),
|
||||
labels = "$modelDir/class_labels_indices.csv",
|
||||
@@ -131,6 +131,57 @@ fun getAudioTaggingConfig(type: Int, numThreads: Int=1): AudioTaggingConfig? {
|
||||
)
|
||||
}
|
||||
|
||||
2 -> {
|
||||
val modelDir = "sherpa-onnx-ced-tiny-audio-tagging-2024-04-19"
|
||||
return AudioTaggingConfig(
|
||||
model = AudioTaggingModelConfig(
|
||||
ced = "$modelDir/model.int8.onnx",
|
||||
numThreads = numThreads,
|
||||
debug = true,
|
||||
),
|
||||
labels = "$modelDir/class_labels_indices.csv",
|
||||
topK = 3,
|
||||
)
|
||||
}
|
||||
|
||||
3 -> {
|
||||
val modelDir = "sherpa-onnx-ced-mini-audio-tagging-2024-04-19"
|
||||
return AudioTaggingConfig(
|
||||
model = AudioTaggingModelConfig(
|
||||
ced = "$modelDir/model.int8.onnx",
|
||||
numThreads = numThreads,
|
||||
debug = true,
|
||||
),
|
||||
labels = "$modelDir/class_labels_indices.csv",
|
||||
topK = 3,
|
||||
)
|
||||
}
|
||||
|
||||
4 -> {
|
||||
val modelDir = "sherpa-onnx-ced-small-audio-tagging-2024-04-19"
|
||||
return AudioTaggingConfig(
|
||||
model = AudioTaggingModelConfig(
|
||||
ced = "$modelDir/model.int8.onnx",
|
||||
numThreads = numThreads,
|
||||
debug = true,
|
||||
),
|
||||
labels = "$modelDir/class_labels_indices.csv",
|
||||
topK = 3,
|
||||
)
|
||||
}
|
||||
|
||||
5 -> {
|
||||
val modelDir = "sherpa-onnx-ced-base-audio-tagging-2024-04-19"
|
||||
return AudioTaggingConfig(
|
||||
model = AudioTaggingModelConfig(
|
||||
ced = "$modelDir/model.int8.onnx",
|
||||
numThreads = numThreads,
|
||||
debug = true,
|
||||
),
|
||||
labels = "$modelDir/class_labels_indices.csv",
|
||||
topK = 3,
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
return null
|
||||
|
||||
@@ -3,24 +3,15 @@
|
||||
package com.k2fsa.sherpa.onnx.audio.tagging
|
||||
|
||||
import android.Manifest
|
||||
|
||||
import android.app.Activity
|
||||
import android.content.pm.PackageManager
|
||||
import android.media.AudioFormat
|
||||
import android.media.AudioRecord
|
||||
import androidx.compose.foundation.lazy.items
|
||||
import android.media.MediaRecorder
|
||||
import android.util.Log
|
||||
import androidx.compose.foundation.ExperimentalFoundationApi
|
||||
import androidx.compose.foundation.background
|
||||
import androidx.compose.foundation.layout.Arrangement
|
||||
import androidx.compose.foundation.layout.Box
|
||||
import androidx.compose.material3.CenterAlignedTopAppBar
|
||||
import androidx.compose.runtime.Composable
|
||||
import androidx.compose.material3.Scaffold
|
||||
import androidx.compose.material3.TopAppBarDefaults
|
||||
import androidx.compose.material3.MaterialTheme
|
||||
import androidx.compose.material3.Text
|
||||
import androidx.compose.foundation.layout.Column
|
||||
import androidx.compose.foundation.layout.PaddingValues
|
||||
import androidx.compose.foundation.layout.Row
|
||||
@@ -30,10 +21,17 @@ import androidx.compose.foundation.layout.fillMaxWidth
|
||||
import androidx.compose.foundation.layout.height
|
||||
import androidx.compose.foundation.layout.padding
|
||||
import androidx.compose.foundation.lazy.LazyColumn
|
||||
import androidx.compose.foundation.lazy.items
|
||||
import androidx.compose.material3.Button
|
||||
import androidx.compose.material3.CenterAlignedTopAppBar
|
||||
import androidx.compose.material3.ExperimentalMaterial3Api
|
||||
import androidx.compose.material3.MaterialTheme
|
||||
import androidx.compose.material3.Scaffold
|
||||
import androidx.compose.material3.Slider
|
||||
import androidx.compose.material3.Surface
|
||||
import androidx.compose.material3.Text
|
||||
import androidx.compose.material3.TopAppBarDefaults
|
||||
import androidx.compose.runtime.Composable
|
||||
import androidx.compose.runtime.getValue
|
||||
import androidx.compose.runtime.mutableStateListOf
|
||||
import androidx.compose.runtime.mutableStateOf
|
||||
@@ -41,7 +39,6 @@ import androidx.compose.runtime.remember
|
||||
import androidx.compose.runtime.setValue
|
||||
import androidx.compose.ui.Alignment
|
||||
import androidx.compose.ui.Modifier
|
||||
import androidx.compose.ui.graphics.Color
|
||||
import androidx.compose.ui.platform.LocalContext
|
||||
import androidx.compose.ui.text.font.FontWeight
|
||||
import androidx.compose.ui.text.style.TextAlign
|
||||
@@ -49,6 +46,7 @@ import androidx.compose.ui.unit.dp
|
||||
import androidx.compose.ui.unit.sp
|
||||
import androidx.core.app.ActivityCompat
|
||||
import com.k2fsa.sherpa.onnx.AudioEvent
|
||||
import com.k2fsa.sherpa.onnx.Tagger
|
||||
import kotlin.concurrent.thread
|
||||
|
||||
|
||||
|
||||
@@ -13,6 +13,7 @@ import androidx.compose.material3.Surface
|
||||
import androidx.compose.runtime.Composable
|
||||
import androidx.compose.ui.Modifier
|
||||
import androidx.core.app.ActivityCompat
|
||||
import com.k2fsa.sherpa.onnx.Tagger
|
||||
import com.k2fsa.sherpa.onnx.audio.tagging.ui.theme.SherpaOnnxAudioTaggingTheme
|
||||
|
||||
const val TAG = "sherpa-onnx"
|
||||
|
||||
@@ -1,10 +1,8 @@
|
||||
package com.k2fsa.sherpa.onnx.audio.tagging
|
||||
package com.k2fsa.sherpa.onnx
|
||||
|
||||
import android.content.res.AssetManager
|
||||
import android.util.Log
|
||||
import com.k2fsa.sherpa.onnx.AudioTagging
|
||||
import com.k2fsa.sherpa.onnx.audio.tagging.wear.os.presentation.TAG
|
||||
import com.k2fsa.sherpa.onnx.getAudioTaggingConfig
|
||||
|
||||
|
||||
object Tagger {
|
||||
private var _tagger: AudioTagging? = null
|
||||
@@ -12,6 +10,7 @@ object Tagger {
|
||||
get() {
|
||||
return _tagger!!
|
||||
}
|
||||
|
||||
fun initTagger(assetManager: AssetManager? = null, numThreads: Int = 1) {
|
||||
synchronized(this) {
|
||||
if (_tagger != null) {
|
||||
@@ -19,7 +18,7 @@ object Tagger {
|
||||
}
|
||||
|
||||
Log.i(TAG, "Initializing audio tagger")
|
||||
val config = getAudioTaggingConfig(type = 0, numThreads=numThreads)!!
|
||||
val config = getAudioTaggingConfig(type = 0, numThreads = numThreads)!!
|
||||
_tagger = AudioTagging(assetManager, config)
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user