Add Go API for audio tagging (#1840)
This commit is contained in:
15
.github/workflows/test-go-package.yaml
vendored
15
.github/workflows/test-go-package.yaml
vendored
@@ -26,6 +26,8 @@ jobs:
|
||||
include:
|
||||
- os: ubuntu-latest
|
||||
arch: amd64
|
||||
- os: ubuntu-22.04-arm
|
||||
arch: arm64
|
||||
- os: macos-13
|
||||
arch: amd64
|
||||
- os: macos-14
|
||||
@@ -460,6 +462,19 @@ jobs:
|
||||
./run-tdnn-yesno.sh
|
||||
rm -rf sherpa-onnx-tdnn-yesno
|
||||
|
||||
- name: Test audio tagging (Linux/macOS)
|
||||
if: matrix.os != 'windows-latest'
|
||||
shell: bash
|
||||
run: |
|
||||
cd go-api-examples/audio-tagging
|
||||
ls -lh
|
||||
go mod tidy
|
||||
cat go.mod
|
||||
go build
|
||||
ls -lh
|
||||
|
||||
./run.sh
|
||||
|
||||
- name: Test streaming decoding files (Linux/macOS)
|
||||
if: matrix.os != 'windows-latest'
|
||||
shell: bash
|
||||
|
||||
13
.github/workflows/test-go.yaml
vendored
13
.github/workflows/test-go.yaml
vendored
@@ -33,7 +33,7 @@ jobs:
|
||||
strategy:
|
||||
fail-fast: false
|
||||
matrix:
|
||||
os: [macos-latest, macos-13, ubuntu-latest, windows-latest]
|
||||
os: [macos-latest, macos-13, ubuntu-latest, windows-latest, ubuntu-22.04-arm]
|
||||
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
@@ -87,7 +87,7 @@ jobs:
|
||||
make -j2 install
|
||||
fi
|
||||
|
||||
if [[ ${{ matrix.os }} == ubuntu-latest ]]; then
|
||||
if [[ ${{ matrix.os }} == ubuntu-latest || ${{ matrix.os }} == ubuntu-22.04-arm ]]; then
|
||||
cp -v ./lib/*.so $upload_dir
|
||||
cp -v _deps/onnxruntime-src/lib/libonnxruntime*so* $upload_dir
|
||||
|
||||
@@ -132,6 +132,15 @@ jobs:
|
||||
name: ${{ matrix.os }}-libs
|
||||
path: to-upload/
|
||||
|
||||
- name: Test audio tagging
|
||||
shell: bash
|
||||
run: |
|
||||
cd scripts/go/_internal/audio-tagging/
|
||||
|
||||
./run.sh
|
||||
|
||||
ls -lh
|
||||
|
||||
- name: Test Keyword spotting
|
||||
shell: bash
|
||||
run: |
|
||||
|
||||
4
go-api-examples/audio-tagging/go.mod
Normal file
4
go-api-examples/audio-tagging/go.mod
Normal file
@@ -0,0 +1,4 @@
|
||||
module audio-tagging
|
||||
|
||||
go 1.12
|
||||
|
||||
36
go-api-examples/audio-tagging/main.go
Normal file
36
go-api-examples/audio-tagging/main.go
Normal file
@@ -0,0 +1,36 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
sherpa "github.com/k2-fsa/sherpa-onnx-go/sherpa_onnx"
|
||||
"log"
|
||||
)
|
||||
|
||||
func main() {
|
||||
config := sherpa.AudioTaggingConfig{}
|
||||
config.Model.Zipformer.Model = "./sherpa-onnx-zipformer-small-audio-tagging-2024-04-15/model.int8.onnx"
|
||||
config.Model.NumThreads = 1
|
||||
config.Model.Debug = 1
|
||||
config.Model.Provider = "cpu"
|
||||
config.Labels = "./sherpa-onnx-zipformer-small-audio-tagging-2024-04-15/class_labels_indices.csv"
|
||||
config.TopK = 5
|
||||
|
||||
tagging := sherpa.NewAudioTagging(&config)
|
||||
defer sherpa.DeleteAudioTagging(tagging)
|
||||
|
||||
wave_filename := "./sherpa-onnx-zipformer-small-audio-tagging-2024-04-15/test_wavs/3.wav"
|
||||
|
||||
wave := sherpa.ReadWave(wave_filename)
|
||||
if wave == nil {
|
||||
log.Printf("Failed to read %v\n", wave_filename)
|
||||
return
|
||||
}
|
||||
|
||||
stream := sherpa.NewAudioTaggingStream(tagging)
|
||||
defer sherpa.DeleteOfflineStream(stream)
|
||||
|
||||
stream.AcceptWaveform(wave.SampleRate, wave.Samples)
|
||||
|
||||
result := tagging.Compute(stream, 10)
|
||||
fmt.Printf("the tagging result: %v\n", result)
|
||||
}
|
||||
13
go-api-examples/audio-tagging/run.sh
Executable file
13
go-api-examples/audio-tagging/run.sh
Executable file
@@ -0,0 +1,13 @@
|
||||
#!/usr/bin/env bash
|
||||
|
||||
if [ ! -f ./sherpa-onnx-zipformer-small-audio-tagging-2024-04-15/model.int8.onnx ]; then
|
||||
curl -SL -O https://github.com/k2-fsa/sherpa-onnx/releases/download/audio-tagging-models/sherpa-onnx-zipformer-small-audio-tagging-2024-04-15.tar.bz2
|
||||
|
||||
tar xvf sherpa-onnx-zipformer-small-audio-tagging-2024-04-15.tar.bz2
|
||||
rm sherpa-onnx-zipformer-small-audio-tagging-2024-04-15.tar.bz2
|
||||
fi
|
||||
|
||||
go mod tidy
|
||||
go build
|
||||
|
||||
./audio-tagging
|
||||
1
scripts/go/_internal/audio-tagging/.gitignore
vendored
Normal file
1
scripts/go/_internal/audio-tagging/.gitignore
vendored
Normal file
@@ -0,0 +1 @@
|
||||
audio-tagging
|
||||
5
scripts/go/_internal/audio-tagging/go.mod
Normal file
5
scripts/go/_internal/audio-tagging/go.mod
Normal file
@@ -0,0 +1,5 @@
|
||||
module audio-tagging
|
||||
|
||||
go 1.12
|
||||
|
||||
replace github.com/k2-fsa/sherpa-onnx-go/sherpa_onnx => ../
|
||||
1
scripts/go/_internal/audio-tagging/main.go
Symbolic link
1
scripts/go/_internal/audio-tagging/main.go
Symbolic link
@@ -0,0 +1 @@
|
||||
../../../../go-api-examples/audio-tagging/main.go
|
||||
1
scripts/go/_internal/audio-tagging/run.sh
Symbolic link
1
scripts/go/_internal/audio-tagging/run.sh
Symbolic link
@@ -0,0 +1 @@
|
||||
../../../../go-api-examples/audio-tagging/run.sh
|
||||
1
scripts/go/_internal/lib/aarch64-unknown-linux-gnu
Symbolic link
1
scripts/go/_internal/lib/aarch64-unknown-linux-gnu
Symbolic link
@@ -0,0 +1 @@
|
||||
../../../../build/lib
|
||||
@@ -1607,3 +1607,95 @@ func (spotter *KeywordSpotter) GetResult(s *OnlineStream) *KeywordSpotterResult
|
||||
result.Keyword = C.GoString(p.keyword)
|
||||
return result
|
||||
}
|
||||
|
||||
// Configuration for the audio tagging.
|
||||
type OfflineZipformerAudioTaggingModelConfig struct {
|
||||
Model string
|
||||
}
|
||||
|
||||
type AudioTaggingModelConfig struct {
|
||||
Zipformer OfflineZipformerAudioTaggingModelConfig
|
||||
Ced string
|
||||
NumThreads int32
|
||||
Debug int32
|
||||
Provider string
|
||||
}
|
||||
|
||||
type AudioTaggingConfig struct {
|
||||
Model AudioTaggingModelConfig
|
||||
Labels string
|
||||
TopK int32
|
||||
}
|
||||
|
||||
type AudioTagging struct {
|
||||
impl *C.struct_SherpaOnnxAudioTagging
|
||||
}
|
||||
|
||||
type AudioEvent struct {
|
||||
Name string
|
||||
Index int
|
||||
Prob float32
|
||||
}
|
||||
|
||||
func DeleteAudioTagging(tagging *AudioTagging) {
|
||||
C.SherpaOnnxDestroyAudioTagging(tagging.impl)
|
||||
tagging.impl = nil
|
||||
}
|
||||
|
||||
// The user is responsible to invoke [DeleteAudioTagging]() to free
|
||||
// the returned tagger to avoid memory leak
|
||||
func NewAudioTagging(config *AudioTaggingConfig) *AudioTagging {
|
||||
c := C.struct_SherpaOnnxAudioTaggingConfig{}
|
||||
|
||||
c.model.zipformer.model = C.CString(config.Model.Zipformer.Model)
|
||||
defer C.free(unsafe.Pointer(c.model.zipformer.model))
|
||||
|
||||
c.model.ced = C.CString(config.Model.Ced)
|
||||
defer C.free(unsafe.Pointer(c.model.ced))
|
||||
|
||||
c.model.num_threads = C.int(config.Model.NumThreads)
|
||||
|
||||
c.model.provider = C.CString(config.Model.Provider)
|
||||
defer C.free(unsafe.Pointer(c.model.provider))
|
||||
|
||||
c.model.debug = C.int(config.Model.Debug)
|
||||
|
||||
c.labels = C.CString(config.Labels)
|
||||
defer C.free(unsafe.Pointer(c.labels))
|
||||
|
||||
c.top_k = C.int(config.TopK)
|
||||
|
||||
tagging := &AudioTagging{}
|
||||
tagging.impl = C.SherpaOnnxCreateAudioTagging(&c)
|
||||
|
||||
return tagging
|
||||
}
|
||||
|
||||
// The user is responsible to invoke [DeleteOfflineStream]() to free
|
||||
// the returned stream to avoid memory leak
|
||||
func NewAudioTaggingStream(tagging *AudioTagging) *OfflineStream {
|
||||
stream := &OfflineStream{}
|
||||
stream.impl = C.SherpaOnnxAudioTaggingCreateOfflineStream(tagging.impl)
|
||||
return stream
|
||||
}
|
||||
|
||||
func (tagging *AudioTagging) Compute(s *OfflineStream, topK int32) []AudioEvent {
|
||||
r := C.SherpaOnnxAudioTaggingCompute(tagging.impl, s.impl, C.int(topK))
|
||||
defer C.SherpaOnnxAudioTaggingFreeResults(r)
|
||||
result := make([]AudioEvent, 0)
|
||||
|
||||
p := (*[1 << 28]*C.struct_SherpaOnnxAudioEvent)(unsafe.Pointer(r))
|
||||
i := 0
|
||||
for {
|
||||
if p[i] == nil {
|
||||
break
|
||||
}
|
||||
result = append(result, AudioEvent{
|
||||
Name: C.GoString(p[i].name),
|
||||
Index: int(p[i].index),
|
||||
Prob: float32(p[i].prob),
|
||||
})
|
||||
i += 1
|
||||
}
|
||||
return result
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user