diff --git a/.github/scripts/test-c-api.sh b/.github/scripts/test-c-api.sh index b29d1a0b..ce2f6350 100755 --- a/.github/scripts/test-c-api.sh +++ b/.github/scripts/test-c-api.sh @@ -11,8 +11,21 @@ log() { echo "SLID_EXE is $SLID_EXE" echo "SID_EXE is $SID_EXE" echo "AT_EXE is $AT_EXE" +echo "PUNCT_EXE is $PUNCT_EXE" echo "PATH: $PATH" +log "------------------------------------------------------------" +log "Test adding punctuations " +log "------------------------------------------------------------" + +curl -SL -O https://github.com/k2-fsa/sherpa-onnx/releases/download/punctuation-models/sherpa-onnx-punct-ct-transformer-zh-en-vocab272727-2024-04-12.tar.bz2 +ls -lh +tar xf sherpa-onnx-punct-ct-transformer-zh-en-vocab272727-2024-04-12.tar.bz2 +ls -lh sherpa-onnx-punct-ct-transformer-zh-en-vocab272727-2024-04-12 +rm sherpa-onnx-punct-ct-transformer-zh-en-vocab272727-2024-04-12.tar.bz2 +$PUNCT_EXE +rm -rf sherpa-onnx-punct-ct-transformer-zh-en-vocab272727-2024-04-12 + log "------------------------------------------------------------" log "Test audio tagging " log "------------------------------------------------------------" diff --git a/.github/workflows/linux.yaml b/.github/workflows/linux.yaml index 0b8ba50f..260b99af 100644 --- a/.github/workflows/linux.yaml +++ b/.github/workflows/linux.yaml @@ -126,7 +126,7 @@ jobs: - uses: actions/upload-artifact@v4 with: name: release-${{ matrix.build_type }}-with-shared-lib-${{ matrix.shared_lib }}-with-tts-${{ matrix.with_tts }} - path: build/bin/* + path: install/* - name: Test offline punctuation shell: bash @@ -143,6 +143,7 @@ jobs: export SLID_EXE=spoken-language-identification-c-api export SID_EXE=speaker-identification-c-api export AT_EXE=audio-tagging-c-api + export PUNCT_EXE=add-punctuation-c-api .github/scripts/test-c-api.sh diff --git a/.github/workflows/macos.yaml b/.github/workflows/macos.yaml index ecb2f835..e70ff11e 100644 --- a/.github/workflows/macos.yaml +++ b/.github/workflows/macos.yaml @@ -122,6 +122,7 @@ jobs: export SLID_EXE=spoken-language-identification-c-api export SID_EXE=speaker-identification-c-api export AT_EXE=audio-tagging-c-api + export PUNCT_EXE=add-punctuation-c-api .github/scripts/test-c-api.sh diff --git a/.github/workflows/windows-x64.yaml b/.github/workflows/windows-x64.yaml index 55eedb37..d160e475 100644 --- a/.github/workflows/windows-x64.yaml +++ b/.github/workflows/windows-x64.yaml @@ -89,6 +89,7 @@ jobs: export SLID_EXE=spoken-language-identification-c-api.exe export SID_EXE=speaker-identification-c-api.exe export AT_EXE=audio-tagging-c-api.exe + export PUNCT_EXE=add-punctuation-c-api.exe .github/scripts/test-c-api.sh diff --git a/.github/workflows/windows-x86.yaml b/.github/workflows/windows-x86.yaml index b579487a..c476ab10 100644 --- a/.github/workflows/windows-x86.yaml +++ b/.github/workflows/windows-x86.yaml @@ -89,6 +89,7 @@ jobs: export SLID_EXE=spoken-language-identification-c-api.exe export SID_EXE=speaker-identification-c-api.exe export AT_EXE=audio-tagging-c-api.exe + export PUNCT_EXE=add-punctuation-c-api.exe .github/scripts/test-c-api.sh diff --git a/CMakeLists.txt b/CMakeLists.txt index d42cc8ba..8bacab02 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -1,7 +1,7 @@ cmake_minimum_required(VERSION 3.13 FATAL_ERROR) project(sherpa-onnx) -set(SHERPA_ONNX_VERSION "1.9.19") +set(SHERPA_ONNX_VERSION "1.9.21") # Disable warning about # diff --git a/c-api-examples/CMakeLists.txt b/c-api-examples/CMakeLists.txt index 8d9bfe98..3aa44547 100644 --- a/c-api-examples/CMakeLists.txt +++ b/c-api-examples/CMakeLists.txt @@ -21,6 +21,9 @@ target_link_libraries(streaming-hlg-decode-file-c-api sherpa-onnx-c-api) add_executable(audio-tagging-c-api audio-tagging-c-api.c) target_link_libraries(audio-tagging-c-api sherpa-onnx-c-api) +add_executable(add-punctuation-c-api add-punctuation-c-api.c) +target_link_libraries(add-punctuation-c-api sherpa-onnx-c-api) + if(SHERPA_ONNX_HAS_ALSA) add_subdirectory(./asr-microphone-example) elseif((UNIX AND NOT APPLE) OR LINUX) diff --git a/c-api-examples/add-punctuation-c-api.c b/c-api-examples/add-punctuation-c-api.c new file mode 100644 index 00000000..9041e4ba --- /dev/null +++ b/c-api-examples/add-punctuation-c-api.c @@ -0,0 +1,67 @@ +// c-api-examples/add-punctuation-c-api.c +// +// Copyright (c) 2024 Xiaomi Corporation + +// We assume you have pre-downloaded the model files for testing +// from https://github.com/k2-fsa/sherpa-onnx/releases/tag/punctuation-models +// +// An example is given below: +// +// clang-format off +// +// wget https://github.com/k2-fsa/sherpa-onnx/releases/download/punctuation-models/sherpa-onnx-punct-ct-transformer-zh-en-vocab272727-2024-04-12.tar.bz2 +// tar xvf sherpa-onnx-punct-ct-transformer-zh-en-vocab272727-2024-04-12.tar.bz2 +// rm sherpa-onnx-punct-ct-transformer-zh-en-vocab272727-2024-04-12.tar.bz2 +// +// clang-format on + +#include +#include +#include + +#include "sherpa-onnx/c-api/c-api.h" + +int32_t main() { + SherpaOnnxOfflinePunctuationConfig config; + memset(&config, 0, sizeof(config)); + + // clang-format off + config.model.ct_transformer = "./sherpa-onnx-punct-ct-transformer-zh-en-vocab272727-2024-04-12/model.onnx"; + // clang-format on + config.model.num_threads = 1; + config.model.debug = 1; + config.model.provider = "cpu"; + + const SherpaOnnxOfflinePunctuation *punct = + SherpaOnnxCreateOfflinePunctuation(&config); + if (!punct) { + fprintf(stderr, + "Failed to create OfflinePunctuation. Please check your config"); + return -1; + } + + const char *texts[] = { + "这是一个测试你好吗How are you我很好thank you are you ok谢谢你", + "我们都是木头人不会说话不会动", + "The African blogosphere is rapidly expanding bringing more voices " + "online in the form of commentaries opinions analyses rants and poetry", + }; + + int32_t n = sizeof(texts) / sizeof(const char *); + fprintf(stderr, "n: %d\n", n); + + fprintf(stderr, "--------------------\n"); + for (int32_t i = 0; i != n; ++i) { + const char *text_with_punct = + SherpaOfflinePunctuationAddPunct(punct, texts[i]); + + fprintf(stderr, "Input text: %s\n", texts[i]); + fprintf(stderr, "Output text: %s\n", text_with_punct); + SherpaOfflinePunctuationFreeText(text_with_punct); + fprintf(stderr, "--------------------\n"); + } + + SherpaOnnxDestroyOfflinePunctuation(punct); + + return 0; +}; diff --git a/python-api-examples/streaming-paraformer-asr-microphone.py b/python-api-examples/streaming-paraformer-asr-microphone.py new file mode 100755 index 00000000..ad5c8f70 --- /dev/null +++ b/python-api-examples/streaming-paraformer-asr-microphone.py @@ -0,0 +1,105 @@ +#!/usr/bin/env python3 + +# Real-time speech recognition from a microphone with sherpa-onnx Python API +# with endpoint detection. +# This script uses a streaming paraformer +# +# Please refer to +# https://k2-fsa.github.io/sherpa/onnx/pretrained_models/online-paraformer/paraformer-models.html# +# to download pre-trained models + +import sys +from pathlib import Path + +try: + import sounddevice as sd +except ImportError: + print("Please install sounddevice first. You can use") + print() + print(" pip install sounddevice") + print() + print("to install it") + sys.exit(-1) + +import sherpa_onnx + + +def assert_file_exists(filename: str): + assert Path(filename).is_file(), ( + f"{filename} does not exist!\n" + "Please refer to " + "https://k2-fsa.github.io/sherpa/onnx/pretrained_models/online-paraformer/paraformer-models.html to download it" + ) + + +def create_recognizer(): + encoder = "./sherpa-onnx-streaming-paraformer-bilingual-zh-en/encoder.int8.onnx" + decoder = "./sherpa-onnx-streaming-paraformer-bilingual-zh-en/decoder.int8.onnx" + tokens = "./sherpa-onnx-streaming-paraformer-bilingual-zh-en/tokens.txt" + assert_file_exists(encoder) + assert_file_exists(decoder) + assert_file_exists(tokens) + recognizer = sherpa_onnx.OnlineRecognizer.from_paraformer( + tokens=tokens, + encoder=encoder, + decoder=decoder, + num_threads=1, + sample_rate=16000, + feature_dim=80, + enable_endpoint_detection=True, + rule1_min_trailing_silence=2.4, + rule2_min_trailing_silence=1.2, + rule3_min_utterance_length=300, # it essentially disables this rule + ) + return recognizer + + +def main(): + devices = sd.query_devices() + if len(devices) == 0: + print("No microphone devices found") + sys.exit(0) + + print(devices) + default_input_device_idx = sd.default.device[0] + print(f'Use default device: {devices[default_input_device_idx]["name"]}') + + recognizer = create_recognizer() + print("Started! Please speak") + + # The model is using 16 kHz, we use 48 kHz here to demonstrate that + # sherpa-onnx will do resampling inside. + sample_rate = 48000 + samples_per_read = int(0.1 * sample_rate) # 0.1 second = 100 ms + + stream = recognizer.create_stream() + + last_result = "" + segment_id = 0 + with sd.InputStream(channels=1, dtype="float32", samplerate=sample_rate) as s: + while True: + samples, _ = s.read(samples_per_read) # a blocking read + samples = samples.reshape(-1) + stream.accept_waveform(sample_rate, samples) + while recognizer.is_ready(stream): + recognizer.decode_stream(stream) + + is_endpoint = recognizer.is_endpoint(stream) + + result = recognizer.get_result(stream) + + if result and (last_result != result): + last_result = result + print("\r{}:{}".format(segment_id, result), end="", flush=True) + if is_endpoint: + if result: + print("\r{}:{}".format(segment_id, result), flush=True) + segment_id += 1 + recognizer.reset(stream) + + +if __name__ == "__main__": + try: + main() + except KeyboardInterrupt: + print("\nCaught Ctrl + C. Exiting") diff --git a/sherpa-onnx/c-api/c-api.cc b/sherpa-onnx/c-api/c-api.cc index 995817a0..81643e86 100644 --- a/sherpa-onnx/c-api/c-api.cc +++ b/sherpa-onnx/c-api/c-api.cc @@ -15,6 +15,7 @@ #include "sherpa-onnx/csrc/display.h" #include "sherpa-onnx/csrc/keyword-spotter.h" #include "sherpa-onnx/csrc/macros.h" +#include "sherpa-onnx/csrc/offline-punctuation.h" #include "sherpa-onnx/csrc/offline-recognizer.h" #include "sherpa-onnx/csrc/online-recognizer.h" #include "sherpa-onnx/csrc/speaker-embedding-extractor.h" @@ -1299,3 +1300,48 @@ void SherpaOnnxAudioTaggingFreeResults( delete[] events; } + +struct SherpaOnnxOfflinePunctuation { + std::unique_ptr impl; +}; + +const SherpaOnnxOfflinePunctuation *SherpaOnnxCreateOfflinePunctuation( + const SherpaOnnxOfflinePunctuationConfig *config) { + sherpa_onnx::OfflinePunctuationConfig c; + c.model.ct_transformer = SHERPA_ONNX_OR(config->model.ct_transformer, ""); + c.model.num_threads = SHERPA_ONNX_OR(config->model.num_threads, 1); + c.model.debug = config->model.debug; + c.model.provider = SHERPA_ONNX_OR(config->model.provider, "cpu"); + + if (c.model.debug) { + SHERPA_ONNX_LOGE("%s\n", c.ToString().c_str()); + } + + if (!c.Validate()) { + SHERPA_ONNX_LOGE("Errors in config"); + return nullptr; + } + + SherpaOnnxOfflinePunctuation *punct = new SherpaOnnxOfflinePunctuation; + punct->impl = std::make_unique(c); + + return punct; +} + +void SherpaOnnxDestroyOfflinePunctuation( + const SherpaOnnxOfflinePunctuation *punct) { + delete punct; +} + +const char *SherpaOfflinePunctuationAddPunct( + const SherpaOnnxOfflinePunctuation *punct, const char *text) { + std::string text_with_punct = punct->impl->AddPunctuation(text); + + char *ans = new char[text_with_punct.size() + 1]; + std::copy(text_with_punct.begin(), text_with_punct.end(), ans); + ans[text_with_punct.size()] = 0; + + return ans; +} + +void SherpaOfflinePunctuationFreeText(const char *text) { delete[] text; } diff --git a/sherpa-onnx/c-api/c-api.h b/sherpa-onnx/c-api/c-api.h index 3833209a..1faf4f71 100644 --- a/sherpa-onnx/c-api/c-api.h +++ b/sherpa-onnx/c-api/c-api.h @@ -1149,6 +1149,41 @@ SherpaOnnxAudioTaggingCompute(const SherpaOnnxAudioTagging *tagger, SHERPA_ONNX_API void SherpaOnnxAudioTaggingFreeResults( const SherpaOnnxAudioEvent *const *p); +// ============================================================ +// For punctuation +// ============================================================ + +SHERPA_ONNX_API typedef struct SherpaOnnxOfflinePunctuationModelConfig { + const char *ct_transformer; + int32_t num_threads; + int32_t debug; // true to print debug information of the model + const char *provider; +} SherpaOnnxOfflinePunctuationModelConfig; + +SHERPA_ONNX_API typedef struct SherpaOnnxOfflinePunctuationConfig { + SherpaOnnxOfflinePunctuationModelConfig model; +} SherpaOnnxOfflinePunctuationConfig; + +SHERPA_ONNX_API typedef struct SherpaOnnxOfflinePunctuation + SherpaOnnxOfflinePunctuation; + +// The user has to invoke SherpaOnnxDestroyOfflinePunctuation() +// to free the returned pointer to avoid memory leak +SHERPA_ONNX_API const SherpaOnnxOfflinePunctuation * +SherpaOnnxCreateOfflinePunctuation( + const SherpaOnnxOfflinePunctuationConfig *config); + +SHERPA_ONNX_API void SherpaOnnxDestroyOfflinePunctuation( + const SherpaOnnxOfflinePunctuation *punct); + +// Add punctuations to the input text. +// The user has to invoke SherpaOfflinePunctuationFreeText() +// to free the returned pointer to avoid memory leak +SHERPA_ONNX_API const char *SherpaOfflinePunctuationAddPunct( + const SherpaOnnxOfflinePunctuation *punct, const char *text); + +SHERPA_ONNX_API void SherpaOfflinePunctuationFreeText(const char *text); + #if defined(__GNUC__) #pragma GCC diagnostic pop #endif diff --git a/sherpa-onnx/csrc/offline-punctuation-ct-transformer-impl.h b/sherpa-onnx/csrc/offline-punctuation-ct-transformer-impl.h index 393feba9..4414a5a8 100644 --- a/sherpa-onnx/csrc/offline-punctuation-ct-transformer-impl.h +++ b/sherpa-onnx/csrc/offline-punctuation-ct-transformer-impl.h @@ -134,25 +134,40 @@ class OfflinePunctuationCtTransformerImpl : public OfflinePunctuationImpl { } } // for (int32_t i = 0; i != num_segments; ++i) - std::string ans; + if (punctuations.empty()) { + return text + meta_data.id2punct[meta_data.dot_id]; + } + std::vector words_punct; for (int32_t i = 0; i != static_cast(punctuations.size()); ++i) { - if (i > tokens.size()) { + if (i >= tokens.size()) { break; } - const std::string &w = tokens[i]; - if (i > 0 && !(ans.back() & 0x80) && !(w[0] & 0x80)) { - ans.push_back(' '); + std::string &w = tokens[i]; + if (i > 0 && !(words_punct.back()[0] & 0x80) && !(w[0] & 0x80)) { + words_punct.push_back(" "); } - ans.append(w); + words_punct.push_back(std::move(w)); + if (punctuations[i] != meta_data.underline_id) { - ans.append(meta_data.id2punct[punctuations[i]]); + words_punct.push_back(meta_data.id2punct[punctuations[i]]); } } - if (ans.back() != meta_data.dot_id && ans.back() != meta_data.quest_id) { - ans.push_back(meta_data.dot_id); - } + if (words_punct.back() == meta_data.id2punct[meta_data.comma_id] || + words_punct.back() == meta_data.id2punct[meta_data.pause_id]) { + words_punct.back() = meta_data.id2punct[meta_data.dot_id]; + } + + if (words_punct.back() != meta_data.id2punct[meta_data.dot_id] && + words_punct.back() != meta_data.id2punct[meta_data.quest_id]) { + words_punct.push_back(meta_data.id2punct[meta_data.dot_id]); + } + + std::string ans; + for (const auto &w : words_punct) { + ans.append(w); + } return ans; } diff --git a/sherpa-onnx/python/csrc/offline-punctuation.cc b/sherpa-onnx/python/csrc/offline-punctuation.cc index 7d3ff86d..0ff25903 100644 --- a/sherpa-onnx/python/csrc/offline-punctuation.cc +++ b/sherpa-onnx/python/csrc/offline-punctuation.cc @@ -4,6 +4,8 @@ #include "sherpa-onnx/python/csrc/offline-punctuation.h" +#include + #include "sherpa-onnx/csrc/offline-punctuation.h" namespace sherpa_onnx {