diff --git a/.github/workflows/test-go-package.yaml b/.github/workflows/test-go-package.yaml index f2e4cb1b..f587032e 100644 --- a/.github/workflows/test-go-package.yaml +++ b/.github/workflows/test-go-package.yaml @@ -26,6 +26,8 @@ jobs: include: - os: ubuntu-latest arch: amd64 + - os: ubuntu-22.04-arm + arch: arm64 - os: macos-13 arch: amd64 - os: macos-14 @@ -460,6 +462,19 @@ jobs: ./run-tdnn-yesno.sh rm -rf sherpa-onnx-tdnn-yesno + - name: Test audio tagging (Linux/macOS) + if: matrix.os != 'windows-latest' + shell: bash + run: | + cd go-api-examples/audio-tagging + ls -lh + go mod tidy + cat go.mod + go build + ls -lh + + ./run.sh + - name: Test streaming decoding files (Linux/macOS) if: matrix.os != 'windows-latest' shell: bash diff --git a/.github/workflows/test-go.yaml b/.github/workflows/test-go.yaml index 8d68076d..569ba578 100644 --- a/.github/workflows/test-go.yaml +++ b/.github/workflows/test-go.yaml @@ -33,7 +33,7 @@ jobs: strategy: fail-fast: false matrix: - os: [macos-latest, macos-13, ubuntu-latest, windows-latest] + os: [macos-latest, macos-13, ubuntu-latest, windows-latest, ubuntu-22.04-arm] steps: - uses: actions/checkout@v4 @@ -87,7 +87,7 @@ jobs: make -j2 install fi - if [[ ${{ matrix.os }} == ubuntu-latest ]]; then + if [[ ${{ matrix.os }} == ubuntu-latest || ${{ matrix.os }} == ubuntu-22.04-arm ]]; then cp -v ./lib/*.so $upload_dir cp -v _deps/onnxruntime-src/lib/libonnxruntime*so* $upload_dir @@ -132,6 +132,15 @@ jobs: name: ${{ matrix.os }}-libs path: to-upload/ + - name: Test audio tagging + shell: bash + run: | + cd scripts/go/_internal/audio-tagging/ + + ./run.sh + + ls -lh + - name: Test Keyword spotting shell: bash run: | diff --git a/go-api-examples/audio-tagging/go.mod b/go-api-examples/audio-tagging/go.mod new file mode 100644 index 00000000..63158071 --- /dev/null +++ b/go-api-examples/audio-tagging/go.mod @@ -0,0 +1,4 @@ +module audio-tagging + +go 1.12 + diff --git a/go-api-examples/audio-tagging/main.go b/go-api-examples/audio-tagging/main.go new file mode 100644 index 00000000..fd86e49f --- /dev/null +++ b/go-api-examples/audio-tagging/main.go @@ -0,0 +1,36 @@ +package main + +import ( + "fmt" + sherpa "github.com/k2-fsa/sherpa-onnx-go/sherpa_onnx" + "log" +) + +func main() { + config := sherpa.AudioTaggingConfig{} + config.Model.Zipformer.Model = "./sherpa-onnx-zipformer-small-audio-tagging-2024-04-15/model.int8.onnx" + config.Model.NumThreads = 1 + config.Model.Debug = 1 + config.Model.Provider = "cpu" + config.Labels = "./sherpa-onnx-zipformer-small-audio-tagging-2024-04-15/class_labels_indices.csv" + config.TopK = 5 + + tagging := sherpa.NewAudioTagging(&config) + defer sherpa.DeleteAudioTagging(tagging) + + wave_filename := "./sherpa-onnx-zipformer-small-audio-tagging-2024-04-15/test_wavs/3.wav" + + wave := sherpa.ReadWave(wave_filename) + if wave == nil { + log.Printf("Failed to read %v\n", wave_filename) + return + } + + stream := sherpa.NewAudioTaggingStream(tagging) + defer sherpa.DeleteOfflineStream(stream) + + stream.AcceptWaveform(wave.SampleRate, wave.Samples) + + result := tagging.Compute(stream, 10) + fmt.Printf("the tagging result: %v\n", result) +} diff --git a/go-api-examples/audio-tagging/run.sh b/go-api-examples/audio-tagging/run.sh new file mode 100755 index 00000000..a4da9519 --- /dev/null +++ b/go-api-examples/audio-tagging/run.sh @@ -0,0 +1,13 @@ +#!/usr/bin/env bash + +if [ ! -f ./sherpa-onnx-zipformer-small-audio-tagging-2024-04-15/model.int8.onnx ]; then + curl -SL -O https://github.com/k2-fsa/sherpa-onnx/releases/download/audio-tagging-models/sherpa-onnx-zipformer-small-audio-tagging-2024-04-15.tar.bz2 + + tar xvf sherpa-onnx-zipformer-small-audio-tagging-2024-04-15.tar.bz2 + rm sherpa-onnx-zipformer-small-audio-tagging-2024-04-15.tar.bz2 +fi + +go mod tidy +go build + +./audio-tagging diff --git a/scripts/go/_internal/audio-tagging/.gitignore b/scripts/go/_internal/audio-tagging/.gitignore new file mode 100644 index 00000000..1623c402 --- /dev/null +++ b/scripts/go/_internal/audio-tagging/.gitignore @@ -0,0 +1 @@ +audio-tagging diff --git a/scripts/go/_internal/audio-tagging/go.mod b/scripts/go/_internal/audio-tagging/go.mod new file mode 100644 index 00000000..e19aae58 --- /dev/null +++ b/scripts/go/_internal/audio-tagging/go.mod @@ -0,0 +1,5 @@ +module audio-tagging + +go 1.12 + +replace github.com/k2-fsa/sherpa-onnx-go/sherpa_onnx => ../ diff --git a/scripts/go/_internal/audio-tagging/main.go b/scripts/go/_internal/audio-tagging/main.go new file mode 120000 index 00000000..cfc179ad --- /dev/null +++ b/scripts/go/_internal/audio-tagging/main.go @@ -0,0 +1 @@ +../../../../go-api-examples/audio-tagging/main.go \ No newline at end of file diff --git a/scripts/go/_internal/audio-tagging/run.sh b/scripts/go/_internal/audio-tagging/run.sh new file mode 120000 index 00000000..34ccc7d0 --- /dev/null +++ b/scripts/go/_internal/audio-tagging/run.sh @@ -0,0 +1 @@ +../../../../go-api-examples/audio-tagging/run.sh \ No newline at end of file diff --git a/scripts/go/_internal/lib/aarch64-unknown-linux-gnu b/scripts/go/_internal/lib/aarch64-unknown-linux-gnu new file mode 120000 index 00000000..b21b8dfa --- /dev/null +++ b/scripts/go/_internal/lib/aarch64-unknown-linux-gnu @@ -0,0 +1 @@ +../../../../build/lib \ No newline at end of file diff --git a/scripts/go/sherpa_onnx.go b/scripts/go/sherpa_onnx.go index 88777a49..490f36ab 100644 --- a/scripts/go/sherpa_onnx.go +++ b/scripts/go/sherpa_onnx.go @@ -1607,3 +1607,95 @@ func (spotter *KeywordSpotter) GetResult(s *OnlineStream) *KeywordSpotterResult result.Keyword = C.GoString(p.keyword) return result } + +// Configuration for the audio tagging. +type OfflineZipformerAudioTaggingModelConfig struct { + Model string +} + +type AudioTaggingModelConfig struct { + Zipformer OfflineZipformerAudioTaggingModelConfig + Ced string + NumThreads int32 + Debug int32 + Provider string +} + +type AudioTaggingConfig struct { + Model AudioTaggingModelConfig + Labels string + TopK int32 +} + +type AudioTagging struct { + impl *C.struct_SherpaOnnxAudioTagging +} + +type AudioEvent struct { + Name string + Index int + Prob float32 +} + +func DeleteAudioTagging(tagging *AudioTagging) { + C.SherpaOnnxDestroyAudioTagging(tagging.impl) + tagging.impl = nil +} + +// The user is responsible to invoke [DeleteAudioTagging]() to free +// the returned tagger to avoid memory leak +func NewAudioTagging(config *AudioTaggingConfig) *AudioTagging { + c := C.struct_SherpaOnnxAudioTaggingConfig{} + + c.model.zipformer.model = C.CString(config.Model.Zipformer.Model) + defer C.free(unsafe.Pointer(c.model.zipformer.model)) + + c.model.ced = C.CString(config.Model.Ced) + defer C.free(unsafe.Pointer(c.model.ced)) + + c.model.num_threads = C.int(config.Model.NumThreads) + + c.model.provider = C.CString(config.Model.Provider) + defer C.free(unsafe.Pointer(c.model.provider)) + + c.model.debug = C.int(config.Model.Debug) + + c.labels = C.CString(config.Labels) + defer C.free(unsafe.Pointer(c.labels)) + + c.top_k = C.int(config.TopK) + + tagging := &AudioTagging{} + tagging.impl = C.SherpaOnnxCreateAudioTagging(&c) + + return tagging +} + +// The user is responsible to invoke [DeleteOfflineStream]() to free +// the returned stream to avoid memory leak +func NewAudioTaggingStream(tagging *AudioTagging) *OfflineStream { + stream := &OfflineStream{} + stream.impl = C.SherpaOnnxAudioTaggingCreateOfflineStream(tagging.impl) + return stream +} + +func (tagging *AudioTagging) Compute(s *OfflineStream, topK int32) []AudioEvent { + r := C.SherpaOnnxAudioTaggingCompute(tagging.impl, s.impl, C.int(topK)) + defer C.SherpaOnnxAudioTaggingFreeResults(r) + result := make([]AudioEvent, 0) + + p := (*[1 << 28]*C.struct_SherpaOnnxAudioEvent)(unsafe.Pointer(r)) + i := 0 + for { + if p[i] == nil { + break + } + result = append(result, AudioEvent{ + Name: C.GoString(p[i].name), + Index: int(p[i].index), + Prob: float32(p[i].prob), + }) + i += 1 + } + return result +}