diff --git a/python-api-examples/speaker-identification.py b/python-api-examples/speaker-identification.py new file mode 100755 index 00000000..20b46639 --- /dev/null +++ b/python-api-examples/speaker-identification.py @@ -0,0 +1,252 @@ +#!/usr/bin/env python3 + +""" +This script shows how to use Python APIs for speaker identification. + +Usage: + +(1) Prepare a text file containing speaker related files. + +Each line in the text file contains two columns. The first column is the +speaker name, while the second column contains the wave file of the speaker. + +If the text file contains multiple wave files for the same speaker, then the +embeddings of these files are averaged. + +An example text file is given below: + + foo /path/to/a.wav + bar /path/to/b.wav + foo /path/to/c.wav + foobar /path/to/d.wav + +Each wave file should contain only a single channel; the sample format +should be int16_t; the sample rate can be arbitrary. + +(2) Download a model for computing speaker embeddings + +Please visit +https://github.com/k2-fsa/sherpa-onnx/releases/tag/speaker-recongition-models +to download a model. An example is given below: + + wget https://github.com/k2-fsa/sherpa-onnx/releases/download/speaker-recongition-models/zh_cnceleb_resnet34.onnx + +Note that `zh` means Chinese, while `en` means English. + +(3) Run this script + +Assume the filename of the text file is speaker.txt. + +python3 ./python-api-examples/speaker-identification.py \ + --speaker-file ./speaker.txt \ + --model ./zh_cnceleb_resnet34.onnx +""" +import argparse +import queue +import threading +from collections import defaultdict +from pathlib import Path +from typing import Dict, List, Tuple + +import numpy as np +import sherpa_onnx +import torchaudio + +try: + import sounddevice as sd +except ImportError: + print("Please install sounddevice first. You can use") + print() + print(" pip install sounddevice") + print() + print("to install it") + sys.exit(-1) + + +def get_args(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + + parser.add_argument( + "--speaker-file", + type=str, + required=True, + help="""Path to the speaker file. Read the help doc at the beginning of this + file for the format.""", + ) + + parser.add_argument( + "--model", + type=str, + required=True, + help="Path to the model file.", + ) + + parser.add_argument("--threshold", type=float, default=0.6) + + parser.add_argument( + "--num-threads", + type=int, + default=1, + help="Number of threads for neural network computation", + ) + + parser.add_argument( + "--debug", + type=bool, + default=False, + help="True to show debug messages", + ) + + parser.add_argument( + "--provider", + type=str, + default="cpu", + help="Valid values: cpu, cuda, coreml", + ) + + return parser.parse_args() + + +def load_speaker_embedding_model(args): + config = sherpa_onnx.SpeakerEmbeddingExtractorConfig( + model=args.model, + num_threads=args.num_threads, + debug=args.debug, + provider=args.provider, + ) + if not config.validate(): + raise ValueError(f"Invalid config. {config}") + extractor = sherpa_onnx.SpeakerEmbeddingExtractor(config) + return extractor + + +def load_speaker_file(args) -> Dict[str, List[str]]: + if not Path(args.speaker_file).is_file(): + raise ValueError(f"--speaker-file {args.speaker_file} does not exist") + + ans = defaultdict(list) + with open(args.speaker_file) as f: + for line in f: + line = line.strip() + if not line: + continue + + fields = line.split() + if len(fields) != 2: + raise ValueError(f"Invalid line: {line}. Fields: {fields}") + + speaker_name, filename = fields + ans[speaker_name].append(filename) + return ans + + +def load_audio(filename: str) -> Tuple[np.ndarray, int]: + samples, sample_rate = torchaudio.load(filename) + return samples[0].contiguous().numpy(), sample_rate + + +def compute_speaker_embedding( + filenames: List[str], + extractor: sherpa_onnx.SpeakerEmbeddingExtractor, +) -> np.ndarray: + assert len(filenames) > 0, f"filenames is empty" + + ans = None + for filename in filenames: + print(f"processing {filename}") + samples, sample_rate = load_audio(filename) + stream = extractor.create_stream() + stream.accept_waveform(sample_rate=sample_rate, waveform=samples) + stream.input_finished() + + assert extractor.is_ready(stream) + embedding = extractor.compute(stream) + embedding = np.array(embedding) + if ans is None: + ans = embedding + else: + ans += embedding + + return ans / len(filenames) + + +g_buffer = queue.Queue() +g_stop = False +g_sample_rate = 16000 +g_read_mic_thread = None + + +def read_mic(): + print("Please speak!") + samples_per_read = int(0.1 * g_sample_rate) # 0.1 second = 100 ms + with sd.InputStream(channels=1, dtype="float32", samplerate=g_sample_rate) as s: + while not g_stop: + samples, _ = s.read(samples_per_read) # a blocking read + g_buffer.put(samples) + + +def main(): + args = get_args() + print(args) + extractor = load_speaker_embedding_model(args) + speaker_file = load_speaker_file(args) + + manager = sherpa_onnx.SpeakerEmbeddingManager(extractor.dim) + for name, filename_list in speaker_file.items(): + embedding = compute_speaker_embedding( + filenames=filename_list, + extractor=extractor, + ) + status = manager.add(name, embedding) + if not status: + raise RuntimeError(f"Failed to register speaker {name}") + + devices = sd.query_devices() + if len(devices) == 0: + print("No microphone devices found") + sys.exit(0) + + print(devices) + default_input_device_idx = sd.default.device[0] + print(f'Use default device: {devices[default_input_device_idx]["name"]}') + + global g_stop + global g_read_mic_thread + while True: + key = input("Press enter to start recording") + if key.lower() in ("q", "quit"): + g_stop = True + break + + g_stop = False + g_buffer.queue.clear() + g_read_mic_thread = threading.Thread(target=read_mic) + g_read_mic_thread.start() + input("Press enter to stop recording") + g_stop = True + g_read_mic_thread.join() + print("Compute embedding") + stream = extractor.create_stream() + while not g_buffer.empty(): + samples = g_buffer.get() + stream.accept_waveform(sample_rate=g_sample_rate, waveform=samples) + stream.input_finished() + + embedding = extractor.compute(stream) + embedding = np.array(embedding) + name = manager.search(embedding, threshold=args.threshold) + if not name: + name = "unknown" + print(f"Predicted name: {name}") + + +if __name__ == "__main__": + try: + main() + except KeyboardInterrupt: + print("\nCaught Ctrl + C. Exiting") + g_stop = True + if g_read_mic_thread.is_alive(): + g_read_mic_thread.join() diff --git a/sherpa-onnx/csrc/CMakeLists.txt b/sherpa-onnx/csrc/CMakeLists.txt index 2b325b40..ba0c5645 100644 --- a/sherpa-onnx/csrc/CMakeLists.txt +++ b/sherpa-onnx/csrc/CMakeLists.txt @@ -96,6 +96,14 @@ set(sources wave-reader.cc ) +# speaker embedding extractor +list(APPEND sources + speaker-embedding-extractor-impl.cc + speaker-embedding-extractor-wespeaker-model.cc + speaker-embedding-extractor.cc + speaker-embedding-manager.cc +) + list(APPEND sources lexicon.cc offline-tts-impl.cc @@ -387,6 +395,10 @@ if(SHERPA_ONNX_ENABLE_TESTS) utfcpp-test.cc ) + list(APPEND sherpa_onnx_test_srcs + speaker-embedding-manager-test.cc + ) + function(sherpa_onnx_add_test source) get_filename_component(name ${source} NAME_WE) set(target_name ${name}) diff --git a/sherpa-onnx/csrc/context-graph-test.cc b/sherpa-onnx/csrc/context-graph-test.cc index 029fecf4..6ad7a5c7 100644 --- a/sherpa-onnx/csrc/context-graph-test.cc +++ b/sherpa-onnx/csrc/context-graph-test.cc @@ -64,8 +64,8 @@ TEST(ContextGraph, Benchmark) { auto stop = std::chrono::high_resolution_clock::now(); auto duration = std::chrono::duration_cast(stop - start); - SHERPA_ONNX_LOGE("Construct context graph for %d item takes %ld us.", num, - duration.count()); + SHERPA_ONNX_LOGE("Construct context graph for %d item takes %d us.", num, + static_cast(duration.count())); } } diff --git a/sherpa-onnx/csrc/session.cc b/sherpa-onnx/csrc/session.cc index 23c8a9cb..6080bae2 100644 --- a/sherpa-onnx/csrc/session.cc +++ b/sherpa-onnx/csrc/session.cc @@ -91,4 +91,9 @@ Ort::SessionOptions GetSessionOptions(const OfflineTtsModelConfig &config) { return GetSessionOptionsImpl(config.num_threads, config.provider); } +Ort::SessionOptions GetSessionOptions( + const SpeakerEmbeddingExtractorConfig &config) { + return GetSessionOptionsImpl(config.num_threads, config.provider); +} + } // namespace sherpa_onnx diff --git a/sherpa-onnx/csrc/session.h b/sherpa-onnx/csrc/session.h index 25675fa2..53cc22b7 100644 --- a/sherpa-onnx/csrc/session.h +++ b/sherpa-onnx/csrc/session.h @@ -11,6 +11,7 @@ #include "sherpa-onnx/csrc/offline-tts-model-config.h" #include "sherpa-onnx/csrc/online-lm-config.h" #include "sherpa-onnx/csrc/online-model-config.h" +#include "sherpa-onnx/csrc/speaker-embedding-extractor.h" #include "sherpa-onnx/csrc/vad-model-config.h" namespace sherpa_onnx { @@ -26,6 +27,9 @@ Ort::SessionOptions GetSessionOptions(const OnlineLMConfig &config); Ort::SessionOptions GetSessionOptions(const VadModelConfig &config); Ort::SessionOptions GetSessionOptions(const OfflineTtsModelConfig &config); + +Ort::SessionOptions GetSessionOptions( + const SpeakerEmbeddingExtractorConfig &config); } // namespace sherpa_onnx #endif // SHERPA_ONNX_CSRC_SESSION_H_ diff --git a/sherpa-onnx/csrc/speaker-embedding-extractor-impl.cc b/sherpa-onnx/csrc/speaker-embedding-extractor-impl.cc new file mode 100644 index 00000000..6dff5ac5 --- /dev/null +++ b/sherpa-onnx/csrc/speaker-embedding-extractor-impl.cc @@ -0,0 +1,82 @@ +// sherpa-onnx/csrc/speaker-embedding-extractor-impl.cc +// +// Copyright (c) 2023 Xiaomi Corporation +#include "sherpa-onnx/csrc/speaker-embedding-extractor-impl.h" + +#include "sherpa-onnx/csrc/macros.h" +#include "sherpa-onnx/csrc/onnx-utils.h" +#include "sherpa-onnx/csrc/speaker-embedding-extractor-wespeaker-impl.h" + +namespace sherpa_onnx { + +namespace { + +enum class ModelType { + kWeSpeaker, + kUnkown, +}; + +} // namespace + +static ModelType GetModelType(char *model_data, size_t model_data_length, + bool debug) { + Ort::Env env(ORT_LOGGING_LEVEL_WARNING); + Ort::SessionOptions sess_opts; + + auto sess = std::make_unique(env, model_data, model_data_length, + sess_opts); + + Ort::ModelMetadata meta_data = sess->GetModelMetadata(); + if (debug) { + std::ostringstream os; + PrintModelMetadata(os, meta_data); + SHERPA_ONNX_LOGE("%s", os.str().c_str()); + } + + Ort::AllocatorWithDefaultOptions allocator; + auto model_type = + meta_data.LookupCustomMetadataMapAllocated("framework", allocator); + if (!model_type) { + SHERPA_ONNX_LOGE( + "No model_type in the metadata!\n" + "Please make sure you have added metadata to the model.\n\n" + "For instance, you can use\n" + "https://github.com/k2-fsa/sherpa-onnx/blob/master/scripts/wespeaker/" + "add_meta_data.py" + "to add metadata to models from WeSpeaker\n"); + return ModelType::kUnkown; + } + + if (model_type.get() == std::string("wespeaker")) { + return ModelType::kWeSpeaker; + } else { + SHERPA_ONNX_LOGE("Unsupported model_type: %s", model_type.get()); + return ModelType::kUnkown; + } +} + +std::unique_ptr +SpeakerEmbeddingExtractorImpl::Create( + const SpeakerEmbeddingExtractorConfig &config) { + ModelType model_type = ModelType::kUnkown; + + { + auto buffer = ReadFile(config.model); + + model_type = GetModelType(buffer.data(), buffer.size(), config.debug); + } + + switch (model_type) { + case ModelType::kWeSpeaker: + return std::make_unique(config); + case ModelType::kUnkown: + SHERPA_ONNX_LOGE( + "Unknown model type in for speaker embedding extractor!"); + return nullptr; + } + + // unreachable code + return nullptr; +} + +} // namespace sherpa_onnx diff --git a/sherpa-onnx/csrc/speaker-embedding-extractor-impl.h b/sherpa-onnx/csrc/speaker-embedding-extractor-impl.h new file mode 100644 index 00000000..fa84b43e --- /dev/null +++ b/sherpa-onnx/csrc/speaker-embedding-extractor-impl.h @@ -0,0 +1,34 @@ +// sherpa-onnx/csrc/speaker-embedding-extractor-impl.h +// +// Copyright (c) 2023 Xiaomi Corporation + +#ifndef SHERPA_ONNX_CSRC_SPEAKER_EMBEDDING_EXTRACTOR_IMPL_H_ +#define SHERPA_ONNX_CSRC_SPEAKER_EMBEDDING_EXTRACTOR_IMPL_H_ + +#include +#include +#include + +#include "sherpa-onnx/csrc/speaker-embedding-extractor.h" + +namespace sherpa_onnx { + +class SpeakerEmbeddingExtractorImpl { + public: + virtual ~SpeakerEmbeddingExtractorImpl() = default; + + static std::unique_ptr Create( + const SpeakerEmbeddingExtractorConfig &config); + + virtual int32_t Dim() const = 0; + + virtual std::unique_ptr CreateStream() const = 0; + + virtual bool IsReady(OnlineStream *s) const = 0; + + virtual std::vector Compute(OnlineStream *s) const = 0; +}; + +} // namespace sherpa_onnx + +#endif // SHERPA_ONNX_CSRC_SPEAKER_EMBEDDING_EXTRACTOR_IMPL_H_ diff --git a/sherpa-onnx/csrc/speaker-embedding-extractor-wespeaker-impl.h b/sherpa-onnx/csrc/speaker-embedding-extractor-wespeaker-impl.h new file mode 100644 index 00000000..f69de574 --- /dev/null +++ b/sherpa-onnx/csrc/speaker-embedding-extractor-wespeaker-impl.h @@ -0,0 +1,79 @@ +// sherpa-onnx/csrc/speaker-embedding-extractor-wespeaker-impl.h +// +// Copyright (c) 2023 Xiaomi Corporation + +#ifndef SHERPA_ONNX_CSRC_SPEAKER_EMBEDDING_EXTRACTOR_WESPEAKER_IMPL_H_ +#define SHERPA_ONNX_CSRC_SPEAKER_EMBEDDING_EXTRACTOR_WESPEAKER_IMPL_H_ +#include +#include +#include +#include + +#include "sherpa-onnx/csrc/speaker-embedding-extractor-impl.h" +#include "sherpa-onnx/csrc/speaker-embedding-extractor-wespeaker-model.h" + +namespace sherpa_onnx { + +class SpeakerEmbeddingExtractorWeSpeakerImpl + : public SpeakerEmbeddingExtractorImpl { + public: + explicit SpeakerEmbeddingExtractorWeSpeakerImpl( + const SpeakerEmbeddingExtractorConfig &config) + : model_(config) {} + + int32_t Dim() const override { return model_.GetMetaData().output_dim; } + + std::unique_ptr CreateStream() const override { + FeatureExtractorConfig feat_config; + auto meta_data = model_.GetMetaData(); + feat_config.sampling_rate = meta_data.sample_rate; + feat_config.normalize_samples = meta_data.normalize_features; + + return std::make_unique(feat_config); + } + + bool IsReady(OnlineStream *s) const override { + return s->GetNumProcessedFrames() < s->NumFramesReady(); + } + + std::vector Compute(OnlineStream *s) const override { + int32_t num_frames = s->NumFramesReady() - s->GetNumProcessedFrames(); + if (num_frames <= 0) { + SHERPA_ONNX_LOGE( + "Please make sure IsReady(s) returns true. num_frames: %d", + num_frames); + return {}; + } + + std::vector features = + s->GetFrames(s->GetNumProcessedFrames(), num_frames); + + s->GetNumProcessedFrames() += num_frames; + + int32_t feat_dim = features.size() / num_frames; + + auto memory_info = + Ort::MemoryInfo::CreateCpu(OrtDeviceAllocator, OrtMemTypeDefault); + + std::array x_shape{1, num_frames, feat_dim}; + Ort::Value x = + Ort::Value::CreateTensor(memory_info, features.data(), features.size(), + x_shape.data(), x_shape.size()); + Ort::Value embedding = model_.Compute(std::move(x)); + std::vector embedding_shape = + embedding.GetTensorTypeAndShapeInfo().GetShape(); + + std::vector ans(embedding_shape[1]); + std::copy(embedding.GetTensorData(), + embedding.GetTensorData() + ans.size(), ans.begin()); + + return ans; + } + + private: + SpeakerEmbeddingExtractorWeSpeakerModel model_; +}; + +} // namespace sherpa_onnx + +#endif // SHERPA_ONNX_CSRC_SPEAKER_EMBEDDING_EXTRACTOR_WESPEAKER_IMPL_H_ diff --git a/sherpa-onnx/csrc/speaker-embedding-extractor-wespeaker-model-metadata.h b/sherpa-onnx/csrc/speaker-embedding-extractor-wespeaker-model-metadata.h new file mode 100644 index 00000000..4d8997c4 --- /dev/null +++ b/sherpa-onnx/csrc/speaker-embedding-extractor-wespeaker-model-metadata.h @@ -0,0 +1,20 @@ +// sherpa-onnx/csrc/speaker-embedding-extractor-wespeaker-model-metadata.h +// +// Copyright (c) 2023 Xiaomi Corporation +#ifndef SHERPA_ONNX_CSRC_SPEAKER_EMBEDDING_EXTRACTOR_WESPEAKER_MODEL_METADATA_H_ +#define SHERPA_ONNX_CSRC_SPEAKER_EMBEDDING_EXTRACTOR_WESPEAKER_MODEL_METADATA_H_ + +#include +#include + +namespace sherpa_onnx { + +struct SpeakerEmbeddingExtractorWeSpeakerModelMetaData { + int32_t output_dim = 0; + int32_t sample_rate = 0; + int32_t normalize_features = 0; + std::string language; +}; + +} // namespace sherpa_onnx +#endif // SHERPA_ONNX_CSRC_SPEAKER_EMBEDDING_EXTRACTOR_WESPEAKER_MODEL_METADATA_H_ diff --git a/sherpa-onnx/csrc/speaker-embedding-extractor-wespeaker-model.cc b/sherpa-onnx/csrc/speaker-embedding-extractor-wespeaker-model.cc new file mode 100644 index 00000000..b23cc95e --- /dev/null +++ b/sherpa-onnx/csrc/speaker-embedding-extractor-wespeaker-model.cc @@ -0,0 +1,112 @@ +// sherpa-onnx/csrc/speaker-embedding-extractor-wespeaker-model.cc +// +// Copyright (c) 2023 Xiaomi Corporation + +#include "sherpa-onnx/csrc/speaker-embedding-extractor-wespeaker-model.h" + +#include +#include +#include + +#include "sherpa-onnx/csrc/macros.h" +#include "sherpa-onnx/csrc/onnx-utils.h" +#include "sherpa-onnx/csrc/session.h" +#include "sherpa-onnx/csrc/speaker-embedding-extractor-wespeaker-model-metadata.h" + +namespace sherpa_onnx { + +class SpeakerEmbeddingExtractorWeSpeakerModel::Impl { + public: + explicit Impl(const SpeakerEmbeddingExtractorConfig &config) + : config_(config), + env_(ORT_LOGGING_LEVEL_ERROR), + sess_opts_(GetSessionOptions(config)), + allocator_{} { + { + auto buf = ReadFile(config.model); + Init(buf.data(), buf.size()); + } + } + + Ort::Value Compute(Ort::Value x) const { + std::array inputs = {std::move(x)}; + + auto outputs = + sess_->Run({}, input_names_ptr_.data(), inputs.data(), inputs.size(), + output_names_ptr_.data(), output_names_ptr_.size()); + return std::move(outputs[0]); + } + + const SpeakerEmbeddingExtractorWeSpeakerModelMetaData &GetMetaData() const { + return meta_data_; + } + + private: + void Init(void *model_data, size_t model_data_length) { + sess_ = std::make_unique(env_, model_data, model_data_length, + sess_opts_); + + GetInputNames(sess_.get(), &input_names_, &input_names_ptr_); + + GetOutputNames(sess_.get(), &output_names_, &output_names_ptr_); + + // get meta data + Ort::ModelMetadata meta_data = sess_->GetModelMetadata(); + if (config_.debug) { + std::ostringstream os; + PrintModelMetadata(os, meta_data); + SHERPA_ONNX_LOGE("%s", os.str().c_str()); + } + + Ort::AllocatorWithDefaultOptions allocator; // used in the macro below + SHERPA_ONNX_READ_META_DATA(meta_data_.output_dim, "output_dim"); + SHERPA_ONNX_READ_META_DATA(meta_data_.sample_rate, "sample_rate"); + SHERPA_ONNX_READ_META_DATA(meta_data_.normalize_features, + "normalize_features"); + SHERPA_ONNX_READ_META_DATA_STR(meta_data_.language, "language"); + + std::string framework; + SHERPA_ONNX_READ_META_DATA_STR(framework, "framework"); + if (framework != "wespeaker") { + SHERPA_ONNX_LOGE("Expect a wespeaker model, given: %s", + framework.c_str()); + exit(-1); + } + } + + private: + SpeakerEmbeddingExtractorConfig config_; + Ort::Env env_; + Ort::SessionOptions sess_opts_; + Ort::AllocatorWithDefaultOptions allocator_; + + std::unique_ptr sess_; + + std::vector input_names_; + std::vector input_names_ptr_; + + std::vector output_names_; + std::vector output_names_ptr_; + + SpeakerEmbeddingExtractorWeSpeakerModelMetaData meta_data_; +}; + +SpeakerEmbeddingExtractorWeSpeakerModel:: + SpeakerEmbeddingExtractorWeSpeakerModel( + const SpeakerEmbeddingExtractorConfig &config) + : impl_(std::make_unique(config)) {} + +SpeakerEmbeddingExtractorWeSpeakerModel:: + ~SpeakerEmbeddingExtractorWeSpeakerModel() = default; + +const SpeakerEmbeddingExtractorWeSpeakerModelMetaData & +SpeakerEmbeddingExtractorWeSpeakerModel::GetMetaData() const { + return impl_->GetMetaData(); +} + +Ort::Value SpeakerEmbeddingExtractorWeSpeakerModel::Compute( + Ort::Value x) const { + return impl_->Compute(std::move(x)); +} + +} // namespace sherpa_onnx diff --git a/sherpa-onnx/csrc/speaker-embedding-extractor-wespeaker-model.h b/sherpa-onnx/csrc/speaker-embedding-extractor-wespeaker-model.h new file mode 100644 index 00000000..f0b910f3 --- /dev/null +++ b/sherpa-onnx/csrc/speaker-embedding-extractor-wespeaker-model.h @@ -0,0 +1,37 @@ +// sherpa-onnx/csrc/speaker-embedding-extractor-wespeaker-model.h +// +// Copyright (c) 2023 Xiaomi Corporation +#ifndef SHERPA_ONNX_CSRC_SPEAKER_EMBEDDING_EXTRACTOR_WESPEAKER_MODEL_H_ +#define SHERPA_ONNX_CSRC_SPEAKER_EMBEDDING_EXTRACTOR_WESPEAKER_MODEL_H_ + +#include + +#include "onnxruntime_cxx_api.h" // NOLINT +#include "sherpa-onnx/csrc/speaker-embedding-extractor-wespeaker-model-metadata.h" +#include "sherpa-onnx/csrc/speaker-embedding-extractor.h" + +namespace sherpa_onnx { + +class SpeakerEmbeddingExtractorWeSpeakerModel { + public: + explicit SpeakerEmbeddingExtractorWeSpeakerModel( + const SpeakerEmbeddingExtractorConfig &config); + + ~SpeakerEmbeddingExtractorWeSpeakerModel(); + + const SpeakerEmbeddingExtractorWeSpeakerModelMetaData &GetMetaData() const; + + /** + * @param x A float32 tensor of shape (N, T, C) + * @return A float32 tensor of shape (N, C) + */ + Ort::Value Compute(Ort::Value x) const; + + private: + class Impl; + std::unique_ptr impl_; +}; + +} // namespace sherpa_onnx + +#endif // SHERPA_ONNX_CSRC_SPEAKER_EMBEDDING_EXTRACTOR_WESPEAKER_MODEL_H_ diff --git a/sherpa-onnx/csrc/speaker-embedding-extractor.cc b/sherpa-onnx/csrc/speaker-embedding-extractor.cc new file mode 100644 index 00000000..7826e4fb --- /dev/null +++ b/sherpa-onnx/csrc/speaker-embedding-extractor.cc @@ -0,0 +1,74 @@ +// sherpa-onnx/csrc/speaker-embedding-extractor.cc +// +// Copyright (c) 2023 Xiaomi Corporation + +#include "sherpa-onnx/csrc/speaker-embedding-extractor.h" + +#include + +#include "sherpa-onnx/csrc/file-utils.h" +#include "sherpa-onnx/csrc/macros.h" +#include "sherpa-onnx/csrc/speaker-embedding-extractor-impl.h" + +namespace sherpa_onnx { + +void SpeakerEmbeddingExtractorConfig::Register(ParseOptions *po) { + po->Register("model", &model, "Path to the speaker embedding model."); + po->Register("num-threads", &num_threads, + "Number of threads to run the neural network"); + + po->Register("debug", &debug, + "true to print model information while loading it."); + + po->Register("provider", &provider, + "Specify a provider to use: cpu, cuda, coreml"); +} + +bool SpeakerEmbeddingExtractorConfig::Validate() const { + if (model.empty()) { + SHERPA_ONNX_LOGE("Please provide --speaker-embedding-model"); + return false; + } + + if (!FileExists(model)) { + SHERPA_ONNX_LOGE("--speaker-embedding-model: %s does not exist", + model.c_str()); + return false; + } + + return true; +} + +std::string SpeakerEmbeddingExtractorConfig::ToString() const { + std::ostringstream os; + + os << "SpeakerEmbeddingExtractorConfig("; + os << "model=\"" << model << "\", "; + os << "num_threads=" << num_threads << ", "; + os << "debug=" << (debug ? "True" : "False") << ", "; + os << "provider=\"" << provider << "\")"; + + return os.str(); +} + +SpeakerEmbeddingExtractor::SpeakerEmbeddingExtractor( + const SpeakerEmbeddingExtractorConfig &config) + : impl_(SpeakerEmbeddingExtractorImpl::Create(config)) {} + +SpeakerEmbeddingExtractor::~SpeakerEmbeddingExtractor() = default; + +int32_t SpeakerEmbeddingExtractor::Dim() const { return impl_->Dim(); } + +std::unique_ptr SpeakerEmbeddingExtractor::CreateStream() const { + return impl_->CreateStream(); +} + +bool SpeakerEmbeddingExtractor::IsReady(OnlineStream *s) const { + return impl_->IsReady(s); +} + +std::vector SpeakerEmbeddingExtractor::Compute(OnlineStream *s) const { + return impl_->Compute(s); +} + +} // namespace sherpa_onnx diff --git a/sherpa-onnx/csrc/speaker-embedding-extractor.h b/sherpa-onnx/csrc/speaker-embedding-extractor.h new file mode 100644 index 00000000..cb23d40c --- /dev/null +++ b/sherpa-onnx/csrc/speaker-embedding-extractor.h @@ -0,0 +1,67 @@ +// sherpa-onnx/csrc/speaker-embedding-extractor.h +// +// Copyright (c) 2023 Xiaomi Corporation + +#ifndef SHERPA_ONNX_CSRC_SPEAKER_EMBEDDING_EXTRACTOR_H_ +#define SHERPA_ONNX_CSRC_SPEAKER_EMBEDDING_EXTRACTOR_H_ + +#include +#include +#include + +#include "sherpa-onnx/csrc/online-stream.h" +#include "sherpa-onnx/csrc/parse-options.h" + +namespace sherpa_onnx { + +struct SpeakerEmbeddingExtractorConfig { + std::string model; + int32_t num_threads = 1; + bool debug = false; + std::string provider = "cpu"; + + SpeakerEmbeddingExtractorConfig() = default; + SpeakerEmbeddingExtractorConfig(const std::string &model, int32_t num_threads, + bool debug, const std::string &provider) + : model(model), + num_threads(num_threads), + debug(debug), + provider(provider) {} + + void Register(ParseOptions *po); + bool Validate() const; + std::string ToString() const; +}; + +class SpeakerEmbeddingExtractorImpl; + +class SpeakerEmbeddingExtractor { + public: + explicit SpeakerEmbeddingExtractor( + const SpeakerEmbeddingExtractorConfig &config); + + ~SpeakerEmbeddingExtractor(); + + // Return the dimension of the embedding + int32_t Dim() const; + + // Create a stream to accept audio samples and compute features + std::unique_ptr CreateStream() const; + + // Return true if there are feature frames in OnlineStream that + // can be used to compute embeddings. + bool IsReady(OnlineStream *s) const; + + // Compute the speaker embedding from the available unprocessed features + // of the given stream + // + // You have to ensure IsReady(s) returns true before you call this method. + std::vector Compute(OnlineStream *s) const; + + private: + std::unique_ptr impl_; +}; + +} // namespace sherpa_onnx + +#endif // SHERPA_ONNX_CSRC_SPEAKER_EMBEDDING_EXTRACTOR_H_ diff --git a/sherpa-onnx/csrc/speaker-embedding-manager-test.cc b/sherpa-onnx/csrc/speaker-embedding-manager-test.cc new file mode 100644 index 00000000..0e1603c2 --- /dev/null +++ b/sherpa-onnx/csrc/speaker-embedding-manager-test.cc @@ -0,0 +1,147 @@ +// sherpa-onnx/csrc/speaker-embedding-manager-test.cc +// +// Copyright (c) 2023 Jingzhao Ou (jingzhao.ou@gmail.com) + +#include "sherpa-onnx/csrc/speaker-embedding-manager.h" + +#include "gtest/gtest.h" + +namespace sherpa_onnx { + +TEST(SpeakerEmbeddingManager, AddAndRemove) { + int32_t dim = 2; + SpeakerEmbeddingManager manager(dim); + std::vector v = {0.1, 0.1}; + bool status = manager.Add("first", v.data()); + ASSERT_TRUE(status); + ASSERT_EQ(manager.NumSpeakers(), 1); + + // duplicate + status = manager.Add("first", v.data()); + ASSERT_FALSE(status); + ASSERT_EQ(manager.NumSpeakers(), 1); + + // non-duplicate + v = {0.1, 0.9}; + status = manager.Add("second", v.data()); + ASSERT_TRUE(status); + ASSERT_EQ(manager.NumSpeakers(), 2); + + // do not exist + status = manager.Remove("third"); + ASSERT_FALSE(status); + + status = manager.Remove("first"); + ASSERT_TRUE(status); + ASSERT_EQ(manager.NumSpeakers(), 1); + + v = {0.1, 0.1}; + status = manager.Add("first", v.data()); + ASSERT_TRUE(status); + ASSERT_EQ(manager.NumSpeakers(), 2); + + status = manager.Remove("first"); + ASSERT_TRUE(status); + ASSERT_EQ(manager.NumSpeakers(), 1); + + status = manager.Remove("second"); + ASSERT_TRUE(status); + ASSERT_EQ(manager.NumSpeakers(), 0); +} + +TEST(SpeakerEmbeddingManager, Search) { + int32_t dim = 2; + SpeakerEmbeddingManager manager(dim); + std::vector v1 = {0.1, 0.1}; + std::vector v2 = {0.1, 0.9}; + std::vector v3 = {0.9, 0.1}; + bool status = manager.Add("first", v1.data()); + ASSERT_TRUE(status); + + status = manager.Add("second", v2.data()); + ASSERT_TRUE(status); + + status = manager.Add("third", v3.data()); + ASSERT_TRUE(status); + + ASSERT_EQ(manager.NumSpeakers(), 3); + + std::vector v = {15, 16}; + float threshold = 0.9; + + std::string name = manager.Search(v.data(), threshold); + EXPECT_EQ(name, "first"); + + v = {2, 17}; + name = manager.Search(v.data(), threshold); + EXPECT_EQ(name, "second"); + + v = {17, 2}; + name = manager.Search(v.data(), threshold); + EXPECT_EQ(name, "third"); + + threshold = 0.9; + v = {15, 16}; + status = manager.Remove("first"); + ASSERT_TRUE(status); + name = manager.Search(v.data(), threshold); + EXPECT_EQ(name, ""); + + v = {17, 2}; + status = manager.Remove("third"); + ASSERT_TRUE(status); + name = manager.Search(v.data(), threshold); + EXPECT_EQ(name, ""); + + v = {2, 17}; + status = manager.Remove("second"); + ASSERT_TRUE(status); + name = manager.Search(v.data(), threshold); + EXPECT_EQ(name, ""); + + ASSERT_EQ(manager.NumSpeakers(), 0); +} + +TEST(SpeakerEmbeddingManager, Verify) { + int32_t dim = 2; + SpeakerEmbeddingManager manager(dim); + std::vector v1 = {0.1, 0.1}; + std::vector v2 = {0.1, 0.9}; + std::vector v3 = {0.9, 0.1}; + bool status = manager.Add("first", v1.data()); + ASSERT_TRUE(status); + + status = manager.Add("second", v2.data()); + ASSERT_TRUE(status); + + status = manager.Add("third", v3.data()); + ASSERT_TRUE(status); + + std::vector v = {15, 16}; + float threshold = 0.9; + + status = manager.Verify("first", v.data(), threshold); + ASSERT_TRUE(status); + + v = {2, 17}; + status = manager.Verify("first", v.data(), threshold); + ASSERT_FALSE(status); + + status = manager.Verify("second", v.data(), threshold); + ASSERT_TRUE(status); + + v = {17, 2}; + status = manager.Verify("first", v.data(), threshold); + ASSERT_FALSE(status); + + status = manager.Verify("second", v.data(), threshold); + ASSERT_FALSE(status); + + status = manager.Verify("third", v.data(), threshold); + ASSERT_TRUE(status); + + status = manager.Verify("fourth", v.data(), threshold); + ASSERT_FALSE(status); +} + +} // namespace sherpa_onnx diff --git a/sherpa-onnx/csrc/speaker-embedding-manager.cc b/sherpa-onnx/csrc/speaker-embedding-manager.cc new file mode 100644 index 00000000..02894436 --- /dev/null +++ b/sherpa-onnx/csrc/speaker-embedding-manager.cc @@ -0,0 +1,144 @@ +// sherpa-onnx/csrc/speaker-embedding-manager.cc +// +// Copyright (c) 2023 Xiaomi Corporation + +#include "sherpa-onnx/csrc/speaker-embedding-manager.h" + +#include +#include + +#include "Eigen/Dense" + +namespace sherpa_onnx { + +using FloatMatrix = + Eigen::Matrix; + +class SpeakerEmbeddingManager::Impl { + public: + explicit Impl(int32_t dim) : dim_(dim) {} + + bool Add(const std::string &name, const float *p) { + if (name2row_.count(name)) { + // a speaker with the same name already exists + return false; + } + + embedding_matrix_.conservativeResize(embedding_matrix_.rows() + 1, dim_); + + std::copy(p, p + dim_, &embedding_matrix_.bottomRows(1)(0, 0)); + + embedding_matrix_.bottomRows(1).normalize(); // inplace + + name2row_[name] = embedding_matrix_.rows() - 1; + row2name_[embedding_matrix_.rows() - 1] = name; + + return true; + } + + bool Remove(const std::string &name) { + if (!name2row_.count(name)) { + return false; + } + + int32_t row_idx = name2row_.at(name); + + int32_t num_rows = embedding_matrix_.rows(); + + if (row_idx < num_rows - 1) { + embedding_matrix_.block(row_idx, 0, num_rows - -1 - row_idx, dim_) = + embedding_matrix_.bottomRows(num_rows - 1 - row_idx); + } + + embedding_matrix_.conservativeResize(num_rows - 1, dim_); + for (auto &p : name2row_) { + if (p.second > row_idx) { + p.second -= 1; + row2name_[p.second] = p.first; + } + } + + name2row_.erase(name); + row2name_.erase(num_rows - 1); + + return true; + } + + std::string Search(const float *p, float threshold) { + if (embedding_matrix_.rows() == 0) { + return {}; + } + + Eigen::VectorXf v = + Eigen::Map(const_cast(p), dim_); + v.normalize(); + + Eigen::VectorXf scores = embedding_matrix_ * v; + + Eigen::VectorXf::Index max_index; + float max_score = scores.maxCoeff(&max_index); + if (max_score < threshold) { + return {}; + } + + return row2name_.at(max_index); + } + + bool Verify(const std::string &name, const float *p, float threshold) { + if (!name2row_.count(name)) { + return false; + } + + int32_t row_idx = name2row_.at(name); + + Eigen::VectorXf v = + Eigen::Map(const_cast(p), dim_); + v.normalize(); + + float score = embedding_matrix_.row(row_idx) * v; + + if (score < threshold) { + return false; + } + + return true; + } + + int32_t NumSpeakers() const { return embedding_matrix_.rows(); } + + private: + int32_t dim_; + FloatMatrix embedding_matrix_; + std::unordered_map name2row_; + std::unordered_map row2name_; +}; + +SpeakerEmbeddingManager::SpeakerEmbeddingManager(int32_t dim) + : impl_(std::make_unique(dim)) {} + +SpeakerEmbeddingManager::~SpeakerEmbeddingManager() = default; + +bool SpeakerEmbeddingManager::Add(const std::string &name, + const float *p) const { + return impl_->Add(name, p); +} + +bool SpeakerEmbeddingManager::Remove(const std::string &name) const { + return impl_->Remove(name); +} + +std::string SpeakerEmbeddingManager::Search(const float *p, + float threshold) const { + return impl_->Search(p, threshold); +} + +bool SpeakerEmbeddingManager::Verify(const std::string &name, const float *p, + float threshold) const { + return impl_->Verify(name, p, threshold); +} + +int32_t SpeakerEmbeddingManager::NumSpeakers() const { + return impl_->NumSpeakers(); +} + +} // namespace sherpa_onnx diff --git a/sherpa-onnx/csrc/speaker-embedding-manager.h b/sherpa-onnx/csrc/speaker-embedding-manager.h new file mode 100644 index 00000000..25f85a93 --- /dev/null +++ b/sherpa-onnx/csrc/speaker-embedding-manager.h @@ -0,0 +1,72 @@ +// sherpa-onnx/csrc/speaker-embedding-manager.h +// +// Copyright (c) 2023 Xiaomi Corporation + +#ifndef SHERPA_ONNX_CSRC_SPEAKER_EMBEDDING_MANAGER_H_ +#define SHERPA_ONNX_CSRC_SPEAKER_EMBEDDING_MANAGER_H_ + +#include +#include + +namespace sherpa_onnx { + +class SpeakerEmbeddingManager { + public: + // @param dim Embedding dimension. + explicit SpeakerEmbeddingManager(int32_t dim); + ~SpeakerEmbeddingManager(); + + /* Add the embedding and name of a speaker to the manager. + * + * @param name Name of the speaker + * @param p Pointer to the embedding. Its length is `dim`. + * @return Return true if added successfully. Return false if it failed. + * At present, the only reason for a failure is that there is already + * a speaker with the same `name`. + */ + bool Add(const std::string &name, const float *p) const; + + /* Remove a speaker by its name. + * + * @param name Name of the speaker to remove. + * @return Return true if it is removed successfully. Return false + * if there is no such a speaker. + */ + bool Remove(const std::string &name) const; + + /** It is for speaker identification. + * + * It computes the cosine similarity between and given embedding and all + * other embeddings and find the embedding that has the largest score + * and the score is above or equal to threshold. Return the speaker + * name for the embedding if found; otherwise, it returns an empty string. + * + * @param p The input embedding. + * @param threshold A value between 0 and 1. + * @param If found, return the name of the speaker. Otherwise, return an + * empty string. + */ + std::string Search(const float *p, float threshold) const; + + /* Check whether the input embedding matches the embedding of the input + * speaker. + * + * It is for speaker verification. + * + * @param name The target speaker name. + * @param p The input embedding to check. + * @param threshold A value between 0 and 1. + * @return Return true if it matches. Otherwise, it returns false. + */ + bool Verify(const std::string &name, const float *p, float threshold) const; + + int32_t NumSpeakers() const; + + private: + class Impl; + std::unique_ptr impl_; +}; + +} // namespace sherpa_onnx + +#endif // SHERPA_ONNX_CSRC_SPEAKER_EMBEDDING_MANAGER_H_ diff --git a/sherpa-onnx/csrc/voice-activity-detector.cc b/sherpa-onnx/csrc/voice-activity-detector.cc index af79db58..5f63acf1 100644 --- a/sherpa-onnx/csrc/voice-activity-detector.cc +++ b/sherpa-onnx/csrc/voice-activity-detector.cc @@ -40,7 +40,9 @@ class VoiceActivityDetector::Impl { for (int32_t i = 0; i != k; ++i, p += window_size) { buffer_.Push(p, window_size); - is_speech = is_speech || model_->IsSpeech(p, window_size); + // NOTE(fangjun): Please don't use a very large n. + bool this_window_is_speech = model_->IsSpeech(p, window_size); + is_speech = is_speech || this_window_is_speech; } last_ = std::vector( @@ -102,6 +104,8 @@ class VoiceActivityDetector::Impl { bool IsSpeechDetected() const { return start_ != -1; } + const VadModelConfig &GetConfig() const { return config_; } + private: std::queue segments_; @@ -146,4 +150,8 @@ bool VoiceActivityDetector::IsSpeechDetected() const { return impl_->IsSpeechDetected(); } +const VadModelConfig &VoiceActivityDetector::GetConfig() const { + return impl_->GetConfig(); +} + } // namespace sherpa_onnx diff --git a/sherpa-onnx/csrc/voice-activity-detector.h b/sherpa-onnx/csrc/voice-activity-detector.h index 603bfbe7..c7a3cb99 100644 --- a/sherpa-onnx/csrc/voice-activity-detector.h +++ b/sherpa-onnx/csrc/voice-activity-detector.h @@ -43,6 +43,8 @@ class VoiceActivityDetector { void Reset(); + const VadModelConfig &GetConfig() const; + private: class Impl; std::unique_ptr impl_; diff --git a/sherpa-onnx/python/csrc/CMakeLists.txt b/sherpa-onnx/python/csrc/CMakeLists.txt index 85120a80..a94a1ff6 100644 --- a/sherpa-onnx/python/csrc/CMakeLists.txt +++ b/sherpa-onnx/python/csrc/CMakeLists.txt @@ -30,6 +30,8 @@ pybind11_add_module(_sherpa_onnx online-zipformer2-ctc-model-config.cc sherpa-onnx.cc silero-vad-model-config.cc + speaker-embedding-extractor.cc + speaker-embedding-manager.cc vad-model-config.cc vad-model.cc voice-activity-detector.cc diff --git a/sherpa-onnx/python/csrc/online-recognizer.h b/sherpa-onnx/python/csrc/online-recognizer.h index 0e652c7f..bc953511 100644 --- a/sherpa-onnx/python/csrc/online-recognizer.h +++ b/sherpa-onnx/python/csrc/online-recognizer.h @@ -1,4 +1,4 @@ -// sherpa-onnx/python/csrc/online-recongizer.h +// sherpa-onnx/python/csrc/online-recognizer.h // // Copyright (c) 2023 Xiaomi Corporation diff --git a/sherpa-onnx/python/csrc/sherpa-onnx.cc b/sherpa-onnx/python/csrc/sherpa-onnx.cc index 4af65fe7..37728426 100644 --- a/sherpa-onnx/python/csrc/sherpa-onnx.cc +++ b/sherpa-onnx/python/csrc/sherpa-onnx.cc @@ -18,6 +18,8 @@ #include "sherpa-onnx/python/csrc/online-model-config.h" #include "sherpa-onnx/python/csrc/online-recognizer.h" #include "sherpa-onnx/python/csrc/online-stream.h" +#include "sherpa-onnx/python/csrc/speaker-embedding-extractor.h" +#include "sherpa-onnx/python/csrc/speaker-embedding-manager.h" #include "sherpa-onnx/python/csrc/vad-model-config.h" #include "sherpa-onnx/python/csrc/vad-model.h" #include "sherpa-onnx/python/csrc/voice-activity-detector.h" @@ -48,6 +50,8 @@ PYBIND11_MODULE(_sherpa_onnx, m) { PybindVoiceActivityDetector(&m); PybindOfflineTts(&m); + PybindSpeakerEmbeddingExtractor(&m); + PybindSpeakerEmbeddingManager(&m); } } // namespace sherpa_onnx diff --git a/sherpa-onnx/python/csrc/speaker-embedding-extractor.cc b/sherpa-onnx/python/csrc/speaker-embedding-extractor.cc new file mode 100644 index 00000000..2749ba3b --- /dev/null +++ b/sherpa-onnx/python/csrc/speaker-embedding-extractor.cc @@ -0,0 +1,44 @@ +// sherpa-onnx/python/csrc/speaker-embedding-extractor.cc +// +// Copyright (c) 2023 Xiaomi Corporation + +#include "sherpa-onnx/python/csrc/speaker-embedding-extractor.h" + +#include + +#include "sherpa-onnx/csrc/speaker-embedding-extractor.h" + +namespace sherpa_onnx { + +static void PybindSpeakerEmbeddingExtractorConfig(py::module *m) { + using PyClass = SpeakerEmbeddingExtractorConfig; + py::class_(*m, "SpeakerEmbeddingExtractorConfig") + .def(py::init<>()) + .def(py::init(), + py::arg("model"), py::arg("num_threads") = 1, + py::arg("debug") = false, py::arg("provider") = "cpu") + .def_readwrite("model", &PyClass::model) + .def_readwrite("num_threads", &PyClass::num_threads) + .def_readwrite("debug", &PyClass::debug) + .def_readwrite("provider", &PyClass::provider) + .def("validate", &PyClass::Validate) + .def("__str__", &PyClass::ToString); +} + +void PybindSpeakerEmbeddingExtractor(py::module *m) { + PybindSpeakerEmbeddingExtractorConfig(m); + + using PyClass = SpeakerEmbeddingExtractor; + py::class_(*m, "SpeakerEmbeddingExtractor") + .def(py::init(), + py::arg("config"), py::call_guard()) + .def_property_readonly("dim", &PyClass::Dim) + .def("create_stream", &PyClass::CreateStream, + py::call_guard()) + .def("compute", &PyClass::Compute, + py::call_guard()) + .def("is_ready", &PyClass::IsReady, + py::call_guard()); +} + +} // namespace sherpa_onnx diff --git a/sherpa-onnx/python/csrc/speaker-embedding-extractor.h b/sherpa-onnx/python/csrc/speaker-embedding-extractor.h new file mode 100644 index 00000000..64014b90 --- /dev/null +++ b/sherpa-onnx/python/csrc/speaker-embedding-extractor.h @@ -0,0 +1,16 @@ +// sherpa-onnx/python/csrc/speaker-embedding-extractor.h +// +// Copyright (c) 2023 Xiaomi Corporation + +#ifndef SHERPA_ONNX_PYTHON_CSRC_SPEAKER_EMBEDDING_EXTRACTOR_H_ +#define SHERPA_ONNX_PYTHON_CSRC_SPEAKER_EMBEDDING_EXTRACTOR_H_ + +#include "sherpa-onnx/python/csrc/sherpa-onnx.h" + +namespace sherpa_onnx { + +void PybindSpeakerEmbeddingExtractor(py::module *m); + +} + +#endif // SHERPA_ONNX_PYTHON_CSRC_SPEAKER_EMBEDDING_EXTRACTOR_H_ diff --git a/sherpa-onnx/python/csrc/speaker-embedding-manager.cc b/sherpa-onnx/python/csrc/speaker-embedding-manager.cc new file mode 100644 index 00000000..3279df82 --- /dev/null +++ b/sherpa-onnx/python/csrc/speaker-embedding-manager.cc @@ -0,0 +1,50 @@ +// sherpa-onnx/python/csrc/speaker-embedding-manager.cc +// +// Copyright (c) 2023 Xiaomi Corporation + +#include "sherpa-onnx/python/csrc/speaker-embedding-manager.h" + +#include +#include + +#include "sherpa-onnx/csrc/speaker-embedding-manager.h" + +namespace sherpa_onnx { + +void PybindSpeakerEmbeddingManager(py::module *m) { + using PyClass = SpeakerEmbeddingManager; + py::class_(*m, "SpeakerEmbeddingManager") + .def(py::init(), py::arg("dim"), + py::call_guard()) + .def_property_readonly("num_speakers", &PyClass::NumSpeakers) + .def( + "add", + [](const PyClass &self, const std::string &name, + const std::vector &v) -> bool { + return self.Add(name, v.data()); + }, + py::arg("name"), py::arg("v"), + py::call_guard()) + .def( + "remove", + [](const PyClass &self, const std::string &name) -> bool { + return self.Remove(name); + }, + py::arg("name"), py::call_guard()) + .def( + "search", + [](const PyClass &self, const std::vector &v, float threshold) + -> std::string { return self.Search(v.data(), threshold); }, + py::arg("v"), py::arg("threshold"), + py::call_guard()) + .def( + "verify", + [](const PyClass &self, const std::string &name, + const std::vector &v, float threshold) -> bool { + return self.Verify(name, v.data(), threshold); + }, + py::arg("name"), py::arg("v"), py::arg("threshold"), + py::call_guard()); +} + +} // namespace sherpa_onnx diff --git a/sherpa-onnx/python/csrc/speaker-embedding-manager.h b/sherpa-onnx/python/csrc/speaker-embedding-manager.h new file mode 100644 index 00000000..a3f9875b --- /dev/null +++ b/sherpa-onnx/python/csrc/speaker-embedding-manager.h @@ -0,0 +1,16 @@ +// sherpa-onnx/python/csrc/speaker-embedding-manager.h +// +// Copyright (c) 2023 Xiaomi Corporation + +#ifndef SHERPA_ONNX_PYTHON_CSRC_SPEAKER_EMBEDDING_MANAGER_H_ +#define SHERPA_ONNX_PYTHON_CSRC_SPEAKER_EMBEDDING_MANAGER_H_ + +#include "sherpa-onnx/python/csrc/sherpa-onnx.h" + +namespace sherpa_onnx { + +void PybindSpeakerEmbeddingManager(py::module *m); + +} // namespace sherpa_onnx + +#endif // SHERPA_ONNX_PYTHON_CSRC_SPEAKER_EMBEDDING_MANAGER_H_ diff --git a/sherpa-onnx/python/csrc/voice-activity-detector.cc b/sherpa-onnx/python/csrc/voice-activity-detector.cc index f360c7ec..698297bc 100644 --- a/sherpa-onnx/python/csrc/voice-activity-detector.cc +++ b/sherpa-onnx/python/csrc/voice-activity-detector.cc @@ -32,6 +32,7 @@ void PybindVoiceActivityDetector(py::module *m) { self.AcceptWaveform(samples.data(), samples.size()); }, py::arg("samples"), py::call_guard()) + .def_property_readonly("config", &PyClass::GetConfig) .def("empty", &PyClass::Empty, py::call_guard()) .def("pop", &PyClass::Pop, py::call_guard()) .def("is_speech_detected", &PyClass::IsSpeechDetected, diff --git a/sherpa-onnx/python/sherpa_onnx/__init__.py b/sherpa-onnx/python/sherpa_onnx/__init__.py index d8ed0d4d..0f13f38c 100644 --- a/sherpa-onnx/python/sherpa_onnx/__init__.py +++ b/sherpa-onnx/python/sherpa_onnx/__init__.py @@ -8,6 +8,9 @@ from _sherpa_onnx import ( OfflineTtsVitsModelConfig, OnlineStream, SileroVadModelConfig, + SpeakerEmbeddingExtractor, + SpeakerEmbeddingExtractorConfig, + SpeakerEmbeddingManager, SpeechSegment, VadModel, VadModelConfig,