Add non-streaming ASR (#92)
This commit is contained in:
10
.github/scripts/Main.kt
vendored
10
.github/scripts/Main.kt
vendored
@@ -33,18 +33,20 @@ fun main() {
|
||||
config = config,
|
||||
)
|
||||
|
||||
var samples = WaveReader.readWave(
|
||||
var objArray = WaveReader.readWave(
|
||||
assetManager = AssetManager(),
|
||||
filename = "./sherpa-onnx-streaming-zipformer-en-2023-02-21/test_wavs/1089-134686-0001.wav",
|
||||
)
|
||||
var samples : FloatArray = objArray[0] as FloatArray
|
||||
var sampleRate : Int = objArray[1] as Int
|
||||
|
||||
model.acceptWaveform(samples!!, sampleRate=16000)
|
||||
model.acceptWaveform(samples, sampleRate=sampleRate)
|
||||
while (model.isReady()) {
|
||||
model.decode()
|
||||
}
|
||||
|
||||
var tail_paddings = FloatArray(8000) // 0.5 seconds
|
||||
model.acceptWaveform(tail_paddings, sampleRate=16000)
|
||||
var tail_paddings = FloatArray((sampleRate * 0.5).toInt()) // 0.5 seconds
|
||||
model.acceptWaveform(tail_paddings, sampleRate=sampleRate)
|
||||
model.inputFinished()
|
||||
while (model.isReady()) {
|
||||
model.decode()
|
||||
|
||||
60
.github/scripts/test-offline-transducer.sh
vendored
Executable file
60
.github/scripts/test-offline-transducer.sh
vendored
Executable file
@@ -0,0 +1,60 @@
|
||||
#!/usr/bin/env bash
|
||||
|
||||
set -e
|
||||
|
||||
log() {
|
||||
# This function is from espnet
|
||||
local fname=${BASH_SOURCE[1]##*/}
|
||||
echo -e "$(date '+%Y-%m-%d %H:%M:%S') (${fname}:${BASH_LINENO[0]}:${FUNCNAME[1]}) $*"
|
||||
}
|
||||
|
||||
echo "EXE is $EXE"
|
||||
echo "PATH: $PATH"
|
||||
|
||||
which $EXE
|
||||
|
||||
log "------------------------------------------------------------"
|
||||
log "Run Conformer transducer (English)"
|
||||
log "------------------------------------------------------------"
|
||||
|
||||
repo_url=https://huggingface.co/csukuangfj/sherpa-onnx-conformer-en-2023-03-18
|
||||
log "Start testing ${repo_url}"
|
||||
repo=$(basename $repo_url)
|
||||
log "Download pretrained model and test-data from $repo_url"
|
||||
|
||||
GIT_LFS_SKIP_SMUDGE=1 git clone $repo_url
|
||||
pushd $repo
|
||||
git lfs pull --include "*.onnx"
|
||||
cd test_wavs
|
||||
popd
|
||||
|
||||
waves=(
|
||||
$repo/test_wavs/0.wav
|
||||
$repo/test_wavs/1.wav
|
||||
$repo/test_wavs/2.wav
|
||||
)
|
||||
|
||||
for wave in ${waves[@]}; do
|
||||
time $EXE \
|
||||
$repo/tokens.txt \
|
||||
$repo/encoder-epoch-99-avg-1.onnx \
|
||||
$repo/decoder-epoch-99-avg-1.onnx \
|
||||
$repo/joiner-epoch-99-avg-1.onnx \
|
||||
$wave \
|
||||
2
|
||||
done
|
||||
|
||||
|
||||
if command -v sox &> /dev/null; then
|
||||
echo "test 8kHz"
|
||||
sox $repo/test_wavs/0.wav -r 8000 8k.wav
|
||||
time $EXE \
|
||||
$repo/tokens.txt \
|
||||
$repo/encoder-epoch-99-avg-1.onnx \
|
||||
$repo/decoder-epoch-99-avg-1.onnx \
|
||||
$repo/joiner-epoch-99-avg-1.onnx \
|
||||
8k.wav \
|
||||
2
|
||||
fi
|
||||
|
||||
rm -rf $repo
|
||||
10
.github/scripts/test-online-transducer.sh
vendored
10
.github/scripts/test-online-transducer.sh
vendored
@@ -40,7 +40,7 @@ for wave in ${waves[@]}; do
|
||||
$repo/decoder-epoch-99-avg-1.onnx \
|
||||
$repo/joiner-epoch-99-avg-1.onnx \
|
||||
$wave \
|
||||
4
|
||||
2
|
||||
done
|
||||
|
||||
rm -rf $repo
|
||||
@@ -72,7 +72,7 @@ for wave in ${waves[@]}; do
|
||||
$repo/decoder-epoch-11-avg-1.onnx \
|
||||
$repo/joiner-epoch-11-avg-1.onnx \
|
||||
$wave \
|
||||
4
|
||||
2
|
||||
done
|
||||
|
||||
rm -rf $repo
|
||||
@@ -104,7 +104,7 @@ for wave in ${waves[@]}; do
|
||||
$repo/decoder-epoch-99-avg-1.onnx \
|
||||
$repo/joiner-epoch-99-avg-1.onnx \
|
||||
$wave \
|
||||
4
|
||||
2
|
||||
done
|
||||
|
||||
rm -rf $repo
|
||||
@@ -138,7 +138,7 @@ for wave in ${waves[@]}; do
|
||||
$repo/decoder-epoch-99-avg-1.onnx \
|
||||
$repo/joiner-epoch-99-avg-1.onnx \
|
||||
$wave \
|
||||
4
|
||||
2
|
||||
done
|
||||
|
||||
# Decode a URL
|
||||
@@ -149,7 +149,7 @@ if [ $EXE == "sherpa-onnx-ffmpeg" ]; then
|
||||
$repo/decoder-epoch-99-avg-1.onnx \
|
||||
$repo/joiner-epoch-99-avg-1.onnx \
|
||||
https://huggingface.co/csukuangfj/sherpa-onnx-streaming-zipformer-bilingual-zh-en-2023-02-20/resolve/main/test_wavs/4.wav \
|
||||
4
|
||||
2
|
||||
fi
|
||||
|
||||
rm -rf $repo
|
||||
|
||||
37
.github/workflows/linux.yaml
vendored
37
.github/workflows/linux.yaml
vendored
@@ -7,11 +7,11 @@ on:
|
||||
paths:
|
||||
- '.github/workflows/linux.yaml'
|
||||
- '.github/scripts/test-online-transducer.sh'
|
||||
- '.github/scripts/test-offline-transducer.sh'
|
||||
- 'CMakeLists.txt'
|
||||
- 'cmake/**'
|
||||
- 'sherpa-onnx/csrc/*'
|
||||
- 'sherpa-onnx/c-api/*'
|
||||
- 'ffmpeg-examples/**'
|
||||
- 'c-api-examples/**'
|
||||
pull_request:
|
||||
branches:
|
||||
@@ -19,11 +19,11 @@ on:
|
||||
paths:
|
||||
- '.github/workflows/linux.yaml'
|
||||
- '.github/scripts/test-online-transducer.sh'
|
||||
- '.github/scripts/test-offline-transducer.sh'
|
||||
- 'CMakeLists.txt'
|
||||
- 'cmake/**'
|
||||
- 'sherpa-onnx/csrc/*'
|
||||
- 'sherpa-onnx/c-api/*'
|
||||
- 'ffmpeg-examples/**'
|
||||
|
||||
concurrency:
|
||||
group: linux-${{ github.ref }}
|
||||
@@ -39,35 +39,26 @@ jobs:
|
||||
fail-fast: false
|
||||
matrix:
|
||||
os: [ubuntu-latest]
|
||||
build_type: [Release, Debug]
|
||||
|
||||
steps:
|
||||
- uses: actions/checkout@v2
|
||||
with:
|
||||
fetch-depth: 0
|
||||
|
||||
- name: Install ffmpeg
|
||||
- name: Install sox
|
||||
shell: bash
|
||||
run: |
|
||||
sudo apt-get install -y software-properties-common
|
||||
sudo add-apt-repository ppa:savoury1/ffmpeg4
|
||||
sudo add-apt-repository ppa:savoury1/ffmpeg5
|
||||
|
||||
sudo apt-get install -y libavdevice-dev libavutil-dev ffmpeg
|
||||
pkg-config --modversion libavutil
|
||||
ffmpeg -version
|
||||
|
||||
- name: Show ffmpeg version
|
||||
shell: bash
|
||||
run: |
|
||||
pkg-config --modversion libavutil
|
||||
ffmpeg -version
|
||||
sudo apt-get update
|
||||
sudo apt-get install -y sox
|
||||
sox -h
|
||||
|
||||
- name: Configure CMake
|
||||
shell: bash
|
||||
run: |
|
||||
mkdir build
|
||||
cd build
|
||||
cmake -D CMAKE_BUILD_TYPE=Release ..
|
||||
cmake -D CMAKE_BUILD_TYPE=${{ matrix.build_type }} ..
|
||||
|
||||
- name: Build sherpa-onnx for ubuntu
|
||||
shell: bash
|
||||
@@ -78,21 +69,19 @@ jobs:
|
||||
ls -lh lib
|
||||
ls -lh bin
|
||||
|
||||
cd ../ffmpeg-examples
|
||||
make
|
||||
|
||||
- name: Display dependencies of sherpa-onnx for linux
|
||||
shell: bash
|
||||
run: |
|
||||
file build/bin/sherpa-onnx
|
||||
readelf -d build/bin/sherpa-onnx
|
||||
|
||||
- name: Test sherpa-onnx-ffmpeg
|
||||
- name: Test offline transducer
|
||||
shell: bash
|
||||
run: |
|
||||
export PATH=$PWD/ffmpeg-examples:$PATH
|
||||
export EXE=sherpa-onnx-ffmpeg
|
||||
export PATH=$PWD/build/bin:$PATH
|
||||
export EXE=sherpa-onnx-offline
|
||||
|
||||
.github/scripts/test-online-transducer.sh
|
||||
.github/scripts/test-offline-transducer.sh
|
||||
|
||||
- name: Test online transducer
|
||||
shell: bash
|
||||
|
||||
19
.github/workflows/macos.yaml
vendored
19
.github/workflows/macos.yaml
vendored
@@ -7,6 +7,7 @@ on:
|
||||
paths:
|
||||
- '.github/workflows/macos.yaml'
|
||||
- '.github/scripts/test-online-transducer.sh'
|
||||
- '.github/scripts/test-offline-transducer.sh'
|
||||
- 'CMakeLists.txt'
|
||||
- 'cmake/**'
|
||||
- 'sherpa-onnx/csrc/*'
|
||||
@@ -16,6 +17,7 @@ on:
|
||||
paths:
|
||||
- '.github/workflows/macos.yaml'
|
||||
- '.github/scripts/test-online-transducer.sh'
|
||||
- '.github/scripts/test-offline-transducer.sh'
|
||||
- 'CMakeLists.txt'
|
||||
- 'cmake/**'
|
||||
- 'sherpa-onnx/csrc/*'
|
||||
@@ -34,18 +36,25 @@ jobs:
|
||||
fail-fast: false
|
||||
matrix:
|
||||
os: [macos-latest]
|
||||
build_type: [Release, Debug]
|
||||
|
||||
steps:
|
||||
- uses: actions/checkout@v2
|
||||
with:
|
||||
fetch-depth: 0
|
||||
|
||||
- name: Install sox
|
||||
shell: bash
|
||||
run: |
|
||||
brew install sox
|
||||
sox -h
|
||||
|
||||
- name: Configure CMake
|
||||
shell: bash
|
||||
run: |
|
||||
mkdir build
|
||||
cd build
|
||||
cmake -D CMAKE_BUILD_TYPE=Release ..
|
||||
cmake -D CMAKE_BUILD_TYPE=${{ matrix.build_type }} ..
|
||||
|
||||
- name: Build sherpa-onnx for macos
|
||||
shell: bash
|
||||
@@ -64,6 +73,14 @@ jobs:
|
||||
otool -L build/bin/sherpa-onnx
|
||||
otool -l build/bin/sherpa-onnx
|
||||
|
||||
- name: Test offline transducer
|
||||
shell: bash
|
||||
run: |
|
||||
export PATH=$PWD/build/bin:$PATH
|
||||
export EXE=sherpa-onnx-offline
|
||||
|
||||
.github/scripts/test-offline-transducer.sh
|
||||
|
||||
- name: Test online transducer
|
||||
shell: bash
|
||||
run: |
|
||||
|
||||
2
.gitignore
vendored
2
.gitignore
vendored
@@ -39,3 +39,5 @@ tags
|
||||
run-decode-file-python.sh
|
||||
android/SherpaOnnx/app/src/main/assets/
|
||||
*.ncnn.*
|
||||
run-sherpa-onnx-offline.sh
|
||||
sherpa-onnx-conformer-en-2023-03-18
|
||||
|
||||
@@ -13,7 +13,7 @@ endif()
|
||||
|
||||
option(SHERPA_ONNX_ENABLE_PYTHON "Whether to build Python" OFF)
|
||||
option(SHERPA_ONNX_ENABLE_TESTS "Whether to build tests" OFF)
|
||||
option(SHERPA_ONNX_ENABLE_CHECK "Whether to build with assert" ON)
|
||||
option(SHERPA_ONNX_ENABLE_CHECK "Whether to build with assert" OFF)
|
||||
option(BUILD_SHARED_LIBS "Whether to build shared libraries" OFF)
|
||||
option(SHERPA_ONNX_ENABLE_PORTAUDIO "Whether to build with portaudio" ON)
|
||||
option(SHERPA_ONNX_ENABLE_JNI "Whether to build JNI internface" OFF)
|
||||
|
||||
@@ -121,7 +121,7 @@ class MainActivity : AppCompatActivity() {
|
||||
val ret = audioRecord?.read(buffer, 0, buffer.size)
|
||||
if (ret != null && ret > 0) {
|
||||
val samples = FloatArray(ret) { buffer[it] / 32768.0f }
|
||||
model.acceptWaveform(samples, sampleRate=16000)
|
||||
model.acceptWaveform(samples, sampleRate=sampleRateInHz)
|
||||
while (model.isReady()) {
|
||||
model.decode()
|
||||
}
|
||||
@@ -180,7 +180,7 @@ class MainActivity : AppCompatActivity() {
|
||||
val type = 0
|
||||
println("Select model type ${type}")
|
||||
val config = OnlineRecognizerConfig(
|
||||
featConfig = getFeatureConfig(sampleRate = 16000, featureDim = 80),
|
||||
featConfig = getFeatureConfig(sampleRate = sampleRateInHz, featureDim = 80),
|
||||
modelConfig = getModelConfig(type = type)!!,
|
||||
endpointConfig = getEndpointConfig(),
|
||||
enableEndpoint = true,
|
||||
|
||||
@@ -8,7 +8,7 @@ class WaveReader {
|
||||
// No resampling is made.
|
||||
external fun readWave(
|
||||
assetManager: AssetManager, filename: String, expected_sample_rate: Float = 16000.0f
|
||||
): FloatArray?
|
||||
): Array<Any>
|
||||
|
||||
init {
|
||||
System.loadLibrary("sherpa-onnx-jni")
|
||||
|
||||
@@ -1,9 +1,9 @@
|
||||
function(download_kaldi_native_fbank)
|
||||
include(FetchContent)
|
||||
|
||||
set(kaldi_native_fbank_URL "https://github.com/csukuangfj/kaldi-native-fbank/archive/refs/tags/v1.13.tar.gz")
|
||||
set(kaldi_native_fbank_URL2 "https://huggingface.co/csukuangfj/sherpa-onnx-cmake-deps/resolve/main/kaldi-native-fbank-1.13.tar.gz")
|
||||
set(kaldi_native_fbank_HASH "SHA256=1f4d228f9fe3e3e9f92a74a7eecd2489071a03982e4ba6d7c70fc5fa7444df57")
|
||||
set(kaldi_native_fbank_URL "https://github.com/csukuangfj/kaldi-native-fbank/archive/refs/tags/v1.14.tar.gz")
|
||||
set(kaldi_native_fbank_URL2 "https://huggingface.co/csukuangfj/sherpa-onnx-cmake-deps/resolve/main/kaldi-native-fbank-1.14.tar.gz")
|
||||
set(kaldi_native_fbank_HASH "SHA256=6a66638a111d3ce21fe6f29cbf9ab3dbcae2331c77391bf825927df5cbf2babe")
|
||||
|
||||
set(KALDI_NATIVE_FBANK_BUILD_TESTS OFF CACHE BOOL "" FORCE)
|
||||
set(KALDI_NATIVE_FBANK_BUILD_PYTHON OFF CACHE BOOL "" FORCE)
|
||||
@@ -12,11 +12,11 @@ function(download_kaldi_native_fbank)
|
||||
# If you don't have access to the Internet,
|
||||
# please pre-download kaldi-native-fbank
|
||||
set(possible_file_locations
|
||||
$ENV{HOME}/Downloads/kaldi-native-fbank-1.13.tar.gz
|
||||
${PROJECT_SOURCE_DIR}/kaldi-native-fbank-1.13.tar.gz
|
||||
${PROJECT_BINARY_DIR}/kaldi-native-fbank-1.13.tar.gz
|
||||
/tmp/kaldi-native-fbank-1.13.tar.gz
|
||||
/star-fj/fangjun/download/github/kaldi-native-fbank-1.13.tar.gz
|
||||
$ENV{HOME}/Downloads/kaldi-native-fbank-1.14.tar.gz
|
||||
${PROJECT_SOURCE_DIR}/kaldi-native-fbank-1.14.tar.gz
|
||||
${PROJECT_BINARY_DIR}/kaldi-native-fbank-1.14.tar.gz
|
||||
/tmp/kaldi-native-fbank-1.14.tar.gz
|
||||
/star-fj/fangjun/download/github/kaldi-native-fbank-1.14.tar.gz
|
||||
)
|
||||
|
||||
foreach(f IN LISTS possible_file_locations)
|
||||
|
||||
@@ -91,7 +91,6 @@ def create_recognizer():
|
||||
rule2_min_trailing_silence=1.2,
|
||||
rule3_min_utterance_length=300, # it essentially disables this rule
|
||||
decoding_method=args.decoding_method,
|
||||
max_feature_vectors=100, # 1 second
|
||||
)
|
||||
return recognizer
|
||||
|
||||
|
||||
@@ -86,7 +86,6 @@ def create_recognizer():
|
||||
sample_rate=16000,
|
||||
feature_dim=80,
|
||||
decoding_method=args.decoding_method,
|
||||
max_feature_vectors=100, # 1 second
|
||||
)
|
||||
return recognizer
|
||||
|
||||
|
||||
@@ -6,6 +6,11 @@ set(sources
|
||||
features.cc
|
||||
file-utils.cc
|
||||
hypothesis.cc
|
||||
offline-stream.cc
|
||||
offline-transducer-greedy-search-decoder.cc
|
||||
offline-transducer-model-config.cc
|
||||
offline-transducer-model.cc
|
||||
offline-recognizer.cc
|
||||
online-lstm-transducer-model.cc
|
||||
online-recognizer.cc
|
||||
online-stream.cc
|
||||
@@ -56,10 +61,13 @@ if(SHERPA_ONNX_ENABLE_CHECK)
|
||||
endif()
|
||||
|
||||
add_executable(sherpa-onnx sherpa-onnx.cc)
|
||||
add_executable(sherpa-onnx-offline sherpa-onnx-offline.cc)
|
||||
|
||||
target_link_libraries(sherpa-onnx sherpa-onnx-core)
|
||||
target_link_libraries(sherpa-onnx-offline sherpa-onnx-core)
|
||||
if(NOT WIN32)
|
||||
target_link_libraries(sherpa-onnx "-Wl,-rpath,${SHERPA_ONNX_RPATH_ORIGIN}/../lib")
|
||||
target_link_libraries(sherpa-onnx-offline "-Wl,-rpath,${SHERPA_ONNX_RPATH_ORIGIN}/../lib")
|
||||
endif()
|
||||
|
||||
if(SHERPA_ONNX_ENABLE_PYTHON AND WIN32)
|
||||
@@ -68,7 +76,13 @@ else()
|
||||
install(TARGETS sherpa-onnx-core DESTINATION lib)
|
||||
endif()
|
||||
|
||||
install(TARGETS sherpa-onnx DESTINATION bin)
|
||||
install(
|
||||
TARGETS
|
||||
sherpa-onnx
|
||||
sherpa-onnx-offline
|
||||
DESTINATION
|
||||
bin
|
||||
)
|
||||
|
||||
if(SHERPA_ONNX_HAS_ALSA)
|
||||
add_executable(sherpa-onnx-alsa sherpa-onnx-alsa.cc alsa.cc)
|
||||
|
||||
@@ -19,7 +19,9 @@ namespace sherpa_onnx {
|
||||
void FeatureExtractorConfig::Register(ParseOptions *po) {
|
||||
po->Register("sample-rate", &sampling_rate,
|
||||
"Sampling rate of the input waveform. Must match the one "
|
||||
"expected by the model.");
|
||||
"expected by the model. Note: You can have a different "
|
||||
"sample rate for the input waveform. We will do resampling "
|
||||
"inside the feature extractor");
|
||||
|
||||
po->Register("feat-dim", &feature_dim,
|
||||
"Feature dimension. Must match the one expected by the model.");
|
||||
@@ -30,8 +32,7 @@ std::string FeatureExtractorConfig::ToString() const {
|
||||
|
||||
os << "FeatureExtractorConfig(";
|
||||
os << "sampling_rate=" << sampling_rate << ", ";
|
||||
os << "feature_dim=" << feature_dim << ", ";
|
||||
os << "max_feature_vectors=" << max_feature_vectors << ")";
|
||||
os << "feature_dim=" << feature_dim << ")";
|
||||
|
||||
return os.str();
|
||||
}
|
||||
@@ -43,8 +44,6 @@ class FeatureExtractor::Impl {
|
||||
opts_.frame_opts.snip_edges = false;
|
||||
opts_.frame_opts.samp_freq = config.sampling_rate;
|
||||
|
||||
opts_.frame_opts.max_feature_vectors = config.max_feature_vectors;
|
||||
|
||||
opts_.mel_opts.num_bins = config.feature_dim;
|
||||
|
||||
fbank_ = std::make_unique<knf::OnlineFbank>(opts_);
|
||||
@@ -95,7 +94,7 @@ class FeatureExtractor::Impl {
|
||||
fbank_->AcceptWaveform(sampling_rate, waveform, n);
|
||||
}
|
||||
|
||||
void InputFinished() {
|
||||
void InputFinished() const {
|
||||
std::lock_guard<std::mutex> lock(mutex_);
|
||||
fbank_->InputFinished();
|
||||
}
|
||||
@@ -110,12 +109,21 @@ class FeatureExtractor::Impl {
|
||||
return fbank_->IsLastFrame(frame);
|
||||
}
|
||||
|
||||
std::vector<float> GetFrames(int32_t frame_index, int32_t n) const {
|
||||
if (frame_index + n > NumFramesReady()) {
|
||||
fprintf(stderr, "%d + %d > %d\n", frame_index, n, NumFramesReady());
|
||||
std::vector<float> GetFrames(int32_t frame_index, int32_t n) {
|
||||
std::lock_guard<std::mutex> lock(mutex_);
|
||||
if (frame_index + n > fbank_->NumFramesReady()) {
|
||||
SHERPA_ONNX_LOGE("%d + %d > %d\n", frame_index, n,
|
||||
fbank_->NumFramesReady());
|
||||
exit(-1);
|
||||
}
|
||||
std::lock_guard<std::mutex> lock(mutex_);
|
||||
|
||||
int32_t discard_num = frame_index - last_frame_index_;
|
||||
if (discard_num < 0) {
|
||||
SHERPA_ONNX_LOGE("last_frame_index_: %d, frame_index_: %d",
|
||||
last_frame_index_, frame_index);
|
||||
exit(-1);
|
||||
}
|
||||
fbank_->Pop(discard_num);
|
||||
|
||||
int32_t feature_dim = fbank_->Dim();
|
||||
std::vector<float> features(feature_dim * n);
|
||||
@@ -128,12 +136,9 @@ class FeatureExtractor::Impl {
|
||||
p += feature_dim;
|
||||
}
|
||||
|
||||
return features;
|
||||
}
|
||||
last_frame_index_ = frame_index;
|
||||
|
||||
void Reset() {
|
||||
std::lock_guard<std::mutex> lock(mutex_);
|
||||
fbank_ = std::make_unique<knf::OnlineFbank>(opts_);
|
||||
return features;
|
||||
}
|
||||
|
||||
int32_t FeatureDim() const { return opts_.mel_opts.num_bins; }
|
||||
@@ -143,6 +148,7 @@ class FeatureExtractor::Impl {
|
||||
knf::FbankOptions opts_;
|
||||
mutable std::mutex mutex_;
|
||||
std::unique_ptr<LinearResample> resampler_;
|
||||
int32_t last_frame_index_ = 0;
|
||||
};
|
||||
|
||||
FeatureExtractor::FeatureExtractor(const FeatureExtractorConfig &config /*={}*/)
|
||||
@@ -151,11 +157,11 @@ FeatureExtractor::FeatureExtractor(const FeatureExtractorConfig &config /*={}*/)
|
||||
FeatureExtractor::~FeatureExtractor() = default;
|
||||
|
||||
void FeatureExtractor::AcceptWaveform(int32_t sampling_rate,
|
||||
const float *waveform, int32_t n) {
|
||||
const float *waveform, int32_t n) const {
|
||||
impl_->AcceptWaveform(sampling_rate, waveform, n);
|
||||
}
|
||||
|
||||
void FeatureExtractor::InputFinished() { impl_->InputFinished(); }
|
||||
void FeatureExtractor::InputFinished() const { impl_->InputFinished(); }
|
||||
|
||||
int32_t FeatureExtractor::NumFramesReady() const {
|
||||
return impl_->NumFramesReady();
|
||||
@@ -170,8 +176,6 @@ std::vector<float> FeatureExtractor::GetFrames(int32_t frame_index,
|
||||
return impl_->GetFrames(frame_index, n);
|
||||
}
|
||||
|
||||
void FeatureExtractor::Reset() { impl_->Reset(); }
|
||||
|
||||
int32_t FeatureExtractor::FeatureDim() const { return impl_->FeatureDim(); }
|
||||
|
||||
} // namespace sherpa_onnx
|
||||
|
||||
@@ -14,9 +14,12 @@
|
||||
namespace sherpa_onnx {
|
||||
|
||||
struct FeatureExtractorConfig {
|
||||
// Sampling rate used by the feature extractor. If it is different from
|
||||
// the sampling rate of the input waveform, we will do resampling inside.
|
||||
int32_t sampling_rate = 16000;
|
||||
|
||||
// Feature dimension
|
||||
int32_t feature_dim = 80;
|
||||
int32_t max_feature_vectors = -1;
|
||||
|
||||
std::string ToString() const;
|
||||
|
||||
@@ -36,7 +39,8 @@ class FeatureExtractor {
|
||||
the range [-1, 1].
|
||||
@param n Number of entries in waveform
|
||||
*/
|
||||
void AcceptWaveform(int32_t sampling_rate, const float *waveform, int32_t n);
|
||||
void AcceptWaveform(int32_t sampling_rate, const float *waveform,
|
||||
int32_t n) const;
|
||||
|
||||
/**
|
||||
* InputFinished() tells the class you won't be providing any
|
||||
@@ -44,7 +48,7 @@ class FeatureExtractor {
|
||||
* of features, in the case where snip-edges == false; it also
|
||||
* affects the return value of IsLastFrame().
|
||||
*/
|
||||
void InputFinished();
|
||||
void InputFinished() const;
|
||||
|
||||
int32_t NumFramesReady() const;
|
||||
|
||||
@@ -62,8 +66,6 @@ class FeatureExtractor {
|
||||
*/
|
||||
std::vector<float> GetFrames(int32_t frame_index, int32_t n) const;
|
||||
|
||||
void Reset();
|
||||
|
||||
/// Return feature dim of this extractor
|
||||
int32_t FeatureDim() const;
|
||||
|
||||
|
||||
163
sherpa-onnx/csrc/offline-recognizer.cc
Normal file
163
sherpa-onnx/csrc/offline-recognizer.cc
Normal file
@@ -0,0 +1,163 @@
|
||||
// sherpa-onnx/csrc/offline-recognizer.cc
|
||||
//
|
||||
// Copyright (c) 2023 Xiaomi Corporation
|
||||
|
||||
#include "sherpa-onnx/csrc/offline-recognizer.h"
|
||||
|
||||
#include <memory>
|
||||
#include <utility>
|
||||
|
||||
#include "sherpa-onnx/csrc/macros.h"
|
||||
#include "sherpa-onnx/csrc/offline-transducer-decoder.h"
|
||||
#include "sherpa-onnx/csrc/offline-transducer-greedy-search-decoder.h"
|
||||
#include "sherpa-onnx/csrc/offline-transducer-model.h"
|
||||
#include "sherpa-onnx/csrc/pad-sequence.h"
|
||||
#include "sherpa-onnx/csrc/symbol-table.h"
|
||||
|
||||
namespace sherpa_onnx {
|
||||
|
||||
static OfflineRecognitionResult Convert(
|
||||
const OfflineTransducerDecoderResult &src, const SymbolTable &sym_table,
|
||||
int32_t frame_shift_ms, int32_t subsampling_factor) {
|
||||
OfflineRecognitionResult r;
|
||||
r.tokens.reserve(src.tokens.size());
|
||||
r.timestamps.reserve(src.timestamps.size());
|
||||
|
||||
std::string text;
|
||||
for (auto i : src.tokens) {
|
||||
auto sym = sym_table[i];
|
||||
text.append(sym);
|
||||
|
||||
r.tokens.push_back(std::move(sym));
|
||||
}
|
||||
r.text = std::move(text);
|
||||
|
||||
float frame_shift_s = frame_shift_ms / 1000. * subsampling_factor;
|
||||
for (auto t : src.timestamps) {
|
||||
float time = frame_shift_s * t;
|
||||
r.timestamps.push_back(time);
|
||||
}
|
||||
|
||||
return r;
|
||||
}
|
||||
|
||||
void OfflineRecognizerConfig::Register(ParseOptions *po) {
|
||||
feat_config.Register(po);
|
||||
model_config.Register(po);
|
||||
|
||||
po->Register("decoding-method", &decoding_method,
|
||||
"decoding method,"
|
||||
"Valid values: greedy_search.");
|
||||
}
|
||||
|
||||
bool OfflineRecognizerConfig::Validate() const {
|
||||
return model_config.Validate();
|
||||
}
|
||||
|
||||
std::string OfflineRecognizerConfig::ToString() const {
|
||||
std::ostringstream os;
|
||||
|
||||
os << "OfflineRecognizerConfig(";
|
||||
os << "feat_config=" << feat_config.ToString() << ", ";
|
||||
os << "model_config=" << model_config.ToString() << ", ";
|
||||
os << "decoding_method=\"" << decoding_method << "\")";
|
||||
|
||||
return os.str();
|
||||
}
|
||||
|
||||
class OfflineRecognizer::Impl {
|
||||
public:
|
||||
explicit Impl(const OfflineRecognizerConfig &config)
|
||||
: config_(config),
|
||||
symbol_table_(config_.model_config.tokens),
|
||||
model_(std::make_unique<OfflineTransducerModel>(config_.model_config)) {
|
||||
if (config_.decoding_method == "greedy_search") {
|
||||
decoder_ =
|
||||
std::make_unique<OfflineTransducerGreedySearchDecoder>(model_.get());
|
||||
} else if (config_.decoding_method == "modified_beam_search") {
|
||||
SHERPA_ONNX_LOGE("TODO: modified_beam_search is to be implemented");
|
||||
exit(-1);
|
||||
} else {
|
||||
SHERPA_ONNX_LOGE("Unsupported decoding method: %s",
|
||||
config_.decoding_method.c_str());
|
||||
exit(-1);
|
||||
}
|
||||
}
|
||||
|
||||
std::unique_ptr<OfflineStream> CreateStream() const {
|
||||
return std::make_unique<OfflineStream>(config_.feat_config);
|
||||
}
|
||||
|
||||
void DecodeStreams(OfflineStream **ss, int32_t n) const {
|
||||
auto memory_info =
|
||||
Ort::MemoryInfo::CreateCpu(OrtDeviceAllocator, OrtMemTypeDefault);
|
||||
|
||||
int32_t feat_dim = ss[0]->FeatureDim();
|
||||
|
||||
std::vector<Ort::Value> features;
|
||||
|
||||
features.reserve(n);
|
||||
|
||||
std::vector<std::vector<float>> features_vec(n);
|
||||
std::vector<int64_t> features_length_vec(n);
|
||||
for (int32_t i = 0; i != n; ++i) {
|
||||
auto f = ss[i]->GetFrames();
|
||||
int32_t num_frames = f.size() / feat_dim;
|
||||
|
||||
features_length_vec[i] = num_frames;
|
||||
features_vec[i] = std::move(f);
|
||||
|
||||
std::array<int64_t, 2> shape = {num_frames, feat_dim};
|
||||
|
||||
Ort::Value x = Ort::Value::CreateTensor(
|
||||
memory_info, features_vec[i].data(), features_vec[i].size(),
|
||||
shape.data(), shape.size());
|
||||
features.push_back(std::move(x));
|
||||
}
|
||||
|
||||
std::vector<const Ort::Value *> features_pointer(n);
|
||||
for (int32_t i = 0; i != n; ++i) {
|
||||
features_pointer[i] = &features[i];
|
||||
}
|
||||
|
||||
std::array<int64_t, 1> features_length_shape = {n};
|
||||
Ort::Value x_length = Ort::Value::CreateTensor(
|
||||
memory_info, features_length_vec.data(), n,
|
||||
features_length_shape.data(), features_length_shape.size());
|
||||
|
||||
Ort::Value x = PadSequence(model_->Allocator(), features_pointer,
|
||||
-23.025850929940457f);
|
||||
|
||||
auto t = model_->RunEncoder(std::move(x), std::move(x_length));
|
||||
auto results = decoder_->Decode(std::move(t.first), std::move(t.second));
|
||||
|
||||
int32_t frame_shift_ms = 10;
|
||||
for (int32_t i = 0; i != n; ++i) {
|
||||
auto r = Convert(results[i], symbol_table_, frame_shift_ms,
|
||||
model_->SubsamplingFactor());
|
||||
|
||||
ss[i]->SetResult(r);
|
||||
}
|
||||
}
|
||||
|
||||
private:
|
||||
OfflineRecognizerConfig config_;
|
||||
SymbolTable symbol_table_;
|
||||
std::unique_ptr<OfflineTransducerModel> model_;
|
||||
std::unique_ptr<OfflineTransducerDecoder> decoder_;
|
||||
};
|
||||
|
||||
OfflineRecognizer::OfflineRecognizer(const OfflineRecognizerConfig &config)
|
||||
: impl_(std::make_unique<Impl>(config)) {}
|
||||
|
||||
OfflineRecognizer::~OfflineRecognizer() = default;
|
||||
|
||||
std::unique_ptr<OfflineStream> OfflineRecognizer::CreateStream() const {
|
||||
return impl_->CreateStream();
|
||||
}
|
||||
|
||||
void OfflineRecognizer::DecodeStreams(OfflineStream **ss, int32_t n) const {
|
||||
impl_->DecodeStreams(ss, n);
|
||||
}
|
||||
|
||||
} // namespace sherpa_onnx
|
||||
87
sherpa-onnx/csrc/offline-recognizer.h
Normal file
87
sherpa-onnx/csrc/offline-recognizer.h
Normal file
@@ -0,0 +1,87 @@
|
||||
// sherpa-onnx/csrc/offline-recognizer.h
|
||||
//
|
||||
// Copyright (c) 2023 Xiaomi Corporation
|
||||
|
||||
#ifndef SHERPA_ONNX_CSRC_OFFLINE_RECOGNIZER_H_
|
||||
#define SHERPA_ONNX_CSRC_OFFLINE_RECOGNIZER_H_
|
||||
|
||||
#include <memory>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
|
||||
#include "sherpa-onnx/csrc/offline-stream.h"
|
||||
#include "sherpa-onnx/csrc/offline-transducer-model-config.h"
|
||||
#include "sherpa-onnx/csrc/parse-options.h"
|
||||
|
||||
namespace sherpa_onnx {
|
||||
|
||||
struct OfflineRecognitionResult {
|
||||
// Recognition results.
|
||||
// For English, it consists of space separated words.
|
||||
// For Chinese, it consists of Chinese words without spaces.
|
||||
std::string text;
|
||||
|
||||
// Decoded results at the token level.
|
||||
// For instance, for BPE-based models it consists of a list of BPE tokens.
|
||||
std::vector<std::string> tokens;
|
||||
|
||||
/// timestamps.size() == tokens.size()
|
||||
/// timestamps[i] records the time in seconds when tokens[i] is decoded.
|
||||
std::vector<float> timestamps;
|
||||
};
|
||||
|
||||
struct OfflineRecognizerConfig {
|
||||
OfflineFeatureExtractorConfig feat_config;
|
||||
OfflineTransducerModelConfig model_config;
|
||||
|
||||
std::string decoding_method = "greedy_search";
|
||||
// only greedy_search is implemented
|
||||
// TODO(fangjun): Implement modified_beam_search
|
||||
|
||||
OfflineRecognizerConfig() = default;
|
||||
OfflineRecognizerConfig(const OfflineFeatureExtractorConfig &feat_config,
|
||||
const OfflineTransducerModelConfig &model_config,
|
||||
const std::string &decoding_method)
|
||||
: feat_config(feat_config),
|
||||
model_config(model_config),
|
||||
decoding_method(decoding_method) {}
|
||||
|
||||
void Register(ParseOptions *po);
|
||||
bool Validate() const;
|
||||
|
||||
std::string ToString() const;
|
||||
};
|
||||
|
||||
class OfflineRecognizer {
|
||||
public:
|
||||
~OfflineRecognizer();
|
||||
|
||||
explicit OfflineRecognizer(const OfflineRecognizerConfig &config);
|
||||
|
||||
/// Create a stream for decoding.
|
||||
std::unique_ptr<OfflineStream> CreateStream() const;
|
||||
|
||||
/** Decode a single stream
|
||||
*
|
||||
* @param s The stream to decode.
|
||||
*/
|
||||
void DecodeStream(OfflineStream *s) const {
|
||||
OfflineStream *ss[1] = {s};
|
||||
DecodeStreams(ss, 1);
|
||||
}
|
||||
|
||||
/** Decode a list of streams.
|
||||
*
|
||||
* @param ss Pointer to an array of streams.
|
||||
* @param n Size of the input array.
|
||||
*/
|
||||
void DecodeStreams(OfflineStream **ss, int32_t n) const;
|
||||
|
||||
private:
|
||||
class Impl;
|
||||
std::unique_ptr<Impl> impl_;
|
||||
};
|
||||
|
||||
} // namespace sherpa_onnx
|
||||
|
||||
#endif // SHERPA_ONNX_CSRC_OFFLINE_RECOGNIZER_H_
|
||||
134
sherpa-onnx/csrc/offline-stream.cc
Normal file
134
sherpa-onnx/csrc/offline-stream.cc
Normal file
@@ -0,0 +1,134 @@
|
||||
// sherpa-onnx/csrc/offline-stream.cc
|
||||
//
|
||||
// Copyright (c) 2023 Xiaomi Corporation
|
||||
|
||||
#include "sherpa-onnx/csrc/offline-stream.h"
|
||||
|
||||
#include <assert.h>
|
||||
|
||||
#include <algorithm>
|
||||
|
||||
#include "kaldi-native-fbank/csrc/online-feature.h"
|
||||
#include "sherpa-onnx/csrc/macros.h"
|
||||
#include "sherpa-onnx/csrc/offline-recognizer.h"
|
||||
#include "sherpa-onnx/csrc/resample.h"
|
||||
|
||||
namespace sherpa_onnx {
|
||||
|
||||
void OfflineFeatureExtractorConfig::Register(ParseOptions *po) {
|
||||
po->Register("sample-rate", &sampling_rate,
|
||||
"Sampling rate of the input waveform. Must match the one "
|
||||
"expected by the model. Note: You can have a different "
|
||||
"sample rate for the input waveform. We will do resampling "
|
||||
"inside the feature extractor");
|
||||
|
||||
po->Register("feat-dim", &feature_dim,
|
||||
"Feature dimension. Must match the one expected by the model.");
|
||||
}
|
||||
|
||||
std::string OfflineFeatureExtractorConfig::ToString() const {
|
||||
std::ostringstream os;
|
||||
|
||||
os << "OfflineFeatureExtractorConfig(";
|
||||
os << "sampling_rate=" << sampling_rate << ", ";
|
||||
os << "feature_dim=" << feature_dim << ")";
|
||||
|
||||
return os.str();
|
||||
}
|
||||
|
||||
class OfflineStream::Impl {
|
||||
public:
|
||||
explicit Impl(const OfflineFeatureExtractorConfig &config) {
|
||||
opts_.frame_opts.dither = 0;
|
||||
opts_.frame_opts.snip_edges = false;
|
||||
opts_.frame_opts.samp_freq = config.sampling_rate;
|
||||
opts_.mel_opts.num_bins = config.feature_dim;
|
||||
|
||||
fbank_ = std::make_unique<knf::OnlineFbank>(opts_);
|
||||
}
|
||||
|
||||
void AcceptWaveform(int32_t sampling_rate, const float *waveform, int32_t n) {
|
||||
if (sampling_rate != opts_.frame_opts.samp_freq) {
|
||||
SHERPA_ONNX_LOGE(
|
||||
"Creating a resampler:\n"
|
||||
" in_sample_rate: %d\n"
|
||||
" output_sample_rate: %d\n",
|
||||
sampling_rate, static_cast<int32_t>(opts_.frame_opts.samp_freq));
|
||||
|
||||
float min_freq =
|
||||
std::min<int32_t>(sampling_rate, opts_.frame_opts.samp_freq);
|
||||
float lowpass_cutoff = 0.99 * 0.5 * min_freq;
|
||||
|
||||
int32_t lowpass_filter_width = 6;
|
||||
auto resampler = std::make_unique<LinearResample>(
|
||||
sampling_rate, opts_.frame_opts.samp_freq, lowpass_cutoff,
|
||||
lowpass_filter_width);
|
||||
std::vector<float> samples;
|
||||
resampler->Resample(waveform, n, true, &samples);
|
||||
fbank_->AcceptWaveform(opts_.frame_opts.samp_freq, samples.data(),
|
||||
samples.size());
|
||||
fbank_->InputFinished();
|
||||
return;
|
||||
}
|
||||
|
||||
fbank_->AcceptWaveform(sampling_rate, waveform, n);
|
||||
fbank_->InputFinished();
|
||||
}
|
||||
|
||||
int32_t FeatureDim() const { return opts_.mel_opts.num_bins; }
|
||||
|
||||
std::vector<float> GetFrames() const {
|
||||
int32_t n = fbank_->NumFramesReady();
|
||||
assert(n > 0 && "Please first call AcceptWaveform()");
|
||||
|
||||
int32_t feature_dim = FeatureDim();
|
||||
|
||||
std::vector<float> features(n * feature_dim);
|
||||
|
||||
float *p = features.data();
|
||||
|
||||
for (int32_t i = 0; i != n; ++i) {
|
||||
const float *f = fbank_->GetFrame(i);
|
||||
std::copy(f, f + feature_dim, p);
|
||||
p += feature_dim;
|
||||
}
|
||||
|
||||
return features;
|
||||
}
|
||||
|
||||
void SetResult(const OfflineRecognitionResult &r) { r_ = r; }
|
||||
|
||||
const OfflineRecognitionResult &GetResult() const { return r_; }
|
||||
|
||||
private:
|
||||
std::unique_ptr<knf::OnlineFbank> fbank_;
|
||||
knf::FbankOptions opts_;
|
||||
OfflineRecognitionResult r_;
|
||||
};
|
||||
|
||||
OfflineStream::OfflineStream(
|
||||
const OfflineFeatureExtractorConfig &config /*= {}*/)
|
||||
: impl_(std::make_unique<Impl>(config)) {}
|
||||
|
||||
OfflineStream::~OfflineStream() = default;
|
||||
|
||||
void OfflineStream::AcceptWaveform(int32_t sampling_rate, const float *waveform,
|
||||
int32_t n) const {
|
||||
impl_->AcceptWaveform(sampling_rate, waveform, n);
|
||||
}
|
||||
|
||||
int32_t OfflineStream::FeatureDim() const { return impl_->FeatureDim(); }
|
||||
|
||||
std::vector<float> OfflineStream::GetFrames() const {
|
||||
return impl_->GetFrames();
|
||||
}
|
||||
|
||||
void OfflineStream::SetResult(const OfflineRecognitionResult &r) {
|
||||
impl_->SetResult(r);
|
||||
}
|
||||
|
||||
const OfflineRecognitionResult &OfflineStream::GetResult() const {
|
||||
return impl_->GetResult();
|
||||
}
|
||||
|
||||
} // namespace sherpa_onnx
|
||||
70
sherpa-onnx/csrc/offline-stream.h
Normal file
70
sherpa-onnx/csrc/offline-stream.h
Normal file
@@ -0,0 +1,70 @@
|
||||
// sherpa-onnx/csrc/offline-stream.h
|
||||
//
|
||||
// Copyright (c) 2023 Xiaomi Corporation
|
||||
|
||||
#ifndef SHERPA_ONNX_CSRC_OFFLINE_STREAM_H_
|
||||
#define SHERPA_ONNX_CSRC_OFFLINE_STREAM_H_
|
||||
#include <stdint.h>
|
||||
|
||||
#include <memory>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
|
||||
#include "sherpa-onnx/csrc/parse-options.h"
|
||||
|
||||
namespace sherpa_onnx {
|
||||
struct OfflineRecognitionResult;
|
||||
|
||||
struct OfflineFeatureExtractorConfig {
|
||||
// Sampling rate used by the feature extractor. If it is different from
|
||||
// the sampling rate of the input waveform, we will do resampling inside.
|
||||
int32_t sampling_rate = 16000;
|
||||
|
||||
// Feature dimension
|
||||
int32_t feature_dim = 80;
|
||||
|
||||
std::string ToString() const;
|
||||
|
||||
void Register(ParseOptions *po);
|
||||
};
|
||||
|
||||
class OfflineStream {
|
||||
public:
|
||||
explicit OfflineStream(const OfflineFeatureExtractorConfig &config = {});
|
||||
~OfflineStream();
|
||||
|
||||
/**
|
||||
@param sampling_rate The sampling_rate of the input waveform. If it does
|
||||
not equal to config.sampling_rate, we will do
|
||||
resampling inside.
|
||||
@param waveform Pointer to a 1-D array of size n. It must be normalized to
|
||||
the range [-1, 1].
|
||||
@param n Number of entries in waveform
|
||||
|
||||
Caution: You can only invoke this function once so you have to input
|
||||
all the samples at once
|
||||
*/
|
||||
void AcceptWaveform(int32_t sampling_rate, const float *waveform,
|
||||
int32_t n) const;
|
||||
|
||||
/// Return feature dim of this extractor
|
||||
int32_t FeatureDim() const;
|
||||
|
||||
// Get all the feature frames of this stream in a 1-D array, which is
|
||||
// flattened from a 2-D array of shape (num_frames, feat_dim).
|
||||
std::vector<float> GetFrames() const;
|
||||
|
||||
/** Set the recognition result for this stream. */
|
||||
void SetResult(const OfflineRecognitionResult &r);
|
||||
|
||||
/** Get the recognition result of this stream */
|
||||
const OfflineRecognitionResult &GetResult() const;
|
||||
|
||||
private:
|
||||
class Impl;
|
||||
std::unique_ptr<Impl> impl_;
|
||||
};
|
||||
|
||||
} // namespace sherpa_onnx
|
||||
|
||||
#endif // SHERPA_ONNX_CSRC_OFFLINE_STREAM_H_
|
||||
41
sherpa-onnx/csrc/offline-transducer-decoder.h
Normal file
41
sherpa-onnx/csrc/offline-transducer-decoder.h
Normal file
@@ -0,0 +1,41 @@
|
||||
// sherpa-onnx/csrc/offline-transducer-decoder.h
|
||||
//
|
||||
// Copyright (c) 2023 Xiaomi Corporation
|
||||
|
||||
#ifndef SHERPA_ONNX_CSRC_OFFLINE_TRANSDUCER_DECODER_H_
|
||||
#define SHERPA_ONNX_CSRC_OFFLINE_TRANSDUCER_DECODER_H_
|
||||
|
||||
#include <vector>
|
||||
|
||||
#include "onnxruntime_cxx_api.h" // NOLINT
|
||||
|
||||
namespace sherpa_onnx {
|
||||
|
||||
struct OfflineTransducerDecoderResult {
|
||||
/// The decoded token IDs
|
||||
std::vector<int64_t> tokens;
|
||||
|
||||
/// timestamps[i] contains the output frame index where tokens[i] is decoded.
|
||||
/// Note: The index is after subsampling
|
||||
std::vector<int32_t> timestamps;
|
||||
};
|
||||
|
||||
class OfflineTransducerDecoder {
|
||||
public:
|
||||
virtual ~OfflineTransducerDecoder() = default;
|
||||
|
||||
/** Run transducer beam search given the output from the encoder model.
|
||||
*
|
||||
* @param encoder_out A 3-D tensor of shape (N, T, joiner_dim)
|
||||
* @param encoder_out_length A 1-D tensor of shape (N,) containing number
|
||||
* of valid frames in encoder_out before padding.
|
||||
*
|
||||
* @return Return a vector of size `N` containing the decoded results.
|
||||
*/
|
||||
virtual std::vector<OfflineTransducerDecoderResult> Decode(
|
||||
Ort::Value encoder_out, Ort::Value encoder_out_length) = 0;
|
||||
};
|
||||
|
||||
} // namespace sherpa_onnx
|
||||
|
||||
#endif // SHERPA_ONNX_CSRC_OFFLINE_TRANSDUCER_DECODER_H_
|
||||
79
sherpa-onnx/csrc/offline-transducer-greedy-search-decoder.cc
Normal file
79
sherpa-onnx/csrc/offline-transducer-greedy-search-decoder.cc
Normal file
@@ -0,0 +1,79 @@
|
||||
// sherpa-onnx/csrc/offline-transducer-greedy-search-decoder.cc
|
||||
//
|
||||
// Copyright (c) 2023 Xiaomi Corporation
|
||||
|
||||
#include "sherpa-onnx/csrc/offline-transducer-greedy-search-decoder.h"
|
||||
|
||||
#include <algorithm>
|
||||
#include <iterator>
|
||||
#include <utility>
|
||||
|
||||
#include "sherpa-onnx/csrc/onnx-utils.h"
|
||||
#include "sherpa-onnx/csrc/packed-sequence.h"
|
||||
#include "sherpa-onnx/csrc/slice.h"
|
||||
|
||||
namespace sherpa_onnx {
|
||||
|
||||
std::vector<OfflineTransducerDecoderResult>
|
||||
OfflineTransducerGreedySearchDecoder::Decode(Ort::Value encoder_out,
|
||||
Ort::Value encoder_out_length) {
|
||||
PackedSequence packed_encoder_out = PackPaddedSequence(
|
||||
model_->Allocator(), &encoder_out, &encoder_out_length);
|
||||
|
||||
int32_t batch_size =
|
||||
static_cast<int32_t>(packed_encoder_out.sorted_indexes.size());
|
||||
|
||||
int32_t vocab_size = model_->VocabSize();
|
||||
int32_t context_size = model_->ContextSize();
|
||||
|
||||
std::vector<OfflineTransducerDecoderResult> ans(batch_size);
|
||||
for (auto &r : ans) {
|
||||
// 0 is the ID of the blank token
|
||||
r.tokens.resize(context_size, 0);
|
||||
}
|
||||
|
||||
auto decoder_input = model_->BuildDecoderInput(ans, ans.size());
|
||||
Ort::Value decoder_out = model_->RunDecoder(std::move(decoder_input));
|
||||
|
||||
int32_t start = 0;
|
||||
int32_t t = 0;
|
||||
for (auto n : packed_encoder_out.batch_sizes) {
|
||||
Ort::Value cur_encoder_out = packed_encoder_out.Get(start, n);
|
||||
Ort::Value cur_decoder_out = Slice(model_->Allocator(), &decoder_out, 0, n);
|
||||
start += n;
|
||||
Ort::Value logit = model_->RunJoiner(std::move(cur_encoder_out),
|
||||
std::move(cur_decoder_out));
|
||||
const float *p_logit = logit.GetTensorData<float>();
|
||||
bool emitted = false;
|
||||
for (int32_t i = 0; i != n; ++i) {
|
||||
auto y = static_cast<int32_t>(std::distance(
|
||||
static_cast<const float *>(p_logit),
|
||||
std::max_element(static_cast<const float *>(p_logit),
|
||||
static_cast<const float *>(p_logit) + vocab_size)));
|
||||
p_logit += vocab_size;
|
||||
if (y != 0) {
|
||||
ans[i].tokens.push_back(y);
|
||||
ans[i].timestamps.push_back(t);
|
||||
emitted = true;
|
||||
}
|
||||
}
|
||||
if (emitted) {
|
||||
Ort::Value decoder_input = model_->BuildDecoderInput(ans, n);
|
||||
decoder_out = model_->RunDecoder(std::move(decoder_input));
|
||||
}
|
||||
++t;
|
||||
}
|
||||
|
||||
for (auto &r : ans) {
|
||||
r.tokens = {r.tokens.begin() + context_size, r.tokens.end()};
|
||||
}
|
||||
|
||||
std::vector<OfflineTransducerDecoderResult> unsorted_ans(batch_size);
|
||||
for (int32_t i = 0; i != batch_size; ++i) {
|
||||
unsorted_ans[packed_encoder_out.sorted_indexes[i]] = std::move(ans[i]);
|
||||
}
|
||||
|
||||
return unsorted_ans;
|
||||
}
|
||||
|
||||
} // namespace sherpa_onnx
|
||||
29
sherpa-onnx/csrc/offline-transducer-greedy-search-decoder.h
Normal file
29
sherpa-onnx/csrc/offline-transducer-greedy-search-decoder.h
Normal file
@@ -0,0 +1,29 @@
|
||||
// sherpa-onnx/csrc/offline-transducer-greedy-search-decoder.h
|
||||
//
|
||||
// Copyright (c) 2023 Xiaomi Corporation
|
||||
|
||||
#ifndef SHERPA_ONNX_CSRC_OFFLINE_TRANSDUCER_GREEDY_SEARCH_DECODER_H_
|
||||
#define SHERPA_ONNX_CSRC_OFFLINE_TRANSDUCER_GREEDY_SEARCH_DECODER_H_
|
||||
|
||||
#include <vector>
|
||||
|
||||
#include "sherpa-onnx/csrc/offline-transducer-decoder.h"
|
||||
#include "sherpa-onnx/csrc/offline-transducer-model.h"
|
||||
|
||||
namespace sherpa_onnx {
|
||||
|
||||
class OfflineTransducerGreedySearchDecoder : public OfflineTransducerDecoder {
|
||||
public:
|
||||
explicit OfflineTransducerGreedySearchDecoder(OfflineTransducerModel *model)
|
||||
: model_(model) {}
|
||||
|
||||
std::vector<OfflineTransducerDecoderResult> Decode(
|
||||
Ort::Value encoder_out, Ort::Value encoder_out_length) override;
|
||||
|
||||
private:
|
||||
OfflineTransducerModel *model_; // Not owned
|
||||
};
|
||||
|
||||
} // namespace sherpa_onnx
|
||||
|
||||
#endif // SHERPA_ONNX_CSRC_OFFLINE_TRANSDUCER_GREEDY_SEARCH_DECODER_H_
|
||||
68
sherpa-onnx/csrc/offline-transducer-model-config.cc
Normal file
68
sherpa-onnx/csrc/offline-transducer-model-config.cc
Normal file
@@ -0,0 +1,68 @@
|
||||
// sherpa-onnx/csrc/offline-transducer-model-config.cc
|
||||
//
|
||||
// Copyright (c) 2023 Xiaomi Corporation
|
||||
#include "sherpa-onnx/csrc/offline-transducer-model-config.h"
|
||||
|
||||
#include <sstream>
|
||||
|
||||
#include "sherpa-onnx/csrc/file-utils.h"
|
||||
#include "sherpa-onnx/csrc/macros.h"
|
||||
|
||||
namespace sherpa_onnx {
|
||||
|
||||
void OfflineTransducerModelConfig::Register(ParseOptions *po) {
|
||||
po->Register("encoder", &encoder_filename, "Path to encoder.onnx");
|
||||
po->Register("decoder", &decoder_filename, "Path to decoder.onnx");
|
||||
po->Register("joiner", &joiner_filename, "Path to joiner.onnx");
|
||||
po->Register("tokens", &tokens, "Path to tokens.txt");
|
||||
po->Register("num_threads", &num_threads,
|
||||
"Number of threads to run the neural network");
|
||||
|
||||
po->Register("debug", &debug,
|
||||
"true to print model information while loading it.");
|
||||
}
|
||||
|
||||
bool OfflineTransducerModelConfig::Validate() const {
|
||||
if (!FileExists(tokens)) {
|
||||
SHERPA_ONNX_LOGE("%s does not exist", tokens.c_str());
|
||||
return false;
|
||||
}
|
||||
|
||||
if (!FileExists(encoder_filename)) {
|
||||
SHERPA_ONNX_LOGE("%s does not exist", encoder_filename.c_str());
|
||||
return false;
|
||||
}
|
||||
|
||||
if (!FileExists(decoder_filename)) {
|
||||
SHERPA_ONNX_LOGE("%s does not exist", decoder_filename.c_str());
|
||||
return false;
|
||||
}
|
||||
|
||||
if (!FileExists(joiner_filename)) {
|
||||
SHERPA_ONNX_LOGE("%s does not exist", joiner_filename.c_str());
|
||||
return false;
|
||||
}
|
||||
|
||||
if (num_threads < 1) {
|
||||
SHERPA_ONNX_LOGE("num_threads should be > 0. Given %d", num_threads);
|
||||
return false;
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
std::string OfflineTransducerModelConfig::ToString() const {
|
||||
std::ostringstream os;
|
||||
|
||||
os << "OfflineTransducerModelConfig(";
|
||||
os << "encoder_filename=\"" << encoder_filename << "\", ";
|
||||
os << "decoder_filename=\"" << decoder_filename << "\", ";
|
||||
os << "joiner_filename=\"" << joiner_filename << "\", ";
|
||||
os << "tokens=\"" << tokens << "\", ";
|
||||
os << "num_threads=" << num_threads << ", ";
|
||||
os << "debug=" << (debug ? "True" : "False") << ")";
|
||||
|
||||
return os.str();
|
||||
}
|
||||
|
||||
} // namespace sherpa_onnx
|
||||
42
sherpa-onnx/csrc/offline-transducer-model-config.h
Normal file
42
sherpa-onnx/csrc/offline-transducer-model-config.h
Normal file
@@ -0,0 +1,42 @@
|
||||
// sherpa-onnx/csrc/offline-transducer-model-config.h
|
||||
//
|
||||
// Copyright (c) 2023 Xiaomi Corporation
|
||||
#ifndef SHERPA_ONNX_CSRC_OFFLINE_TRANSDUCER_MODEL_CONFIG_H_
|
||||
#define SHERPA_ONNX_CSRC_OFFLINE_TRANSDUCER_MODEL_CONFIG_H_
|
||||
|
||||
#include <string>
|
||||
|
||||
#include "sherpa-onnx/csrc/parse-options.h"
|
||||
|
||||
namespace sherpa_onnx {
|
||||
|
||||
struct OfflineTransducerModelConfig {
|
||||
std::string encoder_filename;
|
||||
std::string decoder_filename;
|
||||
std::string joiner_filename;
|
||||
std::string tokens;
|
||||
int32_t num_threads = 2;
|
||||
bool debug = false;
|
||||
|
||||
OfflineTransducerModelConfig() = default;
|
||||
OfflineTransducerModelConfig(const std::string &encoder_filename,
|
||||
const std::string &decoder_filename,
|
||||
const std::string &joiner_filename,
|
||||
const std::string &tokens, int32_t num_threads,
|
||||
bool debug)
|
||||
: encoder_filename(encoder_filename),
|
||||
decoder_filename(decoder_filename),
|
||||
joiner_filename(joiner_filename),
|
||||
tokens(tokens),
|
||||
num_threads(num_threads),
|
||||
debug(debug) {}
|
||||
|
||||
void Register(ParseOptions *po);
|
||||
bool Validate() const;
|
||||
|
||||
std::string ToString() const;
|
||||
};
|
||||
|
||||
} // namespace sherpa_onnx
|
||||
|
||||
#endif // SHERPA_ONNX_CSRC_OFFLINE_TRANSDUCER_MODEL_CONFIG_H_
|
||||
238
sherpa-onnx/csrc/offline-transducer-model.cc
Normal file
238
sherpa-onnx/csrc/offline-transducer-model.cc
Normal file
@@ -0,0 +1,238 @@
|
||||
// sherpa-onnx/csrc/offline-transducer-model.cc
|
||||
//
|
||||
// Copyright (c) 2023 Xiaomi Corporation
|
||||
|
||||
#include "sherpa-onnx/csrc/offline-transducer-model.h"
|
||||
|
||||
#include <algorithm>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
|
||||
#include "sherpa-onnx/csrc/macros.h"
|
||||
#include "sherpa-onnx/csrc/offline-transducer-decoder.h"
|
||||
#include "sherpa-onnx/csrc/onnx-utils.h"
|
||||
|
||||
namespace sherpa_onnx {
|
||||
|
||||
class OfflineTransducerModel::Impl {
|
||||
public:
|
||||
explicit Impl(const OfflineTransducerModelConfig &config)
|
||||
: config_(config),
|
||||
env_(ORT_LOGGING_LEVEL_WARNING),
|
||||
sess_opts_{},
|
||||
allocator_{} {
|
||||
sess_opts_.SetIntraOpNumThreads(config.num_threads);
|
||||
sess_opts_.SetInterOpNumThreads(config.num_threads);
|
||||
{
|
||||
auto buf = ReadFile(config.encoder_filename);
|
||||
InitEncoder(buf.data(), buf.size());
|
||||
}
|
||||
|
||||
{
|
||||
auto buf = ReadFile(config.decoder_filename);
|
||||
InitDecoder(buf.data(), buf.size());
|
||||
}
|
||||
|
||||
{
|
||||
auto buf = ReadFile(config.joiner_filename);
|
||||
InitJoiner(buf.data(), buf.size());
|
||||
}
|
||||
}
|
||||
|
||||
std::pair<Ort::Value, Ort::Value> RunEncoder(Ort::Value features,
|
||||
Ort::Value features_length) {
|
||||
std::array<Ort::Value, 2> encoder_inputs = {std::move(features),
|
||||
std::move(features_length)};
|
||||
|
||||
auto encoder_out = encoder_sess_->Run(
|
||||
{}, encoder_input_names_ptr_.data(), encoder_inputs.data(),
|
||||
encoder_inputs.size(), encoder_output_names_ptr_.data(),
|
||||
encoder_output_names_ptr_.size());
|
||||
|
||||
return {std::move(encoder_out[0]), std::move(encoder_out[1])};
|
||||
}
|
||||
|
||||
Ort::Value RunDecoder(Ort::Value decoder_input) {
|
||||
auto decoder_out = decoder_sess_->Run(
|
||||
{}, decoder_input_names_ptr_.data(), &decoder_input, 1,
|
||||
decoder_output_names_ptr_.data(), decoder_output_names_ptr_.size());
|
||||
return std::move(decoder_out[0]);
|
||||
}
|
||||
|
||||
Ort::Value RunJoiner(Ort::Value encoder_out, Ort::Value decoder_out) {
|
||||
std::array<Ort::Value, 2> joiner_input = {std::move(encoder_out),
|
||||
std::move(decoder_out)};
|
||||
auto logit = joiner_sess_->Run({}, joiner_input_names_ptr_.data(),
|
||||
joiner_input.data(), joiner_input.size(),
|
||||
joiner_output_names_ptr_.data(),
|
||||
joiner_output_names_ptr_.size());
|
||||
|
||||
return std::move(logit[0]);
|
||||
}
|
||||
|
||||
int32_t VocabSize() const { return vocab_size_; }
|
||||
int32_t ContextSize() const { return context_size_; }
|
||||
int32_t SubsamplingFactor() const { return 4; }
|
||||
OrtAllocator *Allocator() const { return allocator_; }
|
||||
|
||||
Ort::Value BuildDecoderInput(
|
||||
const std::vector<OfflineTransducerDecoderResult> &results,
|
||||
int32_t end_index) const {
|
||||
assert(end_index <= results.size());
|
||||
|
||||
int32_t batch_size = end_index;
|
||||
int32_t context_size = ContextSize();
|
||||
std::array<int64_t, 2> shape{batch_size, context_size};
|
||||
|
||||
Ort::Value decoder_input = Ort::Value::CreateTensor<int64_t>(
|
||||
Allocator(), shape.data(), shape.size());
|
||||
int64_t *p = decoder_input.GetTensorMutableData<int64_t>();
|
||||
|
||||
for (int32_t i = 0; i != batch_size; ++i) {
|
||||
const auto &r = results[i];
|
||||
const int64_t *begin = r.tokens.data() + r.tokens.size() - context_size;
|
||||
const int64_t *end = r.tokens.data() + r.tokens.size();
|
||||
std::copy(begin, end, p);
|
||||
p += context_size;
|
||||
}
|
||||
return decoder_input;
|
||||
}
|
||||
|
||||
private:
|
||||
void InitEncoder(void *model_data, size_t model_data_length) {
|
||||
encoder_sess_ = std::make_unique<Ort::Session>(
|
||||
env_, model_data, model_data_length, sess_opts_);
|
||||
|
||||
GetInputNames(encoder_sess_.get(), &encoder_input_names_,
|
||||
&encoder_input_names_ptr_);
|
||||
|
||||
GetOutputNames(encoder_sess_.get(), &encoder_output_names_,
|
||||
&encoder_output_names_ptr_);
|
||||
|
||||
// get meta data
|
||||
Ort::ModelMetadata meta_data = encoder_sess_->GetModelMetadata();
|
||||
if (config_.debug) {
|
||||
std::ostringstream os;
|
||||
os << "---encoder---\n";
|
||||
PrintModelMetadata(os, meta_data);
|
||||
SHERPA_ONNX_LOGE("%s\n", os.str().c_str());
|
||||
}
|
||||
}
|
||||
|
||||
void InitDecoder(void *model_data, size_t model_data_length) {
|
||||
decoder_sess_ = std::make_unique<Ort::Session>(
|
||||
env_, model_data, model_data_length, sess_opts_);
|
||||
|
||||
GetInputNames(decoder_sess_.get(), &decoder_input_names_,
|
||||
&decoder_input_names_ptr_);
|
||||
|
||||
GetOutputNames(decoder_sess_.get(), &decoder_output_names_,
|
||||
&decoder_output_names_ptr_);
|
||||
|
||||
// get meta data
|
||||
Ort::ModelMetadata meta_data = decoder_sess_->GetModelMetadata();
|
||||
if (config_.debug) {
|
||||
std::ostringstream os;
|
||||
os << "---decoder---\n";
|
||||
PrintModelMetadata(os, meta_data);
|
||||
SHERPA_ONNX_LOGE("%s\n", os.str().c_str());
|
||||
}
|
||||
|
||||
Ort::AllocatorWithDefaultOptions allocator; // used in the macro below
|
||||
SHERPA_ONNX_READ_META_DATA(vocab_size_, "vocab_size");
|
||||
SHERPA_ONNX_READ_META_DATA(context_size_, "context_size");
|
||||
}
|
||||
|
||||
void InitJoiner(void *model_data, size_t model_data_length) {
|
||||
joiner_sess_ = std::make_unique<Ort::Session>(
|
||||
env_, model_data, model_data_length, sess_opts_);
|
||||
|
||||
GetInputNames(joiner_sess_.get(), &joiner_input_names_,
|
||||
&joiner_input_names_ptr_);
|
||||
|
||||
GetOutputNames(joiner_sess_.get(), &joiner_output_names_,
|
||||
&joiner_output_names_ptr_);
|
||||
|
||||
// get meta data
|
||||
Ort::ModelMetadata meta_data = joiner_sess_->GetModelMetadata();
|
||||
if (config_.debug) {
|
||||
std::ostringstream os;
|
||||
os << "---joiner---\n";
|
||||
PrintModelMetadata(os, meta_data);
|
||||
SHERPA_ONNX_LOGE("%s\n", os.str().c_str());
|
||||
}
|
||||
}
|
||||
|
||||
private:
|
||||
OfflineTransducerModelConfig config_;
|
||||
Ort::Env env_;
|
||||
Ort::SessionOptions sess_opts_;
|
||||
Ort::AllocatorWithDefaultOptions allocator_;
|
||||
|
||||
std::unique_ptr<Ort::Session> encoder_sess_;
|
||||
std::unique_ptr<Ort::Session> decoder_sess_;
|
||||
std::unique_ptr<Ort::Session> joiner_sess_;
|
||||
|
||||
std::vector<std::string> encoder_input_names_;
|
||||
std::vector<const char *> encoder_input_names_ptr_;
|
||||
|
||||
std::vector<std::string> encoder_output_names_;
|
||||
std::vector<const char *> encoder_output_names_ptr_;
|
||||
|
||||
std::vector<std::string> decoder_input_names_;
|
||||
std::vector<const char *> decoder_input_names_ptr_;
|
||||
|
||||
std::vector<std::string> decoder_output_names_;
|
||||
std::vector<const char *> decoder_output_names_ptr_;
|
||||
|
||||
std::vector<std::string> joiner_input_names_;
|
||||
std::vector<const char *> joiner_input_names_ptr_;
|
||||
|
||||
std::vector<std::string> joiner_output_names_;
|
||||
std::vector<const char *> joiner_output_names_ptr_;
|
||||
|
||||
int32_t vocab_size_ = 0; // initialized in InitDecoder
|
||||
int32_t context_size_ = 0; // initialized in InitDecoder
|
||||
};
|
||||
|
||||
OfflineTransducerModel::OfflineTransducerModel(
|
||||
const OfflineTransducerModelConfig &config)
|
||||
: impl_(std::make_unique<Impl>(config)) {}
|
||||
|
||||
OfflineTransducerModel::~OfflineTransducerModel() = default;
|
||||
|
||||
std::pair<Ort::Value, Ort::Value> OfflineTransducerModel::RunEncoder(
|
||||
Ort::Value features, Ort::Value features_length) {
|
||||
return impl_->RunEncoder(std::move(features), std::move(features_length));
|
||||
}
|
||||
|
||||
Ort::Value OfflineTransducerModel::RunDecoder(Ort::Value decoder_input) {
|
||||
return impl_->RunDecoder(std::move(decoder_input));
|
||||
}
|
||||
|
||||
Ort::Value OfflineTransducerModel::RunJoiner(Ort::Value encoder_out,
|
||||
Ort::Value decoder_out) {
|
||||
return impl_->RunJoiner(std::move(encoder_out), std::move(decoder_out));
|
||||
}
|
||||
|
||||
int32_t OfflineTransducerModel::VocabSize() const { return impl_->VocabSize(); }
|
||||
|
||||
int32_t OfflineTransducerModel::ContextSize() const {
|
||||
return impl_->ContextSize();
|
||||
}
|
||||
|
||||
int32_t OfflineTransducerModel::SubsamplingFactor() const {
|
||||
return impl_->SubsamplingFactor();
|
||||
}
|
||||
|
||||
OrtAllocator *OfflineTransducerModel::Allocator() const {
|
||||
return impl_->Allocator();
|
||||
}
|
||||
|
||||
Ort::Value OfflineTransducerModel::BuildDecoderInput(
|
||||
const std::vector<OfflineTransducerDecoderResult> &results,
|
||||
int32_t end_index) const {
|
||||
return impl_->BuildDecoderInput(results, end_index);
|
||||
}
|
||||
|
||||
} // namespace sherpa_onnx
|
||||
95
sherpa-onnx/csrc/offline-transducer-model.h
Normal file
95
sherpa-onnx/csrc/offline-transducer-model.h
Normal file
@@ -0,0 +1,95 @@
|
||||
// sherpa-onnx/csrc/offline-transducer-model.h
|
||||
//
|
||||
// Copyright (c) 2023 Xiaomi Corporation
|
||||
#ifndef SHERPA_ONNX_CSRC_OFFLINE_TRANSDUCER_MODEL_H_
|
||||
#define SHERPA_ONNX_CSRC_OFFLINE_TRANSDUCER_MODEL_H_
|
||||
|
||||
#include <memory>
|
||||
#include <utility>
|
||||
#include <vector>
|
||||
|
||||
#include "onnxruntime_cxx_api.h" // NOLINT
|
||||
#include "sherpa-onnx/csrc/offline-transducer-model-config.h"
|
||||
|
||||
namespace sherpa_onnx {
|
||||
|
||||
struct OfflineTransducerDecoderResult;
|
||||
|
||||
class OfflineTransducerModel {
|
||||
public:
|
||||
explicit OfflineTransducerModel(const OfflineTransducerModelConfig &config);
|
||||
~OfflineTransducerModel();
|
||||
|
||||
/** Run the encoder.
|
||||
*
|
||||
* @param features A tensor of shape (N, T, C). It is changed in-place.
|
||||
* @param features_length A 1-D tensor of shape (N,) containing number of
|
||||
* valid frames in `features` before padding.
|
||||
*
|
||||
* @return Return a pair containing:
|
||||
* - encoder_out: A 3-D tensor of shape (N, T', encoder_dim)
|
||||
* - encoder_out_length: A 1-D tensor of shape (N,) containing number
|
||||
* of frames in `encoder_out` before padding.
|
||||
*/
|
||||
std::pair<Ort::Value, Ort::Value> RunEncoder(Ort::Value features,
|
||||
Ort::Value features_length);
|
||||
|
||||
/** Run the decoder network.
|
||||
*
|
||||
* Caution: We assume there are no recurrent connections in the decoder and
|
||||
* the decoder is stateless. See
|
||||
* https://github.com/k2-fsa/icefall/blob/master/egs/librispeech/ASR/pruned_transducer_stateless2/decoder.py
|
||||
* for an example
|
||||
*
|
||||
* @param decoder_input It is usually of shape (N, context_size)
|
||||
* @return Return a tensor of shape (N, decoder_dim).
|
||||
*/
|
||||
Ort::Value RunDecoder(Ort::Value decoder_input);
|
||||
|
||||
/** Run the joint network.
|
||||
*
|
||||
* @param encoder_out Output of the encoder network. A tensor of shape
|
||||
* (N, joiner_dim).
|
||||
* @param decoder_out Output of the decoder network. A tensor of shape
|
||||
* (N, joiner_dim).
|
||||
* @return Return a tensor of shape (N, vocab_size). In icefall, the last
|
||||
* last layer of the joint network is `nn.Linear`,
|
||||
* not `nn.LogSoftmax`.
|
||||
*/
|
||||
Ort::Value RunJoiner(Ort::Value encoder_out, Ort::Value decoder_out);
|
||||
|
||||
/** Return the vocabulary size of the model
|
||||
*/
|
||||
int32_t VocabSize() const;
|
||||
|
||||
/** Return the context_size of the decoder model.
|
||||
*/
|
||||
int32_t ContextSize() const;
|
||||
|
||||
/** Return the subsampling factor of the model.
|
||||
*/
|
||||
int32_t SubsamplingFactor() const;
|
||||
|
||||
/** Return an allocator for allocating memory
|
||||
*/
|
||||
OrtAllocator *Allocator() const;
|
||||
|
||||
/** Build decoder_input from the current results.
|
||||
*
|
||||
* @param results Current decoded results.
|
||||
* @param end_index We only use results[0:end_index] to build
|
||||
* the decoder_input.
|
||||
* @return Return a tensor of shape (results.size(), ContextSize())
|
||||
*/
|
||||
Ort::Value BuildDecoderInput(
|
||||
const std::vector<OfflineTransducerDecoderResult> &results,
|
||||
int32_t end_index) const;
|
||||
|
||||
private:
|
||||
class Impl;
|
||||
std::unique_ptr<Impl> impl_;
|
||||
};
|
||||
|
||||
} // namespace sherpa_onnx
|
||||
|
||||
#endif // SHERPA_ONNX_CSRC_OFFLINE_TRANSDUCER_MODEL_H_
|
||||
@@ -95,7 +95,7 @@ void OnlineLstmTransducerModel::InitEncoder(void *model_data,
|
||||
std::ostringstream os;
|
||||
os << "---encoder---\n";
|
||||
PrintModelMetadata(os, meta_data);
|
||||
fprintf(stderr, "%s\n", os.str().c_str());
|
||||
SHERPA_ONNX_LOGE("%s", os.str().c_str());
|
||||
}
|
||||
|
||||
Ort::AllocatorWithDefaultOptions allocator; // used in the macro below
|
||||
@@ -123,7 +123,7 @@ void OnlineLstmTransducerModel::InitDecoder(void *model_data,
|
||||
std::ostringstream os;
|
||||
os << "---decoder---\n";
|
||||
PrintModelMetadata(os, meta_data);
|
||||
fprintf(stderr, "%s\n", os.str().c_str());
|
||||
SHERPA_ONNX_LOGE("%s", os.str().c_str());
|
||||
}
|
||||
|
||||
Ort::AllocatorWithDefaultOptions allocator; // used in the macro below
|
||||
@@ -148,7 +148,7 @@ void OnlineLstmTransducerModel::InitJoiner(void *model_data,
|
||||
std::ostringstream os;
|
||||
os << "---joiner---\n";
|
||||
PrintModelMetadata(os, meta_data);
|
||||
fprintf(stderr, "%s\n", os.str().c_str());
|
||||
SHERPA_ONNX_LOGE("%s", os.str().c_str());
|
||||
}
|
||||
}
|
||||
|
||||
@@ -228,9 +228,6 @@ std::vector<Ort::Value> OnlineLstmTransducerModel::GetEncoderInitStates() {
|
||||
std::pair<Ort::Value, std::vector<Ort::Value>>
|
||||
OnlineLstmTransducerModel::RunEncoder(Ort::Value features,
|
||||
std::vector<Ort::Value> states) {
|
||||
auto memory_info =
|
||||
Ort::MemoryInfo::CreateCpu(OrtDeviceAllocator, OrtMemTypeDefault);
|
||||
|
||||
std::array<Ort::Value, 3> encoder_inputs = {
|
||||
std::move(features), std::move(states[0]), std::move(states[1])};
|
||||
|
||||
|
||||
@@ -20,7 +20,7 @@ class OnlineStream::Impl {
|
||||
feat_extractor_.AcceptWaveform(sampling_rate, waveform, n);
|
||||
}
|
||||
|
||||
void InputFinished() { feat_extractor_.InputFinished(); }
|
||||
void InputFinished() const { feat_extractor_.InputFinished(); }
|
||||
|
||||
int32_t NumFramesReady() const {
|
||||
return feat_extractor_.NumFramesReady() - start_frame_index_;
|
||||
@@ -68,11 +68,11 @@ OnlineStream::OnlineStream(const FeatureExtractorConfig &config /*= {}*/)
|
||||
OnlineStream::~OnlineStream() = default;
|
||||
|
||||
void OnlineStream::AcceptWaveform(int32_t sampling_rate, const float *waveform,
|
||||
int32_t n) {
|
||||
int32_t n) const {
|
||||
impl_->AcceptWaveform(sampling_rate, waveform, n);
|
||||
}
|
||||
|
||||
void OnlineStream::InputFinished() { impl_->InputFinished(); }
|
||||
void OnlineStream::InputFinished() const { impl_->InputFinished(); }
|
||||
|
||||
int32_t OnlineStream::NumFramesReady() const { return impl_->NumFramesReady(); }
|
||||
|
||||
|
||||
@@ -27,7 +27,8 @@ class OnlineStream {
|
||||
the range [-1, 1].
|
||||
@param n Number of entries in waveform
|
||||
*/
|
||||
void AcceptWaveform(int32_t sampling_rate, const float *waveform, int32_t n);
|
||||
void AcceptWaveform(int32_t sampling_rate, const float *waveform,
|
||||
int32_t n) const;
|
||||
|
||||
/**
|
||||
* InputFinished() tells the class you won't be providing any
|
||||
@@ -35,7 +36,7 @@ class OnlineStream {
|
||||
* of features, in the case where snip-edges == false; it also
|
||||
* affects the return value of IsLastFrame().
|
||||
*/
|
||||
void InputFinished();
|
||||
void InputFinished() const;
|
||||
|
||||
int32_t NumFramesReady() const;
|
||||
|
||||
|
||||
@@ -248,14 +248,21 @@ int32_t main(int32_t argc, char *argv[]) {
|
||||
std::string wave_filename = po.GetArg(1);
|
||||
|
||||
bool is_ok = false;
|
||||
int32_t actual_sample_rate = -1;
|
||||
std::vector<float> samples =
|
||||
sherpa_onnx::ReadWave(wave_filename, sample_rate, &is_ok);
|
||||
sherpa_onnx::ReadWave(wave_filename, &actual_sample_rate, &is_ok);
|
||||
|
||||
if (!is_ok) {
|
||||
SHERPA_ONNX_LOGE("Failed to read %s", wave_filename.c_str());
|
||||
return -1;
|
||||
}
|
||||
|
||||
if (actual_sample_rate != sample_rate) {
|
||||
SHERPA_ONNX_LOGE("Expected sample rate: %d, given %d", sample_rate,
|
||||
actual_sample_rate);
|
||||
return -1;
|
||||
}
|
||||
|
||||
asio::io_context io_conn; // for network connections
|
||||
Client c(io_conn, server_ip, server_port, samples, samples_per_message,
|
||||
seconds_per_message);
|
||||
|
||||
@@ -97,7 +97,7 @@ void OnlineZipformerTransducerModel::InitEncoder(void *model_data,
|
||||
std::ostringstream os;
|
||||
os << "---encoder---\n";
|
||||
PrintModelMetadata(os, meta_data);
|
||||
fprintf(stderr, "%s\n", os.str().c_str());
|
||||
SHERPA_ONNX_LOGE("%s", os.str().c_str());
|
||||
}
|
||||
|
||||
Ort::AllocatorWithDefaultOptions allocator; // used in the macro below
|
||||
@@ -123,8 +123,8 @@ void OnlineZipformerTransducerModel::InitEncoder(void *model_data,
|
||||
print(num_encoder_layers_, "num_encoder_layers");
|
||||
print(cnn_module_kernels_, "cnn_module_kernels");
|
||||
print(left_context_len_, "left_context_len");
|
||||
fprintf(stderr, "T: %d\n", T_);
|
||||
fprintf(stderr, "decode_chunk_len_: %d\n", decode_chunk_len_);
|
||||
SHERPA_ONNX_LOGE("T: %d", T_);
|
||||
SHERPA_ONNX_LOGE("decode_chunk_len_: %d", decode_chunk_len_);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -145,7 +145,7 @@ void OnlineZipformerTransducerModel::InitDecoder(void *model_data,
|
||||
std::ostringstream os;
|
||||
os << "---decoder---\n";
|
||||
PrintModelMetadata(os, meta_data);
|
||||
fprintf(stderr, "%s\n", os.str().c_str());
|
||||
SHERPA_ONNX_LOGE("%s", os.str().c_str());
|
||||
}
|
||||
|
||||
Ort::AllocatorWithDefaultOptions allocator; // used in the macro below
|
||||
@@ -170,7 +170,7 @@ void OnlineZipformerTransducerModel::InitJoiner(void *model_data,
|
||||
std::ostringstream os;
|
||||
os << "---joiner---\n";
|
||||
PrintModelMetadata(os, meta_data);
|
||||
fprintf(stderr, "%s\n", os.str().c_str());
|
||||
SHERPA_ONNX_LOGE("%s", os.str().c_str());
|
||||
}
|
||||
}
|
||||
|
||||
@@ -435,9 +435,6 @@ std::vector<Ort::Value> OnlineZipformerTransducerModel::GetEncoderInitStates() {
|
||||
std::pair<Ort::Value, std::vector<Ort::Value>>
|
||||
OnlineZipformerTransducerModel::RunEncoder(Ort::Value features,
|
||||
std::vector<Ort::Value> states) {
|
||||
auto memory_info =
|
||||
Ort::MemoryInfo::CreateCpu(OrtDeviceAllocator, OrtMemTypeDefault);
|
||||
|
||||
std::vector<Ort::Value> encoder_inputs;
|
||||
encoder_inputs.reserve(1 + states.size());
|
||||
|
||||
|
||||
@@ -41,7 +41,7 @@ PackedSequence PackPaddedSequence(OrtAllocator *allocator,
|
||||
std::vector<int64_t> l_shape = length->GetTensorTypeAndShapeInfo().GetShape();
|
||||
|
||||
assert(v_shape.size() == 3);
|
||||
assert(l_shape.size() == 3);
|
||||
assert(l_shape.size() == 1);
|
||||
assert(v_shape[0] == l_shape[0]);
|
||||
|
||||
std::vector<int32_t> indexes(v_shape[0]);
|
||||
|
||||
@@ -13,7 +13,26 @@ namespace sherpa_onnx {
|
||||
struct PackedSequence {
|
||||
std::vector<int32_t> sorted_indexes;
|
||||
std::vector<int32_t> batch_sizes;
|
||||
|
||||
// data is a 2-D tensor of shape (sum(batch_sizes), channels)
|
||||
Ort::Value data{nullptr};
|
||||
|
||||
// Return a shallow copy of data[start:start+size, :]
|
||||
Ort::Value Get(int32_t start, int32_t size) {
|
||||
auto shape = data.GetTensorTypeAndShapeInfo().GetShape();
|
||||
|
||||
std::array<int64_t, 2> ans_shape{size, shape[1]};
|
||||
|
||||
float *p = data.GetTensorMutableData<float>();
|
||||
|
||||
auto memory_info =
|
||||
Ort::MemoryInfo::CreateCpu(OrtDeviceAllocator, OrtMemTypeDefault);
|
||||
|
||||
// a shallow copy
|
||||
return Ort::Value::CreateTensor(memory_info, p + start * shape[1],
|
||||
size * shape[1], ans_shape.data(),
|
||||
ans_shape.size());
|
||||
}
|
||||
};
|
||||
|
||||
/** Similar to torch.nn.utils.rnn.pad_sequence but it supports only
|
||||
|
||||
@@ -46,7 +46,7 @@ I Gcd(I m, I n) {
|
||||
// this function is copied from kaldi/src/base/kaldi-math.h
|
||||
if (m == 0 || n == 0) {
|
||||
if (m == 0 && n == 0) { // gcd not defined, as all integers are divisors.
|
||||
fprintf(stderr, "Undefined GCD since m = 0, n = 0.");
|
||||
fprintf(stderr, "Undefined GCD since m = 0, n = 0.\n");
|
||||
exit(-1);
|
||||
}
|
||||
return (m == 0 ? (n > 0 ? n : -n) : (m > 0 ? m : -m));
|
||||
|
||||
@@ -95,6 +95,10 @@ as the device_name.
|
||||
|
||||
fprintf(stderr, "%s\n", config.ToString().c_str());
|
||||
|
||||
if (!config.Validate()) {
|
||||
fprintf(stderr, "Errors in config!\n");
|
||||
return -1;
|
||||
}
|
||||
sherpa_onnx::OnlineRecognizer recognizer(config);
|
||||
|
||||
int32_t expected_sample_rate = config.feat_config.sampling_rate;
|
||||
|
||||
@@ -86,6 +86,11 @@ for a list of pre-trained models to download.
|
||||
|
||||
fprintf(stderr, "%s\n", config.ToString().c_str());
|
||||
|
||||
if (!config.Validate()) {
|
||||
fprintf(stderr, "Errors in config!\n");
|
||||
return -1;
|
||||
}
|
||||
|
||||
sherpa_onnx::OnlineRecognizer recognizer(config);
|
||||
auto s = recognizer.CreateStream();
|
||||
|
||||
|
||||
115
sherpa-onnx/csrc/sherpa-onnx-offline.cc
Normal file
115
sherpa-onnx/csrc/sherpa-onnx-offline.cc
Normal file
@@ -0,0 +1,115 @@
|
||||
// sherpa-onnx/csrc/sherpa-onnx-offline.cc
|
||||
//
|
||||
// Copyright (c) 2022-2023 Xiaomi Corporation
|
||||
|
||||
#include <stdio.h>
|
||||
|
||||
#include <chrono> // NOLINT
|
||||
#include <string>
|
||||
#include <vector>
|
||||
|
||||
#include "sherpa-onnx/csrc/offline-recognizer.h"
|
||||
#include "sherpa-onnx/csrc/offline-stream.h"
|
||||
#include "sherpa-onnx/csrc/offline-transducer-decoder.h"
|
||||
#include "sherpa-onnx/csrc/offline-transducer-greedy-search-decoder.h"
|
||||
#include "sherpa-onnx/csrc/offline-transducer-model.h"
|
||||
#include "sherpa-onnx/csrc/pad-sequence.h"
|
||||
#include "sherpa-onnx/csrc/symbol-table.h"
|
||||
#include "sherpa-onnx/csrc/wave-reader.h"
|
||||
|
||||
int main(int32_t argc, char *argv[]) {
|
||||
if (argc < 6 || argc > 8) {
|
||||
const char *usage = R"usage(
|
||||
Usage:
|
||||
./bin/sherpa-onnx-offline \
|
||||
/path/to/tokens.txt \
|
||||
/path/to/encoder.onnx \
|
||||
/path/to/decoder.onnx \
|
||||
/path/to/joiner.onnx \
|
||||
/path/to/foo.wav [num_threads [decoding_method]]
|
||||
|
||||
Default value for num_threads is 2.
|
||||
Valid values for decoding_method: greedy_search.
|
||||
foo.wav should be of single channel, 16-bit PCM encoded wave file; its
|
||||
sampling rate can be arbitrary and does not need to be 16kHz.
|
||||
|
||||
Please refer to
|
||||
https://k2-fsa.github.io/sherpa/onnx/pretrained_models/index.html
|
||||
for a list of pre-trained models to download.
|
||||
)usage";
|
||||
fprintf(stderr, "%s\n", usage);
|
||||
|
||||
return 0;
|
||||
}
|
||||
|
||||
sherpa_onnx::OfflineRecognizerConfig config;
|
||||
|
||||
config.model_config.tokens = argv[1];
|
||||
|
||||
config.model_config.debug = false;
|
||||
config.model_config.encoder_filename = argv[2];
|
||||
config.model_config.decoder_filename = argv[3];
|
||||
config.model_config.joiner_filename = argv[4];
|
||||
|
||||
std::string wav_filename = argv[5];
|
||||
|
||||
config.model_config.num_threads = 2;
|
||||
if (argc == 7 && atoi(argv[6]) > 0) {
|
||||
config.model_config.num_threads = atoi(argv[6]);
|
||||
}
|
||||
|
||||
if (argc == 8) {
|
||||
config.decoding_method = argv[7];
|
||||
}
|
||||
|
||||
fprintf(stderr, "%s\n", config.ToString().c_str());
|
||||
|
||||
if (!config.Validate()) {
|
||||
fprintf(stderr, "Errors in config!\n");
|
||||
return -1;
|
||||
}
|
||||
|
||||
int32_t sampling_rate = -1;
|
||||
|
||||
bool is_ok = false;
|
||||
std::vector<float> samples =
|
||||
sherpa_onnx::ReadWave(wav_filename, &sampling_rate, &is_ok);
|
||||
if (!is_ok) {
|
||||
fprintf(stderr, "Failed to read %s\n", wav_filename.c_str());
|
||||
return -1;
|
||||
}
|
||||
fprintf(stderr, "sampling rate of input file: %d\n", sampling_rate);
|
||||
|
||||
float duration = samples.size() / static_cast<float>(sampling_rate);
|
||||
|
||||
sherpa_onnx::OfflineRecognizer recognizer(config);
|
||||
auto s = recognizer.CreateStream();
|
||||
|
||||
auto begin = std::chrono::steady_clock::now();
|
||||
fprintf(stderr, "Started\n");
|
||||
|
||||
s->AcceptWaveform(sampling_rate, samples.data(), samples.size());
|
||||
|
||||
recognizer.DecodeStream(s.get());
|
||||
|
||||
fprintf(stderr, "Done!\n");
|
||||
|
||||
fprintf(stderr, "Recognition result for %s:\n%s\n", wav_filename.c_str(),
|
||||
s->GetResult().text.c_str());
|
||||
|
||||
auto end = std::chrono::steady_clock::now();
|
||||
float elapsed_seconds =
|
||||
std::chrono::duration_cast<std::chrono::milliseconds>(end - begin)
|
||||
.count() /
|
||||
1000.;
|
||||
|
||||
fprintf(stderr, "num threads: %d\n", config.model_config.num_threads);
|
||||
fprintf(stderr, "decoding method: %s\n", config.decoding_method.c_str());
|
||||
|
||||
fprintf(stderr, "Elapsed seconds: %.3f s\n", elapsed_seconds);
|
||||
float rtf = elapsed_seconds / duration;
|
||||
fprintf(stderr, "Real time factor (RTF): %.3f / %.3f = %.3f\n",
|
||||
elapsed_seconds, duration, rtf);
|
||||
|
||||
return 0;
|
||||
}
|
||||
@@ -26,6 +26,8 @@ Usage:
|
||||
|
||||
Default value for num_threads is 2.
|
||||
Valid values for decoding_method: greedy_search (default), modified_beam_search.
|
||||
foo.wav should be of single channel, 16-bit PCM encoded wave file; its
|
||||
sampling rate can be arbitrary and does not need to be 16kHz.
|
||||
|
||||
Please refer to
|
||||
https://k2-fsa.github.io/sherpa/onnx/pretrained_models/index.html
|
||||
@@ -59,20 +61,26 @@ for a list of pre-trained models to download.
|
||||
|
||||
fprintf(stderr, "%s\n", config.ToString().c_str());
|
||||
|
||||
if (!config.Validate()) {
|
||||
fprintf(stderr, "Errors in config!\n");
|
||||
return -1;
|
||||
}
|
||||
|
||||
sherpa_onnx::OnlineRecognizer recognizer(config);
|
||||
|
||||
int32_t expected_sampling_rate = config.feat_config.sampling_rate;
|
||||
int32_t sampling_rate = -1;
|
||||
|
||||
bool is_ok = false;
|
||||
std::vector<float> samples =
|
||||
sherpa_onnx::ReadWave(wav_filename, expected_sampling_rate, &is_ok);
|
||||
sherpa_onnx::ReadWave(wav_filename, &sampling_rate, &is_ok);
|
||||
|
||||
if (!is_ok) {
|
||||
fprintf(stderr, "Failed to read %s\n", wav_filename.c_str());
|
||||
return -1;
|
||||
}
|
||||
fprintf(stderr, "sampling rate of input file: %d\n", sampling_rate);
|
||||
|
||||
float duration = samples.size() / static_cast<float>(expected_sampling_rate);
|
||||
float duration = samples.size() / static_cast<float>(sampling_rate);
|
||||
|
||||
fprintf(stderr, "wav filename: %s\n", wav_filename.c_str());
|
||||
fprintf(stderr, "wav duration (s): %.3f\n", duration);
|
||||
@@ -81,12 +89,13 @@ for a list of pre-trained models to download.
|
||||
fprintf(stderr, "Started\n");
|
||||
|
||||
auto s = recognizer.CreateStream();
|
||||
s->AcceptWaveform(expected_sampling_rate, samples.data(), samples.size());
|
||||
s->AcceptWaveform(sampling_rate, samples.data(), samples.size());
|
||||
|
||||
std::vector<float> tail_paddings(
|
||||
static_cast<int>(0.2 * expected_sampling_rate));
|
||||
s->AcceptWaveform(expected_sampling_rate, tail_paddings.data(),
|
||||
tail_paddings.size());
|
||||
std::vector<float> tail_paddings(static_cast<int>(0.2 * sampling_rate));
|
||||
// Note: We can call AcceptWaveform() multiple times.
|
||||
s->AcceptWaveform(sampling_rate, tail_paddings.data(), tail_paddings.size());
|
||||
|
||||
// Call InputFinished() to indicate that no audio samples are available
|
||||
s->InputFinished();
|
||||
|
||||
while (recognizer.IsReady(s.get())) {
|
||||
|
||||
@@ -30,4 +30,23 @@ TEST(Slice, Slice3D) {
|
||||
// TODO(fangjun): Check that the results are correct
|
||||
}
|
||||
|
||||
TEST(Slice, Slice2D) {
|
||||
Ort::AllocatorWithDefaultOptions allocator;
|
||||
std::array<int64_t, 2> shape{5, 8};
|
||||
Ort::Value v =
|
||||
Ort::Value::CreateTensor<float>(allocator, shape.data(), shape.size());
|
||||
float *p = v.GetTensorMutableData<float>();
|
||||
|
||||
std::iota(p, p + shape[0] * shape[1], 0);
|
||||
|
||||
auto v1 = Slice(allocator, &v, 1, 3);
|
||||
auto v2 = Slice(allocator, &v, 0, 2);
|
||||
|
||||
Print2D(&v);
|
||||
Print2D(&v1);
|
||||
Print2D(&v2);
|
||||
|
||||
// TODO(fangjun): Check that the results are correct
|
||||
}
|
||||
|
||||
} // namespace sherpa_onnx
|
||||
|
||||
@@ -24,7 +24,7 @@ Ort::Value Slice(OrtAllocator *allocator, const Ort::Value *v,
|
||||
|
||||
assert(0 <= dim1_start);
|
||||
assert(dim1_start < dim1_end);
|
||||
assert(dim1_end < shape[1]);
|
||||
assert(dim1_end <= shape[1]);
|
||||
|
||||
const T *src = v->GetTensorData<T>();
|
||||
|
||||
@@ -46,8 +46,35 @@ Ort::Value Slice(OrtAllocator *allocator, const Ort::Value *v,
|
||||
return ans;
|
||||
}
|
||||
|
||||
template <typename T /*= float*/>
|
||||
Ort::Value Slice(OrtAllocator *allocator, const Ort::Value *v,
|
||||
int32_t dim0_start, int32_t dim0_end) {
|
||||
std::vector<int64_t> shape = v->GetTensorTypeAndShapeInfo().GetShape();
|
||||
assert(shape.size() == 2);
|
||||
|
||||
assert(0 <= dim0_start);
|
||||
assert(dim0_start < dim0_end);
|
||||
assert(dim0_end <= shape[0]);
|
||||
|
||||
const T *src = v->GetTensorData<T>();
|
||||
|
||||
std::array<int64_t, 2> ans_shape{dim0_end - dim0_start, shape[1]};
|
||||
|
||||
Ort::Value ans = Ort::Value::CreateTensor<T>(allocator, ans_shape.data(),
|
||||
ans_shape.size());
|
||||
const T *start = v->GetTensorData<T>() + dim0_start * shape[1];
|
||||
const T *end = v->GetTensorData<T>() + dim0_end * shape[1];
|
||||
T *dst = ans.GetTensorMutableData<T>();
|
||||
std::copy(start, end, dst);
|
||||
|
||||
return ans;
|
||||
}
|
||||
|
||||
template Ort::Value Slice<float>(OrtAllocator *allocator, const Ort::Value *v,
|
||||
int32_t dim0_start, int32_t dim0_end,
|
||||
int32_t dim1_start, int32_t dim1_end);
|
||||
|
||||
template Ort::Value Slice<float>(OrtAllocator *allocator, const Ort::Value *v,
|
||||
int32_t dim0_start, int32_t dim0_end);
|
||||
|
||||
} // namespace sherpa_onnx
|
||||
|
||||
@@ -8,12 +8,12 @@
|
||||
|
||||
namespace sherpa_onnx {
|
||||
|
||||
/** Get a deep copy by slicing v.
|
||||
/** Get a deep copy by slicing a 3-D tensor v.
|
||||
*
|
||||
* It returns v[dim0_start:dim0_end, dim1_start:dim1_end]
|
||||
* It returns v[dim0_start:dim0_end, dim1_start:dim1_end, :]
|
||||
*
|
||||
* @param allocator
|
||||
* @param v A 3-D tensor. Its data type is T.
|
||||
* @param v A 2-D tensor. Its data type is T.
|
||||
* @param dim0_start Start index of the first dimension..
|
||||
* @param dim0_end End index of the first dimension..
|
||||
* @param dim1_start Start index of the second dimension.
|
||||
@@ -26,6 +26,23 @@ template <typename T = float>
|
||||
Ort::Value Slice(OrtAllocator *allocator, const Ort::Value *v,
|
||||
int32_t dim0_start, int32_t dim0_end, int32_t dim1_start,
|
||||
int32_t dim1_end);
|
||||
|
||||
/** Get a deep copy by slicing a 2-D tensor v.
|
||||
*
|
||||
* It returns v[dim0_start:dim0_end, :]
|
||||
*
|
||||
* @param allocator
|
||||
* @param v A 2-D tensor. Its data type is T.
|
||||
* @param dim0_start Start index of the first dimension..
|
||||
* @param dim0_end End index of the first dimension..
|
||||
*
|
||||
* @return Return a 2-D tensor of shape
|
||||
* (dim0_end-dim0_start, v.shape[1])
|
||||
*/
|
||||
template <typename T = float>
|
||||
Ort::Value Slice(OrtAllocator *allocator, const Ort::Value *v,
|
||||
int32_t dim0_start, int32_t dim0_end);
|
||||
|
||||
} // namespace sherpa_onnx
|
||||
|
||||
#endif // SHERPA_ONNX_CSRC_SLICE_H_
|
||||
|
||||
@@ -6,10 +6,11 @@
|
||||
|
||||
#include <cassert>
|
||||
#include <fstream>
|
||||
#include <iostream>
|
||||
#include <utility>
|
||||
#include <vector>
|
||||
|
||||
#include "sherpa-onnx/csrc/macros.h"
|
||||
|
||||
namespace sherpa_onnx {
|
||||
namespace {
|
||||
// see http://soundfile.sapp.org/doc/WaveFormat/
|
||||
@@ -20,26 +21,34 @@ struct WaveHeader {
|
||||
bool Validate() const {
|
||||
// F F I R
|
||||
if (chunk_id != 0x46464952) {
|
||||
SHERPA_ONNX_LOGE("Expected chunk_id RIFF. Given: 0x%08x\n", chunk_id);
|
||||
return false;
|
||||
}
|
||||
// E V A W
|
||||
if (format != 0x45564157) {
|
||||
SHERPA_ONNX_LOGE("Expected format WAVE. Given: 0x%08x\n", format);
|
||||
return false;
|
||||
}
|
||||
|
||||
if (subchunk1_id != 0x20746d66) {
|
||||
SHERPA_ONNX_LOGE("Expected subchunk1_id 0x20746d66. Given: 0x%08x\n",
|
||||
subchunk1_id);
|
||||
return false;
|
||||
}
|
||||
|
||||
if (subchunk1_size != 16) { // 16 for PCM
|
||||
SHERPA_ONNX_LOGE("Expected subchunk1_size 16. Given: %d\n",
|
||||
subchunk1_size);
|
||||
return false;
|
||||
}
|
||||
|
||||
if (audio_format != 1) { // 1 for PCM
|
||||
SHERPA_ONNX_LOGE("Expected audio_format 1. Given: %d\n", audio_format);
|
||||
return false;
|
||||
}
|
||||
|
||||
if (num_channels != 1) { // we support only single channel for now
|
||||
SHERPA_ONNX_LOGE("Expected single channel. Given: %d\n", num_channels);
|
||||
return false;
|
||||
}
|
||||
if (byte_rate != (sample_rate * num_channels * bits_per_sample / 8)) {
|
||||
@@ -51,6 +60,8 @@ struct WaveHeader {
|
||||
}
|
||||
|
||||
if (bits_per_sample != 16) { // we support only 16 bits per sample
|
||||
SHERPA_ONNX_LOGE("Expected bits_per_sample 16. Given: %d\n",
|
||||
bits_per_sample);
|
||||
return false;
|
||||
}
|
||||
|
||||
@@ -62,7 +73,7 @@ struct WaveHeader {
|
||||
// and
|
||||
// https://www.robotplanet.dk/audio/wav_meta_data/riff_mci.pdf
|
||||
void SeekToDataChunk(std::istream &is) {
|
||||
// a t a d
|
||||
// a t a d
|
||||
while (is && subchunk2_id != 0x61746164) {
|
||||
// const char *p = reinterpret_cast<const char *>(&subchunk2_id);
|
||||
// printf("Skip chunk (%x): %c%c%c%c of size: %d\n", subchunk2_id, p[0],
|
||||
@@ -91,7 +102,7 @@ static_assert(sizeof(WaveHeader) == 44, "");
|
||||
|
||||
// Read a wave file of mono-channel.
|
||||
// Return its samples normalized to the range [-1, 1).
|
||||
std::vector<float> ReadWaveImpl(std::istream &is, float expected_sample_rate,
|
||||
std::vector<float> ReadWaveImpl(std::istream &is, int32_t *sampling_rate,
|
||||
bool *is_ok) {
|
||||
WaveHeader header;
|
||||
is.read(reinterpret_cast<char *>(&header), sizeof(header));
|
||||
@@ -111,10 +122,7 @@ std::vector<float> ReadWaveImpl(std::istream &is, float expected_sample_rate,
|
||||
return {};
|
||||
}
|
||||
|
||||
if (expected_sample_rate != header.sample_rate) {
|
||||
*is_ok = false;
|
||||
return {};
|
||||
}
|
||||
*sampling_rate = header.sample_rate;
|
||||
|
||||
// header.subchunk2_size contains the number of bytes in the data.
|
||||
// As we assume each sample contains two bytes, so it is divided by 2 here
|
||||
@@ -137,15 +145,15 @@ std::vector<float> ReadWaveImpl(std::istream &is, float expected_sample_rate,
|
||||
|
||||
} // namespace
|
||||
|
||||
std::vector<float> ReadWave(const std::string &filename,
|
||||
float expected_sample_rate, bool *is_ok) {
|
||||
std::vector<float> ReadWave(const std::string &filename, int32_t *sampling_rate,
|
||||
bool *is_ok) {
|
||||
std::ifstream is(filename, std::ifstream::binary);
|
||||
return ReadWave(is, expected_sample_rate, is_ok);
|
||||
return ReadWave(is, sampling_rate, is_ok);
|
||||
}
|
||||
|
||||
std::vector<float> ReadWave(std::istream &is, float expected_sample_rate,
|
||||
std::vector<float> ReadWave(std::istream &is, int32_t *sampling_rate,
|
||||
bool *is_ok) {
|
||||
auto samples = ReadWaveImpl(is, expected_sample_rate, is_ok);
|
||||
auto samples = ReadWaveImpl(is, sampling_rate, is_ok);
|
||||
return samples;
|
||||
}
|
||||
|
||||
|
||||
@@ -13,17 +13,17 @@ namespace sherpa_onnx {
|
||||
|
||||
/** Read a wave file with expected sample rate.
|
||||
|
||||
@param filename Path to a wave file. It MUST be single channel, PCM encoded.
|
||||
@param expected_sample_rate Expected sample rate of the wave file. If the
|
||||
sample rate don't match, it throws an exception.
|
||||
@param filename Path to a wave file. It MUST be single channel, 16-bit
|
||||
PCM encoded.
|
||||
@param sampling_rate On return, it contains the sampling rate of the file.
|
||||
@param is_ok On return it is true if the reading succeeded; false otherwise.
|
||||
|
||||
@return Return wave samples normalized to the range [-1, 1).
|
||||
*/
|
||||
std::vector<float> ReadWave(const std::string &filename,
|
||||
float expected_sample_rate, bool *is_ok);
|
||||
std::vector<float> ReadWave(const std::string &filename, int32_t *sampling_rate,
|
||||
bool *is_ok);
|
||||
|
||||
std::vector<float> ReadWave(std::istream &is, float expected_sample_rate,
|
||||
std::vector<float> ReadWave(std::istream &is, int32_t *sampling_rate,
|
||||
bool *is_ok);
|
||||
|
||||
} // namespace sherpa_onnx
|
||||
|
||||
@@ -11,6 +11,7 @@
|
||||
#include "jni.h" // NOLINT
|
||||
|
||||
#include <strstream>
|
||||
#include <utility>
|
||||
|
||||
#if __ANDROID_API__ >= 9
|
||||
#include "android/asset_manager.h"
|
||||
@@ -43,14 +44,18 @@ class SherpaOnnx {
|
||||
stream_(recognizer_.CreateStream()) {
|
||||
}
|
||||
|
||||
void AcceptWaveform(int32_t sample_rate, const float *samples,
|
||||
int32_t n) const {
|
||||
void AcceptWaveform(int32_t sample_rate, const float *samples, int32_t n) {
|
||||
if (input_sample_rate_ == -1) {
|
||||
input_sample_rate_ = sample_rate;
|
||||
}
|
||||
|
||||
stream_->AcceptWaveform(sample_rate, samples, n);
|
||||
}
|
||||
|
||||
void InputFinished() const {
|
||||
std::vector<float> tail_padding(16000 * 0.32, 0);
|
||||
stream_->AcceptWaveform(16000, tail_padding.data(), tail_padding.size());
|
||||
std::vector<float> tail_padding(input_sample_rate_ * 0.32, 0);
|
||||
stream_->AcceptWaveform(input_sample_rate_, tail_padding.data(),
|
||||
tail_padding.size());
|
||||
stream_->InputFinished();
|
||||
}
|
||||
|
||||
@@ -70,6 +75,7 @@ class SherpaOnnx {
|
||||
private:
|
||||
sherpa_onnx::OnlineRecognizer recognizer_;
|
||||
std::unique_ptr<sherpa_onnx::OnlineStream> stream_;
|
||||
int32_t input_sample_rate_ = -1;
|
||||
};
|
||||
|
||||
static OnlineRecognizerConfig GetConfig(JNIEnv *env, jobject config) {
|
||||
@@ -276,17 +282,24 @@ JNIEXPORT jstring JNICALL Java_com_k2fsa_sherpa_onnx_SherpaOnnx_getText(
|
||||
return env->NewStringUTF(text.c_str());
|
||||
}
|
||||
|
||||
// see
|
||||
// https://stackoverflow.com/questions/29043872/android-jni-return-multiple-variables
|
||||
static jobject NewInteger(JNIEnv *env, int32_t value) {
|
||||
jclass cls = env->FindClass("java/lang/Integer");
|
||||
jmethodID constructor = env->GetMethodID(cls, "<init>", "(I)V");
|
||||
return env->NewObject(cls, constructor, value);
|
||||
}
|
||||
|
||||
SHERPA_ONNX_EXTERN_C
|
||||
JNIEXPORT jfloatArray JNICALL
|
||||
JNIEXPORT jobjectArray JNICALL
|
||||
Java_com_k2fsa_sherpa_onnx_WaveReader_00024Companion_readWave(
|
||||
JNIEnv *env, jclass /*cls*/, jobject asset_manager, jstring filename,
|
||||
jfloat expected_sample_rate) {
|
||||
JNIEnv *env, jclass /*cls*/, jobject asset_manager, jstring filename) {
|
||||
const char *p_filename = env->GetStringUTFChars(filename, nullptr);
|
||||
#if __ANDROID_API__ >= 9
|
||||
AAssetManager *mgr = AAssetManager_fromJava(env, asset_manager);
|
||||
if (!mgr) {
|
||||
SHERPA_ONNX_LOGE("Failed to get asset manager: %p", mgr);
|
||||
return nullptr;
|
||||
exit(-1);
|
||||
}
|
||||
|
||||
std::vector<char> buffer = sherpa_onnx::ReadFile(mgr, p_filename);
|
||||
@@ -297,16 +310,25 @@ Java_com_k2fsa_sherpa_onnx_WaveReader_00024Companion_readWave(
|
||||
#endif
|
||||
|
||||
bool is_ok = false;
|
||||
int32_t sampling_rate = -1;
|
||||
std::vector<float> samples =
|
||||
sherpa_onnx::ReadWave(is, expected_sample_rate, &is_ok);
|
||||
sherpa_onnx::ReadWave(is, &sampling_rate, &is_ok);
|
||||
|
||||
env->ReleaseStringUTFChars(filename, p_filename);
|
||||
|
||||
if (!is_ok) {
|
||||
return nullptr;
|
||||
SHERPA_ONNX_LOGE("Failed to read %s", p_filename);
|
||||
exit(-1);
|
||||
}
|
||||
|
||||
jfloatArray ans = env->NewFloatArray(samples.size());
|
||||
env->SetFloatArrayRegion(ans, 0, samples.size(), samples.data());
|
||||
return ans;
|
||||
|
||||
jobjectArray obj_arr = (jobjectArray)env->NewObjectArray(
|
||||
2, env->FindClass("java/lang/Object"), nullptr);
|
||||
|
||||
env->SetObjectArrayElement(obj_arr, 0, ans);
|
||||
env->SetObjectArrayElement(obj_arr, 1, NewInteger(env, sampling_rate));
|
||||
|
||||
return obj_arr;
|
||||
}
|
||||
|
||||
@@ -11,12 +11,10 @@ namespace sherpa_onnx {
|
||||
static void PybindFeatureExtractorConfig(py::module *m) {
|
||||
using PyClass = FeatureExtractorConfig;
|
||||
py::class_<PyClass>(*m, "FeatureExtractorConfig")
|
||||
.def(py::init<int32_t, int32_t, int32_t>(),
|
||||
py::arg("sampling_rate") = 16000, py::arg("feature_dim") = 80,
|
||||
py::arg("max_feature_vectors") = -1)
|
||||
.def(py::init<int32_t, int32_t>(), py::arg("sampling_rate") = 16000,
|
||||
py::arg("feature_dim") = 80)
|
||||
.def_readwrite("sampling_rate", &PyClass::sampling_rate)
|
||||
.def_readwrite("feature_dim", &PyClass::feature_dim)
|
||||
.def_readwrite("max_feature_vectors", &PyClass::max_feature_vectors)
|
||||
.def("__str__", &PyClass::ToString);
|
||||
}
|
||||
|
||||
|
||||
@@ -34,7 +34,6 @@ class OnlineRecognizer(object):
|
||||
rule3_min_utterance_length: int = 20,
|
||||
decoding_method: str = "greedy_search",
|
||||
max_active_paths: int = 4,
|
||||
max_feature_vectors: int = -1,
|
||||
):
|
||||
"""
|
||||
Please refer to
|
||||
@@ -82,9 +81,6 @@ class OnlineRecognizer(object):
|
||||
max_active_paths:
|
||||
Use only when decoding_method is modified_beam_search. It specifies
|
||||
the maximum number of active paths during beam search.
|
||||
max_feature_vectors:
|
||||
Number of feature vectors to cache. -1 means to cache all feature
|
||||
frames that have been processed.
|
||||
"""
|
||||
_assert_file_exists(tokens)
|
||||
_assert_file_exists(encoder)
|
||||
@@ -104,7 +100,6 @@ class OnlineRecognizer(object):
|
||||
feat_config = FeatureExtractorConfig(
|
||||
sampling_rate=sample_rate,
|
||||
feature_dim=feature_dim,
|
||||
max_feature_vectors=max_feature_vectors,
|
||||
)
|
||||
|
||||
endpoint_config = EndpointConfig(
|
||||
|
||||
@@ -8,18 +8,18 @@
|
||||
|
||||
import unittest
|
||||
|
||||
import sherpa_onnx
|
||||
import _sherpa_onnx
|
||||
|
||||
|
||||
class TestFeatureExtractorConfig(unittest.TestCase):
|
||||
def test_default_constructor(self):
|
||||
config = sherpa_onnx.FeatureExtractorConfig()
|
||||
config = _sherpa_onnx.FeatureExtractorConfig()
|
||||
assert config.sampling_rate == 16000, config.sampling_rate
|
||||
assert config.feature_dim == 80, config.feature_dim
|
||||
print(config)
|
||||
|
||||
def test_constructor(self):
|
||||
config = sherpa_onnx.FeatureExtractorConfig(sampling_rate=8000, feature_dim=40)
|
||||
config = _sherpa_onnx.FeatureExtractorConfig(sampling_rate=8000, feature_dim=40)
|
||||
assert config.sampling_rate == 8000, config.sampling_rate
|
||||
assert config.feature_dim == 40, config.feature_dim
|
||||
print(config)
|
||||
|
||||
@@ -8,21 +8,23 @@
|
||||
|
||||
import unittest
|
||||
|
||||
import sherpa_onnx
|
||||
import _sherpa_onnx
|
||||
|
||||
|
||||
class TestOnlineTransducerModelConfig(unittest.TestCase):
|
||||
def test_constructor(self):
|
||||
config = sherpa_onnx.OnlineTransducerModelConfig(
|
||||
config = _sherpa_onnx.OnlineTransducerModelConfig(
|
||||
encoder_filename="encoder.onnx",
|
||||
decoder_filename="decoder.onnx",
|
||||
joiner_filename="joiner.onnx",
|
||||
tokens="tokens.txt",
|
||||
num_threads=8,
|
||||
debug=True,
|
||||
)
|
||||
assert config.encoder_filename == "encoder.onnx", config.encoder_filename
|
||||
assert config.decoder_filename == "decoder.onnx", config.decoder_filename
|
||||
assert config.joiner_filename == "joiner.onnx", config.joiner_filename
|
||||
assert config.tokens == "tokens.txt", config.tokens
|
||||
assert config.num_threads == 8, config.num_threads
|
||||
assert config.debug is True, config.debug
|
||||
print(config)
|
||||
|
||||
Reference in New Issue
Block a user