Add runtime support for wespeaker models (#516)
This commit is contained in:
252
python-api-examples/speaker-identification.py
Executable file
252
python-api-examples/speaker-identification.py
Executable file
@@ -0,0 +1,252 @@
|
||||
#!/usr/bin/env python3
|
||||
|
||||
"""
|
||||
This script shows how to use Python APIs for speaker identification.
|
||||
|
||||
Usage:
|
||||
|
||||
(1) Prepare a text file containing speaker related files.
|
||||
|
||||
Each line in the text file contains two columns. The first column is the
|
||||
speaker name, while the second column contains the wave file of the speaker.
|
||||
|
||||
If the text file contains multiple wave files for the same speaker, then the
|
||||
embeddings of these files are averaged.
|
||||
|
||||
An example text file is given below:
|
||||
|
||||
foo /path/to/a.wav
|
||||
bar /path/to/b.wav
|
||||
foo /path/to/c.wav
|
||||
foobar /path/to/d.wav
|
||||
|
||||
Each wave file should contain only a single channel; the sample format
|
||||
should be int16_t; the sample rate can be arbitrary.
|
||||
|
||||
(2) Download a model for computing speaker embeddings
|
||||
|
||||
Please visit
|
||||
https://github.com/k2-fsa/sherpa-onnx/releases/tag/speaker-recongition-models
|
||||
to download a model. An example is given below:
|
||||
|
||||
wget https://github.com/k2-fsa/sherpa-onnx/releases/download/speaker-recongition-models/zh_cnceleb_resnet34.onnx
|
||||
|
||||
Note that `zh` means Chinese, while `en` means English.
|
||||
|
||||
(3) Run this script
|
||||
|
||||
Assume the filename of the text file is speaker.txt.
|
||||
|
||||
python3 ./python-api-examples/speaker-identification.py \
|
||||
--speaker-file ./speaker.txt \
|
||||
--model ./zh_cnceleb_resnet34.onnx
|
||||
"""
|
||||
import argparse
|
||||
import queue
|
||||
import threading
|
||||
from collections import defaultdict
|
||||
from pathlib import Path
|
||||
from typing import Dict, List, Tuple
|
||||
|
||||
import numpy as np
|
||||
import sherpa_onnx
|
||||
import torchaudio
|
||||
|
||||
try:
|
||||
import sounddevice as sd
|
||||
except ImportError:
|
||||
print("Please install sounddevice first. You can use")
|
||||
print()
|
||||
print(" pip install sounddevice")
|
||||
print()
|
||||
print("to install it")
|
||||
sys.exit(-1)
|
||||
|
||||
|
||||
def get_args():
|
||||
parser = argparse.ArgumentParser(
|
||||
formatter_class=argparse.ArgumentDefaultsHelpFormatter
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--speaker-file",
|
||||
type=str,
|
||||
required=True,
|
||||
help="""Path to the speaker file. Read the help doc at the beginning of this
|
||||
file for the format.""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--model",
|
||||
type=str,
|
||||
required=True,
|
||||
help="Path to the model file.",
|
||||
)
|
||||
|
||||
parser.add_argument("--threshold", type=float, default=0.6)
|
||||
|
||||
parser.add_argument(
|
||||
"--num-threads",
|
||||
type=int,
|
||||
default=1,
|
||||
help="Number of threads for neural network computation",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--debug",
|
||||
type=bool,
|
||||
default=False,
|
||||
help="True to show debug messages",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--provider",
|
||||
type=str,
|
||||
default="cpu",
|
||||
help="Valid values: cpu, cuda, coreml",
|
||||
)
|
||||
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
def load_speaker_embedding_model(args):
|
||||
config = sherpa_onnx.SpeakerEmbeddingExtractorConfig(
|
||||
model=args.model,
|
||||
num_threads=args.num_threads,
|
||||
debug=args.debug,
|
||||
provider=args.provider,
|
||||
)
|
||||
if not config.validate():
|
||||
raise ValueError(f"Invalid config. {config}")
|
||||
extractor = sherpa_onnx.SpeakerEmbeddingExtractor(config)
|
||||
return extractor
|
||||
|
||||
|
||||
def load_speaker_file(args) -> Dict[str, List[str]]:
|
||||
if not Path(args.speaker_file).is_file():
|
||||
raise ValueError(f"--speaker-file {args.speaker_file} does not exist")
|
||||
|
||||
ans = defaultdict(list)
|
||||
with open(args.speaker_file) as f:
|
||||
for line in f:
|
||||
line = line.strip()
|
||||
if not line:
|
||||
continue
|
||||
|
||||
fields = line.split()
|
||||
if len(fields) != 2:
|
||||
raise ValueError(f"Invalid line: {line}. Fields: {fields}")
|
||||
|
||||
speaker_name, filename = fields
|
||||
ans[speaker_name].append(filename)
|
||||
return ans
|
||||
|
||||
|
||||
def load_audio(filename: str) -> Tuple[np.ndarray, int]:
|
||||
samples, sample_rate = torchaudio.load(filename)
|
||||
return samples[0].contiguous().numpy(), sample_rate
|
||||
|
||||
|
||||
def compute_speaker_embedding(
|
||||
filenames: List[str],
|
||||
extractor: sherpa_onnx.SpeakerEmbeddingExtractor,
|
||||
) -> np.ndarray:
|
||||
assert len(filenames) > 0, f"filenames is empty"
|
||||
|
||||
ans = None
|
||||
for filename in filenames:
|
||||
print(f"processing {filename}")
|
||||
samples, sample_rate = load_audio(filename)
|
||||
stream = extractor.create_stream()
|
||||
stream.accept_waveform(sample_rate=sample_rate, waveform=samples)
|
||||
stream.input_finished()
|
||||
|
||||
assert extractor.is_ready(stream)
|
||||
embedding = extractor.compute(stream)
|
||||
embedding = np.array(embedding)
|
||||
if ans is None:
|
||||
ans = embedding
|
||||
else:
|
||||
ans += embedding
|
||||
|
||||
return ans / len(filenames)
|
||||
|
||||
|
||||
g_buffer = queue.Queue()
|
||||
g_stop = False
|
||||
g_sample_rate = 16000
|
||||
g_read_mic_thread = None
|
||||
|
||||
|
||||
def read_mic():
|
||||
print("Please speak!")
|
||||
samples_per_read = int(0.1 * g_sample_rate) # 0.1 second = 100 ms
|
||||
with sd.InputStream(channels=1, dtype="float32", samplerate=g_sample_rate) as s:
|
||||
while not g_stop:
|
||||
samples, _ = s.read(samples_per_read) # a blocking read
|
||||
g_buffer.put(samples)
|
||||
|
||||
|
||||
def main():
|
||||
args = get_args()
|
||||
print(args)
|
||||
extractor = load_speaker_embedding_model(args)
|
||||
speaker_file = load_speaker_file(args)
|
||||
|
||||
manager = sherpa_onnx.SpeakerEmbeddingManager(extractor.dim)
|
||||
for name, filename_list in speaker_file.items():
|
||||
embedding = compute_speaker_embedding(
|
||||
filenames=filename_list,
|
||||
extractor=extractor,
|
||||
)
|
||||
status = manager.add(name, embedding)
|
||||
if not status:
|
||||
raise RuntimeError(f"Failed to register speaker {name}")
|
||||
|
||||
devices = sd.query_devices()
|
||||
if len(devices) == 0:
|
||||
print("No microphone devices found")
|
||||
sys.exit(0)
|
||||
|
||||
print(devices)
|
||||
default_input_device_idx = sd.default.device[0]
|
||||
print(f'Use default device: {devices[default_input_device_idx]["name"]}')
|
||||
|
||||
global g_stop
|
||||
global g_read_mic_thread
|
||||
while True:
|
||||
key = input("Press enter to start recording")
|
||||
if key.lower() in ("q", "quit"):
|
||||
g_stop = True
|
||||
break
|
||||
|
||||
g_stop = False
|
||||
g_buffer.queue.clear()
|
||||
g_read_mic_thread = threading.Thread(target=read_mic)
|
||||
g_read_mic_thread.start()
|
||||
input("Press enter to stop recording")
|
||||
g_stop = True
|
||||
g_read_mic_thread.join()
|
||||
print("Compute embedding")
|
||||
stream = extractor.create_stream()
|
||||
while not g_buffer.empty():
|
||||
samples = g_buffer.get()
|
||||
stream.accept_waveform(sample_rate=g_sample_rate, waveform=samples)
|
||||
stream.input_finished()
|
||||
|
||||
embedding = extractor.compute(stream)
|
||||
embedding = np.array(embedding)
|
||||
name = manager.search(embedding, threshold=args.threshold)
|
||||
if not name:
|
||||
name = "unknown"
|
||||
print(f"Predicted name: {name}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
try:
|
||||
main()
|
||||
except KeyboardInterrupt:
|
||||
print("\nCaught Ctrl + C. Exiting")
|
||||
g_stop = True
|
||||
if g_read_mic_thread.is_alive():
|
||||
g_read_mic_thread.join()
|
||||
@@ -96,6 +96,14 @@ set(sources
|
||||
wave-reader.cc
|
||||
)
|
||||
|
||||
# speaker embedding extractor
|
||||
list(APPEND sources
|
||||
speaker-embedding-extractor-impl.cc
|
||||
speaker-embedding-extractor-wespeaker-model.cc
|
||||
speaker-embedding-extractor.cc
|
||||
speaker-embedding-manager.cc
|
||||
)
|
||||
|
||||
list(APPEND sources
|
||||
lexicon.cc
|
||||
offline-tts-impl.cc
|
||||
@@ -387,6 +395,10 @@ if(SHERPA_ONNX_ENABLE_TESTS)
|
||||
utfcpp-test.cc
|
||||
)
|
||||
|
||||
list(APPEND sherpa_onnx_test_srcs
|
||||
speaker-embedding-manager-test.cc
|
||||
)
|
||||
|
||||
function(sherpa_onnx_add_test source)
|
||||
get_filename_component(name ${source} NAME_WE)
|
||||
set(target_name ${name})
|
||||
|
||||
@@ -64,8 +64,8 @@ TEST(ContextGraph, Benchmark) {
|
||||
auto stop = std::chrono::high_resolution_clock::now();
|
||||
auto duration =
|
||||
std::chrono::duration_cast<std::chrono::microseconds>(stop - start);
|
||||
SHERPA_ONNX_LOGE("Construct context graph for %d item takes %ld us.", num,
|
||||
duration.count());
|
||||
SHERPA_ONNX_LOGE("Construct context graph for %d item takes %d us.", num,
|
||||
static_cast<int32_t>(duration.count()));
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -91,4 +91,9 @@ Ort::SessionOptions GetSessionOptions(const OfflineTtsModelConfig &config) {
|
||||
return GetSessionOptionsImpl(config.num_threads, config.provider);
|
||||
}
|
||||
|
||||
Ort::SessionOptions GetSessionOptions(
|
||||
const SpeakerEmbeddingExtractorConfig &config) {
|
||||
return GetSessionOptionsImpl(config.num_threads, config.provider);
|
||||
}
|
||||
|
||||
} // namespace sherpa_onnx
|
||||
|
||||
@@ -11,6 +11,7 @@
|
||||
#include "sherpa-onnx/csrc/offline-tts-model-config.h"
|
||||
#include "sherpa-onnx/csrc/online-lm-config.h"
|
||||
#include "sherpa-onnx/csrc/online-model-config.h"
|
||||
#include "sherpa-onnx/csrc/speaker-embedding-extractor.h"
|
||||
#include "sherpa-onnx/csrc/vad-model-config.h"
|
||||
|
||||
namespace sherpa_onnx {
|
||||
@@ -26,6 +27,9 @@ Ort::SessionOptions GetSessionOptions(const OnlineLMConfig &config);
|
||||
Ort::SessionOptions GetSessionOptions(const VadModelConfig &config);
|
||||
|
||||
Ort::SessionOptions GetSessionOptions(const OfflineTtsModelConfig &config);
|
||||
|
||||
Ort::SessionOptions GetSessionOptions(
|
||||
const SpeakerEmbeddingExtractorConfig &config);
|
||||
} // namespace sherpa_onnx
|
||||
|
||||
#endif // SHERPA_ONNX_CSRC_SESSION_H_
|
||||
|
||||
82
sherpa-onnx/csrc/speaker-embedding-extractor-impl.cc
Normal file
82
sherpa-onnx/csrc/speaker-embedding-extractor-impl.cc
Normal file
@@ -0,0 +1,82 @@
|
||||
// sherpa-onnx/csrc/speaker-embedding-extractor-impl.cc
|
||||
//
|
||||
// Copyright (c) 2023 Xiaomi Corporation
|
||||
#include "sherpa-onnx/csrc/speaker-embedding-extractor-impl.h"
|
||||
|
||||
#include "sherpa-onnx/csrc/macros.h"
|
||||
#include "sherpa-onnx/csrc/onnx-utils.h"
|
||||
#include "sherpa-onnx/csrc/speaker-embedding-extractor-wespeaker-impl.h"
|
||||
|
||||
namespace sherpa_onnx {
|
||||
|
||||
namespace {
|
||||
|
||||
enum class ModelType {
|
||||
kWeSpeaker,
|
||||
kUnkown,
|
||||
};
|
||||
|
||||
} // namespace
|
||||
|
||||
static ModelType GetModelType(char *model_data, size_t model_data_length,
|
||||
bool debug) {
|
||||
Ort::Env env(ORT_LOGGING_LEVEL_WARNING);
|
||||
Ort::SessionOptions sess_opts;
|
||||
|
||||
auto sess = std::make_unique<Ort::Session>(env, model_data, model_data_length,
|
||||
sess_opts);
|
||||
|
||||
Ort::ModelMetadata meta_data = sess->GetModelMetadata();
|
||||
if (debug) {
|
||||
std::ostringstream os;
|
||||
PrintModelMetadata(os, meta_data);
|
||||
SHERPA_ONNX_LOGE("%s", os.str().c_str());
|
||||
}
|
||||
|
||||
Ort::AllocatorWithDefaultOptions allocator;
|
||||
auto model_type =
|
||||
meta_data.LookupCustomMetadataMapAllocated("framework", allocator);
|
||||
if (!model_type) {
|
||||
SHERPA_ONNX_LOGE(
|
||||
"No model_type in the metadata!\n"
|
||||
"Please make sure you have added metadata to the model.\n\n"
|
||||
"For instance, you can use\n"
|
||||
"https://github.com/k2-fsa/sherpa-onnx/blob/master/scripts/wespeaker/"
|
||||
"add_meta_data.py"
|
||||
"to add metadata to models from WeSpeaker\n");
|
||||
return ModelType::kUnkown;
|
||||
}
|
||||
|
||||
if (model_type.get() == std::string("wespeaker")) {
|
||||
return ModelType::kWeSpeaker;
|
||||
} else {
|
||||
SHERPA_ONNX_LOGE("Unsupported model_type: %s", model_type.get());
|
||||
return ModelType::kUnkown;
|
||||
}
|
||||
}
|
||||
|
||||
std::unique_ptr<SpeakerEmbeddingExtractorImpl>
|
||||
SpeakerEmbeddingExtractorImpl::Create(
|
||||
const SpeakerEmbeddingExtractorConfig &config) {
|
||||
ModelType model_type = ModelType::kUnkown;
|
||||
|
||||
{
|
||||
auto buffer = ReadFile(config.model);
|
||||
|
||||
model_type = GetModelType(buffer.data(), buffer.size(), config.debug);
|
||||
}
|
||||
|
||||
switch (model_type) {
|
||||
case ModelType::kWeSpeaker:
|
||||
return std::make_unique<SpeakerEmbeddingExtractorWeSpeakerImpl>(config);
|
||||
case ModelType::kUnkown:
|
||||
SHERPA_ONNX_LOGE(
|
||||
"Unknown model type in for speaker embedding extractor!");
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
// unreachable code
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
} // namespace sherpa_onnx
|
||||
34
sherpa-onnx/csrc/speaker-embedding-extractor-impl.h
Normal file
34
sherpa-onnx/csrc/speaker-embedding-extractor-impl.h
Normal file
@@ -0,0 +1,34 @@
|
||||
// sherpa-onnx/csrc/speaker-embedding-extractor-impl.h
|
||||
//
|
||||
// Copyright (c) 2023 Xiaomi Corporation
|
||||
|
||||
#ifndef SHERPA_ONNX_CSRC_SPEAKER_EMBEDDING_EXTRACTOR_IMPL_H_
|
||||
#define SHERPA_ONNX_CSRC_SPEAKER_EMBEDDING_EXTRACTOR_IMPL_H_
|
||||
|
||||
#include <memory>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
|
||||
#include "sherpa-onnx/csrc/speaker-embedding-extractor.h"
|
||||
|
||||
namespace sherpa_onnx {
|
||||
|
||||
class SpeakerEmbeddingExtractorImpl {
|
||||
public:
|
||||
virtual ~SpeakerEmbeddingExtractorImpl() = default;
|
||||
|
||||
static std::unique_ptr<SpeakerEmbeddingExtractorImpl> Create(
|
||||
const SpeakerEmbeddingExtractorConfig &config);
|
||||
|
||||
virtual int32_t Dim() const = 0;
|
||||
|
||||
virtual std::unique_ptr<OnlineStream> CreateStream() const = 0;
|
||||
|
||||
virtual bool IsReady(OnlineStream *s) const = 0;
|
||||
|
||||
virtual std::vector<float> Compute(OnlineStream *s) const = 0;
|
||||
};
|
||||
|
||||
} // namespace sherpa_onnx
|
||||
|
||||
#endif // SHERPA_ONNX_CSRC_SPEAKER_EMBEDDING_EXTRACTOR_IMPL_H_
|
||||
@@ -0,0 +1,79 @@
|
||||
// sherpa-onnx/csrc/speaker-embedding-extractor-wespeaker-impl.h
|
||||
//
|
||||
// Copyright (c) 2023 Xiaomi Corporation
|
||||
|
||||
#ifndef SHERPA_ONNX_CSRC_SPEAKER_EMBEDDING_EXTRACTOR_WESPEAKER_IMPL_H_
|
||||
#define SHERPA_ONNX_CSRC_SPEAKER_EMBEDDING_EXTRACTOR_WESPEAKER_IMPL_H_
|
||||
#include <algorithm>
|
||||
#include <memory>
|
||||
#include <utility>
|
||||
#include <vector>
|
||||
|
||||
#include "sherpa-onnx/csrc/speaker-embedding-extractor-impl.h"
|
||||
#include "sherpa-onnx/csrc/speaker-embedding-extractor-wespeaker-model.h"
|
||||
|
||||
namespace sherpa_onnx {
|
||||
|
||||
class SpeakerEmbeddingExtractorWeSpeakerImpl
|
||||
: public SpeakerEmbeddingExtractorImpl {
|
||||
public:
|
||||
explicit SpeakerEmbeddingExtractorWeSpeakerImpl(
|
||||
const SpeakerEmbeddingExtractorConfig &config)
|
||||
: model_(config) {}
|
||||
|
||||
int32_t Dim() const override { return model_.GetMetaData().output_dim; }
|
||||
|
||||
std::unique_ptr<OnlineStream> CreateStream() const override {
|
||||
FeatureExtractorConfig feat_config;
|
||||
auto meta_data = model_.GetMetaData();
|
||||
feat_config.sampling_rate = meta_data.sample_rate;
|
||||
feat_config.normalize_samples = meta_data.normalize_features;
|
||||
|
||||
return std::make_unique<OnlineStream>(feat_config);
|
||||
}
|
||||
|
||||
bool IsReady(OnlineStream *s) const override {
|
||||
return s->GetNumProcessedFrames() < s->NumFramesReady();
|
||||
}
|
||||
|
||||
std::vector<float> Compute(OnlineStream *s) const override {
|
||||
int32_t num_frames = s->NumFramesReady() - s->GetNumProcessedFrames();
|
||||
if (num_frames <= 0) {
|
||||
SHERPA_ONNX_LOGE(
|
||||
"Please make sure IsReady(s) returns true. num_frames: %d",
|
||||
num_frames);
|
||||
return {};
|
||||
}
|
||||
|
||||
std::vector<float> features =
|
||||
s->GetFrames(s->GetNumProcessedFrames(), num_frames);
|
||||
|
||||
s->GetNumProcessedFrames() += num_frames;
|
||||
|
||||
int32_t feat_dim = features.size() / num_frames;
|
||||
|
||||
auto memory_info =
|
||||
Ort::MemoryInfo::CreateCpu(OrtDeviceAllocator, OrtMemTypeDefault);
|
||||
|
||||
std::array<int64_t, 3> x_shape{1, num_frames, feat_dim};
|
||||
Ort::Value x =
|
||||
Ort::Value::CreateTensor(memory_info, features.data(), features.size(),
|
||||
x_shape.data(), x_shape.size());
|
||||
Ort::Value embedding = model_.Compute(std::move(x));
|
||||
std::vector<int64_t> embedding_shape =
|
||||
embedding.GetTensorTypeAndShapeInfo().GetShape();
|
||||
|
||||
std::vector<float> ans(embedding_shape[1]);
|
||||
std::copy(embedding.GetTensorData<float>(),
|
||||
embedding.GetTensorData<float>() + ans.size(), ans.begin());
|
||||
|
||||
return ans;
|
||||
}
|
||||
|
||||
private:
|
||||
SpeakerEmbeddingExtractorWeSpeakerModel model_;
|
||||
};
|
||||
|
||||
} // namespace sherpa_onnx
|
||||
|
||||
#endif // SHERPA_ONNX_CSRC_SPEAKER_EMBEDDING_EXTRACTOR_WESPEAKER_IMPL_H_
|
||||
@@ -0,0 +1,20 @@
|
||||
// sherpa-onnx/csrc/speaker-embedding-extractor-wespeaker-model-metadata.h
|
||||
//
|
||||
// Copyright (c) 2023 Xiaomi Corporation
|
||||
#ifndef SHERPA_ONNX_CSRC_SPEAKER_EMBEDDING_EXTRACTOR_WESPEAKER_MODEL_METADATA_H_
|
||||
#define SHERPA_ONNX_CSRC_SPEAKER_EMBEDDING_EXTRACTOR_WESPEAKER_MODEL_METADATA_H_
|
||||
|
||||
#include <cstdint>
|
||||
#include <string>
|
||||
|
||||
namespace sherpa_onnx {
|
||||
|
||||
struct SpeakerEmbeddingExtractorWeSpeakerModelMetaData {
|
||||
int32_t output_dim = 0;
|
||||
int32_t sample_rate = 0;
|
||||
int32_t normalize_features = 0;
|
||||
std::string language;
|
||||
};
|
||||
|
||||
} // namespace sherpa_onnx
|
||||
#endif // SHERPA_ONNX_CSRC_SPEAKER_EMBEDDING_EXTRACTOR_WESPEAKER_MODEL_METADATA_H_
|
||||
112
sherpa-onnx/csrc/speaker-embedding-extractor-wespeaker-model.cc
Normal file
112
sherpa-onnx/csrc/speaker-embedding-extractor-wespeaker-model.cc
Normal file
@@ -0,0 +1,112 @@
|
||||
// sherpa-onnx/csrc/speaker-embedding-extractor-wespeaker-model.cc
|
||||
//
|
||||
// Copyright (c) 2023 Xiaomi Corporation
|
||||
|
||||
#include "sherpa-onnx/csrc/speaker-embedding-extractor-wespeaker-model.h"
|
||||
|
||||
#include <string>
|
||||
#include <utility>
|
||||
#include <vector>
|
||||
|
||||
#include "sherpa-onnx/csrc/macros.h"
|
||||
#include "sherpa-onnx/csrc/onnx-utils.h"
|
||||
#include "sherpa-onnx/csrc/session.h"
|
||||
#include "sherpa-onnx/csrc/speaker-embedding-extractor-wespeaker-model-metadata.h"
|
||||
|
||||
namespace sherpa_onnx {
|
||||
|
||||
class SpeakerEmbeddingExtractorWeSpeakerModel::Impl {
|
||||
public:
|
||||
explicit Impl(const SpeakerEmbeddingExtractorConfig &config)
|
||||
: config_(config),
|
||||
env_(ORT_LOGGING_LEVEL_ERROR),
|
||||
sess_opts_(GetSessionOptions(config)),
|
||||
allocator_{} {
|
||||
{
|
||||
auto buf = ReadFile(config.model);
|
||||
Init(buf.data(), buf.size());
|
||||
}
|
||||
}
|
||||
|
||||
Ort::Value Compute(Ort::Value x) const {
|
||||
std::array<Ort::Value, 1> inputs = {std::move(x)};
|
||||
|
||||
auto outputs =
|
||||
sess_->Run({}, input_names_ptr_.data(), inputs.data(), inputs.size(),
|
||||
output_names_ptr_.data(), output_names_ptr_.size());
|
||||
return std::move(outputs[0]);
|
||||
}
|
||||
|
||||
const SpeakerEmbeddingExtractorWeSpeakerModelMetaData &GetMetaData() const {
|
||||
return meta_data_;
|
||||
}
|
||||
|
||||
private:
|
||||
void Init(void *model_data, size_t model_data_length) {
|
||||
sess_ = std::make_unique<Ort::Session>(env_, model_data, model_data_length,
|
||||
sess_opts_);
|
||||
|
||||
GetInputNames(sess_.get(), &input_names_, &input_names_ptr_);
|
||||
|
||||
GetOutputNames(sess_.get(), &output_names_, &output_names_ptr_);
|
||||
|
||||
// get meta data
|
||||
Ort::ModelMetadata meta_data = sess_->GetModelMetadata();
|
||||
if (config_.debug) {
|
||||
std::ostringstream os;
|
||||
PrintModelMetadata(os, meta_data);
|
||||
SHERPA_ONNX_LOGE("%s", os.str().c_str());
|
||||
}
|
||||
|
||||
Ort::AllocatorWithDefaultOptions allocator; // used in the macro below
|
||||
SHERPA_ONNX_READ_META_DATA(meta_data_.output_dim, "output_dim");
|
||||
SHERPA_ONNX_READ_META_DATA(meta_data_.sample_rate, "sample_rate");
|
||||
SHERPA_ONNX_READ_META_DATA(meta_data_.normalize_features,
|
||||
"normalize_features");
|
||||
SHERPA_ONNX_READ_META_DATA_STR(meta_data_.language, "language");
|
||||
|
||||
std::string framework;
|
||||
SHERPA_ONNX_READ_META_DATA_STR(framework, "framework");
|
||||
if (framework != "wespeaker") {
|
||||
SHERPA_ONNX_LOGE("Expect a wespeaker model, given: %s",
|
||||
framework.c_str());
|
||||
exit(-1);
|
||||
}
|
||||
}
|
||||
|
||||
private:
|
||||
SpeakerEmbeddingExtractorConfig config_;
|
||||
Ort::Env env_;
|
||||
Ort::SessionOptions sess_opts_;
|
||||
Ort::AllocatorWithDefaultOptions allocator_;
|
||||
|
||||
std::unique_ptr<Ort::Session> sess_;
|
||||
|
||||
std::vector<std::string> input_names_;
|
||||
std::vector<const char *> input_names_ptr_;
|
||||
|
||||
std::vector<std::string> output_names_;
|
||||
std::vector<const char *> output_names_ptr_;
|
||||
|
||||
SpeakerEmbeddingExtractorWeSpeakerModelMetaData meta_data_;
|
||||
};
|
||||
|
||||
SpeakerEmbeddingExtractorWeSpeakerModel::
|
||||
SpeakerEmbeddingExtractorWeSpeakerModel(
|
||||
const SpeakerEmbeddingExtractorConfig &config)
|
||||
: impl_(std::make_unique<Impl>(config)) {}
|
||||
|
||||
SpeakerEmbeddingExtractorWeSpeakerModel::
|
||||
~SpeakerEmbeddingExtractorWeSpeakerModel() = default;
|
||||
|
||||
const SpeakerEmbeddingExtractorWeSpeakerModelMetaData &
|
||||
SpeakerEmbeddingExtractorWeSpeakerModel::GetMetaData() const {
|
||||
return impl_->GetMetaData();
|
||||
}
|
||||
|
||||
Ort::Value SpeakerEmbeddingExtractorWeSpeakerModel::Compute(
|
||||
Ort::Value x) const {
|
||||
return impl_->Compute(std::move(x));
|
||||
}
|
||||
|
||||
} // namespace sherpa_onnx
|
||||
@@ -0,0 +1,37 @@
|
||||
// sherpa-onnx/csrc/speaker-embedding-extractor-wespeaker-model.h
|
||||
//
|
||||
// Copyright (c) 2023 Xiaomi Corporation
|
||||
#ifndef SHERPA_ONNX_CSRC_SPEAKER_EMBEDDING_EXTRACTOR_WESPEAKER_MODEL_H_
|
||||
#define SHERPA_ONNX_CSRC_SPEAKER_EMBEDDING_EXTRACTOR_WESPEAKER_MODEL_H_
|
||||
|
||||
#include <memory>
|
||||
|
||||
#include "onnxruntime_cxx_api.h" // NOLINT
|
||||
#include "sherpa-onnx/csrc/speaker-embedding-extractor-wespeaker-model-metadata.h"
|
||||
#include "sherpa-onnx/csrc/speaker-embedding-extractor.h"
|
||||
|
||||
namespace sherpa_onnx {
|
||||
|
||||
class SpeakerEmbeddingExtractorWeSpeakerModel {
|
||||
public:
|
||||
explicit SpeakerEmbeddingExtractorWeSpeakerModel(
|
||||
const SpeakerEmbeddingExtractorConfig &config);
|
||||
|
||||
~SpeakerEmbeddingExtractorWeSpeakerModel();
|
||||
|
||||
const SpeakerEmbeddingExtractorWeSpeakerModelMetaData &GetMetaData() const;
|
||||
|
||||
/**
|
||||
* @param x A float32 tensor of shape (N, T, C)
|
||||
* @return A float32 tensor of shape (N, C)
|
||||
*/
|
||||
Ort::Value Compute(Ort::Value x) const;
|
||||
|
||||
private:
|
||||
class Impl;
|
||||
std::unique_ptr<Impl> impl_;
|
||||
};
|
||||
|
||||
} // namespace sherpa_onnx
|
||||
|
||||
#endif // SHERPA_ONNX_CSRC_SPEAKER_EMBEDDING_EXTRACTOR_WESPEAKER_MODEL_H_
|
||||
74
sherpa-onnx/csrc/speaker-embedding-extractor.cc
Normal file
74
sherpa-onnx/csrc/speaker-embedding-extractor.cc
Normal file
@@ -0,0 +1,74 @@
|
||||
// sherpa-onnx/csrc/speaker-embedding-extractor.cc
|
||||
//
|
||||
// Copyright (c) 2023 Xiaomi Corporation
|
||||
|
||||
#include "sherpa-onnx/csrc/speaker-embedding-extractor.h"
|
||||
|
||||
#include <vector>
|
||||
|
||||
#include "sherpa-onnx/csrc/file-utils.h"
|
||||
#include "sherpa-onnx/csrc/macros.h"
|
||||
#include "sherpa-onnx/csrc/speaker-embedding-extractor-impl.h"
|
||||
|
||||
namespace sherpa_onnx {
|
||||
|
||||
void SpeakerEmbeddingExtractorConfig::Register(ParseOptions *po) {
|
||||
po->Register("model", &model, "Path to the speaker embedding model.");
|
||||
po->Register("num-threads", &num_threads,
|
||||
"Number of threads to run the neural network");
|
||||
|
||||
po->Register("debug", &debug,
|
||||
"true to print model information while loading it.");
|
||||
|
||||
po->Register("provider", &provider,
|
||||
"Specify a provider to use: cpu, cuda, coreml");
|
||||
}
|
||||
|
||||
bool SpeakerEmbeddingExtractorConfig::Validate() const {
|
||||
if (model.empty()) {
|
||||
SHERPA_ONNX_LOGE("Please provide --speaker-embedding-model");
|
||||
return false;
|
||||
}
|
||||
|
||||
if (!FileExists(model)) {
|
||||
SHERPA_ONNX_LOGE("--speaker-embedding-model: %s does not exist",
|
||||
model.c_str());
|
||||
return false;
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
std::string SpeakerEmbeddingExtractorConfig::ToString() const {
|
||||
std::ostringstream os;
|
||||
|
||||
os << "SpeakerEmbeddingExtractorConfig(";
|
||||
os << "model=\"" << model << "\", ";
|
||||
os << "num_threads=" << num_threads << ", ";
|
||||
os << "debug=" << (debug ? "True" : "False") << ", ";
|
||||
os << "provider=\"" << provider << "\")";
|
||||
|
||||
return os.str();
|
||||
}
|
||||
|
||||
SpeakerEmbeddingExtractor::SpeakerEmbeddingExtractor(
|
||||
const SpeakerEmbeddingExtractorConfig &config)
|
||||
: impl_(SpeakerEmbeddingExtractorImpl::Create(config)) {}
|
||||
|
||||
SpeakerEmbeddingExtractor::~SpeakerEmbeddingExtractor() = default;
|
||||
|
||||
int32_t SpeakerEmbeddingExtractor::Dim() const { return impl_->Dim(); }
|
||||
|
||||
std::unique_ptr<OnlineStream> SpeakerEmbeddingExtractor::CreateStream() const {
|
||||
return impl_->CreateStream();
|
||||
}
|
||||
|
||||
bool SpeakerEmbeddingExtractor::IsReady(OnlineStream *s) const {
|
||||
return impl_->IsReady(s);
|
||||
}
|
||||
|
||||
std::vector<float> SpeakerEmbeddingExtractor::Compute(OnlineStream *s) const {
|
||||
return impl_->Compute(s);
|
||||
}
|
||||
|
||||
} // namespace sherpa_onnx
|
||||
67
sherpa-onnx/csrc/speaker-embedding-extractor.h
Normal file
67
sherpa-onnx/csrc/speaker-embedding-extractor.h
Normal file
@@ -0,0 +1,67 @@
|
||||
// sherpa-onnx/csrc/speaker-embedding-extractor.h
|
||||
//
|
||||
// Copyright (c) 2023 Xiaomi Corporation
|
||||
|
||||
#ifndef SHERPA_ONNX_CSRC_SPEAKER_EMBEDDING_EXTRACTOR_H_
|
||||
#define SHERPA_ONNX_CSRC_SPEAKER_EMBEDDING_EXTRACTOR_H_
|
||||
|
||||
#include <memory>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
|
||||
#include "sherpa-onnx/csrc/online-stream.h"
|
||||
#include "sherpa-onnx/csrc/parse-options.h"
|
||||
|
||||
namespace sherpa_onnx {
|
||||
|
||||
struct SpeakerEmbeddingExtractorConfig {
|
||||
std::string model;
|
||||
int32_t num_threads = 1;
|
||||
bool debug = false;
|
||||
std::string provider = "cpu";
|
||||
|
||||
SpeakerEmbeddingExtractorConfig() = default;
|
||||
SpeakerEmbeddingExtractorConfig(const std::string &model, int32_t num_threads,
|
||||
bool debug, const std::string &provider)
|
||||
: model(model),
|
||||
num_threads(num_threads),
|
||||
debug(debug),
|
||||
provider(provider) {}
|
||||
|
||||
void Register(ParseOptions *po);
|
||||
bool Validate() const;
|
||||
std::string ToString() const;
|
||||
};
|
||||
|
||||
class SpeakerEmbeddingExtractorImpl;
|
||||
|
||||
class SpeakerEmbeddingExtractor {
|
||||
public:
|
||||
explicit SpeakerEmbeddingExtractor(
|
||||
const SpeakerEmbeddingExtractorConfig &config);
|
||||
|
||||
~SpeakerEmbeddingExtractor();
|
||||
|
||||
// Return the dimension of the embedding
|
||||
int32_t Dim() const;
|
||||
|
||||
// Create a stream to accept audio samples and compute features
|
||||
std::unique_ptr<OnlineStream> CreateStream() const;
|
||||
|
||||
// Return true if there are feature frames in OnlineStream that
|
||||
// can be used to compute embeddings.
|
||||
bool IsReady(OnlineStream *s) const;
|
||||
|
||||
// Compute the speaker embedding from the available unprocessed features
|
||||
// of the given stream
|
||||
//
|
||||
// You have to ensure IsReady(s) returns true before you call this method.
|
||||
std::vector<float> Compute(OnlineStream *s) const;
|
||||
|
||||
private:
|
||||
std::unique_ptr<SpeakerEmbeddingExtractorImpl> impl_;
|
||||
};
|
||||
|
||||
} // namespace sherpa_onnx
|
||||
|
||||
#endif // SHERPA_ONNX_CSRC_SPEAKER_EMBEDDING_EXTRACTOR_H_
|
||||
147
sherpa-onnx/csrc/speaker-embedding-manager-test.cc
Normal file
147
sherpa-onnx/csrc/speaker-embedding-manager-test.cc
Normal file
@@ -0,0 +1,147 @@
|
||||
// sherpa-onnx/csrc/speaker-embedding-manager-test.cc
|
||||
//
|
||||
// Copyright (c) 2023 Jingzhao Ou (jingzhao.ou@gmail.com)
|
||||
|
||||
#include "sherpa-onnx/csrc/speaker-embedding-manager.h"
|
||||
|
||||
#include "gtest/gtest.h"
|
||||
|
||||
namespace sherpa_onnx {
|
||||
|
||||
TEST(SpeakerEmbeddingManager, AddAndRemove) {
|
||||
int32_t dim = 2;
|
||||
SpeakerEmbeddingManager manager(dim);
|
||||
std::vector<float> v = {0.1, 0.1};
|
||||
bool status = manager.Add("first", v.data());
|
||||
ASSERT_TRUE(status);
|
||||
ASSERT_EQ(manager.NumSpeakers(), 1);
|
||||
|
||||
// duplicate
|
||||
status = manager.Add("first", v.data());
|
||||
ASSERT_FALSE(status);
|
||||
ASSERT_EQ(manager.NumSpeakers(), 1);
|
||||
|
||||
// non-duplicate
|
||||
v = {0.1, 0.9};
|
||||
status = manager.Add("second", v.data());
|
||||
ASSERT_TRUE(status);
|
||||
ASSERT_EQ(manager.NumSpeakers(), 2);
|
||||
|
||||
// do not exist
|
||||
status = manager.Remove("third");
|
||||
ASSERT_FALSE(status);
|
||||
|
||||
status = manager.Remove("first");
|
||||
ASSERT_TRUE(status);
|
||||
ASSERT_EQ(manager.NumSpeakers(), 1);
|
||||
|
||||
v = {0.1, 0.1};
|
||||
status = manager.Add("first", v.data());
|
||||
ASSERT_TRUE(status);
|
||||
ASSERT_EQ(manager.NumSpeakers(), 2);
|
||||
|
||||
status = manager.Remove("first");
|
||||
ASSERT_TRUE(status);
|
||||
ASSERT_EQ(manager.NumSpeakers(), 1);
|
||||
|
||||
status = manager.Remove("second");
|
||||
ASSERT_TRUE(status);
|
||||
ASSERT_EQ(manager.NumSpeakers(), 0);
|
||||
}
|
||||
|
||||
TEST(SpeakerEmbeddingManager, Search) {
|
||||
int32_t dim = 2;
|
||||
SpeakerEmbeddingManager manager(dim);
|
||||
std::vector<float> v1 = {0.1, 0.1};
|
||||
std::vector<float> v2 = {0.1, 0.9};
|
||||
std::vector<float> v3 = {0.9, 0.1};
|
||||
bool status = manager.Add("first", v1.data());
|
||||
ASSERT_TRUE(status);
|
||||
|
||||
status = manager.Add("second", v2.data());
|
||||
ASSERT_TRUE(status);
|
||||
|
||||
status = manager.Add("third", v3.data());
|
||||
ASSERT_TRUE(status);
|
||||
|
||||
ASSERT_EQ(manager.NumSpeakers(), 3);
|
||||
|
||||
std::vector<float> v = {15, 16};
|
||||
float threshold = 0.9;
|
||||
|
||||
std::string name = manager.Search(v.data(), threshold);
|
||||
EXPECT_EQ(name, "first");
|
||||
|
||||
v = {2, 17};
|
||||
name = manager.Search(v.data(), threshold);
|
||||
EXPECT_EQ(name, "second");
|
||||
|
||||
v = {17, 2};
|
||||
name = manager.Search(v.data(), threshold);
|
||||
EXPECT_EQ(name, "third");
|
||||
|
||||
threshold = 0.9;
|
||||
v = {15, 16};
|
||||
status = manager.Remove("first");
|
||||
ASSERT_TRUE(status);
|
||||
name = manager.Search(v.data(), threshold);
|
||||
EXPECT_EQ(name, "");
|
||||
|
||||
v = {17, 2};
|
||||
status = manager.Remove("third");
|
||||
ASSERT_TRUE(status);
|
||||
name = manager.Search(v.data(), threshold);
|
||||
EXPECT_EQ(name, "");
|
||||
|
||||
v = {2, 17};
|
||||
status = manager.Remove("second");
|
||||
ASSERT_TRUE(status);
|
||||
name = manager.Search(v.data(), threshold);
|
||||
EXPECT_EQ(name, "");
|
||||
|
||||
ASSERT_EQ(manager.NumSpeakers(), 0);
|
||||
}
|
||||
|
||||
TEST(SpeakerEmbeddingManager, Verify) {
|
||||
int32_t dim = 2;
|
||||
SpeakerEmbeddingManager manager(dim);
|
||||
std::vector<float> v1 = {0.1, 0.1};
|
||||
std::vector<float> v2 = {0.1, 0.9};
|
||||
std::vector<float> v3 = {0.9, 0.1};
|
||||
bool status = manager.Add("first", v1.data());
|
||||
ASSERT_TRUE(status);
|
||||
|
||||
status = manager.Add("second", v2.data());
|
||||
ASSERT_TRUE(status);
|
||||
|
||||
status = manager.Add("third", v3.data());
|
||||
ASSERT_TRUE(status);
|
||||
|
||||
std::vector<float> v = {15, 16};
|
||||
float threshold = 0.9;
|
||||
|
||||
status = manager.Verify("first", v.data(), threshold);
|
||||
ASSERT_TRUE(status);
|
||||
|
||||
v = {2, 17};
|
||||
status = manager.Verify("first", v.data(), threshold);
|
||||
ASSERT_FALSE(status);
|
||||
|
||||
status = manager.Verify("second", v.data(), threshold);
|
||||
ASSERT_TRUE(status);
|
||||
|
||||
v = {17, 2};
|
||||
status = manager.Verify("first", v.data(), threshold);
|
||||
ASSERT_FALSE(status);
|
||||
|
||||
status = manager.Verify("second", v.data(), threshold);
|
||||
ASSERT_FALSE(status);
|
||||
|
||||
status = manager.Verify("third", v.data(), threshold);
|
||||
ASSERT_TRUE(status);
|
||||
|
||||
status = manager.Verify("fourth", v.data(), threshold);
|
||||
ASSERT_FALSE(status);
|
||||
}
|
||||
|
||||
} // namespace sherpa_onnx
|
||||
144
sherpa-onnx/csrc/speaker-embedding-manager.cc
Normal file
144
sherpa-onnx/csrc/speaker-embedding-manager.cc
Normal file
@@ -0,0 +1,144 @@
|
||||
// sherpa-onnx/csrc/speaker-embedding-manager.cc
|
||||
//
|
||||
// Copyright (c) 2023 Xiaomi Corporation
|
||||
|
||||
#include "sherpa-onnx/csrc/speaker-embedding-manager.h"
|
||||
|
||||
#include <algorithm>
|
||||
#include <unordered_map>
|
||||
|
||||
#include "Eigen/Dense"
|
||||
|
||||
namespace sherpa_onnx {
|
||||
|
||||
using FloatMatrix =
|
||||
Eigen::Matrix<float, Eigen::Dynamic, Eigen::Dynamic, Eigen::RowMajor>;
|
||||
|
||||
class SpeakerEmbeddingManager::Impl {
|
||||
public:
|
||||
explicit Impl(int32_t dim) : dim_(dim) {}
|
||||
|
||||
bool Add(const std::string &name, const float *p) {
|
||||
if (name2row_.count(name)) {
|
||||
// a speaker with the same name already exists
|
||||
return false;
|
||||
}
|
||||
|
||||
embedding_matrix_.conservativeResize(embedding_matrix_.rows() + 1, dim_);
|
||||
|
||||
std::copy(p, p + dim_, &embedding_matrix_.bottomRows(1)(0, 0));
|
||||
|
||||
embedding_matrix_.bottomRows(1).normalize(); // inplace
|
||||
|
||||
name2row_[name] = embedding_matrix_.rows() - 1;
|
||||
row2name_[embedding_matrix_.rows() - 1] = name;
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
bool Remove(const std::string &name) {
|
||||
if (!name2row_.count(name)) {
|
||||
return false;
|
||||
}
|
||||
|
||||
int32_t row_idx = name2row_.at(name);
|
||||
|
||||
int32_t num_rows = embedding_matrix_.rows();
|
||||
|
||||
if (row_idx < num_rows - 1) {
|
||||
embedding_matrix_.block(row_idx, 0, num_rows - -1 - row_idx, dim_) =
|
||||
embedding_matrix_.bottomRows(num_rows - 1 - row_idx);
|
||||
}
|
||||
|
||||
embedding_matrix_.conservativeResize(num_rows - 1, dim_);
|
||||
for (auto &p : name2row_) {
|
||||
if (p.second > row_idx) {
|
||||
p.second -= 1;
|
||||
row2name_[p.second] = p.first;
|
||||
}
|
||||
}
|
||||
|
||||
name2row_.erase(name);
|
||||
row2name_.erase(num_rows - 1);
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
std::string Search(const float *p, float threshold) {
|
||||
if (embedding_matrix_.rows() == 0) {
|
||||
return {};
|
||||
}
|
||||
|
||||
Eigen::VectorXf v =
|
||||
Eigen::Map<Eigen::VectorXf>(const_cast<float *>(p), dim_);
|
||||
v.normalize();
|
||||
|
||||
Eigen::VectorXf scores = embedding_matrix_ * v;
|
||||
|
||||
Eigen::VectorXf::Index max_index;
|
||||
float max_score = scores.maxCoeff(&max_index);
|
||||
if (max_score < threshold) {
|
||||
return {};
|
||||
}
|
||||
|
||||
return row2name_.at(max_index);
|
||||
}
|
||||
|
||||
bool Verify(const std::string &name, const float *p, float threshold) {
|
||||
if (!name2row_.count(name)) {
|
||||
return false;
|
||||
}
|
||||
|
||||
int32_t row_idx = name2row_.at(name);
|
||||
|
||||
Eigen::VectorXf v =
|
||||
Eigen::Map<Eigen::VectorXf>(const_cast<float *>(p), dim_);
|
||||
v.normalize();
|
||||
|
||||
float score = embedding_matrix_.row(row_idx) * v;
|
||||
|
||||
if (score < threshold) {
|
||||
return false;
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
int32_t NumSpeakers() const { return embedding_matrix_.rows(); }
|
||||
|
||||
private:
|
||||
int32_t dim_;
|
||||
FloatMatrix embedding_matrix_;
|
||||
std::unordered_map<std::string, int32_t> name2row_;
|
||||
std::unordered_map<int32_t, std::string> row2name_;
|
||||
};
|
||||
|
||||
SpeakerEmbeddingManager::SpeakerEmbeddingManager(int32_t dim)
|
||||
: impl_(std::make_unique<Impl>(dim)) {}
|
||||
|
||||
SpeakerEmbeddingManager::~SpeakerEmbeddingManager() = default;
|
||||
|
||||
bool SpeakerEmbeddingManager::Add(const std::string &name,
|
||||
const float *p) const {
|
||||
return impl_->Add(name, p);
|
||||
}
|
||||
|
||||
bool SpeakerEmbeddingManager::Remove(const std::string &name) const {
|
||||
return impl_->Remove(name);
|
||||
}
|
||||
|
||||
std::string SpeakerEmbeddingManager::Search(const float *p,
|
||||
float threshold) const {
|
||||
return impl_->Search(p, threshold);
|
||||
}
|
||||
|
||||
bool SpeakerEmbeddingManager::Verify(const std::string &name, const float *p,
|
||||
float threshold) const {
|
||||
return impl_->Verify(name, p, threshold);
|
||||
}
|
||||
|
||||
int32_t SpeakerEmbeddingManager::NumSpeakers() const {
|
||||
return impl_->NumSpeakers();
|
||||
}
|
||||
|
||||
} // namespace sherpa_onnx
|
||||
72
sherpa-onnx/csrc/speaker-embedding-manager.h
Normal file
72
sherpa-onnx/csrc/speaker-embedding-manager.h
Normal file
@@ -0,0 +1,72 @@
|
||||
// sherpa-onnx/csrc/speaker-embedding-manager.h
|
||||
//
|
||||
// Copyright (c) 2023 Xiaomi Corporation
|
||||
|
||||
#ifndef SHERPA_ONNX_CSRC_SPEAKER_EMBEDDING_MANAGER_H_
|
||||
#define SHERPA_ONNX_CSRC_SPEAKER_EMBEDDING_MANAGER_H_
|
||||
|
||||
#include <memory>
|
||||
#include <string>
|
||||
|
||||
namespace sherpa_onnx {
|
||||
|
||||
class SpeakerEmbeddingManager {
|
||||
public:
|
||||
// @param dim Embedding dimension.
|
||||
explicit SpeakerEmbeddingManager(int32_t dim);
|
||||
~SpeakerEmbeddingManager();
|
||||
|
||||
/* Add the embedding and name of a speaker to the manager.
|
||||
*
|
||||
* @param name Name of the speaker
|
||||
* @param p Pointer to the embedding. Its length is `dim`.
|
||||
* @return Return true if added successfully. Return false if it failed.
|
||||
* At present, the only reason for a failure is that there is already
|
||||
* a speaker with the same `name`.
|
||||
*/
|
||||
bool Add(const std::string &name, const float *p) const;
|
||||
|
||||
/* Remove a speaker by its name.
|
||||
*
|
||||
* @param name Name of the speaker to remove.
|
||||
* @return Return true if it is removed successfully. Return false
|
||||
* if there is no such a speaker.
|
||||
*/
|
||||
bool Remove(const std::string &name) const;
|
||||
|
||||
/** It is for speaker identification.
|
||||
*
|
||||
* It computes the cosine similarity between and given embedding and all
|
||||
* other embeddings and find the embedding that has the largest score
|
||||
* and the score is above or equal to threshold. Return the speaker
|
||||
* name for the embedding if found; otherwise, it returns an empty string.
|
||||
*
|
||||
* @param p The input embedding.
|
||||
* @param threshold A value between 0 and 1.
|
||||
* @param If found, return the name of the speaker. Otherwise, return an
|
||||
* empty string.
|
||||
*/
|
||||
std::string Search(const float *p, float threshold) const;
|
||||
|
||||
/* Check whether the input embedding matches the embedding of the input
|
||||
* speaker.
|
||||
*
|
||||
* It is for speaker verification.
|
||||
*
|
||||
* @param name The target speaker name.
|
||||
* @param p The input embedding to check.
|
||||
* @param threshold A value between 0 and 1.
|
||||
* @return Return true if it matches. Otherwise, it returns false.
|
||||
*/
|
||||
bool Verify(const std::string &name, const float *p, float threshold) const;
|
||||
|
||||
int32_t NumSpeakers() const;
|
||||
|
||||
private:
|
||||
class Impl;
|
||||
std::unique_ptr<Impl> impl_;
|
||||
};
|
||||
|
||||
} // namespace sherpa_onnx
|
||||
|
||||
#endif // SHERPA_ONNX_CSRC_SPEAKER_EMBEDDING_MANAGER_H_
|
||||
@@ -40,7 +40,9 @@ class VoiceActivityDetector::Impl {
|
||||
|
||||
for (int32_t i = 0; i != k; ++i, p += window_size) {
|
||||
buffer_.Push(p, window_size);
|
||||
is_speech = is_speech || model_->IsSpeech(p, window_size);
|
||||
// NOTE(fangjun): Please don't use a very large n.
|
||||
bool this_window_is_speech = model_->IsSpeech(p, window_size);
|
||||
is_speech = is_speech || this_window_is_speech;
|
||||
}
|
||||
|
||||
last_ = std::vector<float>(
|
||||
@@ -102,6 +104,8 @@ class VoiceActivityDetector::Impl {
|
||||
|
||||
bool IsSpeechDetected() const { return start_ != -1; }
|
||||
|
||||
const VadModelConfig &GetConfig() const { return config_; }
|
||||
|
||||
private:
|
||||
std::queue<SpeechSegment> segments_;
|
||||
|
||||
@@ -146,4 +150,8 @@ bool VoiceActivityDetector::IsSpeechDetected() const {
|
||||
return impl_->IsSpeechDetected();
|
||||
}
|
||||
|
||||
const VadModelConfig &VoiceActivityDetector::GetConfig() const {
|
||||
return impl_->GetConfig();
|
||||
}
|
||||
|
||||
} // namespace sherpa_onnx
|
||||
|
||||
@@ -43,6 +43,8 @@ class VoiceActivityDetector {
|
||||
|
||||
void Reset();
|
||||
|
||||
const VadModelConfig &GetConfig() const;
|
||||
|
||||
private:
|
||||
class Impl;
|
||||
std::unique_ptr<Impl> impl_;
|
||||
|
||||
@@ -30,6 +30,8 @@ pybind11_add_module(_sherpa_onnx
|
||||
online-zipformer2-ctc-model-config.cc
|
||||
sherpa-onnx.cc
|
||||
silero-vad-model-config.cc
|
||||
speaker-embedding-extractor.cc
|
||||
speaker-embedding-manager.cc
|
||||
vad-model-config.cc
|
||||
vad-model.cc
|
||||
voice-activity-detector.cc
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
// sherpa-onnx/python/csrc/online-recongizer.h
|
||||
// sherpa-onnx/python/csrc/online-recognizer.h
|
||||
//
|
||||
// Copyright (c) 2023 Xiaomi Corporation
|
||||
|
||||
|
||||
@@ -18,6 +18,8 @@
|
||||
#include "sherpa-onnx/python/csrc/online-model-config.h"
|
||||
#include "sherpa-onnx/python/csrc/online-recognizer.h"
|
||||
#include "sherpa-onnx/python/csrc/online-stream.h"
|
||||
#include "sherpa-onnx/python/csrc/speaker-embedding-extractor.h"
|
||||
#include "sherpa-onnx/python/csrc/speaker-embedding-manager.h"
|
||||
#include "sherpa-onnx/python/csrc/vad-model-config.h"
|
||||
#include "sherpa-onnx/python/csrc/vad-model.h"
|
||||
#include "sherpa-onnx/python/csrc/voice-activity-detector.h"
|
||||
@@ -48,6 +50,8 @@ PYBIND11_MODULE(_sherpa_onnx, m) {
|
||||
PybindVoiceActivityDetector(&m);
|
||||
|
||||
PybindOfflineTts(&m);
|
||||
PybindSpeakerEmbeddingExtractor(&m);
|
||||
PybindSpeakerEmbeddingManager(&m);
|
||||
}
|
||||
|
||||
} // namespace sherpa_onnx
|
||||
|
||||
44
sherpa-onnx/python/csrc/speaker-embedding-extractor.cc
Normal file
44
sherpa-onnx/python/csrc/speaker-embedding-extractor.cc
Normal file
@@ -0,0 +1,44 @@
|
||||
// sherpa-onnx/python/csrc/speaker-embedding-extractor.cc
|
||||
//
|
||||
// Copyright (c) 2023 Xiaomi Corporation
|
||||
|
||||
#include "sherpa-onnx/python/csrc/speaker-embedding-extractor.h"
|
||||
|
||||
#include <string>
|
||||
|
||||
#include "sherpa-onnx/csrc/speaker-embedding-extractor.h"
|
||||
|
||||
namespace sherpa_onnx {
|
||||
|
||||
static void PybindSpeakerEmbeddingExtractorConfig(py::module *m) {
|
||||
using PyClass = SpeakerEmbeddingExtractorConfig;
|
||||
py::class_<PyClass>(*m, "SpeakerEmbeddingExtractorConfig")
|
||||
.def(py::init<>())
|
||||
.def(py::init<const std::string &, int32_t, bool, const std::string>(),
|
||||
py::arg("model"), py::arg("num_threads") = 1,
|
||||
py::arg("debug") = false, py::arg("provider") = "cpu")
|
||||
.def_readwrite("model", &PyClass::model)
|
||||
.def_readwrite("num_threads", &PyClass::num_threads)
|
||||
.def_readwrite("debug", &PyClass::debug)
|
||||
.def_readwrite("provider", &PyClass::provider)
|
||||
.def("validate", &PyClass::Validate)
|
||||
.def("__str__", &PyClass::ToString);
|
||||
}
|
||||
|
||||
void PybindSpeakerEmbeddingExtractor(py::module *m) {
|
||||
PybindSpeakerEmbeddingExtractorConfig(m);
|
||||
|
||||
using PyClass = SpeakerEmbeddingExtractor;
|
||||
py::class_<PyClass>(*m, "SpeakerEmbeddingExtractor")
|
||||
.def(py::init<const SpeakerEmbeddingExtractorConfig &>(),
|
||||
py::arg("config"), py::call_guard<py::gil_scoped_release>())
|
||||
.def_property_readonly("dim", &PyClass::Dim)
|
||||
.def("create_stream", &PyClass::CreateStream,
|
||||
py::call_guard<py::gil_scoped_release>())
|
||||
.def("compute", &PyClass::Compute,
|
||||
py::call_guard<py::gil_scoped_release>())
|
||||
.def("is_ready", &PyClass::IsReady,
|
||||
py::call_guard<py::gil_scoped_release>());
|
||||
}
|
||||
|
||||
} // namespace sherpa_onnx
|
||||
16
sherpa-onnx/python/csrc/speaker-embedding-extractor.h
Normal file
16
sherpa-onnx/python/csrc/speaker-embedding-extractor.h
Normal file
@@ -0,0 +1,16 @@
|
||||
// sherpa-onnx/python/csrc/speaker-embedding-extractor.h
|
||||
//
|
||||
// Copyright (c) 2023 Xiaomi Corporation
|
||||
|
||||
#ifndef SHERPA_ONNX_PYTHON_CSRC_SPEAKER_EMBEDDING_EXTRACTOR_H_
|
||||
#define SHERPA_ONNX_PYTHON_CSRC_SPEAKER_EMBEDDING_EXTRACTOR_H_
|
||||
|
||||
#include "sherpa-onnx/python/csrc/sherpa-onnx.h"
|
||||
|
||||
namespace sherpa_onnx {
|
||||
|
||||
void PybindSpeakerEmbeddingExtractor(py::module *m);
|
||||
|
||||
}
|
||||
|
||||
#endif // SHERPA_ONNX_PYTHON_CSRC_SPEAKER_EMBEDDING_EXTRACTOR_H_
|
||||
50
sherpa-onnx/python/csrc/speaker-embedding-manager.cc
Normal file
50
sherpa-onnx/python/csrc/speaker-embedding-manager.cc
Normal file
@@ -0,0 +1,50 @@
|
||||
// sherpa-onnx/python/csrc/speaker-embedding-manager.cc
|
||||
//
|
||||
// Copyright (c) 2023 Xiaomi Corporation
|
||||
|
||||
#include "sherpa-onnx/python/csrc/speaker-embedding-manager.h"
|
||||
|
||||
#include <string>
|
||||
#include <vector>
|
||||
|
||||
#include "sherpa-onnx/csrc/speaker-embedding-manager.h"
|
||||
|
||||
namespace sherpa_onnx {
|
||||
|
||||
void PybindSpeakerEmbeddingManager(py::module *m) {
|
||||
using PyClass = SpeakerEmbeddingManager;
|
||||
py::class_<PyClass>(*m, "SpeakerEmbeddingManager")
|
||||
.def(py::init<int32_t>(), py::arg("dim"),
|
||||
py::call_guard<py::gil_scoped_release>())
|
||||
.def_property_readonly("num_speakers", &PyClass::NumSpeakers)
|
||||
.def(
|
||||
"add",
|
||||
[](const PyClass &self, const std::string &name,
|
||||
const std::vector<float> &v) -> bool {
|
||||
return self.Add(name, v.data());
|
||||
},
|
||||
py::arg("name"), py::arg("v"),
|
||||
py::call_guard<py::gil_scoped_release>())
|
||||
.def(
|
||||
"remove",
|
||||
[](const PyClass &self, const std::string &name) -> bool {
|
||||
return self.Remove(name);
|
||||
},
|
||||
py::arg("name"), py::call_guard<py::gil_scoped_release>())
|
||||
.def(
|
||||
"search",
|
||||
[](const PyClass &self, const std::vector<float> &v, float threshold)
|
||||
-> std::string { return self.Search(v.data(), threshold); },
|
||||
py::arg("v"), py::arg("threshold"),
|
||||
py::call_guard<py::gil_scoped_release>())
|
||||
.def(
|
||||
"verify",
|
||||
[](const PyClass &self, const std::string &name,
|
||||
const std::vector<float> &v, float threshold) -> bool {
|
||||
return self.Verify(name, v.data(), threshold);
|
||||
},
|
||||
py::arg("name"), py::arg("v"), py::arg("threshold"),
|
||||
py::call_guard<py::gil_scoped_release>());
|
||||
}
|
||||
|
||||
} // namespace sherpa_onnx
|
||||
16
sherpa-onnx/python/csrc/speaker-embedding-manager.h
Normal file
16
sherpa-onnx/python/csrc/speaker-embedding-manager.h
Normal file
@@ -0,0 +1,16 @@
|
||||
// sherpa-onnx/python/csrc/speaker-embedding-manager.h
|
||||
//
|
||||
// Copyright (c) 2023 Xiaomi Corporation
|
||||
|
||||
#ifndef SHERPA_ONNX_PYTHON_CSRC_SPEAKER_EMBEDDING_MANAGER_H_
|
||||
#define SHERPA_ONNX_PYTHON_CSRC_SPEAKER_EMBEDDING_MANAGER_H_
|
||||
|
||||
#include "sherpa-onnx/python/csrc/sherpa-onnx.h"
|
||||
|
||||
namespace sherpa_onnx {
|
||||
|
||||
void PybindSpeakerEmbeddingManager(py::module *m);
|
||||
|
||||
} // namespace sherpa_onnx
|
||||
|
||||
#endif // SHERPA_ONNX_PYTHON_CSRC_SPEAKER_EMBEDDING_MANAGER_H_
|
||||
@@ -32,6 +32,7 @@ void PybindVoiceActivityDetector(py::module *m) {
|
||||
self.AcceptWaveform(samples.data(), samples.size());
|
||||
},
|
||||
py::arg("samples"), py::call_guard<py::gil_scoped_release>())
|
||||
.def_property_readonly("config", &PyClass::GetConfig)
|
||||
.def("empty", &PyClass::Empty, py::call_guard<py::gil_scoped_release>())
|
||||
.def("pop", &PyClass::Pop, py::call_guard<py::gil_scoped_release>())
|
||||
.def("is_speech_detected", &PyClass::IsSpeechDetected,
|
||||
|
||||
@@ -8,6 +8,9 @@ from _sherpa_onnx import (
|
||||
OfflineTtsVitsModelConfig,
|
||||
OnlineStream,
|
||||
SileroVadModelConfig,
|
||||
SpeakerEmbeddingExtractor,
|
||||
SpeakerEmbeddingExtractorConfig,
|
||||
SpeakerEmbeddingManager,
|
||||
SpeechSegment,
|
||||
VadModel,
|
||||
VadModelConfig,
|
||||
|
||||
Reference in New Issue
Block a user