convert wespeaker models to sherpa-onnx (#475)
This commit is contained in:
10
scripts/wespeaker/README.md
Normal file
10
scripts/wespeaker/README.md
Normal file
@@ -0,0 +1,10 @@
|
||||
# Introduction
|
||||
|
||||
This folder contains script for adding meta data to onnx models from
|
||||
https://github.com/wenet-e2e/wespeaker/blob/master/docs/pretrained.md
|
||||
|
||||
You can use the models with metadata in sherpa-onnx.
|
||||
|
||||
|
||||
**Caution**: You have to add model meta data to `*.onnx` since we plan
|
||||
to support models from different frameworks.
|
||||
143
scripts/wespeaker/add_meta_data.py
Executable file
143
scripts/wespeaker/add_meta_data.py
Executable file
@@ -0,0 +1,143 @@
|
||||
#!/usr/bin/env python3
|
||||
# Copyright 2023 Xiaomi Corp. (authors: Fangjun Kuang)
|
||||
|
||||
"""
|
||||
This script adds meta data to a model so that it can be used in sherpa-onnx.
|
||||
|
||||
Usage:
|
||||
./add_meta_data.py --model ./voxceleb_resnet34.onnx --language English
|
||||
"""
|
||||
|
||||
import argparse
|
||||
from pathlib import Path
|
||||
from typing import Dict
|
||||
|
||||
import onnx
|
||||
import onnxruntime
|
||||
|
||||
|
||||
def get_args():
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument(
|
||||
"--model",
|
||||
type=str,
|
||||
required=True,
|
||||
help="Path to the input onnx model. Example value: model.onnx",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--language",
|
||||
type=str,
|
||||
required=True,
|
||||
help="""Supported language of the input model.
|
||||
Example value: Chinese, English.
|
||||
""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--url",
|
||||
type=str,
|
||||
default="https://github.com/wenet-e2e/wespeaker/blob/master/docs/pretrained.md",
|
||||
help="Where the model is downloaded",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--comment",
|
||||
type=str,
|
||||
default="no comment",
|
||||
help="Comment about the model",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--sample-rate",
|
||||
type=int,
|
||||
default=16000,
|
||||
help="Sample rate expected by the model",
|
||||
)
|
||||
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
def add_meta_data(filename: str, meta_data: Dict[str, str]):
|
||||
"""Add meta data to an ONNX model. It is changed in-place.
|
||||
|
||||
Args:
|
||||
filename:
|
||||
Filename of the ONNX model to be changed.
|
||||
meta_data:
|
||||
Key-value pairs.
|
||||
"""
|
||||
model = onnx.load(filename)
|
||||
for key, value in meta_data.items():
|
||||
meta = model.metadata_props.add()
|
||||
meta.key = key
|
||||
meta.value = str(value)
|
||||
|
||||
onnx.save(model, filename)
|
||||
|
||||
|
||||
def get_output_dim(filename) -> int:
|
||||
filename = str(filename)
|
||||
session_opts = onnxruntime.SessionOptions()
|
||||
session_opts.log_severity_level = 3 # error level
|
||||
sess = onnxruntime.InferenceSession(filename, session_opts)
|
||||
|
||||
for i in sess.get_inputs():
|
||||
print(i)
|
||||
|
||||
print("----------")
|
||||
|
||||
for o in sess.get_outputs():
|
||||
print(o)
|
||||
|
||||
print("----------")
|
||||
|
||||
assert len(sess.get_inputs()) == 1
|
||||
assert len(sess.get_outputs()) == 1
|
||||
|
||||
i = sess.get_inputs()[0]
|
||||
o = sess.get_outputs()[0]
|
||||
|
||||
assert i.shape[:2] == ["B", "T"], i.shape
|
||||
assert o.shape[0] == "B"
|
||||
|
||||
assert i.shape[2] == 80, i.shape
|
||||
|
||||
return o.shape[1]
|
||||
|
||||
|
||||
def main():
|
||||
args = get_args()
|
||||
model = Path(args.model)
|
||||
language = args.language
|
||||
url = args.url
|
||||
comment = args.comment
|
||||
sample_rate = args.sample_rate
|
||||
|
||||
if not model.is_file():
|
||||
raise ValueError(f"{model} does not exist")
|
||||
|
||||
assert len(language) > 0, len(language)
|
||||
assert len(url) > 0, len(url)
|
||||
|
||||
output_dim = get_output_dim(model)
|
||||
|
||||
# all models from wespeaker expect input samples in the range
|
||||
# [-32768, 32767]
|
||||
normalize_features = 0
|
||||
|
||||
meta_data = {
|
||||
"framework": "wespeaker",
|
||||
"language": language,
|
||||
"url": url,
|
||||
"comment": comment,
|
||||
"sample_rate": sample_rate,
|
||||
"output_dim": output_dim,
|
||||
"normalize_features": normalize_features,
|
||||
}
|
||||
print(meta_data)
|
||||
add_meta_data(filename=str(model), meta_data=meta_data)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
132
scripts/wespeaker/run.sh
Executable file
132
scripts/wespeaker/run.sh
Executable file
@@ -0,0 +1,132 @@
|
||||
#!/usr/bin/env bash
|
||||
|
||||
set -ex
|
||||
|
||||
echo "Downloading models"
|
||||
export GIT_LFS_SKIP_SMUDGE=1
|
||||
git clone https://huggingface.co/openspeech/wespeaker-models
|
||||
cd wespeaker-models
|
||||
git lfs pull --include "*.onnx"
|
||||
ls -lh
|
||||
cd ..
|
||||
mv wespeaker-models/*.onnx .
|
||||
ls -lh
|
||||
|
||||
./add_meta_data.py \
|
||||
--model ./voxceleb_resnet34.onnx \
|
||||
--language English \
|
||||
--url https://wespeaker-1256283475.cos.ap-shanghai.myqcloud.com/models/voxceleb/voxceleb_resnet34.onnx
|
||||
./test.py --model ./voxceleb_resnet34.onnx \
|
||||
--file1 ./wespeaker-models/test_wavs/00001_spk1.wav \
|
||||
--file2 ./wespeaker-models/test_wavs/00024_spk1.wav \
|
||||
|
||||
./test.py --model ./voxceleb_resnet34.onnx \
|
||||
--file1 ./wespeaker-models/test_wavs/00001_spk1.wav \
|
||||
--file2 ./wespeaker-models/test_wavs/00010_spk2.wav
|
||||
|
||||
mv voxceleb_resnet34.onnx en_voxceleb_resnet34.onnx
|
||||
|
||||
./add_meta_data.py \
|
||||
--model ./voxceleb_resnet34_LM.onnx \
|
||||
--language English \
|
||||
--url https://wespeaker-1256283475.cos.ap-shanghai.myqcloud.com/models/voxceleb/voxceleb_resnet34_LM.onnx
|
||||
./test.py --model ./voxceleb_resnet34_LM.onnx \
|
||||
--file1 ./wespeaker-models/test_wavs/00001_spk1.wav \
|
||||
--file2 ./wespeaker-models/test_wavs/00024_spk1.wav \
|
||||
|
||||
./test.py --model ./voxceleb_resnet34_LM.onnx \
|
||||
--file1 ./wespeaker-models/test_wavs/00001_spk1.wav \
|
||||
--file2 ./wespeaker-models/test_wavs/00010_spk2.wav
|
||||
|
||||
mv voxceleb_resnet34_LM.onnx en_voxceleb_resnet34_LM.onnx
|
||||
|
||||
./add_meta_data.py \
|
||||
--model ./voxceleb_resnet152_LM.onnx \
|
||||
--language English \
|
||||
--url https://wespeaker-1256283475.cos.ap-shanghai.myqcloud.com/models/voxceleb/voxceleb_resnet152_LM.onnx
|
||||
|
||||
./test.py --model ./voxceleb_resnet152_LM.onnx \
|
||||
--file1 ./wespeaker-models/test_wavs/00001_spk1.wav \
|
||||
--file2 ./wespeaker-models/test_wavs/00024_spk1.wav \
|
||||
|
||||
./test.py --model ./voxceleb_resnet152_LM.onnx \
|
||||
--file1 ./wespeaker-models/test_wavs/00001_spk1.wav \
|
||||
--file2 ./wespeaker-models/test_wavs/00010_spk2.wav
|
||||
|
||||
mv voxceleb_resnet152_LM.onnx en_voxceleb_resnet152_LM.onnx
|
||||
|
||||
./add_meta_data.py \
|
||||
--model ./voxceleb_resnet221_LM.onnx \
|
||||
--language English \
|
||||
--url https://wespeaker-1256283475.cos.ap-shanghai.myqcloud.com/models/voxceleb/voxceleb_resnet221_LM.onnx
|
||||
|
||||
./test.py --model ./voxceleb_resnet221_LM.onnx \
|
||||
--file1 ./wespeaker-models/test_wavs/00001_spk1.wav \
|
||||
--file2 ./wespeaker-models/test_wavs/00024_spk1.wav \
|
||||
|
||||
./test.py --model ./voxceleb_resnet221_LM.onnx \
|
||||
--file1 ./wespeaker-models/test_wavs/00001_spk1.wav \
|
||||
--file2 ./wespeaker-models/test_wavs/00010_spk2.wav
|
||||
|
||||
mv voxceleb_resnet221_LM.onnx en_voxceleb_resnet221_LM.onnx
|
||||
|
||||
./add_meta_data.py \
|
||||
--model ./voxceleb_resnet293_LM.onnx \
|
||||
--language English \
|
||||
--url https://wespeaker-1256283475.cos.ap-shanghai.myqcloud.com/models/voxceleb/voxceleb_resnet293_LM.onnx
|
||||
|
||||
./test.py --model ./voxceleb_resnet293_LM.onnx \
|
||||
--file1 ./wespeaker-models/test_wavs/00001_spk1.wav \
|
||||
--file2 ./wespeaker-models/test_wavs/00024_spk1.wav \
|
||||
|
||||
./test.py --model ./voxceleb_resnet293_LM.onnx \
|
||||
--file1 ./wespeaker-models/test_wavs/00001_spk1.wav \
|
||||
--file2 ./wespeaker-models/test_wavs/00010_spk2.wav
|
||||
|
||||
mv voxceleb_resnet293_LM.onnx en_voxceleb_resnet293_LM.onnx
|
||||
|
||||
./add_meta_data.py \
|
||||
--model ./voxceleb_CAM++.onnx \
|
||||
--language English \
|
||||
--url https://wespeaker-1256283475.cos.ap-shanghai.myqcloud.com/models/voxceleb/voxceleb_CAM++.onnx
|
||||
|
||||
./test.py --model ./voxceleb_CAM++.onnx \
|
||||
--file1 ./wespeaker-models/test_wavs/00001_spk1.wav \
|
||||
--file2 ./wespeaker-models/test_wavs/00024_spk1.wav \
|
||||
|
||||
./test.py --model ./voxceleb_CAM++.onnx \
|
||||
--file1 ./wespeaker-models/test_wavs/00001_spk1.wav \
|
||||
--file2 ./wespeaker-models/test_wavs/00010_spk2.wav
|
||||
|
||||
mv voxceleb_CAM++.onnx en_voxceleb_CAM++.onnx
|
||||
|
||||
./add_meta_data.py \
|
||||
--model ./voxceleb_CAM++_LM.onnx \
|
||||
--language English \
|
||||
--url https://wespeaker-1256283475.cos.ap-shanghai.myqcloud.com/models/voxceleb/voxceleb_CAM++_LM.onnx
|
||||
|
||||
./test.py --model ./voxceleb_CAM++_LM.onnx \
|
||||
--file1 ./wespeaker-models/test_wavs/00001_spk1.wav \
|
||||
--file2 ./wespeaker-models/test_wavs/00024_spk1.wav \
|
||||
|
||||
./test.py --model ./voxceleb_CAM++_LM.onnx \
|
||||
--file1 ./wespeaker-models/test_wavs/00001_spk1.wav \
|
||||
--file2 ./wespeaker-models/test_wavs/00010_spk2.wav
|
||||
|
||||
mv voxceleb_CAM++_LM.onnx en_voxceleb_CAM++_LM.onnx
|
||||
|
||||
./add_meta_data.py \
|
||||
--model ./cnceleb_resnet34.onnx \
|
||||
--language Chinese \
|
||||
--url https://wespeaker-1256283475.cos.ap-shanghai.myqcloud.com/models/cnceleb/cnceleb_resnet34.onnx
|
||||
|
||||
mv cnceleb_resnet34.onnx zh_cnceleb_resnet34.onnx
|
||||
|
||||
./add_meta_data.py \
|
||||
--model ./cnceleb_resnet34_LM.onnx \
|
||||
--language Chinese \
|
||||
--url https://wespeaker-1256283475.cos.ap-shanghai.myqcloud.com/models/cnceleb/cnceleb_resnet34_LM.onnx
|
||||
|
||||
mv cnceleb_resnet34_LM.onnx zh_cnceleb_resnet34_LM.onnx
|
||||
|
||||
ls -lh
|
||||
171
scripts/wespeaker/test.py
Executable file
171
scripts/wespeaker/test.py
Executable file
@@ -0,0 +1,171 @@
|
||||
#!/usr/bin/env python3
|
||||
# Copyright 2023 Xiaomi Corp. (authors: Fangjun Kuang)
|
||||
|
||||
"""
|
||||
This script computes speaker similarity score in the range [0-1]
|
||||
of two wave files using a speaker recognition model.
|
||||
"""
|
||||
import argparse
|
||||
import wave
|
||||
from pathlib import Path
|
||||
|
||||
import kaldi_native_fbank as knf
|
||||
import numpy as np
|
||||
import onnxruntime as ort
|
||||
from numpy.linalg import norm
|
||||
|
||||
|
||||
def get_args():
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument(
|
||||
"--model",
|
||||
type=str,
|
||||
required=True,
|
||||
help="Path to the input onnx model. Example value: model.onnx",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--file1",
|
||||
type=str,
|
||||
required=True,
|
||||
help="Input wave 1",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--file2",
|
||||
type=str,
|
||||
required=True,
|
||||
help="Input wave 2",
|
||||
)
|
||||
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
def read_wavefile(filename, expected_sample_rate: int = 16000) -> np.ndarray:
|
||||
"""
|
||||
Args:
|
||||
filename:
|
||||
Path to a wave file, which must be of 16-bit and 16kHz.
|
||||
expected_sample_rate:
|
||||
Expected sample rate of the wave file.
|
||||
Returns:
|
||||
Return a 1-D float32 array containing audio samples. Each sample is in
|
||||
the range [-1, 1].
|
||||
"""
|
||||
filename = str(filename)
|
||||
with wave.open(filename) as f:
|
||||
# Note: If wave_file_sample_rate is different from
|
||||
# recognizer.sample_rate, we will do resampling inside sherpa-ncnn
|
||||
wave_file_sample_rate = f.getframerate()
|
||||
assert wave_file_sample_rate == expected_sample_rate, (
|
||||
wave_file_sample_rate,
|
||||
expected_sample_rate,
|
||||
)
|
||||
|
||||
num_channels = f.getnchannels()
|
||||
assert f.getsampwidth() == 2, f.getsampwidth() # it is in bytes
|
||||
num_samples = f.getnframes()
|
||||
samples = f.readframes(num_samples)
|
||||
samples_int16 = np.frombuffer(samples, dtype=np.int16)
|
||||
samples_int16 = samples_int16.reshape(-1, num_channels)[:, 0]
|
||||
samples_float32 = samples_int16.astype(np.float32)
|
||||
|
||||
samples_float32 = samples_float32 / 32768
|
||||
|
||||
return samples_float32
|
||||
|
||||
|
||||
def compute_features(samples: np.ndarray, sample_rate: int) -> np.ndarray:
|
||||
opts = knf.FbankOptions()
|
||||
opts.frame_opts.dither = 0
|
||||
opts.frame_opts.samp_freq = sample_rate
|
||||
opts.frame_opts.snip_edges = False
|
||||
|
||||
opts.mel_opts.num_bins = 80
|
||||
opts.mel_opts.debug_mel = False
|
||||
|
||||
fbank = knf.OnlineFbank(opts)
|
||||
fbank.accept_waveform(sample_rate, samples)
|
||||
fbank.input_finished()
|
||||
|
||||
features = []
|
||||
for i in range(fbank.num_frames_ready):
|
||||
f = fbank.get_frame(i)
|
||||
features.append(f)
|
||||
features = np.stack(features, axis=0)
|
||||
|
||||
return features
|
||||
|
||||
|
||||
class OnnxModel:
|
||||
def __init__(
|
||||
self,
|
||||
filename: str,
|
||||
):
|
||||
session_opts = ort.SessionOptions()
|
||||
session_opts.inter_op_num_threads = 1
|
||||
session_opts.intra_op_num_threads = 4
|
||||
|
||||
self.session_opts = session_opts
|
||||
|
||||
self.model = ort.InferenceSession(
|
||||
filename,
|
||||
sess_options=self.session_opts,
|
||||
)
|
||||
|
||||
meta = self.model.get_modelmeta().custom_metadata_map
|
||||
self.normalize_features = int(meta["normalize_features"])
|
||||
self.sample_rate = int(meta["sample_rate"])
|
||||
self.output_dim = int(meta["output_dim"])
|
||||
|
||||
def __call__(self, x: np.ndarray) -> np.ndarray:
|
||||
"""
|
||||
Args:
|
||||
x:
|
||||
A 2-D float32 tensor of shape (T, C).
|
||||
y:
|
||||
A 1-D float32 tensor containing model output.
|
||||
"""
|
||||
x = np.expand_dims(x, axis=0)
|
||||
|
||||
return self.model.run(
|
||||
[
|
||||
self.model.get_outputs()[0].name,
|
||||
],
|
||||
{
|
||||
self.model.get_inputs()[0].name: x,
|
||||
},
|
||||
)[0][0]
|
||||
|
||||
|
||||
def main():
|
||||
args = get_args()
|
||||
filename = Path(args.model)
|
||||
file1 = Path(args.file1)
|
||||
file2 = Path(args.file2)
|
||||
assert filename.is_file(), filename
|
||||
assert file1.is_file(), file1
|
||||
assert file2.is_file(), file2
|
||||
|
||||
model = OnnxModel(filename)
|
||||
wave1 = read_wavefile(file1, model.sample_rate)
|
||||
wave2 = read_wavefile(file2, model.sample_rate)
|
||||
|
||||
if not model.normalize_features:
|
||||
wave1 = wave1 * 32768
|
||||
wave2 = wave2 * 32768
|
||||
|
||||
features1 = compute_features(wave1, model.sample_rate)
|
||||
features2 = compute_features(wave2, model.sample_rate)
|
||||
|
||||
output1 = model(features1)
|
||||
output2 = model(features2)
|
||||
|
||||
print(output1.shape)
|
||||
print(output2.shape)
|
||||
similarity = np.dot(output1, output2) / (norm(output1) * norm(output2))
|
||||
print(f"similarity in the range [0-1]: {similarity}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
Reference in New Issue
Block a user