diff --git a/.github/scripts/test-offline-punctuation.sh b/.github/scripts/test-offline-punctuation.sh new file mode 100755 index 00000000..6a096c36 --- /dev/null +++ b/.github/scripts/test-offline-punctuation.sh @@ -0,0 +1,41 @@ +#!/usr/bin/env bash + +set -ex + +log() { + # This function is from espnet + local fname=${BASH_SOURCE[1]##*/} + echo -e "$(date '+%Y-%m-%d %H:%M:%S') (${fname}:${BASH_LINENO[0]}:${FUNCNAME[1]}) $*" +} + +echo "EXE is $EXE" +echo "PATH: $PATH" + +which $EXE + +log "------------------------------------------------------------" +log "Download model " +log "------------------------------------------------------------" + +curl -SL -O https://github.com/k2-fsa/sherpa-onnx/releases/download/punctuation-models/sherpa-onnx-punct-ct-transformer-zh-en-vocab272727-2024-04-12.tar.bz2 +tar xvf sherpa-onnx-punct-ct-transformer-zh-en-vocab272727-2024-04-12.tar.bz2 +rm sherpa-onnx-punct-ct-transformer-zh-en-vocab272727-2024-04-12.tar.bz2 +repo=sherpa-onnx-punct-ct-transformer-zh-en-vocab272727-2024-04-12 +ls -lh $repo + +$EXE \ + --debug=1 \ + --ct-transformer=$repo/model.onnx \ + "这是一个测试你好吗How are you我很好thank you are you ok谢谢你" + +$EXE \ + --debug=1 \ + --ct-transformer=$repo/model.onnx \ + "我们都是木头人不会说话不会动" + +$EXE \ + --debug=1 \ + --ct-transformer=$repo/model.onnx \ + "The African blogosphere is rapidly expanding bringing more voices online in the form of commentaries opinions analyses rants and poetry" + +rm -rf $repo diff --git a/.github/workflows/linux.yaml b/.github/workflows/linux.yaml index ee58cc40..0b8ba50f 100644 --- a/.github/workflows/linux.yaml +++ b/.github/workflows/linux.yaml @@ -16,6 +16,7 @@ on: - '.github/scripts/test-online-ctc.sh' - '.github/scripts/test-offline-tts.sh' - '.github/scripts/test-audio-tagging.sh' + - '.github/scripts/test-offline-punctuation.sh' - 'CMakeLists.txt' - 'cmake/**' - 'sherpa-onnx/csrc/*' @@ -34,6 +35,7 @@ on: - '.github/scripts/test-online-ctc.sh' - '.github/scripts/test-offline-tts.sh' - '.github/scripts/test-audio-tagging.sh' + - '.github/scripts/test-offline-punctuation.sh' - 'CMakeLists.txt' - 'cmake/**' - 'sherpa-onnx/csrc/*' @@ -126,6 +128,14 @@ jobs: name: release-${{ matrix.build_type }}-with-shared-lib-${{ matrix.shared_lib }}-with-tts-${{ matrix.with_tts }} path: build/bin/* + - name: Test offline punctuation + shell: bash + run: | + export PATH=$PWD/build/bin:$PATH + export EXE=sherpa-onnx-offline-punctuation + + .github/scripts/test-offline-punctuation.sh + - name: Test C API shell: bash run: | diff --git a/.github/workflows/macos.yaml b/.github/workflows/macos.yaml index 99b4f301..ecb2f835 100644 --- a/.github/workflows/macos.yaml +++ b/.github/workflows/macos.yaml @@ -16,6 +16,7 @@ on: - '.github/scripts/test-offline-tts.sh' - '.github/scripts/test-online-ctc.sh' - '.github/scripts/test-audio-tagging.sh' + - '.github/scripts/test-offline-punctuation.sh' - 'CMakeLists.txt' - 'cmake/**' - 'sherpa-onnx/csrc/*' @@ -33,6 +34,7 @@ on: - '.github/scripts/test-offline-tts.sh' - '.github/scripts/test-online-ctc.sh' - '.github/scripts/test-audio-tagging.sh' + - '.github/scripts/test-offline-punctuation.sh' - 'CMakeLists.txt' - 'cmake/**' - 'sherpa-onnx/csrc/*' @@ -105,6 +107,14 @@ jobs: otool -L build/bin/sherpa-onnx otool -l build/bin/sherpa-onnx + - name: Test offline punctuation + shell: bash + run: | + export PATH=$PWD/build/bin:$PATH + export EXE=sherpa-onnx-offline-punctuation + + .github/scripts/test-offline-punctuation.sh + - name: Test C API shell: bash run: | diff --git a/.github/workflows/windows-x64.yaml b/.github/workflows/windows-x64.yaml index cf000be8..55eedb37 100644 --- a/.github/workflows/windows-x64.yaml +++ b/.github/workflows/windows-x64.yaml @@ -15,6 +15,7 @@ on: - '.github/scripts/test-online-ctc.sh' - '.github/scripts/test-offline-tts.sh' - '.github/scripts/test-audio-tagging.sh' + - '.github/scripts/test-offline-punctuation.sh' - 'CMakeLists.txt' - 'cmake/**' - 'sherpa-onnx/csrc/*' @@ -30,6 +31,7 @@ on: - '.github/scripts/test-online-ctc.sh' - '.github/scripts/test-offline-tts.sh' - '.github/scripts/test-audio-tagging.sh' + - '.github/scripts/test-offline-punctuation.sh' - 'CMakeLists.txt' - 'cmake/**' - 'sherpa-onnx/csrc/*' @@ -72,6 +74,14 @@ jobs: ls -lh ./bin/Release/sherpa-onnx.exe + - name: Test offline punctuation + shell: bash + run: | + export PATH=$PWD/build/bin/Release:$PATH + export EXE=sherpa-onnx-offline-punctuation.exe + + .github/scripts/test-offline-punctuation.sh + - name: Test C API shell: bash run: | @@ -82,7 +92,6 @@ jobs: .github/scripts/test-c-api.sh - - name: Test Audio tagging shell: bash run: | diff --git a/.github/workflows/windows-x86.yaml b/.github/workflows/windows-x86.yaml index 7a18e0be..b579487a 100644 --- a/.github/workflows/windows-x86.yaml +++ b/.github/workflows/windows-x86.yaml @@ -15,6 +15,7 @@ on: - '.github/scripts/test-offline-tts.sh' - '.github/scripts/test-online-ctc.sh' - '.github/scripts/test-audio-tagging.sh' + - '.github/scripts/test-offline-punctuation.sh' - 'CMakeLists.txt' - 'cmake/**' - 'sherpa-onnx/csrc/*' @@ -30,6 +31,7 @@ on: - '.github/scripts/test-offline-tts.sh' - '.github/scripts/test-online-ctc.sh' - '.github/scripts/test-audio-tagging.sh' + - '.github/scripts/test-offline-punctuation.sh' - 'CMakeLists.txt' - 'cmake/**' - 'sherpa-onnx/csrc/*' @@ -72,6 +74,14 @@ jobs: ls -lh ./bin/Release/sherpa-onnx.exe + - name: Test offline punctuation + shell: bash + run: | + export PATH=$PWD/build/bin/Release:$PATH + export EXE=sherpa-onnx-offline-punctuation.exe + + .github/scripts/test-offline-punctuation.sh + - name: Test spoken language identification (C API) shell: bash run: | diff --git a/cmake/cmake_extension.py b/cmake/cmake_extension.py index ca57ff30..16327613 100644 --- a/cmake/cmake_extension.py +++ b/cmake/cmake_extension.py @@ -46,14 +46,15 @@ def enable_alsa(): def get_binaries(): binaries = [ "sherpa-onnx", - "sherpa-onnx-offline-audio-tagging", "sherpa-onnx-keyword-spotter", "sherpa-onnx-microphone", "sherpa-onnx-microphone-offline", "sherpa-onnx-microphone-offline-audio-tagging", "sherpa-onnx-microphone-offline-speaker-identification", "sherpa-onnx-offline", + "sherpa-onnx-offline-audio-tagging", "sherpa-onnx-offline-language-identification", + "sherpa-onnx-offline-punctuation", "sherpa-onnx-offline-tts", "sherpa-onnx-offline-tts-play", "sherpa-onnx-offline-websocket-server", diff --git a/python-api-examples/speaker-identification-with-vad-non-streaming-asr.py b/python-api-examples/speaker-identification-with-vad-non-streaming-asr.py index dfa54bb0..fe735e17 100755 --- a/python-api-examples/speaker-identification-with-vad-non-streaming-asr.py +++ b/python-api-examples/speaker-identification-with-vad-non-streaming-asr.py @@ -408,8 +408,11 @@ def main(): vad_config.silero_vad.min_silence_duration = 0.25 vad_config.silero_vad.min_speech_duration = 0.25 vad_config.sample_rate = g_sample_rate + if not vad_config.validate(): + raise ValueError("Errors in vad config") window_size = vad_config.silero_vad.window_size + vad = sherpa_onnx.VoiceActivityDetector(vad_config, buffer_size_in_seconds=100) samples_per_read = int(0.1 * g_sample_rate) # 0.1 second = 100 ms diff --git a/sherpa-onnx/csrc/CMakeLists.txt b/sherpa-onnx/csrc/CMakeLists.txt index fe2a1a93..5ec49050 100644 --- a/sherpa-onnx/csrc/CMakeLists.txt +++ b/sherpa-onnx/csrc/CMakeLists.txt @@ -121,6 +121,14 @@ list(APPEND sources offline-zipformer-audio-tagging-model.cc ) +# punctuation +list(APPEND sources + offline-ct-transformer-model.cc + offline-punctuation-impl.cc + offline-punctuation-model-config.cc + offline-punctuation.cc +) + if(SHERPA_ONNX_ENABLE_TTS) list(APPEND sources lexicon.cc @@ -201,9 +209,10 @@ if(SHERPA_ONNX_ENABLE_BINARY) add_executable(sherpa-onnx sherpa-onnx.cc) add_executable(sherpa-onnx-keyword-spotter sherpa-onnx-keyword-spotter.cc) add_executable(sherpa-onnx-offline sherpa-onnx-offline.cc) - add_executable(sherpa-onnx-offline-parallel sherpa-onnx-offline-parallel.cc) - add_executable(sherpa-onnx-offline-language-identification sherpa-onnx-offline-language-identification.cc) add_executable(sherpa-onnx-offline-audio-tagging sherpa-onnx-offline-audio-tagging.cc) + add_executable(sherpa-onnx-offline-language-identification sherpa-onnx-offline-language-identification.cc) + add_executable(sherpa-onnx-offline-parallel sherpa-onnx-offline-parallel.cc) + add_executable(sherpa-onnx-offline-punctuation sherpa-onnx-offline-punctuation.cc) if(SHERPA_ONNX_ENABLE_TTS) add_executable(sherpa-onnx-offline-tts sherpa-onnx-offline-tts.cc) @@ -213,9 +222,10 @@ if(SHERPA_ONNX_ENABLE_BINARY) sherpa-onnx sherpa-onnx-keyword-spotter sherpa-onnx-offline - sherpa-onnx-offline-parallel - sherpa-onnx-offline-language-identification sherpa-onnx-offline-audio-tagging + sherpa-onnx-offline-language-identification + sherpa-onnx-offline-parallel + sherpa-onnx-offline-punctuation ) if(SHERPA_ONNX_ENABLE_TTS) list(APPEND main_exes @@ -260,11 +270,11 @@ endif() if(SHERPA_ONNX_HAS_ALSA AND SHERPA_ONNX_ENABLE_BINARY) add_executable(sherpa-onnx-alsa sherpa-onnx-alsa.cc alsa.cc) - add_executable(sherpa-onnx-keyword-spotter-alsa sherpa-onnx-keyword-spotter-alsa.cc alsa.cc) add_executable(sherpa-onnx-alsa-offline sherpa-onnx-alsa-offline.cc alsa.cc) - add_executable(sherpa-onnx-alsa-offline-speaker-identification sherpa-onnx-alsa-offline-speaker-identification.cc alsa.cc) - add_executable(sherpa-onnx-vad-alsa sherpa-onnx-vad-alsa.cc alsa.cc) add_executable(sherpa-onnx-alsa-offline-audio-tagging sherpa-onnx-alsa-offline-audio-tagging.cc alsa.cc) + add_executable(sherpa-onnx-alsa-offline-speaker-identification sherpa-onnx-alsa-offline-speaker-identification.cc alsa.cc) + add_executable(sherpa-onnx-keyword-spotter-alsa sherpa-onnx-keyword-spotter-alsa.cc alsa.cc) + add_executable(sherpa-onnx-vad-alsa sherpa-onnx-vad-alsa.cc alsa.cc) if(SHERPA_ONNX_ENABLE_TTS) diff --git a/sherpa-onnx/csrc/lexicon.cc b/sherpa-onnx/csrc/lexicon.cc index e3a87eba..2f176dde 100644 --- a/sherpa-onnx/csrc/lexicon.cc +++ b/sherpa-onnx/csrc/lexicon.cc @@ -74,11 +74,6 @@ static std::vector ProcessHeteronyms( return ans; } -static void ToLowerCase(std::string *in_out) { - std::transform(in_out->begin(), in_out->end(), in_out->begin(), - [](unsigned char c) { return std::tolower(c); }); -} - // Note: We don't use SymbolTable here since tokens may contain a blank // in the first column static std::unordered_map ReadTokens(std::istream &is) { diff --git a/sherpa-onnx/csrc/macros.h b/sherpa-onnx/csrc/macros.h index 9ac3d302..b5dfb99e 100644 --- a/sherpa-onnx/csrc/macros.h +++ b/sherpa-onnx/csrc/macros.h @@ -118,6 +118,24 @@ } \ } while (0) +// read a vector of strings separated by sep +#define SHERPA_ONNX_READ_META_DATA_VEC_STRING_SEP(dst, src_key, sep) \ + do { \ + auto value = \ + meta_data.LookupCustomMetadataMapAllocated(src_key, allocator); \ + if (!value) { \ + SHERPA_ONNX_LOGE("%s does not exist in the metadata", src_key); \ + exit(-1); \ + } \ + SplitStringToVector(value.get(), sep, false, &dst); \ + \ + if (dst.empty()) { \ + SHERPA_ONNX_LOGE("Invalid value %s for %s. Empty vector!", value.get(), \ + src_key); \ + exit(-1); \ + } \ + } while (0) + // Read a string #define SHERPA_ONNX_READ_META_DATA_STR(dst, src_key) \ do { \ diff --git a/sherpa-onnx/csrc/offline-ct-transformer-model-meta-data.h b/sherpa-onnx/csrc/offline-ct-transformer-model-meta-data.h new file mode 100644 index 00000000..eea37d73 --- /dev/null +++ b/sherpa-onnx/csrc/offline-ct-transformer-model-meta-data.h @@ -0,0 +1,29 @@ +// sherpa-onnx/csrc/offline-ct-transformer-model-meta_data.h +// +// Copyright (c) 2024 Xiaomi Corporation +#ifndef SHERPA_ONNX_CSRC_OFFLINE_CT_TRANSFORMER_MODEL_META_DATA_H_ +#define SHERPA_ONNX_CSRC_OFFLINE_CT_TRANSFORMER_MODEL_META_DATA_H_ + +#include +#include +#include + +namespace sherpa_onnx { + +struct OfflineCtTransformerModelMetaData { + std::unordered_map token2id; + std::unordered_map punct2id; + std::vector id2punct; + + int32_t unk_id; + int32_t dot_id; + int32_t comma_id; + int32_t quest_id; + int32_t pause_id; + int32_t underline_id; + int32_t num_punctuations; +}; + +} // namespace sherpa_onnx + +#endif // SHERPA_ONNX_CSRC_OFFLINE_CT_TRANSFORMER_MODEL_META_DATA_H_ diff --git a/sherpa-onnx/csrc/offline-ct-transformer-model.cc b/sherpa-onnx/csrc/offline-ct-transformer-model.cc new file mode 100644 index 00000000..4452f7c7 --- /dev/null +++ b/sherpa-onnx/csrc/offline-ct-transformer-model.cc @@ -0,0 +1,164 @@ +// sherpa-onnx/csrc/offline-ct-transformer-model.cc +// +// Copyright (c) 2024 Xiaomi Corporation + +#include "sherpa-onnx/csrc/offline-ct-transformer-model.h" + +#include +#include + +#include "sherpa-onnx/csrc/onnx-utils.h" +#include "sherpa-onnx/csrc/session.h" +#include "sherpa-onnx/csrc/text-utils.h" + +namespace sherpa_onnx { + +class OfflineCtTransformerModel::Impl { + public: + explicit Impl(const OfflinePunctuationModelConfig &config) + : config_(config), + env_(ORT_LOGGING_LEVEL_ERROR), + sess_opts_(GetSessionOptions(config)), + allocator_{} { + auto buf = ReadFile(config_.ct_transformer); + Init(buf.data(), buf.size()); + } + +#if __ANDROID_API__ >= 9 + Impl(AAssetManager *mgr, const OfflinePunctuationModelConfig &config) + : config_(config), + env_(ORT_LOGGING_LEVEL_ERROR), + sess_opts_(GetSessionOptions(config)), + allocator_{} { + auto buf = ReadFile(mgr, config_.ct_transformer); + Init(buf.data(), buf.size()); + } +#endif + + Ort::Value Forward(Ort::Value text, Ort::Value text_len) { + std::array inputs = {std::move(text), std::move(text_len)}; + + auto ans = + sess_->Run({}, input_names_ptr_.data(), inputs.data(), inputs.size(), + output_names_ptr_.data(), output_names_ptr_.size()); + return std::move(ans[0]); + } + + OrtAllocator *Allocator() const { return allocator_; } + + const OfflineCtTransformerModelMetaData &GetModelMetadata() const { + return meta_data_; + } + + private: + void Init(void *model_data, size_t model_data_length) { + sess_ = std::make_unique(env_, model_data, model_data_length, + sess_opts_); + + GetInputNames(sess_.get(), &input_names_, &input_names_ptr_); + + GetOutputNames(sess_.get(), &output_names_, &output_names_ptr_); + + // get meta data + Ort::ModelMetadata meta_data = sess_->GetModelMetadata(); + + Ort::AllocatorWithDefaultOptions allocator; // used in the macro below + + std::vector tokens; + SHERPA_ONNX_READ_META_DATA_VEC_STRING_SEP(tokens, "tokens", "|"); + + int32_t vocab_size; + SHERPA_ONNX_READ_META_DATA(vocab_size, "vocab_size"); + if (tokens.size() != vocab_size) { + SHERPA_ONNX_LOGE("tokens.size() %d != vocab_size %d", + static_cast(tokens.size()), vocab_size); + exit(-1); + } + + SHERPA_ONNX_READ_META_DATA_VEC_STRING_SEP(meta_data_.id2punct, + "punctuations", "|"); + + std::string unk_symbol; + SHERPA_ONNX_READ_META_DATA_STR(unk_symbol, "unk_symbol"); + + // output shape is (N, T, num_punctuations) + meta_data_.num_punctuations = + sess_->GetOutputTypeInfo(0).GetTensorTypeAndShapeInfo().GetShape()[2]; + + int32_t i = 0; + for (const auto &t : tokens) { + meta_data_.token2id[t] = i; + i += 1; + } + + i = 0; + for (const auto &p : meta_data_.id2punct) { + meta_data_.punct2id[p] = i; + i += 1; + } + + meta_data_.unk_id = meta_data_.token2id.at(unk_symbol); + + meta_data_.dot_id = meta_data_.punct2id.at("。"); + meta_data_.comma_id = meta_data_.punct2id.at(","); + meta_data_.quest_id = meta_data_.punct2id.at("?"); + meta_data_.pause_id = meta_data_.punct2id.at("、"); + meta_data_.underline_id = meta_data_.punct2id.at("_"); + + if (config_.debug) { + std::ostringstream os; + os << "vocab_size: " << meta_data_.token2id.size() << "\n"; + os << "num_punctuations: " << meta_data_.num_punctuations << "\n"; + os << "punctuations: "; + for (const auto &s : meta_data_.id2punct) { + os << s << " "; + } + os << "\n"; + SHERPA_ONNX_LOGE("\n%s\n", os.str().c_str()); + } + } + + private: + OfflinePunctuationModelConfig config_; + Ort::Env env_; + Ort::SessionOptions sess_opts_; + Ort::AllocatorWithDefaultOptions allocator_; + + std::unique_ptr sess_; + + std::vector input_names_; + std::vector input_names_ptr_; + + std::vector output_names_; + std::vector output_names_ptr_; + + OfflineCtTransformerModelMetaData meta_data_; +}; + +OfflineCtTransformerModel::OfflineCtTransformerModel( + const OfflinePunctuationModelConfig &config) + : impl_(std::make_unique(config)) {} + +#if __ANDROID_API__ >= 9 +OfflineCtTransformerModel::OfflineCtTransformerModel( + AAssetManager *mgr, const OfflinePunctuationModelConfig &config) + : impl_(std::make_unique(mgr, config)) {} +#endif + +OfflineCtTransformerModel::~OfflineCtTransformerModel() = default; + +Ort::Value OfflineCtTransformerModel::Forward(Ort::Value text, + Ort::Value text_len) const { + return impl_->Forward(std::move(text), std::move(text_len)); +} + +OrtAllocator *OfflineCtTransformerModel::Allocator() const { + return impl_->Allocator(); +} + +const OfflineCtTransformerModelMetaData & +OfflineCtTransformerModel::GetModelMetadata() const { + return impl_->GetModelMetadata(); +} + +} // namespace sherpa_onnx diff --git a/sherpa-onnx/csrc/offline-ct-transformer-model.h b/sherpa-onnx/csrc/offline-ct-transformer-model.h new file mode 100644 index 00000000..06e14ec7 --- /dev/null +++ b/sherpa-onnx/csrc/offline-ct-transformer-model.h @@ -0,0 +1,59 @@ +// sherpa-onnx/csrc/offline-ct-transformer-model.h +// +// Copyright (c) 2024 Xiaomi Corporation +#ifndef SHERPA_ONNX_CSRC_OFFLINE_CT_TRANSFORMER_MODEL_H_ +#define SHERPA_ONNX_CSRC_OFFLINE_CT_TRANSFORMER_MODEL_H_ +#include +#include + +#if __ANDROID_API__ >= 9 +#include "android/asset_manager.h" +#include "android/asset_manager_jni.h" +#endif + +#include "onnxruntime_cxx_api.h" // NOLINT +#include "sherpa-onnx/csrc/offline-ct-transformer-model-meta-data.h" +#include "sherpa-onnx/csrc/offline-punctuation-model-config.h" + +namespace sherpa_onnx { + +/** This class implements + * https://github.com/alibaba-damo-academy/FunASR/blob/main/runtime/python/onnxruntime/funasr_onnx/punc_bin.py#L17 + * from FunASR + */ +class OfflineCtTransformerModel { + public: + explicit OfflineCtTransformerModel( + const OfflinePunctuationModelConfig &config); + +#if __ANDROID_API__ >= 9 + OfflineCtTransformerModel(AAssetManager *mgr, + const OfflinePunctuationModelConfig &config); +#endif + + ~OfflineCtTransformerModel(); + + /** Run the forward method of the model. + * + * @param text A tensor of shape (N, T) of dtype int32. + * @param text A tensor of shape (N) of dtype int32. + * + * @return Return a tensor + * - punctuation_ids: A 2-D tensor of shape (N, T). + */ + Ort::Value Forward(Ort::Value text, Ort::Value text_len) const; + + /** Return an allocator for allocating memory + */ + OrtAllocator *Allocator() const; + + const OfflineCtTransformerModelMetaData &GetModelMetadata() const; + + private: + class Impl; + std::unique_ptr impl_; +}; + +} // namespace sherpa_onnx + +#endif // SHERPA_ONNX_CSRC_OFFLINE_CT_TRANSFORMER_MODEL_H_ diff --git a/sherpa-onnx/csrc/offline-punctuation-ct-transformer-impl.h b/sherpa-onnx/csrc/offline-punctuation-ct-transformer-impl.h new file mode 100644 index 00000000..134b8807 --- /dev/null +++ b/sherpa-onnx/csrc/offline-punctuation-ct-transformer-impl.h @@ -0,0 +1,170 @@ +// sherpa-onnx/csrc/offline-punctuation-ct-transformer-impl.h +// +// Copyright (c) 2024 Xiaomi Corporation +#ifndef SHERPA_ONNX_CSRC_OFFLINE_PUNCTUATION_CT_TRANSFORMER_IMPL_H_ +#define SHERPA_ONNX_CSRC_OFFLINE_PUNCTUATION_CT_TRANSFORMER_IMPL_H_ + +#include +#include +#include +#include + +#include "sherpa-onnx/csrc/macros.h" +#include "sherpa-onnx/csrc/math.h" +#include "sherpa-onnx/csrc/offline-ct-transformer-model.h" +#include "sherpa-onnx/csrc/offline-punctuation-impl.h" +#include "sherpa-onnx/csrc/offline-punctuation.h" +#include "sherpa-onnx/csrc/text-utils.h" + +namespace sherpa_onnx { + +class OfflinePunctuationCtTransformerImpl : public OfflinePunctuationImpl { + public: + explicit OfflinePunctuationCtTransformerImpl( + const OfflinePunctuationConfig &config) + : config_(config), model_(config.model) {} + + std::string AddPunctuation(const std::string &text) const override { + if (text.empty()) { + return {}; + } + + std::vector tokens = SplitUtf8(text); + std::vector token_ids; + token_ids.reserve(tokens.size()); + + const auto &meta_data = model_.GetModelMetadata(); + + for (const auto &t : tokens) { + std::string token = ToLowerCase(t); + if (meta_data.token2id.count(token)) { + token_ids.push_back(meta_data.token2id.at(token)); + } else { + token_ids.push_back(meta_data.unk_id); + } + } + + auto memory_info = + Ort::MemoryInfo::CreateCpu(OrtDeviceAllocator, OrtMemTypeDefault); + + int32_t segment_size = 20; + int32_t max_len = 200; + int32_t num_segments = (token_ids.size() + segment_size - 1) / segment_size; + + std::vector punctuations; + int32_t last = -1; + for (int32_t i = 0; i != num_segments; ++i) { + int32_t this_start = i * segment_size; // inclusive + int32_t this_end = this_start + segment_size; // exclusive + if (this_end > token_ids.size()) { + this_end = token_ids.size(); + } + + if (last != -1) { + this_start = last; + } + // token_ids[this_start:this_end] is sent to the model + + std::array x_shape = {1, this_end - this_start}; + Ort::Value x = + Ort::Value::CreateTensor(memory_info, token_ids.data() + this_start, + x_shape[1], x_shape.data(), x_shape.size()); + + int64_t len_shape = 1; + int32_t len = x_shape[1]; + Ort::Value x_len = + Ort::Value::CreateTensor(memory_info, &len, 1, &len_shape, 1); + + Ort::Value out = model_.Forward(std::move(x), std::move(x_len)); + + // [N, T, num_punctuations] + std::vector out_shape = + out.GetTensorTypeAndShapeInfo().GetShape(); + + assert(out_shape[0] == 1); + assert(out_shape[1] == len); + assert(out_shape[2] == meta_data.num_punctuations); + + std::vector this_punctuations; + this_punctuations.reserve(len); + + const float *p = out.GetTensorData(); + for (int32_t k = 0; k != len; ++k, p += meta_data.num_punctuations) { + auto index = static_cast(std::distance( + p, std::max_element(p, p + meta_data.num_punctuations))); + this_punctuations.push_back(index); + } // for (int32_t k = 0; k != len; ++k, p += meta_data.num_punctuations) + + int32_t dot_index = -1; + int32_t comma_index = -1; + + for (int32_t m = this_punctuations.size() - 1; m >= 1; --m) { + int32_t punct_id = this_punctuations[m]; + + if (punct_id == meta_data.dot_id || punct_id == meta_data.quest_id) { + dot_index = m; + break; + } + + if (comma_index == -1 && punct_id == meta_data.comma_id) { + comma_index = m; + } + } // for (int32_t k = this_punctuations.size() - 1; k >= 1; --k) + + if (dot_index == -1 && len >= max_len && comma_index != -1) { + dot_index = comma_index; + this_punctuations[dot_index] = meta_data.dot_id; + } + + if (dot_index == -1) { + if (last == -1) { + last = this_start; + } + + if (i == num_segments - 1) { + dot_index = token_ids.size() - 1; + } + } else { + last = this_start + dot_index + 1; + + punctuations.insert(punctuations.end(), this_punctuations.begin(), + this_punctuations.begin() + (dot_index + 1)); + } + } // for (int32_t i = 0; i != num_segments; ++i) + + if (punctuations.size() != token_ids.size() && + punctuations.size() + 1 == token_ids.size()) { + punctuations.push_back(meta_data.dot_id); + } + + if (punctuations.size() != token_ids.size()) { + SHERPA_ONNX_LOGE("%s, %d, %d. Some unexpected things happened", + text.c_str(), static_cast(punctuations.size()), + static_cast(token_ids.size())); + return text; + } + + std::string ans; + + for (int32_t i = 0; i != static_cast(punctuations.size()); ++i) { + const std::string &w = tokens[i]; + if (i > 0 && !(ans.back() & 0x80) && !(w[0] & 0x80)) { + ans.push_back(' '); + } + ans.append(w); + if (punctuations[i] != meta_data.underline_id) { + ans.append(meta_data.id2punct[punctuations[i]]); + } + } + + return ans; + } + + private: + OfflinePunctuationConfig config_; + OfflineCtTransformerModel model_; +}; + +} // namespace sherpa_onnx + +#endif // SHERPA_ONNX_CSRC_OFFLINE_PUNCTUATION_CT_TRANSFORMER_IMPL_H_ diff --git a/sherpa-onnx/csrc/offline-punctuation-impl.cc b/sherpa-onnx/csrc/offline-punctuation-impl.cc new file mode 100644 index 00000000..2eefdae3 --- /dev/null +++ b/sherpa-onnx/csrc/offline-punctuation-impl.cc @@ -0,0 +1,22 @@ +// sherpa-onnx/csrc/offline-punctuation-impl.cc +// +// Copyright (c) 2024 Xiaomi Corporation + +#include "sherpa-onnx/csrc/offline-punctuation-impl.h" + +#include "sherpa-onnx/csrc/macros.h" +#include "sherpa-onnx/csrc/offline-punctuation-ct-transformer-impl.h" + +namespace sherpa_onnx { + +std::unique_ptr OfflinePunctuationImpl::Create( + const OfflinePunctuationConfig &config) { + if (!config.model.ct_transformer.empty()) { + return std::make_unique(config); + } + + SHERPA_ONNX_LOGE("Please specify a punctuation model! Return a null pointer"); + return nullptr; +} + +} // namespace sherpa_onnx diff --git a/sherpa-onnx/csrc/offline-punctuation-impl.h b/sherpa-onnx/csrc/offline-punctuation-impl.h new file mode 100644 index 00000000..7e1c1c1b --- /dev/null +++ b/sherpa-onnx/csrc/offline-punctuation-impl.h @@ -0,0 +1,27 @@ +// sherpa-onnx/csrc/offline-punctuation-impl.h +// +// Copyright (c) 2024 Xiaomi Corporation +#ifndef SHERPA_ONNX_CSRC_OFFLINE_PUNCTUATION_IMPL_H_ +#define SHERPA_ONNX_CSRC_OFFLINE_PUNCTUATION_IMPL_H_ + +#include +#include +#include + +#include "sherpa-onnx/csrc/offline-punctuation.h" + +namespace sherpa_onnx { + +class OfflinePunctuationImpl { + public: + virtual ~OfflinePunctuationImpl() = default; + + static std::unique_ptr Create( + const OfflinePunctuationConfig &config); + + virtual std::string AddPunctuation(const std::string &text) const = 0; +}; + +} // namespace sherpa_onnx + +#endif // SHERPA_ONNX_CSRC_OFFLINE_PUNCTUATION_IMPL_H_ diff --git a/sherpa-onnx/csrc/offline-punctuation-model-config.cc b/sherpa-onnx/csrc/offline-punctuation-model-config.cc new file mode 100644 index 00000000..e98fe00b --- /dev/null +++ b/sherpa-onnx/csrc/offline-punctuation-model-config.cc @@ -0,0 +1,53 @@ +// sherpa-onnx/csrc/offline-punctuation-model-config.cc +// +// Copyright (c) 2024 Xiaomi Corporation + +#include "sherpa-onnx/csrc/offline-punctuation-model-config.h" + +#include "sherpa-onnx/csrc/file-utils.h" +#include "sherpa-onnx/csrc/macros.h" + +namespace sherpa_onnx { + +void OfflinePunctuationModelConfig::Register(ParseOptions *po) { + po->Register("ct-transformer", &ct_transformer, + "Path to the controllable time-delay (CT) transformer model"); + + po->Register("num-threads", &num_threads, + "Number of threads to run the neural network"); + + po->Register("debug", &debug, + "true to print model information while loading it."); + + po->Register("provider", &provider, + "Specify a provider to use: cpu, cuda, coreml"); +} + +bool OfflinePunctuationModelConfig::Validate() const { + if (ct_transformer.empty()) { + SHERPA_ONNX_LOGE("Please provide --ct-transformer"); + return false; + } + + if (!FileExists(ct_transformer)) { + SHERPA_ONNX_LOGE("--ct-transformer %s does not exist", + ct_transformer.c_str()); + return false; + } + + return true; +} + +std::string OfflinePunctuationModelConfig::ToString() const { + std::ostringstream os; + + os << "OfflinePunctuationModelConfig("; + os << "ct_transformer=\"" << ct_transformer << "\", "; + os << "num_threads=" << num_threads << ", "; + os << "debug=" << (debug ? "True" : "False") << ", "; + os << "provider=\"" << provider << "\")"; + + return os.str(); +} + +} // namespace sherpa_onnx diff --git a/sherpa-onnx/csrc/offline-punctuation-model-config.h b/sherpa-onnx/csrc/offline-punctuation-model-config.h new file mode 100644 index 00000000..aa294f3f --- /dev/null +++ b/sherpa-onnx/csrc/offline-punctuation-model-config.h @@ -0,0 +1,38 @@ +// sherpa-onnx/csrc/offline-punctuation-model-config.h +// +// Copyright (c) 2024 Xiaomi Corporation +#ifndef SHERPA_ONNX_CSRC_OFFLINE_PUNCTUATION_MODEL_CONFIG_H_ +#define SHERPA_ONNX_CSRC_OFFLINE_PUNCTUATION_MODEL_CONFIG_H_ + +#include + +#include "sherpa-onnx/csrc/parse-options.h" + +namespace sherpa_onnx { + +struct OfflinePunctuationModelConfig { + std::string ct_transformer; + + int32_t num_threads = 1; + bool debug = false; + std::string provider = "cpu"; + + OfflinePunctuationModelConfig() = default; + + OfflinePunctuationModelConfig(const std::string &ct_transformer, + int32_t num_threads, bool debug, + const std::string &provider) + : ct_transformer(ct_transformer), + num_threads(num_threads), + debug(debug), + provider(provider) {} + + void Register(ParseOptions *po); + bool Validate() const; + + std::string ToString() const; +}; + +} // namespace sherpa_onnx + +#endif // SHERPA_ONNX_CSRC_OFFLINE_PUNCTUATION_MODEL_CONFIG_H_ diff --git a/sherpa-onnx/csrc/offline-punctuation.cc b/sherpa-onnx/csrc/offline-punctuation.cc new file mode 100644 index 00000000..292156ab --- /dev/null +++ b/sherpa-onnx/csrc/offline-punctuation.cc @@ -0,0 +1,42 @@ +// sherpa-onnx/csrc/offline-punctuation.cc +// +// Copyright (c) 2024 Xiaomi Corporation + +#include "sherpa-onnx/csrc/offline-punctuation.h" + +#include "sherpa-onnx/csrc/macros.h" +#include "sherpa-onnx/csrc/offline-punctuation-impl.h" + +namespace sherpa_onnx { + +void OfflinePunctuationConfig::Register(ParseOptions *po) { + model.Register(po); +} + +bool OfflinePunctuationConfig::Validate() const { + if (!model.Validate()) { + return false; + } + + return true; +} + +std::string OfflinePunctuationConfig::ToString() const { + std::ostringstream os; + + os << "OfflinePunctuationConfig("; + os << "model=" << model.ToString() << ")"; + + return os.str(); +} + +OfflinePunctuation::OfflinePunctuation(const OfflinePunctuationConfig &config) + : impl_(OfflinePunctuationImpl::Create(config)) {} + +OfflinePunctuation::~OfflinePunctuation() = default; + +std::string OfflinePunctuation::AddPunctuation(const std::string &text) const { + return impl_->AddPunctuation(text); +} + +} // namespace sherpa_onnx diff --git a/sherpa-onnx/csrc/offline-punctuation.h b/sherpa-onnx/csrc/offline-punctuation.h new file mode 100644 index 00000000..7be31e41 --- /dev/null +++ b/sherpa-onnx/csrc/offline-punctuation.h @@ -0,0 +1,47 @@ +// sherpa-onnx/csrc/offline-punctuation.h +// +// Copyright (c) 2024 Xiaomi Corporation +#ifndef SHERPA_ONNX_CSRC_OFFLINE_PUNCTUATION_H_ +#define SHERPA_ONNX_CSRC_OFFLINE_PUNCTUATION_H_ + +#include +#include +#include + +#include "sherpa-onnx/csrc/offline-punctuation-model-config.h" +#include "sherpa-onnx/csrc/parse-options.h" + +namespace sherpa_onnx { + +struct OfflinePunctuationConfig { + OfflinePunctuationModelConfig model; + + OfflinePunctuationConfig() = default; + + explicit OfflinePunctuationConfig(const OfflinePunctuationModelConfig &model) + : model(model) {} + + void Register(ParseOptions *po); + bool Validate() const; + + std::string ToString() const; +}; + +class OfflinePunctuationImpl; + +class OfflinePunctuation { + public: + explicit OfflinePunctuation(const OfflinePunctuationConfig &config); + + ~OfflinePunctuation(); + + // Add punctuation to the input text and return it. + std::string AddPunctuation(const std::string &text) const; + + private: + std::unique_ptr impl_; +}; + +} // namespace sherpa_onnx + +#endif // SHERPA_ONNX_CSRC_OFFLINE_PUNCTUATION_H_ diff --git a/sherpa-onnx/csrc/online-websocket-server-impl.cc b/sherpa-onnx/csrc/online-websocket-server-impl.cc index ca9f2bf8..d02a4913 100644 --- a/sherpa-onnx/csrc/online-websocket-server-impl.cc +++ b/sherpa-onnx/csrc/online-websocket-server-impl.cc @@ -29,7 +29,6 @@ void OnlineWebsocketDecoderConfig::Validate() const { SHERPA_ONNX_CHECK_GT(loop_interval_ms, 0); SHERPA_ONNX_CHECK_GT(max_batch_size, 0); SHERPA_ONNX_CHECK_GT(end_tail_padding, 0); - } void OnlineWebsocketServerConfig::Register(sherpa_onnx::ParseOptions *po) { @@ -87,7 +86,8 @@ void OnlineWebsocketDecoder::InputFinished(std::shared_ptr c) { c->samples.pop_front(); } - std::vector tail_padding(static_cast(config_.end_tail_padding * sample_rate)); + std::vector tail_padding( + static_cast(config_.end_tail_padding * sample_rate)); c->s->AcceptWaveform(sample_rate, tail_padding.data(), tail_padding.size()); diff --git a/sherpa-onnx/csrc/session.cc b/sherpa-onnx/csrc/session.cc index d555ed7a..d0a69740 100644 --- a/sherpa-onnx/csrc/session.cc +++ b/sherpa-onnx/csrc/session.cc @@ -160,4 +160,9 @@ Ort::SessionOptions GetSessionOptions(const AudioTaggingModelConfig &config) { return GetSessionOptionsImpl(config.num_threads, config.provider); } +Ort::SessionOptions GetSessionOptions( + const OfflinePunctuationModelConfig &config) { + return GetSessionOptionsImpl(config.num_threads, config.provider); +} + } // namespace sherpa_onnx diff --git a/sherpa-onnx/csrc/session.h b/sherpa-onnx/csrc/session.h index 94f263fd..a4121436 100644 --- a/sherpa-onnx/csrc/session.h +++ b/sherpa-onnx/csrc/session.h @@ -9,6 +9,7 @@ #include "sherpa-onnx/csrc/audio-tagging-model-config.h" #include "sherpa-onnx/csrc/offline-lm-config.h" #include "sherpa-onnx/csrc/offline-model-config.h" +#include "sherpa-onnx/csrc/offline-punctuation-model-config.h" #include "sherpa-onnx/csrc/online-lm-config.h" #include "sherpa-onnx/csrc/online-model-config.h" #include "sherpa-onnx/csrc/speaker-embedding-extractor.h" @@ -43,6 +44,9 @@ Ort::SessionOptions GetSessionOptions( Ort::SessionOptions GetSessionOptions(const AudioTaggingModelConfig &config); +Ort::SessionOptions GetSessionOptions( + const OfflinePunctuationModelConfig &config); + } // namespace sherpa_onnx #endif // SHERPA_ONNX_CSRC_SESSION_H_ diff --git a/sherpa-onnx/csrc/sherpa-onnx-offline-punctuation.cc b/sherpa-onnx/csrc/sherpa-onnx-offline-punctuation.cc new file mode 100644 index 00000000..7f220734 --- /dev/null +++ b/sherpa-onnx/csrc/sherpa-onnx-offline-punctuation.cc @@ -0,0 +1,68 @@ +// sherpa-onnx/csrc/sherpa-onnx-offline-punctuation.cc +// +// Copyright (c) 2022-2024 Xiaomi Corporation +#include + +#include // NOLINT + +#include "sherpa-onnx/csrc/offline-punctuation.h" +#include "sherpa-onnx/csrc/parse-options.h" + +int main(int32_t argc, char *argv[]) { + const char *kUsageMessage = R"usage( +Add punctuations to the input text. + +The input text can contain both Chinese and English words. + +Usage: + +wget https://github.com/k2-fsa/sherpa-onnx/releases/download/punctuation-models/sherpa-onnx-punct-ct-transformer-zh-en-vocab272727-2024-04-12.tar.bz2 +tar xvf sherpa-onnx-punct-ct-transformer-zh-en-vocab272727-2024-04-12.tar.bz2 +rm sherpa-onnx-punct-ct-transformer-zh-en-vocab272727-2024-04-12.tar.bz2 + +./bin/sherpa-onnx-offline-punctuation \ + --ct-transformer=./sherpa-onnx-punct-ct-transformer-zh-en-vocab272727-2024-04-12/model.onnx + "你好吗how are you Fantasitic 谢谢我很好你怎么样呢" + +The output text should look like below: +)usage"; + + sherpa_onnx::ParseOptions po(kUsageMessage); + sherpa_onnx::OfflinePunctuationConfig config; + config.Register(&po); + po.Read(argc, argv); + if (po.NumArgs() != 1) { + fprintf(stderr, + "Error: Please provide only 1 position argument containing the " + "input text.\n\n"); + po.PrintUsage(); + exit(EXIT_FAILURE); + } + + fprintf(stderr, "%s\n", config.ToString().c_str()); + + if (!config.Validate()) { + fprintf(stderr, "Errors in config!\n"); + return -1; + } + + fprintf(stderr, "Creating OfflinePunctuation ...\n"); + sherpa_onnx::OfflinePunctuation punct(config); + fprintf(stderr, "Started\n"); + const auto begin = std::chrono::steady_clock::now(); + + std::string text = po.GetArg(1); + std::string text_with_punct = punct.AddPunctuation(text); + fprintf(stderr, "Done\n"); + const auto end = std::chrono::steady_clock::now(); + + float elapsed_seconds = + std::chrono::duration_cast(end - begin) + .count() / + 1000.; + + fprintf(stderr, "Num threads: %d\n", config.model.num_threads); + fprintf(stderr, "Elapsed seconds: %.3f s\n", elapsed_seconds); + fprintf(stderr, "Input text: %s\n", text.c_str()); + fprintf(stderr, "Output text: %s\n", text_with_punct.c_str()); +} diff --git a/sherpa-onnx/csrc/sherpa-onnx-offline.cc b/sherpa-onnx/csrc/sherpa-onnx-offline.cc index e8c1a7b2..a84266c7 100644 --- a/sherpa-onnx/csrc/sherpa-onnx-offline.cc +++ b/sherpa-onnx/csrc/sherpa-onnx-offline.cc @@ -111,8 +111,8 @@ for a list of pre-trained models to download. fprintf(stderr, "Creating recognizer ...\n"); sherpa_onnx::OfflineRecognizer recognizer(config); - const auto begin = std::chrono::steady_clock::now(); fprintf(stderr, "Started\n"); + const auto begin = std::chrono::steady_clock::now(); std::vector> ss; std::vector ss_pointers; diff --git a/sherpa-onnx/csrc/text-utils.cc b/sherpa-onnx/csrc/text-utils.cc index c01c31b3..04586dd8 100644 --- a/sherpa-onnx/csrc/text-utils.cc +++ b/sherpa-onnx/csrc/text-utils.cc @@ -385,4 +385,16 @@ std::vector SplitUtf8(const std::string &text) { return MergeCharactersIntoWords(ans); } +std::string ToLowerCase(const std::string &s) { + std::string ans(s.size(), 0); + std::transform(s.begin(), s.end(), ans.begin(), + [](unsigned char c) { return std::tolower(c); }); + return ans; +} + +void ToLowerCase(std::string *in_out) { + std::transform(in_out->begin(), in_out->end(), in_out->begin(), + [](unsigned char c) { return std::tolower(c); }); +} + } // namespace sherpa_onnx diff --git a/sherpa-onnx/csrc/text-utils.h b/sherpa-onnx/csrc/text-utils.h index 07251eef..a0b968d8 100644 --- a/sherpa-onnx/csrc/text-utils.h +++ b/sherpa-onnx/csrc/text-utils.h @@ -121,6 +121,9 @@ bool ConvertStringToReal(const std::string &str, T *out); std::vector SplitUtf8(const std::string &text); +std::string ToLowerCase(const std::string &s); +void ToLowerCase(std::string *in_out); + } // namespace sherpa_onnx #endif // SHERPA_ONNX_CSRC_TEXT_UTILS_H_