Add C API for speaker embedding extractor. (#711)

This commit is contained in:
Fangjun Kuang
2024-03-28 18:05:40 +08:00
committed by GitHub
parent 638f48f47a
commit 2e0bccad36
23 changed files with 739 additions and 80 deletions

View File

@@ -16,6 +16,8 @@
#include "sherpa-onnx/csrc/macros.h"
#include "sherpa-onnx/csrc/offline-recognizer.h"
#include "sherpa-onnx/csrc/online-recognizer.h"
#include "sherpa-onnx/csrc/speaker-embedding-extractor.h"
#include "sherpa-onnx/csrc/speaker-embedding-manager.h"
#include "sherpa-onnx/csrc/spoken-language-identification.h"
#include "sherpa-onnx/csrc/voice-activity-detector.h"
#include "sherpa-onnx/csrc/wave-reader.h"
@@ -114,7 +116,7 @@ SherpaOnnxOnlineRecognizer *CreateOnlineRecognizer(
return recognizer;
}
void DestroyOnlineRecognizer(SherpaOnnxOnlineRecognizer *recognizer) {
void DestroyOnlineRecognizer(const SherpaOnnxOnlineRecognizer *recognizer) {
delete recognizer;
}
@@ -132,25 +134,28 @@ SherpaOnnxOnlineStream *CreateOnlineStreamWithHotwords(
return stream;
}
void DestroyOnlineStream(SherpaOnnxOnlineStream *stream) { delete stream; }
void DestroyOnlineStream(const SherpaOnnxOnlineStream *stream) {
delete stream;
}
void AcceptWaveform(SherpaOnnxOnlineStream *stream, int32_t sample_rate,
void AcceptWaveform(const SherpaOnnxOnlineStream *stream, int32_t sample_rate,
const float *samples, int32_t n) {
stream->impl->AcceptWaveform(sample_rate, samples, n);
}
int32_t IsOnlineStreamReady(SherpaOnnxOnlineRecognizer *recognizer,
SherpaOnnxOnlineStream *stream) {
int32_t IsOnlineStreamReady(const SherpaOnnxOnlineRecognizer *recognizer,
const SherpaOnnxOnlineStream *stream) {
return recognizer->impl->IsReady(stream->impl.get());
}
void DecodeOnlineStream(SherpaOnnxOnlineRecognizer *recognizer,
SherpaOnnxOnlineStream *stream) {
void DecodeOnlineStream(const SherpaOnnxOnlineRecognizer *recognizer,
const SherpaOnnxOnlineStream *stream) {
recognizer->impl->DecodeStream(stream->impl.get());
}
void DecodeMultipleOnlineStreams(SherpaOnnxOnlineRecognizer *recognizer,
SherpaOnnxOnlineStream **streams, int32_t n) {
void DecodeMultipleOnlineStreams(const SherpaOnnxOnlineRecognizer *recognizer,
const SherpaOnnxOnlineStream **streams,
int32_t n) {
std::vector<sherpa_onnx::OnlineStream *> ss(n);
for (int32_t i = 0; i != n; ++i) {
ss[i] = streams[i]->impl.get();
@@ -159,7 +164,8 @@ void DecodeMultipleOnlineStreams(SherpaOnnxOnlineRecognizer *recognizer,
}
const SherpaOnnxOnlineRecognizerResult *GetOnlineStreamResult(
SherpaOnnxOnlineRecognizer *recognizer, SherpaOnnxOnlineStream *stream) {
const SherpaOnnxOnlineRecognizer *recognizer,
const SherpaOnnxOnlineStream *stream) {
sherpa_onnx::OnlineRecognizerResult result =
recognizer->impl->GetResult(stream->impl.get());
const auto &text = result.text;
@@ -232,29 +238,30 @@ void DestroyOnlineRecognizerResult(const SherpaOnnxOnlineRecognizerResult *r) {
}
}
void Reset(SherpaOnnxOnlineRecognizer *recognizer,
SherpaOnnxOnlineStream *stream) {
void Reset(const SherpaOnnxOnlineRecognizer *recognizer,
const SherpaOnnxOnlineStream *stream) {
recognizer->impl->Reset(stream->impl.get());
}
void InputFinished(SherpaOnnxOnlineStream *stream) {
void InputFinished(const SherpaOnnxOnlineStream *stream) {
stream->impl->InputFinished();
}
int32_t IsEndpoint(SherpaOnnxOnlineRecognizer *recognizer,
SherpaOnnxOnlineStream *stream) {
int32_t IsEndpoint(const SherpaOnnxOnlineRecognizer *recognizer,
const SherpaOnnxOnlineStream *stream) {
return recognizer->impl->IsEndpoint(stream->impl.get());
}
SherpaOnnxDisplay *CreateDisplay(int32_t max_word_per_line) {
const SherpaOnnxDisplay *CreateDisplay(int32_t max_word_per_line) {
SherpaOnnxDisplay *ans = new SherpaOnnxDisplay;
ans->impl = std::make_unique<sherpa_onnx::Display>(max_word_per_line);
return ans;
}
void DestroyDisplay(SherpaOnnxDisplay *display) { delete display; }
void DestroyDisplay(const SherpaOnnxDisplay *display) { delete display; }
void SherpaOnnxPrint(SherpaOnnxDisplay *display, int32_t idx, const char *s) {
void SherpaOnnxPrint(const SherpaOnnxDisplay *display, int32_t idx,
const char *s) {
display->impl->Print(idx, s);
}
@@ -808,9 +815,8 @@ int32_t SherpaOnnxOfflineTtsNumSpeakers(const SherpaOnnxOfflineTts *tts) {
}
static const SherpaOnnxGeneratedAudio *SherpaOnnxOfflineTtsGenerateInternal(
const SherpaOnnxOfflineTts *tts, const char *text, int32_t sid,
float speed, std::function<void(const float *, int32_t, float)> callback)
{
const SherpaOnnxOfflineTts *tts, const char *text, int32_t sid, float speed,
std::function<void(const float *, int32_t, float)> callback) {
sherpa_onnx::GeneratedAudio audio =
tts->impl->Generate(text, sid, speed, callback);
@@ -833,36 +839,37 @@ static const SherpaOnnxGeneratedAudio *SherpaOnnxOfflineTtsGenerateInternal(
const SherpaOnnxGeneratedAudio *SherpaOnnxOfflineTtsGenerate(
const SherpaOnnxOfflineTts *tts, const char *text, int32_t sid,
float speed) {
return SherpaOnnxOfflineTtsGenerateInternal( tts, text, sid, speed, nullptr );
return SherpaOnnxOfflineTtsGenerateInternal(tts, text, sid, speed, nullptr);
}
const SherpaOnnxGeneratedAudio *SherpaOnnxOfflineTtsGenerateWithCallback(
const SherpaOnnxOfflineTts *tts, const char *text, int32_t sid, float speed,
SherpaOnnxGeneratedAudioCallback callback) {
auto wrapper = [callback](const float *samples, int32_t n, float /*progress*/) {
callback(samples, n );
};
auto wrapper = [callback](const float *samples, int32_t n,
float /*progress*/) { callback(samples, n); };
return SherpaOnnxOfflineTtsGenerateInternal( tts, text, sid, speed, wrapper );
return SherpaOnnxOfflineTtsGenerateInternal(tts, text, sid, speed, wrapper);
}
const SherpaOnnxGeneratedAudio *SherpaOnnxOfflineTtsGenerateWithProgressCallback(
const SherpaOnnxGeneratedAudio *
SherpaOnnxOfflineTtsGenerateWithProgressCallback(
const SherpaOnnxOfflineTts *tts, const char *text, int32_t sid, float speed,
SherpaOnnxGeneratedAudioProgressCallback callback) {
auto wrapper = [callback](const float *samples, int32_t n, float progress) {
callback(samples, n, progress );
callback(samples, n, progress);
};
return SherpaOnnxOfflineTtsGenerateInternal( tts, text, sid, speed, wrapper );
return SherpaOnnxOfflineTtsGenerateInternal(tts, text, sid, speed, wrapper);
}
const SherpaOnnxGeneratedAudio *SherpaOnnxOfflineTtsGenerateWithCallbackWithArg(
const SherpaOnnxOfflineTts *tts, const char *text, int32_t sid, float speed,
SherpaOnnxGeneratedAudioCallbackWithArg callback, void *arg) {
auto wrapper = [callback, arg](const float *samples, int32_t n, float /*progress*/) {
auto wrapper = [callback, arg](const float *samples, int32_t n,
float /*progress*/) {
callback(samples, n, arg);
};
return SherpaOnnxOfflineTtsGenerateInternal( tts, text, sid, speed, wrapper );
return SherpaOnnxOfflineTtsGenerateInternal(tts, text, sid, speed, wrapper);
}
void SherpaOnnxDestroyOfflineTtsGeneratedAudio(
@@ -972,3 +979,200 @@ void SherpaOnnxDestroySpokenLanguageIdentificationResult(
delete r;
}
}
struct SherpaOnnxSpeakerEmbeddingExtractor {
std::unique_ptr<sherpa_onnx::SpeakerEmbeddingExtractor> impl;
};
const SherpaOnnxSpeakerEmbeddingExtractor *
SherpaOnnxCreateSpeakerEmbeddingExtractor(
const SherpaOnnxSpeakerEmbeddingExtractorConfig *config) {
sherpa_onnx::SpeakerEmbeddingExtractorConfig c;
c.model = SHERPA_ONNX_OR(config->model, "");
c.num_threads = SHERPA_ONNX_OR(config->num_threads, 1);
c.debug = SHERPA_ONNX_OR(config->debug, 0);
c.provider = SHERPA_ONNX_OR(config->provider, "cpu");
if (config->debug) {
SHERPA_ONNX_LOGE("%s\n", c.ToString().c_str());
}
if (!c.Validate()) {
SHERPA_ONNX_LOGE("Errors in config!");
return nullptr;
}
auto p = new SherpaOnnxSpeakerEmbeddingExtractor;
p->impl = std::make_unique<sherpa_onnx::SpeakerEmbeddingExtractor>(c);
return p;
}
void SherpaOnnxDestroySpeakerEmbeddingExtractor(
const SherpaOnnxSpeakerEmbeddingExtractor *p) {
delete p;
}
int32_t SherpaOnnxSpeakerEmbeddingExtractorDim(
const SherpaOnnxSpeakerEmbeddingExtractor *p) {
return p->impl->Dim();
}
const SherpaOnnxOnlineStream *SherpaOnnxSpeakerEmbeddingExtractorCreateStream(
const SherpaOnnxSpeakerEmbeddingExtractor *p) {
SherpaOnnxOnlineStream *stream =
new SherpaOnnxOnlineStream(p->impl->CreateStream());
return stream;
}
int32_t SherpaOnnxSpeakerEmbeddingExtractorIsReady(
const SherpaOnnxSpeakerEmbeddingExtractor *p,
const SherpaOnnxOnlineStream *s) {
return p->impl->IsReady(s->impl.get());
}
const float *SherpaOnnxSpeakerEmbeddingExtractorComputeEmbedding(
const SherpaOnnxSpeakerEmbeddingExtractor *p,
const SherpaOnnxOnlineStream *s) {
std::vector<float> v = p->impl->Compute(s->impl.get());
float *ans = new float[v.size()];
std::copy(v.begin(), v.end(), ans);
return ans;
}
void SherpaOnnxSpeakerEmbeddingExtractorDestroyEmbedding(const float *v) {
delete[] v;
}
struct SherpaOnnxSpeakerEmbeddingManager {
std::unique_ptr<sherpa_onnx::SpeakerEmbeddingManager> impl;
};
const SherpaOnnxSpeakerEmbeddingManager *
SherpaOnnxCreateSpeakerEmbeddingManager(int32_t dim) {
auto p = new SherpaOnnxSpeakerEmbeddingManager;
p->impl = std::make_unique<sherpa_onnx::SpeakerEmbeddingManager>(dim);
return p;
}
void SherpaOnnxDestroySpeakerEmbeddingManager(
const SherpaOnnxSpeakerEmbeddingManager *p) {
delete p;
}
int32_t SherpaOnnxSpeakerEmbeddingManagerAdd(
const SherpaOnnxSpeakerEmbeddingManager *p, const char *name,
const float *v) {
return p->impl->Add(name, v);
}
int32_t SherpaOnnxSpeakerEmbeddingManagerAddList(
const SherpaOnnxSpeakerEmbeddingManager *p, const char *name,
const float **v) {
int32_t n = 0;
auto q = v;
while (q && q[0]) {
++n;
++q;
}
if (n == 0) {
SHERPA_ONNX_LOGE("Empty embedding!");
return 0;
}
std::vector<std::vector<float>> vec(n);
int32_t dim = p->impl->Dim();
for (int32_t i = 0; i != n; ++i) {
vec[i] = std::vector<float>(v[i], v[i] + dim);
}
return p->impl->Add(name, vec);
}
int32_t SherpaOnnxSpeakerEmbeddingManagerAddListFlattened(
const SherpaOnnxSpeakerEmbeddingManager *p, const char *name,
const float *v, int32_t n) {
std::vector<std::vector<float>> vec(n);
int32_t dim = p->impl->Dim();
for (int32_t i = 0; i != n; ++i, v += dim) {
vec[i] = std::vector<float>(v, v + dim);
}
return p->impl->Add(name, vec);
}
int32_t SherpaOnnxSpeakerEmbeddingManagerRemove(
const SherpaOnnxSpeakerEmbeddingManager *p, const char *name) {
return p->impl->Remove(name);
}
const char *SherpaOnnxSpeakerEmbeddingManagerSearch(
const SherpaOnnxSpeakerEmbeddingManager *p, const float *v,
float threshold) {
auto r = p->impl->Search(v, threshold);
if (r.empty()) {
return nullptr;
}
char *name = new char[r.size() + 1];
std::copy(r.begin(), r.end(), name);
name[r.size()] = '\0';
return name;
}
void SherpaOnnxSpeakerEmbeddingManagerFreeSearch(const char *name) {
delete[] name;
}
int32_t SherpaOnnxSpeakerEmbeddingManagerVerify(
const SherpaOnnxSpeakerEmbeddingManager *p, const char *name,
const float *v, float threshold) {
return p->impl->Verify(name, v, threshold);
}
int32_t SherpaOnnxSpeakerEmbeddingManagerContains(
const SherpaOnnxSpeakerEmbeddingManager *p, const char *name) {
return p->impl->Contains(name);
}
int32_t SherpaOnnxSpeakerEmbeddingManagerNumSpeakers(
const SherpaOnnxSpeakerEmbeddingManager *p) {
return p->impl->NumSpeakers();
}
const char *const *SherpaOnnxSpeakerEmbeddingManagerGetAllSpeakers(
const SherpaOnnxSpeakerEmbeddingManager *manager) {
std::vector<std::string> all_speakers = manager->impl->GetAllSpeakers();
int32_t num_speakers = all_speakers.size();
char **p = new char *[num_speakers + 1];
p[num_speakers] = nullptr;
int32_t i = 0;
for (const auto &name : all_speakers) {
p[i] = new char[name.size() + 1];
std::copy(name.begin(), name.end(), p[i]);
p[i][name.size()] = '\0';
i += 1;
}
return p;
}
void SherpaOnnxSpeakerEmbeddingManagerFreeAllSpeakers(
const char *const *names) {
auto p = names;
while (p && p[0]) {
delete[] p[0];
++p;
}
delete[] names;
}