#!/usr/bin/env python3 # Copyright 2023 Xiaomi Corp. (authors: Fangjun Kuang) # flake8: noqa """ Note: Code in this file is modified from https://github.com/TadaoYamaoka/whisper/blob/main/to_onnx.py Thanks to https://github.com/TadaoYamaoka for making the onnx export script public. """ import argparse import os from pathlib import Path from typing import Any, Dict, Optional import onnx import torch from onnxruntime.quantization import QuantType, quantize_dynamic from torch import Tensor, nn import whisper from whisper.model import ( AudioEncoder, MultiHeadAttention, ResidualAttentionBlock, TextDecoder, ) def get_args(): parser = argparse.ArgumentParser() parser.add_argument( "--model", type=str, required=True, # fmt: off choices=[ "tiny", "tiny.en", "base", "base.en", "small", "small.en", "medium", "medium.en", "large", "large-v1", "large-v2"], # fmt: on ) return parser.parse_args() def add_meta_data(filename: str, meta_data: Dict[str, Any]): """Add meta data to an ONNX model. It is changed in-place. Args: filename: Filename of the ONNX model to be changed. meta_data: Key-value pairs. """ model = onnx.load(filename) for key, value in meta_data.items(): meta = model.metadata_props.add() meta.key = key meta.value = str(value) onnx.save(model, filename) class AudioEncoderTensorCache(nn.Module): def __init__(self, inAudioEncoder: AudioEncoder, inTextDecoder: TextDecoder): super().__init__() self.audioEncoder = inAudioEncoder self.textDecoder = inTextDecoder def forward(self, x: Tensor): audio_features = self.audioEncoder(x) n_layer_cross_k_list = [] n_layer_cross_v_list = [] for block in self.textDecoder.blocks: n_layer_cross_k_list.append(block.cross_attn.key(audio_features)) n_layer_cross_v_list.append(block.cross_attn.value(audio_features)) return torch.stack(n_layer_cross_k_list), torch.stack(n_layer_cross_v_list) class MultiHeadAttentionCross(nn.Module): def __init__(self, inMultiHeadAttention: MultiHeadAttention): super().__init__() self.multiHeadAttention = inMultiHeadAttention def forward( self, x: Tensor, k: Tensor, v: Tensor, mask: Optional[Tensor] = None, ): q = self.multiHeadAttention.query(x) wv, qk = self.multiHeadAttention.qkv_attention(q, k, v, mask) return self.multiHeadAttention.out(wv) class MultiHeadAttentionSelf(nn.Module): def __init__(self, inMultiHeadAttention: MultiHeadAttention): super().__init__() self.multiHeadAttention = inMultiHeadAttention def forward( self, x: Tensor, # (b, n_ctx , n_state) k_cache: Tensor, # (b, n_ctx_cache, n_state) v_cache: Tensor, # (b, n_ctx_cache, n_state) mask: Tensor, ): q = self.multiHeadAttention.query(x) # (b, n_ctx, n_state) k = self.multiHeadAttention.key(x) # (b, n_ctx, n_state) v = self.multiHeadAttention.value(x) # (b, n_ctx, n_state) k_cache[:, -k.shape[1] :, :] = k # (b, n_ctx_cache + n_ctx, n_state) v_cache[:, -v.shape[1] :, :] = v # (b, n_ctx_cache + n_ctx, n_state) wv, qk = self.multiHeadAttention.qkv_attention(q, k_cache, v_cache, mask) return self.multiHeadAttention.out(wv), k_cache, v_cache class ResidualAttentionBlockTensorCache(nn.Module): def __init__(self, inResidualAttentionBlock: ResidualAttentionBlock): super().__init__() self.originalBlock = inResidualAttentionBlock self.attn = MultiHeadAttentionSelf(inResidualAttentionBlock.attn) self.cross_attn = ( MultiHeadAttentionCross(inResidualAttentionBlock.cross_attn) if inResidualAttentionBlock.cross_attn else None ) def forward( self, x: Tensor, self_k_cache: Tensor, self_v_cache: Tensor, cross_k: Tensor, cross_v: Tensor, mask: Tensor, ): self_attn_x, self_k_cache_updated, self_v_cache_updated = self.attn( self.originalBlock.attn_ln(x), self_k_cache, self_v_cache, mask=mask ) x = x + self_attn_x if self.cross_attn: x = x + self.cross_attn( self.originalBlock.cross_attn_ln(x), cross_k, cross_v ) x = x + self.originalBlock.mlp(self.originalBlock.mlp_ln(x)) return x, self_k_cache_updated, self_v_cache_updated class TextDecoderTensorCache(nn.Module): def __init__(self, inTextDecoder: TextDecoder, in_n_ctx: int): super().__init__() self.textDecoder = inTextDecoder self.n_ctx = in_n_ctx self.blocks = [] for orginal_block in self.textDecoder.blocks: self.blocks.append(ResidualAttentionBlockTensorCache(orginal_block)) def forward( self, tokens: Tensor, n_layer_self_k_cache: Tensor, n_layer_self_v_cache: Tensor, n_layer_cross_k: Tensor, n_layer_cross_v: Tensor, offset: Tensor, ): x = ( self.textDecoder.token_embedding(tokens) + self.textDecoder.positional_embedding[ offset[0] : offset[0] + tokens.shape[-1] ] ) x = x.to(n_layer_cross_k[0].dtype) i = 0 for block in self.blocks: self_k_cache = n_layer_self_k_cache[i, :, : offset[0] + tokens.shape[-1], :] self_v_cache = n_layer_self_v_cache[i, :, : offset[0] + tokens.shape[-1], :] x, self_k_cache, self_v_cache = block( x, self_k_cache=self_k_cache, self_v_cache=self_v_cache, cross_k=n_layer_cross_k[i], cross_v=n_layer_cross_v[i], mask=self.textDecoder.mask, ) n_layer_self_k_cache[i, :, : offset[0] + tokens.shape[-1], :] = self_k_cache n_layer_self_v_cache[i, :, : offset[0] + tokens.shape[-1], :] = self_v_cache i += 1 x = self.textDecoder.ln(x) if False: # x.shape (1, 3, 384) # weight.shape (51684, 384) logits = ( x @ torch.transpose( self.textDecoder.token_embedding.weight.to(x.dtype), 0, 1 ) ).float() else: logits = ( torch.matmul( self.textDecoder.token_embedding.weight.to(x.dtype), x.permute(0, 2, 1), ) .permute(0, 2, 1) .float() ) return logits, n_layer_self_k_cache, n_layer_self_v_cache # ref: https://github.com/ggerganov/whisper.cpp/blob/master/models/convert-pt-to-ggml.py#L232 def convert_tokens(name, model): whisper_dir = Path(whisper.__file__).parent multilingual = model.is_multilingual tokenizer = ( whisper_dir / "assets" / (multilingual and "multilingual.tiktoken" or "gpt2.tiktoken") ) if not tokenizer.is_file(): raise ValueError(f"Cannot find {tokenizer}") # import base64 with open(tokenizer, "r") as f: contents = f.read() # tokens = { # base64.b64decode(token): int(rank) # for token, rank in (line.split() for line in contents.splitlines() if line) # } tokens = { token: int(rank) for token, rank in (line.split() for line in contents.splitlines() if line) } with open(f"{name}-tokens.txt", "w") as f: for t, i in tokens.items(): f.write(f"{t} {i}\n") @torch.no_grad() def main(): args = get_args() name = args.model opset_version = 13 model = whisper.load_model(name) print( f"number of model parameters: {name}", sum(p.numel() for p in model.parameters()), ) print( f"number of encoder parameters: {name}", sum(p.numel() for p in model.encoder.parameters()), ) print( f"number of decoder parameters: {name}", sum(p.numel() for p in model.decoder.parameters()), ) convert_tokens(name=name, model=model) # write tokens tokenizer = whisper.tokenizer.get_tokenizer(model.is_multilingual) model.eval() print(model.dims) audio = torch.rand(16000 * 2) audio = whisper.pad_or_trim(audio) assert audio.shape == (16000 * 30,), audio.shape # make log-Mel spectrogram and move to the same device as the model mel = whisper.log_mel_spectrogram(audio).to(model.device).unsqueeze(0) batch_size = 1 assert mel.shape == (batch_size, 80, 30 * 100) encoder = AudioEncoderTensorCache(model.encoder, model.decoder) n_layer_cross_k, n_layer_cross_v = encoder(mel) assert n_layer_cross_k.shape == ( model.dims.n_text_layer, batch_size, model.dims.n_audio_ctx, model.dims.n_text_state, ), n_layer_cross_k.shape assert n_layer_cross_v.shape == ( model.dims.n_text_layer, batch_size, model.dims.n_audio_ctx, model.dims.n_text_state, ), n_layer_cross_v.shape encoder_filename = f"{name}-encoder.onnx" torch.onnx.export( encoder, mel, encoder_filename, opset_version=opset_version, input_names=["mel"], output_names=["n_layer_cross_k", "n_layer_cross_v"], dynamic_axes={ "mel": {0: "n_audio"}, # n_audio is also known as batch_size "n_layer_cross_k": {1: "n_audio"}, "n_layer_cross_v": {1: "n_audio"}, }, ) encoder_meta_data = { "model_type": f"whisper-{name}", "version": "1", "maintainer": "k2-fsa", "n_mels": model.dims.n_mels, "n_audio_ctx": model.dims.n_audio_ctx, "n_audio_state": model.dims.n_audio_state, "n_audio_head": model.dims.n_audio_head, "n_audio_layer": model.dims.n_audio_layer, "n_vocab": model.dims.n_vocab, "n_text_ctx": model.dims.n_text_ctx, "n_text_state": model.dims.n_text_state, "n_text_head": model.dims.n_text_head, "n_text_layer": model.dims.n_text_layer, "sot_sequence": ",".join(list(map(str, tokenizer.sot_sequence))), "all_language_tokens": ",".join( list(map(str, tokenizer.all_language_tokens)) ), # a list of ids "all_language_codes": ",".join( tokenizer.all_language_codes ), # e.g., en, de, zh, fr "sot": tokenizer.sot, "sot_index": tokenizer.sot_sequence.index(tokenizer.sot), "eot": tokenizer.eot, "blank_id": tokenizer.encode(" ")[0], "is_multilingual": int(model.is_multilingual), "no_speech": tokenizer.no_speech, "non_speech_tokens": ",".join(list(map(str, tokenizer.non_speech_tokens))), "transcribe": tokenizer.transcribe, "translate": tokenizer.translate, "sot_prev": tokenizer.sot_prev, "sot_lm": tokenizer.sot_lm, "no_timestamps": tokenizer.no_timestamps, } print(f"encoder_meta_data: {encoder_meta_data}") add_meta_data(filename=encoder_filename, meta_data=encoder_meta_data) n_audio = mel.shape[0] tokens = torch.tensor([[tokenizer.sot, tokenizer.sot, tokenizer.sot]] * n_audio).to( mel.device ) # [n_audio, 3] decoder = TextDecoderTensorCache(model.decoder, model.dims.n_text_ctx) n_layer_self_k_cache = torch.zeros( ( len(model.decoder.blocks), n_audio, model.dims.n_text_ctx, model.dims.n_text_state, ), device=mel.device, ) n_layer_self_v_cache = torch.zeros( ( len(model.decoder.blocks), n_audio, model.dims.n_text_ctx, model.dims.n_text_state, ), device=mel.device, ) offset = torch.zeros(1, dtype=torch.int64).to(mel.device) logits, n_layer_self_k_cache, n_layer_self_v_cache = decoder( tokens, n_layer_self_k_cache, n_layer_self_v_cache, n_layer_cross_k, n_layer_cross_v, offset, ) assert logits.shape == (n_audio, tokens.shape[1], model.dims.n_vocab) assert n_layer_self_k_cache.shape == ( model.dims.n_text_layer, n_audio, model.dims.n_text_ctx, model.dims.n_text_state, ) assert n_layer_self_v_cache.shape == ( model.dims.n_text_layer, n_audio, model.dims.n_text_ctx, model.dims.n_text_state, ) offset = torch.tensor([tokens.shape[1]], dtype=torch.int64).to(mel.device) tokens = torch.tensor([[tokenizer.sot]] * n_audio).to(mel.device) # [n_audio, 1] logits, out_n_layer_self_k_cache, out_n_layer_self_v_cache = decoder( tokens, n_layer_self_k_cache, n_layer_self_v_cache, n_layer_cross_k, n_layer_cross_v, offset, ) decoder_filename = f"{name}-decoder.onnx" torch.onnx.export( decoder, ( tokens, n_layer_self_k_cache, n_layer_self_v_cache, n_layer_cross_k, n_layer_cross_v, offset, ), decoder_filename, opset_version=opset_version, input_names=[ "tokens", "in_n_layer_self_k_cache", "in_n_layer_self_v_cache", "n_layer_cross_k", "n_layer_cross_v", "offset", ], output_names=["logits", "out_n_layer_self_k_cache", "out_n_layer_self_v_cache"], dynamic_axes={ "tokens": {0: "n_audio", 1: "n_tokens"}, "in_n_layer_self_k_cache": {1: "n_audio"}, "in_n_layer_self_v_cache": {1: "n_audio"}, "n_layer_cross_k": {1: "n_audio"}, "n_layer_cross_v": {1: "n_audio"}, }, ) if "large" in args.model: # it causes errors for large models, so skip it. return # Generate int8 quantization models # See https://onnxruntime.ai/docs/performance/model-optimizations/quantization.html#data-type-selection print("Generate int8 quantization models") encoder_filename_int8 = f"{name}-encoder.int8.onnx" quantize_dynamic( model_input=encoder_filename, model_output=encoder_filename_int8, op_types_to_quantize=["MatMul"], weight_type=QuantType.QInt8, ) decoder_filename_int8 = f"{name}-decoder.int8.onnx" quantize_dynamic( model_input=decoder_filename, model_output=decoder_filename_int8, op_types_to_quantize=["MatMul"], weight_type=QuantType.QInt8, ) if __name__ == "__main__": main()