Begin to support https://github.com/usefulsensors/moonshine (#1470)
This commit is contained in:
106
.github/workflows/export-moonshine-to-onnx.yaml
vendored
Normal file
106
.github/workflows/export-moonshine-to-onnx.yaml
vendored
Normal file
@@ -0,0 +1,106 @@
|
|||||||
|
name: export-moonshine-to-onnx
|
||||||
|
|
||||||
|
on:
|
||||||
|
workflow_dispatch:
|
||||||
|
|
||||||
|
concurrency:
|
||||||
|
group: export-moonshine-to-onnx-${{ github.ref }}
|
||||||
|
cancel-in-progress: true
|
||||||
|
|
||||||
|
jobs:
|
||||||
|
export-moonshine-to-onnx:
|
||||||
|
if: github.repository_owner == 'k2-fsa' || github.repository_owner == 'csukuangfj'
|
||||||
|
name: export moonshine models to ONNX
|
||||||
|
runs-on: ${{ matrix.os }}
|
||||||
|
strategy:
|
||||||
|
fail-fast: false
|
||||||
|
matrix:
|
||||||
|
os: [macos-latest]
|
||||||
|
python-version: ["3.10"]
|
||||||
|
|
||||||
|
steps:
|
||||||
|
- uses: actions/checkout@v4
|
||||||
|
|
||||||
|
- name: Setup Python ${{ matrix.python-version }}
|
||||||
|
uses: actions/setup-python@v5
|
||||||
|
with:
|
||||||
|
python-version: ${{ matrix.python-version }}
|
||||||
|
|
||||||
|
- name: Install Python dependencies
|
||||||
|
shell: bash
|
||||||
|
run: |
|
||||||
|
pip install -q onnx onnxruntime librosa tokenizers soundfile
|
||||||
|
|
||||||
|
- name: Run
|
||||||
|
shell: bash
|
||||||
|
run: |
|
||||||
|
pushd scripts/moonshine
|
||||||
|
./run.sh
|
||||||
|
popd
|
||||||
|
|
||||||
|
mv -v scripts/moonshine/*.tar.bz2 .
|
||||||
|
mv -v scripts/moonshine/sherpa-onnx-* ./
|
||||||
|
|
||||||
|
- name: Release
|
||||||
|
uses: svenstaro/upload-release-action@v2
|
||||||
|
with:
|
||||||
|
file_glob: true
|
||||||
|
file: ./*.tar.bz2
|
||||||
|
overwrite: true
|
||||||
|
repo_name: k2-fsa/sherpa-onnx
|
||||||
|
repo_token: ${{ secrets.UPLOAD_GH_SHERPA_ONNX_TOKEN }}
|
||||||
|
tag: asr-models
|
||||||
|
|
||||||
|
- name: Publish to huggingface (tiny)
|
||||||
|
env:
|
||||||
|
HF_TOKEN: ${{ secrets.HF_TOKEN }}
|
||||||
|
uses: nick-fields/retry@v3
|
||||||
|
with:
|
||||||
|
max_attempts: 20
|
||||||
|
timeout_seconds: 200
|
||||||
|
shell: bash
|
||||||
|
command: |
|
||||||
|
git config --global user.email "csukuangfj@gmail.com"
|
||||||
|
git config --global user.name "Fangjun Kuang"
|
||||||
|
|
||||||
|
d=sherpa-onnx-moonshine-tiny-en-int8
|
||||||
|
export GIT_LFS_SKIP_SMUDGE=1
|
||||||
|
export GIT_CLONE_PROTECTION_ACTIVE=false
|
||||||
|
git clone https://csukuangfj:$HF_TOKEN@huggingface.co/csukuangfj/$d huggingface
|
||||||
|
mv -v $d/* ./huggingface
|
||||||
|
cd huggingface
|
||||||
|
git lfs track "*.onnx"
|
||||||
|
git lfs track "*.wav"
|
||||||
|
git status
|
||||||
|
git add .
|
||||||
|
git status
|
||||||
|
git commit -m "add models"
|
||||||
|
git push https://csukuangfj:$HF_TOKEN@huggingface.co/csukuangfj/$d main
|
||||||
|
rm -rf huggingface
|
||||||
|
|
||||||
|
- name: Publish to huggingface (base)
|
||||||
|
env:
|
||||||
|
HF_TOKEN: ${{ secrets.HF_TOKEN }}
|
||||||
|
uses: nick-fields/retry@v3
|
||||||
|
with:
|
||||||
|
max_attempts: 20
|
||||||
|
timeout_seconds: 200
|
||||||
|
shell: bash
|
||||||
|
command: |
|
||||||
|
git config --global user.email "csukuangfj@gmail.com"
|
||||||
|
git config --global user.name "Fangjun Kuang"
|
||||||
|
|
||||||
|
d=sherpa-onnx-moonshine-base-en-int8
|
||||||
|
export GIT_LFS_SKIP_SMUDGE=1
|
||||||
|
export GIT_CLONE_PROTECTION_ACTIVE=false
|
||||||
|
git clone https://csukuangfj:$HF_TOKEN@huggingface.co/csukuangfj/$d huggingface
|
||||||
|
mv -v $d/* ./huggingface
|
||||||
|
cd huggingface
|
||||||
|
git lfs track "*.onnx"
|
||||||
|
git lfs track "*.wav"
|
||||||
|
git status
|
||||||
|
git add .
|
||||||
|
git status
|
||||||
|
git commit -m "add models"
|
||||||
|
git push https://csukuangfj:$HF_TOKEN@huggingface.co/csukuangfj/$d main
|
||||||
|
rm -rf huggingface
|
||||||
1
scripts/moonshine/.gitignore
vendored
Normal file
1
scripts/moonshine/.gitignore
vendored
Normal file
@@ -0,0 +1 @@
|
|||||||
|
tokenizer.json
|
||||||
7
scripts/moonshine/README.md
Normal file
7
scripts/moonshine/README.md
Normal file
@@ -0,0 +1,7 @@
|
|||||||
|
# Introduction
|
||||||
|
|
||||||
|
This directory contains models from
|
||||||
|
https://github.com/usefulsensors/moonshine
|
||||||
|
|
||||||
|
See its license at
|
||||||
|
https://github.com/usefulsensors/moonshine/blob/main/LICENSE
|
||||||
40
scripts/moonshine/export-onnx.py
Executable file
40
scripts/moonshine/export-onnx.py
Executable file
@@ -0,0 +1,40 @@
|
|||||||
|
#!/usr/bin/env python3
|
||||||
|
# Copyright 2024 Xiaomi Corp. (authors: Fangjun Kuang)
|
||||||
|
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
import tokenizers
|
||||||
|
from onnxruntime.quantization import QuantType, quantize_dynamic
|
||||||
|
|
||||||
|
|
||||||
|
def generate_tokens():
|
||||||
|
if Path("./tokens.txt").is_file():
|
||||||
|
return
|
||||||
|
print("Generating tokens.txt")
|
||||||
|
tokenizer = tokenizers.Tokenizer.from_file("./tokenizer.json")
|
||||||
|
vocab_size = tokenizer.get_vocab_size()
|
||||||
|
with open("tokens.txt", "w", encoding="utf-8") as f:
|
||||||
|
for i in range(vocab_size):
|
||||||
|
s = tokenizer.id_to_token(i).strip()
|
||||||
|
f.write(f"{s}\t{i}\n")
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
generate_tokens()
|
||||||
|
|
||||||
|
# Note(fangjun): Don't use int8 for the preprocessor since it has
|
||||||
|
# a larger impact on the accuracy
|
||||||
|
for f in ["uncached_decode", "cached_decode", "encode"]:
|
||||||
|
if Path(f"{f}.int8.onnx").is_file():
|
||||||
|
continue
|
||||||
|
|
||||||
|
print("processing", f)
|
||||||
|
quantize_dynamic(
|
||||||
|
model_input=f"{f}.onnx",
|
||||||
|
model_output=f"{f}.int8.onnx",
|
||||||
|
weight_type=QuantType.QInt8,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
||||||
90
scripts/moonshine/run.sh
Executable file
90
scripts/moonshine/run.sh
Executable file
@@ -0,0 +1,90 @@
|
|||||||
|
#!/usr/bin/env bash
|
||||||
|
# Copyright 2024 Xiaomi Corp. (authors: Fangjun Kuang)
|
||||||
|
set -ex
|
||||||
|
|
||||||
|
cat >LICENSE <<EOF
|
||||||
|
MIT License
|
||||||
|
|
||||||
|
Copyright (c) 2024 Useful Sensors
|
||||||
|
|
||||||
|
Permission is hereby granted, free of charge, to any person obtaining a copy
|
||||||
|
of this software and associated documentation files (the "Software"), to deal
|
||||||
|
in the Software without restriction, including without limitation the rights
|
||||||
|
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
||||||
|
copies of the Software, and to permit persons to whom the Software is
|
||||||
|
furnished to do so, subject to the following conditions:
|
||||||
|
|
||||||
|
The above copyright notice and this permission notice shall be included in all
|
||||||
|
copies or substantial portions of the Software.
|
||||||
|
|
||||||
|
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||||
|
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||||
|
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||||||
|
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||||
|
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
||||||
|
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
||||||
|
SOFTWARE.
|
||||||
|
EOF
|
||||||
|
|
||||||
|
function download_files() {
|
||||||
|
for d in tiny base; do
|
||||||
|
mkdir $d
|
||||||
|
|
||||||
|
pushd $d
|
||||||
|
curl -SL -O https://huggingface.co/UsefulSensors/moonshine/resolve/main/onnx/$d/preprocess.onnx
|
||||||
|
curl -SL -O https://huggingface.co/UsefulSensors/moonshine/resolve/main/onnx/$d/encode.onnx
|
||||||
|
curl -SL -O https://huggingface.co/UsefulSensors/moonshine/resolve/main/onnx/$d/uncached_decode.onnx
|
||||||
|
curl -SL -O https://huggingface.co/UsefulSensors/moonshine/resolve/main/onnx/$d/cached_decode.onnx
|
||||||
|
popd
|
||||||
|
done
|
||||||
|
|
||||||
|
curl -SL -O https://huggingface.co/csukuangfj/sherpa-onnx-whisper-base/resolve/main/test_wavs/0.wav
|
||||||
|
curl -SL -O https://huggingface.co/csukuangfj/sherpa-onnx-whisper-base/resolve/main/test_wavs/1.wav
|
||||||
|
curl -SL -O https://huggingface.co/csukuangfj/sherpa-onnx-whisper-base/resolve/main/test_wavs/8k.wav
|
||||||
|
curl -SL -O https://huggingface.co/csukuangfj/sherpa-onnx-whisper-base/resolve/main/test_wavs/trans.txt
|
||||||
|
|
||||||
|
curl -SL -O https://raw.githubusercontent.com/usefulsensors/moonshine/refs/heads/main/moonshine/assets/tokenizer.json
|
||||||
|
}
|
||||||
|
|
||||||
|
function quantize() {
|
||||||
|
for d in tiny base; do
|
||||||
|
echo "==========$d=========="
|
||||||
|
ls -lh
|
||||||
|
mv $d/*.onnx .
|
||||||
|
./export-onnx.py
|
||||||
|
rm cached_decode.onnx
|
||||||
|
rm uncached_decode.onnx
|
||||||
|
rm encode.onnx
|
||||||
|
ls -lh
|
||||||
|
|
||||||
|
./test.py
|
||||||
|
|
||||||
|
mv *.onnx $d
|
||||||
|
mv tokens.txt $d
|
||||||
|
ls -lh $d
|
||||||
|
|
||||||
|
done
|
||||||
|
}
|
||||||
|
|
||||||
|
function zip() {
|
||||||
|
for d in tiny base; do
|
||||||
|
s=sherpa-onnx-moonshine-$d-en-int8
|
||||||
|
mv $d $s
|
||||||
|
|
||||||
|
mkdir $s/test_wavs
|
||||||
|
|
||||||
|
cp -v *.wav $s/test_wavs
|
||||||
|
cp trans.txt $s/test_wavs
|
||||||
|
cp LICENSE $s/
|
||||||
|
cp ./README.md $s
|
||||||
|
|
||||||
|
ls -lh $s
|
||||||
|
tar cjfv $s.tar.bz2 $s
|
||||||
|
done
|
||||||
|
}
|
||||||
|
|
||||||
|
download_files
|
||||||
|
quantize
|
||||||
|
zip
|
||||||
|
|
||||||
|
ls -lh
|
||||||
274
scripts/moonshine/test.py
Executable file
274
scripts/moonshine/test.py
Executable file
@@ -0,0 +1,274 @@
|
|||||||
|
#!/usr/bin/env python3
|
||||||
|
# Copyright 2024 Xiaomi Corp. (authors: Fangjun Kuang)
|
||||||
|
import datetime as dt
|
||||||
|
|
||||||
|
import librosa
|
||||||
|
import numpy as np
|
||||||
|
import onnxruntime as ort
|
||||||
|
import soundfile as sf
|
||||||
|
|
||||||
|
|
||||||
|
def display(sess, name):
|
||||||
|
print(f"=========={name} Input==========")
|
||||||
|
for i in sess.get_inputs():
|
||||||
|
print(i)
|
||||||
|
print(f"=========={name} Output==========")
|
||||||
|
for i in sess.get_outputs():
|
||||||
|
print(i)
|
||||||
|
|
||||||
|
|
||||||
|
class OnnxModel:
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
preprocess: str,
|
||||||
|
encode: str,
|
||||||
|
uncached_decode: str,
|
||||||
|
cached_decode: str,
|
||||||
|
):
|
||||||
|
self.init_preprocess(preprocess)
|
||||||
|
display(self.preprocess, "preprocess")
|
||||||
|
|
||||||
|
self.init_encode(encode)
|
||||||
|
display(self.encode, "encode")
|
||||||
|
|
||||||
|
self.init_uncached_decode(uncached_decode)
|
||||||
|
display(self.uncached_decode, "uncached_decode")
|
||||||
|
|
||||||
|
self.init_cached_decode(cached_decode)
|
||||||
|
display(self.cached_decode, "cached_decode")
|
||||||
|
|
||||||
|
def init_preprocess(self, preprocess):
|
||||||
|
session_opts = ort.SessionOptions()
|
||||||
|
session_opts.inter_op_num_threads = 1
|
||||||
|
session_opts.intra_op_num_threads = 1
|
||||||
|
|
||||||
|
self.preprocess = ort.InferenceSession(
|
||||||
|
preprocess,
|
||||||
|
sess_options=session_opts,
|
||||||
|
providers=["CPUExecutionProvider"],
|
||||||
|
)
|
||||||
|
|
||||||
|
def init_encode(self, encode):
|
||||||
|
session_opts = ort.SessionOptions()
|
||||||
|
session_opts.inter_op_num_threads = 1
|
||||||
|
session_opts.intra_op_num_threads = 1
|
||||||
|
|
||||||
|
self.encode = ort.InferenceSession(
|
||||||
|
encode,
|
||||||
|
sess_options=session_opts,
|
||||||
|
providers=["CPUExecutionProvider"],
|
||||||
|
)
|
||||||
|
|
||||||
|
def init_uncached_decode(self, uncached_decode):
|
||||||
|
session_opts = ort.SessionOptions()
|
||||||
|
session_opts.inter_op_num_threads = 1
|
||||||
|
session_opts.intra_op_num_threads = 1
|
||||||
|
|
||||||
|
self.uncached_decode = ort.InferenceSession(
|
||||||
|
uncached_decode,
|
||||||
|
sess_options=session_opts,
|
||||||
|
providers=["CPUExecutionProvider"],
|
||||||
|
)
|
||||||
|
|
||||||
|
def init_cached_decode(self, cached_decode):
|
||||||
|
session_opts = ort.SessionOptions()
|
||||||
|
session_opts.inter_op_num_threads = 1
|
||||||
|
session_opts.intra_op_num_threads = 1
|
||||||
|
|
||||||
|
self.cached_decode = ort.InferenceSession(
|
||||||
|
cached_decode,
|
||||||
|
sess_options=session_opts,
|
||||||
|
providers=["CPUExecutionProvider"],
|
||||||
|
)
|
||||||
|
|
||||||
|
def run_preprocess(self, audio):
|
||||||
|
"""
|
||||||
|
Args:
|
||||||
|
audio: (batch_size, num_samples), float32
|
||||||
|
Returns:
|
||||||
|
A tensor of shape (batch_size, T, dim), float32
|
||||||
|
"""
|
||||||
|
return self.preprocess.run(
|
||||||
|
[
|
||||||
|
self.preprocess.get_outputs()[0].name,
|
||||||
|
],
|
||||||
|
{
|
||||||
|
self.preprocess.get_inputs()[0].name: audio,
|
||||||
|
},
|
||||||
|
)[0]
|
||||||
|
|
||||||
|
def run_encode(self, features):
|
||||||
|
"""
|
||||||
|
Args:
|
||||||
|
features: (batch_size, T, dim)
|
||||||
|
Returns:
|
||||||
|
A tensor of shape (batch_size, T, dim)
|
||||||
|
"""
|
||||||
|
features_len = np.array([features.shape[1]], dtype=np.int32)
|
||||||
|
|
||||||
|
return self.encode.run(
|
||||||
|
[
|
||||||
|
self.encode.get_outputs()[0].name,
|
||||||
|
],
|
||||||
|
{
|
||||||
|
self.encode.get_inputs()[0].name: features,
|
||||||
|
self.encode.get_inputs()[1].name: features_len,
|
||||||
|
},
|
||||||
|
)[0]
|
||||||
|
|
||||||
|
def run_uncached_decode(self, token: int, token_len: int, encoder_out: np.ndarray):
|
||||||
|
"""
|
||||||
|
Args:
|
||||||
|
token: The current token
|
||||||
|
token_len: Number of predicted tokens so far
|
||||||
|
encoder_out: A tensor fo shape (batch_size, T, dim)
|
||||||
|
Returns:
|
||||||
|
A a tuple:
|
||||||
|
- a tensor of shape (batch_size, 1, dim)
|
||||||
|
- a list of states
|
||||||
|
"""
|
||||||
|
token_tensor = np.array([[token]], dtype=np.int32)
|
||||||
|
token_len_tensor = np.array([token_len], dtype=np.int32)
|
||||||
|
|
||||||
|
num_outs = len(self.uncached_decode.get_outputs())
|
||||||
|
out_names = [
|
||||||
|
self.uncached_decode.get_outputs()[i].name for i in range(num_outs)
|
||||||
|
]
|
||||||
|
|
||||||
|
out = self.uncached_decode.run(
|
||||||
|
out_names,
|
||||||
|
{
|
||||||
|
self.uncached_decode.get_inputs()[0].name: token_tensor,
|
||||||
|
self.uncached_decode.get_inputs()[1].name: encoder_out,
|
||||||
|
self.uncached_decode.get_inputs()[2].name: token_len_tensor,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
logits = out[0]
|
||||||
|
states = out[1:]
|
||||||
|
|
||||||
|
return logits, states
|
||||||
|
|
||||||
|
def run_cached_decode(
|
||||||
|
self, token: int, token_len: int, encoder_out: np.ndarray, states
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Args:
|
||||||
|
token: The current token
|
||||||
|
token_len: Number of predicted tokens so far
|
||||||
|
encoder_out: A tensor of shape (batch_size, T, dim)
|
||||||
|
states: previous states
|
||||||
|
Returns:
|
||||||
|
A a tuple:
|
||||||
|
- a tensor of shape (batch_size, 1, dim)
|
||||||
|
- a list of states
|
||||||
|
"""
|
||||||
|
token_tensor = np.array([[token]], dtype=np.int32)
|
||||||
|
token_len_tensor = np.array([token_len], dtype=np.int32)
|
||||||
|
|
||||||
|
num_outs = len(self.cached_decode.get_outputs())
|
||||||
|
out_names = [self.cached_decode.get_outputs()[i].name for i in range(num_outs)]
|
||||||
|
|
||||||
|
states_inputs = {}
|
||||||
|
for i in range(3, len(self.cached_decode.get_inputs())):
|
||||||
|
name = self.cached_decode.get_inputs()[i].name
|
||||||
|
states_inputs[name] = states[i - 3]
|
||||||
|
|
||||||
|
out = self.cached_decode.run(
|
||||||
|
out_names,
|
||||||
|
{
|
||||||
|
self.cached_decode.get_inputs()[0].name: token_tensor,
|
||||||
|
self.cached_decode.get_inputs()[1].name: encoder_out,
|
||||||
|
self.cached_decode.get_inputs()[2].name: token_len_tensor,
|
||||||
|
**states_inputs,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
logits = out[0]
|
||||||
|
states = out[1:]
|
||||||
|
|
||||||
|
return logits, states
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
wave = "./1.wav"
|
||||||
|
id2token = dict()
|
||||||
|
token2id = dict()
|
||||||
|
with open("./tokens.txt", encoding="utf-8") as f:
|
||||||
|
for k, line in enumerate(f):
|
||||||
|
t, idx = line.split("\t")
|
||||||
|
id2token[int(idx)] = t
|
||||||
|
token2id[t] = int(idx)
|
||||||
|
|
||||||
|
model = OnnxModel(
|
||||||
|
preprocess="./preprocess.onnx",
|
||||||
|
encode="./encode.int8.onnx",
|
||||||
|
uncached_decode="./uncached_decode.int8.onnx",
|
||||||
|
cached_decode="./cached_decode.int8.onnx",
|
||||||
|
)
|
||||||
|
|
||||||
|
audio, sample_rate = sf.read(wave, dtype="float32", always_2d=True)
|
||||||
|
audio = audio[:, 0] # only use the first channel
|
||||||
|
if sample_rate != 16000:
|
||||||
|
audio = librosa.resample(
|
||||||
|
audio,
|
||||||
|
orig_sr=sample_rate,
|
||||||
|
target_sr=16000,
|
||||||
|
)
|
||||||
|
sample_rate = 16000
|
||||||
|
audio = audio[None] # (1, num_samples)
|
||||||
|
print("audio.shape", audio.shape) # (1, 159414)
|
||||||
|
|
||||||
|
start_t = dt.datetime.now()
|
||||||
|
|
||||||
|
features = model.run_preprocess(audio) # (1, 413, 288)
|
||||||
|
print("features", features.shape)
|
||||||
|
|
||||||
|
sos = token2id["<s>"]
|
||||||
|
eos = token2id["</s>"]
|
||||||
|
|
||||||
|
tokens = [sos]
|
||||||
|
|
||||||
|
encoder_out = model.run_encode(features)
|
||||||
|
print("encoder_out.shape", encoder_out.shape) # (1, 413, 288)
|
||||||
|
|
||||||
|
logits, states = model.run_uncached_decode(
|
||||||
|
token=tokens[-1],
|
||||||
|
token_len=len(tokens),
|
||||||
|
encoder_out=encoder_out,
|
||||||
|
)
|
||||||
|
|
||||||
|
print("logits.shape", logits.shape) # (1, 1, 32768)
|
||||||
|
print("len(states)", len(states)) # 24
|
||||||
|
|
||||||
|
max_len = int((audio.shape[-1] / 16000) * 6)
|
||||||
|
|
||||||
|
for i in range(max_len):
|
||||||
|
token = logits.squeeze().argmax()
|
||||||
|
if token == eos:
|
||||||
|
break
|
||||||
|
tokens.append(token)
|
||||||
|
|
||||||
|
logits, states = model.run_cached_decode(
|
||||||
|
token=tokens[-1],
|
||||||
|
token_len=len(tokens),
|
||||||
|
encoder_out=encoder_out,
|
||||||
|
states=states,
|
||||||
|
)
|
||||||
|
|
||||||
|
tokens = tokens[1:] # remove sos
|
||||||
|
words = [id2token[i] for i in tokens]
|
||||||
|
underline = "▁"
|
||||||
|
# underline = b"\xe2\x96\x81".decode()
|
||||||
|
text = "".join(words).replace(underline, " ").strip()
|
||||||
|
|
||||||
|
end_t = dt.datetime.now()
|
||||||
|
t = (end_t - start_t).total_seconds()
|
||||||
|
rtf = t * 16000 / audio.shape[-1]
|
||||||
|
|
||||||
|
print(text)
|
||||||
|
print("RTF:", rtf)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
||||||
Reference in New Issue
Block a user