Add C API for spoken language identification. (#695)

This commit is contained in:
Fangjun Kuang
2024-03-25 15:16:47 +08:00
committed by GitHub
parent 0d258dd150
commit ab7cff2513
18 changed files with 366 additions and 70 deletions

View File

@@ -6,6 +6,7 @@
#include <algorithm>
#include <memory>
#include <string>
#include <utility>
#include <vector>
@@ -16,7 +17,9 @@
#include "sherpa-onnx/csrc/offline-recognizer.h"
#include "sherpa-onnx/csrc/offline-tts.h"
#include "sherpa-onnx/csrc/online-recognizer.h"
#include "sherpa-onnx/csrc/spoken-language-identification.h"
#include "sherpa-onnx/csrc/voice-activity-detector.h"
#include "sherpa-onnx/csrc/wave-reader.h"
#include "sherpa-onnx/csrc/wave-writer.h"
struct SherpaOnnxOnlineRecognizer {
@@ -859,3 +862,97 @@ int32_t SherpaOnnxWriteWave(const float *samples, int32_t n,
int32_t sample_rate, const char *filename) {
return sherpa_onnx::WriteWave(filename, sample_rate, samples, n);
}
const SherpaOnnxWave *SherpaOnnxReadWave(const char *filename) {
int32_t sample_rate = -1;
bool is_ok = false;
std::vector<float> samples =
sherpa_onnx::ReadWave(filename, &sample_rate, &is_ok);
if (!is_ok) {
return nullptr;
}
float *c_samples = new float[samples.size()];
std::copy(samples.begin(), samples.end(), c_samples);
SherpaOnnxWave *wave = new SherpaOnnxWave;
wave->samples = c_samples;
wave->sample_rate = sample_rate;
wave->num_samples = samples.size();
return wave;
}
void SherpaOnnxFreeWave(const SherpaOnnxWave *wave) {
if (wave) {
delete[] wave->samples;
delete wave;
}
}
struct SherpaOnnxSpokenLanguageIdentification {
std::unique_ptr<sherpa_onnx::SpokenLanguageIdentification> impl;
};
const SherpaOnnxSpokenLanguageIdentification *
SherpaOnnxCreateSpokenLanguageIdentification(
const SherpaOnnxSpokenLanguageIdentificationConfig *config) {
sherpa_onnx::SpokenLanguageIdentificationConfig slid_config;
slid_config.whisper.encoder = SHERPA_ONNX_OR(config->whisper.encoder, "");
slid_config.whisper.decoder = SHERPA_ONNX_OR(config->whisper.decoder, "");
slid_config.whisper.tail_paddings =
SHERPA_ONNX_OR(config->whisper.tail_paddings, -1);
slid_config.num_threads = SHERPA_ONNX_OR(config->num_threads, 1);
slid_config.debug = config->debug;
slid_config.provider = SHERPA_ONNX_OR(config->provider, "cpu");
if (slid_config.debug) {
SHERPA_ONNX_LOGE("%s\n", slid_config.ToString().c_str());
}
if (!slid_config.Validate()) {
SHERPA_ONNX_LOGE("Errors in config");
return nullptr;
}
SherpaOnnxSpokenLanguageIdentification *slid =
new SherpaOnnxSpokenLanguageIdentification;
slid->impl =
std::make_unique<sherpa_onnx::SpokenLanguageIdentification>(slid_config);
return slid;
}
void SherpaOnnxDestroySpokenLanguageIdentification(
const SherpaOnnxSpokenLanguageIdentification *slid) {
delete slid;
}
SherpaOnnxOfflineStream *
SherpaOnnxSpokenLanguageIdentificationCreateOfflineStream(
const SherpaOnnxSpokenLanguageIdentification *slid) {
SherpaOnnxOfflineStream *stream =
new SherpaOnnxOfflineStream(slid->impl->CreateStream());
return stream;
}
const SherpaOnnxSpokenLanguageIdentificationResult *
SherpaOnnxSpokenLanguageIdentificationCompute(
const SherpaOnnxSpokenLanguageIdentification *slid,
const SherpaOnnxOfflineStream *s) {
std::string lang = slid->impl->Compute(s->impl.get());
char *c_lang = new char[lang.size() + 1];
std::copy(lang.begin(), lang.end(), c_lang);
c_lang[lang.size()] = '\0';
SherpaOnnxSpokenLanguageIdentificationResult *r =
new SherpaOnnxSpokenLanguageIdentificationResult;
r->lang = c_lang;
return r;
}
void SherpaOnnxDestroySpokenLanguageIdentificationResult(
const SherpaOnnxSpokenLanguageIdentificationResult *r) {
if (r) {
delete[] r->lang;
delete r;
}
}