Add Python API (#31)
This commit is contained in:
@@ -1,6 +1,6 @@
|
||||
include_directories(${CMAKE_SOURCE_DIR})
|
||||
|
||||
add_executable(sherpa-onnx
|
||||
add_library(sherpa-onnx-core
|
||||
features.cc
|
||||
online-lstm-transducer-model.cc
|
||||
online-recognizer.cc
|
||||
@@ -9,15 +9,21 @@ add_executable(sherpa-onnx
|
||||
online-transducer-model-config.cc
|
||||
online-transducer-model.cc
|
||||
onnx-utils.cc
|
||||
sherpa-onnx.cc
|
||||
symbol-table.cc
|
||||
wave-reader.cc
|
||||
)
|
||||
|
||||
target_link_libraries(sherpa-onnx
|
||||
target_link_libraries(sherpa-onnx-core
|
||||
onnxruntime
|
||||
kaldi-native-fbank-core
|
||||
)
|
||||
|
||||
add_executable(sherpa-onnx-show-info show-onnx-info.cc)
|
||||
target_link_libraries(sherpa-onnx-show-info onnxruntime)
|
||||
add_executable(sherpa-onnx sherpa-onnx.cc)
|
||||
|
||||
target_link_libraries(sherpa-onnx sherpa-onnx-core)
|
||||
if(NOT WIN32)
|
||||
target_link_libraries(sherpa-onnx "-Wl,-rpath,${SHERPA_ONNX_RPATH_ORIGIN}/../lib")
|
||||
endif()
|
||||
|
||||
install(TARGETS sherpa-onnx-core DESTINATION lib)
|
||||
install(TARGETS sherpa-onnx DESTINATION bin)
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
// sherpa/csrc/features.cc
|
||||
// sherpa-onnx/csrc/features.cc
|
||||
//
|
||||
// Copyright (c) 2023 Xiaomi Corporation
|
||||
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
// sherpa/csrc/features.h
|
||||
// sherpa-onnx/csrc/features.h
|
||||
//
|
||||
// Copyright (c) 2023 Xiaomi Corporation
|
||||
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
// sherpa/csrc/online-lstm-transducer-model.cc
|
||||
// sherpa-onnx/csrc/online-lstm-transducer-model.cc
|
||||
//
|
||||
// Copyright (c) 2023 Xiaomi Corporation
|
||||
#include "sherpa-onnx/csrc/online-lstm-transducer-model.h"
|
||||
@@ -232,7 +232,7 @@ std::vector<Ort::Value> OnlineLstmTransducerModel::GetEncoderInitStates() {
|
||||
|
||||
std::pair<Ort::Value, std::vector<Ort::Value>>
|
||||
OnlineLstmTransducerModel::RunEncoder(Ort::Value features,
|
||||
std::vector<Ort::Value> &states) {
|
||||
std::vector<Ort::Value> states) {
|
||||
auto memory_info =
|
||||
Ort::MemoryInfo::CreateCpu(OrtDeviceAllocator, OrtMemTypeDefault);
|
||||
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
// sherpa/csrc/online-lstm-transducer-model.h
|
||||
// sherpa-onnx/csrc/online-lstm-transducer-model.h
|
||||
//
|
||||
// Copyright (c) 2023 Xiaomi Corporation
|
||||
#ifndef SHERPA_ONNX_CSRC_ONLINE_LSTM_TRANSDUCER_MODEL_H_
|
||||
@@ -28,7 +28,7 @@ class OnlineLstmTransducerModel : public OnlineTransducerModel {
|
||||
std::vector<Ort::Value> GetEncoderInitStates() override;
|
||||
|
||||
std::pair<Ort::Value, std::vector<Ort::Value>> RunEncoder(
|
||||
Ort::Value features, std::vector<Ort::Value> &states) override;
|
||||
Ort::Value features, std::vector<Ort::Value> states) override;
|
||||
|
||||
Ort::Value BuildDecoderInput(
|
||||
const std::vector<OnlineTransducerDecoderResult> &results) override;
|
||||
|
||||
@@ -98,7 +98,7 @@ class OnlineRecognizer::Impl {
|
||||
|
||||
auto states = model_->StackStates(states_vec);
|
||||
|
||||
auto pair = model_->RunEncoder(std::move(x), states);
|
||||
auto pair = model_->RunEncoder(std::move(x), std::move(states));
|
||||
|
||||
decoder_->Decode(std::move(pair.first), &results);
|
||||
|
||||
|
||||
@@ -23,6 +23,13 @@ struct OnlineRecognizerConfig {
|
||||
OnlineTransducerModelConfig model_config;
|
||||
std::string tokens;
|
||||
|
||||
OnlineRecognizerConfig() = default;
|
||||
|
||||
OnlineRecognizerConfig(const FeatureExtractorConfig &feat_config,
|
||||
const OnlineTransducerModelConfig &model_config,
|
||||
const std::string &tokens)
|
||||
: feat_config(feat_config), model_config(model_config), tokens(tokens) {}
|
||||
|
||||
std::string ToString() const;
|
||||
};
|
||||
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
// sherpa/csrc/online-transducer-decoder.h
|
||||
// sherpa-onnx/csrc/online-transducer-decoder.h
|
||||
//
|
||||
// Copyright (c) 2023 Xiaomi Corporation
|
||||
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
// sherpa/csrc/online-transducer-greedy-search-decoder.cc
|
||||
// sherpa-onnx/csrc/online-transducer-greedy-search-decoder.cc
|
||||
//
|
||||
// Copyright (c) 2023 Xiaomi Corporation
|
||||
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
// sherpa/csrc/online-transducer-greedy-search-decoder.h
|
||||
// sherpa-onnx/csrc/online-transducer-greedy-search-decoder.h
|
||||
//
|
||||
// Copyright (c) 2023 Xiaomi Corporation
|
||||
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
// sherpa/csrc/online-transducer-model-config.cc
|
||||
// sherpa-onnx/csrc/online-transducer-model-config.cc
|
||||
//
|
||||
// Copyright (c) 2023 Xiaomi Corporation
|
||||
#include "sherpa-onnx/csrc/online-transducer-model-config.h"
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
// sherpa/csrc/online-transducer-model-config.h
|
||||
// sherpa-onnx/csrc/online-transducer-model-config.h
|
||||
//
|
||||
// Copyright (c) 2023 Xiaomi Corporation
|
||||
#ifndef SHERPA_ONNX_CSRC_ONLINE_TRANSDUCER_MODEL_CONFIG_H_
|
||||
@@ -15,6 +15,17 @@ struct OnlineTransducerModelConfig {
|
||||
int32_t num_threads;
|
||||
bool debug = false;
|
||||
|
||||
OnlineTransducerModelConfig() = default;
|
||||
OnlineTransducerModelConfig(const std::string &encoder_filename,
|
||||
const std::string &decoder_filename,
|
||||
const std::string &joiner_filename,
|
||||
int32_t num_threads, bool debug)
|
||||
: encoder_filename(encoder_filename),
|
||||
decoder_filename(decoder_filename),
|
||||
joiner_filename(joiner_filename),
|
||||
num_threads(num_threads),
|
||||
debug(debug) {}
|
||||
|
||||
std::string ToString() const;
|
||||
};
|
||||
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
// sherpa/csrc/online-transducer-model.cc
|
||||
// sherpa-onnx/csrc/online-transducer-model.cc
|
||||
//
|
||||
// Copyright (c) 2023 Xiaomi Corporation
|
||||
#include "sherpa-onnx/csrc/online-transducer-model.h"
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
// sherpa/csrc/online-transducer-model.h
|
||||
// sherpa-onnx/csrc/online-transducer-model.h
|
||||
//
|
||||
// Copyright (c) 2023 Xiaomi Corporation
|
||||
#ifndef SHERPA_ONNX_CSRC_ONLINE_TRANSDUCER_MODEL_H_
|
||||
@@ -59,7 +59,7 @@ class OnlineTransducerModel {
|
||||
*/
|
||||
virtual std::pair<Ort::Value, std::vector<Ort::Value>> RunEncoder(
|
||||
Ort::Value features,
|
||||
std::vector<Ort::Value> &states) = 0; // NOLINT
|
||||
std::vector<Ort::Value> states) = 0; // NOLINT
|
||||
|
||||
virtual Ort::Value BuildDecoderInput(
|
||||
const std::vector<OnlineTransducerDecoderResult> &results) = 0;
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
// sherpa/csrc/onnx-utils.cc
|
||||
// sherpa-onnx/csrc/onnx-utils.cc
|
||||
//
|
||||
// Copyright (c) 2023 Xiaomi Corporation
|
||||
#include "sherpa-onnx/csrc/onnx-utils.h"
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
// sherpa/csrc/onnx-utils.h
|
||||
// sherpa-onnx/csrc/onnx-utils.h
|
||||
//
|
||||
// Copyright (c) 2023 Xiaomi Corporation
|
||||
#ifndef SHERPA_ONNX_CSRC_ONNX_UTILS_H_
|
||||
|
||||
@@ -1,22 +0,0 @@
|
||||
// sherpa-onnx/csrc/show-onnx-info.cc
|
||||
//
|
||||
// Copyright (c) 2022-2023 Xiaomi Corporation
|
||||
|
||||
#include <iostream>
|
||||
#include <sstream>
|
||||
|
||||
#include "onnxruntime_cxx_api.h" // NOLINT
|
||||
|
||||
int main() {
|
||||
std::cout << "ORT_API_VERSION: " << ORT_API_VERSION << "\n";
|
||||
std::vector<std::string> providers = Ort::GetAvailableProviders();
|
||||
std::ostringstream os;
|
||||
os << "Available providers: ";
|
||||
std::string sep = "";
|
||||
for (const auto &p : providers) {
|
||||
os << sep << p;
|
||||
sep = ", ";
|
||||
}
|
||||
std::cout << os.str() << "\n";
|
||||
return 0;
|
||||
}
|
||||
@@ -1,4 +1,4 @@
|
||||
// sherpa-onnx/csrc/symbol-table.cc
|
||||
// sherpa-onnx/csrc/symbol-table.h
|
||||
//
|
||||
// Copyright (c) 2022-2023 Xiaomi Corporation
|
||||
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
// sherpa/csrc/wave-reader.cc
|
||||
// sherpa-onnx/csrc/wave-reader.cc
|
||||
//
|
||||
// Copyright (c) 2023 Xiaomi Corporation
|
||||
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
// sherpa/csrc/wave-reader.h
|
||||
// sherpa-onnx/csrc/wave-reader.h
|
||||
//
|
||||
// Copyright (c) 2023 Xiaomi Corporation
|
||||
|
||||
|
||||
Reference in New Issue
Block a user