diff --git a/.github/scripts/test-offline-punctuation.sh b/.github/scripts/test-offline-punctuation.sh index 6a096c36..bca0ede0 100755 --- a/.github/scripts/test-offline-punctuation.sh +++ b/.github/scripts/test-offline-punctuation.sh @@ -14,7 +14,7 @@ echo "PATH: $PATH" which $EXE log "------------------------------------------------------------" -log "Download model " +log "Download the punctuation model " log "------------------------------------------------------------" curl -SL -O https://github.com/k2-fsa/sherpa-onnx/releases/download/punctuation-models/sherpa-onnx-punct-ct-transformer-zh-en-vocab272727-2024-04-12.tar.bz2 diff --git a/.github/scripts/test-python.sh b/.github/scripts/test-python.sh index aa9b795f..fe0f568f 100755 --- a/.github/scripts/test-python.sh +++ b/.github/scripts/test-python.sh @@ -8,6 +8,18 @@ log() { echo -e "$(date '+%Y-%m-%d %H:%M:%S') (${fname}:${BASH_LINENO[0]}:${FUNCNAME[1]}) $*" } +log "test offline punctuation" + +curl -SL -O https://github.com/k2-fsa/sherpa-onnx/releases/download/punctuation-models/sherpa-onnx-punct-ct-transformer-zh-en-vocab272727-2024-04-12.tar.bz2 +tar xvf sherpa-onnx-punct-ct-transformer-zh-en-vocab272727-2024-04-12.tar.bz2 +rm sherpa-onnx-punct-ct-transformer-zh-en-vocab272727-2024-04-12.tar.bz2 +repo=sherpa-onnx-punct-ct-transformer-zh-en-vocab272727-2024-04-12 +ls -lh $repo + +python3 ./python-api-examples/add-punctuation.py + +rm -rf $repo + log "test audio tagging" curl -SL -O https://github.com/k2-fsa/sherpa-onnx/releases/download/audio-tagging-models/sherpa-onnx-zipformer-audio-tagging-2024-04-09.tar.bz2 diff --git a/.gitignore b/.gitignore index a51cd0ea..3047a1e0 100644 --- a/.gitignore +++ b/.gitignore @@ -91,3 +91,4 @@ sr-data *xcworkspace/xcuserdata/* vits-icefall-* +sherpa-onnx-punct-ct-transformer-zh-en-vocab272727-2024-04-12 diff --git a/go-api-examples/vad-asr-paraformer/run.sh b/go-api-examples/vad-asr-paraformer/run.sh index c2f65d9b..9d136804 100755 --- a/go-api-examples/vad-asr-paraformer/run.sh +++ b/go-api-examples/vad-asr-paraformer/run.sh @@ -2,7 +2,7 @@ if [ ! -f ./silero_vad.onnx ]; then - curl -SL -O https://github.com/snakers4/silero-vad/blob/master/files/silero_vad.onnx + curl -SL -O https://github.com/snakers4/silero-vad/raw/master/files/silero_vad.onnx fi if [ ! -f ./sherpa-onnx-paraformer-trilingual-zh-cantonese-en/model.int8.onnx ]; then diff --git a/go-api-examples/vad-asr-whisper/run.sh b/go-api-examples/vad-asr-whisper/run.sh index 2ae3b5af..8064887d 100755 --- a/go-api-examples/vad-asr-whisper/run.sh +++ b/go-api-examples/vad-asr-whisper/run.sh @@ -2,7 +2,7 @@ if [ ! -f ./silero_vad.onnx ]; then - curl -SL -O https://github.com/snakers4/silero-vad/blob/master/files/silero_vad.onnx + curl -SL -O https://github.com/snakers4/silero-vad/raw/master/files/silero_vad.onnx fi if [ ! -f ./sherpa-onnx-whisper-tiny.en/tiny.en-encoder.int8.onnx ]; then diff --git a/go-api-examples/vad-speaker-identification/run.sh b/go-api-examples/vad-speaker-identification/run.sh index e500f5c0..1df02678 100755 --- a/go-api-examples/vad-speaker-identification/run.sh +++ b/go-api-examples/vad-speaker-identification/run.sh @@ -9,7 +9,7 @@ if [ ! -f ./sr-data/enroll/fangjun-sr-1.wav ]; then fi if [ ! -f ./silero_vad.onnx ]; then - curl -SL -O https://github.com/snakers4/silero-vad/blob/master/files/silero_vad.onnx + curl -SL -O https://github.com/snakers4/silero-vad/raw/master/files/silero_vad.onnx fi go mod tidy diff --git a/go-api-examples/vad-spoken-language-identification/run.sh b/go-api-examples/vad-spoken-language-identification/run.sh index fc3c219e..43ae2525 100755 --- a/go-api-examples/vad-spoken-language-identification/run.sh +++ b/go-api-examples/vad-spoken-language-identification/run.sh @@ -2,7 +2,7 @@ if [ ! -f ./silero_vad.onnx ]; then - curl -SL -O https://github.com/snakers4/silero-vad/blob/master/files/silero_vad.onnx + curl -SL -O https://github.com/snakers4/silero-vad/raw/master/files/silero_vad.onnx fi if [ ! -f ./sherpa-onnx-whisper-tiny/tiny-encoder.int8.onnx ]; then diff --git a/go-api-examples/vad/run.sh b/go-api-examples/vad/run.sh index 1584b99f..13e505ce 100755 --- a/go-api-examples/vad/run.sh +++ b/go-api-examples/vad/run.sh @@ -2,7 +2,7 @@ if [ ! -f ./silero_vad.onnx ]; then - curl -SL -O https://github.com/snakers4/silero-vad/blob/master/files/silero_vad.onnx + curl -SL -O https://github.com/snakers4/silero-vad/raw/master/files/silero_vad.onnx fi go mod tidy diff --git a/python-api-examples/add-punctuation.py b/python-api-examples/add-punctuation.py new file mode 100755 index 00000000..7db3e190 --- /dev/null +++ b/python-api-examples/add-punctuation.py @@ -0,0 +1,46 @@ +#!/usr/bin/env python3 + +""" +This script shows how to add punctuations to text using sherpa-onnx Python API. + +Please download the model from +https://github.com/k2-fsa/sherpa-onnx/releases/tag/punctuation-models + +The following is an example + +wget https://github.com/k2-fsa/sherpa-onnx/releases/download/punctuation-models/sherpa-onnx-punct-ct-transformer-zh-en-vocab272727-2024-04-12.tar.bz2 +tar xvf sherpa-onnx-punct-ct-transformer-zh-en-vocab272727-2024-04-12.tar.bz2 +rm sherpa-onnx-punct-ct-transformer-zh-en-vocab272727-2024-04-12.tar.bz2 +""" + +from pathlib import Path + +import sherpa_onnx + + +def main(): + model = "./sherpa-onnx-punct-ct-transformer-zh-en-vocab272727-2024-04-12/model.onnx" + if not Path(model).is_file(): + raise ValueError(f"{model} does not exist") + config = sherpa_onnx.OfflinePunctuationConfig( + model=sherpa_onnx.OfflinePunctuationModelConfig(ct_transformer=model), + ) + + punct = sherpa_onnx.OfflinePunctuation(config) + + text_list = [ + "这是一个测试你好吗How are you我很好thank you are you ok谢谢你", + "我们都是木头人不会说话不会动", + "The African blogosphere is rapidly expanding bringing more voices online in the form of commentaries opinions analyses rants and poetry", + ] + for text in text_list: + text_with_punct = punct.add_punctuation(text) + print("----------") + print(f"input: {text}") + print(f"output: {text_with_punct}") + + print("----------") + + +if __name__ == "__main__": + main() diff --git a/sherpa-onnx/python/csrc/CMakeLists.txt b/sherpa-onnx/python/csrc/CMakeLists.txt index 266b7c31..e4bff01c 100644 --- a/sherpa-onnx/python/csrc/CMakeLists.txt +++ b/sherpa-onnx/python/csrc/CMakeLists.txt @@ -12,6 +12,7 @@ set(srcs offline-model-config.cc offline-nemo-enc-dec-ctc-model-config.cc offline-paraformer-model-config.cc + offline-punctuation.cc offline-recognizer.cc offline-stream.cc offline-tdnn-model-config.cc diff --git a/sherpa-onnx/python/csrc/offline-punctuation.cc b/sherpa-onnx/python/csrc/offline-punctuation.cc new file mode 100644 index 00000000..7d3ff86d --- /dev/null +++ b/sherpa-onnx/python/csrc/offline-punctuation.cc @@ -0,0 +1,49 @@ +// sherpa-onnx/python/csrc/offline-punctuation.cc +// +// Copyright (c) 2024 Xiaomi Corporation + +#include "sherpa-onnx/python/csrc/offline-punctuation.h" + +#include "sherpa-onnx/csrc/offline-punctuation.h" + +namespace sherpa_onnx { + +static void PybindOfflinePunctuationModelConfig(py::module *m) { + using PyClass = OfflinePunctuationModelConfig; + py::class_(*m, "OfflinePunctuationModelConfig") + .def(py::init<>()) + .def(py::init(), + py::arg("ct_transformer"), py::arg("num_threads") = 1, + py::arg("debug") = false, py::arg("provider") = "cpu") + .def_readwrite("ct_transformer", &PyClass::ct_transformer) + .def_readwrite("num_threads", &PyClass::num_threads) + .def_readwrite("debug", &PyClass::debug) + .def_readwrite("provider", &PyClass::provider) + .def("validate", &PyClass::Validate) + .def("__str__", &PyClass::ToString); +} + +static void PybindOfflinePunctuationConfig(py::module *m) { + PybindOfflinePunctuationModelConfig(m); + using PyClass = OfflinePunctuationConfig; + + py::class_(*m, "OfflinePunctuationConfig") + .def(py::init<>()) + .def(py::init(), py::arg("model")) + .def_readwrite("model", &PyClass::model) + .def("validate", &PyClass::Validate) + .def("__str__", &PyClass::ToString); +} + +void PybindOfflinePunctuation(py::module *m) { + PybindOfflinePunctuationConfig(m); + using PyClass = OfflinePunctuation; + + py::class_(*m, "OfflinePunctuation") + .def(py::init(), py::arg("config"), + py::call_guard()) + .def("add_punctuation", &PyClass::AddPunctuation, py::arg("text"), + py::call_guard()); +} + +} // namespace sherpa_onnx diff --git a/sherpa-onnx/python/csrc/offline-punctuation.h b/sherpa-onnx/python/csrc/offline-punctuation.h new file mode 100644 index 00000000..015f9476 --- /dev/null +++ b/sherpa-onnx/python/csrc/offline-punctuation.h @@ -0,0 +1,16 @@ +// sherpa-onnx/python/csrc/offline-punctuation.h +// +// Copyright (c) 2024 Xiaomi Corporation + +#ifndef SHERPA_ONNX_PYTHON_CSRC_OFFLINE_PUNCTUATION_H_ +#define SHERPA_ONNX_PYTHON_CSRC_OFFLINE_PUNCTUATION_H_ + +#include "sherpa-onnx/python/csrc/sherpa-onnx.h" + +namespace sherpa_onnx { + +void PybindOfflinePunctuation(py::module *m); + +} + +#endif // SHERPA_ONNX_PYTHON_CSRC_OFFLINE_PUNCTUATION_H_ diff --git a/sherpa-onnx/python/csrc/sherpa-onnx.cc b/sherpa-onnx/python/csrc/sherpa-onnx.cc index 31dd9baf..242d8597 100644 --- a/sherpa-onnx/python/csrc/sherpa-onnx.cc +++ b/sherpa-onnx/python/csrc/sherpa-onnx.cc @@ -14,6 +14,7 @@ #include "sherpa-onnx/python/csrc/offline-ctc-fst-decoder-config.h" #include "sherpa-onnx/python/csrc/offline-lm-config.h" #include "sherpa-onnx/python/csrc/offline-model-config.h" +#include "sherpa-onnx/python/csrc/offline-punctuation.h" #include "sherpa-onnx/python/csrc/offline-recognizer.h" #include "sherpa-onnx/python/csrc/offline-stream.h" #include "sherpa-onnx/python/csrc/online-ctc-fst-decoder-config.h" @@ -40,6 +41,7 @@ PYBIND11_MODULE(_sherpa_onnx, m) { PybindWaveWriter(&m); PybindAudioTagging(&m); + PybindOfflinePunctuation(&m); PybindFeatures(&m); PybindOnlineCtcFstDecoderConfig(&m); diff --git a/sherpa-onnx/python/sherpa_onnx/__init__.py b/sherpa-onnx/python/sherpa_onnx/__init__.py index 2b160731..7a832ba0 100644 --- a/sherpa-onnx/python/sherpa_onnx/__init__.py +++ b/sherpa-onnx/python/sherpa_onnx/__init__.py @@ -6,6 +6,9 @@ from _sherpa_onnx import ( AudioTaggingModelConfig, CircularBuffer, Display, + OfflinePunctuation, + OfflinePunctuationConfig, + OfflinePunctuationModelConfig, OfflineStream, OfflineTts, OfflineTtsConfig,