Add CTC HLG decoding using OpenFst (#349)
This commit is contained in:
45
.github/scripts/test-offline-ctc.sh
vendored
45
.github/scripts/test-offline-ctc.sh
vendored
@@ -89,3 +89,48 @@ time $EXE \
|
||||
$repo/test_wavs/8k.wav
|
||||
|
||||
rm -rf $repo
|
||||
|
||||
log "------------------------------------------------------------"
|
||||
log "Run Librispeech zipformer CTC H/HL/HLG decoding (English) "
|
||||
log "------------------------------------------------------------"
|
||||
repo_url=https://huggingface.co/csukuangfj/sherpa-onnx-zipformer-ctc-en-2023-10-02
|
||||
log "Start testing ${repo_url}"
|
||||
repo=$(basename $repo_url)
|
||||
log "Download pretrained model and test-data from $repo_url"
|
||||
|
||||
GIT_LFS_SKIP_SMUDGE=1 git clone $repo_url
|
||||
pushd $repo
|
||||
git lfs pull --include "*.onnx"
|
||||
git lfs pull --include "*.fst"
|
||||
ls -lh
|
||||
popd
|
||||
|
||||
graphs=(
|
||||
$repo/H.fst
|
||||
$repo/HL.fst
|
||||
$repo/HLG.fst
|
||||
)
|
||||
|
||||
for graph in ${graphs[@]}; do
|
||||
log "test float32 models with $graph"
|
||||
time $EXE \
|
||||
--model-type=zipformer2_ctc \
|
||||
--ctc.graph=$graph \
|
||||
--zipformer-ctc-model=$repo/model.onnx \
|
||||
--tokens=$repo/tokens.txt \
|
||||
$repo/test_wavs/0.wav \
|
||||
$repo/test_wavs/1.wav \
|
||||
$repo/test_wavs/2.wav
|
||||
|
||||
log "test int8 models with $graph"
|
||||
time $EXE \
|
||||
--model-type=zipformer2_ctc \
|
||||
--ctc.graph=$graph \
|
||||
--zipformer-ctc-model=$repo/model.int8.onnx \
|
||||
--tokens=$repo/tokens.txt \
|
||||
$repo/test_wavs/0.wav \
|
||||
$repo/test_wavs/1.wav \
|
||||
$repo/test_wavs/2.wav
|
||||
done
|
||||
|
||||
rm -rf $repo
|
||||
|
||||
@@ -18,7 +18,7 @@ permissions:
|
||||
jobs:
|
||||
python_online_websocket_server:
|
||||
runs-on: ${{ matrix.os }}
|
||||
name: ${{ matrix.os }} ${{ matrix.python-version }}
|
||||
name: ${{ matrix.os }} ${{ matrix.python-version }} ${{ matrix.model_type }}
|
||||
strategy:
|
||||
fail-fast: false
|
||||
matrix:
|
||||
|
||||
@@ -154,6 +154,7 @@ if(NOT BUILD_SHARED_LIBS AND CMAKE_SYSTEM_NAME STREQUAL Linux)
|
||||
endif()
|
||||
|
||||
include(kaldi-native-fbank)
|
||||
include(kaldi-decoder)
|
||||
include(onnxruntime)
|
||||
|
||||
if(SHERPA_ONNX_ENABLE_PORTAUDIO)
|
||||
|
||||
48
cmake/eigen.cmake
Normal file
48
cmake/eigen.cmake
Normal file
@@ -0,0 +1,48 @@
|
||||
function(download_eigen)
|
||||
include(FetchContent)
|
||||
|
||||
set(eigen_URL "https://gitlab.com/libeigen/eigen/-/archive/3.4.0/eigen-3.4.0.tar.gz")
|
||||
set(eigen_URL2 "https://huggingface.co/csukuangfj/kaldi-hmm-gmm-cmake-deps/resolve/main/eigen-3.4.0.tar.gz")
|
||||
set(eigen_HASH "SHA256=8586084f71f9bde545ee7fa6d00288b264a2b7ac3607b974e54d13e7162c1c72")
|
||||
|
||||
# If you don't have access to the Internet,
|
||||
# please pre-download eigen
|
||||
set(possible_file_locations
|
||||
$ENV{HOME}/Downloads/eigen-3.4.0.tar.gz
|
||||
${PROJECT_SOURCE_DIR}/eigen-3.4.0.tar.gz
|
||||
${PROJECT_BINARY_DIR}/eigen-3.4.0.tar.gz
|
||||
/tmp/eigen-3.4.0.tar.gz
|
||||
/star-fj/fangjun/download/github/eigen-3.4.0.tar.gz
|
||||
)
|
||||
|
||||
foreach(f IN LISTS possible_file_locations)
|
||||
if(EXISTS ${f})
|
||||
set(eigen_URL "${f}")
|
||||
file(TO_CMAKE_PATH "${eigen_URL}" eigen_URL)
|
||||
message(STATUS "Found local downloaded eigen: ${eigen_URL}")
|
||||
set(eigen_URL2)
|
||||
break()
|
||||
endif()
|
||||
endforeach()
|
||||
|
||||
set(BUILD_TESTING OFF CACHE BOOL "" FORCE)
|
||||
set(EIGEN_BUILD_DOC OFF CACHE BOOL "" FORCE)
|
||||
|
||||
FetchContent_Declare(eigen
|
||||
URL ${eigen_URL}
|
||||
URL_HASH ${eigen_HASH}
|
||||
)
|
||||
|
||||
FetchContent_GetProperties(eigen)
|
||||
if(NOT eigen_POPULATED)
|
||||
message(STATUS "Downloading eigen from ${eigen_URL}")
|
||||
FetchContent_Populate(eigen)
|
||||
endif()
|
||||
message(STATUS "eigen is downloaded to ${eigen_SOURCE_DIR}")
|
||||
message(STATUS "eigen's binary dir is ${eigen_BINARY_DIR}")
|
||||
|
||||
add_subdirectory(${eigen_SOURCE_DIR} ${eigen_BINARY_DIR} EXCLUDE_FROM_ALL)
|
||||
endfunction()
|
||||
|
||||
download_eigen()
|
||||
|
||||
78
cmake/kaldi-decoder.cmake
Normal file
78
cmake/kaldi-decoder.cmake
Normal file
@@ -0,0 +1,78 @@
|
||||
function(download_kaldi_decoder)
|
||||
include(FetchContent)
|
||||
|
||||
set(kaldi_decoder_URL "https://github.com/k2-fsa/kaldi-decoder/archive/refs/tags/v0.2.3.tar.gz")
|
||||
set(kaldi_decoder_URL2 "https://huggingface.co/csukuangfj/sherpa-onnx-cmake-deps/resolve/main/kaldi-decoder-0.2.3.tar.gz")
|
||||
set(kaldi_decoder_HASH "SHA256=98bf445a5b7961ccf3c3522317d900054eaadb6a9cdcf4531e7d9caece94a56d")
|
||||
|
||||
set(KALDI_DECODER_BUILD_PYTHON OFF CACHE BOOL "" FORCE)
|
||||
set(KALDI_DECODER_BUILD_PYTHON OFF CACHE BOOL "" FORCE)
|
||||
set(KALDIFST_BUILD_PYTHON OFF CACHE BOOL "" FORCE)
|
||||
|
||||
# If you don't have access to the Internet,
|
||||
# please pre-download kaldi-decoder
|
||||
set(possible_file_locations
|
||||
$ENV{HOME}/Downloads/kaldi-decoder-0.2.3.tar.gz
|
||||
${PROJECT_SOURCE_DIR}/kaldi-decoder-0.2.3.tar.gz
|
||||
${PROJECT_BINARY_DIR}/kaldi-decoder-0.2.3.tar.gz
|
||||
/tmp/kaldi-decoder-0.2.3.tar.gz
|
||||
/star-fj/fangjun/download/github/kaldi-decoder-0.2.3.tar.gz
|
||||
)
|
||||
|
||||
foreach(f IN LISTS possible_file_locations)
|
||||
if(EXISTS ${f})
|
||||
set(kaldi_decoder_URL "${f}")
|
||||
file(TO_CMAKE_PATH "${kaldi_decoder_URL}" kaldi_decoder_URL)
|
||||
message(STATUS "Found local downloaded kaldi-decoder: ${kaldi_decoder_URL}")
|
||||
set(kaldi_decoder_URL2 )
|
||||
break()
|
||||
endif()
|
||||
endforeach()
|
||||
|
||||
FetchContent_Declare(kaldi_decoder
|
||||
URL
|
||||
${kaldi_decoder_URL}
|
||||
${kaldi_decoder_URL2}
|
||||
URL_HASH ${kaldi_decoder_HASH}
|
||||
)
|
||||
|
||||
FetchContent_GetProperties(kaldi_decoder)
|
||||
if(NOT kaldi_decoder_POPULATED)
|
||||
message(STATUS "Downloading kaldi-decoder from ${kaldi_decoder_URL}")
|
||||
FetchContent_Populate(kaldi_decoder)
|
||||
endif()
|
||||
message(STATUS "kaldi-decoder is downloaded to ${kaldi_decoder_SOURCE_DIR}")
|
||||
message(STATUS "kaldi-decoder's binary dir is ${kaldi_decoder_BINARY_DIR}")
|
||||
|
||||
include_directories(${kaldi_decoder_SOURCE_DIR})
|
||||
add_subdirectory(${kaldi_decoder_SOURCE_DIR} ${kaldi_decoder_BINARY_DIR} EXCLUDE_FROM_ALL)
|
||||
|
||||
target_include_directories(kaldi-decoder-core
|
||||
INTERFACE
|
||||
${kaldi-decoder_SOURCE_DIR}/
|
||||
)
|
||||
if(SHERPA_ONNX_ENABLE_PYTHON AND WIN32)
|
||||
install(TARGETS
|
||||
kaldi-decoder-core
|
||||
kaldifst_core
|
||||
fst
|
||||
DESTINATION ..)
|
||||
else()
|
||||
install(TARGETS
|
||||
kaldi-decoder-core
|
||||
kaldifst_core
|
||||
fst
|
||||
DESTINATION lib)
|
||||
endif()
|
||||
|
||||
if(WIN32 AND BUILD_SHARED_LIBS)
|
||||
install(TARGETS
|
||||
kaldi-decoder-core
|
||||
kaldifst_core
|
||||
fst
|
||||
DESTINATION bin)
|
||||
endif()
|
||||
endfunction()
|
||||
|
||||
download_kaldi_decoder()
|
||||
|
||||
62
cmake/kaldifst.cmake
Normal file
62
cmake/kaldifst.cmake
Normal file
@@ -0,0 +1,62 @@
|
||||
function(download_kaldifst)
|
||||
include(FetchContent)
|
||||
|
||||
set(kaldifst_URL "https://github.com/k2-fsa/kaldifst/archive/refs/tags/v1.7.6.tar.gz")
|
||||
set(kaldifst_URL2 "https://huggingface.co/csukuangfj/kaldi-hmm-gmm-cmake-deps/resolve/main/kaldifst-1.7.6.tar.gz")
|
||||
set(kaldifst_HASH "SHA256=79280c0bb08b5ed1a2ab7c21320a2b071f1f0eb10d2f047e8d6f027f0d32b4d2")
|
||||
|
||||
# If you don't have access to the Internet,
|
||||
# please pre-download kaldifst
|
||||
set(possible_file_locations
|
||||
$ENV{HOME}/Downloads/kaldifst-1.7.6.tar.gz
|
||||
${PROJECT_SOURCE_DIR}/kaldifst-1.7.6.tar.gz
|
||||
${PROJECT_BINARY_DIR}/kaldifst-1.7.6.tar.gz
|
||||
/tmp/kaldifst-1.7.6.tar.gz
|
||||
/star-fj/fangjun/download/github/kaldifst-1.7.6.tar.gz
|
||||
)
|
||||
|
||||
foreach(f IN LISTS possible_file_locations)
|
||||
if(EXISTS ${f})
|
||||
set(kaldifst_URL "${f}")
|
||||
file(TO_CMAKE_PATH "${kaldifst_URL}" kaldifst_URL)
|
||||
message(STATUS "Found local downloaded kaldifst: ${kaldifst_URL}")
|
||||
set(kaldifst_URL2)
|
||||
break()
|
||||
endif()
|
||||
endforeach()
|
||||
|
||||
set(KALDIFST_BUILD_TESTS OFF CACHE BOOL "" FORCE)
|
||||
set(KALDIFST_BUILD_PYTHON OFF CACHE BOOL "" FORCE)
|
||||
|
||||
FetchContent_Declare(kaldifst
|
||||
URL ${kaldifst_URL}
|
||||
URL_HASH ${kaldifst_HASH}
|
||||
)
|
||||
|
||||
FetchContent_GetProperties(kaldifst)
|
||||
if(NOT kaldifst_POPULATED)
|
||||
message(STATUS "Downloading kaldifst from ${kaldifst_URL}")
|
||||
FetchContent_Populate(kaldifst)
|
||||
endif()
|
||||
message(STATUS "kaldifst is downloaded to ${kaldifst_SOURCE_DIR}")
|
||||
message(STATUS "kaldifst's binary dir is ${kaldifst_BINARY_DIR}")
|
||||
|
||||
list(APPEND CMAKE_MODULE_PATH ${kaldifst_SOURCE_DIR}/cmake)
|
||||
|
||||
add_subdirectory(${kaldifst_SOURCE_DIR} ${kaldifst_BINARY_DIR} EXCLUDE_FROM_ALL)
|
||||
|
||||
target_include_directories(kaldifst_core
|
||||
PUBLIC
|
||||
${kaldifst_SOURCE_DIR}/
|
||||
)
|
||||
|
||||
target_include_directories(fst
|
||||
PUBLIC
|
||||
${openfst_SOURCE_DIR}/src/include
|
||||
)
|
||||
|
||||
set_target_properties(kaldifst_core PROPERTIES OUTPUT_NAME "sherpa-onnx-kaldifst-core")
|
||||
set_target_properties(fst PROPERTIES OUTPUT_NAME "sherpa-onnx-fst")
|
||||
endfunction()
|
||||
|
||||
download_kaldifst()
|
||||
@@ -13,4 +13,4 @@ Cflags: -I"${includedir}"
|
||||
# Note: -lcargs is required only for the following file
|
||||
# https://github.com/k2-fsa/sherpa-onnx/blob/master/c-api-examples/decode-file-c-api.c
|
||||
# We add it here so that users don't need to specify -lcargs when compiling decode-file-c-api.c
|
||||
Libs: -L"${libdir}" -lsherpa-onnx-c-api -lsherpa-onnx-core -lonnxruntime -lkaldi-native-fbank-core -lcargs -Wl,-rpath,${libdir} @SHERPA_ONNX_PKG_CONFIG_EXTRA_LIBS@
|
||||
Libs: -L"${libdir}" -lsherpa-onnx-c-api -lsherpa-onnx-core -lonnxruntime -lkaldi-decoder-core -lsherpa-onnx-kaldifst-core -lsherpa-onnx-fst -lkaldi-native-fbank-core -lcargs -Wl,-rpath,${libdir} @SHERPA_ONNX_PKG_CONFIG_EXTRA_LIBS@
|
||||
|
||||
@@ -9,6 +9,9 @@
|
||||
sherpa-onnx-portaudio_static.lib;
|
||||
sherpa-onnx-c-api.lib;
|
||||
sherpa-onnx-core.lib;
|
||||
kaldi-decoder-core.lib;
|
||||
sherpa-onnx-kaldifst-core.lib;
|
||||
sherpa-onnx-fst.lib;
|
||||
kaldi-native-fbank-core.lib;
|
||||
absl_base.lib;
|
||||
absl_city.lib;
|
||||
|
||||
@@ -9,6 +9,9 @@
|
||||
sherpa-onnx-portaudio_static.lib;
|
||||
sherpa-onnx-c-api.lib;
|
||||
sherpa-onnx-core.lib;
|
||||
kaldi-decoder-core.lib;
|
||||
sherpa-onnx-kaldifst-core.lib;
|
||||
sherpa-onnx-fst.lib;
|
||||
kaldi-native-fbank-core.lib;
|
||||
absl_base.lib;
|
||||
absl_city.lib;
|
||||
|
||||
@@ -19,6 +19,8 @@ set(sources
|
||||
features.cc
|
||||
file-utils.cc
|
||||
hypothesis.cc
|
||||
offline-ctc-fst-decoder-config.cc
|
||||
offline-ctc-fst-decoder.cc
|
||||
offline-ctc-greedy-search-decoder.cc
|
||||
offline-ctc-model.cc
|
||||
offline-lm-config.cc
|
||||
@@ -42,6 +44,8 @@ set(sources
|
||||
offline-whisper-greedy-search-decoder.cc
|
||||
offline-whisper-model-config.cc
|
||||
offline-whisper-model.cc
|
||||
offline-zipformer-ctc-model-config.cc
|
||||
offline-zipformer-ctc-model.cc
|
||||
online-conformer-transducer-model.cc
|
||||
online-lm-config.cc
|
||||
online-lm.cc
|
||||
@@ -97,6 +101,8 @@ endif()
|
||||
|
||||
target_link_libraries(sherpa-onnx-core kaldi-native-fbank-core)
|
||||
|
||||
target_link_libraries(sherpa-onnx-core kaldi-decoder-core)
|
||||
|
||||
if(BUILD_SHARED_LIBS OR APPLE OR CMAKE_SYSTEM_PROCESSOR STREQUAL aarch64 OR CMAKE_SYSTEM_PROCESSOR STREQUAL arm)
|
||||
target_link_libraries(sherpa-onnx-core onnxruntime)
|
||||
else()
|
||||
|
||||
32
sherpa-onnx/csrc/offline-ctc-fst-decoder-config.cc
Normal file
32
sherpa-onnx/csrc/offline-ctc-fst-decoder-config.cc
Normal file
@@ -0,0 +1,32 @@
|
||||
// sherpa-onnx/csrc/offline-ctc-fst-decoder-config.cc
|
||||
//
|
||||
// Copyright (c) 2023 Xiaomi Corporation
|
||||
|
||||
#include "sherpa-onnx/csrc/offline-ctc-fst-decoder-config.h"
|
||||
|
||||
#include <sstream>
|
||||
#include <string>
|
||||
|
||||
namespace sherpa_onnx {
|
||||
|
||||
std::string OfflineCtcFstDecoderConfig::ToString() const {
|
||||
std::ostringstream os;
|
||||
|
||||
os << "OfflineCtcFstDecoderConfig(";
|
||||
os << "graph=\"" << graph << "\", ";
|
||||
os << "max_active=" << max_active << ")";
|
||||
|
||||
return os.str();
|
||||
}
|
||||
|
||||
void OfflineCtcFstDecoderConfig::Register(ParseOptions *po) {
|
||||
std::string prefix = "ctc";
|
||||
ParseOptions p(prefix, po);
|
||||
|
||||
p.Register("graph", &graph, "Path to H.fst, HL.fst, or HLG.fst");
|
||||
|
||||
p.Register("max-active", &max_active,
|
||||
"Decoder max active states. Larger->slower; more accurate");
|
||||
}
|
||||
|
||||
} // namespace sherpa_onnx
|
||||
31
sherpa-onnx/csrc/offline-ctc-fst-decoder-config.h
Normal file
31
sherpa-onnx/csrc/offline-ctc-fst-decoder-config.h
Normal file
@@ -0,0 +1,31 @@
|
||||
// sherpa-onnx/csrc/offline-ctc-fst-decoder-config.h
|
||||
//
|
||||
// Copyright (c) 2023 Xiaomi Corporation
|
||||
|
||||
#ifndef SHERPA_ONNX_CSRC_OFFLINE_CTC_FST_DECODER_CONFIG_H_
|
||||
#define SHERPA_ONNX_CSRC_OFFLINE_CTC_FST_DECODER_CONFIG_H_
|
||||
|
||||
#include <string>
|
||||
|
||||
#include "sherpa-onnx/csrc/parse-options.h"
|
||||
|
||||
namespace sherpa_onnx {
|
||||
|
||||
struct OfflineCtcFstDecoderConfig {
|
||||
// Path to H.fst, HL.fst or HLG.fst
|
||||
std::string graph;
|
||||
int32_t max_active = 3000;
|
||||
|
||||
OfflineCtcFstDecoderConfig() = default;
|
||||
|
||||
OfflineCtcFstDecoderConfig(const std::string &graph, int32_t max_active)
|
||||
: graph(graph), max_active(max_active) {}
|
||||
|
||||
std::string ToString() const;
|
||||
|
||||
void Register(ParseOptions *po);
|
||||
};
|
||||
|
||||
} // namespace sherpa_onnx
|
||||
|
||||
#endif // SHERPA_ONNX_CSRC_OFFLINE_CTC_FST_DECODER_CONFIG_H_
|
||||
157
sherpa-onnx/csrc/offline-ctc-fst-decoder.cc
Normal file
157
sherpa-onnx/csrc/offline-ctc-fst-decoder.cc
Normal file
@@ -0,0 +1,157 @@
|
||||
// sherpa-onnx/csrc/offline-ctc-fst-decoder.cc
|
||||
//
|
||||
// Copyright (c) 2023 Xiaomi Corporation
|
||||
|
||||
#include "sherpa-onnx/csrc/offline-ctc-fst-decoder.h"
|
||||
|
||||
#include <string>
|
||||
#include <utility>
|
||||
|
||||
#include "fst/fstlib.h"
|
||||
#include "kaldi-decoder/csrc/decodable-ctc.h"
|
||||
#include "kaldi-decoder/csrc/eigen.h"
|
||||
#include "kaldi-decoder/csrc/faster-decoder.h"
|
||||
#include "sherpa-onnx/csrc/macros.h"
|
||||
|
||||
namespace sherpa_onnx {
|
||||
|
||||
// This function is copied from kaldi.
|
||||
//
|
||||
// @param filename Path to a StdVectorFst or StdConstFst graph
|
||||
// @return The caller should free the returned pointer using `delete` to
|
||||
// avoid memory leak.
|
||||
static fst::Fst<fst::StdArc> *ReadGraph(const std::string &filename) {
|
||||
// read decoding network FST
|
||||
std::ifstream is(filename, std::ios::binary);
|
||||
if (!is.good()) {
|
||||
SHERPA_ONNX_LOGE("Could not open decoding-graph FST %s", filename.c_str());
|
||||
}
|
||||
|
||||
fst::FstHeader hdr;
|
||||
if (!hdr.Read(is, "<unknown>")) {
|
||||
SHERPA_ONNX_LOGE("Reading FST: error reading FST header.");
|
||||
}
|
||||
|
||||
if (hdr.ArcType() != fst::StdArc::Type()) {
|
||||
SHERPA_ONNX_LOGE("FST with arc type %s not supported",
|
||||
hdr.ArcType().c_str());
|
||||
}
|
||||
fst::FstReadOptions ropts("<unspecified>", &hdr);
|
||||
|
||||
fst::Fst<fst::StdArc> *decode_fst = nullptr;
|
||||
|
||||
if (hdr.FstType() == "vector") {
|
||||
decode_fst = fst::VectorFst<fst::StdArc>::Read(is, ropts);
|
||||
} else if (hdr.FstType() == "const") {
|
||||
decode_fst = fst::ConstFst<fst::StdArc>::Read(is, ropts);
|
||||
} else {
|
||||
SHERPA_ONNX_LOGE("Reading FST: unsupported FST type: %s",
|
||||
hdr.FstType().c_str());
|
||||
}
|
||||
|
||||
if (decode_fst == nullptr) { // fst code will warn.
|
||||
SHERPA_ONNX_LOGE("Error reading FST (after reading header).");
|
||||
return nullptr;
|
||||
} else {
|
||||
return decode_fst;
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* @param decoder
|
||||
* @param p Pointer to a 2-d array of shape (num_frames, vocab_size)
|
||||
* @param num_frames Number of rows in the 2-d array.
|
||||
* @param vocab_size Number of columns in the 2-d array.
|
||||
* @return Return the decoded result.
|
||||
*/
|
||||
static OfflineCtcDecoderResult DecodeOne(kaldi_decoder::FasterDecoder *decoder,
|
||||
const float *p, int32_t num_frames,
|
||||
int32_t vocab_size) {
|
||||
OfflineCtcDecoderResult r;
|
||||
kaldi_decoder::DecodableCtc decodable(p, num_frames, vocab_size);
|
||||
|
||||
decoder->Decode(&decodable);
|
||||
|
||||
if (!decoder->ReachedFinal()) {
|
||||
SHERPA_ONNX_LOGE("Not reached final!");
|
||||
return r;
|
||||
}
|
||||
|
||||
fst::VectorFst<fst::LatticeArc> decoded; // linear FST.
|
||||
decoder->GetBestPath(&decoded);
|
||||
|
||||
if (decoded.NumStates() == 0) {
|
||||
SHERPA_ONNX_LOGE("Empty best path!");
|
||||
return r;
|
||||
}
|
||||
|
||||
auto cur_state = decoded.Start();
|
||||
|
||||
int32_t blank_id = 0;
|
||||
|
||||
for (int32_t t = 0, prev = -1; decoded.NumArcs(cur_state) == 1; ++t) {
|
||||
fst::ArcIterator<fst::Fst<fst::LatticeArc>> iter(decoded, cur_state);
|
||||
const auto &arc = iter.Value();
|
||||
|
||||
cur_state = arc.nextstate;
|
||||
|
||||
if (arc.ilabel == prev) {
|
||||
continue;
|
||||
}
|
||||
|
||||
// 0 is epsilon here
|
||||
if (arc.ilabel == 0 || arc.ilabel == blank_id + 1) {
|
||||
prev = arc.ilabel;
|
||||
continue;
|
||||
}
|
||||
|
||||
// -1 here since the input labels are incremented during graph
|
||||
// construction
|
||||
r.tokens.push_back(arc.ilabel - 1);
|
||||
|
||||
r.timestamps.push_back(t);
|
||||
prev = arc.ilabel;
|
||||
}
|
||||
|
||||
return r;
|
||||
}
|
||||
|
||||
OfflineCtcFstDecoder::OfflineCtcFstDecoder(
|
||||
const OfflineCtcFstDecoderConfig &config)
|
||||
: config_(config), fst_(ReadGraph(config_.graph)) {}
|
||||
|
||||
std::vector<OfflineCtcDecoderResult> OfflineCtcFstDecoder::Decode(
|
||||
Ort::Value log_probs, Ort::Value log_probs_length) {
|
||||
std::vector<int64_t> shape = log_probs.GetTensorTypeAndShapeInfo().GetShape();
|
||||
|
||||
assert(static_cast<int32_t>(shape.size()) == 3);
|
||||
int32_t batch_size = shape[0];
|
||||
int32_t T = shape[1];
|
||||
int32_t vocab_size = shape[2];
|
||||
|
||||
std::vector<int64_t> length_shape =
|
||||
log_probs_length.GetTensorTypeAndShapeInfo().GetShape();
|
||||
assert(static_cast<int32_t>(length_shape.size()) == 1);
|
||||
|
||||
assert(shape[0] == length_shape[0]);
|
||||
|
||||
kaldi_decoder::FasterDecoderOptions opts;
|
||||
opts.max_active = config_.max_active;
|
||||
kaldi_decoder::FasterDecoder faster_decoder(*fst_, opts);
|
||||
|
||||
const float *start = log_probs.GetTensorData<float>();
|
||||
|
||||
std::vector<OfflineCtcDecoderResult> ans;
|
||||
ans.reserve(batch_size);
|
||||
|
||||
for (int32_t i = 0; i != batch_size; ++i) {
|
||||
const float *p = start + i * T * vocab_size;
|
||||
int32_t num_frames = log_probs_length.GetTensorData<int64_t>()[i];
|
||||
auto r = DecodeOne(&faster_decoder, p, num_frames, vocab_size);
|
||||
ans.push_back(std::move(r));
|
||||
}
|
||||
|
||||
return ans;
|
||||
}
|
||||
|
||||
} // namespace sherpa_onnx
|
||||
33
sherpa-onnx/csrc/offline-ctc-fst-decoder.h
Normal file
33
sherpa-onnx/csrc/offline-ctc-fst-decoder.h
Normal file
@@ -0,0 +1,33 @@
|
||||
// sherpa-onnx/csrc/offline-ctc-fst-decoder.h
|
||||
//
|
||||
// Copyright (c) 2023 Xiaomi Corporation
|
||||
|
||||
#ifndef SHERPA_ONNX_CSRC_OFFLINE_CTC_FST_DECODER_H_
|
||||
#define SHERPA_ONNX_CSRC_OFFLINE_CTC_FST_DECODER_H_
|
||||
|
||||
#include <memory>
|
||||
#include <vector>
|
||||
|
||||
#include "fst/fst.h"
|
||||
#include "sherpa-onnx/csrc/offline-ctc-decoder.h"
|
||||
#include "sherpa-onnx/csrc/offline-ctc-fst-decoder-config.h"
|
||||
#include "sherpa-onnx/csrc/parse-options.h"
|
||||
|
||||
namespace sherpa_onnx {
|
||||
|
||||
class OfflineCtcFstDecoder : public OfflineCtcDecoder {
|
||||
public:
|
||||
explicit OfflineCtcFstDecoder(const OfflineCtcFstDecoderConfig &config);
|
||||
|
||||
std::vector<OfflineCtcDecoderResult> Decode(
|
||||
Ort::Value log_probs, Ort::Value log_probs_length) override;
|
||||
|
||||
private:
|
||||
OfflineCtcFstDecoderConfig config_;
|
||||
|
||||
std::unique_ptr<fst::Fst<fst::StdArc>> fst_;
|
||||
};
|
||||
|
||||
} // namespace sherpa_onnx
|
||||
|
||||
#endif // SHERPA_ONNX_CSRC_OFFLINE_CTC_FST_DECODER_H_
|
||||
@@ -12,6 +12,7 @@
|
||||
#include "sherpa-onnx/csrc/macros.h"
|
||||
#include "sherpa-onnx/csrc/offline-nemo-enc-dec-ctc-model.h"
|
||||
#include "sherpa-onnx/csrc/offline-tdnn-ctc-model.h"
|
||||
#include "sherpa-onnx/csrc/offline-zipformer-ctc-model.h"
|
||||
#include "sherpa-onnx/csrc/onnx-utils.h"
|
||||
|
||||
namespace {
|
||||
@@ -19,6 +20,7 @@ namespace {
|
||||
enum class ModelType {
|
||||
kEncDecCTCModelBPE,
|
||||
kTdnn,
|
||||
kZipformerCtc,
|
||||
kUnkown,
|
||||
};
|
||||
|
||||
@@ -59,6 +61,8 @@ static ModelType GetModelType(char *model_data, size_t model_data_length,
|
||||
return ModelType::kEncDecCTCModelBPE;
|
||||
} else if (model_type.get() == std::string("tdnn")) {
|
||||
return ModelType::kTdnn;
|
||||
} else if (model_type.get() == std::string("zipformer2_ctc")) {
|
||||
return ModelType::kZipformerCtc;
|
||||
} else {
|
||||
SHERPA_ONNX_LOGE("Unsupported model_type: %s", model_type.get());
|
||||
return ModelType::kUnkown;
|
||||
@@ -74,6 +78,8 @@ std::unique_ptr<OfflineCtcModel> OfflineCtcModel::Create(
|
||||
filename = config.nemo_ctc.model;
|
||||
} else if (!config.tdnn.model.empty()) {
|
||||
filename = config.tdnn.model;
|
||||
} else if (!config.zipformer_ctc.model.empty()) {
|
||||
filename = config.zipformer_ctc.model;
|
||||
} else {
|
||||
SHERPA_ONNX_LOGE("Please specify a CTC model");
|
||||
exit(-1);
|
||||
@@ -92,6 +98,9 @@ std::unique_ptr<OfflineCtcModel> OfflineCtcModel::Create(
|
||||
case ModelType::kTdnn:
|
||||
return std::make_unique<OfflineTdnnCtcModel>(config);
|
||||
break;
|
||||
case ModelType::kZipformerCtc:
|
||||
return std::make_unique<OfflineZipformerCtcModel>(config);
|
||||
break;
|
||||
case ModelType::kUnkown:
|
||||
SHERPA_ONNX_LOGE("Unknown model type in offline CTC!");
|
||||
return nullptr;
|
||||
@@ -111,6 +120,8 @@ std::unique_ptr<OfflineCtcModel> OfflineCtcModel::Create(
|
||||
filename = config.nemo_ctc.model;
|
||||
} else if (!config.tdnn.model.empty()) {
|
||||
filename = config.tdnn.model;
|
||||
} else if (!config.zipformer_ctc.model.empty()) {
|
||||
filename = config.zipformer_ctc.model;
|
||||
} else {
|
||||
SHERPA_ONNX_LOGE("Please specify a CTC model");
|
||||
exit(-1);
|
||||
@@ -129,6 +140,9 @@ std::unique_ptr<OfflineCtcModel> OfflineCtcModel::Create(
|
||||
case ModelType::kTdnn:
|
||||
return std::make_unique<OfflineTdnnCtcModel>(mgr, config);
|
||||
break;
|
||||
case ModelType::kZipformerCtc:
|
||||
return std::make_unique<OfflineZipformerCtcModel>(mgr, config);
|
||||
break;
|
||||
case ModelType::kUnkown:
|
||||
SHERPA_ONNX_LOGE("Unknown model type in offline CTC!");
|
||||
return nullptr;
|
||||
|
||||
@@ -6,7 +6,7 @@
|
||||
|
||||
#include <memory>
|
||||
#include <string>
|
||||
#include <utility>
|
||||
#include <vector>
|
||||
|
||||
#if __ANDROID_API__ >= 9
|
||||
#include "android/asset_manager.h"
|
||||
@@ -32,17 +32,17 @@ class OfflineCtcModel {
|
||||
|
||||
/** Run the forward method of the model.
|
||||
*
|
||||
* @param features A tensor of shape (N, T, C). It is changed in-place.
|
||||
* @param features A tensor of shape (N, T, C).
|
||||
* @param features_length A 1-D tensor of shape (N,) containing number of
|
||||
* valid frames in `features` before padding.
|
||||
* Its dtype is int64_t.
|
||||
*
|
||||
* @return Return a pair containing:
|
||||
* @return Return a vector containing:
|
||||
* - log_probs: A 3-D tensor of shape (N, T', vocab_size).
|
||||
* - log_probs_length A 1-D tensor of shape (N,). Its dtype is int64_t
|
||||
*/
|
||||
virtual std::pair<Ort::Value, Ort::Value> Forward(
|
||||
Ort::Value features, Ort::Value features_length) = 0;
|
||||
virtual std::vector<Ort::Value> Forward(Ort::Value features,
|
||||
Ort::Value features_length) = 0;
|
||||
|
||||
/** Return the vocabulary size of the model
|
||||
*/
|
||||
|
||||
@@ -16,6 +16,7 @@ void OfflineModelConfig::Register(ParseOptions *po) {
|
||||
nemo_ctc.Register(po);
|
||||
whisper.Register(po);
|
||||
tdnn.Register(po);
|
||||
zipformer_ctc.Register(po);
|
||||
|
||||
po->Register("tokens", &tokens, "Path to tokens.txt");
|
||||
|
||||
@@ -31,7 +32,7 @@ void OfflineModelConfig::Register(ParseOptions *po) {
|
||||
po->Register("model-type", &model_type,
|
||||
"Specify it to reduce model initialization time. "
|
||||
"Valid values are: transducer, paraformer, nemo_ctc, whisper, "
|
||||
"tdnn."
|
||||
"tdnn, zipformer2_ctc"
|
||||
"All other values lead to loading the model twice.");
|
||||
}
|
||||
|
||||
@@ -62,6 +63,10 @@ bool OfflineModelConfig::Validate() const {
|
||||
return tdnn.Validate();
|
||||
}
|
||||
|
||||
if (!zipformer_ctc.model.empty()) {
|
||||
return zipformer_ctc.Validate();
|
||||
}
|
||||
|
||||
return transducer.Validate();
|
||||
}
|
||||
|
||||
@@ -74,6 +79,7 @@ std::string OfflineModelConfig::ToString() const {
|
||||
os << "nemo_ctc=" << nemo_ctc.ToString() << ", ";
|
||||
os << "whisper=" << whisper.ToString() << ", ";
|
||||
os << "tdnn=" << tdnn.ToString() << ", ";
|
||||
os << "zipformer_ctc=" << zipformer_ctc.ToString() << ", ";
|
||||
os << "tokens=\"" << tokens << "\", ";
|
||||
os << "num_threads=" << num_threads << ", ";
|
||||
os << "debug=" << (debug ? "True" : "False") << ", ";
|
||||
|
||||
@@ -11,6 +11,7 @@
|
||||
#include "sherpa-onnx/csrc/offline-tdnn-model-config.h"
|
||||
#include "sherpa-onnx/csrc/offline-transducer-model-config.h"
|
||||
#include "sherpa-onnx/csrc/offline-whisper-model-config.h"
|
||||
#include "sherpa-onnx/csrc/offline-zipformer-ctc-model-config.h"
|
||||
|
||||
namespace sherpa_onnx {
|
||||
|
||||
@@ -20,6 +21,7 @@ struct OfflineModelConfig {
|
||||
OfflineNemoEncDecCtcModelConfig nemo_ctc;
|
||||
OfflineWhisperModelConfig whisper;
|
||||
OfflineTdnnModelConfig tdnn;
|
||||
OfflineZipformerCtcModelConfig zipformer_ctc;
|
||||
|
||||
std::string tokens;
|
||||
int32_t num_threads = 2;
|
||||
@@ -43,6 +45,7 @@ struct OfflineModelConfig {
|
||||
const OfflineNemoEncDecCtcModelConfig &nemo_ctc,
|
||||
const OfflineWhisperModelConfig &whisper,
|
||||
const OfflineTdnnModelConfig &tdnn,
|
||||
const OfflineZipformerCtcModelConfig &zipformer_ctc,
|
||||
const std::string &tokens, int32_t num_threads, bool debug,
|
||||
const std::string &provider, const std::string &model_type)
|
||||
: transducer(transducer),
|
||||
@@ -50,6 +53,7 @@ struct OfflineModelConfig {
|
||||
nemo_ctc(nemo_ctc),
|
||||
whisper(whisper),
|
||||
tdnn(tdnn),
|
||||
zipformer_ctc(zipformer_ctc),
|
||||
tokens(tokens),
|
||||
num_threads(num_threads),
|
||||
debug(debug),
|
||||
|
||||
@@ -34,8 +34,8 @@ class OfflineNemoEncDecCtcModel::Impl {
|
||||
}
|
||||
#endif
|
||||
|
||||
std::pair<Ort::Value, Ort::Value> Forward(Ort::Value features,
|
||||
Ort::Value features_length) {
|
||||
std::vector<Ort::Value> Forward(Ort::Value features,
|
||||
Ort::Value features_length) {
|
||||
std::vector<int64_t> shape =
|
||||
features_length.GetTensorTypeAndShapeInfo().GetShape();
|
||||
|
||||
@@ -57,7 +57,11 @@ class OfflineNemoEncDecCtcModel::Impl {
|
||||
sess_->Run({}, input_names_ptr_.data(), inputs.data(), inputs.size(),
|
||||
output_names_ptr_.data(), output_names_ptr_.size());
|
||||
|
||||
return {std::move(out[0]), std::move(out_features_length)};
|
||||
std::vector<Ort::Value> ans;
|
||||
ans.reserve(2);
|
||||
ans.push_back(std::move(out[0]));
|
||||
ans.push_back(std::move(out_features_length));
|
||||
return ans;
|
||||
}
|
||||
|
||||
int32_t VocabSize() const { return vocab_size_; }
|
||||
@@ -122,7 +126,7 @@ OfflineNemoEncDecCtcModel::OfflineNemoEncDecCtcModel(
|
||||
|
||||
OfflineNemoEncDecCtcModel::~OfflineNemoEncDecCtcModel() = default;
|
||||
|
||||
std::pair<Ort::Value, Ort::Value> OfflineNemoEncDecCtcModel::Forward(
|
||||
std::vector<Ort::Value> OfflineNemoEncDecCtcModel::Forward(
|
||||
Ort::Value features, Ort::Value features_length) {
|
||||
return impl_->Forward(std::move(features), std::move(features_length));
|
||||
}
|
||||
|
||||
@@ -38,17 +38,17 @@ class OfflineNemoEncDecCtcModel : public OfflineCtcModel {
|
||||
|
||||
/** Run the forward method of the model.
|
||||
*
|
||||
* @param features A tensor of shape (N, T, C). It is changed in-place.
|
||||
* @param features A tensor of shape (N, T, C).
|
||||
* @param features_length A 1-D tensor of shape (N,) containing number of
|
||||
* valid frames in `features` before padding.
|
||||
* Its dtype is int64_t.
|
||||
*
|
||||
* @return Return a pair containing:
|
||||
* @return Return a vector containing:
|
||||
* - log_probs: A 3-D tensor of shape (N, T', vocab_size).
|
||||
* - log_probs_length A 1-D tensor of shape (N,). Its dtype is int64_t
|
||||
*/
|
||||
std::pair<Ort::Value, Ort::Value> Forward(
|
||||
Ort::Value features, Ort::Value features_length) override;
|
||||
std::vector<Ort::Value> Forward(Ort::Value features,
|
||||
Ort::Value features_length) override;
|
||||
|
||||
/** Return the vocabulary size of the model
|
||||
*/
|
||||
|
||||
@@ -16,6 +16,7 @@
|
||||
#endif
|
||||
|
||||
#include "sherpa-onnx/csrc/offline-ctc-decoder.h"
|
||||
#include "sherpa-onnx/csrc/offline-ctc-fst-decoder.h"
|
||||
#include "sherpa-onnx/csrc/offline-ctc-greedy-search-decoder.h"
|
||||
#include "sherpa-onnx/csrc/offline-ctc-model.h"
|
||||
#include "sherpa-onnx/csrc/offline-recognizer-impl.h"
|
||||
@@ -25,9 +26,12 @@
|
||||
namespace sherpa_onnx {
|
||||
|
||||
static OfflineRecognitionResult Convert(const OfflineCtcDecoderResult &src,
|
||||
const SymbolTable &sym_table) {
|
||||
const SymbolTable &sym_table,
|
||||
int32_t frame_shift_ms,
|
||||
int32_t subsampling_factor) {
|
||||
OfflineRecognitionResult r;
|
||||
r.tokens.reserve(src.tokens.size());
|
||||
r.timestamps.reserve(src.timestamps.size());
|
||||
|
||||
std::string text;
|
||||
|
||||
@@ -42,6 +46,12 @@ static OfflineRecognitionResult Convert(const OfflineCtcDecoderResult &src,
|
||||
}
|
||||
r.text = std::move(text);
|
||||
|
||||
float frame_shift_s = frame_shift_ms / 1000. * subsampling_factor;
|
||||
for (auto t : src.timestamps) {
|
||||
float time = frame_shift_s * t;
|
||||
r.timestamps.push_back(time);
|
||||
}
|
||||
|
||||
return r;
|
||||
}
|
||||
|
||||
@@ -68,7 +78,12 @@ class OfflineRecognizerCtcImpl : public OfflineRecognizerImpl {
|
||||
config_.feat_config.nemo_normalize_type =
|
||||
model_->FeatureNormalizationMethod();
|
||||
|
||||
if (config_.decoding_method == "greedy_search") {
|
||||
if (!config_.ctc_fst_decoder_config.graph.empty()) {
|
||||
// TODO(fangjun): Support android to read the graph from
|
||||
// asset_manager
|
||||
decoder_ = std::make_unique<OfflineCtcFstDecoder>(
|
||||
config_.ctc_fst_decoder_config);
|
||||
} else if (config_.decoding_method == "greedy_search") {
|
||||
if (!symbol_table_.contains("<blk>") &&
|
||||
!symbol_table_.contains("<eps>")) {
|
||||
SHERPA_ONNX_LOGE(
|
||||
@@ -139,10 +154,12 @@ class OfflineRecognizerCtcImpl : public OfflineRecognizerImpl {
|
||||
-23.025850929940457f);
|
||||
auto t = model_->Forward(std::move(x), std::move(x_length));
|
||||
|
||||
auto results = decoder_->Decode(std::move(t.first), std::move(t.second));
|
||||
auto results = decoder_->Decode(std::move(t[0]), std::move(t[1]));
|
||||
|
||||
int32_t frame_shift_ms = 10;
|
||||
for (int32_t i = 0; i != n; ++i) {
|
||||
auto r = Convert(results[i], symbol_table_);
|
||||
auto r = Convert(results[i], symbol_table_, frame_shift_ms,
|
||||
model_->SubsamplingFactor());
|
||||
ss[i]->SetResult(r);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -25,9 +25,8 @@ std::unique_ptr<OfflineRecognizerImpl> OfflineRecognizerImpl::Create(
|
||||
return std::make_unique<OfflineRecognizerTransducerImpl>(config);
|
||||
} else if (model_type == "paraformer") {
|
||||
return std::make_unique<OfflineRecognizerParaformerImpl>(config);
|
||||
} else if (model_type == "nemo_ctc") {
|
||||
return std::make_unique<OfflineRecognizerCtcImpl>(config);
|
||||
} else if (model_type == "tdnn") {
|
||||
} else if (model_type == "nemo_ctc" || model_type == "tdnn" ||
|
||||
model_type == "zipformer2_ctc") {
|
||||
return std::make_unique<OfflineRecognizerCtcImpl>(config);
|
||||
} else if (model_type == "whisper") {
|
||||
return std::make_unique<OfflineRecognizerWhisperImpl>(config);
|
||||
@@ -50,6 +49,8 @@ std::unique_ptr<OfflineRecognizerImpl> OfflineRecognizerImpl::Create(
|
||||
model_filename = config.model_config.nemo_ctc.model;
|
||||
} else if (!config.model_config.tdnn.model.empty()) {
|
||||
model_filename = config.model_config.tdnn.model;
|
||||
} else if (!config.model_config.zipformer_ctc.model.empty()) {
|
||||
model_filename = config.model_config.zipformer_ctc.model;
|
||||
} else if (!config.model_config.whisper.encoder.empty()) {
|
||||
model_filename = config.model_config.whisper.encoder;
|
||||
} else {
|
||||
@@ -93,6 +94,11 @@ std::unique_ptr<OfflineRecognizerImpl> OfflineRecognizerImpl::Create(
|
||||
"\n "
|
||||
"https://github.com/k2-fsa/icefall/tree/master/egs/yesno/ASR/tdnn"
|
||||
"\n"
|
||||
"(5) Zipformer CTC models from icefall"
|
||||
"\n "
|
||||
"https://github.com/k2-fsa/icefall/blob/master/egs/librispeech/ASR/"
|
||||
"zipformer/export-onnx-ctc.py"
|
||||
"\n"
|
||||
"\n");
|
||||
exit(-1);
|
||||
}
|
||||
@@ -107,11 +113,8 @@ std::unique_ptr<OfflineRecognizerImpl> OfflineRecognizerImpl::Create(
|
||||
return std::make_unique<OfflineRecognizerParaformerImpl>(config);
|
||||
}
|
||||
|
||||
if (model_type == "EncDecCTCModelBPE") {
|
||||
return std::make_unique<OfflineRecognizerCtcImpl>(config);
|
||||
}
|
||||
|
||||
if (model_type == "tdnn") {
|
||||
if (model_type == "EncDecCTCModelBPE" || model_type == "tdnn" ||
|
||||
model_type == "zipformer2_ctc") {
|
||||
return std::make_unique<OfflineRecognizerCtcImpl>(config);
|
||||
}
|
||||
|
||||
@@ -126,7 +129,8 @@ std::unique_ptr<OfflineRecognizerImpl> OfflineRecognizerImpl::Create(
|
||||
" - Non-streaming Paraformer models from FunASR\n"
|
||||
" - EncDecCTCModelBPE models from NeMo\n"
|
||||
" - Whisper models\n"
|
||||
" - Tdnn models\n",
|
||||
" - Tdnn models\n"
|
||||
" - Zipformer CTC models\n",
|
||||
model_type.c_str());
|
||||
|
||||
exit(-1);
|
||||
@@ -141,9 +145,8 @@ std::unique_ptr<OfflineRecognizerImpl> OfflineRecognizerImpl::Create(
|
||||
return std::make_unique<OfflineRecognizerTransducerImpl>(mgr, config);
|
||||
} else if (model_type == "paraformer") {
|
||||
return std::make_unique<OfflineRecognizerParaformerImpl>(mgr, config);
|
||||
} else if (model_type == "nemo_ctc") {
|
||||
return std::make_unique<OfflineRecognizerCtcImpl>(mgr, config);
|
||||
} else if (model_type == "tdnn") {
|
||||
} else if (model_type == "nemo_ctc" || model_type == "tdnn" ||
|
||||
model_type == "zipformer2_ctc") {
|
||||
return std::make_unique<OfflineRecognizerCtcImpl>(mgr, config);
|
||||
} else if (model_type == "whisper") {
|
||||
return std::make_unique<OfflineRecognizerWhisperImpl>(mgr, config);
|
||||
@@ -166,6 +169,8 @@ std::unique_ptr<OfflineRecognizerImpl> OfflineRecognizerImpl::Create(
|
||||
model_filename = config.model_config.nemo_ctc.model;
|
||||
} else if (!config.model_config.tdnn.model.empty()) {
|
||||
model_filename = config.model_config.tdnn.model;
|
||||
} else if (!config.model_config.zipformer_ctc.model.empty()) {
|
||||
model_filename = config.model_config.zipformer_ctc.model;
|
||||
} else if (!config.model_config.whisper.encoder.empty()) {
|
||||
model_filename = config.model_config.whisper.encoder;
|
||||
} else {
|
||||
@@ -209,6 +214,11 @@ std::unique_ptr<OfflineRecognizerImpl> OfflineRecognizerImpl::Create(
|
||||
"\n "
|
||||
"https://github.com/k2-fsa/icefall/tree/master/egs/yesno/ASR/tdnn"
|
||||
"\n"
|
||||
"(5) Zipformer CTC models from icefall"
|
||||
"\n "
|
||||
"https://github.com/k2-fsa/icefall/blob/master/egs/librispeech/ASR/"
|
||||
"zipformer/export-onnx-ctc.py"
|
||||
"\n"
|
||||
"\n");
|
||||
exit(-1);
|
||||
}
|
||||
@@ -223,11 +233,8 @@ std::unique_ptr<OfflineRecognizerImpl> OfflineRecognizerImpl::Create(
|
||||
return std::make_unique<OfflineRecognizerParaformerImpl>(mgr, config);
|
||||
}
|
||||
|
||||
if (model_type == "EncDecCTCModelBPE") {
|
||||
return std::make_unique<OfflineRecognizerCtcImpl>(mgr, config);
|
||||
}
|
||||
|
||||
if (model_type == "tdnn") {
|
||||
if (model_type == "EncDecCTCModelBPE" || model_type == "tdnn" ||
|
||||
model_type == "zipformer2_ctc") {
|
||||
return std::make_unique<OfflineRecognizerCtcImpl>(mgr, config);
|
||||
}
|
||||
|
||||
@@ -242,7 +249,8 @@ std::unique_ptr<OfflineRecognizerImpl> OfflineRecognizerImpl::Create(
|
||||
" - Non-streaming Paraformer models from FunASR\n"
|
||||
" - EncDecCTCModelBPE models from NeMo\n"
|
||||
" - Whisper models\n"
|
||||
" - Tdnn models\n",
|
||||
" - Tdnn models\n"
|
||||
" - Zipformer CTC models\n",
|
||||
model_type.c_str());
|
||||
|
||||
exit(-1);
|
||||
|
||||
@@ -17,6 +17,7 @@ void OfflineRecognizerConfig::Register(ParseOptions *po) {
|
||||
feat_config.Register(po);
|
||||
model_config.Register(po);
|
||||
lm_config.Register(po);
|
||||
ctc_fst_decoder_config.Register(po);
|
||||
|
||||
po->Register(
|
||||
"decoding-method", &decoding_method,
|
||||
@@ -69,6 +70,7 @@ std::string OfflineRecognizerConfig::ToString() const {
|
||||
os << "feat_config=" << feat_config.ToString() << ", ";
|
||||
os << "model_config=" << model_config.ToString() << ", ";
|
||||
os << "lm_config=" << lm_config.ToString() << ", ";
|
||||
os << "ctc_fst_decoder_config=" << ctc_fst_decoder_config.ToString() << ", ";
|
||||
os << "decoding_method=\"" << decoding_method << "\", ";
|
||||
os << "max_active_paths=" << max_active_paths << ", ";
|
||||
os << "hotwords_file=\"" << hotwords_file << "\", ";
|
||||
|
||||
@@ -14,6 +14,7 @@
|
||||
#include "android/asset_manager_jni.h"
|
||||
#endif
|
||||
|
||||
#include "sherpa-onnx/csrc/offline-ctc-fst-decoder-config.h"
|
||||
#include "sherpa-onnx/csrc/offline-lm-config.h"
|
||||
#include "sherpa-onnx/csrc/offline-model-config.h"
|
||||
#include "sherpa-onnx/csrc/offline-stream.h"
|
||||
@@ -28,6 +29,7 @@ struct OfflineRecognizerConfig {
|
||||
OfflineFeatureExtractorConfig feat_config;
|
||||
OfflineModelConfig model_config;
|
||||
OfflineLMConfig lm_config;
|
||||
OfflineCtcFstDecoderConfig ctc_fst_decoder_config;
|
||||
|
||||
std::string decoding_method = "greedy_search";
|
||||
int32_t max_active_paths = 4;
|
||||
@@ -39,16 +41,16 @@ struct OfflineRecognizerConfig {
|
||||
// TODO(fangjun): Implement modified_beam_search
|
||||
|
||||
OfflineRecognizerConfig() = default;
|
||||
OfflineRecognizerConfig(const OfflineFeatureExtractorConfig &feat_config,
|
||||
const OfflineModelConfig &model_config,
|
||||
const OfflineLMConfig &lm_config,
|
||||
const std::string &decoding_method,
|
||||
int32_t max_active_paths,
|
||||
const std::string &hotwords_file,
|
||||
float hotwords_score)
|
||||
OfflineRecognizerConfig(
|
||||
const OfflineFeatureExtractorConfig &feat_config,
|
||||
const OfflineModelConfig &model_config, const OfflineLMConfig &lm_config,
|
||||
const OfflineCtcFstDecoderConfig &ctc_fst_decoder_config,
|
||||
const std::string &decoding_method, int32_t max_active_paths,
|
||||
const std::string &hotwords_file, float hotwords_score)
|
||||
: feat_config(feat_config),
|
||||
model_config(model_config),
|
||||
lm_config(lm_config),
|
||||
ctc_fst_decoder_config(ctc_fst_decoder_config),
|
||||
decoding_method(decoding_method),
|
||||
max_active_paths(max_active_paths),
|
||||
hotwords_file(hotwords_file),
|
||||
|
||||
@@ -4,6 +4,8 @@
|
||||
|
||||
#include "sherpa-onnx/csrc/offline-tdnn-ctc-model.h"
|
||||
|
||||
#include <utility>
|
||||
|
||||
#include "sherpa-onnx/csrc/macros.h"
|
||||
#include "sherpa-onnx/csrc/onnx-utils.h"
|
||||
#include "sherpa-onnx/csrc/session.h"
|
||||
@@ -34,7 +36,7 @@ class OfflineTdnnCtcModel::Impl {
|
||||
}
|
||||
#endif
|
||||
|
||||
std::pair<Ort::Value, Ort::Value> Forward(Ort::Value features) {
|
||||
std::vector<Ort::Value> Forward(Ort::Value features) {
|
||||
auto nnet_out =
|
||||
sess_->Run({}, input_names_ptr_.data(), &features, 1,
|
||||
output_names_ptr_.data(), output_names_ptr_.size());
|
||||
@@ -52,7 +54,11 @@ class OfflineTdnnCtcModel::Impl {
|
||||
memory_info, out_length_vec.data(), out_length_vec.size(),
|
||||
out_length_shape.data(), out_length_shape.size());
|
||||
|
||||
return {std::move(nnet_out[0]), Clone(Allocator(), &nnet_out_length)};
|
||||
std::vector<Ort::Value> ans;
|
||||
ans.reserve(2);
|
||||
ans.push_back(std::move(nnet_out[0]));
|
||||
ans.push_back(Clone(Allocator(), &nnet_out_length));
|
||||
return ans;
|
||||
}
|
||||
|
||||
int32_t VocabSize() const { return vocab_size_; }
|
||||
@@ -108,7 +114,7 @@ OfflineTdnnCtcModel::OfflineTdnnCtcModel(AAssetManager *mgr,
|
||||
|
||||
OfflineTdnnCtcModel::~OfflineTdnnCtcModel() = default;
|
||||
|
||||
std::pair<Ort::Value, Ort::Value> OfflineTdnnCtcModel::Forward(
|
||||
std::vector<Ort::Value> OfflineTdnnCtcModel::Forward(
|
||||
Ort::Value features, Ort::Value /*features_length*/) {
|
||||
return impl_->Forward(std::move(features));
|
||||
}
|
||||
|
||||
@@ -5,7 +5,6 @@
|
||||
#define SHERPA_ONNX_CSRC_OFFLINE_TDNN_CTC_MODEL_H_
|
||||
#include <memory>
|
||||
#include <string>
|
||||
#include <utility>
|
||||
#include <vector>
|
||||
|
||||
#if __ANDROID_API__ >= 9
|
||||
@@ -36,7 +35,7 @@ class OfflineTdnnCtcModel : public OfflineCtcModel {
|
||||
|
||||
/** Run the forward method of the model.
|
||||
*
|
||||
* @param features A tensor of shape (N, T, C). It is changed in-place.
|
||||
* @param features A tensor of shape (N, T, C).
|
||||
* @param features_length A 1-D tensor of shape (N,) containing number of
|
||||
* valid frames in `features` before padding.
|
||||
* Its dtype is int64_t.
|
||||
@@ -45,8 +44,8 @@ class OfflineTdnnCtcModel : public OfflineCtcModel {
|
||||
* - log_probs: A 3-D tensor of shape (N, T', vocab_size).
|
||||
* - log_probs_length A 1-D tensor of shape (N,). Its dtype is int64_t
|
||||
*/
|
||||
std::pair<Ort::Value, Ort::Value> Forward(
|
||||
Ort::Value features, Ort::Value /*features_length*/) override;
|
||||
std::vector<Ort::Value> Forward(Ort::Value features,
|
||||
Ort::Value /*features_length*/) override;
|
||||
|
||||
/** Return the vocabulary size of the model
|
||||
*/
|
||||
|
||||
35
sherpa-onnx/csrc/offline-zipformer-ctc-model-config.cc
Normal file
35
sherpa-onnx/csrc/offline-zipformer-ctc-model-config.cc
Normal file
@@ -0,0 +1,35 @@
|
||||
// sherpa-onnx/csrc/offline-zipformer-ctc-model-config.cc
|
||||
//
|
||||
// Copyright (c) 2023 Xiaomi Corporation
|
||||
|
||||
#include "sherpa-onnx/csrc/offline-zipformer-ctc-model-config.h"
|
||||
|
||||
#include "sherpa-onnx/csrc/file-utils.h"
|
||||
#include "sherpa-onnx/csrc/macros.h"
|
||||
|
||||
namespace sherpa_onnx {
|
||||
|
||||
void OfflineZipformerCtcModelConfig::Register(ParseOptions *po) {
|
||||
po->Register("zipformer-ctc-model", &model, "Path to zipformer CTC model");
|
||||
}
|
||||
|
||||
bool OfflineZipformerCtcModelConfig::Validate() const {
|
||||
if (!FileExists(model)) {
|
||||
SHERPA_ONNX_LOGE("zipformer CTC model file %s does not exist",
|
||||
model.c_str());
|
||||
return false;
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
std::string OfflineZipformerCtcModelConfig::ToString() const {
|
||||
std::ostringstream os;
|
||||
|
||||
os << "OfflineZipformerCtcModelConfig(";
|
||||
os << "model=\"" << model << "\")";
|
||||
|
||||
return os.str();
|
||||
}
|
||||
|
||||
} // namespace sherpa_onnx
|
||||
32
sherpa-onnx/csrc/offline-zipformer-ctc-model-config.h
Normal file
32
sherpa-onnx/csrc/offline-zipformer-ctc-model-config.h
Normal file
@@ -0,0 +1,32 @@
|
||||
// sherpa-onnx/csrc/offline-zipformer-ctc-model-config.h
|
||||
//
|
||||
// Copyright (c) 2023 Xiaomi Corporation
|
||||
#ifndef SHERPA_ONNX_CSRC_OFFLINE_ZIPFORMER_CTC_MODEL_CONFIG_H_
|
||||
#define SHERPA_ONNX_CSRC_OFFLINE_ZIPFORMER_CTC_MODEL_CONFIG_H_
|
||||
|
||||
#include <string>
|
||||
|
||||
#include "sherpa-onnx/csrc/parse-options.h"
|
||||
|
||||
namespace sherpa_onnx {
|
||||
|
||||
// for
|
||||
// https://github.com/k2-fsa/icefall/blob/master/egs/librispeech/ASR/zipformer/export-onnx-ctc.py
|
||||
struct OfflineZipformerCtcModelConfig {
|
||||
std::string model;
|
||||
|
||||
OfflineZipformerCtcModelConfig() = default;
|
||||
|
||||
explicit OfflineZipformerCtcModelConfig(const std::string &model)
|
||||
: model(model) {}
|
||||
|
||||
void Register(ParseOptions *po);
|
||||
|
||||
bool Validate() const;
|
||||
|
||||
std::string ToString() const;
|
||||
};
|
||||
|
||||
} // namespace sherpa_onnx
|
||||
|
||||
#endif // SHERPA_ONNX_CSRC_OFFLINE_ZIPFORMER_CTC_MODEL_CONFIG_H_
|
||||
119
sherpa-onnx/csrc/offline-zipformer-ctc-model.cc
Normal file
119
sherpa-onnx/csrc/offline-zipformer-ctc-model.cc
Normal file
@@ -0,0 +1,119 @@
|
||||
// sherpa-onnx/csrc/offline-zipformer-ctc-model.cc
|
||||
//
|
||||
// Copyright (c) 2023 Xiaomi Corporation
|
||||
|
||||
#include "sherpa-onnx/csrc/offline-zipformer-ctc-model.h"
|
||||
|
||||
#include "sherpa-onnx/csrc/macros.h"
|
||||
#include "sherpa-onnx/csrc/onnx-utils.h"
|
||||
#include "sherpa-onnx/csrc/session.h"
|
||||
#include "sherpa-onnx/csrc/text-utils.h"
|
||||
#include "sherpa-onnx/csrc/transpose.h"
|
||||
|
||||
namespace sherpa_onnx {
|
||||
|
||||
class OfflineZipformerCtcModel::Impl {
|
||||
public:
|
||||
explicit Impl(const OfflineModelConfig &config)
|
||||
: config_(config),
|
||||
env_(ORT_LOGGING_LEVEL_ERROR),
|
||||
sess_opts_(GetSessionOptions(config)),
|
||||
allocator_{} {
|
||||
auto buf = ReadFile(config_.zipformer_ctc.model);
|
||||
Init(buf.data(), buf.size());
|
||||
}
|
||||
|
||||
#if __ANDROID_API__ >= 9
|
||||
Impl(AAssetManager *mgr, const OfflineModelConfig &config)
|
||||
: config_(config),
|
||||
env_(ORT_LOGGING_LEVEL_ERROR),
|
||||
sess_opts_(GetSessionOptions(config)),
|
||||
allocator_{} {
|
||||
auto buf = ReadFile(mgr, config_.zipformer_ctc.model);
|
||||
Init(buf.data(), buf.size());
|
||||
}
|
||||
#endif
|
||||
|
||||
std::vector<Ort::Value> Forward(Ort::Value features,
|
||||
Ort::Value features_length) {
|
||||
std::array<Ort::Value, 2> inputs = {std::move(features),
|
||||
std::move(features_length)};
|
||||
|
||||
return sess_->Run({}, input_names_ptr_.data(), inputs.data(), inputs.size(),
|
||||
output_names_ptr_.data(), output_names_ptr_.size());
|
||||
}
|
||||
|
||||
int32_t VocabSize() const { return vocab_size_; }
|
||||
int32_t SubsamplingFactor() const { return 4; }
|
||||
|
||||
OrtAllocator *Allocator() const { return allocator_; }
|
||||
|
||||
private:
|
||||
void Init(void *model_data, size_t model_data_length) {
|
||||
sess_ = std::make_unique<Ort::Session>(env_, model_data, model_data_length,
|
||||
sess_opts_);
|
||||
|
||||
GetInputNames(sess_.get(), &input_names_, &input_names_ptr_);
|
||||
|
||||
GetOutputNames(sess_.get(), &output_names_, &output_names_ptr_);
|
||||
|
||||
// get meta data
|
||||
Ort::ModelMetadata meta_data = sess_->GetModelMetadata();
|
||||
if (config_.debug) {
|
||||
std::ostringstream os;
|
||||
PrintModelMetadata(os, meta_data);
|
||||
SHERPA_ONNX_LOGE("%s\n", os.str().c_str());
|
||||
}
|
||||
|
||||
// get vocab size from the output[0].shape, which is (N, T, vocab_size)
|
||||
vocab_size_ =
|
||||
sess_->GetOutputTypeInfo(0).GetTensorTypeAndShapeInfo().GetShape()[2];
|
||||
}
|
||||
|
||||
private:
|
||||
OfflineModelConfig config_;
|
||||
Ort::Env env_;
|
||||
Ort::SessionOptions sess_opts_;
|
||||
Ort::AllocatorWithDefaultOptions allocator_;
|
||||
|
||||
std::unique_ptr<Ort::Session> sess_;
|
||||
|
||||
std::vector<std::string> input_names_;
|
||||
std::vector<const char *> input_names_ptr_;
|
||||
|
||||
std::vector<std::string> output_names_;
|
||||
std::vector<const char *> output_names_ptr_;
|
||||
|
||||
int32_t vocab_size_ = 0;
|
||||
};
|
||||
|
||||
OfflineZipformerCtcModel::OfflineZipformerCtcModel(
|
||||
const OfflineModelConfig &config)
|
||||
: impl_(std::make_unique<Impl>(config)) {}
|
||||
|
||||
#if __ANDROID_API__ >= 9
|
||||
OfflineZipformerCtcModel::OfflineZipformerCtcModel(
|
||||
AAssetManager *mgr, const OfflineModelConfig &config)
|
||||
: impl_(std::make_unique<Impl>(mgr, config)) {}
|
||||
#endif
|
||||
|
||||
OfflineZipformerCtcModel::~OfflineZipformerCtcModel() = default;
|
||||
|
||||
std::vector<Ort::Value> OfflineZipformerCtcModel::Forward(
|
||||
Ort::Value features, Ort::Value features_length) {
|
||||
return impl_->Forward(std::move(features), std::move(features_length));
|
||||
}
|
||||
|
||||
int32_t OfflineZipformerCtcModel::VocabSize() const {
|
||||
return impl_->VocabSize();
|
||||
}
|
||||
|
||||
OrtAllocator *OfflineZipformerCtcModel::Allocator() const {
|
||||
return impl_->Allocator();
|
||||
}
|
||||
|
||||
int32_t OfflineZipformerCtcModel::SubsamplingFactor() const {
|
||||
return impl_->SubsamplingFactor();
|
||||
}
|
||||
|
||||
} // namespace sherpa_onnx
|
||||
70
sherpa-onnx/csrc/offline-zipformer-ctc-model.h
Normal file
70
sherpa-onnx/csrc/offline-zipformer-ctc-model.h
Normal file
@@ -0,0 +1,70 @@
|
||||
// sherpa-onnx/csrc/offline-zipformer-ctc-model.h
|
||||
//
|
||||
// Copyright (c) 2023 Xiaomi Corporation
|
||||
#ifndef SHERPA_ONNX_CSRC_OFFLINE_ZIPFORMER_CTC_MODEL_H_
|
||||
#define SHERPA_ONNX_CSRC_OFFLINE_ZIPFORMER_CTC_MODEL_H_
|
||||
#include <memory>
|
||||
#include <string>
|
||||
#include <utility>
|
||||
#include <vector>
|
||||
|
||||
#if __ANDROID_API__ >= 9
|
||||
#include "android/asset_manager.h"
|
||||
#include "android/asset_manager_jni.h"
|
||||
#endif
|
||||
|
||||
#include "onnxruntime_cxx_api.h" // NOLINT
|
||||
#include "sherpa-onnx/csrc/offline-ctc-model.h"
|
||||
#include "sherpa-onnx/csrc/offline-model-config.h"
|
||||
|
||||
namespace sherpa_onnx {
|
||||
|
||||
/** This class implements the zipformer CTC model of the librispeech recipe
|
||||
* from icefall.
|
||||
*
|
||||
* See
|
||||
* https://github.com/k2-fsa/icefall/blob/master/egs/librispeech/ASR/zipformer/export-onnx-ctc.py
|
||||
*/
|
||||
class OfflineZipformerCtcModel : public OfflineCtcModel {
|
||||
public:
|
||||
explicit OfflineZipformerCtcModel(const OfflineModelConfig &config);
|
||||
|
||||
#if __ANDROID_API__ >= 9
|
||||
OfflineZipformerCtcModel(AAssetManager *mgr,
|
||||
const OfflineModelConfig &config);
|
||||
#endif
|
||||
|
||||
~OfflineZipformerCtcModel() override;
|
||||
|
||||
/** Run the forward method of the model.
|
||||
*
|
||||
* @param features A tensor of shape (N, T, C).
|
||||
* @param features_length A 1-D tensor of shape (N,) containing number of
|
||||
* valid frames in `features` before padding.
|
||||
* Its dtype is int64_t.
|
||||
*
|
||||
* @return Return a vector containing:
|
||||
* - log_probs: A 3-D tensor of shape (N, T', vocab_size).
|
||||
* - log_probs_length A 1-D tensor of shape (N,). Its dtype is int64_t
|
||||
*/
|
||||
std::vector<Ort::Value> Forward(Ort::Value features,
|
||||
Ort::Value features_length) override;
|
||||
|
||||
/** Return the vocabulary size of the model
|
||||
*/
|
||||
int32_t VocabSize() const override;
|
||||
|
||||
/** Return an allocator for allocating memory
|
||||
*/
|
||||
OrtAllocator *Allocator() const override;
|
||||
|
||||
int32_t SubsamplingFactor() const override;
|
||||
|
||||
private:
|
||||
class Impl;
|
||||
std::unique_ptr<Impl> impl_;
|
||||
};
|
||||
|
||||
} // namespace sherpa_onnx
|
||||
|
||||
#endif // SHERPA_ONNX_CSRC_OFFLINE_ZIPFORMER_CTC_MODEL_H_
|
||||
@@ -5,6 +5,7 @@ pybind11_add_module(_sherpa_onnx
|
||||
display.cc
|
||||
endpoint.cc
|
||||
features.cc
|
||||
offline-ctc-fst-decoder-config.cc
|
||||
offline-lm-config.cc
|
||||
offline-model-config.cc
|
||||
offline-nemo-enc-dec-ctc-model-config.cc
|
||||
@@ -14,6 +15,7 @@ pybind11_add_module(_sherpa_onnx
|
||||
offline-tdnn-model-config.cc
|
||||
offline-transducer-model-config.cc
|
||||
offline-whisper-model-config.cc
|
||||
offline-zipformer-ctc-model-config.cc
|
||||
online-lm-config.cc
|
||||
online-model-config.cc
|
||||
online-paraformer-model-config.cc
|
||||
|
||||
23
sherpa-onnx/python/csrc/offline-ctc-fst-decoder-config.cc
Normal file
23
sherpa-onnx/python/csrc/offline-ctc-fst-decoder-config.cc
Normal file
@@ -0,0 +1,23 @@
|
||||
// sherpa-onnx/python/csrc/offline-ctc-fst-decoder-config.cc
|
||||
//
|
||||
// Copyright (c) 2023 Xiaomi Corporation
|
||||
|
||||
#include "sherpa-onnx/python/csrc/offline-ctc-fst-decoder-config.h"
|
||||
|
||||
#include <string>
|
||||
|
||||
#include "sherpa-onnx/csrc/offline-ctc-fst-decoder-config.h"
|
||||
|
||||
namespace sherpa_onnx {
|
||||
|
||||
void PybindOfflineCtcFstDecoderConfig(py::module *m) {
|
||||
using PyClass = OfflineCtcFstDecoderConfig;
|
||||
py::class_<PyClass>(*m, "OfflineCtcFstDecoderConfig")
|
||||
.def(py::init<const std::string &, int32_t>(), py::arg("graph") = "",
|
||||
py::arg("max_active") = 3000)
|
||||
.def_readwrite("graph", &PyClass::graph)
|
||||
.def_readwrite("max_active", &PyClass::max_active)
|
||||
.def("__str__", &PyClass::ToString);
|
||||
}
|
||||
|
||||
} // namespace sherpa_onnx
|
||||
16
sherpa-onnx/python/csrc/offline-ctc-fst-decoder-config.h
Normal file
16
sherpa-onnx/python/csrc/offline-ctc-fst-decoder-config.h
Normal file
@@ -0,0 +1,16 @@
|
||||
// sherpa-onnx/python/csrc/offline-ctc-fst-decoder-config.h
|
||||
//
|
||||
// Copyright (c) 2023 Xiaomi Corporation
|
||||
|
||||
#ifndef SHERPA_ONNX_PYTHON_CSRC_OFFLINE_CTC_FST_DECODER_CONFIG_H_
|
||||
#define SHERPA_ONNX_PYTHON_CSRC_OFFLINE_CTC_FST_DECODER_CONFIG_H_
|
||||
|
||||
#include "sherpa-onnx/python/csrc/sherpa-onnx.h"
|
||||
|
||||
namespace sherpa_onnx {
|
||||
|
||||
void PybindOfflineCtcFstDecoderConfig(py::module *m);
|
||||
|
||||
}
|
||||
|
||||
#endif // SHERPA_ONNX_PYTHON_CSRC_OFFLINE_CTC_FST_DECODER_CONFIG_H_
|
||||
@@ -13,6 +13,7 @@
|
||||
#include "sherpa-onnx/python/csrc/offline-tdnn-model-config.h"
|
||||
#include "sherpa-onnx/python/csrc/offline-transducer-model-config.h"
|
||||
#include "sherpa-onnx/python/csrc/offline-whisper-model-config.h"
|
||||
#include "sherpa-onnx/python/csrc/offline-zipformer-ctc-model-config.h"
|
||||
|
||||
namespace sherpa_onnx {
|
||||
|
||||
@@ -22,6 +23,7 @@ void PybindOfflineModelConfig(py::module *m) {
|
||||
PybindOfflineNemoEncDecCtcModelConfig(m);
|
||||
PybindOfflineWhisperModelConfig(m);
|
||||
PybindOfflineTdnnModelConfig(m);
|
||||
PybindOfflineZipformerCtcModelConfig(m);
|
||||
|
||||
using PyClass = OfflineModelConfig;
|
||||
py::class_<PyClass>(*m, "OfflineModelConfig")
|
||||
@@ -29,20 +31,23 @@ void PybindOfflineModelConfig(py::module *m) {
|
||||
const OfflineParaformerModelConfig &,
|
||||
const OfflineNemoEncDecCtcModelConfig &,
|
||||
const OfflineWhisperModelConfig &,
|
||||
const OfflineTdnnModelConfig &, const std::string &,
|
||||
const OfflineTdnnModelConfig &,
|
||||
const OfflineZipformerCtcModelConfig &, const std::string &,
|
||||
int32_t, bool, const std::string &, const std::string &>(),
|
||||
py::arg("transducer") = OfflineTransducerModelConfig(),
|
||||
py::arg("paraformer") = OfflineParaformerModelConfig(),
|
||||
py::arg("nemo_ctc") = OfflineNemoEncDecCtcModelConfig(),
|
||||
py::arg("whisper") = OfflineWhisperModelConfig(),
|
||||
py::arg("tdnn") = OfflineTdnnModelConfig(), py::arg("tokens"),
|
||||
py::arg("num_threads"), py::arg("debug") = false,
|
||||
py::arg("tdnn") = OfflineTdnnModelConfig(),
|
||||
py::arg("zipformer_ctc") = OfflineZipformerCtcModelConfig(),
|
||||
py::arg("tokens"), py::arg("num_threads"), py::arg("debug") = false,
|
||||
py::arg("provider") = "cpu", py::arg("model_type") = "")
|
||||
.def_readwrite("transducer", &PyClass::transducer)
|
||||
.def_readwrite("paraformer", &PyClass::paraformer)
|
||||
.def_readwrite("nemo_ctc", &PyClass::nemo_ctc)
|
||||
.def_readwrite("whisper", &PyClass::whisper)
|
||||
.def_readwrite("tdnn", &PyClass::tdnn)
|
||||
.def_readwrite("zipformer_ctc", &PyClass::zipformer_ctc)
|
||||
.def_readwrite("tokens", &PyClass::tokens)
|
||||
.def_readwrite("num_threads", &PyClass::num_threads)
|
||||
.def_readwrite("debug", &PyClass::debug)
|
||||
|
||||
@@ -16,15 +16,18 @@ static void PybindOfflineRecognizerConfig(py::module *m) {
|
||||
py::class_<PyClass>(*m, "OfflineRecognizerConfig")
|
||||
.def(py::init<const OfflineFeatureExtractorConfig &,
|
||||
const OfflineModelConfig &, const OfflineLMConfig &,
|
||||
const std::string &, int32_t, const std::string &, float>(),
|
||||
const OfflineCtcFstDecoderConfig &, const std::string &,
|
||||
int32_t, const std::string &, float>(),
|
||||
py::arg("feat_config"), py::arg("model_config"),
|
||||
py::arg("lm_config") = OfflineLMConfig(),
|
||||
py::arg("ctc_fst_decoder_config") = OfflineCtcFstDecoderConfig(),
|
||||
py::arg("decoding_method") = "greedy_search",
|
||||
py::arg("max_active_paths") = 4, py::arg("hotwords_file") = "",
|
||||
py::arg("hotwords_score") = 1.5)
|
||||
.def_readwrite("feat_config", &PyClass::feat_config)
|
||||
.def_readwrite("model_config", &PyClass::model_config)
|
||||
.def_readwrite("lm_config", &PyClass::lm_config)
|
||||
.def_readwrite("ctc_fst_decoder_config", &PyClass::ctc_fst_decoder_config)
|
||||
.def_readwrite("decoding_method", &PyClass::decoding_method)
|
||||
.def_readwrite("max_active_paths", &PyClass::max_active_paths)
|
||||
.def_readwrite("hotwords_file", &PyClass::hotwords_file)
|
||||
|
||||
@@ -0,0 +1,22 @@
|
||||
// sherpa-onnx/python/csrc/offline-zipformer-ctc-model-config.cc
|
||||
//
|
||||
// Copyright (c) 2023 Xiaomi Corporation
|
||||
|
||||
#include "sherpa-onnx/python/csrc/offline-zipformer-ctc-model-config.h"
|
||||
|
||||
#include <string>
|
||||
|
||||
#include "sherpa-onnx/csrc/offline-zipformer-ctc-model-config.h"
|
||||
|
||||
namespace sherpa_onnx {
|
||||
|
||||
void PybindOfflineZipformerCtcModelConfig(py::module *m) {
|
||||
using PyClass = OfflineZipformerCtcModelConfig;
|
||||
py::class_<PyClass>(*m, "OfflineZipformerCtcModelConfig")
|
||||
.def(py::init<>())
|
||||
.def(py::init<const std::string &>(), py::arg("model"))
|
||||
.def_readwrite("model", &PyClass::model)
|
||||
.def("__str__", &PyClass::ToString);
|
||||
}
|
||||
|
||||
} // namespace sherpa_onnx
|
||||
16
sherpa-onnx/python/csrc/offline-zipformer-ctc-model-config.h
Normal file
16
sherpa-onnx/python/csrc/offline-zipformer-ctc-model-config.h
Normal file
@@ -0,0 +1,16 @@
|
||||
// sherpa-onnx/python/csrc/offline-zipformer-ctc-model-config.h
|
||||
//
|
||||
// Copyright (c) 2023 Xiaomi Corporation
|
||||
|
||||
#ifndef SHERPA_ONNX_PYTHON_CSRC_OFFLINE_ZIPFORMER_CTC_MODEL_CONFIG_H_
|
||||
#define SHERPA_ONNX_PYTHON_CSRC_OFFLINE_ZIPFORMER_CTC_MODEL_CONFIG_H_
|
||||
|
||||
#include "sherpa-onnx/python/csrc/sherpa-onnx.h"
|
||||
|
||||
namespace sherpa_onnx {
|
||||
|
||||
void PybindOfflineZipformerCtcModelConfig(py::module *m);
|
||||
|
||||
}
|
||||
|
||||
#endif // SHERPA_ONNX_PYTHON_CSRC_OFFLINE_ZIPFORMER_CTC_MODEL_CONFIG_H_
|
||||
@@ -8,6 +8,7 @@
|
||||
#include "sherpa-onnx/python/csrc/display.h"
|
||||
#include "sherpa-onnx/python/csrc/endpoint.h"
|
||||
#include "sherpa-onnx/python/csrc/features.h"
|
||||
#include "sherpa-onnx/python/csrc/offline-ctc-fst-decoder-config.h"
|
||||
#include "sherpa-onnx/python/csrc/offline-lm-config.h"
|
||||
#include "sherpa-onnx/python/csrc/offline-model-config.h"
|
||||
#include "sherpa-onnx/python/csrc/offline-recognizer.h"
|
||||
@@ -37,6 +38,7 @@ PYBIND11_MODULE(_sherpa_onnx, m) {
|
||||
PybindOfflineStream(&m);
|
||||
PybindOfflineLMConfig(&m);
|
||||
PybindOfflineModelConfig(&m);
|
||||
PybindOfflineCtcFstDecoderConfig(&m);
|
||||
PybindOfflineRecognizer(&m);
|
||||
|
||||
PybindVadModelConfig(&m);
|
||||
|
||||
@@ -4,12 +4,14 @@ from pathlib import Path
|
||||
from typing import List, Optional
|
||||
|
||||
from _sherpa_onnx import (
|
||||
OfflineCtcFstDecoderConfig,
|
||||
OfflineFeatureExtractorConfig,
|
||||
OfflineModelConfig,
|
||||
OfflineNemoEncDecCtcModelConfig,
|
||||
OfflineParaformerModelConfig,
|
||||
OfflineTdnnModelConfig,
|
||||
OfflineWhisperModelConfig,
|
||||
OfflineZipformerCtcModelConfig,
|
||||
)
|
||||
from _sherpa_onnx import OfflineRecognizer as _Recognizer
|
||||
from _sherpa_onnx import (
|
||||
|
||||
Reference in New Issue
Block a user