diff --git a/.gitignore b/.gitignore index b898c51c..1de2fb4a 100644 --- a/.gitignore +++ b/.gitignore @@ -9,3 +9,4 @@ __pycache__ dist/ sherpa_onnx.egg-info/ .DS_Store +build-aarch64-linux-gnu diff --git a/CMakeLists.txt b/CMakeLists.txt index 4e2e16a8..9ae83e32 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -50,6 +50,12 @@ message(STATUS "SHERPA_ONNX_ENABLE_PYTHON ${SHERPA_ONNX_ENABLE_PYTHON}") set(CMAKE_CXX_STANDARD 14 CACHE STRING "The C++ version to be used.") set(CMAKE_CXX_EXTENSIONS OFF) +include(CheckIncludeFileCXX) +check_include_file_cxx(alsa/asoundlib.h SHERPA_ONNX_HAS_ALSA) +if(SHERPA_ONNX_HAS_ALSA) + add_definitions(-DSHERPA_ONNX_ENABLE_ALSA=1) +endif() + list(APPEND CMAKE_MODULE_PATH ${CMAKE_SOURCE_DIR}/cmake/Modules) list(APPEND CMAKE_MODULE_PATH ${CMAKE_SOURCE_DIR}/cmake) diff --git a/build-aarch64-linux-gnu.sh b/build-aarch64-linux-gnu.sh new file mode 100755 index 00000000..3712b3ca --- /dev/null +++ b/build-aarch64-linux-gnu.sh @@ -0,0 +1,49 @@ +#!/usr/bin/env bash + +if ! command -v aarch64-linux-gnu-gcc &> /dev/null; then + echo "Please install a toolchain for cross-compiling." + echo "You can refer to: " + echo " https://k2-fsa.github.io/sherpa/onnx/install/aarch64-embedded-linux.html" + echo "for help." + exit 1 +fi + +set -ex + +dir=build-aarch64-linux-gnu +mkdir -p $dir +cd $dir + +if [ ! -f alsa-lib/src/.libs/libasound.so ]; then + echo "Start to cross-compile alsa-lib" + if [ ! -d alsa-lib ]; then + git clone --depth 1 https://github.com/alsa-project/alsa-lib + fi + # If it shows: + # ./gitcompile: line 79: libtoolize: command not found + # Please use: + # sudo apt-get install libtool m4 automake + # + pushd alsa-lib + CC=aarch64-linux-gnu-gcc ./gitcompile --host=aarch64-linux-gnu + popd + echo "Finish cross-compiling alsa-lib" +fi + +export CPLUS_INCLUDE_PATH=$PWD/alsa-lib/include:$CPLUS_INCLUDE_PATH +export SHERPA_ONNX_ALSA_LIB_DIR=$PWD/alsa-lib/src/.libs + +cmake \ + -DCMAKE_INSTALL_PREFIX=./install \ + -DCMAKE_BUILD_TYPE=Release \ + -DBUILD_SHARED_LIBS=OFF \ + -DSHERPA_ONNX_ENABLE_TESTS=OFF \ + -DSHERPA_ONNX_ENABLE_PYTHON=OFF \ + -DCMAKE_TOOLCHAIN_FILE=../toolchains/aarch64-linux-gnu.toolchain.cmake \ + .. + +make VERBOSE=1 -j4 +make install/strip + +# Enable it if only needed +# cp -v $SHERPA_ONNX_ALSA_LIB_DIR/libasound.so* ./install/lib/ diff --git a/cmake/onnxruntime.cmake b/cmake/onnxruntime.cmake index e7910bc1..cb4a7aa3 100644 --- a/cmake/onnxruntime.cmake +++ b/cmake/onnxruntime.cmake @@ -1,25 +1,39 @@ function(download_onnxruntime) include(FetchContent) - if(UNIX AND NOT APPLE) - # If you don't have access to the Internet, - # please pre-download onnxruntime - set(possible_file_locations - $ENV{HOME}/Downloads/onnxruntime-linux-x64-1.14.0.tgz - ${PROJECT_SOURCE_DIR}/onnxruntime-linux-x64-1.14.0.tgz - ${PROJECT_BINARY_DIR}/onnxruntime-linux-x64-1.14.0.tgz - /tmp/onnxruntime-linux-x64-1.14.0.tgz - /star-fj/fangjun/download/github/onnxruntime-linux-x64-1.14.0.tgz - ) + if(CMAKE_SYSTEM_NAME STREQUAL Linux) + if(CMAKE_SYSTEM_PROCESSOR STREQUAL aarch64) + # For embedded systems + set(possible_file_locations + $ENV{HOME}/Downloads/onnxruntime-linux-aarch64-1.14.0.tgz + ${PROJECT_SOURCE_DIR}/onnxruntime-linux-aarch64-1.14.0.tgz + ${PROJECT_BINARY_DIR}/onnxruntime-linux-aarch64-1.14.0.tgz + /tmp/onnxruntime-linux-aarch64-1.14.0.tgz + /star-fj/fangjun/download/github/onnxruntime-linux-aarch64-1.14.0.tgz + ) + set(onnxruntime_URL "https://github.com/microsoft/onnxruntime/releases/download/v1.14.0/onnxruntime-linux-aarch64-1.14.0.tgz") + set(onnxruntime_HASH "SHA256=9384d2e6e29fed693a4630303902392eead0c41bee5705ccac6d6d34a3d5db86") - set(onnxruntime_URL "https://github.com/microsoft/onnxruntime/releases/download/v1.14.0/onnxruntime-linux-x64-1.14.0.tgz") - set(onnxruntime_HASH "SHA256=92bf534e5fa5820c8dffe9de2850f84ed2a1c063e47c659ce09e8c7938aa2090") - # After downloading, it contains: - # ./lib/libonnxruntime.so.1.14.0 - # ./lib/libonnxruntime.so, which is a symlink to lib/libonnxruntime.so.1.14.0 - # - # ./include - # It contains all the needed header files + else() + # If you don't have access to the Internet, + # please pre-download onnxruntime + set(possible_file_locations + $ENV{HOME}/Downloads/onnxruntime-linux-x64-1.14.0.tgz + ${PROJECT_SOURCE_DIR}/onnxruntime-linux-x64-1.14.0.tgz + ${PROJECT_BINARY_DIR}/onnxruntime-linux-x64-1.14.0.tgz + /tmp/onnxruntime-linux-x64-1.14.0.tgz + /star-fj/fangjun/download/github/onnxruntime-linux-x64-1.14.0.tgz + ) + + set(onnxruntime_URL "https://github.com/microsoft/onnxruntime/releases/download/v1.14.0/onnxruntime-linux-x64-1.14.0.tgz") + set(onnxruntime_HASH "SHA256=92bf534e5fa5820c8dffe9de2850f84ed2a1c063e47c659ce09e8c7938aa2090") + # After downloading, it contains: + # ./lib/libonnxruntime.so.1.14.0 + # ./lib/libonnxruntime.so, which is a symlink to lib/libonnxruntime.so.1.14.0 + # + # ./include + # It contains all the needed header files + endif() elseif(APPLE) # If you don't have access to the Internet, # please pre-download onnxruntime diff --git a/sherpa-onnx/csrc/CMakeLists.txt b/sherpa-onnx/csrc/CMakeLists.txt index 8895fab4..801f78f4 100644 --- a/sherpa-onnx/csrc/CMakeLists.txt +++ b/sherpa-onnx/csrc/CMakeLists.txt @@ -11,6 +11,7 @@ add_library(sherpa-onnx-core online-transducer-model.cc online-zipformer-transducer-model.cc onnx-utils.cc + resample.cc symbol-table.cc text-utils.cc unbind.cc @@ -32,6 +33,18 @@ endif() install(TARGETS sherpa-onnx-core DESTINATION lib) install(TARGETS sherpa-onnx DESTINATION bin) +if(SHERPA_ONNX_HAS_ALSA) + add_executable(sherpa-onnx-alsa sherpa-onnx-alsa.cc alsa.cc) + target_link_libraries(sherpa-onnx-alsa PRIVATE sherpa-onnx-core) + + if(DEFINED ENV{SHERPA_ONNX_ALSA_LIB_DIR}) + target_link_libraries(sherpa-onnx-alsa PRIVATE -L$ENV{SHERPA_ONNX_ALSA_LIB_DIR} -lasound) + else() + target_link_libraries(sherpa-onnx-alsa PRIVATE asound) + endif() + install(TARGETS sherpa-onnx-alsa DESTINATION bin) +endif() + if(SHERPA_ONNX_ENABLE_TESTS) set(sherpa_onnx_test_srcs cat-test.cc diff --git a/sherpa-onnx/csrc/alsa.cc b/sherpa-onnx/csrc/alsa.cc new file mode 100644 index 00000000..3001a4d9 --- /dev/null +++ b/sherpa-onnx/csrc/alsa.cc @@ -0,0 +1,162 @@ +// sherpa-onnx/csrc/sherpa-alsa.cc +// +// Copyright (c) 2022-2023 Xiaomi Corporation + +#ifdef SHERPA_ONNX_ENABLE_ALSA + +#include "sherpa-onnx/csrc/alsa.h" + +#include + +#include "alsa/asoundlib.h" + +namespace sherpa_onnx { + +void ToFloat(const std::vector &in, int32_t num_channels, + std::vector *out) { + out->resize(in.size() / num_channels); + + int32_t n = in.size(); + for (int32_t i = 0, k = 0; i < n; i += num_channels, ++k) { + (*out)[k] = in[i] / 32768.; + } +} + +Alsa::Alsa(const char *device_name) { + const char *kDeviceHelp = R"( +Please use the command: + + arecord -l + +to list all available devices. For instance, if the output is: + +**** List of CAPTURE Hardware Devices **** +card 3: UACDemoV10 [UACDemoV1.0], device 0: USB Audio [USB Audio] + Subdevices: 1/1 + Subdevice #0: subdevice #0 + +and if you want to select card 3 and the device 0 on that card, please use: + + hw:3,0 + + )"; + + int32_t err = + snd_pcm_open(&capture_handle_, device_name, SND_PCM_STREAM_CAPTURE, 0); + if (err) { + fprintf(stderr, "Unable to open: %s. %s\n", device_name, snd_strerror(err)); + fprintf(stderr, "%s\n", kDeviceHelp); + exit(-1); + } + + snd_pcm_hw_params_t *hw_params; + snd_pcm_hw_params_alloca(&hw_params); + + err = snd_pcm_hw_params_any(capture_handle_, hw_params); + if (err) { + fprintf(stderr, "Failed to initialize hw_params: %s\n", snd_strerror(err)); + exit(-1); + } + + err = snd_pcm_hw_params_set_access(capture_handle_, hw_params, + SND_PCM_ACCESS_RW_INTERLEAVED); + if (err) { + fprintf(stderr, "Failed to set access type: %s\n", snd_strerror(err)); + exit(-1); + } + + err = snd_pcm_hw_params_set_format(capture_handle_, hw_params, + SND_PCM_FORMAT_S16_LE); + if (err) { + fprintf(stderr, "Failed to set format: %s\n", snd_strerror(err)); + exit(-1); + } + + // mono + err = snd_pcm_hw_params_set_channels(capture_handle_, hw_params, 1); + if (err) { + fprintf(stderr, "Failed to set number of channels to 1. %s\n", + snd_strerror(err)); + + err = snd_pcm_hw_params_set_channels(capture_handle_, hw_params, 2); + if (err) { + fprintf(stderr, "Failed to set number of channels to 2. %s\n", + snd_strerror(err)); + + exit(-1); + } + actual_channel_count_ = 2; + fprintf(stderr, + "Channel count is set to 2. Will use only 1 channel of it.\n"); + } + + uint32_t actual_sample_rate = expected_sample_rate_; + + int32_t dir = 0; + err = snd_pcm_hw_params_set_rate_near(capture_handle_, hw_params, + &actual_sample_rate, &dir); + if (err) { + fprintf(stderr, "Failed to set sample rate to, %d: %s\n", + expected_sample_rate_, snd_strerror(err)); + exit(-1); + } + actual_sample_rate_ = actual_sample_rate; + + if (actual_sample_rate_ != expected_sample_rate_) { + fprintf(stderr, "Failed to set sample rate to %d\n", expected_sample_rate_); + fprintf(stderr, "Current sample rate is %d\n", actual_sample_rate_); + fprintf(stderr, + "Creating a resampler:\n" + " in_sample_rate: %d\n" + " output_sample_rate: %d\n", + actual_sample_rate_, expected_sample_rate_); + + float min_freq = std::min(actual_sample_rate_, expected_sample_rate_); + float lowpass_cutoff = 0.99 * 0.5 * min_freq; + + int32_t lowpass_filter_width = 6; + resampler_ = std::make_unique( + actual_sample_rate_, expected_sample_rate_, lowpass_cutoff, + lowpass_filter_width); + } else { + fprintf(stderr, "Current sample rate: %d\n", actual_sample_rate_); + } + + err = snd_pcm_hw_params(capture_handle_, hw_params); + if (err) { + fprintf(stderr, "Failed to set hw params: %s\n", snd_strerror(err)); + exit(-1); + } + + err = snd_pcm_prepare(capture_handle_); + if (err) { + fprintf(stderr, "Failed to prepare for recording: %s\n", snd_strerror(err)); + exit(-1); + } + + fprintf(stderr, "Recording started!\n"); +} + +Alsa::~Alsa() { snd_pcm_close(capture_handle_); } + +const std::vector &Alsa::Read(int32_t num_samples) { + samples_.resize(num_samples * actual_channel_count_); + + // count is in frames. Each frame contains actual_channel_count_ samples + int32_t count = snd_pcm_readi(capture_handle_, samples_.data(), num_samples); + + samples_.resize(count * actual_channel_count_); + + ToFloat(samples_, actual_channel_count_, &samples1_); + + if (!resampler_) { + return samples1_; + } + + resampler_->Resample(samples1_.data(), samples_.size(), false, &samples2_); + return samples2_; +} + +} // namespace sherpa_onnx + +#endif diff --git a/sherpa-onnx/csrc/alsa.h b/sherpa-onnx/csrc/alsa.h new file mode 100644 index 00000000..7a9ca7a7 --- /dev/null +++ b/sherpa-onnx/csrc/alsa.h @@ -0,0 +1,46 @@ +// sherpa-onnx/csrc/sherpa-alsa.h +// +// Copyright (c) 2022-2023 Xiaomi Corporation + +#ifndef SHERPA_ONNX_CSRC_ALSA_H_ +#define SHERPA_ONNX_CSRC_ALSA_H_ + +#include +#include + +#include "alsa/asoundlib.h" +#include "sherpa-onnx/csrc/resample.h" + +namespace sherpa_onnx { + +class Alsa { + public: + explicit Alsa(const char *device_name); + ~Alsa(); + + // This is a blocking read. + // + // @param num_samples Number of samples to read. + // + // The returned value is valid until the next call to Read(). + const std::vector &Read(int32_t num_samples); + + int32_t GetExpectedSampleRate() const { return expected_sample_rate_; } + int32_t GetActualSampleRate() const { return actual_sample_rate_; } + + private: + snd_pcm_t *capture_handle_; + int32_t expected_sample_rate_ = 16000; + int32_t actual_sample_rate_; + + int32_t actual_channel_count_ = 1; + + std::unique_ptr resampler_; + std::vector samples_; // directly from the microphone + std::vector samples1_; // normalized version of samples_ + std::vector samples2_; // possibly resampled from samples1_ +}; + +} // namespace sherpa_onnx + +#endif // SHERPA_ONNX_CSRC_ALSA_H_ diff --git a/sherpa-onnx/csrc/display.h b/sherpa-onnx/csrc/display.h new file mode 100644 index 00000000..c7bbf292 --- /dev/null +++ b/sherpa-onnx/csrc/display.h @@ -0,0 +1,79 @@ +// sherpa-onnx/csrc/display.h +// +// Copyright (c) 2022-2023 Xiaomi Corporation + +#ifndef SHERPA_ONNX_CSRC_DISPLAY_H_ +#define SHERPA_ONNX_CSRC_DISPLAY_H_ +#include + +#include + +namespace sherpa_onnx { + +class Display { + public: + void Print(int32_t segment_id, const std::string &s) { +#ifdef _MSC_VER + fprintf(stderr, "%d:%s\n", segment_id, s.c_str()); + return; +#endif + if (last_segment_ == segment_id) { + Clear(); + } else { + if (last_segment_ != -1) { + fprintf(stderr, "\n\r"); + } + last_segment_ = segment_id; + num_previous_lines_ = 0; + } + + fprintf(stderr, "\r%d:", segment_id); + + int32_t i = 0; + for (size_t n = 0; n < s.size();) { + if (s[n] > 0 && s[n] < 0x7f) { + fprintf(stderr, "%c", s[n]); + ++n; + } else { + // Each Chinese character occupies 3 bytes for UTF-8 encoding. + std::string tmp(s.begin() + n, s.begin() + n + 3); + fprintf(stderr, "%s", tmp.data()); + n += 3; + } + + ++i; + if (i >= max_word_per_line_ && n + 1 < s.size() && + (s[n] == ' ' || s[n] < 0)) { + fprintf(stderr, "\n\r "); + ++num_previous_lines_; + i = 0; + } + } + } + + private: + // Clear the output for the current segment + void Clear() { + ClearCurrentLine(); + while (num_previous_lines_ > 0) { + GoUpOneLine(); + ClearCurrentLine(); + --num_previous_lines_; + } + } + + // Clear the current line + void ClearCurrentLine() const { fprintf(stderr, "\33[2K\r"); } + + // Move the cursor to the previous line + void GoUpOneLine() const { fprintf(stderr, "\033[1A\r"); } + + private: + int32_t max_word_per_line_ = 60; + int32_t num_previous_lines_ = 0; + int32_t last_segment_ = -1; +}; + +} // namespace sherpa_onnx + +#endif // SHERPA_ONNX_CSRC_DISPLAY_H_ diff --git a/sherpa-onnx/csrc/resample.cc b/sherpa-onnx/csrc/resample.cc new file mode 100644 index 00000000..8ef3a1b5 --- /dev/null +++ b/sherpa-onnx/csrc/resample.cc @@ -0,0 +1,309 @@ +/** + * Copyright 2013 Pegah Ghahremani + * 2014 IMSL, PKU-HKUST (author: Wei Shi) + * 2014 Yanqing Sun, Junjie Wang + * 2014 Johns Hopkins University (author: Daniel Povey) + * Copyright 2023 Xiaomi Corporation (authors: Fangjun Kuang) + * + * See LICENSE for clarification regarding multiple authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +// this file is copied and modified from +// kaldi/src/feat/resample.cc + +#include "sherpa-onnx/csrc/resample.h" + +#include +#include +#include + +#include +#include + +#ifndef M_2PI +#define M_2PI 6.283185307179586476925286766559005 +#endif + +#ifndef M_PI +#define M_PI 3.1415926535897932384626433832795 +#endif + +namespace sherpa_onnx { + +template +I Gcd(I m, I n) { + // this function is copied from kaldi/src/base/kaldi-math.h + if (m == 0 || n == 0) { + if (m == 0 && n == 0) { // gcd not defined, as all integers are divisors. + fprintf(stderr, "Undefined GCD since m = 0, n = 0."); + exit(-1); + } + return (m == 0 ? (n > 0 ? n : -n) : (m > 0 ? m : -m)); + // return absolute value of whichever is nonzero + } + // could use compile-time assertion + // but involves messing with complex template stuff. + static_assert(std::is_integral::value, ""); + while (1) { + m %= n; + if (m == 0) return (n > 0 ? n : -n); + n %= m; + if (n == 0) return (m > 0 ? m : -m); + } +} + +/// Returns the least common multiple of two integers. Will +/// crash unless the inputs are positive. +template +I Lcm(I m, I n) { + // This function is copied from kaldi/src/base/kaldi-math.h + assert(m > 0 && n > 0); + I gcd = Gcd(m, n); + return gcd * (m / gcd) * (n / gcd); +} + +static float DotProduct(const float *a, const float *b, int32_t n) { + float sum = 0; + for (int32_t i = 0; i != n; ++i) { + sum += a[i] * b[i]; + } + return sum; +} + +LinearResample::LinearResample(int32_t samp_rate_in_hz, + int32_t samp_rate_out_hz, float filter_cutoff_hz, + int32_t num_zeros) + : samp_rate_in_(samp_rate_in_hz), + samp_rate_out_(samp_rate_out_hz), + filter_cutoff_(filter_cutoff_hz), + num_zeros_(num_zeros) { + assert(samp_rate_in_hz > 0.0 && samp_rate_out_hz > 0.0 && + filter_cutoff_hz > 0.0 && filter_cutoff_hz * 2 <= samp_rate_in_hz && + filter_cutoff_hz * 2 <= samp_rate_out_hz && num_zeros > 0); + + // base_freq is the frequency of the repeating unit, which is the gcd + // of the input frequencies. + int32_t base_freq = Gcd(samp_rate_in_, samp_rate_out_); + input_samples_in_unit_ = samp_rate_in_ / base_freq; + output_samples_in_unit_ = samp_rate_out_ / base_freq; + + SetIndexesAndWeights(); + Reset(); +} + +void LinearResample::SetIndexesAndWeights() { + first_index_.resize(output_samples_in_unit_); + weights_.resize(output_samples_in_unit_); + + double window_width = num_zeros_ / (2.0 * filter_cutoff_); + + for (int32_t i = 0; i < output_samples_in_unit_; i++) { + double output_t = i / static_cast(samp_rate_out_); + double min_t = output_t - window_width, max_t = output_t + window_width; + // we do ceil on the min and floor on the max, because if we did it + // the other way around we would unnecessarily include indexes just + // outside the window, with zero coefficients. It's possible + // if the arguments to the ceil and floor expressions are integers + // (e.g. if filter_cutoff_ has an exact ratio with the sample rates), + // that we unnecessarily include something with a zero coefficient, + // but this is only a slight efficiency issue. + int32_t min_input_index = ceil(min_t * samp_rate_in_), + max_input_index = floor(max_t * samp_rate_in_), + num_indices = max_input_index - min_input_index + 1; + first_index_[i] = min_input_index; + weights_[i].resize(num_indices); + for (int32_t j = 0; j < num_indices; j++) { + int32_t input_index = min_input_index + j; + double input_t = input_index / static_cast(samp_rate_in_), + delta_t = input_t - output_t; + // sign of delta_t doesn't matter. + weights_[i][j] = FilterFunc(delta_t) / samp_rate_in_; + } + } +} + +/** Here, t is a time in seconds representing an offset from + the center of the windowed filter function, and FilterFunction(t) + returns the windowed filter function, described + in the header as h(t) = f(t)g(t), evaluated at t. +*/ +float LinearResample::FilterFunc(float t) const { + float window, // raised-cosine (Hanning) window of width + // num_zeros_/2*filter_cutoff_ + filter; // sinc filter function + if (fabs(t) < num_zeros_ / (2.0 * filter_cutoff_)) + window = 0.5 * (1 + cos(M_2PI * filter_cutoff_ / num_zeros_ * t)); + else + window = 0.0; // outside support of window function + if (t != 0) + filter = sin(M_2PI * filter_cutoff_ * t) / (M_PI * t); + else + filter = 2 * filter_cutoff_; // limit of the function at t = 0 + return filter * window; +} + +void LinearResample::Reset() { + input_sample_offset_ = 0; + output_sample_offset_ = 0; + input_remainder_.resize(0); +} + +void LinearResample::Resample(const float *input, int32_t input_dim, bool flush, + std::vector *output) { + int64_t tot_input_samp = input_sample_offset_ + input_dim, + tot_output_samp = GetNumOutputSamples(tot_input_samp, flush); + + assert(tot_output_samp >= output_sample_offset_); + + output->resize(tot_output_samp - output_sample_offset_); + + // samp_out is the index into the total output signal, not just the part + // of it we are producing here. + for (int64_t samp_out = output_sample_offset_; samp_out < tot_output_samp; + samp_out++) { + int64_t first_samp_in; + int32_t samp_out_wrapped; + GetIndexes(samp_out, &first_samp_in, &samp_out_wrapped); + const std::vector &weights = weights_[samp_out_wrapped]; + // first_input_index is the first index into "input" that we have a weight + // for. + int32_t first_input_index = + static_cast(first_samp_in - input_sample_offset_); + float this_output; + if (first_input_index >= 0 && + first_input_index + static_cast(weights.size()) <= input_dim) { + this_output = + DotProduct(input + first_input_index, weights.data(), weights.size()); + } else { // Handle edge cases. + this_output = 0.0; + for (int32_t i = 0; i < static_cast(weights.size()); i++) { + float weight = weights[i]; + int32_t input_index = first_input_index + i; + if (input_index < 0 && + static_cast(input_remainder_.size()) + input_index >= 0) { + this_output += + weight * input_remainder_[input_remainder_.size() + input_index]; + } else if (input_index >= 0 && input_index < input_dim) { + this_output += weight * input[input_index]; + } else if (input_index >= input_dim) { + // We're past the end of the input and are adding zero; should only + // happen if the user specified flush == true, or else we would not + // be trying to output this sample. + assert(flush); + } + } + } + int32_t output_index = + static_cast(samp_out - output_sample_offset_); + (*output)[output_index] = this_output; + } + + if (flush) { + Reset(); // Reset the internal state. + } else { + SetRemainder(input, input_dim); + input_sample_offset_ = tot_input_samp; + output_sample_offset_ = tot_output_samp; + } +} + +int64_t LinearResample::GetNumOutputSamples(int64_t input_num_samp, + bool flush) const { + // For exact computation, we measure time in "ticks" of 1.0 / tick_freq, + // where tick_freq is the least common multiple of samp_rate_in_ and + // samp_rate_out_. + int32_t tick_freq = Lcm(samp_rate_in_, samp_rate_out_); + int32_t ticks_per_input_period = tick_freq / samp_rate_in_; + + // work out the number of ticks in the time interval + // [ 0, input_num_samp/samp_rate_in_ ). + int64_t interval_length_in_ticks = input_num_samp * ticks_per_input_period; + if (!flush) { + float window_width = num_zeros_ / (2.0 * filter_cutoff_); + // To count the window-width in ticks we take the floor. This + // is because since we're looking for the largest integer num-out-samp + // that fits in the interval, which is open on the right, a reduction + // in interval length of less than a tick will never make a difference. + // For example, the largest integer in the interval [ 0, 2 ) and the + // largest integer in the interval [ 0, 2 - 0.9 ) are the same (both one). + // So when we're subtracting the window-width we can ignore the fractional + // part. + int32_t window_width_ticks = floor(window_width * tick_freq); + // The time-period of the output that we can sample gets reduced + // by the window-width (which is actually the distance from the + // center to the edge of the windowing function) if we're not + // "flushing the output". + interval_length_in_ticks -= window_width_ticks; + } + if (interval_length_in_ticks <= 0) return 0; + + int32_t ticks_per_output_period = tick_freq / samp_rate_out_; + // Get the last output-sample in the closed interval, i.e. replacing [ ) with + // [ ]. Note: integer division rounds down. See + // http://en.wikipedia.org/wiki/Interval_(mathematics) for an explanation of + // the notation. + int64_t last_output_samp = interval_length_in_ticks / ticks_per_output_period; + // We need the last output-sample in the open interval, so if it takes us to + // the end of the interval exactly, subtract one. + if (last_output_samp * ticks_per_output_period == interval_length_in_ticks) + last_output_samp--; + + // First output-sample index is zero, so the number of output samples + // is the last output-sample plus one. + int64_t num_output_samp = last_output_samp + 1; + return num_output_samp; +} + +// inline +void LinearResample::GetIndexes(int64_t samp_out, int64_t *first_samp_in, + int32_t *samp_out_wrapped) const { + // A unit is the smallest nonzero amount of time that is an exact + // multiple of the input and output sample periods. The unit index + // is the answer to "which numbered unit we are in". + int64_t unit_index = samp_out / output_samples_in_unit_; + // samp_out_wrapped is equal to samp_out % output_samples_in_unit_ + *samp_out_wrapped = + static_cast(samp_out - unit_index * output_samples_in_unit_); + *first_samp_in = + first_index_[*samp_out_wrapped] + unit_index * input_samples_in_unit_; +} + +void LinearResample::SetRemainder(const float *input, int32_t input_dim) { + std::vector old_remainder(input_remainder_); + // max_remainder_needed is the width of the filter from side to side, + // measured in input samples. you might think it should be half that, + // but you have to consider that you might be wanting to output samples + // that are "in the past" relative to the beginning of the latest + // input... anyway, storing more remainder than needed is not harmful. + int32_t max_remainder_needed = + ceil(samp_rate_in_ * num_zeros_ / filter_cutoff_); + input_remainder_.resize(max_remainder_needed); + for (int32_t index = -static_cast(input_remainder_.size()); + index < 0; index++) { + // we interpret "index" as an offset from the end of "input" and + // from the end of input_remainder_. + int32_t input_index = index + input_dim; + if (input_index >= 0) { + input_remainder_[index + static_cast(input_remainder_.size())] = + input[input_index]; + } else if (input_index + static_cast(old_remainder.size()) >= 0) { + input_remainder_[index + static_cast(input_remainder_.size())] = + old_remainder[input_index + + static_cast(old_remainder.size())]; + // else leave it at zero. + } + } +} + +} // namespace sherpa_onnx diff --git a/sherpa-onnx/csrc/resample.h b/sherpa-onnx/csrc/resample.h new file mode 100644 index 00000000..2006ae90 --- /dev/null +++ b/sherpa-onnx/csrc/resample.h @@ -0,0 +1,144 @@ +/** + * Copyright 2013 Pegah Ghahremani + * 2014 IMSL, PKU-HKUST (author: Wei Shi) + * 2014 Yanqing Sun, Junjie Wang + * 2014 Johns Hopkins University (author: Daniel Povey) + * Copyright 2023 Xiaomi Corporation (authors: Fangjun Kuang) + * + * See LICENSE for clarification regarding multiple authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +// this file is copied and modified from +// kaldi/src/feat/resample.h +#ifndef SHERPA_ONNX_CSRC_RESAMPLE_H_ +#define SHERPA_ONNX_CSRC_RESAMPLE_H_ + +#include +#include + +namespace sherpa_onnx { + +/* + We require that the input and output sampling rate be specified as + integers, as this is an easy way to specify that their ratio be rational. +*/ + +class LinearResample { + public: + /// Constructor. We make the input and output sample rates integers, because + /// we are going to need to find a common divisor. This should just remind + /// you that they need to be integers. The filter cutoff needs to be less + /// than samp_rate_in_hz/2 and less than samp_rate_out_hz/2. num_zeros + /// controls the sharpness of the filter, more == sharper but less efficient. + /// We suggest around 4 to 10 for normal use. + LinearResample(int32_t samp_rate_in_hz, int32_t samp_rate_out_hz, + float filter_cutoff_hz, int32_t num_zeros); + + /// Calling the function Reset() resets the state of the object prior to + /// processing a new signal; it is only necessary if you have called + /// Resample(x, x_size, false, y) for some signal, leading to a remainder of + /// the signal being called, but then abandon processing the signal before + /// calling Resample(x, x_size, true, y) for the last piece. Call it + /// unnecessarily between signals will not do any harm. + void Reset(); + + /// This function does the resampling. If you call it with flush == true and + /// you have never called it with flush == false, it just resamples the input + /// signal (it resizes the output to a suitable number of samples). + /// + /// You can also use this function to process a signal a piece at a time. + /// suppose you break it into piece1, piece2, ... pieceN. You can call + /// \code{.cc} + /// Resample(piece1, piece1_size, false, &output1); + /// Resample(piece2, piece2_size, false, &output2); + /// Resample(piece3, piece3_size, true, &output3); + /// \endcode + /// If you call it with flush == false, it won't output the last few samples + /// but will remember them, so that if you later give it a second piece of + /// the input signal it can process it correctly. + /// If your most recent call to the object was with flush == false, it will + /// have internal state; you can remove this by calling Reset(). + /// Empty input is acceptable. + void Resample(const float *input, int32_t input_dim, bool flush, + std::vector *output); + + //// Return the input and output sampling rates (for checks, for example) + int32_t GetInputSamplingRate() const { return samp_rate_in_; } + int32_t GetOutputSamplingRate() const { return samp_rate_out_; } + + private: + void SetIndexesAndWeights(); + + float FilterFunc(float) const; + + /// This function outputs the number of output samples we will output + /// for a signal with "input_num_samp" input samples. If flush == true, + /// we return the largest n such that + /// (n/samp_rate_out_) is in the interval [ 0, input_num_samp/samp_rate_in_ ), + /// and note that the interval is half-open. If flush == false, + /// define window_width as num_zeros / (2.0 * filter_cutoff_); + /// we return the largest n such that (n/samp_rate_out_) is in the interval + /// [ 0, input_num_samp/samp_rate_in_ - window_width ). + int64_t GetNumOutputSamples(int64_t input_num_samp, bool flush) const; + + /// Given an output-sample index, this function outputs to *first_samp_in the + /// first input-sample index that we have a weight on (may be negative), + /// and to *samp_out_wrapped the index into weights_ where we can get the + /// corresponding weights on the input. + inline void GetIndexes(int64_t samp_out, int64_t *first_samp_in, + int32_t *samp_out_wrapped) const; + + void SetRemainder(const float *input, int32_t input_dim); + + private: + // The following variables are provided by the user. + int32_t samp_rate_in_; + int32_t samp_rate_out_; + float filter_cutoff_; + int32_t num_zeros_; + + int32_t input_samples_in_unit_; ///< The number of input samples in the + ///< smallest repeating unit: num_samp_in_ = + ///< samp_rate_in_hz / Gcd(samp_rate_in_hz, + ///< samp_rate_out_hz) + + int32_t output_samples_in_unit_; ///< The number of output samples in the + ///< smallest repeating unit: num_samp_out_ + ///< = samp_rate_out_hz / + ///< Gcd(samp_rate_in_hz, samp_rate_out_hz) + + /// The first input-sample index that we sum over, for this output-sample + /// index. May be negative; any truncation at the beginning is handled + /// separately. This is just for the first few output samples, but we can + /// extrapolate the correct input-sample index for arbitrary output samples. + std::vector first_index_; + + /// Weights on the input samples, for this output-sample index. + std::vector> weights_; + + // the following variables keep track of where we are in a particular signal, + // if it is being provided over multiple calls to Resample(). + + int64_t input_sample_offset_; ///< The number of input samples we have + ///< already received for this signal + ///< (including anything in remainder_) + int64_t output_sample_offset_; ///< The number of samples we have already + ///< output for this signal. + std::vector input_remainder_; ///< A small trailing part of the + ///< previously seen input signal. +}; + +} // namespace sherpa_onnx + +#endif // SHERPA_ONNX_CSRC_RESAMPLE_H_ diff --git a/sherpa-onnx/csrc/sherpa-onnx.cc b/sherpa-onnx/csrc/sherpa-onnx.cc index 7b838b53..ad23ce41 100644 --- a/sherpa-onnx/csrc/sherpa-onnx.cc +++ b/sherpa-onnx/csrc/sherpa-onnx.cc @@ -10,9 +10,6 @@ #include "sherpa-onnx/csrc/online-recognizer.h" #include "sherpa-onnx/csrc/online-stream.h" -#include "sherpa-onnx/csrc/online-transducer-greedy-search-decoder.h" -#include "sherpa-onnx/csrc/online-transducer-model-config.h" -#include "sherpa-onnx/csrc/online-transducer-model.h" #include "sherpa-onnx/csrc/symbol-table.h" #include "sherpa-onnx/csrc/wave-reader.h" diff --git a/toolchains/aarch64-linux-gnu.toolchain.cmake b/toolchains/aarch64-linux-gnu.toolchain.cmake new file mode 100644 index 00000000..e72e0cba --- /dev/null +++ b/toolchains/aarch64-linux-gnu.toolchain.cmake @@ -0,0 +1,18 @@ +# Copied from https://github.com/Tencent/ncnn/blob/master/toolchains/aarch64-linux-gnu.toolchain.cmake + +set(CMAKE_SYSTEM_NAME Linux) +set(CMAKE_SYSTEM_PROCESSOR aarch64) + +set(CMAKE_C_COMPILER "aarch64-linux-gnu-gcc") +set(CMAKE_CXX_COMPILER "aarch64-linux-gnu-g++") + +set(CMAKE_FIND_ROOT_PATH_MODE_PROGRAM NEVER) +set(CMAKE_FIND_ROOT_PATH_MODE_LIBRARY ONLY) +set(CMAKE_FIND_ROOT_PATH_MODE_INCLUDE ONLY) + +set(CMAKE_C_FLAGS "-march=armv8-a") +set(CMAKE_CXX_FLAGS "-march=armv8-a") + +# cache flags +set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS}" CACHE STRING "c flags") +set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS}" CACHE STRING "c++ flags")