Begin to support CTC models (#119)
Please see https://k2-fsa.github.io/sherpa/onnx/pretrained_models/offline-ctc/nemo/index.html for a list of pre-trained CTC models from NeMo.
This commit is contained in:
47
.github/scripts/test-offline-ctc.sh
vendored
Executable file
47
.github/scripts/test-offline-ctc.sh
vendored
Executable file
@@ -0,0 +1,47 @@
|
||||
#!/usr/bin/env bash
|
||||
|
||||
set -e
|
||||
|
||||
log() {
|
||||
# This function is from espnet
|
||||
local fname=${BASH_SOURCE[1]##*/}
|
||||
echo -e "$(date '+%Y-%m-%d %H:%M:%S') (${fname}:${BASH_LINENO[0]}:${FUNCNAME[1]}) $*"
|
||||
}
|
||||
|
||||
echo "EXE is $EXE"
|
||||
echo "PATH: $PATH"
|
||||
|
||||
which $EXE
|
||||
|
||||
log "------------------------------------------------------------"
|
||||
log "Run Citrinet (stt_en_citrinet_512, English)"
|
||||
log "------------------------------------------------------------"
|
||||
|
||||
repo_url=http://huggingface.co/csukuangfj/sherpa-onnx-nemo-ctc-en-citrinet-512
|
||||
log "Start testing ${repo_url}"
|
||||
repo=$(basename $repo_url)
|
||||
log "Download pretrained model and test-data from $repo_url"
|
||||
|
||||
GIT_LFS_SKIP_SMUDGE=1 git clone $repo_url
|
||||
pushd $repo
|
||||
git lfs pull --include "*.onnx"
|
||||
ls -lh *.onnx
|
||||
popd
|
||||
|
||||
time $EXE \
|
||||
--tokens=$repo/tokens.txt \
|
||||
--nemo-ctc-model=$repo/model.onnx \
|
||||
--num-threads=2 \
|
||||
$repo/test_wavs/0.wav \
|
||||
$repo/test_wavs/1.wav \
|
||||
$repo/test_wavs/8k.wav
|
||||
|
||||
time $EXE \
|
||||
--tokens=$repo/tokens.txt \
|
||||
--nemo-ctc-model=$repo/model.int8.onnx \
|
||||
--num-threads=2 \
|
||||
$repo/test_wavs/0.wav \
|
||||
$repo/test_wavs/1.wav \
|
||||
$repo/test_wavs/8k.wav
|
||||
|
||||
rm -rf $repo
|
||||
38
.github/scripts/test-python.sh
vendored
38
.github/scripts/test-python.sh
vendored
@@ -95,6 +95,8 @@ python3 ./python-api-examples/offline-decode-files.py \
|
||||
|
||||
python3 sherpa-onnx/python/tests/test_offline_recognizer.py --verbose
|
||||
|
||||
rm -rf $repo
|
||||
|
||||
log "Test non-streaming paraformer models"
|
||||
|
||||
pushd $dir
|
||||
@@ -128,3 +130,39 @@ python3 ./python-api-examples/offline-decode-files.py \
|
||||
$repo/test_wavs/8k.wav
|
||||
|
||||
python3 sherpa-onnx/python/tests/test_offline_recognizer.py --verbose
|
||||
|
||||
rm -rf $repo
|
||||
|
||||
log "Test non-streaming NeMo CTC models"
|
||||
|
||||
pushd $dir
|
||||
repo_url=http://huggingface.co/csukuangfj/sherpa-onnx-nemo-ctc-en-citrinet-512
|
||||
|
||||
log "Start testing ${repo_url}"
|
||||
repo=$dir/$(basename $repo_url)
|
||||
log "Download pretrained model and test-data from $repo_url"
|
||||
|
||||
GIT_LFS_SKIP_SMUDGE=1 git clone $repo_url
|
||||
cd $repo
|
||||
git lfs pull --include "*.onnx"
|
||||
popd
|
||||
|
||||
ls -lh $repo
|
||||
|
||||
python3 ./python-api-examples/offline-decode-files.py \
|
||||
--tokens=$repo/tokens.txt \
|
||||
--nemo-ctc=$repo/model.onnx \
|
||||
$repo/test_wavs/0.wav \
|
||||
$repo/test_wavs/1.wav \
|
||||
$repo/test_wavs/8k.wav
|
||||
|
||||
python3 ./python-api-examples/offline-decode-files.py \
|
||||
--tokens=$repo/tokens.txt \
|
||||
--nemo-ctc=$repo/model.int8.onnx \
|
||||
$repo/test_wavs/0.wav \
|
||||
$repo/test_wavs/1.wav \
|
||||
$repo/test_wavs/8k.wav
|
||||
|
||||
python3 sherpa-onnx/python/tests/test_offline_recognizer.py --verbose
|
||||
|
||||
rm -rf $repo
|
||||
|
||||
10
.github/workflows/linux.yaml
vendored
10
.github/workflows/linux.yaml
vendored
@@ -8,6 +8,7 @@ on:
|
||||
- '.github/workflows/linux.yaml'
|
||||
- '.github/scripts/test-online-transducer.sh'
|
||||
- '.github/scripts/test-offline-transducer.sh'
|
||||
- '.github/scripts/test-offline-ctc.sh'
|
||||
- 'CMakeLists.txt'
|
||||
- 'cmake/**'
|
||||
- 'sherpa-onnx/csrc/*'
|
||||
@@ -20,6 +21,7 @@ on:
|
||||
- '.github/workflows/linux.yaml'
|
||||
- '.github/scripts/test-online-transducer.sh'
|
||||
- '.github/scripts/test-offline-transducer.sh'
|
||||
- '.github/scripts/test-offline-ctc.sh'
|
||||
- 'CMakeLists.txt'
|
||||
- 'cmake/**'
|
||||
- 'sherpa-onnx/csrc/*'
|
||||
@@ -68,6 +70,14 @@ jobs:
|
||||
file build/bin/sherpa-onnx
|
||||
readelf -d build/bin/sherpa-onnx
|
||||
|
||||
- name: Test offline CTC
|
||||
shell: bash
|
||||
run: |
|
||||
export PATH=$PWD/build/bin:$PATH
|
||||
export EXE=sherpa-onnx-offline
|
||||
|
||||
.github/scripts/test-offline-ctc.sh
|
||||
|
||||
- name: Test offline transducer
|
||||
shell: bash
|
||||
run: |
|
||||
|
||||
10
.github/workflows/macos.yaml
vendored
10
.github/workflows/macos.yaml
vendored
@@ -8,6 +8,7 @@ on:
|
||||
- '.github/workflows/macos.yaml'
|
||||
- '.github/scripts/test-online-transducer.sh'
|
||||
- '.github/scripts/test-offline-transducer.sh'
|
||||
- '.github/scripts/test-offline-ctc.sh'
|
||||
- 'CMakeLists.txt'
|
||||
- 'cmake/**'
|
||||
- 'sherpa-onnx/csrc/*'
|
||||
@@ -18,6 +19,7 @@ on:
|
||||
- '.github/workflows/macos.yaml'
|
||||
- '.github/scripts/test-online-transducer.sh'
|
||||
- '.github/scripts/test-offline-transducer.sh'
|
||||
- '.github/scripts/test-offline-ctc.sh'
|
||||
- 'CMakeLists.txt'
|
||||
- 'cmake/**'
|
||||
- 'sherpa-onnx/csrc/*'
|
||||
@@ -67,6 +69,14 @@ jobs:
|
||||
otool -L build/bin/sherpa-onnx
|
||||
otool -l build/bin/sherpa-onnx
|
||||
|
||||
- name: Test offline CTC
|
||||
shell: bash
|
||||
run: |
|
||||
export PATH=$PWD/build/bin:$PATH
|
||||
export EXE=sherpa-onnx-offline
|
||||
|
||||
.github/scripts/test-offline-ctc.sh
|
||||
|
||||
- name: Test offline transducer
|
||||
shell: bash
|
||||
run: |
|
||||
|
||||
10
.github/workflows/windows-x64.yaml
vendored
10
.github/workflows/windows-x64.yaml
vendored
@@ -8,6 +8,7 @@ on:
|
||||
- '.github/workflows/windows-x64.yaml'
|
||||
- '.github/scripts/test-online-transducer.sh'
|
||||
- '.github/scripts/test-offline-transducer.sh'
|
||||
- '.github/scripts/test-offline-ctc.sh'
|
||||
- 'CMakeLists.txt'
|
||||
- 'cmake/**'
|
||||
- 'sherpa-onnx/csrc/*'
|
||||
@@ -18,6 +19,7 @@ on:
|
||||
- '.github/workflows/windows-x64.yaml'
|
||||
- '.github/scripts/test-online-transducer.sh'
|
||||
- '.github/scripts/test-offline-transducer.sh'
|
||||
- '.github/scripts/test-offline-ctc.sh'
|
||||
- 'CMakeLists.txt'
|
||||
- 'cmake/**'
|
||||
- 'sherpa-onnx/csrc/*'
|
||||
@@ -73,6 +75,14 @@ jobs:
|
||||
|
||||
ls -lh ./bin/Release/sherpa-onnx.exe
|
||||
|
||||
- name: Test offline CTC for windows x64
|
||||
shell: bash
|
||||
run: |
|
||||
export PATH=$PWD/build/bin/Release:$PATH
|
||||
export EXE=sherpa-onnx-offline.exe
|
||||
|
||||
.github/scripts/test-offline-ctc.sh
|
||||
|
||||
- name: Test offline transducer for Windows x64
|
||||
shell: bash
|
||||
run: |
|
||||
|
||||
11
.github/workflows/windows-x86.yaml
vendored
11
.github/workflows/windows-x86.yaml
vendored
@@ -8,6 +8,7 @@ on:
|
||||
- '.github/workflows/windows-x86.yaml'
|
||||
- '.github/scripts/test-online-transducer.sh'
|
||||
- '.github/scripts/test-offline-transducer.sh'
|
||||
- '.github/scripts/test-offline-ctc.sh'
|
||||
- 'CMakeLists.txt'
|
||||
- 'cmake/**'
|
||||
- 'sherpa-onnx/csrc/*'
|
||||
@@ -18,6 +19,7 @@ on:
|
||||
- '.github/workflows/windows-x86.yaml'
|
||||
- '.github/scripts/test-online-transducer.sh'
|
||||
- '.github/scripts/test-offline-transducer.sh'
|
||||
- '.github/scripts/test-offline-ctc.sh'
|
||||
- 'CMakeLists.txt'
|
||||
- 'cmake/**'
|
||||
- 'sherpa-onnx/csrc/*'
|
||||
@@ -31,6 +33,7 @@ permissions:
|
||||
|
||||
jobs:
|
||||
windows_x86:
|
||||
if: false # disable windows x86 CI for now
|
||||
runs-on: ${{ matrix.os }}
|
||||
name: ${{ matrix.vs-version }}
|
||||
strategy:
|
||||
@@ -73,6 +76,14 @@ jobs:
|
||||
|
||||
ls -lh ./bin/Release/sherpa-onnx.exe
|
||||
|
||||
- name: Test offline CTC for windows x86
|
||||
shell: bash
|
||||
run: |
|
||||
export PATH=$PWD/build/bin/Release:$PATH
|
||||
export EXE=sherpa-onnx-offline.exe
|
||||
|
||||
.github/scripts/test-offline-ctc.sh
|
||||
|
||||
- name: Test offline transducer for Windows x86
|
||||
shell: bash
|
||||
run: |
|
||||
|
||||
3
.gitignore
vendored
3
.gitignore
vendored
@@ -52,3 +52,6 @@ run-offline-websocket-client-*.sh
|
||||
run-sherpa-onnx-*.sh
|
||||
sherpa-onnx-zipformer-en-2023-03-30
|
||||
sherpa-onnx-zipformer-en-2023-04-01
|
||||
run-offline-decode-files.sh
|
||||
sherpa-onnx-nemo-ctc-en-citrinet-512
|
||||
run-offline-decode-files-nemo-ctc.sh
|
||||
|
||||
@@ -6,7 +6,7 @@
|
||||
This file demonstrates how to use sherpa-onnx Python API to transcribe
|
||||
file(s) with a non-streaming model.
|
||||
|
||||
paraformer Usage:
|
||||
(1) For paraformer
|
||||
./python-api-examples/offline-decode-files.py \
|
||||
--tokens=/path/to/tokens.txt \
|
||||
--paraformer=/path/to/paraformer.onnx \
|
||||
@@ -18,7 +18,7 @@ paraformer Usage:
|
||||
/path/to/0.wav \
|
||||
/path/to/1.wav
|
||||
|
||||
transducer Usage:
|
||||
(2) For transducer models from icefall
|
||||
./python-api-examples/offline-decode-files.py \
|
||||
--tokens=/path/to/tokens.txt \
|
||||
--encoder=/path/to/encoder.onnx \
|
||||
@@ -32,6 +32,8 @@ transducer Usage:
|
||||
/path/to/0.wav \
|
||||
/path/to/1.wav
|
||||
|
||||
(3) For CTC models from NeMo
|
||||
|
||||
Please refer to
|
||||
https://k2-fsa.github.io/sherpa/onnx/index.html
|
||||
to install sherpa-onnx and to download the pre-trained models
|
||||
@@ -83,7 +85,14 @@ def get_args():
|
||||
"--paraformer",
|
||||
default="",
|
||||
type=str,
|
||||
help="Path to the paraformer model",
|
||||
help="Path to the model.onnx from Paraformer",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--nemo-ctc",
|
||||
default="",
|
||||
type=str,
|
||||
help="Path to the model.onnx from NeMo CTC",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
@@ -171,11 +180,14 @@ def main():
|
||||
args = get_args()
|
||||
assert_file_exists(args.tokens)
|
||||
assert args.num_threads > 0, args.num_threads
|
||||
if len(args.encoder) > 0:
|
||||
if args.encoder:
|
||||
assert len(args.paraformer) == 0, args.paraformer
|
||||
assert len(args.nemo_ctc) == 0, args.nemo_ctc
|
||||
|
||||
assert_file_exists(args.encoder)
|
||||
assert_file_exists(args.decoder)
|
||||
assert_file_exists(args.joiner)
|
||||
assert len(args.paraformer) == 0, args.paraformer
|
||||
|
||||
recognizer = sherpa_onnx.OfflineRecognizer.from_transducer(
|
||||
encoder=args.encoder,
|
||||
decoder=args.decoder,
|
||||
@@ -187,8 +199,10 @@ def main():
|
||||
decoding_method=args.decoding_method,
|
||||
debug=args.debug,
|
||||
)
|
||||
else:
|
||||
elif args.paraformer:
|
||||
assert len(args.nemo_ctc) == 0, args.nemo_ctc
|
||||
assert_file_exists(args.paraformer)
|
||||
|
||||
recognizer = sherpa_onnx.OfflineRecognizer.from_paraformer(
|
||||
paraformer=args.paraformer,
|
||||
tokens=args.tokens,
|
||||
@@ -198,6 +212,19 @@ def main():
|
||||
decoding_method=args.decoding_method,
|
||||
debug=args.debug,
|
||||
)
|
||||
elif args.nemo_ctc:
|
||||
recognizer = sherpa_onnx.OfflineRecognizer.from_nemo_ctc(
|
||||
model=args.nemo_ctc,
|
||||
tokens=args.tokens,
|
||||
num_threads=args.num_threads,
|
||||
sample_rate=args.sample_rate,
|
||||
feature_dim=args.feature_dim,
|
||||
decoding_method=args.decoding_method,
|
||||
debug=args.debug,
|
||||
)
|
||||
else:
|
||||
print("Please specify at least one model")
|
||||
return
|
||||
|
||||
print("Started!")
|
||||
start_time = time.time()
|
||||
@@ -225,12 +252,14 @@ def main():
|
||||
print("-" * 10)
|
||||
|
||||
elapsed_seconds = end_time - start_time
|
||||
rtf = elapsed_seconds / duration
|
||||
rtf = elapsed_seconds / total_duration
|
||||
print(f"num_threads: {args.num_threads}")
|
||||
print(f"decoding_method: {args.decoding_method}")
|
||||
print(f"Wave duration: {duration:.3f} s")
|
||||
print(f"Wave duration: {total_duration:.3f} s")
|
||||
print(f"Elapsed time: {elapsed_seconds:.3f} s")
|
||||
print(f"Real time factor (RTF): {elapsed_seconds:.3f}/{duration:.3f} = {rtf:.3f}")
|
||||
print(
|
||||
f"Real time factor (RTF): {elapsed_seconds:.3f}/{total_duration:.3f} = {rtf:.3f}"
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
@@ -172,12 +172,14 @@ def main():
|
||||
print("-" * 10)
|
||||
|
||||
elapsed_seconds = end_time - start_time
|
||||
rtf = elapsed_seconds / duration
|
||||
rtf = elapsed_seconds / total_duration
|
||||
print(f"num_threads: {args.num_threads}")
|
||||
print(f"decoding_method: {args.decoding_method}")
|
||||
print(f"Wave duration: {duration:.3f} s")
|
||||
print(f"Wave duration: {total_duration:.3f} s")
|
||||
print(f"Elapsed time: {elapsed_seconds:.3f} s")
|
||||
print(f"Real time factor (RTF): {elapsed_seconds:.3f}/{duration:.3f} = {rtf:.3f}")
|
||||
print(
|
||||
f"Real time factor (RTF): {elapsed_seconds:.3f}/{total_duration:.3f} = {rtf:.3f}"
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
@@ -16,7 +16,11 @@ set(sources
|
||||
features.cc
|
||||
file-utils.cc
|
||||
hypothesis.cc
|
||||
offline-ctc-greedy-search-decoder.cc
|
||||
offline-ctc-model.cc
|
||||
offline-model-config.cc
|
||||
offline-nemo-enc-dec-ctc-model-config.cc
|
||||
offline-nemo-enc-dec-ctc-model.cc
|
||||
offline-paraformer-greedy-search-decoder.cc
|
||||
offline-paraformer-model-config.cc
|
||||
offline-paraformer-model.cc
|
||||
|
||||
@@ -11,15 +11,19 @@
|
||||
#include "android/log.h"
|
||||
#define SHERPA_ONNX_LOGE(...) \
|
||||
do { \
|
||||
fprintf(stderr, "%s:%s:%d ", __FILE__, __func__, \
|
||||
static_cast<int>(__LINE__)); \
|
||||
fprintf(stderr, ##__VA_ARGS__); \
|
||||
fprintf(stderr, "\n"); \
|
||||
__android_log_print(ANDROID_LOG_WARN, "sherpa-onnx", ##__VA_ARGS__); \
|
||||
} while (0)
|
||||
#else
|
||||
#define SHERPA_ONNX_LOGE(...) \
|
||||
do { \
|
||||
fprintf(stderr, ##__VA_ARGS__); \
|
||||
fprintf(stderr, "\n"); \
|
||||
#define SHERPA_ONNX_LOGE(...) \
|
||||
do { \
|
||||
fprintf(stderr, "%s:%s:%d ", __FILE__, __func__, \
|
||||
static_cast<int>(__LINE__)); \
|
||||
fprintf(stderr, ##__VA_ARGS__); \
|
||||
fprintf(stderr, "\n"); \
|
||||
} while (0)
|
||||
#endif
|
||||
|
||||
|
||||
42
sherpa-onnx/csrc/offline-ctc-decoder.h
Normal file
42
sherpa-onnx/csrc/offline-ctc-decoder.h
Normal file
@@ -0,0 +1,42 @@
|
||||
// sherpa-onnx/csrc/offline-ctc-decoder.h
|
||||
//
|
||||
// Copyright (c) 2023 Xiaomi Corporation
|
||||
|
||||
#ifndef SHERPA_ONNX_CSRC_OFFLINE_CTC_DECODER_H_
|
||||
#define SHERPA_ONNX_CSRC_OFFLINE_CTC_DECODER_H_
|
||||
|
||||
#include <vector>
|
||||
|
||||
#include "onnxruntime_cxx_api.h" // NOLINT
|
||||
|
||||
namespace sherpa_onnx {
|
||||
|
||||
struct OfflineCtcDecoderResult {
|
||||
/// The decoded token IDs
|
||||
std::vector<int64_t> tokens;
|
||||
|
||||
/// timestamps[i] contains the output frame index where tokens[i] is decoded.
|
||||
/// Note: The index is after subsampling
|
||||
std::vector<int32_t> timestamps;
|
||||
};
|
||||
|
||||
class OfflineCtcDecoder {
|
||||
public:
|
||||
virtual ~OfflineCtcDecoder() = default;
|
||||
|
||||
/** Run CTC decoding given the output from the encoder model.
|
||||
*
|
||||
* @param log_probs A 3-D tensor of shape (N, T, vocab_size) containing
|
||||
* lob_probs.
|
||||
* @param log_probs_length A 1-D tensor of shape (N,) containing number
|
||||
* of valid frames in log_probs before padding.
|
||||
*
|
||||
* @return Return a vector of size `N` containing the decoded results.
|
||||
*/
|
||||
virtual std::vector<OfflineCtcDecoderResult> Decode(
|
||||
Ort::Value log_probs, Ort::Value log_probs_length) = 0;
|
||||
};
|
||||
|
||||
} // namespace sherpa_onnx
|
||||
|
||||
#endif // SHERPA_ONNX_CSRC_OFFLINE_CTC_DECODER_H_
|
||||
54
sherpa-onnx/csrc/offline-ctc-greedy-search-decoder.cc
Normal file
54
sherpa-onnx/csrc/offline-ctc-greedy-search-decoder.cc
Normal file
@@ -0,0 +1,54 @@
|
||||
// sherpa-onnx/csrc/offline-ctc-greedy-search-decoder.h
|
||||
//
|
||||
// Copyright (c) 2023 Xiaomi Corporation
|
||||
|
||||
#include "sherpa-onnx/csrc/offline-ctc-greedy-search-decoder.h"
|
||||
|
||||
#include <algorithm>
|
||||
#include <utility>
|
||||
#include <vector>
|
||||
|
||||
#include "sherpa-onnx/csrc/macros.h"
|
||||
|
||||
namespace sherpa_onnx {
|
||||
|
||||
std::vector<OfflineCtcDecoderResult> OfflineCtcGreedySearchDecoder::Decode(
|
||||
Ort::Value log_probs, Ort::Value log_probs_length) {
|
||||
std::vector<int64_t> shape = log_probs.GetTensorTypeAndShapeInfo().GetShape();
|
||||
int32_t batch_size = static_cast<int32_t>(shape[0]);
|
||||
int32_t num_frames = static_cast<int32_t>(shape[1]);
|
||||
int32_t vocab_size = static_cast<int32_t>(shape[2]);
|
||||
|
||||
const int64_t *p_log_probs_length = log_probs_length.GetTensorData<int64_t>();
|
||||
|
||||
std::vector<OfflineCtcDecoderResult> ans;
|
||||
ans.reserve(batch_size);
|
||||
|
||||
for (int32_t b = 0; b != batch_size; ++b) {
|
||||
const float *p_log_probs =
|
||||
log_probs.GetTensorData<float>() + b * num_frames * vocab_size;
|
||||
|
||||
OfflineCtcDecoderResult r;
|
||||
int64_t prev_id = -1;
|
||||
|
||||
for (int32_t t = 0; t != static_cast<int32_t>(p_log_probs_length[b]); ++t) {
|
||||
auto y = static_cast<int64_t>(std::distance(
|
||||
static_cast<const float *>(p_log_probs),
|
||||
std::max_element(
|
||||
static_cast<const float *>(p_log_probs),
|
||||
static_cast<const float *>(p_log_probs) + vocab_size)));
|
||||
p_log_probs += vocab_size;
|
||||
|
||||
if (y != blank_id_ && y != prev_id) {
|
||||
r.tokens.push_back(y);
|
||||
r.timestamps.push_back(t);
|
||||
prev_id = y;
|
||||
}
|
||||
} // for (int32_t t = 0; ...)
|
||||
|
||||
ans.push_back(std::move(r));
|
||||
}
|
||||
return ans;
|
||||
}
|
||||
|
||||
} // namespace sherpa_onnx
|
||||
28
sherpa-onnx/csrc/offline-ctc-greedy-search-decoder.h
Normal file
28
sherpa-onnx/csrc/offline-ctc-greedy-search-decoder.h
Normal file
@@ -0,0 +1,28 @@
|
||||
// sherpa-onnx/csrc/offline-ctc-greedy-search-decoder.h
|
||||
//
|
||||
// Copyright (c) 2023 Xiaomi Corporation
|
||||
|
||||
#ifndef SHERPA_ONNX_CSRC_OFFLINE_CTC_GREEDY_SEARCH_DECODER_H_
|
||||
#define SHERPA_ONNX_CSRC_OFFLINE_CTC_GREEDY_SEARCH_DECODER_H_
|
||||
|
||||
#include <vector>
|
||||
|
||||
#include "sherpa-onnx/csrc/offline-ctc-decoder.h"
|
||||
|
||||
namespace sherpa_onnx {
|
||||
|
||||
class OfflineCtcGreedySearchDecoder : public OfflineCtcDecoder {
|
||||
public:
|
||||
explicit OfflineCtcGreedySearchDecoder(int32_t blank_id)
|
||||
: blank_id_(blank_id) {}
|
||||
|
||||
std::vector<OfflineCtcDecoderResult> Decode(
|
||||
Ort::Value log_probs, Ort::Value log_probs_length) override;
|
||||
|
||||
private:
|
||||
int32_t blank_id_;
|
||||
};
|
||||
|
||||
} // namespace sherpa_onnx
|
||||
|
||||
#endif // SHERPA_ONNX_CSRC_OFFLINE_CTC_GREEDY_SEARCH_DECODER_H_
|
||||
86
sherpa-onnx/csrc/offline-ctc-model.cc
Normal file
86
sherpa-onnx/csrc/offline-ctc-model.cc
Normal file
@@ -0,0 +1,86 @@
|
||||
// sherpa-onnx/csrc/offline-ctc-model.cc
|
||||
//
|
||||
// Copyright (c) 2022-2023 Xiaomi Corporation
|
||||
|
||||
#include "sherpa-onnx/csrc/offline-ctc-model.h"
|
||||
|
||||
#include <algorithm>
|
||||
#include <memory>
|
||||
#include <sstream>
|
||||
#include <string>
|
||||
|
||||
#include "sherpa-onnx/csrc/macros.h"
|
||||
#include "sherpa-onnx/csrc/offline-nemo-enc-dec-ctc-model.h"
|
||||
#include "sherpa-onnx/csrc/onnx-utils.h"
|
||||
|
||||
namespace {
|
||||
|
||||
enum class ModelType {
|
||||
kEncDecCTCModelBPE,
|
||||
kUnkown,
|
||||
};
|
||||
|
||||
}
|
||||
|
||||
namespace sherpa_onnx {
|
||||
|
||||
static ModelType GetModelType(char *model_data, size_t model_data_length,
|
||||
bool debug) {
|
||||
Ort::Env env(ORT_LOGGING_LEVEL_WARNING);
|
||||
Ort::SessionOptions sess_opts;
|
||||
|
||||
auto sess = std::make_unique<Ort::Session>(env, model_data, model_data_length,
|
||||
sess_opts);
|
||||
|
||||
Ort::ModelMetadata meta_data = sess->GetModelMetadata();
|
||||
if (debug) {
|
||||
std::ostringstream os;
|
||||
PrintModelMetadata(os, meta_data);
|
||||
SHERPA_ONNX_LOGE("%s", os.str().c_str());
|
||||
}
|
||||
|
||||
Ort::AllocatorWithDefaultOptions allocator;
|
||||
auto model_type =
|
||||
meta_data.LookupCustomMetadataMapAllocated("model_type", allocator);
|
||||
if (!model_type) {
|
||||
SHERPA_ONNX_LOGE(
|
||||
"No model_type in the metadata!\n"
|
||||
"If you are using models from NeMo, please refer to\n"
|
||||
"https://huggingface.co/csukuangfj/"
|
||||
"sherpa-onnx-nemo-ctc-en-citrinet-512/blob/main/add-model-metadata.py"
|
||||
"\n"
|
||||
"for how to add metadta to model.onnx\n");
|
||||
return ModelType::kUnkown;
|
||||
}
|
||||
|
||||
if (model_type.get() == std::string("EncDecCTCModelBPE")) {
|
||||
return ModelType::kEncDecCTCModelBPE;
|
||||
} else {
|
||||
SHERPA_ONNX_LOGE("Unsupported model_type: %s", model_type.get());
|
||||
return ModelType::kUnkown;
|
||||
}
|
||||
}
|
||||
|
||||
std::unique_ptr<OfflineCtcModel> OfflineCtcModel::Create(
|
||||
const OfflineModelConfig &config) {
|
||||
ModelType model_type = ModelType::kUnkown;
|
||||
|
||||
{
|
||||
auto buffer = ReadFile(config.nemo_ctc.model);
|
||||
|
||||
model_type = GetModelType(buffer.data(), buffer.size(), config.debug);
|
||||
}
|
||||
|
||||
switch (model_type) {
|
||||
case ModelType::kEncDecCTCModelBPE:
|
||||
return std::make_unique<OfflineNemoEncDecCtcModel>(config);
|
||||
break;
|
||||
case ModelType::kUnkown:
|
||||
SHERPA_ONNX_LOGE("Unknown model type in offline CTC!");
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
} // namespace sherpa_onnx
|
||||
59
sherpa-onnx/csrc/offline-ctc-model.h
Normal file
59
sherpa-onnx/csrc/offline-ctc-model.h
Normal file
@@ -0,0 +1,59 @@
|
||||
// sherpa-onnx/csrc/offline-ctc-model.h
|
||||
//
|
||||
// Copyright (c) 2022-2023 Xiaomi Corporation
|
||||
#ifndef SHERPA_ONNX_CSRC_OFFLINE_CTC_MODEL_H_
|
||||
#define SHERPA_ONNX_CSRC_OFFLINE_CTC_MODEL_H_
|
||||
|
||||
#include <memory>
|
||||
#include <string>
|
||||
#include <utility>
|
||||
|
||||
#include "onnxruntime_cxx_api.h" // NOLINT
|
||||
#include "sherpa-onnx/csrc/offline-model-config.h"
|
||||
|
||||
namespace sherpa_onnx {
|
||||
|
||||
class OfflineCtcModel {
|
||||
public:
|
||||
virtual ~OfflineCtcModel() = default;
|
||||
static std::unique_ptr<OfflineCtcModel> Create(
|
||||
const OfflineModelConfig &config);
|
||||
|
||||
/** Run the forward method of the model.
|
||||
*
|
||||
* @param features A tensor of shape (N, T, C). It is changed in-place.
|
||||
* @param features_length A 1-D tensor of shape (N,) containing number of
|
||||
* valid frames in `features` before padding.
|
||||
* Its dtype is int64_t.
|
||||
*
|
||||
* @return Return a pair containing:
|
||||
* - log_probs: A 3-D tensor of shape (N, T', vocab_size).
|
||||
* - log_probs_length A 1-D tensor of shape (N,). Its dtype is int64_t
|
||||
*/
|
||||
virtual std::pair<Ort::Value, Ort::Value> Forward(
|
||||
Ort::Value features, Ort::Value features_length) = 0;
|
||||
|
||||
/** Return the vocabulary size of the model
|
||||
*/
|
||||
virtual int32_t VocabSize() const = 0;
|
||||
|
||||
/** SubsamplingFactor of the model
|
||||
*
|
||||
* For Citrinet, the subsampling factor is usually 4.
|
||||
* For Conformer CTC, the subsampling factor is usually 8.
|
||||
*/
|
||||
virtual int32_t SubsamplingFactor() const = 0;
|
||||
|
||||
/** Return an allocator for allocating memory
|
||||
*/
|
||||
virtual OrtAllocator *Allocator() const = 0;
|
||||
|
||||
/** For some models, e.g., those from NeMo, they require some preprocessing
|
||||
* for the features.
|
||||
*/
|
||||
virtual std::string FeatureNormalizationMethod() const { return {}; }
|
||||
};
|
||||
|
||||
} // namespace sherpa_onnx
|
||||
|
||||
#endif // SHERPA_ONNX_CSRC_OFFLINE_CTC_MODEL_H_
|
||||
@@ -13,6 +13,7 @@ namespace sherpa_onnx {
|
||||
void OfflineModelConfig::Register(ParseOptions *po) {
|
||||
transducer.Register(po);
|
||||
paraformer.Register(po);
|
||||
nemo_ctc.Register(po);
|
||||
|
||||
po->Register("tokens", &tokens, "Path to tokens.txt");
|
||||
|
||||
@@ -38,6 +39,10 @@ bool OfflineModelConfig::Validate() const {
|
||||
return paraformer.Validate();
|
||||
}
|
||||
|
||||
if (!nemo_ctc.model.empty()) {
|
||||
return nemo_ctc.Validate();
|
||||
}
|
||||
|
||||
return transducer.Validate();
|
||||
}
|
||||
|
||||
@@ -47,6 +52,7 @@ std::string OfflineModelConfig::ToString() const {
|
||||
os << "OfflineModelConfig(";
|
||||
os << "transducer=" << transducer.ToString() << ", ";
|
||||
os << "paraformer=" << paraformer.ToString() << ", ";
|
||||
os << "nemo_ctc=" << nemo_ctc.ToString() << ", ";
|
||||
os << "tokens=\"" << tokens << "\", ";
|
||||
os << "num_threads=" << num_threads << ", ";
|
||||
os << "debug=" << (debug ? "True" : "False") << ")";
|
||||
|
||||
@@ -6,6 +6,7 @@
|
||||
|
||||
#include <string>
|
||||
|
||||
#include "sherpa-onnx/csrc/offline-nemo-enc-dec-ctc-model-config.h"
|
||||
#include "sherpa-onnx/csrc/offline-paraformer-model-config.h"
|
||||
#include "sherpa-onnx/csrc/offline-transducer-model-config.h"
|
||||
|
||||
@@ -14,6 +15,7 @@ namespace sherpa_onnx {
|
||||
struct OfflineModelConfig {
|
||||
OfflineTransducerModelConfig transducer;
|
||||
OfflineParaformerModelConfig paraformer;
|
||||
OfflineNemoEncDecCtcModelConfig nemo_ctc;
|
||||
|
||||
std::string tokens;
|
||||
int32_t num_threads = 2;
|
||||
@@ -22,9 +24,11 @@ struct OfflineModelConfig {
|
||||
OfflineModelConfig() = default;
|
||||
OfflineModelConfig(const OfflineTransducerModelConfig &transducer,
|
||||
const OfflineParaformerModelConfig ¶former,
|
||||
const OfflineNemoEncDecCtcModelConfig &nemo_ctc,
|
||||
const std::string &tokens, int32_t num_threads, bool debug)
|
||||
: transducer(transducer),
|
||||
paraformer(paraformer),
|
||||
nemo_ctc(nemo_ctc),
|
||||
tokens(tokens),
|
||||
num_threads(num_threads),
|
||||
debug(debug) {}
|
||||
|
||||
35
sherpa-onnx/csrc/offline-nemo-enc-dec-ctc-model-config.cc
Normal file
35
sherpa-onnx/csrc/offline-nemo-enc-dec-ctc-model-config.cc
Normal file
@@ -0,0 +1,35 @@
|
||||
// sherpa-onnx/csrc/offline-nemo-enc-dec-ctc-model-config.cc
|
||||
//
|
||||
// Copyright (c) 2023 Xiaomi Corporation
|
||||
|
||||
#include "sherpa-onnx/csrc/offline-nemo-enc-dec-ctc-model-config.h"
|
||||
|
||||
#include "sherpa-onnx/csrc/file-utils.h"
|
||||
#include "sherpa-onnx/csrc/macros.h"
|
||||
|
||||
namespace sherpa_onnx {
|
||||
|
||||
void OfflineNemoEncDecCtcModelConfig::Register(ParseOptions *po) {
|
||||
po->Register("nemo-ctc-model", &model,
|
||||
"Path to model.onnx of Nemo EncDecCtcModel.");
|
||||
}
|
||||
|
||||
bool OfflineNemoEncDecCtcModelConfig::Validate() const {
|
||||
if (!FileExists(model)) {
|
||||
SHERPA_ONNX_LOGE("%s does not exist", model.c_str());
|
||||
return false;
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
std::string OfflineNemoEncDecCtcModelConfig::ToString() const {
|
||||
std::ostringstream os;
|
||||
|
||||
os << "OfflineNemoEncDecCtcModelConfig(";
|
||||
os << "model=\"" << model << "\")";
|
||||
|
||||
return os.str();
|
||||
}
|
||||
|
||||
} // namespace sherpa_onnx
|
||||
28
sherpa-onnx/csrc/offline-nemo-enc-dec-ctc-model-config.h
Normal file
28
sherpa-onnx/csrc/offline-nemo-enc-dec-ctc-model-config.h
Normal file
@@ -0,0 +1,28 @@
|
||||
// sherpa-onnx/csrc/offline-nemo-enc-dec-ctc-model-config.h
|
||||
//
|
||||
// Copyright (c) 2023 Xiaomi Corporation
|
||||
#ifndef SHERPA_ONNX_CSRC_OFFLINE_NEMO_ENC_DEC_CTC_MODEL_CONFIG_H_
|
||||
#define SHERPA_ONNX_CSRC_OFFLINE_NEMO_ENC_DEC_CTC_MODEL_CONFIG_H_
|
||||
|
||||
#include <string>
|
||||
|
||||
#include "sherpa-onnx/csrc/parse-options.h"
|
||||
|
||||
namespace sherpa_onnx {
|
||||
|
||||
struct OfflineNemoEncDecCtcModelConfig {
|
||||
std::string model;
|
||||
|
||||
OfflineNemoEncDecCtcModelConfig() = default;
|
||||
explicit OfflineNemoEncDecCtcModelConfig(const std::string &model)
|
||||
: model(model) {}
|
||||
|
||||
void Register(ParseOptions *po);
|
||||
bool Validate() const;
|
||||
|
||||
std::string ToString() const;
|
||||
};
|
||||
|
||||
} // namespace sherpa_onnx
|
||||
|
||||
#endif // SHERPA_ONNX_CSRC_OFFLINE_NEMO_ENC_DEC_CTC_MODEL_CONFIG_H_
|
||||
131
sherpa-onnx/csrc/offline-nemo-enc-dec-ctc-model.cc
Normal file
131
sherpa-onnx/csrc/offline-nemo-enc-dec-ctc-model.cc
Normal file
@@ -0,0 +1,131 @@
|
||||
// sherpa-onnx/csrc/offline-nemo-enc-dec-ctc-model.cc
|
||||
//
|
||||
// Copyright (c) 2023 Xiaomi Corporation
|
||||
|
||||
#include "sherpa-onnx/csrc/offline-nemo-enc-dec-ctc-model.h"
|
||||
|
||||
#include "sherpa-onnx/csrc/macros.h"
|
||||
#include "sherpa-onnx/csrc/onnx-utils.h"
|
||||
#include "sherpa-onnx/csrc/text-utils.h"
|
||||
#include "sherpa-onnx/csrc/transpose.h"
|
||||
|
||||
namespace sherpa_onnx {
|
||||
|
||||
class OfflineNemoEncDecCtcModel::Impl {
|
||||
public:
|
||||
explicit Impl(const OfflineModelConfig &config)
|
||||
: config_(config),
|
||||
env_(ORT_LOGGING_LEVEL_ERROR),
|
||||
sess_opts_{},
|
||||
allocator_{} {
|
||||
sess_opts_.SetIntraOpNumThreads(config_.num_threads);
|
||||
sess_opts_.SetInterOpNumThreads(config_.num_threads);
|
||||
|
||||
Init();
|
||||
}
|
||||
|
||||
std::pair<Ort::Value, Ort::Value> Forward(Ort::Value features,
|
||||
Ort::Value features_length) {
|
||||
std::vector<int64_t> shape =
|
||||
features_length.GetTensorTypeAndShapeInfo().GetShape();
|
||||
|
||||
Ort::Value out_features_length = Ort::Value::CreateTensor<int64_t>(
|
||||
allocator_, shape.data(), shape.size());
|
||||
|
||||
const int64_t *src = features_length.GetTensorData<int64_t>();
|
||||
int64_t *dst = out_features_length.GetTensorMutableData<int64_t>();
|
||||
for (int64_t i = 0; i != shape[0]; ++i) {
|
||||
dst[i] = src[i] / subsampling_factor_;
|
||||
}
|
||||
|
||||
// (B, T, C) -> (B, C, T)
|
||||
features = Transpose12(allocator_, &features);
|
||||
|
||||
std::array<Ort::Value, 2> inputs = {std::move(features),
|
||||
std::move(features_length)};
|
||||
auto out =
|
||||
sess_->Run({}, input_names_ptr_.data(), inputs.data(), inputs.size(),
|
||||
output_names_ptr_.data(), output_names_ptr_.size());
|
||||
|
||||
return {std::move(out[0]), std::move(out_features_length)};
|
||||
}
|
||||
|
||||
int32_t VocabSize() const { return vocab_size_; }
|
||||
|
||||
int32_t SubsamplingFactor() const { return subsampling_factor_; }
|
||||
|
||||
OrtAllocator *Allocator() const { return allocator_; }
|
||||
|
||||
std::string FeatureNormalizationMethod() const { return normalize_type_; }
|
||||
|
||||
private:
|
||||
void Init() {
|
||||
auto buf = ReadFile(config_.nemo_ctc.model);
|
||||
|
||||
sess_ = std::make_unique<Ort::Session>(env_, buf.data(), buf.size(),
|
||||
sess_opts_);
|
||||
|
||||
GetInputNames(sess_.get(), &input_names_, &input_names_ptr_);
|
||||
|
||||
GetOutputNames(sess_.get(), &output_names_, &output_names_ptr_);
|
||||
|
||||
// get meta data
|
||||
Ort::ModelMetadata meta_data = sess_->GetModelMetadata();
|
||||
if (config_.debug) {
|
||||
std::ostringstream os;
|
||||
PrintModelMetadata(os, meta_data);
|
||||
SHERPA_ONNX_LOGE("%s\n", os.str().c_str());
|
||||
}
|
||||
|
||||
Ort::AllocatorWithDefaultOptions allocator; // used in the macro below
|
||||
SHERPA_ONNX_READ_META_DATA(vocab_size_, "vocab_size");
|
||||
SHERPA_ONNX_READ_META_DATA(subsampling_factor_, "subsampling_factor");
|
||||
SHERPA_ONNX_READ_META_DATA_STR(normalize_type_, "normalize_type");
|
||||
}
|
||||
|
||||
private:
|
||||
OfflineModelConfig config_;
|
||||
Ort::Env env_;
|
||||
Ort::SessionOptions sess_opts_;
|
||||
Ort::AllocatorWithDefaultOptions allocator_;
|
||||
|
||||
std::unique_ptr<Ort::Session> sess_;
|
||||
|
||||
std::vector<std::string> input_names_;
|
||||
std::vector<const char *> input_names_ptr_;
|
||||
|
||||
std::vector<std::string> output_names_;
|
||||
std::vector<const char *> output_names_ptr_;
|
||||
|
||||
int32_t vocab_size_ = 0;
|
||||
int32_t subsampling_factor_ = 0;
|
||||
std::string normalize_type_;
|
||||
};
|
||||
|
||||
OfflineNemoEncDecCtcModel::OfflineNemoEncDecCtcModel(
|
||||
const OfflineModelConfig &config)
|
||||
: impl_(std::make_unique<Impl>(config)) {}
|
||||
|
||||
OfflineNemoEncDecCtcModel::~OfflineNemoEncDecCtcModel() = default;
|
||||
|
||||
std::pair<Ort::Value, Ort::Value> OfflineNemoEncDecCtcModel::Forward(
|
||||
Ort::Value features, Ort::Value features_length) {
|
||||
return impl_->Forward(std::move(features), std::move(features_length));
|
||||
}
|
||||
|
||||
int32_t OfflineNemoEncDecCtcModel::VocabSize() const {
|
||||
return impl_->VocabSize();
|
||||
}
|
||||
int32_t OfflineNemoEncDecCtcModel::SubsamplingFactor() const {
|
||||
return impl_->SubsamplingFactor();
|
||||
}
|
||||
|
||||
OrtAllocator *OfflineNemoEncDecCtcModel::Allocator() const {
|
||||
return impl_->Allocator();
|
||||
}
|
||||
|
||||
std::string OfflineNemoEncDecCtcModel::FeatureNormalizationMethod() const {
|
||||
return impl_->FeatureNormalizationMethod();
|
||||
}
|
||||
|
||||
} // namespace sherpa_onnx
|
||||
75
sherpa-onnx/csrc/offline-nemo-enc-dec-ctc-model.h
Normal file
75
sherpa-onnx/csrc/offline-nemo-enc-dec-ctc-model.h
Normal file
@@ -0,0 +1,75 @@
|
||||
// sherpa-onnx/csrc/offline-nemo-enc-dec-ctc-model.h
|
||||
//
|
||||
// Copyright (c) 2023 Xiaomi Corporation
|
||||
#ifndef SHERPA_ONNX_CSRC_OFFLINE_NEMO_ENC_DEC_CTC_MODEL_H_
|
||||
#define SHERPA_ONNX_CSRC_OFFLINE_NEMO_ENC_DEC_CTC_MODEL_H_
|
||||
#include <memory>
|
||||
#include <string>
|
||||
#include <utility>
|
||||
#include <vector>
|
||||
|
||||
#include "onnxruntime_cxx_api.h" // NOLINT
|
||||
#include "sherpa-onnx/csrc/offline-ctc-model.h"
|
||||
#include "sherpa-onnx/csrc/offline-model-config.h"
|
||||
|
||||
namespace sherpa_onnx {
|
||||
|
||||
/** This class implements the EncDecCTCModelBPE model from NeMo.
|
||||
*
|
||||
* See
|
||||
* https://github.com/NVIDIA/NeMo/blob/main/nemo/collections/asr/models/ctc_bpe_models.py
|
||||
* https://github.com/NVIDIA/NeMo/blob/main/nemo/collections/asr/models/ctc_models.py
|
||||
*/
|
||||
class OfflineNemoEncDecCtcModel : public OfflineCtcModel {
|
||||
public:
|
||||
explicit OfflineNemoEncDecCtcModel(const OfflineModelConfig &config);
|
||||
~OfflineNemoEncDecCtcModel() override;
|
||||
|
||||
/** Run the forward method of the model.
|
||||
*
|
||||
* @param features A tensor of shape (N, T, C). It is changed in-place.
|
||||
* @param features_length A 1-D tensor of shape (N,) containing number of
|
||||
* valid frames in `features` before padding.
|
||||
* Its dtype is int64_t.
|
||||
*
|
||||
* @return Return a pair containing:
|
||||
* - log_probs: A 3-D tensor of shape (N, T', vocab_size).
|
||||
* - log_probs_length A 1-D tensor of shape (N,). Its dtype is int64_t
|
||||
*/
|
||||
std::pair<Ort::Value, Ort::Value> Forward(
|
||||
Ort::Value features, Ort::Value features_length) override;
|
||||
|
||||
/** Return the vocabulary size of the model
|
||||
*/
|
||||
int32_t VocabSize() const override;
|
||||
|
||||
/** SubsamplingFactor of the model
|
||||
*
|
||||
* For Citrinet, the subsampling factor is usually 4.
|
||||
* For Conformer CTC, the subsampling factor is usually 8.
|
||||
*/
|
||||
int32_t SubsamplingFactor() const override;
|
||||
|
||||
/** Return an allocator for allocating memory
|
||||
*/
|
||||
OrtAllocator *Allocator() const override;
|
||||
|
||||
// Possible values:
|
||||
// - per_feature
|
||||
// - all_features (not implemented yet)
|
||||
// - fixed_mean (not implemented)
|
||||
// - fixed_std (not implemented)
|
||||
// - or just leave it to empty
|
||||
// See
|
||||
// https://github.com/NVIDIA/NeMo/blob/main/nemo/collections/asr/parts/preprocessing/features.py#L59
|
||||
// for details
|
||||
std::string FeatureNormalizationMethod() const override;
|
||||
|
||||
private:
|
||||
class Impl;
|
||||
std::unique_ptr<Impl> impl_;
|
||||
};
|
||||
|
||||
} // namespace sherpa_onnx
|
||||
|
||||
#endif // SHERPA_ONNX_CSRC_OFFLINE_NEMO_ENC_DEC_CTC_MODEL_H_
|
||||
128
sherpa-onnx/csrc/offline-recognizer-ctc-impl.h
Normal file
128
sherpa-onnx/csrc/offline-recognizer-ctc-impl.h
Normal file
@@ -0,0 +1,128 @@
|
||||
// sherpa-onnx/csrc/offline-recognizer-ctc-impl.h
|
||||
//
|
||||
// Copyright (c) 2022-2023 Xiaomi Corporation
|
||||
|
||||
#ifndef SHERPA_ONNX_CSRC_OFFLINE_RECOGNIZER_CTC_IMPL_H_
|
||||
#define SHERPA_ONNX_CSRC_OFFLINE_RECOGNIZER_CTC_IMPL_H_
|
||||
|
||||
#include <memory>
|
||||
#include <string>
|
||||
#include <utility>
|
||||
#include <vector>
|
||||
|
||||
#include "sherpa-onnx/csrc/offline-ctc-decoder.h"
|
||||
#include "sherpa-onnx/csrc/offline-ctc-greedy-search-decoder.h"
|
||||
#include "sherpa-onnx/csrc/offline-ctc-model.h"
|
||||
#include "sherpa-onnx/csrc/offline-recognizer-impl.h"
|
||||
#include "sherpa-onnx/csrc/pad-sequence.h"
|
||||
#include "sherpa-onnx/csrc/symbol-table.h"
|
||||
|
||||
namespace sherpa_onnx {
|
||||
|
||||
static OfflineRecognitionResult Convert(const OfflineCtcDecoderResult &src,
|
||||
const SymbolTable &sym_table) {
|
||||
OfflineRecognitionResult r;
|
||||
r.tokens.reserve(src.tokens.size());
|
||||
|
||||
std::string text;
|
||||
|
||||
for (int32_t i = 0; i != src.tokens.size(); ++i) {
|
||||
auto sym = sym_table[src.tokens[i]];
|
||||
text.append(sym);
|
||||
r.tokens.push_back(std::move(sym));
|
||||
}
|
||||
r.text = std::move(text);
|
||||
|
||||
return r;
|
||||
}
|
||||
|
||||
class OfflineRecognizerCtcImpl : public OfflineRecognizerImpl {
|
||||
public:
|
||||
explicit OfflineRecognizerCtcImpl(const OfflineRecognizerConfig &config)
|
||||
: config_(config),
|
||||
symbol_table_(config_.model_config.tokens),
|
||||
model_(OfflineCtcModel::Create(config_.model_config)) {
|
||||
config_.feat_config.nemo_normalize_type =
|
||||
model_->FeatureNormalizationMethod();
|
||||
|
||||
if (config.decoding_method == "greedy_search") {
|
||||
if (!symbol_table_.contains("<blk>")) {
|
||||
SHERPA_ONNX_LOGE(
|
||||
"We expect that tokens.txt contains "
|
||||
"the symbol <blk> and its ID.");
|
||||
exit(-1);
|
||||
}
|
||||
|
||||
int32_t blank_id = symbol_table_["<blk>"];
|
||||
decoder_ = std::make_unique<OfflineCtcGreedySearchDecoder>(blank_id);
|
||||
} else {
|
||||
SHERPA_ONNX_LOGE("Only greedy_search is supported at present. Given %s",
|
||||
config.decoding_method.c_str());
|
||||
exit(-1);
|
||||
}
|
||||
}
|
||||
|
||||
std::unique_ptr<OfflineStream> CreateStream() const override {
|
||||
return std::make_unique<OfflineStream>(config_.feat_config);
|
||||
}
|
||||
|
||||
void DecodeStreams(OfflineStream **ss, int32_t n) const override {
|
||||
auto memory_info =
|
||||
Ort::MemoryInfo::CreateCpu(OrtDeviceAllocator, OrtMemTypeDefault);
|
||||
|
||||
int32_t feat_dim = config_.feat_config.feature_dim;
|
||||
|
||||
std::vector<Ort::Value> features;
|
||||
features.reserve(n);
|
||||
|
||||
std::vector<std::vector<float>> features_vec(n);
|
||||
std::vector<int64_t> features_length_vec(n);
|
||||
|
||||
for (int32_t i = 0; i != n; ++i) {
|
||||
std::vector<float> f = ss[i]->GetFrames();
|
||||
|
||||
int32_t num_frames = f.size() / feat_dim;
|
||||
features_vec[i] = std::move(f);
|
||||
|
||||
features_length_vec[i] = num_frames;
|
||||
|
||||
std::array<int64_t, 2> shape = {num_frames, feat_dim};
|
||||
|
||||
Ort::Value x = Ort::Value::CreateTensor(
|
||||
memory_info, features_vec[i].data(), features_vec[i].size(),
|
||||
shape.data(), shape.size());
|
||||
features.push_back(std::move(x));
|
||||
} // for (int32_t i = 0; i != n; ++i)
|
||||
|
||||
std::vector<const Ort::Value *> features_pointer(n);
|
||||
for (int32_t i = 0; i != n; ++i) {
|
||||
features_pointer[i] = &features[i];
|
||||
}
|
||||
|
||||
std::array<int64_t, 1> features_length_shape = {n};
|
||||
Ort::Value x_length = Ort::Value::CreateTensor(
|
||||
memory_info, features_length_vec.data(), n,
|
||||
features_length_shape.data(), features_length_shape.size());
|
||||
|
||||
Ort::Value x = PadSequence(model_->Allocator(), features_pointer,
|
||||
-23.025850929940457f);
|
||||
auto t = model_->Forward(std::move(x), std::move(x_length));
|
||||
|
||||
auto results = decoder_->Decode(std::move(t.first), std::move(t.second));
|
||||
|
||||
for (int32_t i = 0; i != n; ++i) {
|
||||
auto r = Convert(results[i], symbol_table_);
|
||||
ss[i]->SetResult(r);
|
||||
}
|
||||
}
|
||||
|
||||
private:
|
||||
OfflineRecognizerConfig config_;
|
||||
SymbolTable symbol_table_;
|
||||
std::unique_ptr<OfflineCtcModel> model_;
|
||||
std::unique_ptr<OfflineCtcDecoder> decoder_;
|
||||
};
|
||||
|
||||
} // namespace sherpa_onnx
|
||||
|
||||
#endif // SHERPA_ONNX_CSRC_OFFLINE_RECOGNIZER_CTC_IMPL_H_
|
||||
@@ -8,6 +8,7 @@
|
||||
|
||||
#include "onnxruntime_cxx_api.h" // NOLINT
|
||||
#include "sherpa-onnx/csrc/macros.h"
|
||||
#include "sherpa-onnx/csrc/offline-recognizer-ctc-impl.h"
|
||||
#include "sherpa-onnx/csrc/offline-recognizer-paraformer-impl.h"
|
||||
#include "sherpa-onnx/csrc/offline-recognizer-transducer-impl.h"
|
||||
#include "sherpa-onnx/csrc/onnx-utils.h"
|
||||
@@ -25,6 +26,8 @@ std::unique_ptr<OfflineRecognizerImpl> OfflineRecognizerImpl::Create(
|
||||
model_filename = config.model_config.transducer.encoder_filename;
|
||||
} else if (!config.model_config.paraformer.model.empty()) {
|
||||
model_filename = config.model_config.paraformer.model;
|
||||
} else if (!config.model_config.nemo_ctc.model.empty()) {
|
||||
model_filename = config.model_config.nemo_ctc.model;
|
||||
} else {
|
||||
SHERPA_ONNX_LOGE("Please provide a model");
|
||||
exit(-1);
|
||||
@@ -39,8 +42,30 @@ std::unique_ptr<OfflineRecognizerImpl> OfflineRecognizerImpl::Create(
|
||||
|
||||
Ort::AllocatorWithDefaultOptions allocator; // used in the macro below
|
||||
|
||||
std::string model_type;
|
||||
SHERPA_ONNX_READ_META_DATA_STR(model_type, "model_type");
|
||||
auto model_type_ptr =
|
||||
meta_data.LookupCustomMetadataMapAllocated("model_type", allocator);
|
||||
if (!model_type_ptr) {
|
||||
SHERPA_ONNX_LOGE(
|
||||
"No model_type in the metadata!\n\n"
|
||||
"Please refer to the following URLs to add metadata"
|
||||
"\n"
|
||||
"(0) Transducer models from icefall"
|
||||
"\n "
|
||||
"https://github.com/k2-fsa/icefall/blob/master/egs/librispeech/ASR/"
|
||||
"pruned_transducer_stateless7/export-onnx.py#L303"
|
||||
"\n"
|
||||
"(1) Nemo CTC models\n "
|
||||
"https://huggingface.co/csukuangfj/"
|
||||
"sherpa-onnx-nemo-ctc-en-citrinet-512/blob/main/add-model-metadata.py"
|
||||
"\n"
|
||||
"(2) Paraformer"
|
||||
"\n "
|
||||
"https://huggingface.co/csukuangfj/"
|
||||
"paraformer-onnxruntime-python-example/blob/main/add-model-metadata.py"
|
||||
"\n");
|
||||
exit(-1);
|
||||
}
|
||||
std::string model_type(model_type_ptr.get());
|
||||
|
||||
if (model_type == "conformer" || model_type == "zipformer") {
|
||||
return std::make_unique<OfflineRecognizerTransducerImpl>(config);
|
||||
@@ -50,11 +75,16 @@ std::unique_ptr<OfflineRecognizerImpl> OfflineRecognizerImpl::Create(
|
||||
return std::make_unique<OfflineRecognizerParaformerImpl>(config);
|
||||
}
|
||||
|
||||
if (model_type == "EncDecCTCModelBPE") {
|
||||
return std::make_unique<OfflineRecognizerCtcImpl>(config);
|
||||
}
|
||||
|
||||
SHERPA_ONNX_LOGE(
|
||||
"\nUnsupported model_type: %s\n"
|
||||
"We support only the following model types at present: \n"
|
||||
" - transducer models from icefall\n"
|
||||
" - Paraformer models from FunASR\n",
|
||||
" - Non-streaming transducer models from icefall\n"
|
||||
" - Non-streaming Paraformer models from FunASR\n"
|
||||
" - EncDecCTCModelBPE models from NeMo\n",
|
||||
model_type.c_str());
|
||||
|
||||
exit(-1);
|
||||
|
||||
@@ -7,6 +7,7 @@
|
||||
#include <assert.h>
|
||||
|
||||
#include <algorithm>
|
||||
#include <cmath>
|
||||
|
||||
#include "kaldi-native-fbank/csrc/online-feature.h"
|
||||
#include "sherpa-onnx/csrc/macros.h"
|
||||
@@ -15,6 +16,41 @@
|
||||
|
||||
namespace sherpa_onnx {
|
||||
|
||||
/* Compute mean and inverse stddev over rows.
|
||||
*
|
||||
* @param p A pointer to a 2-d array of shape (num_rows, num_cols)
|
||||
* @param num_rows Number of rows
|
||||
* @param num_cols Number of columns
|
||||
* @param mean On return, it contains p.mean(axis=0)
|
||||
* @param inv_stddev On return, it contains 1/p.std(axis=0)
|
||||
*/
|
||||
static void ComputeMeanAndInvStd(const float *p, int32_t num_rows,
|
||||
int32_t num_cols, std::vector<float> *mean,
|
||||
std::vector<float> *inv_stddev) {
|
||||
std::vector<float> sum(num_cols);
|
||||
std::vector<float> sum_sq(num_cols);
|
||||
|
||||
for (int32_t i = 0; i != num_rows; ++i) {
|
||||
for (int32_t c = 0; c != num_cols; ++c) {
|
||||
auto t = p[c];
|
||||
sum[c] += t;
|
||||
sum_sq[c] += t * t;
|
||||
}
|
||||
p += num_cols;
|
||||
}
|
||||
|
||||
mean->resize(num_cols);
|
||||
inv_stddev->resize(num_cols);
|
||||
|
||||
for (int32_t i = 0; i != num_cols; ++i) {
|
||||
auto t = sum[i] / num_rows;
|
||||
(*mean)[i] = t;
|
||||
|
||||
float stddev = std::sqrt(sum_sq[i] / num_rows - t * t);
|
||||
(*inv_stddev)[i] = 1.0f / (stddev + 1e-5f);
|
||||
}
|
||||
}
|
||||
|
||||
void OfflineFeatureExtractorConfig::Register(ParseOptions *po) {
|
||||
po->Register("sample-rate", &sampling_rate,
|
||||
"Sampling rate of the input waveform. "
|
||||
@@ -106,6 +142,8 @@ class OfflineStream::Impl {
|
||||
p += feature_dim;
|
||||
}
|
||||
|
||||
NemoNormalizeFeatures(features.data(), n, feature_dim);
|
||||
|
||||
return features;
|
||||
}
|
||||
|
||||
@@ -113,6 +151,38 @@ class OfflineStream::Impl {
|
||||
|
||||
const OfflineRecognitionResult &GetResult() const { return r_; }
|
||||
|
||||
private:
|
||||
void NemoNormalizeFeatures(float *p, int32_t num_frames,
|
||||
int32_t feature_dim) const {
|
||||
if (config_.nemo_normalize_type.empty()) {
|
||||
return;
|
||||
}
|
||||
|
||||
if (config_.nemo_normalize_type != "per_feature") {
|
||||
SHERPA_ONNX_LOGE(
|
||||
"Only normalize_type=per_feature is implemented. Given: %s",
|
||||
config_.nemo_normalize_type.c_str());
|
||||
exit(-1);
|
||||
}
|
||||
|
||||
NemoNormalizePerFeature(p, num_frames, feature_dim);
|
||||
}
|
||||
|
||||
static void NemoNormalizePerFeature(float *p, int32_t num_frames,
|
||||
int32_t feature_dim) {
|
||||
std::vector<float> mean;
|
||||
std::vector<float> inv_stddev;
|
||||
|
||||
ComputeMeanAndInvStd(p, num_frames, feature_dim, &mean, &inv_stddev);
|
||||
|
||||
for (int32_t n = 0; n != num_frames; ++n) {
|
||||
for (int32_t i = 0; i != feature_dim; ++i) {
|
||||
p[i] = (p[i] - mean[i]) * inv_stddev[i];
|
||||
}
|
||||
p += feature_dim;
|
||||
}
|
||||
}
|
||||
|
||||
private:
|
||||
OfflineFeatureExtractorConfig config_;
|
||||
std::unique_ptr<knf::OnlineFbank> fbank_;
|
||||
|
||||
@@ -37,13 +37,26 @@ struct OfflineFeatureExtractorConfig {
|
||||
// Feature dimension
|
||||
int32_t feature_dim = 80;
|
||||
|
||||
// Set internally by some models, e.g., paraformer
|
||||
// Set internally by some models, e.g., paraformer sets it to false.
|
||||
// This parameter is not exposed to users from the commandline
|
||||
// If true, the feature extractor expects inputs to be normalized to
|
||||
// the range [-1, 1].
|
||||
// If false, we will multiply the inputs by 32768
|
||||
bool normalize_samples = true;
|
||||
|
||||
// For models from NeMo
|
||||
// This option is not exposed and is set internally when loading models.
|
||||
// Possible values:
|
||||
// - per_feature
|
||||
// - all_features (not implemented yet)
|
||||
// - fixed_mean (not implemented)
|
||||
// - fixed_std (not implemented)
|
||||
// - or just leave it to empty
|
||||
// See
|
||||
// https://github.com/NVIDIA/NeMo/blob/main/nemo/collections/asr/parts/preprocessing/features.py#L59
|
||||
// for details
|
||||
std::string nemo_normalize_type;
|
||||
|
||||
std::string ToString() const;
|
||||
|
||||
void Register(ParseOptions *po);
|
||||
|
||||
@@ -14,10 +14,12 @@
|
||||
#include <sstream>
|
||||
#include <string>
|
||||
|
||||
#include "sherpa-onnx/csrc/macros.h"
|
||||
#include "sherpa-onnx/csrc/online-lstm-transducer-model.h"
|
||||
#include "sherpa-onnx/csrc/online-zipformer-transducer-model.h"
|
||||
#include "sherpa-onnx/csrc/onnx-utils.h"
|
||||
namespace sherpa_onnx {
|
||||
|
||||
namespace {
|
||||
|
||||
enum class ModelType {
|
||||
kLstm,
|
||||
@@ -25,6 +27,10 @@ enum class ModelType {
|
||||
kUnkown,
|
||||
};
|
||||
|
||||
}
|
||||
|
||||
namespace sherpa_onnx {
|
||||
|
||||
static ModelType GetModelType(char *model_data, size_t model_data_length,
|
||||
bool debug) {
|
||||
Ort::Env env(ORT_LOGGING_LEVEL_WARNING);
|
||||
@@ -37,14 +43,17 @@ static ModelType GetModelType(char *model_data, size_t model_data_length,
|
||||
if (debug) {
|
||||
std::ostringstream os;
|
||||
PrintModelMetadata(os, meta_data);
|
||||
fprintf(stderr, "%s\n", os.str().c_str());
|
||||
SHERPA_ONNX_LOGE("%s", os.str().c_str());
|
||||
}
|
||||
|
||||
Ort::AllocatorWithDefaultOptions allocator;
|
||||
auto model_type =
|
||||
meta_data.LookupCustomMetadataMapAllocated("model_type", allocator);
|
||||
if (!model_type) {
|
||||
fprintf(stderr, "No model_type in the metadata!\n");
|
||||
SHERPA_ONNX_LOGE(
|
||||
"No model_type in the metadata!\n"
|
||||
"Please make sure you are using the latest export-onnx.py from icefall "
|
||||
"to export your transducer models");
|
||||
return ModelType::kUnkown;
|
||||
}
|
||||
|
||||
@@ -53,7 +62,7 @@ static ModelType GetModelType(char *model_data, size_t model_data_length,
|
||||
} else if (model_type.get() == std::string("zipformer")) {
|
||||
return ModelType::kZipformer;
|
||||
} else {
|
||||
fprintf(stderr, "Unsupported model_type: %s\n", model_type.get());
|
||||
SHERPA_ONNX_LOGE("Unsupported model_type: %s", model_type.get());
|
||||
return ModelType::kUnkown;
|
||||
}
|
||||
}
|
||||
@@ -74,6 +83,7 @@ std::unique_ptr<OnlineTransducerModel> OnlineTransducerModel::Create(
|
||||
case ModelType::kZipformer:
|
||||
return std::make_unique<OnlineZipformerTransducerModel>(config);
|
||||
case ModelType::kUnkown:
|
||||
SHERPA_ONNX_LOGE("Unknown model type in online transducer!");
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
@@ -127,6 +137,7 @@ std::unique_ptr<OnlineTransducerModel> OnlineTransducerModel::Create(
|
||||
case ModelType::kZipformer:
|
||||
return std::make_unique<OnlineZipformerTransducerModel>(mgr, config);
|
||||
case ModelType::kUnkown:
|
||||
SHERPA_ONNX_LOGE("Unknown model type in online transducer!");
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
|
||||
@@ -35,4 +35,28 @@ TEST(Tranpose, Tranpose01) {
|
||||
}
|
||||
}
|
||||
|
||||
TEST(Tranpose, Tranpose12) {
|
||||
Ort::AllocatorWithDefaultOptions allocator;
|
||||
std::array<int64_t, 3> shape{3, 2, 5};
|
||||
Ort::Value v =
|
||||
Ort::Value::CreateTensor<float>(allocator, shape.data(), shape.size());
|
||||
float *p = v.GetTensorMutableData<float>();
|
||||
|
||||
std::iota(p, p + shape[0] * shape[1] * shape[2], 0);
|
||||
|
||||
auto ans = Transpose12(allocator, &v);
|
||||
auto v2 = Transpose12(allocator, &ans);
|
||||
|
||||
Print3D(&v);
|
||||
Print3D(&ans);
|
||||
Print3D(&v2);
|
||||
|
||||
const float *q = v2.GetTensorData<float>();
|
||||
|
||||
for (int32_t i = 0; i != static_cast<int32_t>(shape[0] * shape[1] * shape[2]);
|
||||
++i) {
|
||||
EXPECT_EQ(p[i], q[i]);
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace sherpa_onnx
|
||||
|
||||
@@ -17,8 +17,8 @@ Ort::Value Transpose01(OrtAllocator *allocator, const Ort::Value *v) {
|
||||
assert(shape.size() == 3);
|
||||
|
||||
std::array<int64_t, 3> ans_shape{shape[1], shape[0], shape[2]};
|
||||
Ort::Value ans = Ort::Value::CreateTensor<float>(allocator, ans_shape.data(),
|
||||
ans_shape.size());
|
||||
Ort::Value ans = Ort::Value::CreateTensor<T>(allocator, ans_shape.data(),
|
||||
ans_shape.size());
|
||||
|
||||
T *dst = ans.GetTensorMutableData<T>();
|
||||
auto plane_offset = shape[1] * shape[2];
|
||||
@@ -35,7 +35,32 @@ Ort::Value Transpose01(OrtAllocator *allocator, const Ort::Value *v) {
|
||||
return ans;
|
||||
}
|
||||
|
||||
template <typename T /*= float*/>
|
||||
Ort::Value Transpose12(OrtAllocator *allocator, const Ort::Value *v) {
|
||||
std::vector<int64_t> shape = v->GetTensorTypeAndShapeInfo().GetShape();
|
||||
assert(shape.size() == 3);
|
||||
|
||||
std::array<int64_t, 3> ans_shape{shape[0], shape[2], shape[1]};
|
||||
Ort::Value ans = Ort::Value::CreateTensor<T>(allocator, ans_shape.data(),
|
||||
ans_shape.size());
|
||||
T *dst = ans.GetTensorMutableData<T>();
|
||||
auto row_stride = shape[2];
|
||||
for (int64_t b = 0; b != ans_shape[0]; ++b) {
|
||||
const T *src = v->GetTensorData<T>() + b * shape[1] * shape[2];
|
||||
for (int64_t i = 0; i != ans_shape[1]; ++i) {
|
||||
for (int64_t k = 0; k != ans_shape[2]; ++k, ++dst) {
|
||||
*dst = (src + k * row_stride)[i];
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return ans;
|
||||
}
|
||||
|
||||
template Ort::Value Transpose01<float>(OrtAllocator *allocator,
|
||||
const Ort::Value *v);
|
||||
|
||||
template Ort::Value Transpose12<float>(OrtAllocator *allocator,
|
||||
const Ort::Value *v);
|
||||
|
||||
} // namespace sherpa_onnx
|
||||
|
||||
@@ -10,13 +10,23 @@ namespace sherpa_onnx {
|
||||
/** Transpose a 3-D tensor from shape (B, T, C) to (T, B, C).
|
||||
*
|
||||
* @param allocator
|
||||
* @param v A 3-D tensor of shape (B, T, C). Its dataype is T.
|
||||
* @param v A 3-D tensor of shape (B, T, C). Its dataype is type.
|
||||
*
|
||||
* @return Return a 3-D tensor of shape (T, B, C). Its datatype is T.
|
||||
* @return Return a 3-D tensor of shape (T, B, C). Its datatype is type.
|
||||
*/
|
||||
template <typename T = float>
|
||||
template <typename type = float>
|
||||
Ort::Value Transpose01(OrtAllocator *allocator, const Ort::Value *v);
|
||||
|
||||
/** Transpose a 3-D tensor from shape (B, T, C) to (B, C, T).
|
||||
*
|
||||
* @param allocator
|
||||
* @param v A 3-D tensor of shape (B, T, C). Its dataype is type.
|
||||
*
|
||||
* @return Return a 3-D tensor of shape (B, C, T). Its datatype is type.
|
||||
*/
|
||||
template <typename type = float>
|
||||
Ort::Value Transpose12(OrtAllocator *allocator, const Ort::Value *v);
|
||||
|
||||
} // namespace sherpa_onnx
|
||||
|
||||
#endif // SHERPA_ONNX_CSRC_TRANSPOSE_H_
|
||||
|
||||
@@ -5,6 +5,7 @@ pybind11_add_module(_sherpa_onnx
|
||||
endpoint.cc
|
||||
features.cc
|
||||
offline-model-config.cc
|
||||
offline-nemo-enc-dec-ctc-model-config.cc
|
||||
offline-paraformer-model-config.cc
|
||||
offline-recognizer.cc
|
||||
offline-stream.cc
|
||||
|
||||
@@ -7,26 +7,31 @@
|
||||
#include <string>
|
||||
#include <vector>
|
||||
|
||||
#include "sherpa-onnx/python/csrc/offline-transducer-model-config.h"
|
||||
#include "sherpa-onnx/python/csrc/offline-paraformer-model-config.h"
|
||||
|
||||
#include "sherpa-onnx/csrc/offline-model-config.h"
|
||||
#include "sherpa-onnx/python/csrc/offline-nemo-enc-dec-ctc-model-config.h"
|
||||
#include "sherpa-onnx/python/csrc/offline-paraformer-model-config.h"
|
||||
#include "sherpa-onnx/python/csrc/offline-transducer-model-config.h"
|
||||
|
||||
namespace sherpa_onnx {
|
||||
|
||||
void PybindOfflineModelConfig(py::module *m) {
|
||||
PybindOfflineTransducerModelConfig(m);
|
||||
PybindOfflineParaformerModelConfig(m);
|
||||
PybindOfflineNemoEncDecCtcModelConfig(m);
|
||||
|
||||
using PyClass = OfflineModelConfig;
|
||||
py::class_<PyClass>(*m, "OfflineModelConfig")
|
||||
.def(py::init<OfflineTransducerModelConfig &,
|
||||
OfflineParaformerModelConfig &,
|
||||
const std::string &, int32_t, bool>(),
|
||||
py::arg("transducer"), py::arg("paraformer"), py::arg("tokens"),
|
||||
py::arg("num_threads"), py::arg("debug") = false)
|
||||
.def(py::init<const OfflineTransducerModelConfig &,
|
||||
const OfflineParaformerModelConfig &,
|
||||
const OfflineNemoEncDecCtcModelConfig &,
|
||||
const std::string &, int32_t, bool>(),
|
||||
py::arg("transducer") = OfflineTransducerModelConfig(),
|
||||
py::arg("paraformer") = OfflineParaformerModelConfig(),
|
||||
py::arg("nemo_ctc") = OfflineNemoEncDecCtcModelConfig(),
|
||||
py::arg("tokens"), py::arg("num_threads"), py::arg("debug") = false)
|
||||
.def_readwrite("transducer", &PyClass::transducer)
|
||||
.def_readwrite("paraformer", &PyClass::paraformer)
|
||||
.def_readwrite("nemo_ctc", &PyClass::nemo_ctc)
|
||||
.def_readwrite("tokens", &PyClass::tokens)
|
||||
.def_readwrite("num_threads", &PyClass::num_threads)
|
||||
.def_readwrite("debug", &PyClass::debug)
|
||||
|
||||
@@ -0,0 +1,22 @@
|
||||
// sherpa-onnx/python/csrc/offline-nemo-enc-dec-ctc-model-config.cc
|
||||
//
|
||||
// Copyright (c) 2023 Xiaomi Corporation
|
||||
|
||||
#include "sherpa-onnx/python/csrc/offline-nemo-enc-dec-ctc-model-config.h"
|
||||
|
||||
#include <string>
|
||||
#include <vector>
|
||||
|
||||
#include "sherpa-onnx/csrc/offline-nemo-enc-dec-ctc-model-config.h"
|
||||
|
||||
namespace sherpa_onnx {
|
||||
|
||||
void PybindOfflineNemoEncDecCtcModelConfig(py::module *m) {
|
||||
using PyClass = OfflineNemoEncDecCtcModelConfig;
|
||||
py::class_<PyClass>(*m, "OfflineNemoEncDecCtcModelConfig")
|
||||
.def(py::init<const std::string &>(), py::arg("model"))
|
||||
.def_readwrite("model", &PyClass::model)
|
||||
.def("__str__", &PyClass::ToString);
|
||||
}
|
||||
|
||||
} // namespace sherpa_onnx
|
||||
@@ -0,0 +1,16 @@
|
||||
// sherpa-onnx/python/csrc/offline-nemo-enc-dec-ctc-model-config.h
|
||||
//
|
||||
// Copyright (c) 2023 Xiaomi Corporation
|
||||
|
||||
#ifndef SHERPA_ONNX_PYTHON_CSRC_OFFLINE_NEMO_ENC_DEC_CTC_MODEL_CONFIG_H_
|
||||
#define SHERPA_ONNX_PYTHON_CSRC_OFFLINE_NEMO_ENC_DEC_CTC_MODEL_CONFIG_H_
|
||||
|
||||
#include "sherpa-onnx/python/csrc/sherpa-onnx.h"
|
||||
|
||||
namespace sherpa_onnx {
|
||||
|
||||
void PybindOfflineNemoEncDecCtcModelConfig(py::module *m);
|
||||
|
||||
}
|
||||
|
||||
#endif // SHERPA_ONNX_PYTHON_CSRC_OFFLINE_NEMO_ENC_DEC_CTC_MODEL_CONFIG_H_
|
||||
@@ -4,7 +4,6 @@
|
||||
|
||||
#include "sherpa-onnx/python/csrc/offline-paraformer-model-config.h"
|
||||
|
||||
|
||||
#include <string>
|
||||
#include <vector>
|
||||
|
||||
@@ -15,8 +14,7 @@ namespace sherpa_onnx {
|
||||
void PybindOfflineParaformerModelConfig(py::module *m) {
|
||||
using PyClass = OfflineParaformerModelConfig;
|
||||
py::class_<PyClass>(*m, "OfflineParaformerModelConfig")
|
||||
.def(py::init<const std::string &>(),
|
||||
py::arg("model"))
|
||||
.def(py::init<const std::string &>(), py::arg("model"))
|
||||
.def_readwrite("model", &PyClass::model)
|
||||
.def("__str__", &PyClass::ToString);
|
||||
}
|
||||
|
||||
@@ -11,8 +11,6 @@
|
||||
|
||||
namespace sherpa_onnx {
|
||||
|
||||
|
||||
|
||||
static void PybindOfflineRecognizerConfig(py::module *m) {
|
||||
using PyClass = OfflineRecognizerConfig;
|
||||
py::class_<PyClass>(*m, "OfflineRecognizerConfig")
|
||||
|
||||
@@ -31,7 +31,6 @@ static void PybindOfflineRecognitionResult(py::module *m) { // NOLINT
|
||||
"timestamps", [](const PyClass &self) { return self.timestamps; });
|
||||
}
|
||||
|
||||
|
||||
static void PybindOfflineFeatureExtractorConfig(py::module *m) {
|
||||
using PyClass = OfflineFeatureExtractorConfig;
|
||||
py::class_<PyClass>(*m, "OfflineFeatureExtractorConfig")
|
||||
@@ -42,7 +41,6 @@ static void PybindOfflineFeatureExtractorConfig(py::module *m) {
|
||||
.def("__str__", &PyClass::ToString);
|
||||
}
|
||||
|
||||
|
||||
void PybindOfflineStream(py::module *m) {
|
||||
PybindOfflineFeatureExtractorConfig(m);
|
||||
PybindOfflineRecognitionResult(m);
|
||||
@@ -55,7 +53,7 @@ void PybindOfflineStream(py::module *m) {
|
||||
self.AcceptWaveform(sample_rate, waveform.data(), waveform.size());
|
||||
},
|
||||
py::arg("sample_rate"), py::arg("waveform"), kAcceptWaveformUsage)
|
||||
.def_property_readonly("result", &PyClass::GetResult);
|
||||
.def_property_readonly("result", &PyClass::GetResult);
|
||||
}
|
||||
|
||||
} // namespace sherpa_onnx
|
||||
|
||||
@@ -7,16 +7,13 @@
|
||||
#include "sherpa-onnx/python/csrc/display.h"
|
||||
#include "sherpa-onnx/python/csrc/endpoint.h"
|
||||
#include "sherpa-onnx/python/csrc/features.h"
|
||||
#include "sherpa-onnx/python/csrc/offline-model-config.h"
|
||||
#include "sherpa-onnx/python/csrc/offline-recognizer.h"
|
||||
#include "sherpa-onnx/python/csrc/offline-stream.h"
|
||||
#include "sherpa-onnx/python/csrc/online-recognizer.h"
|
||||
#include "sherpa-onnx/python/csrc/online-stream.h"
|
||||
#include "sherpa-onnx/python/csrc/online-transducer-model-config.h"
|
||||
|
||||
#include "sherpa-onnx/python/csrc/offline-model-config.h"
|
||||
#include "sherpa-onnx/python/csrc/offline-paraformer-model-config.h"
|
||||
#include "sherpa-onnx/python/csrc/offline-recognizer.h"
|
||||
#include "sherpa-onnx/python/csrc/offline-stream.h"
|
||||
#include "sherpa-onnx/python/csrc/offline-transducer-model-config.h"
|
||||
|
||||
namespace sherpa_onnx {
|
||||
|
||||
PYBIND11_MODULE(_sherpa_onnx, m) {
|
||||
|
||||
@@ -4,12 +4,15 @@ from typing import List
|
||||
|
||||
from _sherpa_onnx import (
|
||||
OfflineFeatureExtractorConfig,
|
||||
OfflineRecognizer as _Recognizer,
|
||||
OfflineModelConfig,
|
||||
OfflineNemoEncDecCtcModelConfig,
|
||||
OfflineParaformerModelConfig,
|
||||
)
|
||||
from _sherpa_onnx import OfflineRecognizer as _Recognizer
|
||||
from _sherpa_onnx import (
|
||||
OfflineRecognizerConfig,
|
||||
OfflineStream,
|
||||
OfflineModelConfig,
|
||||
OfflineTransducerModelConfig,
|
||||
OfflineParaformerModelConfig,
|
||||
)
|
||||
|
||||
|
||||
@@ -75,7 +78,6 @@ class OfflineRecognizer(object):
|
||||
decoder_filename=decoder,
|
||||
joiner_filename=joiner,
|
||||
),
|
||||
paraformer=OfflineParaformerModelConfig(model=""),
|
||||
tokens=tokens,
|
||||
num_threads=num_threads,
|
||||
debug=debug,
|
||||
@@ -119,7 +121,7 @@ class OfflineRecognizer(object):
|
||||
symbol integer_id
|
||||
|
||||
paraformer:
|
||||
Path to ``paraformer.onnx``.
|
||||
Path to ``model.onnx``.
|
||||
num_threads:
|
||||
Number of threads for neural network computation.
|
||||
sample_rate:
|
||||
@@ -133,9 +135,6 @@ class OfflineRecognizer(object):
|
||||
"""
|
||||
self = cls.__new__(cls)
|
||||
model_config = OfflineModelConfig(
|
||||
transducer=OfflineTransducerModelConfig(
|
||||
encoder_filename="", decoder_filename="", joiner_filename=""
|
||||
),
|
||||
paraformer=OfflineParaformerModelConfig(model=paraformer),
|
||||
tokens=tokens,
|
||||
num_threads=num_threads,
|
||||
@@ -155,6 +154,64 @@ class OfflineRecognizer(object):
|
||||
self.recognizer = _Recognizer(recognizer_config)
|
||||
return self
|
||||
|
||||
@classmethod
|
||||
def from_nemo_ctc(
|
||||
cls,
|
||||
model: str,
|
||||
tokens: str,
|
||||
num_threads: int,
|
||||
sample_rate: int = 16000,
|
||||
feature_dim: int = 80,
|
||||
decoding_method: str = "greedy_search",
|
||||
debug: bool = False,
|
||||
):
|
||||
"""
|
||||
Please refer to
|
||||
`<https://k2-fsa.github.io/sherpa/onnx/pretrained_models/index.html>`_
|
||||
to download pre-trained models for different languages, e.g., Chinese,
|
||||
English, etc.
|
||||
|
||||
Args:
|
||||
tokens:
|
||||
Path to ``tokens.txt``. Each line in ``tokens.txt`` contains two
|
||||
columns::
|
||||
|
||||
symbol integer_id
|
||||
|
||||
model:
|
||||
Path to ``model.onnx``.
|
||||
num_threads:
|
||||
Number of threads for neural network computation.
|
||||
sample_rate:
|
||||
Sample rate of the training data used to train the model.
|
||||
feature_dim:
|
||||
Dimension of the feature used to train the model.
|
||||
decoding_method:
|
||||
Valid values are greedy_search, modified_beam_search.
|
||||
debug:
|
||||
True to show debug messages.
|
||||
"""
|
||||
self = cls.__new__(cls)
|
||||
model_config = OfflineModelConfig(
|
||||
nemo_ctc=OfflineNemoEncDecCtcModelConfig(model=model),
|
||||
tokens=tokens,
|
||||
num_threads=num_threads,
|
||||
debug=debug,
|
||||
)
|
||||
|
||||
feat_config = OfflineFeatureExtractorConfig(
|
||||
sampling_rate=sample_rate,
|
||||
feature_dim=feature_dim,
|
||||
)
|
||||
|
||||
recognizer_config = OfflineRecognizerConfig(
|
||||
feat_config=feat_config,
|
||||
model_config=model_config,
|
||||
decoding_method=decoding_method,
|
||||
)
|
||||
self.recognizer = _Recognizer(recognizer_config)
|
||||
return self
|
||||
|
||||
def create_stream(self):
|
||||
return self.recognizer.create_stream()
|
||||
|
||||
|
||||
@@ -196,6 +196,71 @@ class TestOfflineRecognizer(unittest.TestCase):
|
||||
print(s2.result.text)
|
||||
print(s3.result.text)
|
||||
|
||||
def test_nemo_ctc_single_file(self):
|
||||
for use_int8 in [True, False]:
|
||||
if use_int8:
|
||||
model = f"{d}/sherpa-onnx-nemo-ctc-en-citrinet-512/model.int8.onnx"
|
||||
else:
|
||||
model = f"{d}/sherpa-onnx-nemo-ctc-en-citrinet-512/model.onnx"
|
||||
|
||||
tokens = f"{d}/sherpa-onnx-nemo-ctc-en-citrinet-512/tokens.txt"
|
||||
wave0 = f"{d}/sherpa-onnx-nemo-ctc-en-citrinet-512/test_wavs/0.wav"
|
||||
|
||||
if not Path(model).is_file():
|
||||
print("skipping test_nemo_ctc_single_file()")
|
||||
return
|
||||
|
||||
recognizer = sherpa_onnx.OfflineRecognizer.from_nemo_ctc(
|
||||
model=model,
|
||||
tokens=tokens,
|
||||
num_threads=1,
|
||||
)
|
||||
|
||||
s = recognizer.create_stream()
|
||||
samples, sample_rate = read_wave(wave0)
|
||||
s.accept_waveform(sample_rate, samples)
|
||||
recognizer.decode_stream(s)
|
||||
print(s.result.text)
|
||||
|
||||
def test_nemo_ctc_multiple_files(self):
|
||||
for use_int8 in [True, False]:
|
||||
if use_int8:
|
||||
model = f"{d}/sherpa-onnx-nemo-ctc-en-citrinet-512/model.int8.onnx"
|
||||
else:
|
||||
model = f"{d}/sherpa-onnx-nemo-ctc-en-citrinet-512/model.onnx"
|
||||
|
||||
tokens = f"{d}/sherpa-onnx-nemo-ctc-en-citrinet-512/tokens.txt"
|
||||
wave0 = f"{d}/sherpa-onnx-nemo-ctc-en-citrinet-512/test_wavs/0.wav"
|
||||
wave1 = f"{d}/sherpa-onnx-nemo-ctc-en-citrinet-512/test_wavs/1.wav"
|
||||
wave2 = f"{d}/sherpa-onnx-nemo-ctc-en-citrinet-512/test_wavs/8k.wav"
|
||||
|
||||
if not Path(model).is_file():
|
||||
print("skipping test_nemo_ctc_multiple_files()")
|
||||
return
|
||||
|
||||
recognizer = sherpa_onnx.OfflineRecognizer.from_nemo_ctc(
|
||||
model=model,
|
||||
tokens=tokens,
|
||||
num_threads=1,
|
||||
)
|
||||
|
||||
s0 = recognizer.create_stream()
|
||||
samples0, sample_rate0 = read_wave(wave0)
|
||||
s0.accept_waveform(sample_rate0, samples0)
|
||||
|
||||
s1 = recognizer.create_stream()
|
||||
samples1, sample_rate1 = read_wave(wave1)
|
||||
s1.accept_waveform(sample_rate1, samples1)
|
||||
|
||||
s2 = recognizer.create_stream()
|
||||
samples2, sample_rate2 = read_wave(wave2)
|
||||
s2.accept_waveform(sample_rate2, samples2)
|
||||
|
||||
recognizer.decode_streams([s0, s1, s2])
|
||||
print(s0.result.text)
|
||||
print(s1.result.text)
|
||||
print(s2.result.text)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
|
||||
Reference in New Issue
Block a user