Support whisper models (#238)
This commit is contained in:
4
scripts/whisper/.gitignore
vendored
Normal file
4
scripts/whisper/.gitignore
vendored
Normal file
@@ -0,0 +1,4 @@
|
||||
*.onnx
|
||||
*.config
|
||||
*.ort
|
||||
*-tokens.txt
|
||||
9
scripts/whisper/README.md
Normal file
9
scripts/whisper/README.md
Normal file
@@ -0,0 +1,9 @@
|
||||
# Introduction
|
||||
|
||||
This folder contains code showing how to convert [Whisper][whisper] to onnx
|
||||
and use onnxruntime to replace PyTorch for speech recognition.
|
||||
|
||||
You can use [sherpa-onnx][sherpa-onnx] to run the converted model.
|
||||
|
||||
[whisper]: https://github.com/openai/whisper
|
||||
[sherpa-onnx]: https://github.com/k2-fsa/sherpa-onnx
|
||||
439
scripts/whisper/export-onnx.py
Executable file
439
scripts/whisper/export-onnx.py
Executable file
@@ -0,0 +1,439 @@
|
||||
#!/usr/bin/env python3
|
||||
# Copyright 2023 Xiaomi Corp. (authors: Fangjun Kuang)
|
||||
# flake8: noqa
|
||||
|
||||
"""
|
||||
Note: Code in this file is modified from
|
||||
https://github.com/TadaoYamaoka/whisper/blob/main/to_onnx.py
|
||||
|
||||
Thanks to https://github.com/TadaoYamaoka
|
||||
for making the onnx export script public.
|
||||
"""
|
||||
|
||||
import argparse
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
import onnx
|
||||
import torch
|
||||
from onnxruntime.quantization import QuantType, quantize_dynamic
|
||||
from torch import Tensor, nn
|
||||
|
||||
import whisper
|
||||
from whisper.model import (
|
||||
AudioEncoder,
|
||||
MultiHeadAttention,
|
||||
ResidualAttentionBlock,
|
||||
TextDecoder,
|
||||
)
|
||||
|
||||
|
||||
def get_args():
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument(
|
||||
"--model",
|
||||
type=str,
|
||||
required=True,
|
||||
# fmt: off
|
||||
choices=[
|
||||
"tiny", "tiny.en", "base", "base.en",
|
||||
"small", "small.en", "medium", "medium.en",
|
||||
"large", "large-v1", "large-v2"],
|
||||
# fmt: on
|
||||
)
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
def add_meta_data(filename: str, meta_data: Dict[str, Any]):
|
||||
"""Add meta data to an ONNX model. It is changed in-place.
|
||||
|
||||
Args:
|
||||
filename:
|
||||
Filename of the ONNX model to be changed.
|
||||
meta_data:
|
||||
Key-value pairs.
|
||||
"""
|
||||
model = onnx.load(filename)
|
||||
for key, value in meta_data.items():
|
||||
meta = model.metadata_props.add()
|
||||
meta.key = key
|
||||
meta.value = str(value)
|
||||
|
||||
onnx.save(model, filename)
|
||||
|
||||
|
||||
class AudioEncoderTensorCache(nn.Module):
|
||||
def __init__(self, inAudioEncoder: AudioEncoder, inTextDecoder: TextDecoder):
|
||||
super().__init__()
|
||||
self.audioEncoder = inAudioEncoder
|
||||
self.textDecoder = inTextDecoder
|
||||
|
||||
def forward(self, x: Tensor):
|
||||
audio_features = self.audioEncoder(x)
|
||||
|
||||
n_layer_cross_k_list = []
|
||||
n_layer_cross_v_list = []
|
||||
for block in self.textDecoder.blocks:
|
||||
n_layer_cross_k_list.append(block.cross_attn.key(audio_features))
|
||||
n_layer_cross_v_list.append(block.cross_attn.value(audio_features))
|
||||
|
||||
return torch.stack(n_layer_cross_k_list), torch.stack(n_layer_cross_v_list)
|
||||
|
||||
|
||||
class MultiHeadAttentionCross(nn.Module):
|
||||
def __init__(self, inMultiHeadAttention: MultiHeadAttention):
|
||||
super().__init__()
|
||||
self.multiHeadAttention = inMultiHeadAttention
|
||||
|
||||
def forward(
|
||||
self,
|
||||
x: Tensor,
|
||||
k: Tensor,
|
||||
v: Tensor,
|
||||
mask: Optional[Tensor] = None,
|
||||
):
|
||||
q = self.multiHeadAttention.query(x)
|
||||
wv, qk = self.multiHeadAttention.qkv_attention(q, k, v, mask)
|
||||
return self.multiHeadAttention.out(wv)
|
||||
|
||||
|
||||
class MultiHeadAttentionSelf(nn.Module):
|
||||
def __init__(self, inMultiHeadAttention: MultiHeadAttention):
|
||||
super().__init__()
|
||||
self.multiHeadAttention = inMultiHeadAttention
|
||||
|
||||
def forward(
|
||||
self,
|
||||
x: Tensor, # (b, n_ctx , n_state)
|
||||
k_cache: Tensor, # (b, n_ctx_cache, n_state)
|
||||
v_cache: Tensor, # (b, n_ctx_cache, n_state)
|
||||
mask: Tensor,
|
||||
):
|
||||
q = self.multiHeadAttention.query(x) # (b, n_ctx, n_state)
|
||||
k = self.multiHeadAttention.key(x) # (b, n_ctx, n_state)
|
||||
v = self.multiHeadAttention.value(x) # (b, n_ctx, n_state)
|
||||
|
||||
k_cache[:, -k.shape[1] :, :] = k # (b, n_ctx_cache + n_ctx, n_state)
|
||||
v_cache[:, -v.shape[1] :, :] = v # (b, n_ctx_cache + n_ctx, n_state)
|
||||
|
||||
wv, qk = self.multiHeadAttention.qkv_attention(q, k_cache, v_cache, mask)
|
||||
return self.multiHeadAttention.out(wv), k_cache, v_cache
|
||||
|
||||
|
||||
class ResidualAttentionBlockTensorCache(nn.Module):
|
||||
def __init__(self, inResidualAttentionBlock: ResidualAttentionBlock):
|
||||
super().__init__()
|
||||
self.originalBlock = inResidualAttentionBlock
|
||||
self.attn = MultiHeadAttentionSelf(inResidualAttentionBlock.attn)
|
||||
self.cross_attn = (
|
||||
MultiHeadAttentionCross(inResidualAttentionBlock.cross_attn)
|
||||
if inResidualAttentionBlock.cross_attn
|
||||
else None
|
||||
)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
x: Tensor,
|
||||
self_k_cache: Tensor,
|
||||
self_v_cache: Tensor,
|
||||
cross_k: Tensor,
|
||||
cross_v: Tensor,
|
||||
mask: Tensor,
|
||||
):
|
||||
self_attn_x, self_k_cache_updated, self_v_cache_updated = self.attn(
|
||||
self.originalBlock.attn_ln(x), self_k_cache, self_v_cache, mask=mask
|
||||
)
|
||||
x = x + self_attn_x
|
||||
|
||||
if self.cross_attn:
|
||||
x = x + self.cross_attn(
|
||||
self.originalBlock.cross_attn_ln(x), cross_k, cross_v
|
||||
)
|
||||
|
||||
x = x + self.originalBlock.mlp(self.originalBlock.mlp_ln(x))
|
||||
return x, self_k_cache_updated, self_v_cache_updated
|
||||
|
||||
|
||||
class TextDecoderTensorCache(nn.Module):
|
||||
def __init__(self, inTextDecoder: TextDecoder, in_n_ctx: int):
|
||||
super().__init__()
|
||||
self.textDecoder = inTextDecoder
|
||||
self.n_ctx = in_n_ctx
|
||||
|
||||
self.blocks = []
|
||||
for orginal_block in self.textDecoder.blocks:
|
||||
self.blocks.append(ResidualAttentionBlockTensorCache(orginal_block))
|
||||
|
||||
def forward(
|
||||
self,
|
||||
tokens: Tensor,
|
||||
n_layer_self_k_cache: Tensor,
|
||||
n_layer_self_v_cache: Tensor,
|
||||
n_layer_cross_k: Tensor,
|
||||
n_layer_cross_v: Tensor,
|
||||
offset: Tensor,
|
||||
):
|
||||
x = (
|
||||
self.textDecoder.token_embedding(tokens)
|
||||
+ self.textDecoder.positional_embedding[
|
||||
offset[0] : offset[0] + tokens.shape[-1]
|
||||
]
|
||||
)
|
||||
x = x.to(n_layer_cross_k[0].dtype)
|
||||
|
||||
i = 0
|
||||
for block in self.blocks:
|
||||
self_k_cache = n_layer_self_k_cache[i, :, : offset[0] + tokens.shape[-1], :]
|
||||
self_v_cache = n_layer_self_v_cache[i, :, : offset[0] + tokens.shape[-1], :]
|
||||
x, self_k_cache, self_v_cache = block(
|
||||
x,
|
||||
self_k_cache=self_k_cache,
|
||||
self_v_cache=self_v_cache,
|
||||
cross_k=n_layer_cross_k[i],
|
||||
cross_v=n_layer_cross_v[i],
|
||||
mask=self.textDecoder.mask,
|
||||
)
|
||||
n_layer_self_k_cache[i, :, : offset[0] + tokens.shape[-1], :] = self_k_cache
|
||||
n_layer_self_v_cache[i, :, : offset[0] + tokens.shape[-1], :] = self_v_cache
|
||||
i += 1
|
||||
|
||||
x = self.textDecoder.ln(x)
|
||||
|
||||
logits = (
|
||||
x
|
||||
@ torch.transpose(self.textDecoder.token_embedding.weight.to(x.dtype), 0, 1)
|
||||
).float()
|
||||
|
||||
return logits, n_layer_self_k_cache, n_layer_self_v_cache
|
||||
|
||||
|
||||
# ref: https://github.com/ggerganov/whisper.cpp/blob/master/models/convert-pt-to-ggml.py#L232
|
||||
def convert_tokens(name, model):
|
||||
whisper_dir = Path(whisper.__file__).parent
|
||||
multilingual = model.is_multilingual
|
||||
tokenizer = (
|
||||
whisper_dir
|
||||
/ "assets"
|
||||
/ (multilingual and "multilingual.tiktoken" or "gpt2.tiktoken")
|
||||
)
|
||||
if not tokenizer.is_file():
|
||||
raise ValueError(f"Cannot find {tokenizer}")
|
||||
|
||||
# import base64
|
||||
|
||||
with open(tokenizer, "r") as f:
|
||||
contents = f.read()
|
||||
# tokens = {
|
||||
# base64.b64decode(token): int(rank)
|
||||
# for token, rank in (line.split() for line in contents.splitlines() if line)
|
||||
# }
|
||||
tokens = {
|
||||
token: int(rank)
|
||||
for token, rank in (line.split() for line in contents.splitlines() if line)
|
||||
}
|
||||
|
||||
with open(f"{name}-tokens.txt", "w") as f:
|
||||
for t, i in tokens.items():
|
||||
f.write(f"{t} {i}\n")
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def main():
|
||||
args = get_args()
|
||||
name = args.model
|
||||
|
||||
opset_version = 13
|
||||
|
||||
model = whisper.load_model(name)
|
||||
convert_tokens(name=name, model=model)
|
||||
|
||||
# write tokens
|
||||
|
||||
tokenizer = whisper.tokenizer.get_tokenizer(model.is_multilingual)
|
||||
model.eval()
|
||||
print(model.dims)
|
||||
audio = torch.rand(16000 * 2)
|
||||
audio = whisper.pad_or_trim(audio)
|
||||
assert audio.shape == (16000 * 30,), audio.shape
|
||||
|
||||
# make log-Mel spectrogram and move to the same device as the model
|
||||
mel = whisper.log_mel_spectrogram(audio).to(model.device).unsqueeze(0)
|
||||
batch_size = 1
|
||||
assert mel.shape == (batch_size, 80, 30 * 100)
|
||||
|
||||
encoder = AudioEncoderTensorCache(model.encoder, model.decoder)
|
||||
n_layer_cross_k, n_layer_cross_v = encoder(mel)
|
||||
assert n_layer_cross_k.shape == (
|
||||
model.dims.n_text_layer,
|
||||
batch_size,
|
||||
model.dims.n_audio_ctx,
|
||||
model.dims.n_text_state,
|
||||
), n_layer_cross_k.shape
|
||||
assert n_layer_cross_v.shape == (
|
||||
model.dims.n_text_layer,
|
||||
batch_size,
|
||||
model.dims.n_audio_ctx,
|
||||
model.dims.n_text_state,
|
||||
), n_layer_cross_v.shape
|
||||
|
||||
encoder_filename = f"{name}-encoder.onnx"
|
||||
torch.onnx.export(
|
||||
encoder,
|
||||
mel,
|
||||
encoder_filename,
|
||||
opset_version=opset_version,
|
||||
input_names=["mel"],
|
||||
output_names=["n_layer_cross_k", "n_layer_cross_v"],
|
||||
dynamic_axes={
|
||||
"mel": {0: "n_audio"}, # n_audio is also known as batch_size
|
||||
"n_layer_cross_k": {1: "n_audio"},
|
||||
"n_layer_cross_v": {1: "n_audio"},
|
||||
},
|
||||
)
|
||||
|
||||
encoder_meta_data = {
|
||||
"model_type": f"whisper-{name}",
|
||||
"version": "1",
|
||||
"maintainer": "k2-fsa",
|
||||
"n_mels": model.dims.n_mels,
|
||||
"n_audio_ctx": model.dims.n_audio_ctx,
|
||||
"n_audio_state": model.dims.n_audio_state,
|
||||
"n_audio_head": model.dims.n_audio_head,
|
||||
"n_audio_layer": model.dims.n_audio_layer,
|
||||
"n_vocab": model.dims.n_vocab,
|
||||
"n_text_ctx": model.dims.n_text_ctx,
|
||||
"n_text_state": model.dims.n_text_state,
|
||||
"n_text_head": model.dims.n_text_head,
|
||||
"n_text_layer": model.dims.n_text_layer,
|
||||
"sot_sequence": ",".join(list(map(str, tokenizer.sot_sequence))),
|
||||
"all_language_tokens": ",".join(list(map(str, tokenizer.all_language_tokens))),
|
||||
"all_language_codes": ",".join(tokenizer.all_language_codes),
|
||||
"sot": tokenizer.sot,
|
||||
"sot_index": tokenizer.sot_sequence.index(tokenizer.sot),
|
||||
"eot": tokenizer.eot,
|
||||
"blank_id": tokenizer.encode(" ")[0],
|
||||
"is_multilingual": int(model.is_multilingual),
|
||||
"no_speech": tokenizer.no_speech,
|
||||
"non_speech_tokens": ",".join(list(map(str, tokenizer.non_speech_tokens))),
|
||||
"transcribe": tokenizer.transcribe,
|
||||
"translate": tokenizer.translate,
|
||||
"sot_prev": tokenizer.sot_prev,
|
||||
"sot_lm": tokenizer.sot_lm,
|
||||
"no_timestamps": tokenizer.no_timestamps,
|
||||
}
|
||||
print(f"encoder_meta_data: {encoder_meta_data}")
|
||||
add_meta_data(filename=encoder_filename, meta_data=encoder_meta_data)
|
||||
|
||||
n_audio = mel.shape[0]
|
||||
tokens = torch.tensor([[tokenizer.sot, tokenizer.sot, tokenizer.sot]] * n_audio).to(
|
||||
mel.device
|
||||
) # [n_audio, 3]
|
||||
decoder = TextDecoderTensorCache(model.decoder, model.dims.n_text_ctx)
|
||||
n_layer_self_k_cache = torch.zeros(
|
||||
(
|
||||
len(model.decoder.blocks),
|
||||
n_audio,
|
||||
model.dims.n_text_ctx,
|
||||
model.dims.n_text_state,
|
||||
),
|
||||
device=mel.device,
|
||||
)
|
||||
n_layer_self_v_cache = torch.zeros(
|
||||
(
|
||||
len(model.decoder.blocks),
|
||||
n_audio,
|
||||
model.dims.n_text_ctx,
|
||||
model.dims.n_text_state,
|
||||
),
|
||||
device=mel.device,
|
||||
)
|
||||
offset = torch.zeros(1, dtype=torch.int64).to(mel.device)
|
||||
logits, n_layer_self_k_cache, n_layer_self_v_cache = decoder(
|
||||
tokens,
|
||||
n_layer_self_k_cache,
|
||||
n_layer_self_v_cache,
|
||||
n_layer_cross_k,
|
||||
n_layer_cross_v,
|
||||
offset,
|
||||
)
|
||||
assert logits.shape == (n_audio, tokens.shape[1], model.dims.n_vocab)
|
||||
assert n_layer_self_k_cache.shape == (
|
||||
model.dims.n_text_layer,
|
||||
n_audio,
|
||||
model.dims.n_text_ctx,
|
||||
model.dims.n_text_state,
|
||||
)
|
||||
assert n_layer_self_v_cache.shape == (
|
||||
model.dims.n_text_layer,
|
||||
n_audio,
|
||||
model.dims.n_text_ctx,
|
||||
model.dims.n_text_state,
|
||||
)
|
||||
|
||||
offset = torch.tensor([tokens.shape[1]], dtype=torch.int64).to(mel.device)
|
||||
tokens = torch.tensor([[tokenizer.sot]] * n_audio).to(mel.device) # [n_audio, 1]
|
||||
|
||||
logits, out_n_layer_self_k_cache, out_n_layer_self_v_cache = decoder(
|
||||
tokens,
|
||||
n_layer_self_k_cache,
|
||||
n_layer_self_v_cache,
|
||||
n_layer_cross_k,
|
||||
n_layer_cross_v,
|
||||
offset,
|
||||
)
|
||||
|
||||
decoder_filename = f"{name}-decoder.onnx"
|
||||
torch.onnx.export(
|
||||
decoder,
|
||||
(
|
||||
tokens,
|
||||
n_layer_self_k_cache,
|
||||
n_layer_self_v_cache,
|
||||
n_layer_cross_k,
|
||||
n_layer_cross_v,
|
||||
offset,
|
||||
),
|
||||
decoder_filename,
|
||||
opset_version=opset_version,
|
||||
input_names=[
|
||||
"tokens",
|
||||
"in_n_layer_self_k_cache",
|
||||
"in_n_layer_self_v_cache",
|
||||
"n_layer_cross_k",
|
||||
"n_layer_cross_v",
|
||||
"offset",
|
||||
],
|
||||
output_names=["logits", "out_n_layer_self_k_cache", "out_n_layer_self_v_cache"],
|
||||
dynamic_axes={
|
||||
"tokens": {0: "n_audio", 1: "n_tokens"},
|
||||
"in_n_layer_self_k_cache": {1: "n_audio"},
|
||||
"in_n_layer_self_v_cache": {1: "n_audio"},
|
||||
"n_layer_cross_k": {1: "n_audio"},
|
||||
"n_layer_cross_v": {1: "n_audio"},
|
||||
},
|
||||
)
|
||||
|
||||
# Generate int8 quantization models
|
||||
# See https://onnxruntime.ai/docs/performance/model-optimizations/quantization.html#data-type-selection
|
||||
|
||||
print("Generate int8 quantization models")
|
||||
|
||||
encoder_filename_int8 = f"{name}-encoder.int8.onnx"
|
||||
quantize_dynamic(
|
||||
model_input=encoder_filename,
|
||||
model_output=encoder_filename_int8,
|
||||
op_types_to_quantize=["MatMul"],
|
||||
weight_type=QuantType.QInt8,
|
||||
)
|
||||
|
||||
decoder_filename_int8 = f"{name}-decoder.int8.onnx"
|
||||
quantize_dynamic(
|
||||
model_input=decoder_filename,
|
||||
model_output=decoder_filename_int8,
|
||||
op_types_to_quantize=["MatMul"],
|
||||
weight_type=QuantType.QInt8,
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
1
scripts/whisper/requirements.txt
Normal file
1
scripts/whisper/requirements.txt
Normal file
@@ -0,0 +1 @@
|
||||
openai-whisper
|
||||
241
scripts/whisper/test.py
Executable file
241
scripts/whisper/test.py
Executable file
@@ -0,0 +1,241 @@
|
||||
#!/usr/bin/env python3
|
||||
# Copyright 2023 Xiaomi Corp. (authors: Fangjun Kuang)
|
||||
"""
|
||||
Please first run ./export-onnx.py
|
||||
before you run this script
|
||||
"""
|
||||
import base64
|
||||
from typing import Tuple
|
||||
|
||||
import kaldi_native_fbank as knf
|
||||
import onnxruntime as ort
|
||||
import torch
|
||||
|
||||
import whisper
|
||||
import argparse
|
||||
|
||||
|
||||
def get_args():
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument(
|
||||
"--model",
|
||||
type=str,
|
||||
required=True,
|
||||
# fmt: off
|
||||
choices=[
|
||||
"tiny", "tiny.en", "base", "base.en",
|
||||
"small", "small.en", "medium", "medium.en",
|
||||
"large", "large-v1", "large-v2"],
|
||||
# fmt: on
|
||||
)
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
class OnnxModel:
|
||||
def __init__(
|
||||
self,
|
||||
encoder: str,
|
||||
decoder: str,
|
||||
):
|
||||
session_opts = ort.SessionOptions()
|
||||
session_opts.inter_op_num_threads = 1
|
||||
session_opts.intra_op_num_threads = 4
|
||||
|
||||
self.session_opts = session_opts
|
||||
|
||||
self.init_encoder(encoder)
|
||||
self.init_decoder(decoder)
|
||||
|
||||
def init_encoder(self, encoder: str):
|
||||
self.encoder = ort.InferenceSession(
|
||||
encoder,
|
||||
sess_options=self.session_opts,
|
||||
)
|
||||
|
||||
meta = self.encoder.get_modelmeta().custom_metadata_map
|
||||
self.n_text_layer = int(meta["n_text_layer"])
|
||||
self.n_text_ctx = int(meta["n_text_ctx"])
|
||||
self.n_text_state = int(meta["n_text_state"])
|
||||
self.sot = int(meta["sot"])
|
||||
self.eot = int(meta["eot"])
|
||||
self.translate = int(meta["translate"])
|
||||
self.no_timestamps = int(meta["no_timestamps"])
|
||||
self.no_speech = int(meta["no_speech"])
|
||||
self.blank = int(meta["blank_id"])
|
||||
|
||||
self.sot_sequence = list(map(int, meta["sot_sequence"].split(",")))
|
||||
|
||||
self.is_multilingual = int(meta["is_multilingual"]) == 1
|
||||
|
||||
def init_decoder(self, decoder: str):
|
||||
self.decoder = ort.InferenceSession(
|
||||
decoder,
|
||||
sess_options=self.session_opts,
|
||||
)
|
||||
|
||||
def run_encoder(
|
||||
self,
|
||||
mel: torch.Tensor,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
n_layer_cross_k, n_layer_cross_v = self.encoder.run(
|
||||
[
|
||||
self.encoder.get_outputs()[0].name,
|
||||
self.encoder.get_outputs()[1].name,
|
||||
],
|
||||
{
|
||||
self.encoder.get_inputs()[0].name: mel.numpy(),
|
||||
},
|
||||
)
|
||||
return torch.from_numpy(n_layer_cross_k), torch.from_numpy(n_layer_cross_v)
|
||||
|
||||
def run_decoder(
|
||||
self,
|
||||
tokens: torch.Tensor,
|
||||
n_layer_self_k_cache: torch.Tensor,
|
||||
n_layer_self_v_cache: torch.Tensor,
|
||||
n_layer_cross_k: torch.Tensor,
|
||||
n_layer_cross_v: torch.Tensor,
|
||||
offset: torch.Tensor,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||
logits, out_n_layer_self_k_cache, out_n_layer_self_v_cache = self.decoder.run(
|
||||
[
|
||||
self.decoder.get_outputs()[0].name,
|
||||
self.decoder.get_outputs()[1].name,
|
||||
self.decoder.get_outputs()[2].name,
|
||||
],
|
||||
{
|
||||
self.decoder.get_inputs()[0].name: tokens.numpy(),
|
||||
self.decoder.get_inputs()[1].name: n_layer_self_k_cache.numpy(),
|
||||
self.decoder.get_inputs()[2].name: n_layer_self_v_cache.numpy(),
|
||||
self.decoder.get_inputs()[3].name: n_layer_cross_k.numpy(),
|
||||
self.decoder.get_inputs()[4].name: n_layer_cross_v.numpy(),
|
||||
self.decoder.get_inputs()[5].name: offset.numpy(),
|
||||
},
|
||||
)
|
||||
return (
|
||||
torch.from_numpy(logits),
|
||||
torch.from_numpy(out_n_layer_self_k_cache),
|
||||
torch.from_numpy(out_n_layer_self_v_cache),
|
||||
)
|
||||
|
||||
def get_self_cache(self) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
batch_size = 1
|
||||
n_layer_self_k_cache = torch.zeros(
|
||||
self.n_text_layer,
|
||||
batch_size,
|
||||
self.n_text_ctx,
|
||||
self.n_text_state,
|
||||
)
|
||||
n_layer_self_v_cache = torch.zeros(
|
||||
self.n_text_layer,
|
||||
batch_size,
|
||||
self.n_text_ctx,
|
||||
self.n_text_state,
|
||||
)
|
||||
return n_layer_self_k_cache, n_layer_self_v_cache
|
||||
|
||||
def suppress_tokens(self, logits, is_initial: bool) -> None:
|
||||
# suppress blank
|
||||
if is_initial:
|
||||
logits[self.eot] = float("-inf")
|
||||
logits[self.blank] = float("-inf")
|
||||
|
||||
# suppress <|notimestamps|>
|
||||
logits[self.no_timestamps] = float("-inf")
|
||||
|
||||
logits[self.sot] = float("-inf")
|
||||
logits[self.no_speech] = float("-inf")
|
||||
|
||||
# logits is changed in-place
|
||||
logits[self.translate] = float("-inf")
|
||||
|
||||
|
||||
def load_tokens(filename):
|
||||
tokens = dict()
|
||||
with open(filename, "r") as f:
|
||||
for line in f:
|
||||
t, i = line.split()
|
||||
tokens[int(i)] = t
|
||||
return tokens
|
||||
|
||||
|
||||
def main():
|
||||
args = get_args()
|
||||
name = args.model
|
||||
|
||||
encoder = f"./{name}-encoder.onnx"
|
||||
decoder = f"./{name}-decoder.onnx"
|
||||
audio = whisper.load_audio("0.wav")
|
||||
|
||||
features = []
|
||||
online_whisper_fbank = knf.OnlineWhisperFbank(knf.FrameExtractionOptions())
|
||||
online_whisper_fbank.accept_waveform(16000, audio)
|
||||
online_whisper_fbank.input_finished()
|
||||
for i in range(online_whisper_fbank.num_frames_ready):
|
||||
f = online_whisper_fbank.get_frame(i)
|
||||
f = torch.from_numpy(f)
|
||||
features.append(f)
|
||||
|
||||
features = torch.stack(features)
|
||||
|
||||
log_spec = torch.clamp(features, min=1e-10).log10()
|
||||
log_spec = torch.maximum(log_spec, log_spec.max() - 8.0)
|
||||
mel = (log_spec + 4.0) / 4.0
|
||||
target = 3000
|
||||
mel = torch.nn.functional.pad(mel, (0, 0, 0, target - mel.shape[0]), "constant", 0)
|
||||
mel = mel.t().unsqueeze(0)
|
||||
|
||||
model = OnnxModel(encoder, decoder)
|
||||
n_layer_cross_k, n_layer_cross_v = model.run_encoder(mel)
|
||||
n_layer_self_k_cache, n_layer_self_v_cache = model.get_self_cache()
|
||||
|
||||
tokens = torch.tensor([model.sot_sequence], dtype=torch.int64)
|
||||
offset = torch.zeros(1, dtype=torch.int64)
|
||||
logits, n_layer_self_k_cache, n_layer_self_v_cache = model.run_decoder(
|
||||
tokens=tokens,
|
||||
n_layer_self_k_cache=n_layer_self_k_cache,
|
||||
n_layer_self_v_cache=n_layer_self_v_cache,
|
||||
n_layer_cross_k=n_layer_cross_k,
|
||||
n_layer_cross_v=n_layer_cross_v,
|
||||
offset=offset,
|
||||
)
|
||||
# logits.shape (batch_size, tokens.shape[1], vocab_size)
|
||||
logits = logits[0, -1]
|
||||
model.suppress_tokens(logits, is_initial=True)
|
||||
# logits = logits.softmax(dim=-1)
|
||||
# for greedy search, we don't need to compute softmax or log_softmax
|
||||
max_token_id = logits.argmax(dim=-1)
|
||||
results = []
|
||||
for i in range(model.n_text_ctx):
|
||||
if max_token_id == model.eot:
|
||||
break
|
||||
results.append(max_token_id.item())
|
||||
tokens = torch.tensor([[results[-1]]])
|
||||
offset += 1
|
||||
|
||||
logits, n_layer_self_k_cache, n_layer_self_v_cache = model.run_decoder(
|
||||
tokens=tokens,
|
||||
n_layer_self_k_cache=n_layer_self_k_cache,
|
||||
n_layer_self_v_cache=n_layer_self_v_cache,
|
||||
n_layer_cross_k=n_layer_cross_k,
|
||||
n_layer_cross_v=n_layer_cross_v,
|
||||
offset=offset,
|
||||
)
|
||||
logits = logits[0, -1]
|
||||
model.suppress_tokens(logits, is_initial=False)
|
||||
max_token_id = logits.argmax(dim=-1)
|
||||
token_table = load_tokens(f"./{name}-tokens.txt")
|
||||
s = b""
|
||||
for i in results:
|
||||
if i in token_table:
|
||||
s += base64.b64decode(token_table[i])
|
||||
else:
|
||||
print("oov", i)
|
||||
|
||||
print(s.decode().strip())
|
||||
print(results)
|
||||
print(model.sot_sequence)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
Reference in New Issue
Block a user