Add C API for speaker embedding extractor. (#711)

This commit is contained in:
Fangjun Kuang
2024-03-28 18:05:40 +08:00
committed by GitHub
parent 638f48f47a
commit 2e0bccad36
23 changed files with 739 additions and 80 deletions

View File

@@ -16,6 +16,8 @@
#include "sherpa-onnx/csrc/macros.h"
#include "sherpa-onnx/csrc/offline-recognizer.h"
#include "sherpa-onnx/csrc/online-recognizer.h"
#include "sherpa-onnx/csrc/speaker-embedding-extractor.h"
#include "sherpa-onnx/csrc/speaker-embedding-manager.h"
#include "sherpa-onnx/csrc/spoken-language-identification.h"
#include "sherpa-onnx/csrc/voice-activity-detector.h"
#include "sherpa-onnx/csrc/wave-reader.h"
@@ -114,7 +116,7 @@ SherpaOnnxOnlineRecognizer *CreateOnlineRecognizer(
return recognizer;
}
void DestroyOnlineRecognizer(SherpaOnnxOnlineRecognizer *recognizer) {
void DestroyOnlineRecognizer(const SherpaOnnxOnlineRecognizer *recognizer) {
delete recognizer;
}
@@ -132,25 +134,28 @@ SherpaOnnxOnlineStream *CreateOnlineStreamWithHotwords(
return stream;
}
void DestroyOnlineStream(SherpaOnnxOnlineStream *stream) { delete stream; }
void DestroyOnlineStream(const SherpaOnnxOnlineStream *stream) {
delete stream;
}
void AcceptWaveform(SherpaOnnxOnlineStream *stream, int32_t sample_rate,
void AcceptWaveform(const SherpaOnnxOnlineStream *stream, int32_t sample_rate,
const float *samples, int32_t n) {
stream->impl->AcceptWaveform(sample_rate, samples, n);
}
int32_t IsOnlineStreamReady(SherpaOnnxOnlineRecognizer *recognizer,
SherpaOnnxOnlineStream *stream) {
int32_t IsOnlineStreamReady(const SherpaOnnxOnlineRecognizer *recognizer,
const SherpaOnnxOnlineStream *stream) {
return recognizer->impl->IsReady(stream->impl.get());
}
void DecodeOnlineStream(SherpaOnnxOnlineRecognizer *recognizer,
SherpaOnnxOnlineStream *stream) {
void DecodeOnlineStream(const SherpaOnnxOnlineRecognizer *recognizer,
const SherpaOnnxOnlineStream *stream) {
recognizer->impl->DecodeStream(stream->impl.get());
}
void DecodeMultipleOnlineStreams(SherpaOnnxOnlineRecognizer *recognizer,
SherpaOnnxOnlineStream **streams, int32_t n) {
void DecodeMultipleOnlineStreams(const SherpaOnnxOnlineRecognizer *recognizer,
const SherpaOnnxOnlineStream **streams,
int32_t n) {
std::vector<sherpa_onnx::OnlineStream *> ss(n);
for (int32_t i = 0; i != n; ++i) {
ss[i] = streams[i]->impl.get();
@@ -159,7 +164,8 @@ void DecodeMultipleOnlineStreams(SherpaOnnxOnlineRecognizer *recognizer,
}
const SherpaOnnxOnlineRecognizerResult *GetOnlineStreamResult(
SherpaOnnxOnlineRecognizer *recognizer, SherpaOnnxOnlineStream *stream) {
const SherpaOnnxOnlineRecognizer *recognizer,
const SherpaOnnxOnlineStream *stream) {
sherpa_onnx::OnlineRecognizerResult result =
recognizer->impl->GetResult(stream->impl.get());
const auto &text = result.text;
@@ -232,29 +238,30 @@ void DestroyOnlineRecognizerResult(const SherpaOnnxOnlineRecognizerResult *r) {
}
}
void Reset(SherpaOnnxOnlineRecognizer *recognizer,
SherpaOnnxOnlineStream *stream) {
void Reset(const SherpaOnnxOnlineRecognizer *recognizer,
const SherpaOnnxOnlineStream *stream) {
recognizer->impl->Reset(stream->impl.get());
}
void InputFinished(SherpaOnnxOnlineStream *stream) {
void InputFinished(const SherpaOnnxOnlineStream *stream) {
stream->impl->InputFinished();
}
int32_t IsEndpoint(SherpaOnnxOnlineRecognizer *recognizer,
SherpaOnnxOnlineStream *stream) {
int32_t IsEndpoint(const SherpaOnnxOnlineRecognizer *recognizer,
const SherpaOnnxOnlineStream *stream) {
return recognizer->impl->IsEndpoint(stream->impl.get());
}
SherpaOnnxDisplay *CreateDisplay(int32_t max_word_per_line) {
const SherpaOnnxDisplay *CreateDisplay(int32_t max_word_per_line) {
SherpaOnnxDisplay *ans = new SherpaOnnxDisplay;
ans->impl = std::make_unique<sherpa_onnx::Display>(max_word_per_line);
return ans;
}
void DestroyDisplay(SherpaOnnxDisplay *display) { delete display; }
void DestroyDisplay(const SherpaOnnxDisplay *display) { delete display; }
void SherpaOnnxPrint(SherpaOnnxDisplay *display, int32_t idx, const char *s) {
void SherpaOnnxPrint(const SherpaOnnxDisplay *display, int32_t idx,
const char *s) {
display->impl->Print(idx, s);
}
@@ -808,9 +815,8 @@ int32_t SherpaOnnxOfflineTtsNumSpeakers(const SherpaOnnxOfflineTts *tts) {
}
static const SherpaOnnxGeneratedAudio *SherpaOnnxOfflineTtsGenerateInternal(
const SherpaOnnxOfflineTts *tts, const char *text, int32_t sid,
float speed, std::function<void(const float *, int32_t, float)> callback)
{
const SherpaOnnxOfflineTts *tts, const char *text, int32_t sid, float speed,
std::function<void(const float *, int32_t, float)> callback) {
sherpa_onnx::GeneratedAudio audio =
tts->impl->Generate(text, sid, speed, callback);
@@ -833,36 +839,37 @@ static const SherpaOnnxGeneratedAudio *SherpaOnnxOfflineTtsGenerateInternal(
const SherpaOnnxGeneratedAudio *SherpaOnnxOfflineTtsGenerate(
const SherpaOnnxOfflineTts *tts, const char *text, int32_t sid,
float speed) {
return SherpaOnnxOfflineTtsGenerateInternal( tts, text, sid, speed, nullptr );
return SherpaOnnxOfflineTtsGenerateInternal(tts, text, sid, speed, nullptr);
}
const SherpaOnnxGeneratedAudio *SherpaOnnxOfflineTtsGenerateWithCallback(
const SherpaOnnxOfflineTts *tts, const char *text, int32_t sid, float speed,
SherpaOnnxGeneratedAudioCallback callback) {
auto wrapper = [callback](const float *samples, int32_t n, float /*progress*/) {
callback(samples, n );
};
auto wrapper = [callback](const float *samples, int32_t n,
float /*progress*/) { callback(samples, n); };
return SherpaOnnxOfflineTtsGenerateInternal( tts, text, sid, speed, wrapper );
return SherpaOnnxOfflineTtsGenerateInternal(tts, text, sid, speed, wrapper);
}
const SherpaOnnxGeneratedAudio *SherpaOnnxOfflineTtsGenerateWithProgressCallback(
const SherpaOnnxGeneratedAudio *
SherpaOnnxOfflineTtsGenerateWithProgressCallback(
const SherpaOnnxOfflineTts *tts, const char *text, int32_t sid, float speed,
SherpaOnnxGeneratedAudioProgressCallback callback) {
auto wrapper = [callback](const float *samples, int32_t n, float progress) {
callback(samples, n, progress );
callback(samples, n, progress);
};
return SherpaOnnxOfflineTtsGenerateInternal( tts, text, sid, speed, wrapper );
return SherpaOnnxOfflineTtsGenerateInternal(tts, text, sid, speed, wrapper);
}
const SherpaOnnxGeneratedAudio *SherpaOnnxOfflineTtsGenerateWithCallbackWithArg(
const SherpaOnnxOfflineTts *tts, const char *text, int32_t sid, float speed,
SherpaOnnxGeneratedAudioCallbackWithArg callback, void *arg) {
auto wrapper = [callback, arg](const float *samples, int32_t n, float /*progress*/) {
auto wrapper = [callback, arg](const float *samples, int32_t n,
float /*progress*/) {
callback(samples, n, arg);
};
return SherpaOnnxOfflineTtsGenerateInternal( tts, text, sid, speed, wrapper );
return SherpaOnnxOfflineTtsGenerateInternal(tts, text, sid, speed, wrapper);
}
void SherpaOnnxDestroyOfflineTtsGeneratedAudio(
@@ -972,3 +979,200 @@ void SherpaOnnxDestroySpokenLanguageIdentificationResult(
delete r;
}
}
struct SherpaOnnxSpeakerEmbeddingExtractor {
std::unique_ptr<sherpa_onnx::SpeakerEmbeddingExtractor> impl;
};
const SherpaOnnxSpeakerEmbeddingExtractor *
SherpaOnnxCreateSpeakerEmbeddingExtractor(
const SherpaOnnxSpeakerEmbeddingExtractorConfig *config) {
sherpa_onnx::SpeakerEmbeddingExtractorConfig c;
c.model = SHERPA_ONNX_OR(config->model, "");
c.num_threads = SHERPA_ONNX_OR(config->num_threads, 1);
c.debug = SHERPA_ONNX_OR(config->debug, 0);
c.provider = SHERPA_ONNX_OR(config->provider, "cpu");
if (config->debug) {
SHERPA_ONNX_LOGE("%s\n", c.ToString().c_str());
}
if (!c.Validate()) {
SHERPA_ONNX_LOGE("Errors in config!");
return nullptr;
}
auto p = new SherpaOnnxSpeakerEmbeddingExtractor;
p->impl = std::make_unique<sherpa_onnx::SpeakerEmbeddingExtractor>(c);
return p;
}
void SherpaOnnxDestroySpeakerEmbeddingExtractor(
const SherpaOnnxSpeakerEmbeddingExtractor *p) {
delete p;
}
int32_t SherpaOnnxSpeakerEmbeddingExtractorDim(
const SherpaOnnxSpeakerEmbeddingExtractor *p) {
return p->impl->Dim();
}
const SherpaOnnxOnlineStream *SherpaOnnxSpeakerEmbeddingExtractorCreateStream(
const SherpaOnnxSpeakerEmbeddingExtractor *p) {
SherpaOnnxOnlineStream *stream =
new SherpaOnnxOnlineStream(p->impl->CreateStream());
return stream;
}
int32_t SherpaOnnxSpeakerEmbeddingExtractorIsReady(
const SherpaOnnxSpeakerEmbeddingExtractor *p,
const SherpaOnnxOnlineStream *s) {
return p->impl->IsReady(s->impl.get());
}
const float *SherpaOnnxSpeakerEmbeddingExtractorComputeEmbedding(
const SherpaOnnxSpeakerEmbeddingExtractor *p,
const SherpaOnnxOnlineStream *s) {
std::vector<float> v = p->impl->Compute(s->impl.get());
float *ans = new float[v.size()];
std::copy(v.begin(), v.end(), ans);
return ans;
}
void SherpaOnnxSpeakerEmbeddingExtractorDestroyEmbedding(const float *v) {
delete[] v;
}
struct SherpaOnnxSpeakerEmbeddingManager {
std::unique_ptr<sherpa_onnx::SpeakerEmbeddingManager> impl;
};
const SherpaOnnxSpeakerEmbeddingManager *
SherpaOnnxCreateSpeakerEmbeddingManager(int32_t dim) {
auto p = new SherpaOnnxSpeakerEmbeddingManager;
p->impl = std::make_unique<sherpa_onnx::SpeakerEmbeddingManager>(dim);
return p;
}
void SherpaOnnxDestroySpeakerEmbeddingManager(
const SherpaOnnxSpeakerEmbeddingManager *p) {
delete p;
}
int32_t SherpaOnnxSpeakerEmbeddingManagerAdd(
const SherpaOnnxSpeakerEmbeddingManager *p, const char *name,
const float *v) {
return p->impl->Add(name, v);
}
int32_t SherpaOnnxSpeakerEmbeddingManagerAddList(
const SherpaOnnxSpeakerEmbeddingManager *p, const char *name,
const float **v) {
int32_t n = 0;
auto q = v;
while (q && q[0]) {
++n;
++q;
}
if (n == 0) {
SHERPA_ONNX_LOGE("Empty embedding!");
return 0;
}
std::vector<std::vector<float>> vec(n);
int32_t dim = p->impl->Dim();
for (int32_t i = 0; i != n; ++i) {
vec[i] = std::vector<float>(v[i], v[i] + dim);
}
return p->impl->Add(name, vec);
}
int32_t SherpaOnnxSpeakerEmbeddingManagerAddListFlattened(
const SherpaOnnxSpeakerEmbeddingManager *p, const char *name,
const float *v, int32_t n) {
std::vector<std::vector<float>> vec(n);
int32_t dim = p->impl->Dim();
for (int32_t i = 0; i != n; ++i, v += dim) {
vec[i] = std::vector<float>(v, v + dim);
}
return p->impl->Add(name, vec);
}
int32_t SherpaOnnxSpeakerEmbeddingManagerRemove(
const SherpaOnnxSpeakerEmbeddingManager *p, const char *name) {
return p->impl->Remove(name);
}
const char *SherpaOnnxSpeakerEmbeddingManagerSearch(
const SherpaOnnxSpeakerEmbeddingManager *p, const float *v,
float threshold) {
auto r = p->impl->Search(v, threshold);
if (r.empty()) {
return nullptr;
}
char *name = new char[r.size() + 1];
std::copy(r.begin(), r.end(), name);
name[r.size()] = '\0';
return name;
}
void SherpaOnnxSpeakerEmbeddingManagerFreeSearch(const char *name) {
delete[] name;
}
int32_t SherpaOnnxSpeakerEmbeddingManagerVerify(
const SherpaOnnxSpeakerEmbeddingManager *p, const char *name,
const float *v, float threshold) {
return p->impl->Verify(name, v, threshold);
}
int32_t SherpaOnnxSpeakerEmbeddingManagerContains(
const SherpaOnnxSpeakerEmbeddingManager *p, const char *name) {
return p->impl->Contains(name);
}
int32_t SherpaOnnxSpeakerEmbeddingManagerNumSpeakers(
const SherpaOnnxSpeakerEmbeddingManager *p) {
return p->impl->NumSpeakers();
}
const char *const *SherpaOnnxSpeakerEmbeddingManagerGetAllSpeakers(
const SherpaOnnxSpeakerEmbeddingManager *manager) {
std::vector<std::string> all_speakers = manager->impl->GetAllSpeakers();
int32_t num_speakers = all_speakers.size();
char **p = new char *[num_speakers + 1];
p[num_speakers] = nullptr;
int32_t i = 0;
for (const auto &name : all_speakers) {
p[i] = new char[name.size() + 1];
std::copy(name.begin(), name.end(), p[i]);
p[i][name.size()] = '\0';
i += 1;
}
return p;
}
void SherpaOnnxSpeakerEmbeddingManagerFreeAllSpeakers(
const char *const *names) {
auto p = names;
while (p && p[0]) {
delete[] p[0];
++p;
}
delete[] names;
}

View File

@@ -186,7 +186,7 @@ SHERPA_ONNX_API SherpaOnnxOnlineRecognizer *CreateOnlineRecognizer(
///
/// @param p A pointer returned by CreateOnlineRecognizer()
SHERPA_ONNX_API void DestroyOnlineRecognizer(
SherpaOnnxOnlineRecognizer *recognizer);
const SherpaOnnxOnlineRecognizer *recognizer);
/// Create an online stream for accepting wave samples.
///
@@ -208,7 +208,7 @@ SHERPA_ONNX_API SherpaOnnxOnlineStream *CreateOnlineStreamWithHotwords(
/// Destroy an online stream.
///
/// @param stream A pointer returned by CreateOnlineStream()
SHERPA_ONNX_API void DestroyOnlineStream(SherpaOnnxOnlineStream *stream);
SHERPA_ONNX_API void DestroyOnlineStream(const SherpaOnnxOnlineStream *stream);
/// Accept input audio samples and compute the features.
/// The user has to invoke DecodeOnlineStream() to run the neural network and
@@ -221,7 +221,7 @@ SHERPA_ONNX_API void DestroyOnlineStream(SherpaOnnxOnlineStream *stream);
/// @param samples A pointer to a 1-D array containing audio samples.
/// The range of samples has to be normalized to [-1, 1].
/// @param n Number of elements in the samples array.
SHERPA_ONNX_API void AcceptWaveform(SherpaOnnxOnlineStream *stream,
SHERPA_ONNX_API void AcceptWaveform(const SherpaOnnxOnlineStream *stream,
int32_t sample_rate, const float *samples,
int32_t n);
@@ -230,8 +230,9 @@ SHERPA_ONNX_API void AcceptWaveform(SherpaOnnxOnlineStream *stream,
///
/// @param recognizer A pointer returned by CreateOnlineRecognizer
/// @param stream A pointer returned by CreateOnlineStream
SHERPA_ONNX_API int32_t IsOnlineStreamReady(
SherpaOnnxOnlineRecognizer *recognizer, SherpaOnnxOnlineStream *stream);
SHERPA_ONNX_API int32_t
IsOnlineStreamReady(const SherpaOnnxOnlineRecognizer *recognizer,
const SherpaOnnxOnlineStream *stream);
/// Call this function to run the neural network model and decoding.
//
@@ -243,8 +244,9 @@ SHERPA_ONNX_API int32_t IsOnlineStreamReady(
/// DecodeOnlineStream(recognizer, stream);
/// }
///
SHERPA_ONNX_API void DecodeOnlineStream(SherpaOnnxOnlineRecognizer *recognizer,
SherpaOnnxOnlineStream *stream);
SHERPA_ONNX_API void DecodeOnlineStream(
const SherpaOnnxOnlineRecognizer *recognizer,
const SherpaOnnxOnlineStream *stream);
/// This function is similar to DecodeOnlineStream(). It decodes multiple
/// OnlineStream in parallel.
@@ -257,8 +259,8 @@ SHERPA_ONNX_API void DecodeOnlineStream(SherpaOnnxOnlineRecognizer *recognizer,
/// CreateOnlineRecognizer()
/// @param n Number of elements in the given streams array.
SHERPA_ONNX_API void DecodeMultipleOnlineStreams(
SherpaOnnxOnlineRecognizer *recognizer, SherpaOnnxOnlineStream **streams,
int32_t n);
const SherpaOnnxOnlineRecognizer *recognizer,
const SherpaOnnxOnlineStream **streams, int32_t n);
/// Get the decoding results so far for an OnlineStream.
///
@@ -268,7 +270,8 @@ SHERPA_ONNX_API void DecodeMultipleOnlineStreams(
/// DestroyOnlineRecognizerResult() to free the returned pointer to
/// avoid memory leak.
SHERPA_ONNX_API const SherpaOnnxOnlineRecognizerResult *GetOnlineStreamResult(
SherpaOnnxOnlineRecognizer *recognizer, SherpaOnnxOnlineStream *stream);
const SherpaOnnxOnlineRecognizer *recognizer,
const SherpaOnnxOnlineStream *stream);
/// Destroy the pointer returned by GetOnlineStreamResult().
///
@@ -281,35 +284,36 @@ SHERPA_ONNX_API void DestroyOnlineRecognizerResult(
///
/// @param recognizer A pointer returned by CreateOnlineRecognizer().
/// @param stream A pointer returned by CreateOnlineStream
SHERPA_ONNX_API void Reset(SherpaOnnxOnlineRecognizer *recognizer,
SherpaOnnxOnlineStream *stream);
SHERPA_ONNX_API void Reset(const SherpaOnnxOnlineRecognizer *recognizer,
const SherpaOnnxOnlineStream *stream);
/// Signal that no more audio samples would be available.
/// After this call, you cannot call AcceptWaveform() any more.
///
/// @param stream A pointer returned by CreateOnlineStream()
SHERPA_ONNX_API void InputFinished(SherpaOnnxOnlineStream *stream);
SHERPA_ONNX_API void InputFinished(const SherpaOnnxOnlineStream *stream);
/// Return 1 if an endpoint has been detected.
///
/// @param recognizer A pointer returned by CreateOnlineRecognizer()
/// @param stream A pointer returned by CreateOnlineStream()
/// @return Return 1 if an endpoint is detected. Return 0 otherwise.
SHERPA_ONNX_API int32_t IsEndpoint(SherpaOnnxOnlineRecognizer *recognizer,
SherpaOnnxOnlineStream *stream);
SHERPA_ONNX_API int32_t IsEndpoint(const SherpaOnnxOnlineRecognizer *recognizer,
const SherpaOnnxOnlineStream *stream);
// for displaying results on Linux/macOS.
SHERPA_ONNX_API typedef struct SherpaOnnxDisplay SherpaOnnxDisplay;
/// Create a display object. Must be freed using DestroyDisplay to avoid
/// memory leak.
SHERPA_ONNX_API SherpaOnnxDisplay *CreateDisplay(int32_t max_word_per_line);
SHERPA_ONNX_API const SherpaOnnxDisplay *CreateDisplay(
int32_t max_word_per_line);
SHERPA_ONNX_API void DestroyDisplay(SherpaOnnxDisplay *display);
SHERPA_ONNX_API void DestroyDisplay(const SherpaOnnxDisplay *display);
/// Print the result.
SHERPA_ONNX_API void SherpaOnnxPrint(SherpaOnnxDisplay *display, int32_t idx,
const char *s);
SHERPA_ONNX_API void SherpaOnnxPrint(const SherpaOnnxDisplay *display,
int32_t idx, const char *s);
// ============================================================
// For offline ASR (i.e., non-streaming ASR)
// ============================================================
@@ -769,7 +773,7 @@ typedef void (*SherpaOnnxGeneratedAudioCallbackWithArg)(const float *samples,
int32_t n, void *arg);
typedef void (*SherpaOnnxGeneratedAudioProgressCallback)(const float *samples,
int32_t n, float p);
int32_t n, float p);
SHERPA_ONNX_API typedef struct SherpaOnnxOfflineTts SherpaOnnxOfflineTts;
@@ -839,7 +843,9 @@ SHERPA_ONNX_API const SherpaOnnxWave *SherpaOnnxReadWave(const char *filename);
SHERPA_ONNX_API void SherpaOnnxFreeWave(const SherpaOnnxWave *wave);
// Spoken language identification
// ============================================================
// For spoken language identification
// ============================================================
SHERPA_ONNX_API typedef struct
SherpaOnnxSpokenLanguageIdentificationWhisperConfig {
@@ -893,6 +899,169 @@ SherpaOnnxSpokenLanguageIdentificationCompute(
SHERPA_ONNX_API void SherpaOnnxDestroySpokenLanguageIdentificationResult(
const SherpaOnnxSpokenLanguageIdentificationResult *r);
// ============================================================
// For speaker embedding extraction
// ============================================================
SHERPA_ONNX_API typedef struct SherpaOnnxSpeakerEmbeddingExtractorConfig {
const char *model;
int32_t num_threads;
int32_t debug;
const char *provider;
} SherpaOnnxSpeakerEmbeddingExtractorConfig;
SHERPA_ONNX_API typedef struct SherpaOnnxSpeakerEmbeddingExtractor
SherpaOnnxSpeakerEmbeddingExtractor;
// The user has to invoke SherpaOnnxDestroySpeakerEmbeddingExtractor()
// to free the returned pointer to avoid memory leak
SHERPA_ONNX_API const SherpaOnnxSpeakerEmbeddingExtractor *
SherpaOnnxCreateSpeakerEmbeddingExtractor(
const SherpaOnnxSpeakerEmbeddingExtractorConfig *config);
SHERPA_ONNX_API void SherpaOnnxDestroySpeakerEmbeddingExtractor(
const SherpaOnnxSpeakerEmbeddingExtractor *p);
SHERPA_ONNX_API int32_t SherpaOnnxSpeakerEmbeddingExtractorDim(
const SherpaOnnxSpeakerEmbeddingExtractor *p);
// The user has to invoke DestroyOnlineStream() to free the returned pointer
// to avoid memory leak
SHERPA_ONNX_API const SherpaOnnxOnlineStream *
SherpaOnnxSpeakerEmbeddingExtractorCreateStream(
const SherpaOnnxSpeakerEmbeddingExtractor *p);
// Return 1 if the stream has enough feature frames for computing embeddings.
// Return 0 otherwise.
SHERPA_ONNX_API int32_t SherpaOnnxSpeakerEmbeddingExtractorIsReady(
const SherpaOnnxSpeakerEmbeddingExtractor *p,
const SherpaOnnxOnlineStream *s);
// Compute the embedding of the stream.
//
// @return Return a pointer pointing to an array containing the embedding.
// The length of the array is `dim` as returned by
// SherpaOnnxSpeakerEmbeddingExtractorDim(p)
//
// The user has to invoke SherpaOnnxSpeakerEmbeddingExtractorDestroyEmbedding()
// to free the returned pointer to avoid memory leak.
SHERPA_ONNX_API const float *
SherpaOnnxSpeakerEmbeddingExtractorComputeEmbedding(
const SherpaOnnxSpeakerEmbeddingExtractor *p,
const SherpaOnnxOnlineStream *s);
SHERPA_ONNX_API void SherpaOnnxSpeakerEmbeddingExtractorDestroyEmbedding(
const float *v);
SHERPA_ONNX_API typedef struct SherpaOnnxSpeakerEmbeddingManager
SherpaOnnxSpeakerEmbeddingManager;
// The user has to invoke SherpaOnnxDestroySpeakerEmbeddingManager()
// to free the returned pointer to avoid memory leak
SHERPA_ONNX_API const SherpaOnnxSpeakerEmbeddingManager *
SherpaOnnxCreateSpeakerEmbeddingManager(int32_t dim);
SHERPA_ONNX_API void SherpaOnnxDestroySpeakerEmbeddingManager(
const SherpaOnnxSpeakerEmbeddingManager *p);
// Register the embedding of a user
//
// @param name The name of the user
// @param p Pointer to an array containing the embeddings. The length of the
// array must be equal to `dim` used to construct the manager `p`.
//
// @return Return 1 if added successfully. Return 0 on error
SHERPA_ONNX_API int32_t
SherpaOnnxSpeakerEmbeddingManagerAdd(const SherpaOnnxSpeakerEmbeddingManager *p,
const char *name, const float *v);
// @param v Pointer to an array of embeddings. If there are n embeddings, then
// v[0] is the pointer to the 0-th array containing the embeddings
// v[1] is the pointer to the 1-st array containing the embeddings
// v[n-1] is the pointer to the last array containing the embeddings
// v[n] is a NULL pointer
// @return Return 1 if added successfully. Return 0 on error
SHERPA_ONNX_API int32_t SherpaOnnxSpeakerEmbeddingManagerAddList(
const SherpaOnnxSpeakerEmbeddingManager *p, const char *name,
const float **v);
// Similar to SherpaOnnxSpeakerEmbeddingManagerAddList() but the memory
// is flattened.
//
// The length of the input array should be `n * dim`.
//
// @return Return 1 if added successfully. Return 0 on error
SHERPA_ONNX_API int32_t SherpaOnnxSpeakerEmbeddingManagerAddListFlattened(
const SherpaOnnxSpeakerEmbeddingManager *p, const char *name,
const float *v, int32_t n);
// Remove a user.
// @param naem The name of the user to remove.
// @return Return 1 if removed successfully; return 0 on error.
//
// Note if the user does not exist, it also returns 0.
SHERPA_ONNX_API int32_t SherpaOnnxSpeakerEmbeddingManagerRemove(
const SherpaOnnxSpeakerEmbeddingManager *p, const char *name);
// Search if an existing users' embedding matches the given one.
//
// @param p Pointer to an array containing the embedding. The dim
// of the array must equal to `dim` used to construct the manager `p`.
// @param threshold A value between 0 and 1. If the similarity score exceeds
// this threshold, we say a match is found.
// @return Returns the name of the user if found. Return NULL if not found.
// If not NULL, the caller has to invoke
// SherpaOnnxSpeakerEmbeddingManagerFreeSearch() to free the returned
// pointer to avoid memory leak.
SHERPA_ONNX_API const char *SherpaOnnxSpeakerEmbeddingManagerSearch(
const SherpaOnnxSpeakerEmbeddingManager *p, const float *v,
float threshold);
SHERPA_ONNX_API void SherpaOnnxSpeakerEmbeddingManagerFreeSearch(
const char *name);
// Check whether the input embedding matches the embedding of the input
// speaker.
//
// It is for speaker verification.
//
// @param name The target speaker name.
// @param p The input embedding to check.
// @param threshold A value between 0 and 1.
// @return Return 1 if it matches. Otherwise, it returns 0.
SHERPA_ONNX_API int32_t SherpaOnnxSpeakerEmbeddingManagerVerify(
const SherpaOnnxSpeakerEmbeddingManager *p, const char *name,
const float *v, float threshold);
// Return 1 if the user with the name is in the manager.
// Return 0 if the user does not exist.
SHERPA_ONNX_API int32_t SherpaOnnxSpeakerEmbeddingManagerContains(
const SherpaOnnxSpeakerEmbeddingManager *p, const char *name);
// Return number of speakers in the manager.
SHERPA_ONNX_API int32_t SherpaOnnxSpeakerEmbeddingManagerNumSpeakers(
const SherpaOnnxSpeakerEmbeddingManager *p);
// Return the name of all speakers in the manager.
//
// @return Return an array of pointers `ans`. If there are n speakers, then
// - ans[0] contains the name of the 0-th speaker
// - ans[1] contains the name of the 1-st speaker
// - ans[n-1] contains the name of the last speaker
// - ans[n] is NULL
// If there are no users at all, then ans[0] is NULL. In any case,
// `ans` is not NULL.
//
// Each name is NULL-terminated
//
// The caller has to invoke SherpaOnnxSpeakerEmbeddingManagerFreeAllSpeakers()
// to free the returned pointer to avoid memory leak.
SHERPA_ONNX_API const char *const *
SherpaOnnxSpeakerEmbeddingManagerGetAllSpeakers(
const SherpaOnnxSpeakerEmbeddingManager *p);
SHERPA_ONNX_API void SherpaOnnxSpeakerEmbeddingManagerFreeAllSpeakers(
const char *const *names);
#if defined(__GNUC__)
#pragma GCC diagnostic pop
#endif

View File

@@ -168,7 +168,8 @@ class OfflineTtsVitsImpl : public OfflineTtsImpl {
ans.samples.insert(ans.samples.end(), audio.samples.begin(),
audio.samples.end());
if (callback) {
callback(audio.samples.data(), audio.samples.size(), b * 1.0 / num_batches);
callback(audio.samples.data(), audio.samples.size(),
b * 1.0 / num_batches);
// Caution(fangjun): audio is freed when the callback returns, so users
// should copy the data if they want to access the data after
// the callback returns to avoid segmentation fault.

View File

@@ -54,8 +54,8 @@ struct GeneratedAudio {
class OfflineTtsImpl;
using GeneratedAudioCallback =
std::function<void(const float * /*samples*/, int32_t /*n*/, float /*progress*/)>;
using GeneratedAudioCallback = std::function<void(
const float * /*samples*/, int32_t /*n*/, float /*progress*/)>;
class OfflineTts {
public:

View File

@@ -44,7 +44,8 @@ static void Handler(int32_t /*sig*/) {
fprintf(stderr, "\nCaught Ctrl + C. Exiting\n");
}
static void AudioGeneratedCallback(const float *s, int32_t n) {
static void AudioGeneratedCallback(const float *s, int32_t n,
float /*progress*/) {
if (n > 0) {
std::lock_guard<std::mutex> lock(g_buffer.mutex);
g_buffer.samples.push({s, s + n});

View File

@@ -47,7 +47,8 @@ static void Handler(int32_t /*sig*/) {
fprintf(stderr, "\nCaught Ctrl + C. Exiting\n");
}
static void AudioGeneratedCallback(const float *s, int32_t n, float /*progress*/) {
static void AudioGeneratedCallback(const float *s, int32_t n,
float /*progress*/) {
if (n > 0) {
Samples samples;
samples.data = std::vector<float>{s, s + n};

View File

@@ -9,9 +9,8 @@
#include "sherpa-onnx/csrc/parse-options.h"
#include "sherpa-onnx/csrc/wave-writer.h"
void audioCallback(const float *samples, int32_t n, float progress)
{
printf( "sample=%d, progress=%f\n", n, progress );
void audioCallback(const float *samples, int32_t n, float progress) {
printf("sample=%d, progress=%f\n", n, progress);
}
int main(int32_t argc, char *argv[]) {

View File

@@ -93,7 +93,7 @@ class SpeakerEmbeddingManager::Impl {
int32_t num_rows = embedding_matrix_.rows();
if (row_idx < num_rows - 1) {
embedding_matrix_.block(row_idx, 0, num_rows - -1 - row_idx, dim_) =
embedding_matrix_.block(row_idx, 0, num_rows - 1 - row_idx, dim_) =
embedding_matrix_.bottomRows(num_rows - 1 - row_idx);
}

View File

@@ -795,9 +795,10 @@ class SherpaOnnxOfflineTts {
explicit SherpaOnnxOfflineTts(const OfflineTtsConfig &config)
: tts_(config) {}
GeneratedAudio Generate(
const std::string &text, int64_t sid = 0, float speed = 1.0,
std::function<void(const float *, int32_t, float)> callback = nullptr) const {
GeneratedAudio Generate(const std::string &text, int64_t sid = 0,
float speed = 1.0,
std::function<void(const float *, int32_t, float)>
callback = nullptr) const {
return tts_.Generate(text, sid, speed, callback);
}

View File

@@ -55,14 +55,16 @@ void PybindOfflineTts(py::module *m) {
.def(
"generate",
[](const PyClass &self, const std::string &text, int64_t sid,
float speed, std::function<void(py::array_t<float>, float)> callback)
float speed,
std::function<void(py::array_t<float>, float)> callback)
-> GeneratedAudio {
if (!callback) {
return self.Generate(text, sid, speed);
}
std::function<void(const float *, int32_t, float)> callback_wrapper =
[callback](const float *samples, int32_t n, float progress) {
std::function<void(const float *, int32_t, float)>
callback_wrapper = [callback](const float *samples, int32_t n,
float progress) {
// CAUTION(fangjun): we have to copy samples since it is
// freed once the call back returns.