Add swift online punctuation (#1661)
This commit is contained in:
@@ -24,6 +24,7 @@
|
||||
#include "sherpa-onnx/csrc/macros.h"
|
||||
#include "sherpa-onnx/csrc/offline-punctuation.h"
|
||||
#include "sherpa-onnx/csrc/offline-recognizer.h"
|
||||
#include "sherpa-onnx/csrc/online-punctuation.h"
|
||||
#include "sherpa-onnx/csrc/online-recognizer.h"
|
||||
#include "sherpa-onnx/csrc/resample.h"
|
||||
#include "sherpa-onnx/csrc/speaker-embedding-extractor.h"
|
||||
@@ -1717,6 +1718,53 @@ const char *SherpaOfflinePunctuationAddPunct(
|
||||
|
||||
void SherpaOfflinePunctuationFreeText(const char *text) { delete[] text; }
|
||||
|
||||
struct SherpaOnnxOnlinePunctuation {
|
||||
std::unique_ptr<sherpa_onnx::OnlinePunctuation> impl;
|
||||
};
|
||||
|
||||
const SherpaOnnxOnlinePunctuation *SherpaOnnxCreateOnlinePunctuation(
|
||||
const SherpaOnnxOnlinePunctuationConfig *config) {
|
||||
auto p = new SherpaOnnxOnlinePunctuation;
|
||||
try {
|
||||
sherpa_onnx::OnlinePunctuationConfig punctuation_config;
|
||||
punctuation_config.model.cnn_bilstm = SHERPA_ONNX_OR(config->model.cnn_bilstm, "");
|
||||
punctuation_config.model.bpe_vocab = SHERPA_ONNX_OR(config->model.bpe_vocab, "");
|
||||
punctuation_config.model.num_threads = SHERPA_ONNX_OR(config->model.num_threads, 1);
|
||||
punctuation_config.model.debug = config->model.debug;
|
||||
punctuation_config.model.provider = SHERPA_ONNX_OR(config->model.provider, "cpu");
|
||||
|
||||
p->impl =
|
||||
std::make_unique<sherpa_onnx::OnlinePunctuation>(punctuation_config);
|
||||
} catch (const std::exception &e) {
|
||||
SHERPA_ONNX_LOGE("Failed to create online punctuation: %s", e.what());
|
||||
delete p;
|
||||
return nullptr;
|
||||
}
|
||||
return p;
|
||||
}
|
||||
|
||||
void SherpaOnnxDestroyOnlinePunctuation(const SherpaOnnxOnlinePunctuation *p) {
|
||||
delete p;
|
||||
}
|
||||
|
||||
const char *SherpaOnnxOnlinePunctuationAddPunct(
|
||||
const SherpaOnnxOnlinePunctuation *punctuation, const char *text) {
|
||||
if (!punctuation || !text) return nullptr;
|
||||
|
||||
try {
|
||||
std::string s = punctuation->impl->AddPunctuationWithCase(text);
|
||||
char *p = new char[s.size() + 1];
|
||||
std::copy(s.begin(), s.end(), p);
|
||||
p[s.size()] = '\0';
|
||||
return p;
|
||||
} catch (const std::exception &e) {
|
||||
SHERPA_ONNX_LOGE("Failed to add punctuation: %s", e.what());
|
||||
return nullptr;
|
||||
}
|
||||
}
|
||||
|
||||
void SherpaOnnxOnlinePunctuationFreeText(const char *text) { delete[] text; }
|
||||
|
||||
struct SherpaOnnxLinearResampler {
|
||||
std::unique_ptr<sherpa_onnx::LinearResample> impl;
|
||||
};
|
||||
|
||||
@@ -1369,6 +1369,39 @@ SHERPA_ONNX_API const char *SherpaOfflinePunctuationAddPunct(
|
||||
|
||||
SHERPA_ONNX_API void SherpaOfflinePunctuationFreeText(const char *text);
|
||||
|
||||
SHERPA_ONNX_API typedef struct SherpaOnnxOnlinePunctuationModelConfig {
|
||||
const char *cnn_bilstm;
|
||||
const char *bpe_vocab;
|
||||
int32_t num_threads;
|
||||
int32_t debug;
|
||||
const char *provider;
|
||||
} SherpaOnnxOnlinePunctuationModelConfig;
|
||||
|
||||
SHERPA_ONNX_API typedef struct SherpaOnnxOnlinePunctuationConfig {
|
||||
SherpaOnnxOnlinePunctuationModelConfig model;
|
||||
} SherpaOnnxOnlinePunctuationConfig;
|
||||
|
||||
SHERPA_ONNX_API typedef struct SherpaOnnxOnlinePunctuation SherpaOnnxOnlinePunctuation;
|
||||
|
||||
// Create an online punctuation processor. The user has to invoke
|
||||
// SherpaOnnxDestroyOnlinePunctuation() to free the returned pointer
|
||||
// to avoid memory leak
|
||||
SHERPA_ONNX_API const SherpaOnnxOnlinePunctuation *SherpaOnnxCreateOnlinePunctuation(
|
||||
const SherpaOnnxOnlinePunctuationConfig *config);
|
||||
|
||||
// Free a pointer returned by SherpaOnnxCreateOnlinePunctuation()
|
||||
SHERPA_ONNX_API void SherpaOnnxDestroyOnlinePunctuation(
|
||||
const SherpaOnnxOnlinePunctuation *punctuation);
|
||||
|
||||
// Add punctuations to the input text. The user has to invoke
|
||||
// SherpaOnnxOnlinePunctuationFreeText() to free the returned pointer
|
||||
// to avoid memory leak
|
||||
SHERPA_ONNX_API const char *SherpaOnnxOnlinePunctuationAddPunct(
|
||||
const SherpaOnnxOnlinePunctuation *punctuation, const char *text);
|
||||
|
||||
// Free a pointer returned by SherpaOnnxOnlinePunctuationAddPunct()
|
||||
SHERPA_ONNX_API void SherpaOnnxOnlinePunctuationFreeText(const char *text);
|
||||
|
||||
// for resampling
|
||||
SHERPA_ONNX_API typedef struct SherpaOnnxLinearResampler
|
||||
SherpaOnnxLinearResampler;
|
||||
|
||||
Reference in New Issue
Block a user