Add C++ API for streaming ASR. (#1455)
It is a wrapper around the C API.
This commit is contained in:
@@ -3,12 +3,25 @@ add_library(sherpa-onnx-c-api c-api.cc)
|
||||
target_link_libraries(sherpa-onnx-c-api sherpa-onnx-core)
|
||||
|
||||
if(BUILD_SHARED_LIBS)
|
||||
target_compile_definitions(sherpa-onnx-c-api PRIVATE SHERPA_ONNX_BUILD_SHARED_LIBS=1)
|
||||
target_compile_definitions(sherpa-onnx-c-api PRIVATE SHERPA_ONNX_BUILD_MAIN_LIB=1)
|
||||
target_compile_definitions(sherpa-onnx-c-api PUBLIC SHERPA_ONNX_BUILD_SHARED_LIBS=1)
|
||||
target_compile_definitions(sherpa-onnx-c-api PUBLIC SHERPA_ONNX_BUILD_MAIN_LIB=1)
|
||||
endif()
|
||||
|
||||
install(TARGETS sherpa-onnx-c-api DESTINATION lib)
|
||||
add_library(sherpa-onnx-cxx-api cxx-api.cc)
|
||||
target_link_libraries(sherpa-onnx-cxx-api sherpa-onnx-c-api)
|
||||
|
||||
install(FILES c-api.h
|
||||
DESTINATION include/sherpa-onnx/c-api
|
||||
install(
|
||||
TARGETS
|
||||
sherpa-onnx-c-api
|
||||
sherpa-onnx-cxx-api
|
||||
DESTINATION
|
||||
lib
|
||||
)
|
||||
|
||||
install(
|
||||
FILES
|
||||
c-api.h
|
||||
cxx-api.h
|
||||
DESTINATION
|
||||
include/sherpa-onnx/c-api
|
||||
)
|
||||
|
||||
@@ -5,6 +5,7 @@
|
||||
#include "sherpa-onnx/c-api/c-api.h"
|
||||
|
||||
#include <algorithm>
|
||||
#include <cstring>
|
||||
#include <memory>
|
||||
#include <string>
|
||||
#include <utility>
|
||||
@@ -51,7 +52,7 @@ struct SherpaOnnxDisplay {
|
||||
|
||||
#define SHERPA_ONNX_OR(x, y) (x ? x : y)
|
||||
|
||||
SherpaOnnxOnlineRecognizer *SherpaOnnxCreateOnlineRecognizer(
|
||||
const SherpaOnnxOnlineRecognizer *SherpaOnnxCreateOnlineRecognizer(
|
||||
const SherpaOnnxOnlineRecognizerConfig *config) {
|
||||
sherpa_onnx::OnlineRecognizerConfig recognizer_config;
|
||||
|
||||
|
||||
@@ -205,7 +205,8 @@ SHERPA_ONNX_API typedef struct SherpaOnnxOnlineStream SherpaOnnxOnlineStream;
|
||||
/// @param config Config for the recognizer.
|
||||
/// @return Return a pointer to the recognizer. The user has to invoke
|
||||
// SherpaOnnxDestroyOnlineRecognizer() to free it to avoid memory leak.
|
||||
SHERPA_ONNX_API SherpaOnnxOnlineRecognizer *SherpaOnnxCreateOnlineRecognizer(
|
||||
SHERPA_ONNX_API const SherpaOnnxOnlineRecognizer *
|
||||
SherpaOnnxCreateOnlineRecognizer(
|
||||
const SherpaOnnxOnlineRecognizerConfig *config);
|
||||
|
||||
/// Free a pointer returned by SherpaOnnxCreateOnlineRecognizer()
|
||||
|
||||
159
sherpa-onnx/c-api/cxx-api.cc
Normal file
159
sherpa-onnx/c-api/cxx-api.cc
Normal file
@@ -0,0 +1,159 @@
|
||||
// sherpa-onnx/c-api/cxx-api.cc
|
||||
//
|
||||
// Copyright (c) 2024 Xiaomi Corporation
|
||||
#include "sherpa-onnx/c-api/cxx-api.h"
|
||||
|
||||
#include <algorithm>
|
||||
#include <cstring>
|
||||
|
||||
namespace sherpa_onnx::cxx {
|
||||
|
||||
Wave ReadWave(const std::string &filename) {
|
||||
auto p = SherpaOnnxReadWave(filename.c_str());
|
||||
|
||||
Wave ans;
|
||||
if (p) {
|
||||
ans.samples.resize(p->num_samples);
|
||||
|
||||
std::copy(p->samples, p->samples + p->num_samples, ans.samples.data());
|
||||
|
||||
ans.sample_rate = p->sample_rate;
|
||||
SherpaOnnxFreeWave(p);
|
||||
}
|
||||
|
||||
return ans;
|
||||
}
|
||||
|
||||
OnlineStream::OnlineStream(const SherpaOnnxOnlineStream *p)
|
||||
: MoveOnly<OnlineStream, SherpaOnnxOnlineStream>(p) {}
|
||||
|
||||
void OnlineStream::Destroy(const SherpaOnnxOnlineStream *p) const {
|
||||
SherpaOnnxDestroyOnlineStream(p);
|
||||
}
|
||||
|
||||
void OnlineStream::AcceptWaveform(int32_t sample_rate, const float *samples,
|
||||
int32_t n) const {
|
||||
SherpaOnnxOnlineStreamAcceptWaveform(p_, sample_rate, samples, n);
|
||||
}
|
||||
|
||||
OnlineRecognizer OnlineRecognizer::Create(
|
||||
const OnlineRecognizerConfig &config) {
|
||||
struct SherpaOnnxOnlineRecognizerConfig c;
|
||||
memset(&c, 0, sizeof(c));
|
||||
|
||||
c.feat_config.sample_rate = config.feat_config.sample_rate;
|
||||
c.feat_config.feature_dim = config.feat_config.feature_dim;
|
||||
|
||||
c.model_config.transducer.encoder =
|
||||
config.model_config.transducer.encoder.c_str();
|
||||
c.model_config.transducer.decoder =
|
||||
config.model_config.transducer.decoder.c_str();
|
||||
c.model_config.transducer.joiner =
|
||||
config.model_config.transducer.joiner.c_str();
|
||||
|
||||
c.model_config.paraformer.encoder =
|
||||
config.model_config.paraformer.encoder.c_str();
|
||||
c.model_config.paraformer.decoder =
|
||||
config.model_config.paraformer.decoder.c_str();
|
||||
|
||||
c.model_config.zipformer2_ctc.model =
|
||||
config.model_config.zipformer2_ctc.model.c_str();
|
||||
|
||||
c.model_config.tokens = config.model_config.tokens.c_str();
|
||||
c.model_config.num_threads = config.model_config.num_threads;
|
||||
c.model_config.provider = config.model_config.provider.c_str();
|
||||
c.model_config.debug = config.model_config.debug;
|
||||
c.model_config.model_type = config.model_config.model_type.c_str();
|
||||
c.model_config.modeling_unit = config.model_config.modeling_unit.c_str();
|
||||
c.model_config.bpe_vocab = config.model_config.bpe_vocab.c_str();
|
||||
c.model_config.tokens_buf = config.model_config.tokens_buf.c_str();
|
||||
c.model_config.tokens_buf_size = config.model_config.tokens_buf.size();
|
||||
|
||||
c.decoding_method = config.decoding_method.c_str();
|
||||
c.max_active_paths = config.max_active_paths;
|
||||
c.enable_endpoint = config.enable_endpoint;
|
||||
c.rule1_min_trailing_silence = config.rule1_min_trailing_silence;
|
||||
c.rule2_min_trailing_silence = config.rule2_min_trailing_silence;
|
||||
c.rule3_min_utterance_length = config.rule3_min_utterance_length;
|
||||
c.hotwords_file = config.hotwords_file.c_str();
|
||||
c.hotwords_score = config.hotwords_score;
|
||||
|
||||
c.ctc_fst_decoder_config.graph = config.ctc_fst_decoder_config.graph.c_str();
|
||||
c.ctc_fst_decoder_config.max_active =
|
||||
config.ctc_fst_decoder_config.max_active;
|
||||
|
||||
c.rule_fsts = config.rule_fsts.c_str();
|
||||
c.rule_fars = config.rule_fars.c_str();
|
||||
|
||||
c.blank_penalty = config.blank_penalty;
|
||||
|
||||
c.hotwords_buf = config.hotwords_buf.c_str();
|
||||
c.hotwords_buf_size = config.hotwords_buf.size();
|
||||
|
||||
auto p = SherpaOnnxCreateOnlineRecognizer(&c);
|
||||
return OnlineRecognizer(p);
|
||||
}
|
||||
|
||||
OnlineRecognizer::OnlineRecognizer(const SherpaOnnxOnlineRecognizer *p)
|
||||
: MoveOnly<OnlineRecognizer, SherpaOnnxOnlineRecognizer>(p) {}
|
||||
|
||||
void OnlineRecognizer::Destroy(const SherpaOnnxOnlineRecognizer *p) const {
|
||||
SherpaOnnxDestroyOnlineRecognizer(p);
|
||||
}
|
||||
|
||||
OnlineStream OnlineRecognizer::CreateStream() const {
|
||||
auto s = SherpaOnnxCreateOnlineStream(p_);
|
||||
return OnlineStream{s};
|
||||
}
|
||||
|
||||
OnlineStream OnlineRecognizer::CreateStream(const std::string &hotwords) const {
|
||||
auto s = SherpaOnnxCreateOnlineStreamWithHotwords(p_, hotwords.c_str());
|
||||
return OnlineStream{s};
|
||||
}
|
||||
|
||||
bool OnlineRecognizer::IsReady(const OnlineStream *s) const {
|
||||
return SherpaOnnxIsOnlineStreamReady(p_, s->Get());
|
||||
}
|
||||
|
||||
void OnlineRecognizer::Decode(const OnlineStream *s) const {
|
||||
SherpaOnnxDecodeOnlineStream(p_, s->Get());
|
||||
}
|
||||
|
||||
void OnlineRecognizer::Decode(const OnlineStream *ss, int32_t n) const {
|
||||
if (n <= 0) {
|
||||
return;
|
||||
}
|
||||
|
||||
std::vector<const SherpaOnnxOnlineStream *> streams(n);
|
||||
for (int32_t i = 0; i != n; ++n) {
|
||||
streams[i] = ss[i].Get();
|
||||
}
|
||||
|
||||
SherpaOnnxDecodeMultipleOnlineStreams(p_, streams.data(), n);
|
||||
}
|
||||
|
||||
OnlineRecognizerResult OnlineRecognizer::GetResult(
|
||||
const OnlineStream *s) const {
|
||||
auto r = SherpaOnnxGetOnlineStreamResult(p_, s->Get());
|
||||
|
||||
OnlineRecognizerResult ans;
|
||||
ans.text = r->text;
|
||||
|
||||
ans.tokens.resize(r->count);
|
||||
for (int32_t i = 0; i != r->count; ++i) {
|
||||
ans.tokens[i] = r->tokens_arr[i];
|
||||
}
|
||||
|
||||
if (r->timestamps) {
|
||||
ans.timestamps.resize(r->count);
|
||||
std::copy(r->timestamps, r->timestamps + r->count, ans.timestamps.data());
|
||||
}
|
||||
|
||||
ans.json = r->json;
|
||||
|
||||
SherpaOnnxDestroyOnlineRecognizerResult(r);
|
||||
|
||||
return ans;
|
||||
}
|
||||
|
||||
} // namespace sherpa_onnx::cxx
|
||||
179
sherpa-onnx/c-api/cxx-api.h
Normal file
179
sherpa-onnx/c-api/cxx-api.h
Normal file
@@ -0,0 +1,179 @@
|
||||
// sherpa-onnx/c-api/cxx-api.h
|
||||
//
|
||||
// Copyright (c) 2024 Xiaomi Corporation
|
||||
|
||||
// C++ Wrapper of the C API for sherpa-onnx
|
||||
#ifndef SHERPA_ONNX_C_API_CXX_API_H_
|
||||
#define SHERPA_ONNX_C_API_CXX_API_H_
|
||||
|
||||
#include <string>
|
||||
#include <vector>
|
||||
|
||||
#include "sherpa-onnx/c-api/c-api.h"
|
||||
|
||||
namespace sherpa_onnx::cxx {
|
||||
|
||||
struct SHERPA_ONNX_API OnlineTransducerModelConfig {
|
||||
std::string encoder;
|
||||
std::string decoder;
|
||||
std::string joiner;
|
||||
};
|
||||
|
||||
struct SHERPA_ONNX_API OnlineParaformerModelConfig {
|
||||
std::string encoder;
|
||||
std::string decoder;
|
||||
};
|
||||
|
||||
struct SHERPA_ONNX_API OnlineZipformer2CtcModelConfig {
|
||||
std::string model;
|
||||
};
|
||||
|
||||
struct SHERPA_ONNX_API OnlineModelConfig {
|
||||
OnlineTransducerModelConfig transducer;
|
||||
OnlineParaformerModelConfig paraformer;
|
||||
OnlineZipformer2CtcModelConfig zipformer2_ctc;
|
||||
std::string tokens;
|
||||
int32_t num_threads = 1;
|
||||
std::string provider = "cpu";
|
||||
bool debug = false;
|
||||
std::string model_type;
|
||||
std::string modeling_unit = "cjkchar";
|
||||
std::string bpe_vocab;
|
||||
std::string tokens_buf;
|
||||
};
|
||||
|
||||
struct SHERPA_ONNX_API FeatureConfig {
|
||||
int32_t sample_rate = 16000;
|
||||
int32_t feature_dim = 80;
|
||||
};
|
||||
|
||||
struct SHERPA_ONNX_API OnlineCtcFstDecoderConfig {
|
||||
std::string graph;
|
||||
int32_t max_active = 3000;
|
||||
};
|
||||
|
||||
struct SHERPA_ONNX_API OnlineRecognizerConfig {
|
||||
FeatureConfig feat_config;
|
||||
OnlineModelConfig model_config;
|
||||
|
||||
std::string decoding_method = "greedy_search";
|
||||
|
||||
int32_t max_active_paths = 4;
|
||||
|
||||
bool enable_endpoint = false;
|
||||
|
||||
float rule1_min_trailing_silence = 2.4;
|
||||
|
||||
float rule2_min_trailing_silence = 1.2;
|
||||
|
||||
float rule3_min_utterance_length = 20;
|
||||
|
||||
std::string hotwords_file;
|
||||
|
||||
float hotwords_score = 1.5;
|
||||
|
||||
OnlineCtcFstDecoderConfig ctc_fst_decoder_config;
|
||||
std::string rule_fsts;
|
||||
std::string rule_fars;
|
||||
float blank_penalty = 0;
|
||||
|
||||
std::string hotwords_buf;
|
||||
};
|
||||
|
||||
struct SHERPA_ONNX_API OnlineRecognizerResult {
|
||||
std::string text;
|
||||
std::vector<std::string> tokens;
|
||||
std::vector<float> timestamps;
|
||||
std::string json;
|
||||
};
|
||||
|
||||
struct SHERPA_ONNX_API Wave {
|
||||
std::vector<float> samples;
|
||||
int32_t sample_rate;
|
||||
};
|
||||
|
||||
SHERPA_ONNX_API Wave ReadWave(const std::string &filename);
|
||||
|
||||
template <typename Derived, typename T>
|
||||
class SHERPA_ONNX_API MoveOnly {
|
||||
public:
|
||||
explicit MoveOnly(const T *p) : p_(p) {}
|
||||
|
||||
~MoveOnly() { Destroy(); }
|
||||
|
||||
MoveOnly(const MoveOnly &) = delete;
|
||||
|
||||
MoveOnly &operator=(const MoveOnly &) = delete;
|
||||
|
||||
MoveOnly(MoveOnly &&other) : p_(other.Release()) {}
|
||||
|
||||
MoveOnly &operator=(MoveOnly &&other) {
|
||||
if (&other == this) {
|
||||
return *this;
|
||||
}
|
||||
|
||||
Destroy();
|
||||
|
||||
p_ = other.Release();
|
||||
}
|
||||
|
||||
const T *Get() const { return p_; }
|
||||
|
||||
const T *Release() {
|
||||
const T *p = p_;
|
||||
p_ = nullptr;
|
||||
return p;
|
||||
}
|
||||
|
||||
private:
|
||||
void Destroy() {
|
||||
if (p_ == nullptr) {
|
||||
return;
|
||||
}
|
||||
|
||||
static_cast<Derived *>(this)->Destroy(p_);
|
||||
|
||||
p_ = nullptr;
|
||||
}
|
||||
|
||||
protected:
|
||||
const T *p_ = nullptr;
|
||||
};
|
||||
|
||||
class SHERPA_ONNX_API OnlineStream
|
||||
: public MoveOnly<OnlineStream, SherpaOnnxOnlineStream> {
|
||||
public:
|
||||
explicit OnlineStream(const SherpaOnnxOnlineStream *p);
|
||||
|
||||
void AcceptWaveform(int32_t sample_rate, const float *samples,
|
||||
int32_t n) const;
|
||||
|
||||
void Destroy(const SherpaOnnxOnlineStream *p) const;
|
||||
};
|
||||
|
||||
class SHERPA_ONNX_API OnlineRecognizer
|
||||
: public MoveOnly<OnlineRecognizer, SherpaOnnxOnlineRecognizer> {
|
||||
public:
|
||||
static OnlineRecognizer Create(const OnlineRecognizerConfig &config);
|
||||
|
||||
void Destroy(const SherpaOnnxOnlineRecognizer *p) const;
|
||||
|
||||
OnlineStream CreateStream() const;
|
||||
|
||||
OnlineStream CreateStream(const std::string &hotwords) const;
|
||||
|
||||
bool IsReady(const OnlineStream *s) const;
|
||||
|
||||
void Decode(const OnlineStream *s) const;
|
||||
|
||||
void Decode(const OnlineStream *ss, int32_t n) const;
|
||||
|
||||
OnlineRecognizerResult GetResult(const OnlineStream *s) const;
|
||||
|
||||
private:
|
||||
explicit OnlineRecognizer(const SherpaOnnxOnlineRecognizer *p);
|
||||
};
|
||||
|
||||
} // namespace sherpa_onnx::cxx
|
||||
|
||||
#endif // SHERPA_ONNX_C_API_CXX_API_H_
|
||||
Reference in New Issue
Block a user