diff --git a/scripts/whisper/test.py b/scripts/whisper/test.py index 03e5e32c..d93da482 100755 --- a/scripts/whisper/test.py +++ b/scripts/whisper/test.py @@ -4,15 +4,14 @@ Please first run ./export-onnx.py before you run this script """ +import argparse import base64 from typing import Tuple import kaldi_native_fbank as knf import onnxruntime as ort import torch - -import whisper -import argparse +import torchaudio def get_args(): @@ -225,16 +224,24 @@ def load_tokens(filename): return tokens -def main(): - args = get_args() - encoder = args.encoder - decoder = args.decoder - - audio = whisper.load_audio(args.sound_file) +def compute_features(filename: str) -> torch.Tensor: + """ + Args: + filename: + Path to an audio file. + Returns: + Return a 1-D float32 tensor of shape (1, 80, 3000) containing the features. + """ + wave, sample_rate = torchaudio.load(filename) + audio = wave[0].contiguous() # only use the first channel + if sample_rate != 16000: + audio = torchaudio.functional.resample( + audio, orig_freq=sample_rate, new_freq=16000 + ) features = [] online_whisper_fbank = knf.OnlineWhisperFbank(knf.FrameExtractionOptions()) - online_whisper_fbank.accept_waveform(16000, audio) + online_whisper_fbank.accept_waveform(16000, audio.numpy()) online_whisper_fbank.input_finished() for i in range(online_whisper_fbank.num_frames_ready): f = online_whisper_fbank.get_frame(i) @@ -250,7 +257,14 @@ def main(): mel = torch.nn.functional.pad(mel, (0, 0, 0, target - mel.shape[0]), "constant", 0) mel = mel.t().unsqueeze(0) - model = OnnxModel(encoder, decoder) + return mel + + +def main(): + args = get_args() + + mel = compute_features(args.sound_file) + model = OnnxModel(args.encoder, args.decoder) n_layer_cross_k, n_layer_cross_v = model.run_encoder(mel)