Add HLG decoding for streaming CTC models (#731)
This commit is contained in:
22
.github/scripts/test-online-ctc.sh
vendored
22
.github/scripts/test-online-ctc.sh
vendored
@@ -1,6 +1,6 @@
|
||||
#!/usr/bin/env bash
|
||||
|
||||
set -e
|
||||
set -ex
|
||||
|
||||
log() {
|
||||
# This function is from espnet
|
||||
@@ -13,6 +13,26 @@ echo "PATH: $PATH"
|
||||
|
||||
which $EXE
|
||||
|
||||
log "------------------------------------------------------------"
|
||||
log "Run streaming Zipformer2 CTC HLG decoding "
|
||||
log "------------------------------------------------------------"
|
||||
curl -SL -O https://github.com/k2-fsa/sherpa-onnx/releases/download/asr-models/sherpa-onnx-streaming-zipformer-ctc-small-2024-03-18.tar.bz2
|
||||
tar xvf sherpa-onnx-streaming-zipformer-ctc-small-2024-03-18.tar.bz2
|
||||
rm sherpa-onnx-streaming-zipformer-ctc-small-2024-03-18.tar.bz2
|
||||
repo=$PWD/sherpa-onnx-streaming-zipformer-ctc-small-2024-03-18
|
||||
ls -lh $repo
|
||||
echo "pwd: $PWD"
|
||||
|
||||
$EXE \
|
||||
--zipformer2-ctc-model=$repo/ctc-epoch-30-avg-3-chunk-16-left-128.int8.onnx \
|
||||
--ctc-graph=$repo/HLG.fst \
|
||||
--tokens=$repo/tokens.txt \
|
||||
$repo/test_wavs/0.wav \
|
||||
$repo/test_wavs/1.wav \
|
||||
$repo/test_wavs/8k.wav
|
||||
|
||||
rm -rf sherpa-onnx-streaming-zipformer-ctc-small-2024-03-18
|
||||
|
||||
log "------------------------------------------------------------"
|
||||
log "Run streaming Zipformer2 CTC "
|
||||
log "------------------------------------------------------------"
|
||||
|
||||
19
.github/scripts/test-python.sh
vendored
19
.github/scripts/test-python.sh
vendored
@@ -1,6 +1,6 @@
|
||||
#!/usr/bin/env bash
|
||||
|
||||
set -e
|
||||
set -ex
|
||||
|
||||
log() {
|
||||
# This function is from espnet
|
||||
@@ -8,6 +8,23 @@ log() {
|
||||
echo -e "$(date '+%Y-%m-%d %H:%M:%S') (${fname}:${BASH_LINENO[0]}:${FUNCNAME[1]}) $*"
|
||||
}
|
||||
|
||||
log "test streaming zipformer2 ctc HLG decoding"
|
||||
|
||||
curl -SL -O https://github.com/k2-fsa/sherpa-onnx/releases/download/asr-models/sherpa-onnx-streaming-zipformer-ctc-small-2024-03-18.tar.bz2
|
||||
tar xvf sherpa-onnx-streaming-zipformer-ctc-small-2024-03-18.tar.bz2
|
||||
rm sherpa-onnx-streaming-zipformer-ctc-small-2024-03-18.tar.bz2
|
||||
repo=sherpa-onnx-streaming-zipformer-ctc-small-2024-03-18
|
||||
|
||||
python3 ./python-api-examples/online-zipformer-ctc-hlg-decode-file.py \
|
||||
--debug 1 \
|
||||
--tokens ./sherpa-onnx-streaming-zipformer-ctc-small-2024-03-18/tokens.txt \
|
||||
--graph ./sherpa-onnx-streaming-zipformer-ctc-small-2024-03-18/HLG.fst \
|
||||
--model ./sherpa-onnx-streaming-zipformer-ctc-small-2024-03-18/ctc-epoch-30-avg-3-chunk-16-left-128.int8.onnx \
|
||||
./sherpa-onnx-streaming-zipformer-ctc-small-2024-03-18/test_wavs/0.wav
|
||||
|
||||
rm -rf sherpa-onnx-streaming-zipformer-ctc-small-2024-03-18
|
||||
|
||||
|
||||
mkdir -p /tmp/icefall-models
|
||||
dir=/tmp/icefall-models
|
||||
|
||||
|
||||
15
.github/workflows/linux.yaml
vendored
15
.github/workflows/linux.yaml
vendored
@@ -124,6 +124,14 @@ jobs:
|
||||
name: release-${{ matrix.build_type }}-with-shared-lib-${{ matrix.shared_lib }}-with-tts-${{ matrix.with_tts }}
|
||||
path: build/bin/*
|
||||
|
||||
- name: Test online CTC
|
||||
shell: bash
|
||||
run: |
|
||||
export PATH=$PWD/build/bin:$PATH
|
||||
export EXE=sherpa-onnx
|
||||
|
||||
.github/scripts/test-online-ctc.sh
|
||||
|
||||
- name: Test C API
|
||||
shell: bash
|
||||
run: |
|
||||
@@ -149,13 +157,6 @@ jobs:
|
||||
|
||||
.github/scripts/test-kws.sh
|
||||
|
||||
- name: Test online CTC
|
||||
shell: bash
|
||||
run: |
|
||||
export PATH=$PWD/build/bin:$PATH
|
||||
export EXE=sherpa-onnx
|
||||
|
||||
.github/scripts/test-online-ctc.sh
|
||||
|
||||
- name: Test offline Whisper
|
||||
if: matrix.build_type != 'Debug'
|
||||
|
||||
@@ -1,9 +1,9 @@
|
||||
function(download_kaldi_decoder)
|
||||
include(FetchContent)
|
||||
|
||||
set(kaldi_decoder_URL "https://github.com/k2-fsa/kaldi-decoder/archive/refs/tags/v0.2.4.tar.gz")
|
||||
set(kaldi_decoder_URL2 "https://hub.nuaa.cf/k2-fsa/kaldi-decoder/archive/refs/tags/v0.2.4.tar.gz")
|
||||
set(kaldi_decoder_HASH "SHA256=136d96c2f1f8ec44de095205f81a6ce98981cd867fe4ba840f9415a0b58fe601")
|
||||
set(kaldi_decoder_URL "https://github.com/k2-fsa/kaldi-decoder/archive/refs/tags/v0.2.5.tar.gz")
|
||||
set(kaldi_decoder_URL2 "https://hub.nuaa.cf/k2-fsa/kaldi-decoder/archive/refs/tags/v0.2.5.tar.gz")
|
||||
set(kaldi_decoder_HASH "SHA256=f663e58aef31b33cd8086eaa09ff1383628039845f31300b5abef817d8cc2fff")
|
||||
|
||||
set(KALDI_DECODER_BUILD_PYTHON OFF CACHE BOOL "" FORCE)
|
||||
set(KALDI_DECODER_ENABLE_TESTS OFF CACHE BOOL "" FORCE)
|
||||
@@ -12,11 +12,11 @@ function(download_kaldi_decoder)
|
||||
# If you don't have access to the Internet,
|
||||
# please pre-download kaldi-decoder
|
||||
set(possible_file_locations
|
||||
$ENV{HOME}/Downloads/kaldi-decoder-0.2.4.tar.gz
|
||||
${CMAKE_SOURCE_DIR}/kaldi-decoder-0.2.4.tar.gz
|
||||
${CMAKE_BINARY_DIR}/kaldi-decoder-0.2.4.tar.gz
|
||||
/tmp/kaldi-decoder-0.2.4.tar.gz
|
||||
/star-fj/fangjun/download/github/kaldi-decoder-0.2.4.tar.gz
|
||||
$ENV{HOME}/Downloads/kaldi-decoder-0.2.5.tar.gz
|
||||
${CMAKE_SOURCE_DIR}/kaldi-decoder-0.2.5.tar.gz
|
||||
${CMAKE_BINARY_DIR}/kaldi-decoder-0.2.5.tar.gz
|
||||
/tmp/kaldi-decoder-0.2.5.tar.gz
|
||||
/star-fj/fangjun/download/github/kaldi-decoder-0.2.5.tar.gz
|
||||
)
|
||||
|
||||
foreach(f IN LISTS possible_file_locations)
|
||||
|
||||
172
python-api-examples/online-zipformer-ctc-hlg-decode-file.py
Executable file
172
python-api-examples/online-zipformer-ctc-hlg-decode-file.py
Executable file
@@ -0,0 +1,172 @@
|
||||
#!/usr/bin/env python3
|
||||
|
||||
# This file shows how to use a streaming zipformer CTC model and an HLG
|
||||
# graph for decoding.
|
||||
#
|
||||
# We use the following model as an example
|
||||
#
|
||||
"""
|
||||
wget https://github.com/k2-fsa/sherpa-onnx/releases/download/asr-models/sherpa-onnx-streaming-zipformer-ctc-small-2024-03-18.tar.bz2
|
||||
tar xvf sherpa-onnx-streaming-zipformer-ctc-small-2024-03-18.tar.bz2
|
||||
rm sherpa-onnx-streaming-zipformer-ctc-small-2024-03-18.tar.bz2
|
||||
|
||||
python3 ./python-api-examples/online-zipformer-ctc-hlg-decode-file.py \
|
||||
--tokens ./sherpa-onnx-streaming-zipformer-ctc-small-2024-03-18/tokens.txt \
|
||||
--graph ./sherpa-onnx-streaming-zipformer-ctc-small-2024-03-18/HLG.fst \
|
||||
--model ./sherpa-onnx-streaming-zipformer-ctc-small-2024-03-18/ctc-epoch-30-avg-3-chunk-16-left-128.int8.onnx \
|
||||
./sherpa-onnx-streaming-zipformer-ctc-small-2024-03-18/test_wavs/0.wav
|
||||
|
||||
"""
|
||||
# (The above model is from https://github.com/k2-fsa/icefall/pull/1557)
|
||||
|
||||
import argparse
|
||||
import time
|
||||
import wave
|
||||
from pathlib import Path
|
||||
from typing import List, Tuple
|
||||
|
||||
import numpy as np
|
||||
import sherpa_onnx
|
||||
|
||||
|
||||
def get_args():
|
||||
parser = argparse.ArgumentParser(
|
||||
formatter_class=argparse.ArgumentDefaultsHelpFormatter
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--tokens",
|
||||
type=str,
|
||||
required=True,
|
||||
help="Path to tokens.txt",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--model",
|
||||
type=str,
|
||||
required=True,
|
||||
help="Path to the ONNX model",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--graph",
|
||||
type=str,
|
||||
required=True,
|
||||
help="Path to H.fst, HL.fst, or HLG.fst",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--num-threads",
|
||||
type=int,
|
||||
default=1,
|
||||
help="Number of threads for neural network computation",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--provider",
|
||||
type=str,
|
||||
default="cpu",
|
||||
help="Valid values: cpu, cuda, coreml",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--debug",
|
||||
type=int,
|
||||
default=0,
|
||||
help="Valid values: 1, 0",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"sound_file",
|
||||
type=str,
|
||||
help="The input sound file to decode. It must be of WAVE"
|
||||
"format with a single channel, and each sample has 16-bit, "
|
||||
"i.e., int16_t. "
|
||||
"The sample rate of the file can be arbitrary and does not need to "
|
||||
"be 16 kHz",
|
||||
)
|
||||
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
def assert_file_exists(filename: str):
|
||||
assert Path(filename).is_file(), (
|
||||
f"{filename} does not exist!\n"
|
||||
"Please refer to "
|
||||
"https://k2-fsa.github.io/sherpa/onnx/pretrained_models/index.html to download it"
|
||||
)
|
||||
|
||||
|
||||
def read_wave(wave_filename: str) -> Tuple[np.ndarray, int]:
|
||||
"""
|
||||
Args:
|
||||
wave_filename:
|
||||
Path to a wave file. It should be single channel and each sample should
|
||||
be 16-bit. Its sample rate does not need to be 16kHz.
|
||||
Returns:
|
||||
Return a tuple containing:
|
||||
- A 1-D array of dtype np.float32 containing the samples, which are
|
||||
normalized to the range [-1, 1].
|
||||
- sample rate of the wave file
|
||||
"""
|
||||
|
||||
with wave.open(wave_filename) as f:
|
||||
assert f.getnchannels() == 1, f.getnchannels()
|
||||
assert f.getsampwidth() == 2, f.getsampwidth() # it is in bytes
|
||||
num_samples = f.getnframes()
|
||||
samples = f.readframes(num_samples)
|
||||
samples_int16 = np.frombuffer(samples, dtype=np.int16)
|
||||
samples_float32 = samples_int16.astype(np.float32)
|
||||
|
||||
samples_float32 = samples_float32 / 32768
|
||||
return samples_float32, f.getframerate()
|
||||
|
||||
|
||||
def main():
|
||||
args = get_args()
|
||||
print(vars(args))
|
||||
|
||||
assert_file_exists(args.tokens)
|
||||
assert_file_exists(args.graph)
|
||||
assert_file_exists(args.model)
|
||||
|
||||
recognizer = sherpa_onnx.OnlineRecognizer.from_zipformer2_ctc(
|
||||
tokens=args.tokens,
|
||||
model=args.model,
|
||||
num_threads=args.num_threads,
|
||||
provider=args.provider,
|
||||
sample_rate=16000,
|
||||
feature_dim=80,
|
||||
ctc_graph=args.graph,
|
||||
)
|
||||
|
||||
wave_filename = args.sound_file
|
||||
assert_file_exists(wave_filename)
|
||||
samples, sample_rate = read_wave(wave_filename)
|
||||
duration = len(samples) / sample_rate
|
||||
|
||||
print("Started")
|
||||
|
||||
start_time = time.time()
|
||||
s = recognizer.create_stream()
|
||||
s.accept_waveform(sample_rate, samples)
|
||||
tail_paddings = np.zeros(int(0.66 * sample_rate), dtype=np.float32)
|
||||
s.accept_waveform(sample_rate, tail_paddings)
|
||||
s.input_finished()
|
||||
while recognizer.is_ready(s):
|
||||
recognizer.decode_stream(s)
|
||||
|
||||
result = recognizer.get_result(s).lower()
|
||||
end_time = time.time()
|
||||
|
||||
elapsed_seconds = end_time - start_time
|
||||
rtf = elapsed_seconds / duration
|
||||
print(f"num_threads: {args.num_threads}")
|
||||
print(f"Wave duration: {duration:.3f} s")
|
||||
print(f"Elapsed time: {elapsed_seconds:.3f} s")
|
||||
print(f"Real time factor (RTF): {elapsed_seconds:.3f}/{duration:.3f} = {rtf:.3f}")
|
||||
print(result)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@@ -51,6 +51,8 @@ set(sources
|
||||
offline-zipformer-ctc-model-config.cc
|
||||
offline-zipformer-ctc-model.cc
|
||||
online-conformer-transducer-model.cc
|
||||
online-ctc-fst-decoder-config.cc
|
||||
online-ctc-fst-decoder.cc
|
||||
online-ctc-greedy-search-decoder.cc
|
||||
online-ctc-model.cc
|
||||
online-lm-config.cc
|
||||
|
||||
@@ -7,6 +7,9 @@
|
||||
#include <sstream>
|
||||
#include <string>
|
||||
|
||||
#include "sherpa-onnx/csrc/file-utils.h"
|
||||
#include "sherpa-onnx/csrc/macros.h"
|
||||
|
||||
namespace sherpa_onnx {
|
||||
|
||||
std::string OfflineCtcFstDecoderConfig::ToString() const {
|
||||
@@ -29,4 +32,12 @@ void OfflineCtcFstDecoderConfig::Register(ParseOptions *po) {
|
||||
"Decoder max active states. Larger->slower; more accurate");
|
||||
}
|
||||
|
||||
bool OfflineCtcFstDecoderConfig::Validate() const {
|
||||
if (!graph.empty() && !FileExists(graph)) {
|
||||
SHERPA_ONNX_LOGE("graph: %s does not exist", graph.c_str());
|
||||
return false;
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
} // namespace sherpa_onnx
|
||||
|
||||
@@ -24,6 +24,7 @@ struct OfflineCtcFstDecoderConfig {
|
||||
std::string ToString() const;
|
||||
|
||||
void Register(ParseOptions *po);
|
||||
bool Validate() const;
|
||||
};
|
||||
|
||||
} // namespace sherpa_onnx
|
||||
|
||||
@@ -20,7 +20,7 @@ namespace sherpa_onnx {
|
||||
// @param filename Path to a StdVectorFst or StdConstFst graph
|
||||
// @return The caller should free the returned pointer using `delete` to
|
||||
// avoid memory leak.
|
||||
static fst::Fst<fst::StdArc> *ReadGraph(const std::string &filename) {
|
||||
fst::Fst<fst::StdArc> *ReadGraph(const std::string &filename) {
|
||||
// read decoding network FST
|
||||
std::ifstream is(filename, std::ios::binary);
|
||||
if (!is.good()) {
|
||||
|
||||
@@ -67,6 +67,12 @@ bool OfflineRecognizerConfig::Validate() const {
|
||||
return false;
|
||||
}
|
||||
|
||||
if (!ctc_fst_decoder_config.graph.empty() &&
|
||||
!ctc_fst_decoder_config.Validate()) {
|
||||
SHERPA_ONNX_LOGE("Errors in fst_decoder");
|
||||
return false;
|
||||
}
|
||||
|
||||
return model_config.Validate();
|
||||
}
|
||||
|
||||
|
||||
@@ -5,12 +5,16 @@
|
||||
#ifndef SHERPA_ONNX_CSRC_ONLINE_CTC_DECODER_H_
|
||||
#define SHERPA_ONNX_CSRC_ONLINE_CTC_DECODER_H_
|
||||
|
||||
#include <memory>
|
||||
#include <vector>
|
||||
|
||||
#include "kaldi-decoder/csrc/faster-decoder.h"
|
||||
#include "onnxruntime_cxx_api.h" // NOLINT
|
||||
|
||||
namespace sherpa_onnx {
|
||||
|
||||
class OnlineStream;
|
||||
|
||||
struct OnlineCtcDecoderResult {
|
||||
/// Number of frames after subsampling we have decoded so far
|
||||
int32_t frame_offset = 0;
|
||||
@@ -37,7 +41,13 @@ class OnlineCtcDecoder {
|
||||
* @param results Input & Output parameters..
|
||||
*/
|
||||
virtual void Decode(Ort::Value log_probs,
|
||||
std::vector<OnlineCtcDecoderResult> *results) = 0;
|
||||
std::vector<OnlineCtcDecoderResult> *results,
|
||||
OnlineStream **ss = nullptr, int32_t n = 0) = 0;
|
||||
|
||||
virtual std::unique_ptr<kaldi_decoder::FasterDecoder> CreateFasterDecoder()
|
||||
const {
|
||||
return nullptr;
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace sherpa_onnx
|
||||
|
||||
40
sherpa-onnx/csrc/online-ctc-fst-decoder-config.cc
Normal file
40
sherpa-onnx/csrc/online-ctc-fst-decoder-config.cc
Normal file
@@ -0,0 +1,40 @@
|
||||
// sherpa-onnx/csrc/online-ctc-fst-decoder-config.cc
|
||||
//
|
||||
// Copyright (c) 2024 Xiaomi Corporation
|
||||
|
||||
#include "sherpa-onnx/csrc/online-ctc-fst-decoder-config.h"
|
||||
|
||||
#include <sstream>
|
||||
#include <string>
|
||||
|
||||
#include "sherpa-onnx/csrc/file-utils.h"
|
||||
#include "sherpa-onnx/csrc/macros.h"
|
||||
|
||||
namespace sherpa_onnx {
|
||||
|
||||
std::string OnlineCtcFstDecoderConfig::ToString() const {
|
||||
std::ostringstream os;
|
||||
|
||||
os << "OnlineCtcFstDecoderConfig(";
|
||||
os << "graph=\"" << graph << "\", ";
|
||||
os << "max_active=" << max_active << ")";
|
||||
|
||||
return os.str();
|
||||
}
|
||||
|
||||
void OnlineCtcFstDecoderConfig::Register(ParseOptions *po) {
|
||||
po->Register("ctc-graph", &graph, "Path to H.fst, HL.fst, or HLG.fst");
|
||||
|
||||
po->Register("ctc-max-active", &max_active,
|
||||
"Decoder max active states. Larger->slower; more accurate");
|
||||
}
|
||||
|
||||
bool OnlineCtcFstDecoderConfig::Validate() const {
|
||||
if (!graph.empty() && !FileExists(graph)) {
|
||||
SHERPA_ONNX_LOGE("graph: %s does not exist", graph.c_str());
|
||||
return false;
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
} // namespace sherpa_onnx
|
||||
32
sherpa-onnx/csrc/online-ctc-fst-decoder-config.h
Normal file
32
sherpa-onnx/csrc/online-ctc-fst-decoder-config.h
Normal file
@@ -0,0 +1,32 @@
|
||||
// sherpa-onnx/csrc/online-ctc-fst-decoder-config.h
|
||||
//
|
||||
// Copyright (c) 2024 Xiaomi Corporation
|
||||
|
||||
#ifndef SHERPA_ONNX_CSRC_ONLINE_CTC_FST_DECODER_CONFIG_H_
|
||||
#define SHERPA_ONNX_CSRC_ONLINE_CTC_FST_DECODER_CONFIG_H_
|
||||
|
||||
#include <string>
|
||||
|
||||
#include "sherpa-onnx/csrc/parse-options.h"
|
||||
|
||||
namespace sherpa_onnx {
|
||||
|
||||
struct OnlineCtcFstDecoderConfig {
|
||||
// Path to H.fst, HL.fst or HLG.fst
|
||||
std::string graph;
|
||||
int32_t max_active = 3000;
|
||||
|
||||
OnlineCtcFstDecoderConfig() = default;
|
||||
|
||||
OnlineCtcFstDecoderConfig(const std::string &graph, int32_t max_active)
|
||||
: graph(graph), max_active(max_active) {}
|
||||
|
||||
std::string ToString() const;
|
||||
|
||||
void Register(ParseOptions *po);
|
||||
bool Validate() const;
|
||||
};
|
||||
|
||||
} // namespace sherpa_onnx
|
||||
|
||||
#endif // SHERPA_ONNX_CSRC_ONLINE_CTC_FST_DECODER_CONFIG_H_
|
||||
125
sherpa-onnx/csrc/online-ctc-fst-decoder.cc
Normal file
125
sherpa-onnx/csrc/online-ctc-fst-decoder.cc
Normal file
@@ -0,0 +1,125 @@
|
||||
// sherpa-onnx/csrc/online-ctc-fst-decoder.cc
|
||||
//
|
||||
// Copyright (c) 2024 Xiaomi Corporation
|
||||
|
||||
#include "sherpa-onnx/csrc/online-ctc-fst-decoder.h"
|
||||
|
||||
#include <algorithm>
|
||||
#include <memory>
|
||||
#include <string>
|
||||
#include <utility>
|
||||
#include <vector>
|
||||
|
||||
#include "fst/fstlib.h"
|
||||
#include "kaldi-decoder/csrc/decodable-ctc.h"
|
||||
#include "kaldifst/csrc/fstext-utils.h"
|
||||
#include "sherpa-onnx/csrc/macros.h"
|
||||
#include "sherpa-onnx/csrc/online-stream.h"
|
||||
|
||||
namespace sherpa_onnx {
|
||||
|
||||
// defined in ./offline-ctc-fst-decoder.cc
|
||||
fst::Fst<fst::StdArc> *ReadGraph(const std::string &filename);
|
||||
|
||||
OnlineCtcFstDecoder::OnlineCtcFstDecoder(
|
||||
const OnlineCtcFstDecoderConfig &config, int32_t blank_id)
|
||||
: config_(config), fst_(ReadGraph(config.graph)), blank_id_(blank_id) {
|
||||
options_.max_active = config_.max_active;
|
||||
}
|
||||
|
||||
std::unique_ptr<kaldi_decoder::FasterDecoder>
|
||||
OnlineCtcFstDecoder::CreateFasterDecoder() const {
|
||||
return std::make_unique<kaldi_decoder::FasterDecoder>(*fst_, options_);
|
||||
}
|
||||
|
||||
static void DecodeOne(const float *log_probs, int32_t num_rows,
|
||||
int32_t num_cols, OnlineCtcDecoderResult *result,
|
||||
OnlineStream *s, int32_t blank_id) {
|
||||
int32_t &processed_frames = s->GetFasterDecoderProcessedFrames();
|
||||
kaldi_decoder::DecodableCtc decodable(log_probs, num_rows, num_cols,
|
||||
processed_frames);
|
||||
|
||||
kaldi_decoder::FasterDecoder *decoder = s->GetFasterDecoder();
|
||||
if (processed_frames == 0) {
|
||||
decoder->InitDecoding();
|
||||
}
|
||||
|
||||
decoder->AdvanceDecoding(&decodable);
|
||||
|
||||
if (decoder->ReachedFinal()) {
|
||||
fst::VectorFst<fst::LatticeArc> fst_out;
|
||||
bool ok = decoder->GetBestPath(&fst_out);
|
||||
if (ok) {
|
||||
std::vector<int32_t> isymbols_out;
|
||||
std::vector<int32_t> osymbols_out_unused;
|
||||
ok = fst::GetLinearSymbolSequence(fst_out, &isymbols_out,
|
||||
&osymbols_out_unused, nullptr);
|
||||
std::vector<int64_t> tokens;
|
||||
tokens.reserve(isymbols_out.size());
|
||||
|
||||
std::vector<int32_t> timestamps;
|
||||
timestamps.reserve(isymbols_out.size());
|
||||
|
||||
std::ostringstream os;
|
||||
int32_t prev_id = -1;
|
||||
int32_t num_trailing_blanks = 0;
|
||||
int32_t f = 0; // frame number
|
||||
|
||||
for (auto i : isymbols_out) {
|
||||
i -= 1;
|
||||
|
||||
if (i == blank_id) {
|
||||
num_trailing_blanks += 1;
|
||||
} else {
|
||||
num_trailing_blanks = 0;
|
||||
}
|
||||
|
||||
if (i != blank_id && i != prev_id) {
|
||||
tokens.push_back(i);
|
||||
timestamps.push_back(f);
|
||||
}
|
||||
prev_id = i;
|
||||
f += 1;
|
||||
}
|
||||
|
||||
result->tokens = std::move(tokens);
|
||||
result->timestamps = std::move(timestamps);
|
||||
// no need to set frame_offset
|
||||
}
|
||||
}
|
||||
|
||||
processed_frames += num_rows;
|
||||
}
|
||||
|
||||
void OnlineCtcFstDecoder::Decode(Ort::Value log_probs,
|
||||
std::vector<OnlineCtcDecoderResult> *results,
|
||||
OnlineStream **ss, int32_t n) {
|
||||
std::vector<int64_t> log_probs_shape =
|
||||
log_probs.GetTensorTypeAndShapeInfo().GetShape();
|
||||
|
||||
if (log_probs_shape[0] != results->size()) {
|
||||
SHERPA_ONNX_LOGE("Size mismatch! log_probs.size(0) %d, results.size(0): %d",
|
||||
static_cast<int32_t>(log_probs_shape[0]),
|
||||
static_cast<int32_t>(results->size()));
|
||||
exit(-1);
|
||||
}
|
||||
|
||||
if (log_probs_shape[0] != n) {
|
||||
SHERPA_ONNX_LOGE("Size mismatch! log_probs.size(0) %d, n: %d",
|
||||
static_cast<int32_t>(log_probs_shape[0]), n);
|
||||
exit(-1);
|
||||
}
|
||||
|
||||
int32_t batch_size = static_cast<int32_t>(log_probs_shape[0]);
|
||||
int32_t num_frames = static_cast<int32_t>(log_probs_shape[1]);
|
||||
int32_t vocab_size = static_cast<int32_t>(log_probs_shape[2]);
|
||||
|
||||
const float *p = log_probs.GetTensorData<float>();
|
||||
|
||||
for (int32_t i = 0; i != batch_size; ++i) {
|
||||
DecodeOne(p + i * num_frames * vocab_size, num_frames, vocab_size,
|
||||
&(*results)[i], ss[i], blank_id_);
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace sherpa_onnx
|
||||
39
sherpa-onnx/csrc/online-ctc-fst-decoder.h
Normal file
39
sherpa-onnx/csrc/online-ctc-fst-decoder.h
Normal file
@@ -0,0 +1,39 @@
|
||||
// sherpa-onnx/csrc/online-ctc-fst-decoder.h
|
||||
//
|
||||
// Copyright (c) 2024 Xiaomi Corporation
|
||||
|
||||
#ifndef SHERPA_ONNX_CSRC_ONLINE_CTC_FST_DECODER_H_
|
||||
#define SHERPA_ONNX_CSRC_ONLINE_CTC_FST_DECODER_H_
|
||||
|
||||
#include <memory>
|
||||
#include <vector>
|
||||
|
||||
#include "fst/fst.h"
|
||||
#include "sherpa-onnx/csrc/online-ctc-decoder.h"
|
||||
#include "sherpa-onnx/csrc/online-ctc-fst-decoder-config.h"
|
||||
|
||||
namespace sherpa_onnx {
|
||||
|
||||
class OnlineCtcFstDecoder : public OnlineCtcDecoder {
|
||||
public:
|
||||
OnlineCtcFstDecoder(const OnlineCtcFstDecoderConfig &config,
|
||||
int32_t blank_id);
|
||||
|
||||
void Decode(Ort::Value log_probs,
|
||||
std::vector<OnlineCtcDecoderResult> *results,
|
||||
OnlineStream **ss = nullptr, int32_t n = 0) override;
|
||||
|
||||
std::unique_ptr<kaldi_decoder::FasterDecoder> CreateFasterDecoder()
|
||||
const override;
|
||||
|
||||
private:
|
||||
OnlineCtcFstDecoderConfig config_;
|
||||
kaldi_decoder::FasterDecoderOptions options_;
|
||||
|
||||
std::unique_ptr<fst::Fst<fst::StdArc>> fst_;
|
||||
int32_t blank_id_ = 0;
|
||||
};
|
||||
|
||||
} // namespace sherpa_onnx
|
||||
|
||||
#endif // SHERPA_ONNX_CSRC_ONLINE_CTC_FST_DECODER_H_
|
||||
@@ -13,7 +13,8 @@
|
||||
namespace sherpa_onnx {
|
||||
|
||||
void OnlineCtcGreedySearchDecoder::Decode(
|
||||
Ort::Value log_probs, std::vector<OnlineCtcDecoderResult> *results) {
|
||||
Ort::Value log_probs, std::vector<OnlineCtcDecoderResult> *results,
|
||||
OnlineStream ** /*ss=nullptr*/, int32_t /*n = 0*/) {
|
||||
std::vector<int64_t> log_probs_shape =
|
||||
log_probs.GetTensorTypeAndShapeInfo().GetShape();
|
||||
|
||||
|
||||
@@ -17,7 +17,8 @@ class OnlineCtcGreedySearchDecoder : public OnlineCtcDecoder {
|
||||
: blank_id_(blank_id) {}
|
||||
|
||||
void Decode(Ort::Value log_probs,
|
||||
std::vector<OnlineCtcDecoderResult> *results) override;
|
||||
std::vector<OnlineCtcDecoderResult> *results,
|
||||
OnlineStream **ss = nullptr, int32_t n = 0) override;
|
||||
|
||||
private:
|
||||
int32_t blank_id_;
|
||||
|
||||
@@ -16,6 +16,7 @@
|
||||
#include "sherpa-onnx/csrc/file-utils.h"
|
||||
#include "sherpa-onnx/csrc/macros.h"
|
||||
#include "sherpa-onnx/csrc/online-ctc-decoder.h"
|
||||
#include "sherpa-onnx/csrc/online-ctc-fst-decoder.h"
|
||||
#include "sherpa-onnx/csrc/online-ctc-greedy-search-decoder.h"
|
||||
#include "sherpa-onnx/csrc/online-ctc-model.h"
|
||||
#include "sherpa-onnx/csrc/online-recognizer-impl.h"
|
||||
@@ -99,6 +100,7 @@ class OnlineRecognizerCtcImpl : public OnlineRecognizerImpl {
|
||||
std::unique_ptr<OnlineStream> CreateStream() const override {
|
||||
auto stream = std::make_unique<OnlineStream>(config_.feat_config);
|
||||
stream->SetStates(model_->GetInitStates());
|
||||
stream->SetFasterDecoder(decoder_->CreateFasterDecoder());
|
||||
|
||||
return stream;
|
||||
}
|
||||
@@ -165,7 +167,7 @@ class OnlineRecognizerCtcImpl : public OnlineRecognizerImpl {
|
||||
std::vector<std::vector<Ort::Value>> next_states =
|
||||
model_->UnStackStates(std::move(out_states));
|
||||
|
||||
decoder_->Decode(std::move(out[0]), &results);
|
||||
decoder_->Decode(std::move(out[0]), &results, ss, n);
|
||||
|
||||
for (int32_t k = 0; k != n; ++k) {
|
||||
ss[k]->SetCtcResult(results[k]);
|
||||
@@ -221,30 +223,34 @@ class OnlineRecognizerCtcImpl : public OnlineRecognizerImpl {
|
||||
|
||||
private:
|
||||
void InitDecoder() {
|
||||
if (config_.decoding_method == "greedy_search") {
|
||||
if (!sym_.contains("<blk>") && !sym_.contains("<eps>") &&
|
||||
!sym_.contains("<blank>")) {
|
||||
SHERPA_ONNX_LOGE(
|
||||
"We expect that tokens.txt contains "
|
||||
"the symbol <blk> or <eps> or <blank> and its ID.");
|
||||
exit(-1);
|
||||
}
|
||||
if (!sym_.contains("<blk>") && !sym_.contains("<eps>") &&
|
||||
!sym_.contains("<blank>")) {
|
||||
SHERPA_ONNX_LOGE(
|
||||
"We expect that tokens.txt contains "
|
||||
"the symbol <blk> or <eps> or <blank> and its ID.");
|
||||
exit(-1);
|
||||
}
|
||||
|
||||
int32_t blank_id = 0;
|
||||
if (sym_.contains("<blk>")) {
|
||||
blank_id = sym_["<blk>"];
|
||||
} else if (sym_.contains("<eps>")) {
|
||||
// for tdnn models of the yesno recipe from icefall
|
||||
blank_id = sym_["<eps>"];
|
||||
} else if (sym_.contains("<blank>")) {
|
||||
// for WeNet CTC models
|
||||
blank_id = sym_["<blank>"];
|
||||
}
|
||||
int32_t blank_id = 0;
|
||||
if (sym_.contains("<blk>")) {
|
||||
blank_id = sym_["<blk>"];
|
||||
} else if (sym_.contains("<eps>")) {
|
||||
// for tdnn models of the yesno recipe from icefall
|
||||
blank_id = sym_["<eps>"];
|
||||
} else if (sym_.contains("<blank>")) {
|
||||
// for WeNet CTC models
|
||||
blank_id = sym_["<blank>"];
|
||||
}
|
||||
|
||||
if (!config_.ctc_fst_decoder_config.graph.empty()) {
|
||||
decoder_ = std::make_unique<OnlineCtcFstDecoder>(
|
||||
config_.ctc_fst_decoder_config, blank_id);
|
||||
} else if (config_.decoding_method == "greedy_search") {
|
||||
decoder_ = std::make_unique<OnlineCtcGreedySearchDecoder>(blank_id);
|
||||
} else {
|
||||
SHERPA_ONNX_LOGE("Unsupported decoding method: %s",
|
||||
config_.decoding_method.c_str());
|
||||
SHERPA_ONNX_LOGE(
|
||||
"Unsupported decoding method: %s for streaming CTC models",
|
||||
config_.decoding_method.c_str());
|
||||
exit(-1);
|
||||
}
|
||||
}
|
||||
@@ -281,7 +287,7 @@ class OnlineRecognizerCtcImpl : public OnlineRecognizerImpl {
|
||||
std::vector<OnlineCtcDecoderResult> results(1);
|
||||
results[0] = std::move(s->GetCtcResult());
|
||||
|
||||
decoder_->Decode(std::move(out[0]), &results);
|
||||
decoder_->Decode(std::move(out[0]), &results, &s, 1);
|
||||
s->SetCtcResult(results[0]);
|
||||
}
|
||||
|
||||
|
||||
@@ -19,13 +19,13 @@
|
||||
namespace sherpa_onnx {
|
||||
|
||||
/// Helper for `OnlineRecognizerResult::AsJsonString()`
|
||||
template<typename T>
|
||||
std::string VecToString(const std::vector<T>& vec, int32_t precision = 6) {
|
||||
template <typename T>
|
||||
std::string VecToString(const std::vector<T> &vec, int32_t precision = 6) {
|
||||
std::ostringstream oss;
|
||||
oss << std::fixed << std::setprecision(precision);
|
||||
oss << "[ ";
|
||||
std::string sep = "";
|
||||
for (const auto& item : vec) {
|
||||
for (const auto &item : vec) {
|
||||
oss << sep << item;
|
||||
sep = ", ";
|
||||
}
|
||||
@@ -34,13 +34,13 @@ std::string VecToString(const std::vector<T>& vec, int32_t precision = 6) {
|
||||
}
|
||||
|
||||
/// Helper for `OnlineRecognizerResult::AsJsonString()`
|
||||
template<> // explicit specialization for T = std::string
|
||||
std::string VecToString<std::string>(const std::vector<std::string>& vec,
|
||||
template <> // explicit specialization for T = std::string
|
||||
std::string VecToString<std::string>(const std::vector<std::string> &vec,
|
||||
int32_t) { // ignore 2nd arg
|
||||
std::ostringstream oss;
|
||||
oss << "[ ";
|
||||
std::string sep = "";
|
||||
for (const auto& item : vec) {
|
||||
for (const auto &item : vec) {
|
||||
oss << sep << "\"" << item << "\"";
|
||||
sep = ", ";
|
||||
}
|
||||
@@ -51,15 +51,17 @@ std::string VecToString<std::string>(const std::vector<std::string>& vec,
|
||||
std::string OnlineRecognizerResult::AsJsonString() const {
|
||||
std::ostringstream os;
|
||||
os << "{ ";
|
||||
os << "\"text\": " << "\"" << text << "\"" << ", ";
|
||||
os << "\"text\": "
|
||||
<< "\"" << text << "\""
|
||||
<< ", ";
|
||||
os << "\"tokens\": " << VecToString(tokens) << ", ";
|
||||
os << "\"timestamps\": " << VecToString(timestamps, 2) << ", ";
|
||||
os << "\"ys_probs\": " << VecToString(ys_probs, 6) << ", ";
|
||||
os << "\"lm_probs\": " << VecToString(lm_probs, 6) << ", ";
|
||||
os << "\"context_scores\": " << VecToString(context_scores, 6) << ", ";
|
||||
os << "\"segment\": " << segment << ", ";
|
||||
os << "\"start_time\": " << std::fixed << std::setprecision(2)
|
||||
<< start_time << ", ";
|
||||
os << "\"start_time\": " << std::fixed << std::setprecision(2) << start_time
|
||||
<< ", ";
|
||||
os << "\"is_final\": " << (is_final ? "true" : "false");
|
||||
os << "}";
|
||||
return os.str();
|
||||
@@ -70,6 +72,7 @@ void OnlineRecognizerConfig::Register(ParseOptions *po) {
|
||||
model_config.Register(po);
|
||||
endpoint_config.Register(po);
|
||||
lm_config.Register(po);
|
||||
ctc_fst_decoder_config.Register(po);
|
||||
|
||||
po->Register("enable-endpoint", &enable_endpoint,
|
||||
"True to enable endpoint detection. False to disable it.");
|
||||
@@ -116,6 +119,12 @@ bool OnlineRecognizerConfig::Validate() const {
|
||||
return false;
|
||||
}
|
||||
|
||||
if (!ctc_fst_decoder_config.graph.empty() &&
|
||||
!ctc_fst_decoder_config.Validate()) {
|
||||
SHERPA_ONNX_LOGE("Errors in ctc_fst_decoder_config");
|
||||
return false;
|
||||
}
|
||||
|
||||
return model_config.Validate();
|
||||
}
|
||||
|
||||
@@ -127,6 +136,7 @@ std::string OnlineRecognizerConfig::ToString() const {
|
||||
os << "model_config=" << model_config.ToString() << ", ";
|
||||
os << "lm_config=" << lm_config.ToString() << ", ";
|
||||
os << "endpoint_config=" << endpoint_config.ToString() << ", ";
|
||||
os << "ctc_fst_decoder_config=" << ctc_fst_decoder_config.ToString() << ", ";
|
||||
os << "enable_endpoint=" << (enable_endpoint ? "True" : "False") << ", ";
|
||||
os << "max_active_paths=" << max_active_paths << ", ";
|
||||
os << "hotwords_score=" << hotwords_score << ", ";
|
||||
|
||||
@@ -16,6 +16,7 @@
|
||||
|
||||
#include "sherpa-onnx/csrc/endpoint.h"
|
||||
#include "sherpa-onnx/csrc/features.h"
|
||||
#include "sherpa-onnx/csrc/online-ctc-fst-decoder-config.h"
|
||||
#include "sherpa-onnx/csrc/online-lm-config.h"
|
||||
#include "sherpa-onnx/csrc/online-model-config.h"
|
||||
#include "sherpa-onnx/csrc/online-stream.h"
|
||||
@@ -80,6 +81,7 @@ struct OnlineRecognizerConfig {
|
||||
OnlineModelConfig model_config;
|
||||
OnlineLMConfig lm_config;
|
||||
EndpointConfig endpoint_config;
|
||||
OnlineCtcFstDecoderConfig ctc_fst_decoder_config;
|
||||
bool enable_endpoint = true;
|
||||
|
||||
std::string decoding_method = "greedy_search";
|
||||
@@ -96,19 +98,19 @@ struct OnlineRecognizerConfig {
|
||||
|
||||
OnlineRecognizerConfig() = default;
|
||||
|
||||
OnlineRecognizerConfig(const FeatureExtractorConfig &feat_config,
|
||||
const OnlineModelConfig &model_config,
|
||||
const OnlineLMConfig &lm_config,
|
||||
const EndpointConfig &endpoint_config,
|
||||
bool enable_endpoint,
|
||||
const std::string &decoding_method,
|
||||
int32_t max_active_paths,
|
||||
const std::string &hotwords_file, float hotwords_score,
|
||||
float blank_penalty)
|
||||
OnlineRecognizerConfig(
|
||||
const FeatureExtractorConfig &feat_config,
|
||||
const OnlineModelConfig &model_config, const OnlineLMConfig &lm_config,
|
||||
const EndpointConfig &endpoint_config,
|
||||
const OnlineCtcFstDecoderConfig &ctc_fst_decoder_config,
|
||||
bool enable_endpoint, const std::string &decoding_method,
|
||||
int32_t max_active_paths, const std::string &hotwords_file,
|
||||
float hotwords_score, float blank_penalty)
|
||||
: feat_config(feat_config),
|
||||
model_config(model_config),
|
||||
lm_config(lm_config),
|
||||
endpoint_config(endpoint_config),
|
||||
ctc_fst_decoder_config(ctc_fst_decoder_config),
|
||||
enable_endpoint(enable_endpoint),
|
||||
decoding_method(decoding_method),
|
||||
max_active_paths(max_active_paths),
|
||||
|
||||
@@ -104,6 +104,18 @@ class OnlineStream::Impl {
|
||||
return paraformer_alpha_cache_;
|
||||
}
|
||||
|
||||
void SetFasterDecoder(std::unique_ptr<kaldi_decoder::FasterDecoder> decoder) {
|
||||
faster_decoder_ = std::move(decoder);
|
||||
}
|
||||
|
||||
kaldi_decoder::FasterDecoder *GetFasterDecoder() const {
|
||||
return faster_decoder_.get();
|
||||
}
|
||||
|
||||
int32_t &GetFasterDecoderProcessedFrames() {
|
||||
return faster_decoder_processed_frames_;
|
||||
}
|
||||
|
||||
private:
|
||||
FeatureExtractor feat_extractor_;
|
||||
/// For contextual-biasing
|
||||
@@ -121,6 +133,8 @@ class OnlineStream::Impl {
|
||||
std::vector<float> paraformer_encoder_out_cache_;
|
||||
std::vector<float> paraformer_alpha_cache_;
|
||||
OnlineParaformerDecoderResult paraformer_result_;
|
||||
std::unique_ptr<kaldi_decoder::FasterDecoder> faster_decoder_;
|
||||
int32_t faster_decoder_processed_frames_ = 0;
|
||||
};
|
||||
|
||||
OnlineStream::OnlineStream(const FeatureExtractorConfig &config /*= {}*/,
|
||||
@@ -208,6 +222,19 @@ const ContextGraphPtr &OnlineStream::GetContextGraph() const {
|
||||
return impl_->GetContextGraph();
|
||||
}
|
||||
|
||||
void OnlineStream::SetFasterDecoder(
|
||||
std::unique_ptr<kaldi_decoder::FasterDecoder> decoder) {
|
||||
impl_->SetFasterDecoder(std::move(decoder));
|
||||
}
|
||||
|
||||
kaldi_decoder::FasterDecoder *OnlineStream::GetFasterDecoder() const {
|
||||
return impl_->GetFasterDecoder();
|
||||
}
|
||||
|
||||
int32_t &OnlineStream::GetFasterDecoderProcessedFrames() {
|
||||
return impl_->GetFasterDecoderProcessedFrames();
|
||||
}
|
||||
|
||||
std::vector<float> &OnlineStream::GetParaformerFeatCache() {
|
||||
return impl_->GetParaformerFeatCache();
|
||||
}
|
||||
|
||||
@@ -8,6 +8,7 @@
|
||||
#include <memory>
|
||||
#include <vector>
|
||||
|
||||
#include "kaldi-decoder/csrc/faster-decoder.h"
|
||||
#include "onnxruntime_cxx_api.h" // NOLINT
|
||||
#include "sherpa-onnx/csrc/context-graph.h"
|
||||
#include "sherpa-onnx/csrc/features.h"
|
||||
@@ -97,6 +98,11 @@ class OnlineStream {
|
||||
*/
|
||||
const ContextGraphPtr &GetContextGraph() const;
|
||||
|
||||
// for online ctc decoder
|
||||
void SetFasterDecoder(std::unique_ptr<kaldi_decoder::FasterDecoder> decoder);
|
||||
kaldi_decoder::FasterDecoder *GetFasterDecoder() const;
|
||||
int32_t &GetFasterDecoderProcessedFrames();
|
||||
|
||||
// for streaming paraformer
|
||||
std::vector<float> &GetParaformerFeatCache();
|
||||
std::vector<float> &GetParaformerEncoderOutCache();
|
||||
|
||||
@@ -18,6 +18,7 @@ set(srcs
|
||||
offline-wenet-ctc-model-config.cc
|
||||
offline-whisper-model-config.cc
|
||||
offline-zipformer-ctc-model-config.cc
|
||||
online-ctc-fst-decoder-config.cc
|
||||
online-lm-config.cc
|
||||
online-model-config.cc
|
||||
online-paraformer-model-config.cc
|
||||
|
||||
23
sherpa-onnx/python/csrc/online-ctc-fst-decoder-config.cc
Normal file
23
sherpa-onnx/python/csrc/online-ctc-fst-decoder-config.cc
Normal file
@@ -0,0 +1,23 @@
|
||||
// sherpa-onnx/python/csrc/online-ctc-fst-decoder-config.cc
|
||||
//
|
||||
// Copyright (c) 2024 Xiaomi Corporation
|
||||
|
||||
#include "sherpa-onnx/python/csrc/online-ctc-fst-decoder-config.h"
|
||||
|
||||
#include <string>
|
||||
|
||||
#include "sherpa-onnx/csrc/online-ctc-fst-decoder-config.h"
|
||||
|
||||
namespace sherpa_onnx {
|
||||
|
||||
void PybindOnlineCtcFstDecoderConfig(py::module *m) {
|
||||
using PyClass = OnlineCtcFstDecoderConfig;
|
||||
py::class_<PyClass>(*m, "OnlineCtcFstDecoderConfig")
|
||||
.def(py::init<const std::string &, int32_t>(), py::arg("graph") = "",
|
||||
py::arg("max_active") = 3000)
|
||||
.def_readwrite("graph", &PyClass::graph)
|
||||
.def_readwrite("max_active", &PyClass::max_active)
|
||||
.def("__str__", &PyClass::ToString);
|
||||
}
|
||||
|
||||
} // namespace sherpa_onnx
|
||||
16
sherpa-onnx/python/csrc/online-ctc-fst-decoder-config.h
Normal file
16
sherpa-onnx/python/csrc/online-ctc-fst-decoder-config.h
Normal file
@@ -0,0 +1,16 @@
|
||||
// sherpa-onnx/python/csrc/online-ctc-fst-decoder-config.h
|
||||
//
|
||||
// Copyright (c) 2024 Xiaomi Corporation
|
||||
|
||||
#ifndef SHERPA_ONNX_PYTHON_CSRC_ONLINE_CTC_FST_DECODER_CONFIG_H_
|
||||
#define SHERPA_ONNX_PYTHON_CSRC_ONLINE_CTC_FST_DECODER_CONFIG_H_
|
||||
|
||||
#include "sherpa-onnx/python/csrc/sherpa-onnx.h"
|
||||
|
||||
namespace sherpa_onnx {
|
||||
|
||||
void PybindOnlineCtcFstDecoderConfig(py::module *m);
|
||||
|
||||
}
|
||||
|
||||
#endif // SHERPA_ONNX_PYTHON_CSRC_ONLINE_CTC_FST_DECODER_CONFIG_H_
|
||||
@@ -24,8 +24,7 @@ static void PybindOnlineRecognizerResult(py::module *m) {
|
||||
"tokens",
|
||||
[](PyClass &self) -> std::vector<std::string> { return self.tokens; })
|
||||
.def_property_readonly(
|
||||
"start_time",
|
||||
[](PyClass &self) -> float { return self.start_time; })
|
||||
"start_time", [](PyClass &self) -> float { return self.start_time; })
|
||||
.def_property_readonly(
|
||||
"timestamps",
|
||||
[](PyClass &self) -> std::vector<float> { return self.timestamps; })
|
||||
@@ -35,37 +34,38 @@ static void PybindOnlineRecognizerResult(py::module *m) {
|
||||
.def_property_readonly(
|
||||
"lm_probs",
|
||||
[](PyClass &self) -> std::vector<float> { return self.lm_probs; })
|
||||
.def_property_readonly("context_scores",
|
||||
[](PyClass &self) -> std::vector<float> {
|
||||
return self.context_scores;
|
||||
})
|
||||
.def_property_readonly(
|
||||
"context_scores",
|
||||
[](PyClass &self) -> std::vector<float> {
|
||||
return self.context_scores;
|
||||
})
|
||||
"segment", [](PyClass &self) -> int32_t { return self.segment; })
|
||||
.def_property_readonly(
|
||||
"segment",
|
||||
[](PyClass &self) -> int32_t { return self.segment; })
|
||||
.def_property_readonly(
|
||||
"is_final",
|
||||
[](PyClass &self) -> bool { return self.is_final; })
|
||||
"is_final", [](PyClass &self) -> bool { return self.is_final; })
|
||||
.def("as_json_string", &PyClass::AsJsonString,
|
||||
py::call_guard<py::gil_scoped_release>());
|
||||
py::call_guard<py::gil_scoped_release>());
|
||||
}
|
||||
|
||||
static void PybindOnlineRecognizerConfig(py::module *m) {
|
||||
using PyClass = OnlineRecognizerConfig;
|
||||
py::class_<PyClass>(*m, "OnlineRecognizerConfig")
|
||||
.def(py::init<const FeatureExtractorConfig &, const OnlineModelConfig &,
|
||||
const OnlineLMConfig &, const EndpointConfig &, bool,
|
||||
const std::string &, int32_t, const std::string &, float,
|
||||
float>(),
|
||||
py::arg("feat_config"), py::arg("model_config"),
|
||||
py::arg("lm_config") = OnlineLMConfig(), py::arg("endpoint_config"),
|
||||
py::arg("enable_endpoint"), py::arg("decoding_method"),
|
||||
py::arg("max_active_paths") = 4, py::arg("hotwords_file") = "",
|
||||
py::arg("hotwords_score") = 0, py::arg("blank_penalty") = 0.0)
|
||||
.def(
|
||||
py::init<const FeatureExtractorConfig &, const OnlineModelConfig &,
|
||||
const OnlineLMConfig &, const EndpointConfig &,
|
||||
const OnlineCtcFstDecoderConfig &, bool, const std::string &,
|
||||
int32_t, const std::string &, float, float>(),
|
||||
py::arg("feat_config"), py::arg("model_config"),
|
||||
py::arg("lm_config") = OnlineLMConfig(),
|
||||
py::arg("endpoint_config") = EndpointConfig(),
|
||||
py::arg("ctc_fst_decoder_config") = OnlineCtcFstDecoderConfig(),
|
||||
py::arg("enable_endpoint"), py::arg("decoding_method"),
|
||||
py::arg("max_active_paths") = 4, py::arg("hotwords_file") = "",
|
||||
py::arg("hotwords_score") = 0, py::arg("blank_penalty") = 0.0)
|
||||
.def_readwrite("feat_config", &PyClass::feat_config)
|
||||
.def_readwrite("model_config", &PyClass::model_config)
|
||||
.def_readwrite("lm_config", &PyClass::lm_config)
|
||||
.def_readwrite("endpoint_config", &PyClass::endpoint_config)
|
||||
.def_readwrite("ctc_fst_decoder_config", &PyClass::ctc_fst_decoder_config)
|
||||
.def_readwrite("enable_endpoint", &PyClass::enable_endpoint)
|
||||
.def_readwrite("decoding_method", &PyClass::decoding_method)
|
||||
.def_readwrite("max_active_paths", &PyClass::max_active_paths)
|
||||
|
||||
@@ -15,6 +15,7 @@
|
||||
#include "sherpa-onnx/python/csrc/offline-model-config.h"
|
||||
#include "sherpa-onnx/python/csrc/offline-recognizer.h"
|
||||
#include "sherpa-onnx/python/csrc/offline-stream.h"
|
||||
#include "sherpa-onnx/python/csrc/online-ctc-fst-decoder-config.h"
|
||||
#include "sherpa-onnx/python/csrc/online-lm-config.h"
|
||||
#include "sherpa-onnx/python/csrc/online-model-config.h"
|
||||
#include "sherpa-onnx/python/csrc/online-recognizer.h"
|
||||
@@ -36,6 +37,7 @@ PYBIND11_MODULE(_sherpa_onnx, m) {
|
||||
m.doc() = "pybind11 binding of sherpa-onnx";
|
||||
|
||||
PybindFeatures(&m);
|
||||
PybindOnlineCtcFstDecoderConfig(&m);
|
||||
PybindOnlineModelConfig(&m);
|
||||
PybindOnlineLMConfig(&m);
|
||||
PybindOnlineStream(&m);
|
||||
|
||||
@@ -16,6 +16,7 @@ from _sherpa_onnx import (
|
||||
OnlineTransducerModelConfig,
|
||||
OnlineWenetCtcModelConfig,
|
||||
OnlineZipformer2CtcModelConfig,
|
||||
OnlineCtcFstDecoderConfig,
|
||||
)
|
||||
|
||||
|
||||
@@ -314,6 +315,8 @@ class OnlineRecognizer(object):
|
||||
rule2_min_trailing_silence: float = 1.2,
|
||||
rule3_min_utterance_length: float = 20.0,
|
||||
decoding_method: str = "greedy_search",
|
||||
ctc_graph: str = "",
|
||||
ctc_max_active: int = 3000,
|
||||
provider: str = "cpu",
|
||||
):
|
||||
"""
|
||||
@@ -355,6 +358,12 @@ class OnlineRecognizer(object):
|
||||
is detected.
|
||||
decoding_method:
|
||||
The only valid value is greedy_search.
|
||||
ctc_graph:
|
||||
If not empty, decoding_method is ignored. It contains the path to
|
||||
H.fst, HL.fst, or HLG.fst
|
||||
ctc_max_active:
|
||||
Used only when ctc_graph is not empty. It specifies the maximum
|
||||
active paths at a time.
|
||||
provider:
|
||||
onnxruntime execution providers. Valid values are: cpu, cuda, coreml.
|
||||
"""
|
||||
@@ -384,10 +393,16 @@ class OnlineRecognizer(object):
|
||||
rule3_min_utterance_length=rule3_min_utterance_length,
|
||||
)
|
||||
|
||||
ctc_fst_decoder_config = OnlineCtcFstDecoderConfig(
|
||||
graph=ctc_graph,
|
||||
max_active=ctc_max_active,
|
||||
)
|
||||
|
||||
recognizer_config = OnlineRecognizerConfig(
|
||||
feat_config=feat_config,
|
||||
model_config=model_config,
|
||||
endpoint_config=endpoint_config,
|
||||
ctc_fst_decoder_config=ctc_fst_decoder_config,
|
||||
enable_endpoint=enable_endpoint_detection,
|
||||
decoding_method=decoding_method,
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user