Add swift online punctuation (#1661)

This commit is contained in:
yujinqiu
2024-12-31 11:26:32 +08:00
committed by GitHub
parent 49154c957b
commit 5c2cc48f50
5 changed files with 198 additions and 0 deletions

View File

@@ -24,6 +24,7 @@
#include "sherpa-onnx/csrc/macros.h"
#include "sherpa-onnx/csrc/offline-punctuation.h"
#include "sherpa-onnx/csrc/offline-recognizer.h"
#include "sherpa-onnx/csrc/online-punctuation.h"
#include "sherpa-onnx/csrc/online-recognizer.h"
#include "sherpa-onnx/csrc/resample.h"
#include "sherpa-onnx/csrc/speaker-embedding-extractor.h"
@@ -1717,6 +1718,53 @@ const char *SherpaOfflinePunctuationAddPunct(
void SherpaOfflinePunctuationFreeText(const char *text) { delete[] text; }
struct SherpaOnnxOnlinePunctuation {
std::unique_ptr<sherpa_onnx::OnlinePunctuation> impl;
};
const SherpaOnnxOnlinePunctuation *SherpaOnnxCreateOnlinePunctuation(
const SherpaOnnxOnlinePunctuationConfig *config) {
auto p = new SherpaOnnxOnlinePunctuation;
try {
sherpa_onnx::OnlinePunctuationConfig punctuation_config;
punctuation_config.model.cnn_bilstm = SHERPA_ONNX_OR(config->model.cnn_bilstm, "");
punctuation_config.model.bpe_vocab = SHERPA_ONNX_OR(config->model.bpe_vocab, "");
punctuation_config.model.num_threads = SHERPA_ONNX_OR(config->model.num_threads, 1);
punctuation_config.model.debug = config->model.debug;
punctuation_config.model.provider = SHERPA_ONNX_OR(config->model.provider, "cpu");
p->impl =
std::make_unique<sherpa_onnx::OnlinePunctuation>(punctuation_config);
} catch (const std::exception &e) {
SHERPA_ONNX_LOGE("Failed to create online punctuation: %s", e.what());
delete p;
return nullptr;
}
return p;
}
void SherpaOnnxDestroyOnlinePunctuation(const SherpaOnnxOnlinePunctuation *p) {
delete p;
}
const char *SherpaOnnxOnlinePunctuationAddPunct(
const SherpaOnnxOnlinePunctuation *punctuation, const char *text) {
if (!punctuation || !text) return nullptr;
try {
std::string s = punctuation->impl->AddPunctuationWithCase(text);
char *p = new char[s.size() + 1];
std::copy(s.begin(), s.end(), p);
p[s.size()] = '\0';
return p;
} catch (const std::exception &e) {
SHERPA_ONNX_LOGE("Failed to add punctuation: %s", e.what());
return nullptr;
}
}
void SherpaOnnxOnlinePunctuationFreeText(const char *text) { delete[] text; }
struct SherpaOnnxLinearResampler {
std::unique_ptr<sherpa_onnx::LinearResample> impl;
};