Add Go API for audio tagging (#1840)
This commit is contained in:
15
.github/workflows/test-go-package.yaml
vendored
15
.github/workflows/test-go-package.yaml
vendored
@@ -26,6 +26,8 @@ jobs:
|
|||||||
include:
|
include:
|
||||||
- os: ubuntu-latest
|
- os: ubuntu-latest
|
||||||
arch: amd64
|
arch: amd64
|
||||||
|
- os: ubuntu-22.04-arm
|
||||||
|
arch: arm64
|
||||||
- os: macos-13
|
- os: macos-13
|
||||||
arch: amd64
|
arch: amd64
|
||||||
- os: macos-14
|
- os: macos-14
|
||||||
@@ -460,6 +462,19 @@ jobs:
|
|||||||
./run-tdnn-yesno.sh
|
./run-tdnn-yesno.sh
|
||||||
rm -rf sherpa-onnx-tdnn-yesno
|
rm -rf sherpa-onnx-tdnn-yesno
|
||||||
|
|
||||||
|
- name: Test audio tagging (Linux/macOS)
|
||||||
|
if: matrix.os != 'windows-latest'
|
||||||
|
shell: bash
|
||||||
|
run: |
|
||||||
|
cd go-api-examples/audio-tagging
|
||||||
|
ls -lh
|
||||||
|
go mod tidy
|
||||||
|
cat go.mod
|
||||||
|
go build
|
||||||
|
ls -lh
|
||||||
|
|
||||||
|
./run.sh
|
||||||
|
|
||||||
- name: Test streaming decoding files (Linux/macOS)
|
- name: Test streaming decoding files (Linux/macOS)
|
||||||
if: matrix.os != 'windows-latest'
|
if: matrix.os != 'windows-latest'
|
||||||
shell: bash
|
shell: bash
|
||||||
|
|||||||
13
.github/workflows/test-go.yaml
vendored
13
.github/workflows/test-go.yaml
vendored
@@ -33,7 +33,7 @@ jobs:
|
|||||||
strategy:
|
strategy:
|
||||||
fail-fast: false
|
fail-fast: false
|
||||||
matrix:
|
matrix:
|
||||||
os: [macos-latest, macos-13, ubuntu-latest, windows-latest]
|
os: [macos-latest, macos-13, ubuntu-latest, windows-latest, ubuntu-22.04-arm]
|
||||||
|
|
||||||
steps:
|
steps:
|
||||||
- uses: actions/checkout@v4
|
- uses: actions/checkout@v4
|
||||||
@@ -87,7 +87,7 @@ jobs:
|
|||||||
make -j2 install
|
make -j2 install
|
||||||
fi
|
fi
|
||||||
|
|
||||||
if [[ ${{ matrix.os }} == ubuntu-latest ]]; then
|
if [[ ${{ matrix.os }} == ubuntu-latest || ${{ matrix.os }} == ubuntu-22.04-arm ]]; then
|
||||||
cp -v ./lib/*.so $upload_dir
|
cp -v ./lib/*.so $upload_dir
|
||||||
cp -v _deps/onnxruntime-src/lib/libonnxruntime*so* $upload_dir
|
cp -v _deps/onnxruntime-src/lib/libonnxruntime*so* $upload_dir
|
||||||
|
|
||||||
@@ -132,6 +132,15 @@ jobs:
|
|||||||
name: ${{ matrix.os }}-libs
|
name: ${{ matrix.os }}-libs
|
||||||
path: to-upload/
|
path: to-upload/
|
||||||
|
|
||||||
|
- name: Test audio tagging
|
||||||
|
shell: bash
|
||||||
|
run: |
|
||||||
|
cd scripts/go/_internal/audio-tagging/
|
||||||
|
|
||||||
|
./run.sh
|
||||||
|
|
||||||
|
ls -lh
|
||||||
|
|
||||||
- name: Test Keyword spotting
|
- name: Test Keyword spotting
|
||||||
shell: bash
|
shell: bash
|
||||||
run: |
|
run: |
|
||||||
|
|||||||
4
go-api-examples/audio-tagging/go.mod
Normal file
4
go-api-examples/audio-tagging/go.mod
Normal file
@@ -0,0 +1,4 @@
|
|||||||
|
module audio-tagging
|
||||||
|
|
||||||
|
go 1.12
|
||||||
|
|
||||||
36
go-api-examples/audio-tagging/main.go
Normal file
36
go-api-examples/audio-tagging/main.go
Normal file
@@ -0,0 +1,36 @@
|
|||||||
|
package main
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
sherpa "github.com/k2-fsa/sherpa-onnx-go/sherpa_onnx"
|
||||||
|
"log"
|
||||||
|
)
|
||||||
|
|
||||||
|
func main() {
|
||||||
|
config := sherpa.AudioTaggingConfig{}
|
||||||
|
config.Model.Zipformer.Model = "./sherpa-onnx-zipformer-small-audio-tagging-2024-04-15/model.int8.onnx"
|
||||||
|
config.Model.NumThreads = 1
|
||||||
|
config.Model.Debug = 1
|
||||||
|
config.Model.Provider = "cpu"
|
||||||
|
config.Labels = "./sherpa-onnx-zipformer-small-audio-tagging-2024-04-15/class_labels_indices.csv"
|
||||||
|
config.TopK = 5
|
||||||
|
|
||||||
|
tagging := sherpa.NewAudioTagging(&config)
|
||||||
|
defer sherpa.DeleteAudioTagging(tagging)
|
||||||
|
|
||||||
|
wave_filename := "./sherpa-onnx-zipformer-small-audio-tagging-2024-04-15/test_wavs/3.wav"
|
||||||
|
|
||||||
|
wave := sherpa.ReadWave(wave_filename)
|
||||||
|
if wave == nil {
|
||||||
|
log.Printf("Failed to read %v\n", wave_filename)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
stream := sherpa.NewAudioTaggingStream(tagging)
|
||||||
|
defer sherpa.DeleteOfflineStream(stream)
|
||||||
|
|
||||||
|
stream.AcceptWaveform(wave.SampleRate, wave.Samples)
|
||||||
|
|
||||||
|
result := tagging.Compute(stream, 10)
|
||||||
|
fmt.Printf("the tagging result: %v\n", result)
|
||||||
|
}
|
||||||
13
go-api-examples/audio-tagging/run.sh
Executable file
13
go-api-examples/audio-tagging/run.sh
Executable file
@@ -0,0 +1,13 @@
|
|||||||
|
#!/usr/bin/env bash
|
||||||
|
|
||||||
|
if [ ! -f ./sherpa-onnx-zipformer-small-audio-tagging-2024-04-15/model.int8.onnx ]; then
|
||||||
|
curl -SL -O https://github.com/k2-fsa/sherpa-onnx/releases/download/audio-tagging-models/sherpa-onnx-zipformer-small-audio-tagging-2024-04-15.tar.bz2
|
||||||
|
|
||||||
|
tar xvf sherpa-onnx-zipformer-small-audio-tagging-2024-04-15.tar.bz2
|
||||||
|
rm sherpa-onnx-zipformer-small-audio-tagging-2024-04-15.tar.bz2
|
||||||
|
fi
|
||||||
|
|
||||||
|
go mod tidy
|
||||||
|
go build
|
||||||
|
|
||||||
|
./audio-tagging
|
||||||
1
scripts/go/_internal/audio-tagging/.gitignore
vendored
Normal file
1
scripts/go/_internal/audio-tagging/.gitignore
vendored
Normal file
@@ -0,0 +1 @@
|
|||||||
|
audio-tagging
|
||||||
5
scripts/go/_internal/audio-tagging/go.mod
Normal file
5
scripts/go/_internal/audio-tagging/go.mod
Normal file
@@ -0,0 +1,5 @@
|
|||||||
|
module audio-tagging
|
||||||
|
|
||||||
|
go 1.12
|
||||||
|
|
||||||
|
replace github.com/k2-fsa/sherpa-onnx-go/sherpa_onnx => ../
|
||||||
1
scripts/go/_internal/audio-tagging/main.go
Symbolic link
1
scripts/go/_internal/audio-tagging/main.go
Symbolic link
@@ -0,0 +1 @@
|
|||||||
|
../../../../go-api-examples/audio-tagging/main.go
|
||||||
1
scripts/go/_internal/audio-tagging/run.sh
Symbolic link
1
scripts/go/_internal/audio-tagging/run.sh
Symbolic link
@@ -0,0 +1 @@
|
|||||||
|
../../../../go-api-examples/audio-tagging/run.sh
|
||||||
1
scripts/go/_internal/lib/aarch64-unknown-linux-gnu
Symbolic link
1
scripts/go/_internal/lib/aarch64-unknown-linux-gnu
Symbolic link
@@ -0,0 +1 @@
|
|||||||
|
../../../../build/lib
|
||||||
@@ -1607,3 +1607,95 @@ func (spotter *KeywordSpotter) GetResult(s *OnlineStream) *KeywordSpotterResult
|
|||||||
result.Keyword = C.GoString(p.keyword)
|
result.Keyword = C.GoString(p.keyword)
|
||||||
return result
|
return result
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Configuration for the audio tagging.
|
||||||
|
type OfflineZipformerAudioTaggingModelConfig struct {
|
||||||
|
Model string
|
||||||
|
}
|
||||||
|
|
||||||
|
type AudioTaggingModelConfig struct {
|
||||||
|
Zipformer OfflineZipformerAudioTaggingModelConfig
|
||||||
|
Ced string
|
||||||
|
NumThreads int32
|
||||||
|
Debug int32
|
||||||
|
Provider string
|
||||||
|
}
|
||||||
|
|
||||||
|
type AudioTaggingConfig struct {
|
||||||
|
Model AudioTaggingModelConfig
|
||||||
|
Labels string
|
||||||
|
TopK int32
|
||||||
|
}
|
||||||
|
|
||||||
|
type AudioTagging struct {
|
||||||
|
impl *C.struct_SherpaOnnxAudioTagging
|
||||||
|
}
|
||||||
|
|
||||||
|
type AudioEvent struct {
|
||||||
|
Name string
|
||||||
|
Index int
|
||||||
|
Prob float32
|
||||||
|
}
|
||||||
|
|
||||||
|
func DeleteAudioTagging(tagging *AudioTagging) {
|
||||||
|
C.SherpaOnnxDestroyAudioTagging(tagging.impl)
|
||||||
|
tagging.impl = nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// The user is responsible to invoke [DeleteAudioTagging]() to free
|
||||||
|
// the returned tagger to avoid memory leak
|
||||||
|
func NewAudioTagging(config *AudioTaggingConfig) *AudioTagging {
|
||||||
|
c := C.struct_SherpaOnnxAudioTaggingConfig{}
|
||||||
|
|
||||||
|
c.model.zipformer.model = C.CString(config.Model.Zipformer.Model)
|
||||||
|
defer C.free(unsafe.Pointer(c.model.zipformer.model))
|
||||||
|
|
||||||
|
c.model.ced = C.CString(config.Model.Ced)
|
||||||
|
defer C.free(unsafe.Pointer(c.model.ced))
|
||||||
|
|
||||||
|
c.model.num_threads = C.int(config.Model.NumThreads)
|
||||||
|
|
||||||
|
c.model.provider = C.CString(config.Model.Provider)
|
||||||
|
defer C.free(unsafe.Pointer(c.model.provider))
|
||||||
|
|
||||||
|
c.model.debug = C.int(config.Model.Debug)
|
||||||
|
|
||||||
|
c.labels = C.CString(config.Labels)
|
||||||
|
defer C.free(unsafe.Pointer(c.labels))
|
||||||
|
|
||||||
|
c.top_k = C.int(config.TopK)
|
||||||
|
|
||||||
|
tagging := &AudioTagging{}
|
||||||
|
tagging.impl = C.SherpaOnnxCreateAudioTagging(&c)
|
||||||
|
|
||||||
|
return tagging
|
||||||
|
}
|
||||||
|
|
||||||
|
// The user is responsible to invoke [DeleteOfflineStream]() to free
|
||||||
|
// the returned stream to avoid memory leak
|
||||||
|
func NewAudioTaggingStream(tagging *AudioTagging) *OfflineStream {
|
||||||
|
stream := &OfflineStream{}
|
||||||
|
stream.impl = C.SherpaOnnxAudioTaggingCreateOfflineStream(tagging.impl)
|
||||||
|
return stream
|
||||||
|
}
|
||||||
|
|
||||||
|
func (tagging *AudioTagging) Compute(s *OfflineStream, topK int32) []AudioEvent {
|
||||||
|
r := C.SherpaOnnxAudioTaggingCompute(tagging.impl, s.impl, C.int(topK))
|
||||||
|
defer C.SherpaOnnxAudioTaggingFreeResults(r)
|
||||||
|
result := make([]AudioEvent, 0)
|
||||||
|
|
||||||
|
p := (*[1 << 28]*C.struct_SherpaOnnxAudioEvent)(unsafe.Pointer(r))
|
||||||
|
i := 0
|
||||||
|
for {
|
||||||
|
if p[i] == nil {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
result = append(result, AudioEvent{
|
||||||
|
Name: C.GoString(p[i].name),
|
||||||
|
Index: int(p[i].index),
|
||||||
|
Prob: float32(p[i].prob),
|
||||||
|
})
|
||||||
|
i += 1
|
||||||
|
}
|
||||||
|
return result
|
||||||
|
}
|
||||||
|
|||||||
Reference in New Issue
Block a user