From 052b8645ba49194460bbf12a50d5079f33ba9545 Mon Sep 17 00:00:00 2001 From: Fangjun Kuang Date: Sun, 27 Oct 2024 09:04:05 +0800 Subject: [PATCH] Add Go API examples for adding punctuations to text. (#1478) --- .github/workflows/test-go-package.yaml | 7 +++++ .github/workflows/test-go.yaml | 6 ++++ go-api-examples/add-punctuation/go.mod | 3 ++ go-api-examples/add-punctuation/main.go | 31 ++++++++++++++++++++ go-api-examples/add-punctuation/run.sh | 14 +++++++++ scripts/go/_internal/add-punctuation/go.mod | 5 ++++ scripts/go/_internal/add-punctuation/main.go | 1 + scripts/go/_internal/add-punctuation/run.sh | 1 + scripts/go/sherpa_onnx.go | 12 ++++---- 9 files changed, 74 insertions(+), 6 deletions(-) create mode 100644 go-api-examples/add-punctuation/go.mod create mode 100644 go-api-examples/add-punctuation/main.go create mode 100755 go-api-examples/add-punctuation/run.sh create mode 100644 scripts/go/_internal/add-punctuation/go.mod create mode 120000 scripts/go/_internal/add-punctuation/main.go create mode 120000 scripts/go/_internal/add-punctuation/run.sh diff --git a/.github/workflows/test-go-package.yaml b/.github/workflows/test-go-package.yaml index 6839b0ba..9074449b 100644 --- a/.github/workflows/test-go-package.yaml +++ b/.github/workflows/test-go-package.yaml @@ -68,6 +68,13 @@ jobs: run: | gcc --version + - name: Test adding punctuation + if: matrix.os != 'windows-latest' + shell: bash + run: | + cd go-api-examples/add-punctuation/ + ./run.sh + - name: Test non-streaming speaker diarization if: matrix.os != 'windows-latest' shell: bash diff --git a/.github/workflows/test-go.yaml b/.github/workflows/test-go.yaml index ccb3eb8d..1e8ad9c8 100644 --- a/.github/workflows/test-go.yaml +++ b/.github/workflows/test-go.yaml @@ -134,6 +134,12 @@ jobs: name: ${{ matrix.os }}-libs path: to-upload/ + - name: Test adding punctuation + shell: bash + run: | + cd scripts/go/_internal/add-punctuation/ + ./run.sh + - name: Test non-streaming speaker diarization shell: bash run: | diff --git a/go-api-examples/add-punctuation/go.mod b/go-api-examples/add-punctuation/go.mod new file mode 100644 index 00000000..ec6d7580 --- /dev/null +++ b/go-api-examples/add-punctuation/go.mod @@ -0,0 +1,3 @@ +module add-punctuation + +go 1.12 diff --git a/go-api-examples/add-punctuation/main.go b/go-api-examples/add-punctuation/main.go new file mode 100644 index 00000000..055748ea --- /dev/null +++ b/go-api-examples/add-punctuation/main.go @@ -0,0 +1,31 @@ +package main + +import ( + sherpa "github.com/k2-fsa/sherpa-onnx-go/sherpa_onnx" + "log" +) + +func main() { + log.SetFlags(log.LstdFlags | log.Lmicroseconds) + + config := sherpa.OfflinePunctuationConfig{} + config.Model.CtTransformer = "./sherpa-onnx-punct-ct-transformer-zh-en-vocab272727-2024-04-12/model.onnx" + config.Model.NumThreads = 1 + config.Model.Provider = "cpu" + + punct := sherpa.NewOfflinePunctuation(&config) + defer sherpa.DeleteOfflinePunc(punct) + + textArray := []string{ + "这是一个测试你好吗How are you我很好thank you are you ok谢谢你", + "我们都是木头人不会说话不会动", + "The African blogosphere is rapidly expanding bringing more voices online in the form of commentaries opinions analyses rants and poetry", + } + log.Println("----------") + for _, text := range textArray { + newText := punct.AddPunct(text) + log.Printf("Input text: %v", text) + log.Printf("Output text: %v", newText) + log.Println("----------") + } +} diff --git a/go-api-examples/add-punctuation/run.sh b/go-api-examples/add-punctuation/run.sh new file mode 100755 index 00000000..6d43b84f --- /dev/null +++ b/go-api-examples/add-punctuation/run.sh @@ -0,0 +1,14 @@ +#!/usr/bin/env bash + +set -ex + +if [ ! -d ./sherpa-onnx-punct-ct-transformer-zh-en-vocab272727-2024-04-12 ]; then + curl -SL -O https://github.com/k2-fsa/sherpa-onnx/releases/download/punctuation-models/sherpa-onnx-punct-ct-transformer-zh-en-vocab272727-2024-04-12.tar.bz2 + tar xvf sherpa-onnx-punct-ct-transformer-zh-en-vocab272727-2024-04-12.tar.bz2 + rm sherpa-onnx-punct-ct-transformer-zh-en-vocab272727-2024-04-12.tar.bz2 +fi + +go mod tidy +go build + +./add-punctuation diff --git a/scripts/go/_internal/add-punctuation/go.mod b/scripts/go/_internal/add-punctuation/go.mod new file mode 100644 index 00000000..f25ca0a8 --- /dev/null +++ b/scripts/go/_internal/add-punctuation/go.mod @@ -0,0 +1,5 @@ +module add-punctuation + +go 1.12 + +replace github.com/k2-fsa/sherpa-onnx-go/sherpa_onnx => ../ diff --git a/scripts/go/_internal/add-punctuation/main.go b/scripts/go/_internal/add-punctuation/main.go new file mode 120000 index 00000000..84df910c --- /dev/null +++ b/scripts/go/_internal/add-punctuation/main.go @@ -0,0 +1 @@ +../../../../go-api-examples/add-punctuation/main.go \ No newline at end of file diff --git a/scripts/go/_internal/add-punctuation/run.sh b/scripts/go/_internal/add-punctuation/run.sh new file mode 120000 index 00000000..2b1ee21b --- /dev/null +++ b/scripts/go/_internal/add-punctuation/run.sh @@ -0,0 +1 @@ +../../../../go-api-examples/add-punctuation/run.sh \ No newline at end of file diff --git a/scripts/go/sherpa_onnx.go b/scripts/go/sherpa_onnx.go index 45bf714b..30ca31dc 100644 --- a/scripts/go/sherpa_onnx.go +++ b/scripts/go/sherpa_onnx.go @@ -1322,10 +1322,10 @@ func (sd *OfflineSpeakerDiarization) Process(samples []float32) []OfflineSpeaker // For punctuation // ============================================================ type OfflinePunctuationModelConfig struct { - Ct_transformer string - Num_threads C.int - Debug C.int // true to print debug information of the model - Provider string + CtTransformer string + NumThreads C.int + Debug C.int // true to print debug information of the model + Provider string } type OfflinePunctuationConfig struct { @@ -1338,10 +1338,10 @@ type OfflinePunctuation struct { func NewOfflinePunctuation(config *OfflinePunctuationConfig) *OfflinePunctuation { cfg := C.struct_SherpaOnnxOfflinePunctuationConfig{} - cfg.model.ct_transformer = C.CString(config.Model.Ct_transformer) + cfg.model.ct_transformer = C.CString(config.Model.CtTransformer) defer C.free(unsafe.Pointer(cfg.model.ct_transformer)) - cfg.model.num_threads = config.Model.Num_threads + cfg.model.num_threads = config.Model.NumThreads cfg.model.debug = config.Model.Debug cfg.model.provider = C.CString(config.Model.Provider) defer C.free(unsafe.Pointer(cfg.model.provider))