Remove whisper dependency from the whisper Python example (#283)
This commit is contained in:
@@ -4,15 +4,14 @@
|
|||||||
Please first run ./export-onnx.py
|
Please first run ./export-onnx.py
|
||||||
before you run this script
|
before you run this script
|
||||||
"""
|
"""
|
||||||
|
import argparse
|
||||||
import base64
|
import base64
|
||||||
from typing import Tuple
|
from typing import Tuple
|
||||||
|
|
||||||
import kaldi_native_fbank as knf
|
import kaldi_native_fbank as knf
|
||||||
import onnxruntime as ort
|
import onnxruntime as ort
|
||||||
import torch
|
import torch
|
||||||
|
import torchaudio
|
||||||
import whisper
|
|
||||||
import argparse
|
|
||||||
|
|
||||||
|
|
||||||
def get_args():
|
def get_args():
|
||||||
@@ -225,16 +224,24 @@ def load_tokens(filename):
|
|||||||
return tokens
|
return tokens
|
||||||
|
|
||||||
|
|
||||||
def main():
|
def compute_features(filename: str) -> torch.Tensor:
|
||||||
args = get_args()
|
"""
|
||||||
encoder = args.encoder
|
Args:
|
||||||
decoder = args.decoder
|
filename:
|
||||||
|
Path to an audio file.
|
||||||
audio = whisper.load_audio(args.sound_file)
|
Returns:
|
||||||
|
Return a 1-D float32 tensor of shape (1, 80, 3000) containing the features.
|
||||||
|
"""
|
||||||
|
wave, sample_rate = torchaudio.load(filename)
|
||||||
|
audio = wave[0].contiguous() # only use the first channel
|
||||||
|
if sample_rate != 16000:
|
||||||
|
audio = torchaudio.functional.resample(
|
||||||
|
audio, orig_freq=sample_rate, new_freq=16000
|
||||||
|
)
|
||||||
|
|
||||||
features = []
|
features = []
|
||||||
online_whisper_fbank = knf.OnlineWhisperFbank(knf.FrameExtractionOptions())
|
online_whisper_fbank = knf.OnlineWhisperFbank(knf.FrameExtractionOptions())
|
||||||
online_whisper_fbank.accept_waveform(16000, audio)
|
online_whisper_fbank.accept_waveform(16000, audio.numpy())
|
||||||
online_whisper_fbank.input_finished()
|
online_whisper_fbank.input_finished()
|
||||||
for i in range(online_whisper_fbank.num_frames_ready):
|
for i in range(online_whisper_fbank.num_frames_ready):
|
||||||
f = online_whisper_fbank.get_frame(i)
|
f = online_whisper_fbank.get_frame(i)
|
||||||
@@ -250,7 +257,14 @@ def main():
|
|||||||
mel = torch.nn.functional.pad(mel, (0, 0, 0, target - mel.shape[0]), "constant", 0)
|
mel = torch.nn.functional.pad(mel, (0, 0, 0, target - mel.shape[0]), "constant", 0)
|
||||||
mel = mel.t().unsqueeze(0)
|
mel = mel.t().unsqueeze(0)
|
||||||
|
|
||||||
model = OnnxModel(encoder, decoder)
|
return mel
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
args = get_args()
|
||||||
|
|
||||||
|
mel = compute_features(args.sound_file)
|
||||||
|
model = OnnxModel(args.encoder, args.decoder)
|
||||||
|
|
||||||
n_layer_cross_k, n_layer_cross_v = model.run_encoder(mel)
|
n_layer_cross_k, n_layer_cross_v = model.run_encoder(mel)
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user