Support whisper large/large-v1/large-v2/large-v3 and distil-large-v2 (#1114)

This commit is contained in:
Fangjun Kuang
2024-07-12 23:47:39 +08:00
committed by GitHub
parent d928f77d0e
commit 117cd7bb8c
23 changed files with 152 additions and 85 deletions

View File

@@ -9,9 +9,10 @@ import base64
from typing import Tuple
import kaldi_native_fbank as knf
import numpy as np
import onnxruntime as ort
import soundfile as sf
import torch
import torchaudio
def get_args():
@@ -98,7 +99,6 @@ class OnnxModel:
self.blank = int(meta["blank_id"])
self.sot_sequence = list(map(int, meta["sot_sequence"].split(",")))
self.sot_sequence.append(self.no_timestamps)
self.all_language_tokens = list(
@@ -226,7 +226,18 @@ def load_tokens(filename):
return tokens
def compute_features(filename: str) -> torch.Tensor:
def load_audio(filename: str) -> Tuple[np.ndarray, int]:
data, sample_rate = sf.read(
filename,
always_2d=True,
dtype="float32",
)
data = data[:, 0] # use only the first channel
samples = np.ascontiguousarray(data)
return samples, sample_rate
def compute_features(filename: str, dim: int = 80) -> torch.Tensor:
"""
Args:
filename:
@@ -234,16 +245,18 @@ def compute_features(filename: str) -> torch.Tensor:
Returns:
Return a 1-D float32 tensor of shape (1, 80, 3000) containing the features.
"""
wave, sample_rate = torchaudio.load(filename)
audio = wave[0].contiguous() # only use the first channel
wave, sample_rate = load_audio(filename)
if sample_rate != 16000:
audio = torchaudio.functional.resample(
audio, orig_freq=sample_rate, new_freq=16000
)
import librosa
wave = librosa.resample(wave, orig_sr=sample_rate, target_sr=16000)
sample_rate = 16000
features = []
online_whisper_fbank = knf.OnlineWhisperFbank(knf.FrameExtractionOptions())
online_whisper_fbank.accept_waveform(16000, audio.numpy())
opts = knf.WhisperFeatureOptions()
opts.dim = dim
online_whisper_fbank = knf.OnlineWhisperFbank(opts)
online_whisper_fbank.accept_waveform(16000, wave)
online_whisper_fbank.input_finished()
for i in range(online_whisper_fbank.num_frames_ready):
f = online_whisper_fbank.get_frame(i)
@@ -280,8 +293,9 @@ def compute_features(filename: str) -> torch.Tensor:
def main():
args = get_args()
mel = compute_features(args.sound_file)
model = OnnxModel(args.encoder, args.decoder)
dim = 80 if "large-v3" not in args.encoder else 128
mel = compute_features(args.sound_file, dim=dim)
n_layer_cross_k, n_layer_cross_v = model.run_encoder(mel)
@@ -313,6 +327,7 @@ def main():
n_layer_self_k_cache, n_layer_self_v_cache = model.get_self_cache()
print(model.sot_sequence)
tokens = torch.tensor([model.sot_sequence], dtype=torch.int64)
offset = torch.zeros(1, dtype=torch.int64)
logits, n_layer_self_k_cache, n_layer_self_v_cache = model.run_decoder(