Add Python API (#31)

This commit is contained in:
Fangjun Kuang
2023-02-19 19:36:03 +08:00
committed by GitHub
parent 8acc059b3f
commit ea09d5fbc5
51 changed files with 967 additions and 57 deletions

View File

@@ -1,6 +1,6 @@
include_directories(${CMAKE_SOURCE_DIR})
add_executable(sherpa-onnx
add_library(sherpa-onnx-core
features.cc
online-lstm-transducer-model.cc
online-recognizer.cc
@@ -9,15 +9,21 @@ add_executable(sherpa-onnx
online-transducer-model-config.cc
online-transducer-model.cc
onnx-utils.cc
sherpa-onnx.cc
symbol-table.cc
wave-reader.cc
)
target_link_libraries(sherpa-onnx
target_link_libraries(sherpa-onnx-core
onnxruntime
kaldi-native-fbank-core
)
add_executable(sherpa-onnx-show-info show-onnx-info.cc)
target_link_libraries(sherpa-onnx-show-info onnxruntime)
add_executable(sherpa-onnx sherpa-onnx.cc)
target_link_libraries(sherpa-onnx sherpa-onnx-core)
if(NOT WIN32)
target_link_libraries(sherpa-onnx "-Wl,-rpath,${SHERPA_ONNX_RPATH_ORIGIN}/../lib")
endif()
install(TARGETS sherpa-onnx-core DESTINATION lib)
install(TARGETS sherpa-onnx DESTINATION bin)

View File

@@ -1,4 +1,4 @@
// sherpa/csrc/features.cc
// sherpa-onnx/csrc/features.cc
//
// Copyright (c) 2023 Xiaomi Corporation

View File

@@ -1,4 +1,4 @@
// sherpa/csrc/features.h
// sherpa-onnx/csrc/features.h
//
// Copyright (c) 2023 Xiaomi Corporation

View File

@@ -1,4 +1,4 @@
// sherpa/csrc/online-lstm-transducer-model.cc
// sherpa-onnx/csrc/online-lstm-transducer-model.cc
//
// Copyright (c) 2023 Xiaomi Corporation
#include "sherpa-onnx/csrc/online-lstm-transducer-model.h"
@@ -232,7 +232,7 @@ std::vector<Ort::Value> OnlineLstmTransducerModel::GetEncoderInitStates() {
std::pair<Ort::Value, std::vector<Ort::Value>>
OnlineLstmTransducerModel::RunEncoder(Ort::Value features,
std::vector<Ort::Value> &states) {
std::vector<Ort::Value> states) {
auto memory_info =
Ort::MemoryInfo::CreateCpu(OrtDeviceAllocator, OrtMemTypeDefault);

View File

@@ -1,4 +1,4 @@
// sherpa/csrc/online-lstm-transducer-model.h
// sherpa-onnx/csrc/online-lstm-transducer-model.h
//
// Copyright (c) 2023 Xiaomi Corporation
#ifndef SHERPA_ONNX_CSRC_ONLINE_LSTM_TRANSDUCER_MODEL_H_
@@ -28,7 +28,7 @@ class OnlineLstmTransducerModel : public OnlineTransducerModel {
std::vector<Ort::Value> GetEncoderInitStates() override;
std::pair<Ort::Value, std::vector<Ort::Value>> RunEncoder(
Ort::Value features, std::vector<Ort::Value> &states) override;
Ort::Value features, std::vector<Ort::Value> states) override;
Ort::Value BuildDecoderInput(
const std::vector<OnlineTransducerDecoderResult> &results) override;

View File

@@ -98,7 +98,7 @@ class OnlineRecognizer::Impl {
auto states = model_->StackStates(states_vec);
auto pair = model_->RunEncoder(std::move(x), states);
auto pair = model_->RunEncoder(std::move(x), std::move(states));
decoder_->Decode(std::move(pair.first), &results);

View File

@@ -23,6 +23,13 @@ struct OnlineRecognizerConfig {
OnlineTransducerModelConfig model_config;
std::string tokens;
OnlineRecognizerConfig() = default;
OnlineRecognizerConfig(const FeatureExtractorConfig &feat_config,
const OnlineTransducerModelConfig &model_config,
const std::string &tokens)
: feat_config(feat_config), model_config(model_config), tokens(tokens) {}
std::string ToString() const;
};

View File

@@ -1,4 +1,4 @@
// sherpa/csrc/online-transducer-decoder.h
// sherpa-onnx/csrc/online-transducer-decoder.h
//
// Copyright (c) 2023 Xiaomi Corporation

View File

@@ -1,4 +1,4 @@
// sherpa/csrc/online-transducer-greedy-search-decoder.cc
// sherpa-onnx/csrc/online-transducer-greedy-search-decoder.cc
//
// Copyright (c) 2023 Xiaomi Corporation

View File

@@ -1,4 +1,4 @@
// sherpa/csrc/online-transducer-greedy-search-decoder.h
// sherpa-onnx/csrc/online-transducer-greedy-search-decoder.h
//
// Copyright (c) 2023 Xiaomi Corporation

View File

@@ -1,4 +1,4 @@
// sherpa/csrc/online-transducer-model-config.cc
// sherpa-onnx/csrc/online-transducer-model-config.cc
//
// Copyright (c) 2023 Xiaomi Corporation
#include "sherpa-onnx/csrc/online-transducer-model-config.h"

View File

@@ -1,4 +1,4 @@
// sherpa/csrc/online-transducer-model-config.h
// sherpa-onnx/csrc/online-transducer-model-config.h
//
// Copyright (c) 2023 Xiaomi Corporation
#ifndef SHERPA_ONNX_CSRC_ONLINE_TRANSDUCER_MODEL_CONFIG_H_
@@ -15,6 +15,17 @@ struct OnlineTransducerModelConfig {
int32_t num_threads;
bool debug = false;
OnlineTransducerModelConfig() = default;
OnlineTransducerModelConfig(const std::string &encoder_filename,
const std::string &decoder_filename,
const std::string &joiner_filename,
int32_t num_threads, bool debug)
: encoder_filename(encoder_filename),
decoder_filename(decoder_filename),
joiner_filename(joiner_filename),
num_threads(num_threads),
debug(debug) {}
std::string ToString() const;
};

View File

@@ -1,4 +1,4 @@
// sherpa/csrc/online-transducer-model.cc
// sherpa-onnx/csrc/online-transducer-model.cc
//
// Copyright (c) 2023 Xiaomi Corporation
#include "sherpa-onnx/csrc/online-transducer-model.h"

View File

@@ -1,4 +1,4 @@
// sherpa/csrc/online-transducer-model.h
// sherpa-onnx/csrc/online-transducer-model.h
//
// Copyright (c) 2023 Xiaomi Corporation
#ifndef SHERPA_ONNX_CSRC_ONLINE_TRANSDUCER_MODEL_H_
@@ -59,7 +59,7 @@ class OnlineTransducerModel {
*/
virtual std::pair<Ort::Value, std::vector<Ort::Value>> RunEncoder(
Ort::Value features,
std::vector<Ort::Value> &states) = 0; // NOLINT
std::vector<Ort::Value> states) = 0; // NOLINT
virtual Ort::Value BuildDecoderInput(
const std::vector<OnlineTransducerDecoderResult> &results) = 0;

View File

@@ -1,4 +1,4 @@
// sherpa/csrc/onnx-utils.cc
// sherpa-onnx/csrc/onnx-utils.cc
//
// Copyright (c) 2023 Xiaomi Corporation
#include "sherpa-onnx/csrc/onnx-utils.h"

View File

@@ -1,4 +1,4 @@
// sherpa/csrc/onnx-utils.h
// sherpa-onnx/csrc/onnx-utils.h
//
// Copyright (c) 2023 Xiaomi Corporation
#ifndef SHERPA_ONNX_CSRC_ONNX_UTILS_H_

View File

@@ -1,22 +0,0 @@
// sherpa-onnx/csrc/show-onnx-info.cc
//
// Copyright (c) 2022-2023 Xiaomi Corporation
#include <iostream>
#include <sstream>
#include "onnxruntime_cxx_api.h" // NOLINT
int main() {
std::cout << "ORT_API_VERSION: " << ORT_API_VERSION << "\n";
std::vector<std::string> providers = Ort::GetAvailableProviders();
std::ostringstream os;
os << "Available providers: ";
std::string sep = "";
for (const auto &p : providers) {
os << sep << p;
sep = ", ";
}
std::cout << os.str() << "\n";
return 0;
}

View File

@@ -1,4 +1,4 @@
// sherpa-onnx/csrc/symbol-table.cc
// sherpa-onnx/csrc/symbol-table.h
//
// Copyright (c) 2022-2023 Xiaomi Corporation

View File

@@ -1,4 +1,4 @@
// sherpa/csrc/wave-reader.cc
// sherpa-onnx/csrc/wave-reader.cc
//
// Copyright (c) 2023 Xiaomi Corporation

View File

@@ -1,4 +1,4 @@
// sherpa/csrc/wave-reader.h
// sherpa-onnx/csrc/wave-reader.h
//
// Copyright (c) 2023 Xiaomi Corporation