Begin to support https://github.com/usefulsensors/moonshine (#1470)
This commit is contained in:
1
scripts/moonshine/.gitignore
vendored
Normal file
1
scripts/moonshine/.gitignore
vendored
Normal file
@@ -0,0 +1 @@
|
||||
tokenizer.json
|
||||
7
scripts/moonshine/README.md
Normal file
7
scripts/moonshine/README.md
Normal file
@@ -0,0 +1,7 @@
|
||||
# Introduction
|
||||
|
||||
This directory contains models from
|
||||
https://github.com/usefulsensors/moonshine
|
||||
|
||||
See its license at
|
||||
https://github.com/usefulsensors/moonshine/blob/main/LICENSE
|
||||
40
scripts/moonshine/export-onnx.py
Executable file
40
scripts/moonshine/export-onnx.py
Executable file
@@ -0,0 +1,40 @@
|
||||
#!/usr/bin/env python3
|
||||
# Copyright 2024 Xiaomi Corp. (authors: Fangjun Kuang)
|
||||
|
||||
from pathlib import Path
|
||||
|
||||
import tokenizers
|
||||
from onnxruntime.quantization import QuantType, quantize_dynamic
|
||||
|
||||
|
||||
def generate_tokens():
|
||||
if Path("./tokens.txt").is_file():
|
||||
return
|
||||
print("Generating tokens.txt")
|
||||
tokenizer = tokenizers.Tokenizer.from_file("./tokenizer.json")
|
||||
vocab_size = tokenizer.get_vocab_size()
|
||||
with open("tokens.txt", "w", encoding="utf-8") as f:
|
||||
for i in range(vocab_size):
|
||||
s = tokenizer.id_to_token(i).strip()
|
||||
f.write(f"{s}\t{i}\n")
|
||||
|
||||
|
||||
def main():
|
||||
generate_tokens()
|
||||
|
||||
# Note(fangjun): Don't use int8 for the preprocessor since it has
|
||||
# a larger impact on the accuracy
|
||||
for f in ["uncached_decode", "cached_decode", "encode"]:
|
||||
if Path(f"{f}.int8.onnx").is_file():
|
||||
continue
|
||||
|
||||
print("processing", f)
|
||||
quantize_dynamic(
|
||||
model_input=f"{f}.onnx",
|
||||
model_output=f"{f}.int8.onnx",
|
||||
weight_type=QuantType.QInt8,
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
90
scripts/moonshine/run.sh
Executable file
90
scripts/moonshine/run.sh
Executable file
@@ -0,0 +1,90 @@
|
||||
#!/usr/bin/env bash
|
||||
# Copyright 2024 Xiaomi Corp. (authors: Fangjun Kuang)
|
||||
set -ex
|
||||
|
||||
cat >LICENSE <<EOF
|
||||
MIT License
|
||||
|
||||
Copyright (c) 2024 Useful Sensors
|
||||
|
||||
Permission is hereby granted, free of charge, to any person obtaining a copy
|
||||
of this software and associated documentation files (the "Software"), to deal
|
||||
in the Software without restriction, including without limitation the rights
|
||||
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
||||
copies of the Software, and to permit persons to whom the Software is
|
||||
furnished to do so, subject to the following conditions:
|
||||
|
||||
The above copyright notice and this permission notice shall be included in all
|
||||
copies or substantial portions of the Software.
|
||||
|
||||
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||||
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
||||
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
||||
SOFTWARE.
|
||||
EOF
|
||||
|
||||
function download_files() {
|
||||
for d in tiny base; do
|
||||
mkdir $d
|
||||
|
||||
pushd $d
|
||||
curl -SL -O https://huggingface.co/UsefulSensors/moonshine/resolve/main/onnx/$d/preprocess.onnx
|
||||
curl -SL -O https://huggingface.co/UsefulSensors/moonshine/resolve/main/onnx/$d/encode.onnx
|
||||
curl -SL -O https://huggingface.co/UsefulSensors/moonshine/resolve/main/onnx/$d/uncached_decode.onnx
|
||||
curl -SL -O https://huggingface.co/UsefulSensors/moonshine/resolve/main/onnx/$d/cached_decode.onnx
|
||||
popd
|
||||
done
|
||||
|
||||
curl -SL -O https://huggingface.co/csukuangfj/sherpa-onnx-whisper-base/resolve/main/test_wavs/0.wav
|
||||
curl -SL -O https://huggingface.co/csukuangfj/sherpa-onnx-whisper-base/resolve/main/test_wavs/1.wav
|
||||
curl -SL -O https://huggingface.co/csukuangfj/sherpa-onnx-whisper-base/resolve/main/test_wavs/8k.wav
|
||||
curl -SL -O https://huggingface.co/csukuangfj/sherpa-onnx-whisper-base/resolve/main/test_wavs/trans.txt
|
||||
|
||||
curl -SL -O https://raw.githubusercontent.com/usefulsensors/moonshine/refs/heads/main/moonshine/assets/tokenizer.json
|
||||
}
|
||||
|
||||
function quantize() {
|
||||
for d in tiny base; do
|
||||
echo "==========$d=========="
|
||||
ls -lh
|
||||
mv $d/*.onnx .
|
||||
./export-onnx.py
|
||||
rm cached_decode.onnx
|
||||
rm uncached_decode.onnx
|
||||
rm encode.onnx
|
||||
ls -lh
|
||||
|
||||
./test.py
|
||||
|
||||
mv *.onnx $d
|
||||
mv tokens.txt $d
|
||||
ls -lh $d
|
||||
|
||||
done
|
||||
}
|
||||
|
||||
function zip() {
|
||||
for d in tiny base; do
|
||||
s=sherpa-onnx-moonshine-$d-en-int8
|
||||
mv $d $s
|
||||
|
||||
mkdir $s/test_wavs
|
||||
|
||||
cp -v *.wav $s/test_wavs
|
||||
cp trans.txt $s/test_wavs
|
||||
cp LICENSE $s/
|
||||
cp ./README.md $s
|
||||
|
||||
ls -lh $s
|
||||
tar cjfv $s.tar.bz2 $s
|
||||
done
|
||||
}
|
||||
|
||||
download_files
|
||||
quantize
|
||||
zip
|
||||
|
||||
ls -lh
|
||||
274
scripts/moonshine/test.py
Executable file
274
scripts/moonshine/test.py
Executable file
@@ -0,0 +1,274 @@
|
||||
#!/usr/bin/env python3
|
||||
# Copyright 2024 Xiaomi Corp. (authors: Fangjun Kuang)
|
||||
import datetime as dt
|
||||
|
||||
import librosa
|
||||
import numpy as np
|
||||
import onnxruntime as ort
|
||||
import soundfile as sf
|
||||
|
||||
|
||||
def display(sess, name):
|
||||
print(f"=========={name} Input==========")
|
||||
for i in sess.get_inputs():
|
||||
print(i)
|
||||
print(f"=========={name} Output==========")
|
||||
for i in sess.get_outputs():
|
||||
print(i)
|
||||
|
||||
|
||||
class OnnxModel:
|
||||
def __init__(
|
||||
self,
|
||||
preprocess: str,
|
||||
encode: str,
|
||||
uncached_decode: str,
|
||||
cached_decode: str,
|
||||
):
|
||||
self.init_preprocess(preprocess)
|
||||
display(self.preprocess, "preprocess")
|
||||
|
||||
self.init_encode(encode)
|
||||
display(self.encode, "encode")
|
||||
|
||||
self.init_uncached_decode(uncached_decode)
|
||||
display(self.uncached_decode, "uncached_decode")
|
||||
|
||||
self.init_cached_decode(cached_decode)
|
||||
display(self.cached_decode, "cached_decode")
|
||||
|
||||
def init_preprocess(self, preprocess):
|
||||
session_opts = ort.SessionOptions()
|
||||
session_opts.inter_op_num_threads = 1
|
||||
session_opts.intra_op_num_threads = 1
|
||||
|
||||
self.preprocess = ort.InferenceSession(
|
||||
preprocess,
|
||||
sess_options=session_opts,
|
||||
providers=["CPUExecutionProvider"],
|
||||
)
|
||||
|
||||
def init_encode(self, encode):
|
||||
session_opts = ort.SessionOptions()
|
||||
session_opts.inter_op_num_threads = 1
|
||||
session_opts.intra_op_num_threads = 1
|
||||
|
||||
self.encode = ort.InferenceSession(
|
||||
encode,
|
||||
sess_options=session_opts,
|
||||
providers=["CPUExecutionProvider"],
|
||||
)
|
||||
|
||||
def init_uncached_decode(self, uncached_decode):
|
||||
session_opts = ort.SessionOptions()
|
||||
session_opts.inter_op_num_threads = 1
|
||||
session_opts.intra_op_num_threads = 1
|
||||
|
||||
self.uncached_decode = ort.InferenceSession(
|
||||
uncached_decode,
|
||||
sess_options=session_opts,
|
||||
providers=["CPUExecutionProvider"],
|
||||
)
|
||||
|
||||
def init_cached_decode(self, cached_decode):
|
||||
session_opts = ort.SessionOptions()
|
||||
session_opts.inter_op_num_threads = 1
|
||||
session_opts.intra_op_num_threads = 1
|
||||
|
||||
self.cached_decode = ort.InferenceSession(
|
||||
cached_decode,
|
||||
sess_options=session_opts,
|
||||
providers=["CPUExecutionProvider"],
|
||||
)
|
||||
|
||||
def run_preprocess(self, audio):
|
||||
"""
|
||||
Args:
|
||||
audio: (batch_size, num_samples), float32
|
||||
Returns:
|
||||
A tensor of shape (batch_size, T, dim), float32
|
||||
"""
|
||||
return self.preprocess.run(
|
||||
[
|
||||
self.preprocess.get_outputs()[0].name,
|
||||
],
|
||||
{
|
||||
self.preprocess.get_inputs()[0].name: audio,
|
||||
},
|
||||
)[0]
|
||||
|
||||
def run_encode(self, features):
|
||||
"""
|
||||
Args:
|
||||
features: (batch_size, T, dim)
|
||||
Returns:
|
||||
A tensor of shape (batch_size, T, dim)
|
||||
"""
|
||||
features_len = np.array([features.shape[1]], dtype=np.int32)
|
||||
|
||||
return self.encode.run(
|
||||
[
|
||||
self.encode.get_outputs()[0].name,
|
||||
],
|
||||
{
|
||||
self.encode.get_inputs()[0].name: features,
|
||||
self.encode.get_inputs()[1].name: features_len,
|
||||
},
|
||||
)[0]
|
||||
|
||||
def run_uncached_decode(self, token: int, token_len: int, encoder_out: np.ndarray):
|
||||
"""
|
||||
Args:
|
||||
token: The current token
|
||||
token_len: Number of predicted tokens so far
|
||||
encoder_out: A tensor fo shape (batch_size, T, dim)
|
||||
Returns:
|
||||
A a tuple:
|
||||
- a tensor of shape (batch_size, 1, dim)
|
||||
- a list of states
|
||||
"""
|
||||
token_tensor = np.array([[token]], dtype=np.int32)
|
||||
token_len_tensor = np.array([token_len], dtype=np.int32)
|
||||
|
||||
num_outs = len(self.uncached_decode.get_outputs())
|
||||
out_names = [
|
||||
self.uncached_decode.get_outputs()[i].name for i in range(num_outs)
|
||||
]
|
||||
|
||||
out = self.uncached_decode.run(
|
||||
out_names,
|
||||
{
|
||||
self.uncached_decode.get_inputs()[0].name: token_tensor,
|
||||
self.uncached_decode.get_inputs()[1].name: encoder_out,
|
||||
self.uncached_decode.get_inputs()[2].name: token_len_tensor,
|
||||
},
|
||||
)
|
||||
|
||||
logits = out[0]
|
||||
states = out[1:]
|
||||
|
||||
return logits, states
|
||||
|
||||
def run_cached_decode(
|
||||
self, token: int, token_len: int, encoder_out: np.ndarray, states
|
||||
):
|
||||
"""
|
||||
Args:
|
||||
token: The current token
|
||||
token_len: Number of predicted tokens so far
|
||||
encoder_out: A tensor of shape (batch_size, T, dim)
|
||||
states: previous states
|
||||
Returns:
|
||||
A a tuple:
|
||||
- a tensor of shape (batch_size, 1, dim)
|
||||
- a list of states
|
||||
"""
|
||||
token_tensor = np.array([[token]], dtype=np.int32)
|
||||
token_len_tensor = np.array([token_len], dtype=np.int32)
|
||||
|
||||
num_outs = len(self.cached_decode.get_outputs())
|
||||
out_names = [self.cached_decode.get_outputs()[i].name for i in range(num_outs)]
|
||||
|
||||
states_inputs = {}
|
||||
for i in range(3, len(self.cached_decode.get_inputs())):
|
||||
name = self.cached_decode.get_inputs()[i].name
|
||||
states_inputs[name] = states[i - 3]
|
||||
|
||||
out = self.cached_decode.run(
|
||||
out_names,
|
||||
{
|
||||
self.cached_decode.get_inputs()[0].name: token_tensor,
|
||||
self.cached_decode.get_inputs()[1].name: encoder_out,
|
||||
self.cached_decode.get_inputs()[2].name: token_len_tensor,
|
||||
**states_inputs,
|
||||
},
|
||||
)
|
||||
|
||||
logits = out[0]
|
||||
states = out[1:]
|
||||
|
||||
return logits, states
|
||||
|
||||
|
||||
def main():
|
||||
wave = "./1.wav"
|
||||
id2token = dict()
|
||||
token2id = dict()
|
||||
with open("./tokens.txt", encoding="utf-8") as f:
|
||||
for k, line in enumerate(f):
|
||||
t, idx = line.split("\t")
|
||||
id2token[int(idx)] = t
|
||||
token2id[t] = int(idx)
|
||||
|
||||
model = OnnxModel(
|
||||
preprocess="./preprocess.onnx",
|
||||
encode="./encode.int8.onnx",
|
||||
uncached_decode="./uncached_decode.int8.onnx",
|
||||
cached_decode="./cached_decode.int8.onnx",
|
||||
)
|
||||
|
||||
audio, sample_rate = sf.read(wave, dtype="float32", always_2d=True)
|
||||
audio = audio[:, 0] # only use the first channel
|
||||
if sample_rate != 16000:
|
||||
audio = librosa.resample(
|
||||
audio,
|
||||
orig_sr=sample_rate,
|
||||
target_sr=16000,
|
||||
)
|
||||
sample_rate = 16000
|
||||
audio = audio[None] # (1, num_samples)
|
||||
print("audio.shape", audio.shape) # (1, 159414)
|
||||
|
||||
start_t = dt.datetime.now()
|
||||
|
||||
features = model.run_preprocess(audio) # (1, 413, 288)
|
||||
print("features", features.shape)
|
||||
|
||||
sos = token2id["<s>"]
|
||||
eos = token2id["</s>"]
|
||||
|
||||
tokens = [sos]
|
||||
|
||||
encoder_out = model.run_encode(features)
|
||||
print("encoder_out.shape", encoder_out.shape) # (1, 413, 288)
|
||||
|
||||
logits, states = model.run_uncached_decode(
|
||||
token=tokens[-1],
|
||||
token_len=len(tokens),
|
||||
encoder_out=encoder_out,
|
||||
)
|
||||
|
||||
print("logits.shape", logits.shape) # (1, 1, 32768)
|
||||
print("len(states)", len(states)) # 24
|
||||
|
||||
max_len = int((audio.shape[-1] / 16000) * 6)
|
||||
|
||||
for i in range(max_len):
|
||||
token = logits.squeeze().argmax()
|
||||
if token == eos:
|
||||
break
|
||||
tokens.append(token)
|
||||
|
||||
logits, states = model.run_cached_decode(
|
||||
token=tokens[-1],
|
||||
token_len=len(tokens),
|
||||
encoder_out=encoder_out,
|
||||
states=states,
|
||||
)
|
||||
|
||||
tokens = tokens[1:] # remove sos
|
||||
words = [id2token[i] for i in tokens]
|
||||
underline = "▁"
|
||||
# underline = b"\xe2\x96\x81".decode()
|
||||
text = "".join(words).replace(underline, " ").strip()
|
||||
|
||||
end_t = dt.datetime.now()
|
||||
t = (end_t - start_t).total_seconds()
|
||||
rtf = t * 16000 / audio.shape[-1]
|
||||
|
||||
print(text)
|
||||
print("RTF:", rtf)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
Reference in New Issue
Block a user