Add runtime support for wespeaker models (#516)
This commit is contained in:
252
python-api-examples/speaker-identification.py
Executable file
252
python-api-examples/speaker-identification.py
Executable file
@@ -0,0 +1,252 @@
|
|||||||
|
#!/usr/bin/env python3
|
||||||
|
|
||||||
|
"""
|
||||||
|
This script shows how to use Python APIs for speaker identification.
|
||||||
|
|
||||||
|
Usage:
|
||||||
|
|
||||||
|
(1) Prepare a text file containing speaker related files.
|
||||||
|
|
||||||
|
Each line in the text file contains two columns. The first column is the
|
||||||
|
speaker name, while the second column contains the wave file of the speaker.
|
||||||
|
|
||||||
|
If the text file contains multiple wave files for the same speaker, then the
|
||||||
|
embeddings of these files are averaged.
|
||||||
|
|
||||||
|
An example text file is given below:
|
||||||
|
|
||||||
|
foo /path/to/a.wav
|
||||||
|
bar /path/to/b.wav
|
||||||
|
foo /path/to/c.wav
|
||||||
|
foobar /path/to/d.wav
|
||||||
|
|
||||||
|
Each wave file should contain only a single channel; the sample format
|
||||||
|
should be int16_t; the sample rate can be arbitrary.
|
||||||
|
|
||||||
|
(2) Download a model for computing speaker embeddings
|
||||||
|
|
||||||
|
Please visit
|
||||||
|
https://github.com/k2-fsa/sherpa-onnx/releases/tag/speaker-recongition-models
|
||||||
|
to download a model. An example is given below:
|
||||||
|
|
||||||
|
wget https://github.com/k2-fsa/sherpa-onnx/releases/download/speaker-recongition-models/zh_cnceleb_resnet34.onnx
|
||||||
|
|
||||||
|
Note that `zh` means Chinese, while `en` means English.
|
||||||
|
|
||||||
|
(3) Run this script
|
||||||
|
|
||||||
|
Assume the filename of the text file is speaker.txt.
|
||||||
|
|
||||||
|
python3 ./python-api-examples/speaker-identification.py \
|
||||||
|
--speaker-file ./speaker.txt \
|
||||||
|
--model ./zh_cnceleb_resnet34.onnx
|
||||||
|
"""
|
||||||
|
import argparse
|
||||||
|
import queue
|
||||||
|
import threading
|
||||||
|
from collections import defaultdict
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Dict, List, Tuple
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import sherpa_onnx
|
||||||
|
import torchaudio
|
||||||
|
|
||||||
|
try:
|
||||||
|
import sounddevice as sd
|
||||||
|
except ImportError:
|
||||||
|
print("Please install sounddevice first. You can use")
|
||||||
|
print()
|
||||||
|
print(" pip install sounddevice")
|
||||||
|
print()
|
||||||
|
print("to install it")
|
||||||
|
sys.exit(-1)
|
||||||
|
|
||||||
|
|
||||||
|
def get_args():
|
||||||
|
parser = argparse.ArgumentParser(
|
||||||
|
formatter_class=argparse.ArgumentDefaultsHelpFormatter
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--speaker-file",
|
||||||
|
type=str,
|
||||||
|
required=True,
|
||||||
|
help="""Path to the speaker file. Read the help doc at the beginning of this
|
||||||
|
file for the format.""",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--model",
|
||||||
|
type=str,
|
||||||
|
required=True,
|
||||||
|
help="Path to the model file.",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument("--threshold", type=float, default=0.6)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--num-threads",
|
||||||
|
type=int,
|
||||||
|
default=1,
|
||||||
|
help="Number of threads for neural network computation",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--debug",
|
||||||
|
type=bool,
|
||||||
|
default=False,
|
||||||
|
help="True to show debug messages",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--provider",
|
||||||
|
type=str,
|
||||||
|
default="cpu",
|
||||||
|
help="Valid values: cpu, cuda, coreml",
|
||||||
|
)
|
||||||
|
|
||||||
|
return parser.parse_args()
|
||||||
|
|
||||||
|
|
||||||
|
def load_speaker_embedding_model(args):
|
||||||
|
config = sherpa_onnx.SpeakerEmbeddingExtractorConfig(
|
||||||
|
model=args.model,
|
||||||
|
num_threads=args.num_threads,
|
||||||
|
debug=args.debug,
|
||||||
|
provider=args.provider,
|
||||||
|
)
|
||||||
|
if not config.validate():
|
||||||
|
raise ValueError(f"Invalid config. {config}")
|
||||||
|
extractor = sherpa_onnx.SpeakerEmbeddingExtractor(config)
|
||||||
|
return extractor
|
||||||
|
|
||||||
|
|
||||||
|
def load_speaker_file(args) -> Dict[str, List[str]]:
|
||||||
|
if not Path(args.speaker_file).is_file():
|
||||||
|
raise ValueError(f"--speaker-file {args.speaker_file} does not exist")
|
||||||
|
|
||||||
|
ans = defaultdict(list)
|
||||||
|
with open(args.speaker_file) as f:
|
||||||
|
for line in f:
|
||||||
|
line = line.strip()
|
||||||
|
if not line:
|
||||||
|
continue
|
||||||
|
|
||||||
|
fields = line.split()
|
||||||
|
if len(fields) != 2:
|
||||||
|
raise ValueError(f"Invalid line: {line}. Fields: {fields}")
|
||||||
|
|
||||||
|
speaker_name, filename = fields
|
||||||
|
ans[speaker_name].append(filename)
|
||||||
|
return ans
|
||||||
|
|
||||||
|
|
||||||
|
def load_audio(filename: str) -> Tuple[np.ndarray, int]:
|
||||||
|
samples, sample_rate = torchaudio.load(filename)
|
||||||
|
return samples[0].contiguous().numpy(), sample_rate
|
||||||
|
|
||||||
|
|
||||||
|
def compute_speaker_embedding(
|
||||||
|
filenames: List[str],
|
||||||
|
extractor: sherpa_onnx.SpeakerEmbeddingExtractor,
|
||||||
|
) -> np.ndarray:
|
||||||
|
assert len(filenames) > 0, f"filenames is empty"
|
||||||
|
|
||||||
|
ans = None
|
||||||
|
for filename in filenames:
|
||||||
|
print(f"processing {filename}")
|
||||||
|
samples, sample_rate = load_audio(filename)
|
||||||
|
stream = extractor.create_stream()
|
||||||
|
stream.accept_waveform(sample_rate=sample_rate, waveform=samples)
|
||||||
|
stream.input_finished()
|
||||||
|
|
||||||
|
assert extractor.is_ready(stream)
|
||||||
|
embedding = extractor.compute(stream)
|
||||||
|
embedding = np.array(embedding)
|
||||||
|
if ans is None:
|
||||||
|
ans = embedding
|
||||||
|
else:
|
||||||
|
ans += embedding
|
||||||
|
|
||||||
|
return ans / len(filenames)
|
||||||
|
|
||||||
|
|
||||||
|
g_buffer = queue.Queue()
|
||||||
|
g_stop = False
|
||||||
|
g_sample_rate = 16000
|
||||||
|
g_read_mic_thread = None
|
||||||
|
|
||||||
|
|
||||||
|
def read_mic():
|
||||||
|
print("Please speak!")
|
||||||
|
samples_per_read = int(0.1 * g_sample_rate) # 0.1 second = 100 ms
|
||||||
|
with sd.InputStream(channels=1, dtype="float32", samplerate=g_sample_rate) as s:
|
||||||
|
while not g_stop:
|
||||||
|
samples, _ = s.read(samples_per_read) # a blocking read
|
||||||
|
g_buffer.put(samples)
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
args = get_args()
|
||||||
|
print(args)
|
||||||
|
extractor = load_speaker_embedding_model(args)
|
||||||
|
speaker_file = load_speaker_file(args)
|
||||||
|
|
||||||
|
manager = sherpa_onnx.SpeakerEmbeddingManager(extractor.dim)
|
||||||
|
for name, filename_list in speaker_file.items():
|
||||||
|
embedding = compute_speaker_embedding(
|
||||||
|
filenames=filename_list,
|
||||||
|
extractor=extractor,
|
||||||
|
)
|
||||||
|
status = manager.add(name, embedding)
|
||||||
|
if not status:
|
||||||
|
raise RuntimeError(f"Failed to register speaker {name}")
|
||||||
|
|
||||||
|
devices = sd.query_devices()
|
||||||
|
if len(devices) == 0:
|
||||||
|
print("No microphone devices found")
|
||||||
|
sys.exit(0)
|
||||||
|
|
||||||
|
print(devices)
|
||||||
|
default_input_device_idx = sd.default.device[0]
|
||||||
|
print(f'Use default device: {devices[default_input_device_idx]["name"]}')
|
||||||
|
|
||||||
|
global g_stop
|
||||||
|
global g_read_mic_thread
|
||||||
|
while True:
|
||||||
|
key = input("Press enter to start recording")
|
||||||
|
if key.lower() in ("q", "quit"):
|
||||||
|
g_stop = True
|
||||||
|
break
|
||||||
|
|
||||||
|
g_stop = False
|
||||||
|
g_buffer.queue.clear()
|
||||||
|
g_read_mic_thread = threading.Thread(target=read_mic)
|
||||||
|
g_read_mic_thread.start()
|
||||||
|
input("Press enter to stop recording")
|
||||||
|
g_stop = True
|
||||||
|
g_read_mic_thread.join()
|
||||||
|
print("Compute embedding")
|
||||||
|
stream = extractor.create_stream()
|
||||||
|
while not g_buffer.empty():
|
||||||
|
samples = g_buffer.get()
|
||||||
|
stream.accept_waveform(sample_rate=g_sample_rate, waveform=samples)
|
||||||
|
stream.input_finished()
|
||||||
|
|
||||||
|
embedding = extractor.compute(stream)
|
||||||
|
embedding = np.array(embedding)
|
||||||
|
name = manager.search(embedding, threshold=args.threshold)
|
||||||
|
if not name:
|
||||||
|
name = "unknown"
|
||||||
|
print(f"Predicted name: {name}")
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
try:
|
||||||
|
main()
|
||||||
|
except KeyboardInterrupt:
|
||||||
|
print("\nCaught Ctrl + C. Exiting")
|
||||||
|
g_stop = True
|
||||||
|
if g_read_mic_thread.is_alive():
|
||||||
|
g_read_mic_thread.join()
|
||||||
@@ -96,6 +96,14 @@ set(sources
|
|||||||
wave-reader.cc
|
wave-reader.cc
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# speaker embedding extractor
|
||||||
|
list(APPEND sources
|
||||||
|
speaker-embedding-extractor-impl.cc
|
||||||
|
speaker-embedding-extractor-wespeaker-model.cc
|
||||||
|
speaker-embedding-extractor.cc
|
||||||
|
speaker-embedding-manager.cc
|
||||||
|
)
|
||||||
|
|
||||||
list(APPEND sources
|
list(APPEND sources
|
||||||
lexicon.cc
|
lexicon.cc
|
||||||
offline-tts-impl.cc
|
offline-tts-impl.cc
|
||||||
@@ -387,6 +395,10 @@ if(SHERPA_ONNX_ENABLE_TESTS)
|
|||||||
utfcpp-test.cc
|
utfcpp-test.cc
|
||||||
)
|
)
|
||||||
|
|
||||||
|
list(APPEND sherpa_onnx_test_srcs
|
||||||
|
speaker-embedding-manager-test.cc
|
||||||
|
)
|
||||||
|
|
||||||
function(sherpa_onnx_add_test source)
|
function(sherpa_onnx_add_test source)
|
||||||
get_filename_component(name ${source} NAME_WE)
|
get_filename_component(name ${source} NAME_WE)
|
||||||
set(target_name ${name})
|
set(target_name ${name})
|
||||||
|
|||||||
@@ -64,8 +64,8 @@ TEST(ContextGraph, Benchmark) {
|
|||||||
auto stop = std::chrono::high_resolution_clock::now();
|
auto stop = std::chrono::high_resolution_clock::now();
|
||||||
auto duration =
|
auto duration =
|
||||||
std::chrono::duration_cast<std::chrono::microseconds>(stop - start);
|
std::chrono::duration_cast<std::chrono::microseconds>(stop - start);
|
||||||
SHERPA_ONNX_LOGE("Construct context graph for %d item takes %ld us.", num,
|
SHERPA_ONNX_LOGE("Construct context graph for %d item takes %d us.", num,
|
||||||
duration.count());
|
static_cast<int32_t>(duration.count()));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -91,4 +91,9 @@ Ort::SessionOptions GetSessionOptions(const OfflineTtsModelConfig &config) {
|
|||||||
return GetSessionOptionsImpl(config.num_threads, config.provider);
|
return GetSessionOptionsImpl(config.num_threads, config.provider);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
Ort::SessionOptions GetSessionOptions(
|
||||||
|
const SpeakerEmbeddingExtractorConfig &config) {
|
||||||
|
return GetSessionOptionsImpl(config.num_threads, config.provider);
|
||||||
|
}
|
||||||
|
|
||||||
} // namespace sherpa_onnx
|
} // namespace sherpa_onnx
|
||||||
|
|||||||
@@ -11,6 +11,7 @@
|
|||||||
#include "sherpa-onnx/csrc/offline-tts-model-config.h"
|
#include "sherpa-onnx/csrc/offline-tts-model-config.h"
|
||||||
#include "sherpa-onnx/csrc/online-lm-config.h"
|
#include "sherpa-onnx/csrc/online-lm-config.h"
|
||||||
#include "sherpa-onnx/csrc/online-model-config.h"
|
#include "sherpa-onnx/csrc/online-model-config.h"
|
||||||
|
#include "sherpa-onnx/csrc/speaker-embedding-extractor.h"
|
||||||
#include "sherpa-onnx/csrc/vad-model-config.h"
|
#include "sherpa-onnx/csrc/vad-model-config.h"
|
||||||
|
|
||||||
namespace sherpa_onnx {
|
namespace sherpa_onnx {
|
||||||
@@ -26,6 +27,9 @@ Ort::SessionOptions GetSessionOptions(const OnlineLMConfig &config);
|
|||||||
Ort::SessionOptions GetSessionOptions(const VadModelConfig &config);
|
Ort::SessionOptions GetSessionOptions(const VadModelConfig &config);
|
||||||
|
|
||||||
Ort::SessionOptions GetSessionOptions(const OfflineTtsModelConfig &config);
|
Ort::SessionOptions GetSessionOptions(const OfflineTtsModelConfig &config);
|
||||||
|
|
||||||
|
Ort::SessionOptions GetSessionOptions(
|
||||||
|
const SpeakerEmbeddingExtractorConfig &config);
|
||||||
} // namespace sherpa_onnx
|
} // namespace sherpa_onnx
|
||||||
|
|
||||||
#endif // SHERPA_ONNX_CSRC_SESSION_H_
|
#endif // SHERPA_ONNX_CSRC_SESSION_H_
|
||||||
|
|||||||
82
sherpa-onnx/csrc/speaker-embedding-extractor-impl.cc
Normal file
82
sherpa-onnx/csrc/speaker-embedding-extractor-impl.cc
Normal file
@@ -0,0 +1,82 @@
|
|||||||
|
// sherpa-onnx/csrc/speaker-embedding-extractor-impl.cc
|
||||||
|
//
|
||||||
|
// Copyright (c) 2023 Xiaomi Corporation
|
||||||
|
#include "sherpa-onnx/csrc/speaker-embedding-extractor-impl.h"
|
||||||
|
|
||||||
|
#include "sherpa-onnx/csrc/macros.h"
|
||||||
|
#include "sherpa-onnx/csrc/onnx-utils.h"
|
||||||
|
#include "sherpa-onnx/csrc/speaker-embedding-extractor-wespeaker-impl.h"
|
||||||
|
|
||||||
|
namespace sherpa_onnx {
|
||||||
|
|
||||||
|
namespace {
|
||||||
|
|
||||||
|
enum class ModelType {
|
||||||
|
kWeSpeaker,
|
||||||
|
kUnkown,
|
||||||
|
};
|
||||||
|
|
||||||
|
} // namespace
|
||||||
|
|
||||||
|
static ModelType GetModelType(char *model_data, size_t model_data_length,
|
||||||
|
bool debug) {
|
||||||
|
Ort::Env env(ORT_LOGGING_LEVEL_WARNING);
|
||||||
|
Ort::SessionOptions sess_opts;
|
||||||
|
|
||||||
|
auto sess = std::make_unique<Ort::Session>(env, model_data, model_data_length,
|
||||||
|
sess_opts);
|
||||||
|
|
||||||
|
Ort::ModelMetadata meta_data = sess->GetModelMetadata();
|
||||||
|
if (debug) {
|
||||||
|
std::ostringstream os;
|
||||||
|
PrintModelMetadata(os, meta_data);
|
||||||
|
SHERPA_ONNX_LOGE("%s", os.str().c_str());
|
||||||
|
}
|
||||||
|
|
||||||
|
Ort::AllocatorWithDefaultOptions allocator;
|
||||||
|
auto model_type =
|
||||||
|
meta_data.LookupCustomMetadataMapAllocated("framework", allocator);
|
||||||
|
if (!model_type) {
|
||||||
|
SHERPA_ONNX_LOGE(
|
||||||
|
"No model_type in the metadata!\n"
|
||||||
|
"Please make sure you have added metadata to the model.\n\n"
|
||||||
|
"For instance, you can use\n"
|
||||||
|
"https://github.com/k2-fsa/sherpa-onnx/blob/master/scripts/wespeaker/"
|
||||||
|
"add_meta_data.py"
|
||||||
|
"to add metadata to models from WeSpeaker\n");
|
||||||
|
return ModelType::kUnkown;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (model_type.get() == std::string("wespeaker")) {
|
||||||
|
return ModelType::kWeSpeaker;
|
||||||
|
} else {
|
||||||
|
SHERPA_ONNX_LOGE("Unsupported model_type: %s", model_type.get());
|
||||||
|
return ModelType::kUnkown;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
std::unique_ptr<SpeakerEmbeddingExtractorImpl>
|
||||||
|
SpeakerEmbeddingExtractorImpl::Create(
|
||||||
|
const SpeakerEmbeddingExtractorConfig &config) {
|
||||||
|
ModelType model_type = ModelType::kUnkown;
|
||||||
|
|
||||||
|
{
|
||||||
|
auto buffer = ReadFile(config.model);
|
||||||
|
|
||||||
|
model_type = GetModelType(buffer.data(), buffer.size(), config.debug);
|
||||||
|
}
|
||||||
|
|
||||||
|
switch (model_type) {
|
||||||
|
case ModelType::kWeSpeaker:
|
||||||
|
return std::make_unique<SpeakerEmbeddingExtractorWeSpeakerImpl>(config);
|
||||||
|
case ModelType::kUnkown:
|
||||||
|
SHERPA_ONNX_LOGE(
|
||||||
|
"Unknown model type in for speaker embedding extractor!");
|
||||||
|
return nullptr;
|
||||||
|
}
|
||||||
|
|
||||||
|
// unreachable code
|
||||||
|
return nullptr;
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace sherpa_onnx
|
||||||
34
sherpa-onnx/csrc/speaker-embedding-extractor-impl.h
Normal file
34
sherpa-onnx/csrc/speaker-embedding-extractor-impl.h
Normal file
@@ -0,0 +1,34 @@
|
|||||||
|
// sherpa-onnx/csrc/speaker-embedding-extractor-impl.h
|
||||||
|
//
|
||||||
|
// Copyright (c) 2023 Xiaomi Corporation
|
||||||
|
|
||||||
|
#ifndef SHERPA_ONNX_CSRC_SPEAKER_EMBEDDING_EXTRACTOR_IMPL_H_
|
||||||
|
#define SHERPA_ONNX_CSRC_SPEAKER_EMBEDDING_EXTRACTOR_IMPL_H_
|
||||||
|
|
||||||
|
#include <memory>
|
||||||
|
#include <string>
|
||||||
|
#include <vector>
|
||||||
|
|
||||||
|
#include "sherpa-onnx/csrc/speaker-embedding-extractor.h"
|
||||||
|
|
||||||
|
namespace sherpa_onnx {
|
||||||
|
|
||||||
|
class SpeakerEmbeddingExtractorImpl {
|
||||||
|
public:
|
||||||
|
virtual ~SpeakerEmbeddingExtractorImpl() = default;
|
||||||
|
|
||||||
|
static std::unique_ptr<SpeakerEmbeddingExtractorImpl> Create(
|
||||||
|
const SpeakerEmbeddingExtractorConfig &config);
|
||||||
|
|
||||||
|
virtual int32_t Dim() const = 0;
|
||||||
|
|
||||||
|
virtual std::unique_ptr<OnlineStream> CreateStream() const = 0;
|
||||||
|
|
||||||
|
virtual bool IsReady(OnlineStream *s) const = 0;
|
||||||
|
|
||||||
|
virtual std::vector<float> Compute(OnlineStream *s) const = 0;
|
||||||
|
};
|
||||||
|
|
||||||
|
} // namespace sherpa_onnx
|
||||||
|
|
||||||
|
#endif // SHERPA_ONNX_CSRC_SPEAKER_EMBEDDING_EXTRACTOR_IMPL_H_
|
||||||
@@ -0,0 +1,79 @@
|
|||||||
|
// sherpa-onnx/csrc/speaker-embedding-extractor-wespeaker-impl.h
|
||||||
|
//
|
||||||
|
// Copyright (c) 2023 Xiaomi Corporation
|
||||||
|
|
||||||
|
#ifndef SHERPA_ONNX_CSRC_SPEAKER_EMBEDDING_EXTRACTOR_WESPEAKER_IMPL_H_
|
||||||
|
#define SHERPA_ONNX_CSRC_SPEAKER_EMBEDDING_EXTRACTOR_WESPEAKER_IMPL_H_
|
||||||
|
#include <algorithm>
|
||||||
|
#include <memory>
|
||||||
|
#include <utility>
|
||||||
|
#include <vector>
|
||||||
|
|
||||||
|
#include "sherpa-onnx/csrc/speaker-embedding-extractor-impl.h"
|
||||||
|
#include "sherpa-onnx/csrc/speaker-embedding-extractor-wespeaker-model.h"
|
||||||
|
|
||||||
|
namespace sherpa_onnx {
|
||||||
|
|
||||||
|
class SpeakerEmbeddingExtractorWeSpeakerImpl
|
||||||
|
: public SpeakerEmbeddingExtractorImpl {
|
||||||
|
public:
|
||||||
|
explicit SpeakerEmbeddingExtractorWeSpeakerImpl(
|
||||||
|
const SpeakerEmbeddingExtractorConfig &config)
|
||||||
|
: model_(config) {}
|
||||||
|
|
||||||
|
int32_t Dim() const override { return model_.GetMetaData().output_dim; }
|
||||||
|
|
||||||
|
std::unique_ptr<OnlineStream> CreateStream() const override {
|
||||||
|
FeatureExtractorConfig feat_config;
|
||||||
|
auto meta_data = model_.GetMetaData();
|
||||||
|
feat_config.sampling_rate = meta_data.sample_rate;
|
||||||
|
feat_config.normalize_samples = meta_data.normalize_features;
|
||||||
|
|
||||||
|
return std::make_unique<OnlineStream>(feat_config);
|
||||||
|
}
|
||||||
|
|
||||||
|
bool IsReady(OnlineStream *s) const override {
|
||||||
|
return s->GetNumProcessedFrames() < s->NumFramesReady();
|
||||||
|
}
|
||||||
|
|
||||||
|
std::vector<float> Compute(OnlineStream *s) const override {
|
||||||
|
int32_t num_frames = s->NumFramesReady() - s->GetNumProcessedFrames();
|
||||||
|
if (num_frames <= 0) {
|
||||||
|
SHERPA_ONNX_LOGE(
|
||||||
|
"Please make sure IsReady(s) returns true. num_frames: %d",
|
||||||
|
num_frames);
|
||||||
|
return {};
|
||||||
|
}
|
||||||
|
|
||||||
|
std::vector<float> features =
|
||||||
|
s->GetFrames(s->GetNumProcessedFrames(), num_frames);
|
||||||
|
|
||||||
|
s->GetNumProcessedFrames() += num_frames;
|
||||||
|
|
||||||
|
int32_t feat_dim = features.size() / num_frames;
|
||||||
|
|
||||||
|
auto memory_info =
|
||||||
|
Ort::MemoryInfo::CreateCpu(OrtDeviceAllocator, OrtMemTypeDefault);
|
||||||
|
|
||||||
|
std::array<int64_t, 3> x_shape{1, num_frames, feat_dim};
|
||||||
|
Ort::Value x =
|
||||||
|
Ort::Value::CreateTensor(memory_info, features.data(), features.size(),
|
||||||
|
x_shape.data(), x_shape.size());
|
||||||
|
Ort::Value embedding = model_.Compute(std::move(x));
|
||||||
|
std::vector<int64_t> embedding_shape =
|
||||||
|
embedding.GetTensorTypeAndShapeInfo().GetShape();
|
||||||
|
|
||||||
|
std::vector<float> ans(embedding_shape[1]);
|
||||||
|
std::copy(embedding.GetTensorData<float>(),
|
||||||
|
embedding.GetTensorData<float>() + ans.size(), ans.begin());
|
||||||
|
|
||||||
|
return ans;
|
||||||
|
}
|
||||||
|
|
||||||
|
private:
|
||||||
|
SpeakerEmbeddingExtractorWeSpeakerModel model_;
|
||||||
|
};
|
||||||
|
|
||||||
|
} // namespace sherpa_onnx
|
||||||
|
|
||||||
|
#endif // SHERPA_ONNX_CSRC_SPEAKER_EMBEDDING_EXTRACTOR_WESPEAKER_IMPL_H_
|
||||||
@@ -0,0 +1,20 @@
|
|||||||
|
// sherpa-onnx/csrc/speaker-embedding-extractor-wespeaker-model-metadata.h
|
||||||
|
//
|
||||||
|
// Copyright (c) 2023 Xiaomi Corporation
|
||||||
|
#ifndef SHERPA_ONNX_CSRC_SPEAKER_EMBEDDING_EXTRACTOR_WESPEAKER_MODEL_METADATA_H_
|
||||||
|
#define SHERPA_ONNX_CSRC_SPEAKER_EMBEDDING_EXTRACTOR_WESPEAKER_MODEL_METADATA_H_
|
||||||
|
|
||||||
|
#include <cstdint>
|
||||||
|
#include <string>
|
||||||
|
|
||||||
|
namespace sherpa_onnx {
|
||||||
|
|
||||||
|
struct SpeakerEmbeddingExtractorWeSpeakerModelMetaData {
|
||||||
|
int32_t output_dim = 0;
|
||||||
|
int32_t sample_rate = 0;
|
||||||
|
int32_t normalize_features = 0;
|
||||||
|
std::string language;
|
||||||
|
};
|
||||||
|
|
||||||
|
} // namespace sherpa_onnx
|
||||||
|
#endif // SHERPA_ONNX_CSRC_SPEAKER_EMBEDDING_EXTRACTOR_WESPEAKER_MODEL_METADATA_H_
|
||||||
112
sherpa-onnx/csrc/speaker-embedding-extractor-wespeaker-model.cc
Normal file
112
sherpa-onnx/csrc/speaker-embedding-extractor-wespeaker-model.cc
Normal file
@@ -0,0 +1,112 @@
|
|||||||
|
// sherpa-onnx/csrc/speaker-embedding-extractor-wespeaker-model.cc
|
||||||
|
//
|
||||||
|
// Copyright (c) 2023 Xiaomi Corporation
|
||||||
|
|
||||||
|
#include "sherpa-onnx/csrc/speaker-embedding-extractor-wespeaker-model.h"
|
||||||
|
|
||||||
|
#include <string>
|
||||||
|
#include <utility>
|
||||||
|
#include <vector>
|
||||||
|
|
||||||
|
#include "sherpa-onnx/csrc/macros.h"
|
||||||
|
#include "sherpa-onnx/csrc/onnx-utils.h"
|
||||||
|
#include "sherpa-onnx/csrc/session.h"
|
||||||
|
#include "sherpa-onnx/csrc/speaker-embedding-extractor-wespeaker-model-metadata.h"
|
||||||
|
|
||||||
|
namespace sherpa_onnx {
|
||||||
|
|
||||||
|
class SpeakerEmbeddingExtractorWeSpeakerModel::Impl {
|
||||||
|
public:
|
||||||
|
explicit Impl(const SpeakerEmbeddingExtractorConfig &config)
|
||||||
|
: config_(config),
|
||||||
|
env_(ORT_LOGGING_LEVEL_ERROR),
|
||||||
|
sess_opts_(GetSessionOptions(config)),
|
||||||
|
allocator_{} {
|
||||||
|
{
|
||||||
|
auto buf = ReadFile(config.model);
|
||||||
|
Init(buf.data(), buf.size());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
Ort::Value Compute(Ort::Value x) const {
|
||||||
|
std::array<Ort::Value, 1> inputs = {std::move(x)};
|
||||||
|
|
||||||
|
auto outputs =
|
||||||
|
sess_->Run({}, input_names_ptr_.data(), inputs.data(), inputs.size(),
|
||||||
|
output_names_ptr_.data(), output_names_ptr_.size());
|
||||||
|
return std::move(outputs[0]);
|
||||||
|
}
|
||||||
|
|
||||||
|
const SpeakerEmbeddingExtractorWeSpeakerModelMetaData &GetMetaData() const {
|
||||||
|
return meta_data_;
|
||||||
|
}
|
||||||
|
|
||||||
|
private:
|
||||||
|
void Init(void *model_data, size_t model_data_length) {
|
||||||
|
sess_ = std::make_unique<Ort::Session>(env_, model_data, model_data_length,
|
||||||
|
sess_opts_);
|
||||||
|
|
||||||
|
GetInputNames(sess_.get(), &input_names_, &input_names_ptr_);
|
||||||
|
|
||||||
|
GetOutputNames(sess_.get(), &output_names_, &output_names_ptr_);
|
||||||
|
|
||||||
|
// get meta data
|
||||||
|
Ort::ModelMetadata meta_data = sess_->GetModelMetadata();
|
||||||
|
if (config_.debug) {
|
||||||
|
std::ostringstream os;
|
||||||
|
PrintModelMetadata(os, meta_data);
|
||||||
|
SHERPA_ONNX_LOGE("%s", os.str().c_str());
|
||||||
|
}
|
||||||
|
|
||||||
|
Ort::AllocatorWithDefaultOptions allocator; // used in the macro below
|
||||||
|
SHERPA_ONNX_READ_META_DATA(meta_data_.output_dim, "output_dim");
|
||||||
|
SHERPA_ONNX_READ_META_DATA(meta_data_.sample_rate, "sample_rate");
|
||||||
|
SHERPA_ONNX_READ_META_DATA(meta_data_.normalize_features,
|
||||||
|
"normalize_features");
|
||||||
|
SHERPA_ONNX_READ_META_DATA_STR(meta_data_.language, "language");
|
||||||
|
|
||||||
|
std::string framework;
|
||||||
|
SHERPA_ONNX_READ_META_DATA_STR(framework, "framework");
|
||||||
|
if (framework != "wespeaker") {
|
||||||
|
SHERPA_ONNX_LOGE("Expect a wespeaker model, given: %s",
|
||||||
|
framework.c_str());
|
||||||
|
exit(-1);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
private:
|
||||||
|
SpeakerEmbeddingExtractorConfig config_;
|
||||||
|
Ort::Env env_;
|
||||||
|
Ort::SessionOptions sess_opts_;
|
||||||
|
Ort::AllocatorWithDefaultOptions allocator_;
|
||||||
|
|
||||||
|
std::unique_ptr<Ort::Session> sess_;
|
||||||
|
|
||||||
|
std::vector<std::string> input_names_;
|
||||||
|
std::vector<const char *> input_names_ptr_;
|
||||||
|
|
||||||
|
std::vector<std::string> output_names_;
|
||||||
|
std::vector<const char *> output_names_ptr_;
|
||||||
|
|
||||||
|
SpeakerEmbeddingExtractorWeSpeakerModelMetaData meta_data_;
|
||||||
|
};
|
||||||
|
|
||||||
|
SpeakerEmbeddingExtractorWeSpeakerModel::
|
||||||
|
SpeakerEmbeddingExtractorWeSpeakerModel(
|
||||||
|
const SpeakerEmbeddingExtractorConfig &config)
|
||||||
|
: impl_(std::make_unique<Impl>(config)) {}
|
||||||
|
|
||||||
|
SpeakerEmbeddingExtractorWeSpeakerModel::
|
||||||
|
~SpeakerEmbeddingExtractorWeSpeakerModel() = default;
|
||||||
|
|
||||||
|
const SpeakerEmbeddingExtractorWeSpeakerModelMetaData &
|
||||||
|
SpeakerEmbeddingExtractorWeSpeakerModel::GetMetaData() const {
|
||||||
|
return impl_->GetMetaData();
|
||||||
|
}
|
||||||
|
|
||||||
|
Ort::Value SpeakerEmbeddingExtractorWeSpeakerModel::Compute(
|
||||||
|
Ort::Value x) const {
|
||||||
|
return impl_->Compute(std::move(x));
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace sherpa_onnx
|
||||||
@@ -0,0 +1,37 @@
|
|||||||
|
// sherpa-onnx/csrc/speaker-embedding-extractor-wespeaker-model.h
|
||||||
|
//
|
||||||
|
// Copyright (c) 2023 Xiaomi Corporation
|
||||||
|
#ifndef SHERPA_ONNX_CSRC_SPEAKER_EMBEDDING_EXTRACTOR_WESPEAKER_MODEL_H_
|
||||||
|
#define SHERPA_ONNX_CSRC_SPEAKER_EMBEDDING_EXTRACTOR_WESPEAKER_MODEL_H_
|
||||||
|
|
||||||
|
#include <memory>
|
||||||
|
|
||||||
|
#include "onnxruntime_cxx_api.h" // NOLINT
|
||||||
|
#include "sherpa-onnx/csrc/speaker-embedding-extractor-wespeaker-model-metadata.h"
|
||||||
|
#include "sherpa-onnx/csrc/speaker-embedding-extractor.h"
|
||||||
|
|
||||||
|
namespace sherpa_onnx {
|
||||||
|
|
||||||
|
class SpeakerEmbeddingExtractorWeSpeakerModel {
|
||||||
|
public:
|
||||||
|
explicit SpeakerEmbeddingExtractorWeSpeakerModel(
|
||||||
|
const SpeakerEmbeddingExtractorConfig &config);
|
||||||
|
|
||||||
|
~SpeakerEmbeddingExtractorWeSpeakerModel();
|
||||||
|
|
||||||
|
const SpeakerEmbeddingExtractorWeSpeakerModelMetaData &GetMetaData() const;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* @param x A float32 tensor of shape (N, T, C)
|
||||||
|
* @return A float32 tensor of shape (N, C)
|
||||||
|
*/
|
||||||
|
Ort::Value Compute(Ort::Value x) const;
|
||||||
|
|
||||||
|
private:
|
||||||
|
class Impl;
|
||||||
|
std::unique_ptr<Impl> impl_;
|
||||||
|
};
|
||||||
|
|
||||||
|
} // namespace sherpa_onnx
|
||||||
|
|
||||||
|
#endif // SHERPA_ONNX_CSRC_SPEAKER_EMBEDDING_EXTRACTOR_WESPEAKER_MODEL_H_
|
||||||
74
sherpa-onnx/csrc/speaker-embedding-extractor.cc
Normal file
74
sherpa-onnx/csrc/speaker-embedding-extractor.cc
Normal file
@@ -0,0 +1,74 @@
|
|||||||
|
// sherpa-onnx/csrc/speaker-embedding-extractor.cc
|
||||||
|
//
|
||||||
|
// Copyright (c) 2023 Xiaomi Corporation
|
||||||
|
|
||||||
|
#include "sherpa-onnx/csrc/speaker-embedding-extractor.h"
|
||||||
|
|
||||||
|
#include <vector>
|
||||||
|
|
||||||
|
#include "sherpa-onnx/csrc/file-utils.h"
|
||||||
|
#include "sherpa-onnx/csrc/macros.h"
|
||||||
|
#include "sherpa-onnx/csrc/speaker-embedding-extractor-impl.h"
|
||||||
|
|
||||||
|
namespace sherpa_onnx {
|
||||||
|
|
||||||
|
void SpeakerEmbeddingExtractorConfig::Register(ParseOptions *po) {
|
||||||
|
po->Register("model", &model, "Path to the speaker embedding model.");
|
||||||
|
po->Register("num-threads", &num_threads,
|
||||||
|
"Number of threads to run the neural network");
|
||||||
|
|
||||||
|
po->Register("debug", &debug,
|
||||||
|
"true to print model information while loading it.");
|
||||||
|
|
||||||
|
po->Register("provider", &provider,
|
||||||
|
"Specify a provider to use: cpu, cuda, coreml");
|
||||||
|
}
|
||||||
|
|
||||||
|
bool SpeakerEmbeddingExtractorConfig::Validate() const {
|
||||||
|
if (model.empty()) {
|
||||||
|
SHERPA_ONNX_LOGE("Please provide --speaker-embedding-model");
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (!FileExists(model)) {
|
||||||
|
SHERPA_ONNX_LOGE("--speaker-embedding-model: %s does not exist",
|
||||||
|
model.c_str());
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
|
std::string SpeakerEmbeddingExtractorConfig::ToString() const {
|
||||||
|
std::ostringstream os;
|
||||||
|
|
||||||
|
os << "SpeakerEmbeddingExtractorConfig(";
|
||||||
|
os << "model=\"" << model << "\", ";
|
||||||
|
os << "num_threads=" << num_threads << ", ";
|
||||||
|
os << "debug=" << (debug ? "True" : "False") << ", ";
|
||||||
|
os << "provider=\"" << provider << "\")";
|
||||||
|
|
||||||
|
return os.str();
|
||||||
|
}
|
||||||
|
|
||||||
|
SpeakerEmbeddingExtractor::SpeakerEmbeddingExtractor(
|
||||||
|
const SpeakerEmbeddingExtractorConfig &config)
|
||||||
|
: impl_(SpeakerEmbeddingExtractorImpl::Create(config)) {}
|
||||||
|
|
||||||
|
SpeakerEmbeddingExtractor::~SpeakerEmbeddingExtractor() = default;
|
||||||
|
|
||||||
|
int32_t SpeakerEmbeddingExtractor::Dim() const { return impl_->Dim(); }
|
||||||
|
|
||||||
|
std::unique_ptr<OnlineStream> SpeakerEmbeddingExtractor::CreateStream() const {
|
||||||
|
return impl_->CreateStream();
|
||||||
|
}
|
||||||
|
|
||||||
|
bool SpeakerEmbeddingExtractor::IsReady(OnlineStream *s) const {
|
||||||
|
return impl_->IsReady(s);
|
||||||
|
}
|
||||||
|
|
||||||
|
std::vector<float> SpeakerEmbeddingExtractor::Compute(OnlineStream *s) const {
|
||||||
|
return impl_->Compute(s);
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace sherpa_onnx
|
||||||
67
sherpa-onnx/csrc/speaker-embedding-extractor.h
Normal file
67
sherpa-onnx/csrc/speaker-embedding-extractor.h
Normal file
@@ -0,0 +1,67 @@
|
|||||||
|
// sherpa-onnx/csrc/speaker-embedding-extractor.h
|
||||||
|
//
|
||||||
|
// Copyright (c) 2023 Xiaomi Corporation
|
||||||
|
|
||||||
|
#ifndef SHERPA_ONNX_CSRC_SPEAKER_EMBEDDING_EXTRACTOR_H_
|
||||||
|
#define SHERPA_ONNX_CSRC_SPEAKER_EMBEDDING_EXTRACTOR_H_
|
||||||
|
|
||||||
|
#include <memory>
|
||||||
|
#include <string>
|
||||||
|
#include <vector>
|
||||||
|
|
||||||
|
#include "sherpa-onnx/csrc/online-stream.h"
|
||||||
|
#include "sherpa-onnx/csrc/parse-options.h"
|
||||||
|
|
||||||
|
namespace sherpa_onnx {
|
||||||
|
|
||||||
|
struct SpeakerEmbeddingExtractorConfig {
|
||||||
|
std::string model;
|
||||||
|
int32_t num_threads = 1;
|
||||||
|
bool debug = false;
|
||||||
|
std::string provider = "cpu";
|
||||||
|
|
||||||
|
SpeakerEmbeddingExtractorConfig() = default;
|
||||||
|
SpeakerEmbeddingExtractorConfig(const std::string &model, int32_t num_threads,
|
||||||
|
bool debug, const std::string &provider)
|
||||||
|
: model(model),
|
||||||
|
num_threads(num_threads),
|
||||||
|
debug(debug),
|
||||||
|
provider(provider) {}
|
||||||
|
|
||||||
|
void Register(ParseOptions *po);
|
||||||
|
bool Validate() const;
|
||||||
|
std::string ToString() const;
|
||||||
|
};
|
||||||
|
|
||||||
|
class SpeakerEmbeddingExtractorImpl;
|
||||||
|
|
||||||
|
class SpeakerEmbeddingExtractor {
|
||||||
|
public:
|
||||||
|
explicit SpeakerEmbeddingExtractor(
|
||||||
|
const SpeakerEmbeddingExtractorConfig &config);
|
||||||
|
|
||||||
|
~SpeakerEmbeddingExtractor();
|
||||||
|
|
||||||
|
// Return the dimension of the embedding
|
||||||
|
int32_t Dim() const;
|
||||||
|
|
||||||
|
// Create a stream to accept audio samples and compute features
|
||||||
|
std::unique_ptr<OnlineStream> CreateStream() const;
|
||||||
|
|
||||||
|
// Return true if there are feature frames in OnlineStream that
|
||||||
|
// can be used to compute embeddings.
|
||||||
|
bool IsReady(OnlineStream *s) const;
|
||||||
|
|
||||||
|
// Compute the speaker embedding from the available unprocessed features
|
||||||
|
// of the given stream
|
||||||
|
//
|
||||||
|
// You have to ensure IsReady(s) returns true before you call this method.
|
||||||
|
std::vector<float> Compute(OnlineStream *s) const;
|
||||||
|
|
||||||
|
private:
|
||||||
|
std::unique_ptr<SpeakerEmbeddingExtractorImpl> impl_;
|
||||||
|
};
|
||||||
|
|
||||||
|
} // namespace sherpa_onnx
|
||||||
|
|
||||||
|
#endif // SHERPA_ONNX_CSRC_SPEAKER_EMBEDDING_EXTRACTOR_H_
|
||||||
147
sherpa-onnx/csrc/speaker-embedding-manager-test.cc
Normal file
147
sherpa-onnx/csrc/speaker-embedding-manager-test.cc
Normal file
@@ -0,0 +1,147 @@
|
|||||||
|
// sherpa-onnx/csrc/speaker-embedding-manager-test.cc
|
||||||
|
//
|
||||||
|
// Copyright (c) 2023 Jingzhao Ou (jingzhao.ou@gmail.com)
|
||||||
|
|
||||||
|
#include "sherpa-onnx/csrc/speaker-embedding-manager.h"
|
||||||
|
|
||||||
|
#include "gtest/gtest.h"
|
||||||
|
|
||||||
|
namespace sherpa_onnx {
|
||||||
|
|
||||||
|
TEST(SpeakerEmbeddingManager, AddAndRemove) {
|
||||||
|
int32_t dim = 2;
|
||||||
|
SpeakerEmbeddingManager manager(dim);
|
||||||
|
std::vector<float> v = {0.1, 0.1};
|
||||||
|
bool status = manager.Add("first", v.data());
|
||||||
|
ASSERT_TRUE(status);
|
||||||
|
ASSERT_EQ(manager.NumSpeakers(), 1);
|
||||||
|
|
||||||
|
// duplicate
|
||||||
|
status = manager.Add("first", v.data());
|
||||||
|
ASSERT_FALSE(status);
|
||||||
|
ASSERT_EQ(manager.NumSpeakers(), 1);
|
||||||
|
|
||||||
|
// non-duplicate
|
||||||
|
v = {0.1, 0.9};
|
||||||
|
status = manager.Add("second", v.data());
|
||||||
|
ASSERT_TRUE(status);
|
||||||
|
ASSERT_EQ(manager.NumSpeakers(), 2);
|
||||||
|
|
||||||
|
// do not exist
|
||||||
|
status = manager.Remove("third");
|
||||||
|
ASSERT_FALSE(status);
|
||||||
|
|
||||||
|
status = manager.Remove("first");
|
||||||
|
ASSERT_TRUE(status);
|
||||||
|
ASSERT_EQ(manager.NumSpeakers(), 1);
|
||||||
|
|
||||||
|
v = {0.1, 0.1};
|
||||||
|
status = manager.Add("first", v.data());
|
||||||
|
ASSERT_TRUE(status);
|
||||||
|
ASSERT_EQ(manager.NumSpeakers(), 2);
|
||||||
|
|
||||||
|
status = manager.Remove("first");
|
||||||
|
ASSERT_TRUE(status);
|
||||||
|
ASSERT_EQ(manager.NumSpeakers(), 1);
|
||||||
|
|
||||||
|
status = manager.Remove("second");
|
||||||
|
ASSERT_TRUE(status);
|
||||||
|
ASSERT_EQ(manager.NumSpeakers(), 0);
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST(SpeakerEmbeddingManager, Search) {
|
||||||
|
int32_t dim = 2;
|
||||||
|
SpeakerEmbeddingManager manager(dim);
|
||||||
|
std::vector<float> v1 = {0.1, 0.1};
|
||||||
|
std::vector<float> v2 = {0.1, 0.9};
|
||||||
|
std::vector<float> v3 = {0.9, 0.1};
|
||||||
|
bool status = manager.Add("first", v1.data());
|
||||||
|
ASSERT_TRUE(status);
|
||||||
|
|
||||||
|
status = manager.Add("second", v2.data());
|
||||||
|
ASSERT_TRUE(status);
|
||||||
|
|
||||||
|
status = manager.Add("third", v3.data());
|
||||||
|
ASSERT_TRUE(status);
|
||||||
|
|
||||||
|
ASSERT_EQ(manager.NumSpeakers(), 3);
|
||||||
|
|
||||||
|
std::vector<float> v = {15, 16};
|
||||||
|
float threshold = 0.9;
|
||||||
|
|
||||||
|
std::string name = manager.Search(v.data(), threshold);
|
||||||
|
EXPECT_EQ(name, "first");
|
||||||
|
|
||||||
|
v = {2, 17};
|
||||||
|
name = manager.Search(v.data(), threshold);
|
||||||
|
EXPECT_EQ(name, "second");
|
||||||
|
|
||||||
|
v = {17, 2};
|
||||||
|
name = manager.Search(v.data(), threshold);
|
||||||
|
EXPECT_EQ(name, "third");
|
||||||
|
|
||||||
|
threshold = 0.9;
|
||||||
|
v = {15, 16};
|
||||||
|
status = manager.Remove("first");
|
||||||
|
ASSERT_TRUE(status);
|
||||||
|
name = manager.Search(v.data(), threshold);
|
||||||
|
EXPECT_EQ(name, "");
|
||||||
|
|
||||||
|
v = {17, 2};
|
||||||
|
status = manager.Remove("third");
|
||||||
|
ASSERT_TRUE(status);
|
||||||
|
name = manager.Search(v.data(), threshold);
|
||||||
|
EXPECT_EQ(name, "");
|
||||||
|
|
||||||
|
v = {2, 17};
|
||||||
|
status = manager.Remove("second");
|
||||||
|
ASSERT_TRUE(status);
|
||||||
|
name = manager.Search(v.data(), threshold);
|
||||||
|
EXPECT_EQ(name, "");
|
||||||
|
|
||||||
|
ASSERT_EQ(manager.NumSpeakers(), 0);
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST(SpeakerEmbeddingManager, Verify) {
|
||||||
|
int32_t dim = 2;
|
||||||
|
SpeakerEmbeddingManager manager(dim);
|
||||||
|
std::vector<float> v1 = {0.1, 0.1};
|
||||||
|
std::vector<float> v2 = {0.1, 0.9};
|
||||||
|
std::vector<float> v3 = {0.9, 0.1};
|
||||||
|
bool status = manager.Add("first", v1.data());
|
||||||
|
ASSERT_TRUE(status);
|
||||||
|
|
||||||
|
status = manager.Add("second", v2.data());
|
||||||
|
ASSERT_TRUE(status);
|
||||||
|
|
||||||
|
status = manager.Add("third", v3.data());
|
||||||
|
ASSERT_TRUE(status);
|
||||||
|
|
||||||
|
std::vector<float> v = {15, 16};
|
||||||
|
float threshold = 0.9;
|
||||||
|
|
||||||
|
status = manager.Verify("first", v.data(), threshold);
|
||||||
|
ASSERT_TRUE(status);
|
||||||
|
|
||||||
|
v = {2, 17};
|
||||||
|
status = manager.Verify("first", v.data(), threshold);
|
||||||
|
ASSERT_FALSE(status);
|
||||||
|
|
||||||
|
status = manager.Verify("second", v.data(), threshold);
|
||||||
|
ASSERT_TRUE(status);
|
||||||
|
|
||||||
|
v = {17, 2};
|
||||||
|
status = manager.Verify("first", v.data(), threshold);
|
||||||
|
ASSERT_FALSE(status);
|
||||||
|
|
||||||
|
status = manager.Verify("second", v.data(), threshold);
|
||||||
|
ASSERT_FALSE(status);
|
||||||
|
|
||||||
|
status = manager.Verify("third", v.data(), threshold);
|
||||||
|
ASSERT_TRUE(status);
|
||||||
|
|
||||||
|
status = manager.Verify("fourth", v.data(), threshold);
|
||||||
|
ASSERT_FALSE(status);
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace sherpa_onnx
|
||||||
144
sherpa-onnx/csrc/speaker-embedding-manager.cc
Normal file
144
sherpa-onnx/csrc/speaker-embedding-manager.cc
Normal file
@@ -0,0 +1,144 @@
|
|||||||
|
// sherpa-onnx/csrc/speaker-embedding-manager.cc
|
||||||
|
//
|
||||||
|
// Copyright (c) 2023 Xiaomi Corporation
|
||||||
|
|
||||||
|
#include "sherpa-onnx/csrc/speaker-embedding-manager.h"
|
||||||
|
|
||||||
|
#include <algorithm>
|
||||||
|
#include <unordered_map>
|
||||||
|
|
||||||
|
#include "Eigen/Dense"
|
||||||
|
|
||||||
|
namespace sherpa_onnx {
|
||||||
|
|
||||||
|
using FloatMatrix =
|
||||||
|
Eigen::Matrix<float, Eigen::Dynamic, Eigen::Dynamic, Eigen::RowMajor>;
|
||||||
|
|
||||||
|
class SpeakerEmbeddingManager::Impl {
|
||||||
|
public:
|
||||||
|
explicit Impl(int32_t dim) : dim_(dim) {}
|
||||||
|
|
||||||
|
bool Add(const std::string &name, const float *p) {
|
||||||
|
if (name2row_.count(name)) {
|
||||||
|
// a speaker with the same name already exists
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
embedding_matrix_.conservativeResize(embedding_matrix_.rows() + 1, dim_);
|
||||||
|
|
||||||
|
std::copy(p, p + dim_, &embedding_matrix_.bottomRows(1)(0, 0));
|
||||||
|
|
||||||
|
embedding_matrix_.bottomRows(1).normalize(); // inplace
|
||||||
|
|
||||||
|
name2row_[name] = embedding_matrix_.rows() - 1;
|
||||||
|
row2name_[embedding_matrix_.rows() - 1] = name;
|
||||||
|
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
|
bool Remove(const std::string &name) {
|
||||||
|
if (!name2row_.count(name)) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
int32_t row_idx = name2row_.at(name);
|
||||||
|
|
||||||
|
int32_t num_rows = embedding_matrix_.rows();
|
||||||
|
|
||||||
|
if (row_idx < num_rows - 1) {
|
||||||
|
embedding_matrix_.block(row_idx, 0, num_rows - -1 - row_idx, dim_) =
|
||||||
|
embedding_matrix_.bottomRows(num_rows - 1 - row_idx);
|
||||||
|
}
|
||||||
|
|
||||||
|
embedding_matrix_.conservativeResize(num_rows - 1, dim_);
|
||||||
|
for (auto &p : name2row_) {
|
||||||
|
if (p.second > row_idx) {
|
||||||
|
p.second -= 1;
|
||||||
|
row2name_[p.second] = p.first;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
name2row_.erase(name);
|
||||||
|
row2name_.erase(num_rows - 1);
|
||||||
|
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
|
std::string Search(const float *p, float threshold) {
|
||||||
|
if (embedding_matrix_.rows() == 0) {
|
||||||
|
return {};
|
||||||
|
}
|
||||||
|
|
||||||
|
Eigen::VectorXf v =
|
||||||
|
Eigen::Map<Eigen::VectorXf>(const_cast<float *>(p), dim_);
|
||||||
|
v.normalize();
|
||||||
|
|
||||||
|
Eigen::VectorXf scores = embedding_matrix_ * v;
|
||||||
|
|
||||||
|
Eigen::VectorXf::Index max_index;
|
||||||
|
float max_score = scores.maxCoeff(&max_index);
|
||||||
|
if (max_score < threshold) {
|
||||||
|
return {};
|
||||||
|
}
|
||||||
|
|
||||||
|
return row2name_.at(max_index);
|
||||||
|
}
|
||||||
|
|
||||||
|
bool Verify(const std::string &name, const float *p, float threshold) {
|
||||||
|
if (!name2row_.count(name)) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
int32_t row_idx = name2row_.at(name);
|
||||||
|
|
||||||
|
Eigen::VectorXf v =
|
||||||
|
Eigen::Map<Eigen::VectorXf>(const_cast<float *>(p), dim_);
|
||||||
|
v.normalize();
|
||||||
|
|
||||||
|
float score = embedding_matrix_.row(row_idx) * v;
|
||||||
|
|
||||||
|
if (score < threshold) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
|
int32_t NumSpeakers() const { return embedding_matrix_.rows(); }
|
||||||
|
|
||||||
|
private:
|
||||||
|
int32_t dim_;
|
||||||
|
FloatMatrix embedding_matrix_;
|
||||||
|
std::unordered_map<std::string, int32_t> name2row_;
|
||||||
|
std::unordered_map<int32_t, std::string> row2name_;
|
||||||
|
};
|
||||||
|
|
||||||
|
SpeakerEmbeddingManager::SpeakerEmbeddingManager(int32_t dim)
|
||||||
|
: impl_(std::make_unique<Impl>(dim)) {}
|
||||||
|
|
||||||
|
SpeakerEmbeddingManager::~SpeakerEmbeddingManager() = default;
|
||||||
|
|
||||||
|
bool SpeakerEmbeddingManager::Add(const std::string &name,
|
||||||
|
const float *p) const {
|
||||||
|
return impl_->Add(name, p);
|
||||||
|
}
|
||||||
|
|
||||||
|
bool SpeakerEmbeddingManager::Remove(const std::string &name) const {
|
||||||
|
return impl_->Remove(name);
|
||||||
|
}
|
||||||
|
|
||||||
|
std::string SpeakerEmbeddingManager::Search(const float *p,
|
||||||
|
float threshold) const {
|
||||||
|
return impl_->Search(p, threshold);
|
||||||
|
}
|
||||||
|
|
||||||
|
bool SpeakerEmbeddingManager::Verify(const std::string &name, const float *p,
|
||||||
|
float threshold) const {
|
||||||
|
return impl_->Verify(name, p, threshold);
|
||||||
|
}
|
||||||
|
|
||||||
|
int32_t SpeakerEmbeddingManager::NumSpeakers() const {
|
||||||
|
return impl_->NumSpeakers();
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace sherpa_onnx
|
||||||
72
sherpa-onnx/csrc/speaker-embedding-manager.h
Normal file
72
sherpa-onnx/csrc/speaker-embedding-manager.h
Normal file
@@ -0,0 +1,72 @@
|
|||||||
|
// sherpa-onnx/csrc/speaker-embedding-manager.h
|
||||||
|
//
|
||||||
|
// Copyright (c) 2023 Xiaomi Corporation
|
||||||
|
|
||||||
|
#ifndef SHERPA_ONNX_CSRC_SPEAKER_EMBEDDING_MANAGER_H_
|
||||||
|
#define SHERPA_ONNX_CSRC_SPEAKER_EMBEDDING_MANAGER_H_
|
||||||
|
|
||||||
|
#include <memory>
|
||||||
|
#include <string>
|
||||||
|
|
||||||
|
namespace sherpa_onnx {
|
||||||
|
|
||||||
|
class SpeakerEmbeddingManager {
|
||||||
|
public:
|
||||||
|
// @param dim Embedding dimension.
|
||||||
|
explicit SpeakerEmbeddingManager(int32_t dim);
|
||||||
|
~SpeakerEmbeddingManager();
|
||||||
|
|
||||||
|
/* Add the embedding and name of a speaker to the manager.
|
||||||
|
*
|
||||||
|
* @param name Name of the speaker
|
||||||
|
* @param p Pointer to the embedding. Its length is `dim`.
|
||||||
|
* @return Return true if added successfully. Return false if it failed.
|
||||||
|
* At present, the only reason for a failure is that there is already
|
||||||
|
* a speaker with the same `name`.
|
||||||
|
*/
|
||||||
|
bool Add(const std::string &name, const float *p) const;
|
||||||
|
|
||||||
|
/* Remove a speaker by its name.
|
||||||
|
*
|
||||||
|
* @param name Name of the speaker to remove.
|
||||||
|
* @return Return true if it is removed successfully. Return false
|
||||||
|
* if there is no such a speaker.
|
||||||
|
*/
|
||||||
|
bool Remove(const std::string &name) const;
|
||||||
|
|
||||||
|
/** It is for speaker identification.
|
||||||
|
*
|
||||||
|
* It computes the cosine similarity between and given embedding and all
|
||||||
|
* other embeddings and find the embedding that has the largest score
|
||||||
|
* and the score is above or equal to threshold. Return the speaker
|
||||||
|
* name for the embedding if found; otherwise, it returns an empty string.
|
||||||
|
*
|
||||||
|
* @param p The input embedding.
|
||||||
|
* @param threshold A value between 0 and 1.
|
||||||
|
* @param If found, return the name of the speaker. Otherwise, return an
|
||||||
|
* empty string.
|
||||||
|
*/
|
||||||
|
std::string Search(const float *p, float threshold) const;
|
||||||
|
|
||||||
|
/* Check whether the input embedding matches the embedding of the input
|
||||||
|
* speaker.
|
||||||
|
*
|
||||||
|
* It is for speaker verification.
|
||||||
|
*
|
||||||
|
* @param name The target speaker name.
|
||||||
|
* @param p The input embedding to check.
|
||||||
|
* @param threshold A value between 0 and 1.
|
||||||
|
* @return Return true if it matches. Otherwise, it returns false.
|
||||||
|
*/
|
||||||
|
bool Verify(const std::string &name, const float *p, float threshold) const;
|
||||||
|
|
||||||
|
int32_t NumSpeakers() const;
|
||||||
|
|
||||||
|
private:
|
||||||
|
class Impl;
|
||||||
|
std::unique_ptr<Impl> impl_;
|
||||||
|
};
|
||||||
|
|
||||||
|
} // namespace sherpa_onnx
|
||||||
|
|
||||||
|
#endif // SHERPA_ONNX_CSRC_SPEAKER_EMBEDDING_MANAGER_H_
|
||||||
@@ -40,7 +40,9 @@ class VoiceActivityDetector::Impl {
|
|||||||
|
|
||||||
for (int32_t i = 0; i != k; ++i, p += window_size) {
|
for (int32_t i = 0; i != k; ++i, p += window_size) {
|
||||||
buffer_.Push(p, window_size);
|
buffer_.Push(p, window_size);
|
||||||
is_speech = is_speech || model_->IsSpeech(p, window_size);
|
// NOTE(fangjun): Please don't use a very large n.
|
||||||
|
bool this_window_is_speech = model_->IsSpeech(p, window_size);
|
||||||
|
is_speech = is_speech || this_window_is_speech;
|
||||||
}
|
}
|
||||||
|
|
||||||
last_ = std::vector<float>(
|
last_ = std::vector<float>(
|
||||||
@@ -102,6 +104,8 @@ class VoiceActivityDetector::Impl {
|
|||||||
|
|
||||||
bool IsSpeechDetected() const { return start_ != -1; }
|
bool IsSpeechDetected() const { return start_ != -1; }
|
||||||
|
|
||||||
|
const VadModelConfig &GetConfig() const { return config_; }
|
||||||
|
|
||||||
private:
|
private:
|
||||||
std::queue<SpeechSegment> segments_;
|
std::queue<SpeechSegment> segments_;
|
||||||
|
|
||||||
@@ -146,4 +150,8 @@ bool VoiceActivityDetector::IsSpeechDetected() const {
|
|||||||
return impl_->IsSpeechDetected();
|
return impl_->IsSpeechDetected();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
const VadModelConfig &VoiceActivityDetector::GetConfig() const {
|
||||||
|
return impl_->GetConfig();
|
||||||
|
}
|
||||||
|
|
||||||
} // namespace sherpa_onnx
|
} // namespace sherpa_onnx
|
||||||
|
|||||||
@@ -43,6 +43,8 @@ class VoiceActivityDetector {
|
|||||||
|
|
||||||
void Reset();
|
void Reset();
|
||||||
|
|
||||||
|
const VadModelConfig &GetConfig() const;
|
||||||
|
|
||||||
private:
|
private:
|
||||||
class Impl;
|
class Impl;
|
||||||
std::unique_ptr<Impl> impl_;
|
std::unique_ptr<Impl> impl_;
|
||||||
|
|||||||
@@ -30,6 +30,8 @@ pybind11_add_module(_sherpa_onnx
|
|||||||
online-zipformer2-ctc-model-config.cc
|
online-zipformer2-ctc-model-config.cc
|
||||||
sherpa-onnx.cc
|
sherpa-onnx.cc
|
||||||
silero-vad-model-config.cc
|
silero-vad-model-config.cc
|
||||||
|
speaker-embedding-extractor.cc
|
||||||
|
speaker-embedding-manager.cc
|
||||||
vad-model-config.cc
|
vad-model-config.cc
|
||||||
vad-model.cc
|
vad-model.cc
|
||||||
voice-activity-detector.cc
|
voice-activity-detector.cc
|
||||||
|
|||||||
@@ -1,4 +1,4 @@
|
|||||||
// sherpa-onnx/python/csrc/online-recongizer.h
|
// sherpa-onnx/python/csrc/online-recognizer.h
|
||||||
//
|
//
|
||||||
// Copyright (c) 2023 Xiaomi Corporation
|
// Copyright (c) 2023 Xiaomi Corporation
|
||||||
|
|
||||||
|
|||||||
@@ -18,6 +18,8 @@
|
|||||||
#include "sherpa-onnx/python/csrc/online-model-config.h"
|
#include "sherpa-onnx/python/csrc/online-model-config.h"
|
||||||
#include "sherpa-onnx/python/csrc/online-recognizer.h"
|
#include "sherpa-onnx/python/csrc/online-recognizer.h"
|
||||||
#include "sherpa-onnx/python/csrc/online-stream.h"
|
#include "sherpa-onnx/python/csrc/online-stream.h"
|
||||||
|
#include "sherpa-onnx/python/csrc/speaker-embedding-extractor.h"
|
||||||
|
#include "sherpa-onnx/python/csrc/speaker-embedding-manager.h"
|
||||||
#include "sherpa-onnx/python/csrc/vad-model-config.h"
|
#include "sherpa-onnx/python/csrc/vad-model-config.h"
|
||||||
#include "sherpa-onnx/python/csrc/vad-model.h"
|
#include "sherpa-onnx/python/csrc/vad-model.h"
|
||||||
#include "sherpa-onnx/python/csrc/voice-activity-detector.h"
|
#include "sherpa-onnx/python/csrc/voice-activity-detector.h"
|
||||||
@@ -48,6 +50,8 @@ PYBIND11_MODULE(_sherpa_onnx, m) {
|
|||||||
PybindVoiceActivityDetector(&m);
|
PybindVoiceActivityDetector(&m);
|
||||||
|
|
||||||
PybindOfflineTts(&m);
|
PybindOfflineTts(&m);
|
||||||
|
PybindSpeakerEmbeddingExtractor(&m);
|
||||||
|
PybindSpeakerEmbeddingManager(&m);
|
||||||
}
|
}
|
||||||
|
|
||||||
} // namespace sherpa_onnx
|
} // namespace sherpa_onnx
|
||||||
|
|||||||
44
sherpa-onnx/python/csrc/speaker-embedding-extractor.cc
Normal file
44
sherpa-onnx/python/csrc/speaker-embedding-extractor.cc
Normal file
@@ -0,0 +1,44 @@
|
|||||||
|
// sherpa-onnx/python/csrc/speaker-embedding-extractor.cc
|
||||||
|
//
|
||||||
|
// Copyright (c) 2023 Xiaomi Corporation
|
||||||
|
|
||||||
|
#include "sherpa-onnx/python/csrc/speaker-embedding-extractor.h"
|
||||||
|
|
||||||
|
#include <string>
|
||||||
|
|
||||||
|
#include "sherpa-onnx/csrc/speaker-embedding-extractor.h"
|
||||||
|
|
||||||
|
namespace sherpa_onnx {
|
||||||
|
|
||||||
|
static void PybindSpeakerEmbeddingExtractorConfig(py::module *m) {
|
||||||
|
using PyClass = SpeakerEmbeddingExtractorConfig;
|
||||||
|
py::class_<PyClass>(*m, "SpeakerEmbeddingExtractorConfig")
|
||||||
|
.def(py::init<>())
|
||||||
|
.def(py::init<const std::string &, int32_t, bool, const std::string>(),
|
||||||
|
py::arg("model"), py::arg("num_threads") = 1,
|
||||||
|
py::arg("debug") = false, py::arg("provider") = "cpu")
|
||||||
|
.def_readwrite("model", &PyClass::model)
|
||||||
|
.def_readwrite("num_threads", &PyClass::num_threads)
|
||||||
|
.def_readwrite("debug", &PyClass::debug)
|
||||||
|
.def_readwrite("provider", &PyClass::provider)
|
||||||
|
.def("validate", &PyClass::Validate)
|
||||||
|
.def("__str__", &PyClass::ToString);
|
||||||
|
}
|
||||||
|
|
||||||
|
void PybindSpeakerEmbeddingExtractor(py::module *m) {
|
||||||
|
PybindSpeakerEmbeddingExtractorConfig(m);
|
||||||
|
|
||||||
|
using PyClass = SpeakerEmbeddingExtractor;
|
||||||
|
py::class_<PyClass>(*m, "SpeakerEmbeddingExtractor")
|
||||||
|
.def(py::init<const SpeakerEmbeddingExtractorConfig &>(),
|
||||||
|
py::arg("config"), py::call_guard<py::gil_scoped_release>())
|
||||||
|
.def_property_readonly("dim", &PyClass::Dim)
|
||||||
|
.def("create_stream", &PyClass::CreateStream,
|
||||||
|
py::call_guard<py::gil_scoped_release>())
|
||||||
|
.def("compute", &PyClass::Compute,
|
||||||
|
py::call_guard<py::gil_scoped_release>())
|
||||||
|
.def("is_ready", &PyClass::IsReady,
|
||||||
|
py::call_guard<py::gil_scoped_release>());
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace sherpa_onnx
|
||||||
16
sherpa-onnx/python/csrc/speaker-embedding-extractor.h
Normal file
16
sherpa-onnx/python/csrc/speaker-embedding-extractor.h
Normal file
@@ -0,0 +1,16 @@
|
|||||||
|
// sherpa-onnx/python/csrc/speaker-embedding-extractor.h
|
||||||
|
//
|
||||||
|
// Copyright (c) 2023 Xiaomi Corporation
|
||||||
|
|
||||||
|
#ifndef SHERPA_ONNX_PYTHON_CSRC_SPEAKER_EMBEDDING_EXTRACTOR_H_
|
||||||
|
#define SHERPA_ONNX_PYTHON_CSRC_SPEAKER_EMBEDDING_EXTRACTOR_H_
|
||||||
|
|
||||||
|
#include "sherpa-onnx/python/csrc/sherpa-onnx.h"
|
||||||
|
|
||||||
|
namespace sherpa_onnx {
|
||||||
|
|
||||||
|
void PybindSpeakerEmbeddingExtractor(py::module *m);
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
#endif // SHERPA_ONNX_PYTHON_CSRC_SPEAKER_EMBEDDING_EXTRACTOR_H_
|
||||||
50
sherpa-onnx/python/csrc/speaker-embedding-manager.cc
Normal file
50
sherpa-onnx/python/csrc/speaker-embedding-manager.cc
Normal file
@@ -0,0 +1,50 @@
|
|||||||
|
// sherpa-onnx/python/csrc/speaker-embedding-manager.cc
|
||||||
|
//
|
||||||
|
// Copyright (c) 2023 Xiaomi Corporation
|
||||||
|
|
||||||
|
#include "sherpa-onnx/python/csrc/speaker-embedding-manager.h"
|
||||||
|
|
||||||
|
#include <string>
|
||||||
|
#include <vector>
|
||||||
|
|
||||||
|
#include "sherpa-onnx/csrc/speaker-embedding-manager.h"
|
||||||
|
|
||||||
|
namespace sherpa_onnx {
|
||||||
|
|
||||||
|
void PybindSpeakerEmbeddingManager(py::module *m) {
|
||||||
|
using PyClass = SpeakerEmbeddingManager;
|
||||||
|
py::class_<PyClass>(*m, "SpeakerEmbeddingManager")
|
||||||
|
.def(py::init<int32_t>(), py::arg("dim"),
|
||||||
|
py::call_guard<py::gil_scoped_release>())
|
||||||
|
.def_property_readonly("num_speakers", &PyClass::NumSpeakers)
|
||||||
|
.def(
|
||||||
|
"add",
|
||||||
|
[](const PyClass &self, const std::string &name,
|
||||||
|
const std::vector<float> &v) -> bool {
|
||||||
|
return self.Add(name, v.data());
|
||||||
|
},
|
||||||
|
py::arg("name"), py::arg("v"),
|
||||||
|
py::call_guard<py::gil_scoped_release>())
|
||||||
|
.def(
|
||||||
|
"remove",
|
||||||
|
[](const PyClass &self, const std::string &name) -> bool {
|
||||||
|
return self.Remove(name);
|
||||||
|
},
|
||||||
|
py::arg("name"), py::call_guard<py::gil_scoped_release>())
|
||||||
|
.def(
|
||||||
|
"search",
|
||||||
|
[](const PyClass &self, const std::vector<float> &v, float threshold)
|
||||||
|
-> std::string { return self.Search(v.data(), threshold); },
|
||||||
|
py::arg("v"), py::arg("threshold"),
|
||||||
|
py::call_guard<py::gil_scoped_release>())
|
||||||
|
.def(
|
||||||
|
"verify",
|
||||||
|
[](const PyClass &self, const std::string &name,
|
||||||
|
const std::vector<float> &v, float threshold) -> bool {
|
||||||
|
return self.Verify(name, v.data(), threshold);
|
||||||
|
},
|
||||||
|
py::arg("name"), py::arg("v"), py::arg("threshold"),
|
||||||
|
py::call_guard<py::gil_scoped_release>());
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace sherpa_onnx
|
||||||
16
sherpa-onnx/python/csrc/speaker-embedding-manager.h
Normal file
16
sherpa-onnx/python/csrc/speaker-embedding-manager.h
Normal file
@@ -0,0 +1,16 @@
|
|||||||
|
// sherpa-onnx/python/csrc/speaker-embedding-manager.h
|
||||||
|
//
|
||||||
|
// Copyright (c) 2023 Xiaomi Corporation
|
||||||
|
|
||||||
|
#ifndef SHERPA_ONNX_PYTHON_CSRC_SPEAKER_EMBEDDING_MANAGER_H_
|
||||||
|
#define SHERPA_ONNX_PYTHON_CSRC_SPEAKER_EMBEDDING_MANAGER_H_
|
||||||
|
|
||||||
|
#include "sherpa-onnx/python/csrc/sherpa-onnx.h"
|
||||||
|
|
||||||
|
namespace sherpa_onnx {
|
||||||
|
|
||||||
|
void PybindSpeakerEmbeddingManager(py::module *m);
|
||||||
|
|
||||||
|
} // namespace sherpa_onnx
|
||||||
|
|
||||||
|
#endif // SHERPA_ONNX_PYTHON_CSRC_SPEAKER_EMBEDDING_MANAGER_H_
|
||||||
@@ -32,6 +32,7 @@ void PybindVoiceActivityDetector(py::module *m) {
|
|||||||
self.AcceptWaveform(samples.data(), samples.size());
|
self.AcceptWaveform(samples.data(), samples.size());
|
||||||
},
|
},
|
||||||
py::arg("samples"), py::call_guard<py::gil_scoped_release>())
|
py::arg("samples"), py::call_guard<py::gil_scoped_release>())
|
||||||
|
.def_property_readonly("config", &PyClass::GetConfig)
|
||||||
.def("empty", &PyClass::Empty, py::call_guard<py::gil_scoped_release>())
|
.def("empty", &PyClass::Empty, py::call_guard<py::gil_scoped_release>())
|
||||||
.def("pop", &PyClass::Pop, py::call_guard<py::gil_scoped_release>())
|
.def("pop", &PyClass::Pop, py::call_guard<py::gil_scoped_release>())
|
||||||
.def("is_speech_detected", &PyClass::IsSpeechDetected,
|
.def("is_speech_detected", &PyClass::IsSpeechDetected,
|
||||||
|
|||||||
@@ -8,6 +8,9 @@ from _sherpa_onnx import (
|
|||||||
OfflineTtsVitsModelConfig,
|
OfflineTtsVitsModelConfig,
|
||||||
OnlineStream,
|
OnlineStream,
|
||||||
SileroVadModelConfig,
|
SileroVadModelConfig,
|
||||||
|
SpeakerEmbeddingExtractor,
|
||||||
|
SpeakerEmbeddingExtractorConfig,
|
||||||
|
SpeakerEmbeddingManager,
|
||||||
SpeechSegment,
|
SpeechSegment,
|
||||||
VadModel,
|
VadModel,
|
||||||
VadModelConfig,
|
VadModelConfig,
|
||||||
|
|||||||
Reference in New Issue
Block a user