diff --git a/.github/scripts/test-offline-ctc.sh b/.github/scripts/test-offline-ctc.sh index 1648a18a..f5301686 100755 --- a/.github/scripts/test-offline-ctc.sh +++ b/.github/scripts/test-offline-ctc.sh @@ -89,3 +89,48 @@ time $EXE \ $repo/test_wavs/8k.wav rm -rf $repo + +log "------------------------------------------------------------" +log "Run Librispeech zipformer CTC H/HL/HLG decoding (English) " +log "------------------------------------------------------------" +repo_url=https://huggingface.co/csukuangfj/sherpa-onnx-zipformer-ctc-en-2023-10-02 +log "Start testing ${repo_url}" +repo=$(basename $repo_url) +log "Download pretrained model and test-data from $repo_url" + +GIT_LFS_SKIP_SMUDGE=1 git clone $repo_url +pushd $repo +git lfs pull --include "*.onnx" +git lfs pull --include "*.fst" +ls -lh +popd + +graphs=( +$repo/H.fst +$repo/HL.fst +$repo/HLG.fst +) + +for graph in ${graphs[@]}; do + log "test float32 models with $graph" + time $EXE \ + --model-type=zipformer2_ctc \ + --ctc.graph=$graph \ + --zipformer-ctc-model=$repo/model.onnx \ + --tokens=$repo/tokens.txt \ + $repo/test_wavs/0.wav \ + $repo/test_wavs/1.wav \ + $repo/test_wavs/2.wav + + log "test int8 models with $graph" + time $EXE \ + --model-type=zipformer2_ctc \ + --ctc.graph=$graph \ + --zipformer-ctc-model=$repo/model.int8.onnx \ + --tokens=$repo/tokens.txt \ + $repo/test_wavs/0.wav \ + $repo/test_wavs/1.wav \ + $repo/test_wavs/2.wav +done + +rm -rf $repo diff --git a/.github/workflows/test-python-online-websocket-server.yaml b/.github/workflows/test-python-online-websocket-server.yaml index 15f81778..428440ed 100644 --- a/.github/workflows/test-python-online-websocket-server.yaml +++ b/.github/workflows/test-python-online-websocket-server.yaml @@ -18,7 +18,7 @@ permissions: jobs: python_online_websocket_server: runs-on: ${{ matrix.os }} - name: ${{ matrix.os }} ${{ matrix.python-version }} + name: ${{ matrix.os }} ${{ matrix.python-version }} ${{ matrix.model_type }} strategy: fail-fast: false matrix: diff --git a/CMakeLists.txt b/CMakeLists.txt index 41a12718..4ccd881e 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -154,6 +154,7 @@ if(NOT BUILD_SHARED_LIBS AND CMAKE_SYSTEM_NAME STREQUAL Linux) endif() include(kaldi-native-fbank) +include(kaldi-decoder) include(onnxruntime) if(SHERPA_ONNX_ENABLE_PORTAUDIO) diff --git a/cmake/eigen.cmake b/cmake/eigen.cmake new file mode 100644 index 00000000..7491bbc1 --- /dev/null +++ b/cmake/eigen.cmake @@ -0,0 +1,48 @@ +function(download_eigen) + include(FetchContent) + + set(eigen_URL "https://gitlab.com/libeigen/eigen/-/archive/3.4.0/eigen-3.4.0.tar.gz") + set(eigen_URL2 "https://huggingface.co/csukuangfj/kaldi-hmm-gmm-cmake-deps/resolve/main/eigen-3.4.0.tar.gz") + set(eigen_HASH "SHA256=8586084f71f9bde545ee7fa6d00288b264a2b7ac3607b974e54d13e7162c1c72") + + # If you don't have access to the Internet, + # please pre-download eigen + set(possible_file_locations + $ENV{HOME}/Downloads/eigen-3.4.0.tar.gz + ${PROJECT_SOURCE_DIR}/eigen-3.4.0.tar.gz + ${PROJECT_BINARY_DIR}/eigen-3.4.0.tar.gz + /tmp/eigen-3.4.0.tar.gz + /star-fj/fangjun/download/github/eigen-3.4.0.tar.gz + ) + + foreach(f IN LISTS possible_file_locations) + if(EXISTS ${f}) + set(eigen_URL "${f}") + file(TO_CMAKE_PATH "${eigen_URL}" eigen_URL) + message(STATUS "Found local downloaded eigen: ${eigen_URL}") + set(eigen_URL2) + break() + endif() + endforeach() + + set(BUILD_TESTING OFF CACHE BOOL "" FORCE) + set(EIGEN_BUILD_DOC OFF CACHE BOOL "" FORCE) + + FetchContent_Declare(eigen + URL ${eigen_URL} + URL_HASH ${eigen_HASH} + ) + + FetchContent_GetProperties(eigen) + if(NOT eigen_POPULATED) + message(STATUS "Downloading eigen from ${eigen_URL}") + FetchContent_Populate(eigen) + endif() + message(STATUS "eigen is downloaded to ${eigen_SOURCE_DIR}") + message(STATUS "eigen's binary dir is ${eigen_BINARY_DIR}") + + add_subdirectory(${eigen_SOURCE_DIR} ${eigen_BINARY_DIR} EXCLUDE_FROM_ALL) +endfunction() + +download_eigen() + diff --git a/cmake/kaldi-decoder.cmake b/cmake/kaldi-decoder.cmake new file mode 100644 index 00000000..ac2482bd --- /dev/null +++ b/cmake/kaldi-decoder.cmake @@ -0,0 +1,78 @@ +function(download_kaldi_decoder) + include(FetchContent) + + set(kaldi_decoder_URL "https://github.com/k2-fsa/kaldi-decoder/archive/refs/tags/v0.2.3.tar.gz") + set(kaldi_decoder_URL2 "https://huggingface.co/csukuangfj/sherpa-onnx-cmake-deps/resolve/main/kaldi-decoder-0.2.3.tar.gz") + set(kaldi_decoder_HASH "SHA256=98bf445a5b7961ccf3c3522317d900054eaadb6a9cdcf4531e7d9caece94a56d") + + set(KALDI_DECODER_BUILD_PYTHON OFF CACHE BOOL "" FORCE) + set(KALDI_DECODER_BUILD_PYTHON OFF CACHE BOOL "" FORCE) + set(KALDIFST_BUILD_PYTHON OFF CACHE BOOL "" FORCE) + + # If you don't have access to the Internet, + # please pre-download kaldi-decoder + set(possible_file_locations + $ENV{HOME}/Downloads/kaldi-decoder-0.2.3.tar.gz + ${PROJECT_SOURCE_DIR}/kaldi-decoder-0.2.3.tar.gz + ${PROJECT_BINARY_DIR}/kaldi-decoder-0.2.3.tar.gz + /tmp/kaldi-decoder-0.2.3.tar.gz + /star-fj/fangjun/download/github/kaldi-decoder-0.2.3.tar.gz + ) + + foreach(f IN LISTS possible_file_locations) + if(EXISTS ${f}) + set(kaldi_decoder_URL "${f}") + file(TO_CMAKE_PATH "${kaldi_decoder_URL}" kaldi_decoder_URL) + message(STATUS "Found local downloaded kaldi-decoder: ${kaldi_decoder_URL}") + set(kaldi_decoder_URL2 ) + break() + endif() + endforeach() + + FetchContent_Declare(kaldi_decoder + URL + ${kaldi_decoder_URL} + ${kaldi_decoder_URL2} + URL_HASH ${kaldi_decoder_HASH} + ) + + FetchContent_GetProperties(kaldi_decoder) + if(NOT kaldi_decoder_POPULATED) + message(STATUS "Downloading kaldi-decoder from ${kaldi_decoder_URL}") + FetchContent_Populate(kaldi_decoder) + endif() + message(STATUS "kaldi-decoder is downloaded to ${kaldi_decoder_SOURCE_DIR}") + message(STATUS "kaldi-decoder's binary dir is ${kaldi_decoder_BINARY_DIR}") + + include_directories(${kaldi_decoder_SOURCE_DIR}) + add_subdirectory(${kaldi_decoder_SOURCE_DIR} ${kaldi_decoder_BINARY_DIR} EXCLUDE_FROM_ALL) + + target_include_directories(kaldi-decoder-core + INTERFACE + ${kaldi-decoder_SOURCE_DIR}/ + ) + if(SHERPA_ONNX_ENABLE_PYTHON AND WIN32) + install(TARGETS + kaldi-decoder-core + kaldifst_core + fst + DESTINATION ..) + else() + install(TARGETS + kaldi-decoder-core + kaldifst_core + fst + DESTINATION lib) + endif() + + if(WIN32 AND BUILD_SHARED_LIBS) + install(TARGETS + kaldi-decoder-core + kaldifst_core + fst + DESTINATION bin) + endif() +endfunction() + +download_kaldi_decoder() + diff --git a/cmake/kaldifst.cmake b/cmake/kaldifst.cmake new file mode 100644 index 00000000..7f9fceef --- /dev/null +++ b/cmake/kaldifst.cmake @@ -0,0 +1,62 @@ +function(download_kaldifst) + include(FetchContent) + + set(kaldifst_URL "https://github.com/k2-fsa/kaldifst/archive/refs/tags/v1.7.6.tar.gz") + set(kaldifst_URL2 "https://huggingface.co/csukuangfj/kaldi-hmm-gmm-cmake-deps/resolve/main/kaldifst-1.7.6.tar.gz") + set(kaldifst_HASH "SHA256=79280c0bb08b5ed1a2ab7c21320a2b071f1f0eb10d2f047e8d6f027f0d32b4d2") + + # If you don't have access to the Internet, + # please pre-download kaldifst + set(possible_file_locations + $ENV{HOME}/Downloads/kaldifst-1.7.6.tar.gz + ${PROJECT_SOURCE_DIR}/kaldifst-1.7.6.tar.gz + ${PROJECT_BINARY_DIR}/kaldifst-1.7.6.tar.gz + /tmp/kaldifst-1.7.6.tar.gz + /star-fj/fangjun/download/github/kaldifst-1.7.6.tar.gz + ) + + foreach(f IN LISTS possible_file_locations) + if(EXISTS ${f}) + set(kaldifst_URL "${f}") + file(TO_CMAKE_PATH "${kaldifst_URL}" kaldifst_URL) + message(STATUS "Found local downloaded kaldifst: ${kaldifst_URL}") + set(kaldifst_URL2) + break() + endif() + endforeach() + + set(KALDIFST_BUILD_TESTS OFF CACHE BOOL "" FORCE) + set(KALDIFST_BUILD_PYTHON OFF CACHE BOOL "" FORCE) + + FetchContent_Declare(kaldifst + URL ${kaldifst_URL} + URL_HASH ${kaldifst_HASH} + ) + + FetchContent_GetProperties(kaldifst) + if(NOT kaldifst_POPULATED) + message(STATUS "Downloading kaldifst from ${kaldifst_URL}") + FetchContent_Populate(kaldifst) + endif() + message(STATUS "kaldifst is downloaded to ${kaldifst_SOURCE_DIR}") + message(STATUS "kaldifst's binary dir is ${kaldifst_BINARY_DIR}") + + list(APPEND CMAKE_MODULE_PATH ${kaldifst_SOURCE_DIR}/cmake) + + add_subdirectory(${kaldifst_SOURCE_DIR} ${kaldifst_BINARY_DIR} EXCLUDE_FROM_ALL) + + target_include_directories(kaldifst_core + PUBLIC + ${kaldifst_SOURCE_DIR}/ + ) + + target_include_directories(fst + PUBLIC + ${openfst_SOURCE_DIR}/src/include + ) + + set_target_properties(kaldifst_core PROPERTIES OUTPUT_NAME "sherpa-onnx-kaldifst-core") + set_target_properties(fst PROPERTIES OUTPUT_NAME "sherpa-onnx-fst") +endfunction() + +download_kaldifst() diff --git a/cmake/sherpa-onnx.pc.in b/cmake/sherpa-onnx.pc.in index 2e640ed0..1d38622e 100644 --- a/cmake/sherpa-onnx.pc.in +++ b/cmake/sherpa-onnx.pc.in @@ -13,4 +13,4 @@ Cflags: -I"${includedir}" # Note: -lcargs is required only for the following file # https://github.com/k2-fsa/sherpa-onnx/blob/master/c-api-examples/decode-file-c-api.c # We add it here so that users don't need to specify -lcargs when compiling decode-file-c-api.c -Libs: -L"${libdir}" -lsherpa-onnx-c-api -lsherpa-onnx-core -lonnxruntime -lkaldi-native-fbank-core -lcargs -Wl,-rpath,${libdir} @SHERPA_ONNX_PKG_CONFIG_EXTRA_LIBS@ +Libs: -L"${libdir}" -lsherpa-onnx-c-api -lsherpa-onnx-core -lonnxruntime -lkaldi-decoder-core -lsherpa-onnx-kaldifst-core -lsherpa-onnx-fst -lkaldi-native-fbank-core -lcargs -Wl,-rpath,${libdir} @SHERPA_ONNX_PKG_CONFIG_EXTRA_LIBS@ diff --git a/mfc-examples/NonStreamingSpeechRecognition/sherpa-onnx-deps.props b/mfc-examples/NonStreamingSpeechRecognition/sherpa-onnx-deps.props index f0e609d3..4c144708 100644 --- a/mfc-examples/NonStreamingSpeechRecognition/sherpa-onnx-deps.props +++ b/mfc-examples/NonStreamingSpeechRecognition/sherpa-onnx-deps.props @@ -9,6 +9,9 @@ sherpa-onnx-portaudio_static.lib; sherpa-onnx-c-api.lib; sherpa-onnx-core.lib; + kaldi-decoder-core.lib; + sherpa-onnx-kaldifst-core.lib; + sherpa-onnx-fst.lib; kaldi-native-fbank-core.lib; absl_base.lib; absl_city.lib; diff --git a/mfc-examples/StreamingSpeechRecognition/sherpa-onnx-deps.props b/mfc-examples/StreamingSpeechRecognition/sherpa-onnx-deps.props index f0e609d3..4c144708 100644 --- a/mfc-examples/StreamingSpeechRecognition/sherpa-onnx-deps.props +++ b/mfc-examples/StreamingSpeechRecognition/sherpa-onnx-deps.props @@ -9,6 +9,9 @@ sherpa-onnx-portaudio_static.lib; sherpa-onnx-c-api.lib; sherpa-onnx-core.lib; + kaldi-decoder-core.lib; + sherpa-onnx-kaldifst-core.lib; + sherpa-onnx-fst.lib; kaldi-native-fbank-core.lib; absl_base.lib; absl_city.lib; diff --git a/sherpa-onnx/csrc/CMakeLists.txt b/sherpa-onnx/csrc/CMakeLists.txt index 386b81ba..e3f6ff73 100644 --- a/sherpa-onnx/csrc/CMakeLists.txt +++ b/sherpa-onnx/csrc/CMakeLists.txt @@ -19,6 +19,8 @@ set(sources features.cc file-utils.cc hypothesis.cc + offline-ctc-fst-decoder-config.cc + offline-ctc-fst-decoder.cc offline-ctc-greedy-search-decoder.cc offline-ctc-model.cc offline-lm-config.cc @@ -42,6 +44,8 @@ set(sources offline-whisper-greedy-search-decoder.cc offline-whisper-model-config.cc offline-whisper-model.cc + offline-zipformer-ctc-model-config.cc + offline-zipformer-ctc-model.cc online-conformer-transducer-model.cc online-lm-config.cc online-lm.cc @@ -97,6 +101,8 @@ endif() target_link_libraries(sherpa-onnx-core kaldi-native-fbank-core) +target_link_libraries(sherpa-onnx-core kaldi-decoder-core) + if(BUILD_SHARED_LIBS OR APPLE OR CMAKE_SYSTEM_PROCESSOR STREQUAL aarch64 OR CMAKE_SYSTEM_PROCESSOR STREQUAL arm) target_link_libraries(sherpa-onnx-core onnxruntime) else() diff --git a/sherpa-onnx/csrc/offline-ctc-fst-decoder-config.cc b/sherpa-onnx/csrc/offline-ctc-fst-decoder-config.cc new file mode 100644 index 00000000..bd412668 --- /dev/null +++ b/sherpa-onnx/csrc/offline-ctc-fst-decoder-config.cc @@ -0,0 +1,32 @@ +// sherpa-onnx/csrc/offline-ctc-fst-decoder-config.cc +// +// Copyright (c) 2023 Xiaomi Corporation + +#include "sherpa-onnx/csrc/offline-ctc-fst-decoder-config.h" + +#include +#include + +namespace sherpa_onnx { + +std::string OfflineCtcFstDecoderConfig::ToString() const { + std::ostringstream os; + + os << "OfflineCtcFstDecoderConfig("; + os << "graph=\"" << graph << "\", "; + os << "max_active=" << max_active << ")"; + + return os.str(); +} + +void OfflineCtcFstDecoderConfig::Register(ParseOptions *po) { + std::string prefix = "ctc"; + ParseOptions p(prefix, po); + + p.Register("graph", &graph, "Path to H.fst, HL.fst, or HLG.fst"); + + p.Register("max-active", &max_active, + "Decoder max active states. Larger->slower; more accurate"); +} + +} // namespace sherpa_onnx diff --git a/sherpa-onnx/csrc/offline-ctc-fst-decoder-config.h b/sherpa-onnx/csrc/offline-ctc-fst-decoder-config.h new file mode 100644 index 00000000..6d7f70ae --- /dev/null +++ b/sherpa-onnx/csrc/offline-ctc-fst-decoder-config.h @@ -0,0 +1,31 @@ +// sherpa-onnx/csrc/offline-ctc-fst-decoder-config.h +// +// Copyright (c) 2023 Xiaomi Corporation + +#ifndef SHERPA_ONNX_CSRC_OFFLINE_CTC_FST_DECODER_CONFIG_H_ +#define SHERPA_ONNX_CSRC_OFFLINE_CTC_FST_DECODER_CONFIG_H_ + +#include + +#include "sherpa-onnx/csrc/parse-options.h" + +namespace sherpa_onnx { + +struct OfflineCtcFstDecoderConfig { + // Path to H.fst, HL.fst or HLG.fst + std::string graph; + int32_t max_active = 3000; + + OfflineCtcFstDecoderConfig() = default; + + OfflineCtcFstDecoderConfig(const std::string &graph, int32_t max_active) + : graph(graph), max_active(max_active) {} + + std::string ToString() const; + + void Register(ParseOptions *po); +}; + +} // namespace sherpa_onnx + +#endif // SHERPA_ONNX_CSRC_OFFLINE_CTC_FST_DECODER_CONFIG_H_ diff --git a/sherpa-onnx/csrc/offline-ctc-fst-decoder.cc b/sherpa-onnx/csrc/offline-ctc-fst-decoder.cc new file mode 100644 index 00000000..efee65a7 --- /dev/null +++ b/sherpa-onnx/csrc/offline-ctc-fst-decoder.cc @@ -0,0 +1,157 @@ +// sherpa-onnx/csrc/offline-ctc-fst-decoder.cc +// +// Copyright (c) 2023 Xiaomi Corporation + +#include "sherpa-onnx/csrc/offline-ctc-fst-decoder.h" + +#include +#include + +#include "fst/fstlib.h" +#include "kaldi-decoder/csrc/decodable-ctc.h" +#include "kaldi-decoder/csrc/eigen.h" +#include "kaldi-decoder/csrc/faster-decoder.h" +#include "sherpa-onnx/csrc/macros.h" + +namespace sherpa_onnx { + +// This function is copied from kaldi. +// +// @param filename Path to a StdVectorFst or StdConstFst graph +// @return The caller should free the returned pointer using `delete` to +// avoid memory leak. +static fst::Fst *ReadGraph(const std::string &filename) { + // read decoding network FST + std::ifstream is(filename, std::ios::binary); + if (!is.good()) { + SHERPA_ONNX_LOGE("Could not open decoding-graph FST %s", filename.c_str()); + } + + fst::FstHeader hdr; + if (!hdr.Read(is, "")) { + SHERPA_ONNX_LOGE("Reading FST: error reading FST header."); + } + + if (hdr.ArcType() != fst::StdArc::Type()) { + SHERPA_ONNX_LOGE("FST with arc type %s not supported", + hdr.ArcType().c_str()); + } + fst::FstReadOptions ropts("", &hdr); + + fst::Fst *decode_fst = nullptr; + + if (hdr.FstType() == "vector") { + decode_fst = fst::VectorFst::Read(is, ropts); + } else if (hdr.FstType() == "const") { + decode_fst = fst::ConstFst::Read(is, ropts); + } else { + SHERPA_ONNX_LOGE("Reading FST: unsupported FST type: %s", + hdr.FstType().c_str()); + } + + if (decode_fst == nullptr) { // fst code will warn. + SHERPA_ONNX_LOGE("Error reading FST (after reading header)."); + return nullptr; + } else { + return decode_fst; + } +} + +/** + * @param decoder + * @param p Pointer to a 2-d array of shape (num_frames, vocab_size) + * @param num_frames Number of rows in the 2-d array. + * @param vocab_size Number of columns in the 2-d array. + * @return Return the decoded result. + */ +static OfflineCtcDecoderResult DecodeOne(kaldi_decoder::FasterDecoder *decoder, + const float *p, int32_t num_frames, + int32_t vocab_size) { + OfflineCtcDecoderResult r; + kaldi_decoder::DecodableCtc decodable(p, num_frames, vocab_size); + + decoder->Decode(&decodable); + + if (!decoder->ReachedFinal()) { + SHERPA_ONNX_LOGE("Not reached final!"); + return r; + } + + fst::VectorFst decoded; // linear FST. + decoder->GetBestPath(&decoded); + + if (decoded.NumStates() == 0) { + SHERPA_ONNX_LOGE("Empty best path!"); + return r; + } + + auto cur_state = decoded.Start(); + + int32_t blank_id = 0; + + for (int32_t t = 0, prev = -1; decoded.NumArcs(cur_state) == 1; ++t) { + fst::ArcIterator> iter(decoded, cur_state); + const auto &arc = iter.Value(); + + cur_state = arc.nextstate; + + if (arc.ilabel == prev) { + continue; + } + + // 0 is epsilon here + if (arc.ilabel == 0 || arc.ilabel == blank_id + 1) { + prev = arc.ilabel; + continue; + } + + // -1 here since the input labels are incremented during graph + // construction + r.tokens.push_back(arc.ilabel - 1); + + r.timestamps.push_back(t); + prev = arc.ilabel; + } + + return r; +} + +OfflineCtcFstDecoder::OfflineCtcFstDecoder( + const OfflineCtcFstDecoderConfig &config) + : config_(config), fst_(ReadGraph(config_.graph)) {} + +std::vector OfflineCtcFstDecoder::Decode( + Ort::Value log_probs, Ort::Value log_probs_length) { + std::vector shape = log_probs.GetTensorTypeAndShapeInfo().GetShape(); + + assert(static_cast(shape.size()) == 3); + int32_t batch_size = shape[0]; + int32_t T = shape[1]; + int32_t vocab_size = shape[2]; + + std::vector length_shape = + log_probs_length.GetTensorTypeAndShapeInfo().GetShape(); + assert(static_cast(length_shape.size()) == 1); + + assert(shape[0] == length_shape[0]); + + kaldi_decoder::FasterDecoderOptions opts; + opts.max_active = config_.max_active; + kaldi_decoder::FasterDecoder faster_decoder(*fst_, opts); + + const float *start = log_probs.GetTensorData(); + + std::vector ans; + ans.reserve(batch_size); + + for (int32_t i = 0; i != batch_size; ++i) { + const float *p = start + i * T * vocab_size; + int32_t num_frames = log_probs_length.GetTensorData()[i]; + auto r = DecodeOne(&faster_decoder, p, num_frames, vocab_size); + ans.push_back(std::move(r)); + } + + return ans; +} + +} // namespace sherpa_onnx diff --git a/sherpa-onnx/csrc/offline-ctc-fst-decoder.h b/sherpa-onnx/csrc/offline-ctc-fst-decoder.h new file mode 100644 index 00000000..2b33c14e --- /dev/null +++ b/sherpa-onnx/csrc/offline-ctc-fst-decoder.h @@ -0,0 +1,33 @@ +// sherpa-onnx/csrc/offline-ctc-fst-decoder.h +// +// Copyright (c) 2023 Xiaomi Corporation + +#ifndef SHERPA_ONNX_CSRC_OFFLINE_CTC_FST_DECODER_H_ +#define SHERPA_ONNX_CSRC_OFFLINE_CTC_FST_DECODER_H_ + +#include +#include + +#include "fst/fst.h" +#include "sherpa-onnx/csrc/offline-ctc-decoder.h" +#include "sherpa-onnx/csrc/offline-ctc-fst-decoder-config.h" +#include "sherpa-onnx/csrc/parse-options.h" + +namespace sherpa_onnx { + +class OfflineCtcFstDecoder : public OfflineCtcDecoder { + public: + explicit OfflineCtcFstDecoder(const OfflineCtcFstDecoderConfig &config); + + std::vector Decode( + Ort::Value log_probs, Ort::Value log_probs_length) override; + + private: + OfflineCtcFstDecoderConfig config_; + + std::unique_ptr> fst_; +}; + +} // namespace sherpa_onnx + +#endif // SHERPA_ONNX_CSRC_OFFLINE_CTC_FST_DECODER_H_ diff --git a/sherpa-onnx/csrc/offline-ctc-model.cc b/sherpa-onnx/csrc/offline-ctc-model.cc index 0be943dd..d8864404 100644 --- a/sherpa-onnx/csrc/offline-ctc-model.cc +++ b/sherpa-onnx/csrc/offline-ctc-model.cc @@ -12,6 +12,7 @@ #include "sherpa-onnx/csrc/macros.h" #include "sherpa-onnx/csrc/offline-nemo-enc-dec-ctc-model.h" #include "sherpa-onnx/csrc/offline-tdnn-ctc-model.h" +#include "sherpa-onnx/csrc/offline-zipformer-ctc-model.h" #include "sherpa-onnx/csrc/onnx-utils.h" namespace { @@ -19,6 +20,7 @@ namespace { enum class ModelType { kEncDecCTCModelBPE, kTdnn, + kZipformerCtc, kUnkown, }; @@ -59,6 +61,8 @@ static ModelType GetModelType(char *model_data, size_t model_data_length, return ModelType::kEncDecCTCModelBPE; } else if (model_type.get() == std::string("tdnn")) { return ModelType::kTdnn; + } else if (model_type.get() == std::string("zipformer2_ctc")) { + return ModelType::kZipformerCtc; } else { SHERPA_ONNX_LOGE("Unsupported model_type: %s", model_type.get()); return ModelType::kUnkown; @@ -74,6 +78,8 @@ std::unique_ptr OfflineCtcModel::Create( filename = config.nemo_ctc.model; } else if (!config.tdnn.model.empty()) { filename = config.tdnn.model; + } else if (!config.zipformer_ctc.model.empty()) { + filename = config.zipformer_ctc.model; } else { SHERPA_ONNX_LOGE("Please specify a CTC model"); exit(-1); @@ -92,6 +98,9 @@ std::unique_ptr OfflineCtcModel::Create( case ModelType::kTdnn: return std::make_unique(config); break; + case ModelType::kZipformerCtc: + return std::make_unique(config); + break; case ModelType::kUnkown: SHERPA_ONNX_LOGE("Unknown model type in offline CTC!"); return nullptr; @@ -111,6 +120,8 @@ std::unique_ptr OfflineCtcModel::Create( filename = config.nemo_ctc.model; } else if (!config.tdnn.model.empty()) { filename = config.tdnn.model; + } else if (!config.zipformer_ctc.model.empty()) { + filename = config.zipformer_ctc.model; } else { SHERPA_ONNX_LOGE("Please specify a CTC model"); exit(-1); @@ -129,6 +140,9 @@ std::unique_ptr OfflineCtcModel::Create( case ModelType::kTdnn: return std::make_unique(mgr, config); break; + case ModelType::kZipformerCtc: + return std::make_unique(mgr, config); + break; case ModelType::kUnkown: SHERPA_ONNX_LOGE("Unknown model type in offline CTC!"); return nullptr; diff --git a/sherpa-onnx/csrc/offline-ctc-model.h b/sherpa-onnx/csrc/offline-ctc-model.h index 95e0279a..e2947f95 100644 --- a/sherpa-onnx/csrc/offline-ctc-model.h +++ b/sherpa-onnx/csrc/offline-ctc-model.h @@ -6,7 +6,7 @@ #include #include -#include +#include #if __ANDROID_API__ >= 9 #include "android/asset_manager.h" @@ -32,17 +32,17 @@ class OfflineCtcModel { /** Run the forward method of the model. * - * @param features A tensor of shape (N, T, C). It is changed in-place. + * @param features A tensor of shape (N, T, C). * @param features_length A 1-D tensor of shape (N,) containing number of * valid frames in `features` before padding. * Its dtype is int64_t. * - * @return Return a pair containing: + * @return Return a vector containing: * - log_probs: A 3-D tensor of shape (N, T', vocab_size). * - log_probs_length A 1-D tensor of shape (N,). Its dtype is int64_t */ - virtual std::pair Forward( - Ort::Value features, Ort::Value features_length) = 0; + virtual std::vector Forward(Ort::Value features, + Ort::Value features_length) = 0; /** Return the vocabulary size of the model */ diff --git a/sherpa-onnx/csrc/offline-model-config.cc b/sherpa-onnx/csrc/offline-model-config.cc index c491ed55..02c799d6 100644 --- a/sherpa-onnx/csrc/offline-model-config.cc +++ b/sherpa-onnx/csrc/offline-model-config.cc @@ -16,6 +16,7 @@ void OfflineModelConfig::Register(ParseOptions *po) { nemo_ctc.Register(po); whisper.Register(po); tdnn.Register(po); + zipformer_ctc.Register(po); po->Register("tokens", &tokens, "Path to tokens.txt"); @@ -31,7 +32,7 @@ void OfflineModelConfig::Register(ParseOptions *po) { po->Register("model-type", &model_type, "Specify it to reduce model initialization time. " "Valid values are: transducer, paraformer, nemo_ctc, whisper, " - "tdnn." + "tdnn, zipformer2_ctc" "All other values lead to loading the model twice."); } @@ -62,6 +63,10 @@ bool OfflineModelConfig::Validate() const { return tdnn.Validate(); } + if (!zipformer_ctc.model.empty()) { + return zipformer_ctc.Validate(); + } + return transducer.Validate(); } @@ -74,6 +79,7 @@ std::string OfflineModelConfig::ToString() const { os << "nemo_ctc=" << nemo_ctc.ToString() << ", "; os << "whisper=" << whisper.ToString() << ", "; os << "tdnn=" << tdnn.ToString() << ", "; + os << "zipformer_ctc=" << zipformer_ctc.ToString() << ", "; os << "tokens=\"" << tokens << "\", "; os << "num_threads=" << num_threads << ", "; os << "debug=" << (debug ? "True" : "False") << ", "; diff --git a/sherpa-onnx/csrc/offline-model-config.h b/sherpa-onnx/csrc/offline-model-config.h index 2664db31..55a063f9 100644 --- a/sherpa-onnx/csrc/offline-model-config.h +++ b/sherpa-onnx/csrc/offline-model-config.h @@ -11,6 +11,7 @@ #include "sherpa-onnx/csrc/offline-tdnn-model-config.h" #include "sherpa-onnx/csrc/offline-transducer-model-config.h" #include "sherpa-onnx/csrc/offline-whisper-model-config.h" +#include "sherpa-onnx/csrc/offline-zipformer-ctc-model-config.h" namespace sherpa_onnx { @@ -20,6 +21,7 @@ struct OfflineModelConfig { OfflineNemoEncDecCtcModelConfig nemo_ctc; OfflineWhisperModelConfig whisper; OfflineTdnnModelConfig tdnn; + OfflineZipformerCtcModelConfig zipformer_ctc; std::string tokens; int32_t num_threads = 2; @@ -43,6 +45,7 @@ struct OfflineModelConfig { const OfflineNemoEncDecCtcModelConfig &nemo_ctc, const OfflineWhisperModelConfig &whisper, const OfflineTdnnModelConfig &tdnn, + const OfflineZipformerCtcModelConfig &zipformer_ctc, const std::string &tokens, int32_t num_threads, bool debug, const std::string &provider, const std::string &model_type) : transducer(transducer), @@ -50,6 +53,7 @@ struct OfflineModelConfig { nemo_ctc(nemo_ctc), whisper(whisper), tdnn(tdnn), + zipformer_ctc(zipformer_ctc), tokens(tokens), num_threads(num_threads), debug(debug), diff --git a/sherpa-onnx/csrc/offline-nemo-enc-dec-ctc-model.cc b/sherpa-onnx/csrc/offline-nemo-enc-dec-ctc-model.cc index 0b685c43..23123399 100644 --- a/sherpa-onnx/csrc/offline-nemo-enc-dec-ctc-model.cc +++ b/sherpa-onnx/csrc/offline-nemo-enc-dec-ctc-model.cc @@ -34,8 +34,8 @@ class OfflineNemoEncDecCtcModel::Impl { } #endif - std::pair Forward(Ort::Value features, - Ort::Value features_length) { + std::vector Forward(Ort::Value features, + Ort::Value features_length) { std::vector shape = features_length.GetTensorTypeAndShapeInfo().GetShape(); @@ -57,7 +57,11 @@ class OfflineNemoEncDecCtcModel::Impl { sess_->Run({}, input_names_ptr_.data(), inputs.data(), inputs.size(), output_names_ptr_.data(), output_names_ptr_.size()); - return {std::move(out[0]), std::move(out_features_length)}; + std::vector ans; + ans.reserve(2); + ans.push_back(std::move(out[0])); + ans.push_back(std::move(out_features_length)); + return ans; } int32_t VocabSize() const { return vocab_size_; } @@ -122,7 +126,7 @@ OfflineNemoEncDecCtcModel::OfflineNemoEncDecCtcModel( OfflineNemoEncDecCtcModel::~OfflineNemoEncDecCtcModel() = default; -std::pair OfflineNemoEncDecCtcModel::Forward( +std::vector OfflineNemoEncDecCtcModel::Forward( Ort::Value features, Ort::Value features_length) { return impl_->Forward(std::move(features), std::move(features_length)); } diff --git a/sherpa-onnx/csrc/offline-nemo-enc-dec-ctc-model.h b/sherpa-onnx/csrc/offline-nemo-enc-dec-ctc-model.h index d4c9aadc..9cdfe199 100644 --- a/sherpa-onnx/csrc/offline-nemo-enc-dec-ctc-model.h +++ b/sherpa-onnx/csrc/offline-nemo-enc-dec-ctc-model.h @@ -38,17 +38,17 @@ class OfflineNemoEncDecCtcModel : public OfflineCtcModel { /** Run the forward method of the model. * - * @param features A tensor of shape (N, T, C). It is changed in-place. + * @param features A tensor of shape (N, T, C). * @param features_length A 1-D tensor of shape (N,) containing number of * valid frames in `features` before padding. * Its dtype is int64_t. * - * @return Return a pair containing: + * @return Return a vector containing: * - log_probs: A 3-D tensor of shape (N, T', vocab_size). * - log_probs_length A 1-D tensor of shape (N,). Its dtype is int64_t */ - std::pair Forward( - Ort::Value features, Ort::Value features_length) override; + std::vector Forward(Ort::Value features, + Ort::Value features_length) override; /** Return the vocabulary size of the model */ diff --git a/sherpa-onnx/csrc/offline-recognizer-ctc-impl.h b/sherpa-onnx/csrc/offline-recognizer-ctc-impl.h index da337000..98d220ba 100644 --- a/sherpa-onnx/csrc/offline-recognizer-ctc-impl.h +++ b/sherpa-onnx/csrc/offline-recognizer-ctc-impl.h @@ -16,6 +16,7 @@ #endif #include "sherpa-onnx/csrc/offline-ctc-decoder.h" +#include "sherpa-onnx/csrc/offline-ctc-fst-decoder.h" #include "sherpa-onnx/csrc/offline-ctc-greedy-search-decoder.h" #include "sherpa-onnx/csrc/offline-ctc-model.h" #include "sherpa-onnx/csrc/offline-recognizer-impl.h" @@ -25,9 +26,12 @@ namespace sherpa_onnx { static OfflineRecognitionResult Convert(const OfflineCtcDecoderResult &src, - const SymbolTable &sym_table) { + const SymbolTable &sym_table, + int32_t frame_shift_ms, + int32_t subsampling_factor) { OfflineRecognitionResult r; r.tokens.reserve(src.tokens.size()); + r.timestamps.reserve(src.timestamps.size()); std::string text; @@ -42,6 +46,12 @@ static OfflineRecognitionResult Convert(const OfflineCtcDecoderResult &src, } r.text = std::move(text); + float frame_shift_s = frame_shift_ms / 1000. * subsampling_factor; + for (auto t : src.timestamps) { + float time = frame_shift_s * t; + r.timestamps.push_back(time); + } + return r; } @@ -68,7 +78,12 @@ class OfflineRecognizerCtcImpl : public OfflineRecognizerImpl { config_.feat_config.nemo_normalize_type = model_->FeatureNormalizationMethod(); - if (config_.decoding_method == "greedy_search") { + if (!config_.ctc_fst_decoder_config.graph.empty()) { + // TODO(fangjun): Support android to read the graph from + // asset_manager + decoder_ = std::make_unique( + config_.ctc_fst_decoder_config); + } else if (config_.decoding_method == "greedy_search") { if (!symbol_table_.contains("") && !symbol_table_.contains("")) { SHERPA_ONNX_LOGE( @@ -139,10 +154,12 @@ class OfflineRecognizerCtcImpl : public OfflineRecognizerImpl { -23.025850929940457f); auto t = model_->Forward(std::move(x), std::move(x_length)); - auto results = decoder_->Decode(std::move(t.first), std::move(t.second)); + auto results = decoder_->Decode(std::move(t[0]), std::move(t[1])); + int32_t frame_shift_ms = 10; for (int32_t i = 0; i != n; ++i) { - auto r = Convert(results[i], symbol_table_); + auto r = Convert(results[i], symbol_table_, frame_shift_ms, + model_->SubsamplingFactor()); ss[i]->SetResult(r); } } diff --git a/sherpa-onnx/csrc/offline-recognizer-impl.cc b/sherpa-onnx/csrc/offline-recognizer-impl.cc index 0a4db68e..31e16133 100644 --- a/sherpa-onnx/csrc/offline-recognizer-impl.cc +++ b/sherpa-onnx/csrc/offline-recognizer-impl.cc @@ -25,9 +25,8 @@ std::unique_ptr OfflineRecognizerImpl::Create( return std::make_unique(config); } else if (model_type == "paraformer") { return std::make_unique(config); - } else if (model_type == "nemo_ctc") { - return std::make_unique(config); - } else if (model_type == "tdnn") { + } else if (model_type == "nemo_ctc" || model_type == "tdnn" || + model_type == "zipformer2_ctc") { return std::make_unique(config); } else if (model_type == "whisper") { return std::make_unique(config); @@ -50,6 +49,8 @@ std::unique_ptr OfflineRecognizerImpl::Create( model_filename = config.model_config.nemo_ctc.model; } else if (!config.model_config.tdnn.model.empty()) { model_filename = config.model_config.tdnn.model; + } else if (!config.model_config.zipformer_ctc.model.empty()) { + model_filename = config.model_config.zipformer_ctc.model; } else if (!config.model_config.whisper.encoder.empty()) { model_filename = config.model_config.whisper.encoder; } else { @@ -93,6 +94,11 @@ std::unique_ptr OfflineRecognizerImpl::Create( "\n " "https://github.com/k2-fsa/icefall/tree/master/egs/yesno/ASR/tdnn" "\n" + "(5) Zipformer CTC models from icefall" + "\n " + "https://github.com/k2-fsa/icefall/blob/master/egs/librispeech/ASR/" + "zipformer/export-onnx-ctc.py" + "\n" "\n"); exit(-1); } @@ -107,11 +113,8 @@ std::unique_ptr OfflineRecognizerImpl::Create( return std::make_unique(config); } - if (model_type == "EncDecCTCModelBPE") { - return std::make_unique(config); - } - - if (model_type == "tdnn") { + if (model_type == "EncDecCTCModelBPE" || model_type == "tdnn" || + model_type == "zipformer2_ctc") { return std::make_unique(config); } @@ -126,7 +129,8 @@ std::unique_ptr OfflineRecognizerImpl::Create( " - Non-streaming Paraformer models from FunASR\n" " - EncDecCTCModelBPE models from NeMo\n" " - Whisper models\n" - " - Tdnn models\n", + " - Tdnn models\n" + " - Zipformer CTC models\n", model_type.c_str()); exit(-1); @@ -141,9 +145,8 @@ std::unique_ptr OfflineRecognizerImpl::Create( return std::make_unique(mgr, config); } else if (model_type == "paraformer") { return std::make_unique(mgr, config); - } else if (model_type == "nemo_ctc") { - return std::make_unique(mgr, config); - } else if (model_type == "tdnn") { + } else if (model_type == "nemo_ctc" || model_type == "tdnn" || + model_type == "zipformer2_ctc") { return std::make_unique(mgr, config); } else if (model_type == "whisper") { return std::make_unique(mgr, config); @@ -166,6 +169,8 @@ std::unique_ptr OfflineRecognizerImpl::Create( model_filename = config.model_config.nemo_ctc.model; } else if (!config.model_config.tdnn.model.empty()) { model_filename = config.model_config.tdnn.model; + } else if (!config.model_config.zipformer_ctc.model.empty()) { + model_filename = config.model_config.zipformer_ctc.model; } else if (!config.model_config.whisper.encoder.empty()) { model_filename = config.model_config.whisper.encoder; } else { @@ -209,6 +214,11 @@ std::unique_ptr OfflineRecognizerImpl::Create( "\n " "https://github.com/k2-fsa/icefall/tree/master/egs/yesno/ASR/tdnn" "\n" + "(5) Zipformer CTC models from icefall" + "\n " + "https://github.com/k2-fsa/icefall/blob/master/egs/librispeech/ASR/" + "zipformer/export-onnx-ctc.py" + "\n" "\n"); exit(-1); } @@ -223,11 +233,8 @@ std::unique_ptr OfflineRecognizerImpl::Create( return std::make_unique(mgr, config); } - if (model_type == "EncDecCTCModelBPE") { - return std::make_unique(mgr, config); - } - - if (model_type == "tdnn") { + if (model_type == "EncDecCTCModelBPE" || model_type == "tdnn" || + model_type == "zipformer2_ctc") { return std::make_unique(mgr, config); } @@ -242,7 +249,8 @@ std::unique_ptr OfflineRecognizerImpl::Create( " - Non-streaming Paraformer models from FunASR\n" " - EncDecCTCModelBPE models from NeMo\n" " - Whisper models\n" - " - Tdnn models\n", + " - Tdnn models\n" + " - Zipformer CTC models\n", model_type.c_str()); exit(-1); diff --git a/sherpa-onnx/csrc/offline-recognizer.cc b/sherpa-onnx/csrc/offline-recognizer.cc index 5e4835b8..7ab7849c 100644 --- a/sherpa-onnx/csrc/offline-recognizer.cc +++ b/sherpa-onnx/csrc/offline-recognizer.cc @@ -17,6 +17,7 @@ void OfflineRecognizerConfig::Register(ParseOptions *po) { feat_config.Register(po); model_config.Register(po); lm_config.Register(po); + ctc_fst_decoder_config.Register(po); po->Register( "decoding-method", &decoding_method, @@ -69,6 +70,7 @@ std::string OfflineRecognizerConfig::ToString() const { os << "feat_config=" << feat_config.ToString() << ", "; os << "model_config=" << model_config.ToString() << ", "; os << "lm_config=" << lm_config.ToString() << ", "; + os << "ctc_fst_decoder_config=" << ctc_fst_decoder_config.ToString() << ", "; os << "decoding_method=\"" << decoding_method << "\", "; os << "max_active_paths=" << max_active_paths << ", "; os << "hotwords_file=\"" << hotwords_file << "\", "; diff --git a/sherpa-onnx/csrc/offline-recognizer.h b/sherpa-onnx/csrc/offline-recognizer.h index 63c23bc2..16d2aa92 100644 --- a/sherpa-onnx/csrc/offline-recognizer.h +++ b/sherpa-onnx/csrc/offline-recognizer.h @@ -14,6 +14,7 @@ #include "android/asset_manager_jni.h" #endif +#include "sherpa-onnx/csrc/offline-ctc-fst-decoder-config.h" #include "sherpa-onnx/csrc/offline-lm-config.h" #include "sherpa-onnx/csrc/offline-model-config.h" #include "sherpa-onnx/csrc/offline-stream.h" @@ -28,6 +29,7 @@ struct OfflineRecognizerConfig { OfflineFeatureExtractorConfig feat_config; OfflineModelConfig model_config; OfflineLMConfig lm_config; + OfflineCtcFstDecoderConfig ctc_fst_decoder_config; std::string decoding_method = "greedy_search"; int32_t max_active_paths = 4; @@ -39,16 +41,16 @@ struct OfflineRecognizerConfig { // TODO(fangjun): Implement modified_beam_search OfflineRecognizerConfig() = default; - OfflineRecognizerConfig(const OfflineFeatureExtractorConfig &feat_config, - const OfflineModelConfig &model_config, - const OfflineLMConfig &lm_config, - const std::string &decoding_method, - int32_t max_active_paths, - const std::string &hotwords_file, - float hotwords_score) + OfflineRecognizerConfig( + const OfflineFeatureExtractorConfig &feat_config, + const OfflineModelConfig &model_config, const OfflineLMConfig &lm_config, + const OfflineCtcFstDecoderConfig &ctc_fst_decoder_config, + const std::string &decoding_method, int32_t max_active_paths, + const std::string &hotwords_file, float hotwords_score) : feat_config(feat_config), model_config(model_config), lm_config(lm_config), + ctc_fst_decoder_config(ctc_fst_decoder_config), decoding_method(decoding_method), max_active_paths(max_active_paths), hotwords_file(hotwords_file), diff --git a/sherpa-onnx/csrc/offline-tdnn-ctc-model.cc b/sherpa-onnx/csrc/offline-tdnn-ctc-model.cc index 824e113c..ea91d1c5 100644 --- a/sherpa-onnx/csrc/offline-tdnn-ctc-model.cc +++ b/sherpa-onnx/csrc/offline-tdnn-ctc-model.cc @@ -4,6 +4,8 @@ #include "sherpa-onnx/csrc/offline-tdnn-ctc-model.h" +#include + #include "sherpa-onnx/csrc/macros.h" #include "sherpa-onnx/csrc/onnx-utils.h" #include "sherpa-onnx/csrc/session.h" @@ -34,7 +36,7 @@ class OfflineTdnnCtcModel::Impl { } #endif - std::pair Forward(Ort::Value features) { + std::vector Forward(Ort::Value features) { auto nnet_out = sess_->Run({}, input_names_ptr_.data(), &features, 1, output_names_ptr_.data(), output_names_ptr_.size()); @@ -52,7 +54,11 @@ class OfflineTdnnCtcModel::Impl { memory_info, out_length_vec.data(), out_length_vec.size(), out_length_shape.data(), out_length_shape.size()); - return {std::move(nnet_out[0]), Clone(Allocator(), &nnet_out_length)}; + std::vector ans; + ans.reserve(2); + ans.push_back(std::move(nnet_out[0])); + ans.push_back(Clone(Allocator(), &nnet_out_length)); + return ans; } int32_t VocabSize() const { return vocab_size_; } @@ -108,7 +114,7 @@ OfflineTdnnCtcModel::OfflineTdnnCtcModel(AAssetManager *mgr, OfflineTdnnCtcModel::~OfflineTdnnCtcModel() = default; -std::pair OfflineTdnnCtcModel::Forward( +std::vector OfflineTdnnCtcModel::Forward( Ort::Value features, Ort::Value /*features_length*/) { return impl_->Forward(std::move(features)); } diff --git a/sherpa-onnx/csrc/offline-tdnn-ctc-model.h b/sherpa-onnx/csrc/offline-tdnn-ctc-model.h index 0527f503..b6b5c7e5 100644 --- a/sherpa-onnx/csrc/offline-tdnn-ctc-model.h +++ b/sherpa-onnx/csrc/offline-tdnn-ctc-model.h @@ -5,7 +5,6 @@ #define SHERPA_ONNX_CSRC_OFFLINE_TDNN_CTC_MODEL_H_ #include #include -#include #include #if __ANDROID_API__ >= 9 @@ -36,7 +35,7 @@ class OfflineTdnnCtcModel : public OfflineCtcModel { /** Run the forward method of the model. * - * @param features A tensor of shape (N, T, C). It is changed in-place. + * @param features A tensor of shape (N, T, C). * @param features_length A 1-D tensor of shape (N,) containing number of * valid frames in `features` before padding. * Its dtype is int64_t. @@ -45,8 +44,8 @@ class OfflineTdnnCtcModel : public OfflineCtcModel { * - log_probs: A 3-D tensor of shape (N, T', vocab_size). * - log_probs_length A 1-D tensor of shape (N,). Its dtype is int64_t */ - std::pair Forward( - Ort::Value features, Ort::Value /*features_length*/) override; + std::vector Forward(Ort::Value features, + Ort::Value /*features_length*/) override; /** Return the vocabulary size of the model */ diff --git a/sherpa-onnx/csrc/offline-zipformer-ctc-model-config.cc b/sherpa-onnx/csrc/offline-zipformer-ctc-model-config.cc new file mode 100644 index 00000000..1c661fcc --- /dev/null +++ b/sherpa-onnx/csrc/offline-zipformer-ctc-model-config.cc @@ -0,0 +1,35 @@ +// sherpa-onnx/csrc/offline-zipformer-ctc-model-config.cc +// +// Copyright (c) 2023 Xiaomi Corporation + +#include "sherpa-onnx/csrc/offline-zipformer-ctc-model-config.h" + +#include "sherpa-onnx/csrc/file-utils.h" +#include "sherpa-onnx/csrc/macros.h" + +namespace sherpa_onnx { + +void OfflineZipformerCtcModelConfig::Register(ParseOptions *po) { + po->Register("zipformer-ctc-model", &model, "Path to zipformer CTC model"); +} + +bool OfflineZipformerCtcModelConfig::Validate() const { + if (!FileExists(model)) { + SHERPA_ONNX_LOGE("zipformer CTC model file %s does not exist", + model.c_str()); + return false; + } + + return true; +} + +std::string OfflineZipformerCtcModelConfig::ToString() const { + std::ostringstream os; + + os << "OfflineZipformerCtcModelConfig("; + os << "model=\"" << model << "\")"; + + return os.str(); +} + +} // namespace sherpa_onnx diff --git a/sherpa-onnx/csrc/offline-zipformer-ctc-model-config.h b/sherpa-onnx/csrc/offline-zipformer-ctc-model-config.h new file mode 100644 index 00000000..702575e7 --- /dev/null +++ b/sherpa-onnx/csrc/offline-zipformer-ctc-model-config.h @@ -0,0 +1,32 @@ +// sherpa-onnx/csrc/offline-zipformer-ctc-model-config.h +// +// Copyright (c) 2023 Xiaomi Corporation +#ifndef SHERPA_ONNX_CSRC_OFFLINE_ZIPFORMER_CTC_MODEL_CONFIG_H_ +#define SHERPA_ONNX_CSRC_OFFLINE_ZIPFORMER_CTC_MODEL_CONFIG_H_ + +#include + +#include "sherpa-onnx/csrc/parse-options.h" + +namespace sherpa_onnx { + +// for +// https://github.com/k2-fsa/icefall/blob/master/egs/librispeech/ASR/zipformer/export-onnx-ctc.py +struct OfflineZipformerCtcModelConfig { + std::string model; + + OfflineZipformerCtcModelConfig() = default; + + explicit OfflineZipformerCtcModelConfig(const std::string &model) + : model(model) {} + + void Register(ParseOptions *po); + + bool Validate() const; + + std::string ToString() const; +}; + +} // namespace sherpa_onnx + +#endif // SHERPA_ONNX_CSRC_OFFLINE_ZIPFORMER_CTC_MODEL_CONFIG_H_ diff --git a/sherpa-onnx/csrc/offline-zipformer-ctc-model.cc b/sherpa-onnx/csrc/offline-zipformer-ctc-model.cc new file mode 100644 index 00000000..a82ef625 --- /dev/null +++ b/sherpa-onnx/csrc/offline-zipformer-ctc-model.cc @@ -0,0 +1,119 @@ +// sherpa-onnx/csrc/offline-zipformer-ctc-model.cc +// +// Copyright (c) 2023 Xiaomi Corporation + +#include "sherpa-onnx/csrc/offline-zipformer-ctc-model.h" + +#include "sherpa-onnx/csrc/macros.h" +#include "sherpa-onnx/csrc/onnx-utils.h" +#include "sherpa-onnx/csrc/session.h" +#include "sherpa-onnx/csrc/text-utils.h" +#include "sherpa-onnx/csrc/transpose.h" + +namespace sherpa_onnx { + +class OfflineZipformerCtcModel::Impl { + public: + explicit Impl(const OfflineModelConfig &config) + : config_(config), + env_(ORT_LOGGING_LEVEL_ERROR), + sess_opts_(GetSessionOptions(config)), + allocator_{} { + auto buf = ReadFile(config_.zipformer_ctc.model); + Init(buf.data(), buf.size()); + } + +#if __ANDROID_API__ >= 9 + Impl(AAssetManager *mgr, const OfflineModelConfig &config) + : config_(config), + env_(ORT_LOGGING_LEVEL_ERROR), + sess_opts_(GetSessionOptions(config)), + allocator_{} { + auto buf = ReadFile(mgr, config_.zipformer_ctc.model); + Init(buf.data(), buf.size()); + } +#endif + + std::vector Forward(Ort::Value features, + Ort::Value features_length) { + std::array inputs = {std::move(features), + std::move(features_length)}; + + return sess_->Run({}, input_names_ptr_.data(), inputs.data(), inputs.size(), + output_names_ptr_.data(), output_names_ptr_.size()); + } + + int32_t VocabSize() const { return vocab_size_; } + int32_t SubsamplingFactor() const { return 4; } + + OrtAllocator *Allocator() const { return allocator_; } + + private: + void Init(void *model_data, size_t model_data_length) { + sess_ = std::make_unique(env_, model_data, model_data_length, + sess_opts_); + + GetInputNames(sess_.get(), &input_names_, &input_names_ptr_); + + GetOutputNames(sess_.get(), &output_names_, &output_names_ptr_); + + // get meta data + Ort::ModelMetadata meta_data = sess_->GetModelMetadata(); + if (config_.debug) { + std::ostringstream os; + PrintModelMetadata(os, meta_data); + SHERPA_ONNX_LOGE("%s\n", os.str().c_str()); + } + + // get vocab size from the output[0].shape, which is (N, T, vocab_size) + vocab_size_ = + sess_->GetOutputTypeInfo(0).GetTensorTypeAndShapeInfo().GetShape()[2]; + } + + private: + OfflineModelConfig config_; + Ort::Env env_; + Ort::SessionOptions sess_opts_; + Ort::AllocatorWithDefaultOptions allocator_; + + std::unique_ptr sess_; + + std::vector input_names_; + std::vector input_names_ptr_; + + std::vector output_names_; + std::vector output_names_ptr_; + + int32_t vocab_size_ = 0; +}; + +OfflineZipformerCtcModel::OfflineZipformerCtcModel( + const OfflineModelConfig &config) + : impl_(std::make_unique(config)) {} + +#if __ANDROID_API__ >= 9 +OfflineZipformerCtcModel::OfflineZipformerCtcModel( + AAssetManager *mgr, const OfflineModelConfig &config) + : impl_(std::make_unique(mgr, config)) {} +#endif + +OfflineZipformerCtcModel::~OfflineZipformerCtcModel() = default; + +std::vector OfflineZipformerCtcModel::Forward( + Ort::Value features, Ort::Value features_length) { + return impl_->Forward(std::move(features), std::move(features_length)); +} + +int32_t OfflineZipformerCtcModel::VocabSize() const { + return impl_->VocabSize(); +} + +OrtAllocator *OfflineZipformerCtcModel::Allocator() const { + return impl_->Allocator(); +} + +int32_t OfflineZipformerCtcModel::SubsamplingFactor() const { + return impl_->SubsamplingFactor(); +} + +} // namespace sherpa_onnx diff --git a/sherpa-onnx/csrc/offline-zipformer-ctc-model.h b/sherpa-onnx/csrc/offline-zipformer-ctc-model.h new file mode 100644 index 00000000..e3b9a05c --- /dev/null +++ b/sherpa-onnx/csrc/offline-zipformer-ctc-model.h @@ -0,0 +1,70 @@ +// sherpa-onnx/csrc/offline-zipformer-ctc-model.h +// +// Copyright (c) 2023 Xiaomi Corporation +#ifndef SHERPA_ONNX_CSRC_OFFLINE_ZIPFORMER_CTC_MODEL_H_ +#define SHERPA_ONNX_CSRC_OFFLINE_ZIPFORMER_CTC_MODEL_H_ +#include +#include +#include +#include + +#if __ANDROID_API__ >= 9 +#include "android/asset_manager.h" +#include "android/asset_manager_jni.h" +#endif + +#include "onnxruntime_cxx_api.h" // NOLINT +#include "sherpa-onnx/csrc/offline-ctc-model.h" +#include "sherpa-onnx/csrc/offline-model-config.h" + +namespace sherpa_onnx { + +/** This class implements the zipformer CTC model of the librispeech recipe + * from icefall. + * + * See + * https://github.com/k2-fsa/icefall/blob/master/egs/librispeech/ASR/zipformer/export-onnx-ctc.py + */ +class OfflineZipformerCtcModel : public OfflineCtcModel { + public: + explicit OfflineZipformerCtcModel(const OfflineModelConfig &config); + +#if __ANDROID_API__ >= 9 + OfflineZipformerCtcModel(AAssetManager *mgr, + const OfflineModelConfig &config); +#endif + + ~OfflineZipformerCtcModel() override; + + /** Run the forward method of the model. + * + * @param features A tensor of shape (N, T, C). + * @param features_length A 1-D tensor of shape (N,) containing number of + * valid frames in `features` before padding. + * Its dtype is int64_t. + * + * @return Return a vector containing: + * - log_probs: A 3-D tensor of shape (N, T', vocab_size). + * - log_probs_length A 1-D tensor of shape (N,). Its dtype is int64_t + */ + std::vector Forward(Ort::Value features, + Ort::Value features_length) override; + + /** Return the vocabulary size of the model + */ + int32_t VocabSize() const override; + + /** Return an allocator for allocating memory + */ + OrtAllocator *Allocator() const override; + + int32_t SubsamplingFactor() const override; + + private: + class Impl; + std::unique_ptr impl_; +}; + +} // namespace sherpa_onnx + +#endif // SHERPA_ONNX_CSRC_OFFLINE_ZIPFORMER_CTC_MODEL_H_ diff --git a/sherpa-onnx/python/csrc/CMakeLists.txt b/sherpa-onnx/python/csrc/CMakeLists.txt index 1973a1c7..a5832d22 100644 --- a/sherpa-onnx/python/csrc/CMakeLists.txt +++ b/sherpa-onnx/python/csrc/CMakeLists.txt @@ -5,6 +5,7 @@ pybind11_add_module(_sherpa_onnx display.cc endpoint.cc features.cc + offline-ctc-fst-decoder-config.cc offline-lm-config.cc offline-model-config.cc offline-nemo-enc-dec-ctc-model-config.cc @@ -14,6 +15,7 @@ pybind11_add_module(_sherpa_onnx offline-tdnn-model-config.cc offline-transducer-model-config.cc offline-whisper-model-config.cc + offline-zipformer-ctc-model-config.cc online-lm-config.cc online-model-config.cc online-paraformer-model-config.cc diff --git a/sherpa-onnx/python/csrc/offline-ctc-fst-decoder-config.cc b/sherpa-onnx/python/csrc/offline-ctc-fst-decoder-config.cc new file mode 100644 index 00000000..cbd18341 --- /dev/null +++ b/sherpa-onnx/python/csrc/offline-ctc-fst-decoder-config.cc @@ -0,0 +1,23 @@ +// sherpa-onnx/python/csrc/offline-ctc-fst-decoder-config.cc +// +// Copyright (c) 2023 Xiaomi Corporation + +#include "sherpa-onnx/python/csrc/offline-ctc-fst-decoder-config.h" + +#include + +#include "sherpa-onnx/csrc/offline-ctc-fst-decoder-config.h" + +namespace sherpa_onnx { + +void PybindOfflineCtcFstDecoderConfig(py::module *m) { + using PyClass = OfflineCtcFstDecoderConfig; + py::class_(*m, "OfflineCtcFstDecoderConfig") + .def(py::init(), py::arg("graph") = "", + py::arg("max_active") = 3000) + .def_readwrite("graph", &PyClass::graph) + .def_readwrite("max_active", &PyClass::max_active) + .def("__str__", &PyClass::ToString); +} + +} // namespace sherpa_onnx diff --git a/sherpa-onnx/python/csrc/offline-ctc-fst-decoder-config.h b/sherpa-onnx/python/csrc/offline-ctc-fst-decoder-config.h new file mode 100644 index 00000000..4941c82f --- /dev/null +++ b/sherpa-onnx/python/csrc/offline-ctc-fst-decoder-config.h @@ -0,0 +1,16 @@ +// sherpa-onnx/python/csrc/offline-ctc-fst-decoder-config.h +// +// Copyright (c) 2023 Xiaomi Corporation + +#ifndef SHERPA_ONNX_PYTHON_CSRC_OFFLINE_CTC_FST_DECODER_CONFIG_H_ +#define SHERPA_ONNX_PYTHON_CSRC_OFFLINE_CTC_FST_DECODER_CONFIG_H_ + +#include "sherpa-onnx/python/csrc/sherpa-onnx.h" + +namespace sherpa_onnx { + +void PybindOfflineCtcFstDecoderConfig(py::module *m); + +} + +#endif // SHERPA_ONNX_PYTHON_CSRC_OFFLINE_CTC_FST_DECODER_CONFIG_H_ diff --git a/sherpa-onnx/python/csrc/offline-model-config.cc b/sherpa-onnx/python/csrc/offline-model-config.cc index 4ed0483c..970d83b2 100644 --- a/sherpa-onnx/python/csrc/offline-model-config.cc +++ b/sherpa-onnx/python/csrc/offline-model-config.cc @@ -13,6 +13,7 @@ #include "sherpa-onnx/python/csrc/offline-tdnn-model-config.h" #include "sherpa-onnx/python/csrc/offline-transducer-model-config.h" #include "sherpa-onnx/python/csrc/offline-whisper-model-config.h" +#include "sherpa-onnx/python/csrc/offline-zipformer-ctc-model-config.h" namespace sherpa_onnx { @@ -22,6 +23,7 @@ void PybindOfflineModelConfig(py::module *m) { PybindOfflineNemoEncDecCtcModelConfig(m); PybindOfflineWhisperModelConfig(m); PybindOfflineTdnnModelConfig(m); + PybindOfflineZipformerCtcModelConfig(m); using PyClass = OfflineModelConfig; py::class_(*m, "OfflineModelConfig") @@ -29,20 +31,23 @@ void PybindOfflineModelConfig(py::module *m) { const OfflineParaformerModelConfig &, const OfflineNemoEncDecCtcModelConfig &, const OfflineWhisperModelConfig &, - const OfflineTdnnModelConfig &, const std::string &, + const OfflineTdnnModelConfig &, + const OfflineZipformerCtcModelConfig &, const std::string &, int32_t, bool, const std::string &, const std::string &>(), py::arg("transducer") = OfflineTransducerModelConfig(), py::arg("paraformer") = OfflineParaformerModelConfig(), py::arg("nemo_ctc") = OfflineNemoEncDecCtcModelConfig(), py::arg("whisper") = OfflineWhisperModelConfig(), - py::arg("tdnn") = OfflineTdnnModelConfig(), py::arg("tokens"), - py::arg("num_threads"), py::arg("debug") = false, + py::arg("tdnn") = OfflineTdnnModelConfig(), + py::arg("zipformer_ctc") = OfflineZipformerCtcModelConfig(), + py::arg("tokens"), py::arg("num_threads"), py::arg("debug") = false, py::arg("provider") = "cpu", py::arg("model_type") = "") .def_readwrite("transducer", &PyClass::transducer) .def_readwrite("paraformer", &PyClass::paraformer) .def_readwrite("nemo_ctc", &PyClass::nemo_ctc) .def_readwrite("whisper", &PyClass::whisper) .def_readwrite("tdnn", &PyClass::tdnn) + .def_readwrite("zipformer_ctc", &PyClass::zipformer_ctc) .def_readwrite("tokens", &PyClass::tokens) .def_readwrite("num_threads", &PyClass::num_threads) .def_readwrite("debug", &PyClass::debug) diff --git a/sherpa-onnx/python/csrc/offline-recognizer.cc b/sherpa-onnx/python/csrc/offline-recognizer.cc index dbeec96c..3c3cc043 100644 --- a/sherpa-onnx/python/csrc/offline-recognizer.cc +++ b/sherpa-onnx/python/csrc/offline-recognizer.cc @@ -16,15 +16,18 @@ static void PybindOfflineRecognizerConfig(py::module *m) { py::class_(*m, "OfflineRecognizerConfig") .def(py::init(), + const OfflineCtcFstDecoderConfig &, const std::string &, + int32_t, const std::string &, float>(), py::arg("feat_config"), py::arg("model_config"), py::arg("lm_config") = OfflineLMConfig(), + py::arg("ctc_fst_decoder_config") = OfflineCtcFstDecoderConfig(), py::arg("decoding_method") = "greedy_search", py::arg("max_active_paths") = 4, py::arg("hotwords_file") = "", py::arg("hotwords_score") = 1.5) .def_readwrite("feat_config", &PyClass::feat_config) .def_readwrite("model_config", &PyClass::model_config) .def_readwrite("lm_config", &PyClass::lm_config) + .def_readwrite("ctc_fst_decoder_config", &PyClass::ctc_fst_decoder_config) .def_readwrite("decoding_method", &PyClass::decoding_method) .def_readwrite("max_active_paths", &PyClass::max_active_paths) .def_readwrite("hotwords_file", &PyClass::hotwords_file) diff --git a/sherpa-onnx/python/csrc/offline-zipformer-ctc-model-config.cc b/sherpa-onnx/python/csrc/offline-zipformer-ctc-model-config.cc new file mode 100644 index 00000000..75409225 --- /dev/null +++ b/sherpa-onnx/python/csrc/offline-zipformer-ctc-model-config.cc @@ -0,0 +1,22 @@ +// sherpa-onnx/python/csrc/offline-zipformer-ctc-model-config.cc +// +// Copyright (c) 2023 Xiaomi Corporation + +#include "sherpa-onnx/python/csrc/offline-zipformer-ctc-model-config.h" + +#include + +#include "sherpa-onnx/csrc/offline-zipformer-ctc-model-config.h" + +namespace sherpa_onnx { + +void PybindOfflineZipformerCtcModelConfig(py::module *m) { + using PyClass = OfflineZipformerCtcModelConfig; + py::class_(*m, "OfflineZipformerCtcModelConfig") + .def(py::init<>()) + .def(py::init(), py::arg("model")) + .def_readwrite("model", &PyClass::model) + .def("__str__", &PyClass::ToString); +} + +} // namespace sherpa_onnx diff --git a/sherpa-onnx/python/csrc/offline-zipformer-ctc-model-config.h b/sherpa-onnx/python/csrc/offline-zipformer-ctc-model-config.h new file mode 100644 index 00000000..716d6a5b --- /dev/null +++ b/sherpa-onnx/python/csrc/offline-zipformer-ctc-model-config.h @@ -0,0 +1,16 @@ +// sherpa-onnx/python/csrc/offline-zipformer-ctc-model-config.h +// +// Copyright (c) 2023 Xiaomi Corporation + +#ifndef SHERPA_ONNX_PYTHON_CSRC_OFFLINE_ZIPFORMER_CTC_MODEL_CONFIG_H_ +#define SHERPA_ONNX_PYTHON_CSRC_OFFLINE_ZIPFORMER_CTC_MODEL_CONFIG_H_ + +#include "sherpa-onnx/python/csrc/sherpa-onnx.h" + +namespace sherpa_onnx { + +void PybindOfflineZipformerCtcModelConfig(py::module *m); + +} + +#endif // SHERPA_ONNX_PYTHON_CSRC_OFFLINE_ZIPFORMER_CTC_MODEL_CONFIG_H_ diff --git a/sherpa-onnx/python/csrc/sherpa-onnx.cc b/sherpa-onnx/python/csrc/sherpa-onnx.cc index 98547df8..27f5f827 100644 --- a/sherpa-onnx/python/csrc/sherpa-onnx.cc +++ b/sherpa-onnx/python/csrc/sherpa-onnx.cc @@ -8,6 +8,7 @@ #include "sherpa-onnx/python/csrc/display.h" #include "sherpa-onnx/python/csrc/endpoint.h" #include "sherpa-onnx/python/csrc/features.h" +#include "sherpa-onnx/python/csrc/offline-ctc-fst-decoder-config.h" #include "sherpa-onnx/python/csrc/offline-lm-config.h" #include "sherpa-onnx/python/csrc/offline-model-config.h" #include "sherpa-onnx/python/csrc/offline-recognizer.h" @@ -37,6 +38,7 @@ PYBIND11_MODULE(_sherpa_onnx, m) { PybindOfflineStream(&m); PybindOfflineLMConfig(&m); PybindOfflineModelConfig(&m); + PybindOfflineCtcFstDecoderConfig(&m); PybindOfflineRecognizer(&m); PybindVadModelConfig(&m); diff --git a/sherpa-onnx/python/sherpa_onnx/offline_recognizer.py b/sherpa-onnx/python/sherpa_onnx/offline_recognizer.py index e1c82279..21bd8d58 100644 --- a/sherpa-onnx/python/sherpa_onnx/offline_recognizer.py +++ b/sherpa-onnx/python/sherpa_onnx/offline_recognizer.py @@ -4,12 +4,14 @@ from pathlib import Path from typing import List, Optional from _sherpa_onnx import ( + OfflineCtcFstDecoderConfig, OfflineFeatureExtractorConfig, OfflineModelConfig, OfflineNemoEncDecCtcModelConfig, OfflineParaformerModelConfig, OfflineTdnnModelConfig, OfflineWhisperModelConfig, + OfflineZipformerCtcModelConfig, ) from _sherpa_onnx import OfflineRecognizer as _Recognizer from _sherpa_onnx import (