Support whisper large/large-v1/large-v2/large-v3 and distil-large-v2 (#1114)
This commit is contained in:
@@ -9,9 +9,10 @@ import base64
|
||||
from typing import Tuple
|
||||
|
||||
import kaldi_native_fbank as knf
|
||||
import numpy as np
|
||||
import onnxruntime as ort
|
||||
import soundfile as sf
|
||||
import torch
|
||||
import torchaudio
|
||||
|
||||
|
||||
def get_args():
|
||||
@@ -98,7 +99,6 @@ class OnnxModel:
|
||||
self.blank = int(meta["blank_id"])
|
||||
|
||||
self.sot_sequence = list(map(int, meta["sot_sequence"].split(",")))
|
||||
|
||||
self.sot_sequence.append(self.no_timestamps)
|
||||
|
||||
self.all_language_tokens = list(
|
||||
@@ -226,7 +226,18 @@ def load_tokens(filename):
|
||||
return tokens
|
||||
|
||||
|
||||
def compute_features(filename: str) -> torch.Tensor:
|
||||
def load_audio(filename: str) -> Tuple[np.ndarray, int]:
|
||||
data, sample_rate = sf.read(
|
||||
filename,
|
||||
always_2d=True,
|
||||
dtype="float32",
|
||||
)
|
||||
data = data[:, 0] # use only the first channel
|
||||
samples = np.ascontiguousarray(data)
|
||||
return samples, sample_rate
|
||||
|
||||
|
||||
def compute_features(filename: str, dim: int = 80) -> torch.Tensor:
|
||||
"""
|
||||
Args:
|
||||
filename:
|
||||
@@ -234,16 +245,18 @@ def compute_features(filename: str) -> torch.Tensor:
|
||||
Returns:
|
||||
Return a 1-D float32 tensor of shape (1, 80, 3000) containing the features.
|
||||
"""
|
||||
wave, sample_rate = torchaudio.load(filename)
|
||||
audio = wave[0].contiguous() # only use the first channel
|
||||
wave, sample_rate = load_audio(filename)
|
||||
if sample_rate != 16000:
|
||||
audio = torchaudio.functional.resample(
|
||||
audio, orig_freq=sample_rate, new_freq=16000
|
||||
)
|
||||
import librosa
|
||||
|
||||
wave = librosa.resample(wave, orig_sr=sample_rate, target_sr=16000)
|
||||
sample_rate = 16000
|
||||
|
||||
features = []
|
||||
online_whisper_fbank = knf.OnlineWhisperFbank(knf.FrameExtractionOptions())
|
||||
online_whisper_fbank.accept_waveform(16000, audio.numpy())
|
||||
opts = knf.WhisperFeatureOptions()
|
||||
opts.dim = dim
|
||||
online_whisper_fbank = knf.OnlineWhisperFbank(opts)
|
||||
online_whisper_fbank.accept_waveform(16000, wave)
|
||||
online_whisper_fbank.input_finished()
|
||||
for i in range(online_whisper_fbank.num_frames_ready):
|
||||
f = online_whisper_fbank.get_frame(i)
|
||||
@@ -280,8 +293,9 @@ def compute_features(filename: str) -> torch.Tensor:
|
||||
def main():
|
||||
args = get_args()
|
||||
|
||||
mel = compute_features(args.sound_file)
|
||||
model = OnnxModel(args.encoder, args.decoder)
|
||||
dim = 80 if "large-v3" not in args.encoder else 128
|
||||
mel = compute_features(args.sound_file, dim=dim)
|
||||
|
||||
n_layer_cross_k, n_layer_cross_v = model.run_encoder(mel)
|
||||
|
||||
@@ -313,6 +327,7 @@ def main():
|
||||
|
||||
n_layer_self_k_cache, n_layer_self_v_cache = model.get_self_cache()
|
||||
|
||||
print(model.sot_sequence)
|
||||
tokens = torch.tensor([model.sot_sequence], dtype=torch.int64)
|
||||
offset = torch.zeros(1, dtype=torch.int64)
|
||||
logits, n_layer_self_k_cache, n_layer_self_v_cache = model.run_decoder(
|
||||
|
||||
Reference in New Issue
Block a user