diff --git a/.github/workflows/export-nemo-canary-180m-flash.yaml b/.github/workflows/export-nemo-canary-180m-flash.yaml new file mode 100644 index 00000000..54f0c6bc --- /dev/null +++ b/.github/workflows/export-nemo-canary-180m-flash.yaml @@ -0,0 +1,132 @@ +name: export-nemo-canary-180m-flash + +on: + push: + branches: + - export-nemo-canary + workflow_dispatch: + +concurrency: + group: export-nemo-canary-180m-flash-${{ github.ref }} + cancel-in-progress: true + +jobs: + export-nemo-canary-180m-flash: + if: github.repository_owner == 'k2-fsa' || github.repository_owner == 'csukuangfj' + name: parakeet nemo canary 180m flash + runs-on: ${{ matrix.os }} + strategy: + fail-fast: false + matrix: + os: [macos-latest] + python-version: ["3.10"] + + steps: + - uses: actions/checkout@v4 + + - name: Setup Python ${{ matrix.python-version }} + uses: actions/setup-python@v5 + with: + python-version: ${{ matrix.python-version }} + + - name: Run + shell: bash + run: | + cd scripts/nemo/canary + ./run_180m_flash.sh + + ls -lh *.onnx + mv -v *.onnx ../../.. + mv -v tokens.txt ../../.. + mv de.wav ../../../ + mv en.wav ../../../ + + - name: Collect files (fp32) + shell: bash + run: | + d=sherpa-onnx-nemo-canary-180m-flash-en-es-de-fr + mkdir -p $d + cp encoder.onnx $d + cp decoder.onnx $d + cp tokens.txt $d + + mkdir $d/test_wavs + cp de.wav $d/test_wavs + cp en.wav $d/test_wavs + + tar cjfv $d.tar.bz2 $d + + - name: Collect files (int8) + shell: bash + run: | + d=sherpa-onnx-nemo-canary-180m-flash-en-es-de-fr-int8 + mkdir -p $d + cp encoder.int8.onnx $d + cp decoder.fp16.onnx $d + cp tokens.txt $d + + mkdir $d/test_wavs + cp de.wav $d/test_wavs + cp en.wav $d/test_wavs + + tar cjfv $d.tar.bz2 $d + + - name: Collect files (fp16) + shell: bash + run: | + d=sherpa-onnx-nemo-canary-180m-flash-en-es-de-fr-fp16 + mkdir -p $d + cp encoder.fp16.onnx $d + cp decoder.fp16.onnx $d + cp tokens.txt $d + + mkdir $d/test_wavs + cp de.wav $d/test_wavs + cp en.wav $d/test_wavs + + tar cjfv $d.tar.bz2 $d + + - name: Publish to huggingface + env: + HF_TOKEN: ${{ secrets.HF_TOKEN }} + uses: nick-fields/retry@v3 + with: + max_attempts: 20 + timeout_seconds: 200 + shell: bash + command: | + git config --global user.email "csukuangfj@gmail.com" + git config --global user.name "Fangjun Kuang" + + models=( + sherpa-onnx-nemo-canary-180m-flash-en-es-de-fr + sherpa-onnx-nemo-canary-180m-flash-en-es-de-fr-int8 + sherpa-onnx-nemo-canary-180m-flash-en-es-de-fr-fp16 + ) + + for m in ${models[@]}; do + rm -rf huggingface + export GIT_LFS_SKIP_SMUDGE=1 + export GIT_CLONE_PROTECTION_ACTIVE=false + git clone https://csukuangfj:$HF_TOKEN@huggingface.co/csukuangfj/$m huggingface + cp -av $m/* huggingface + cd huggingface + git lfs track "*.onnx" + git lfs track "*.wav" + git status + git add . + git status + git commit -m "first commit" + git push https://csukuangfj:$HF_TOKEN@huggingface.co/csukuangfj/$m main + cd .. + done + + - name: Release + uses: svenstaro/upload-release-action@v2 + with: + file_glob: true + file: ./*.tar.bz2 + overwrite: true + repo_name: k2-fsa/sherpa-onnx + repo_token: ${{ secrets.UPLOAD_GH_SHERPA_ONNX_TOKEN }} + tag: asr-models diff --git a/scripts/nemo/canary/export_onnx_180m_flash.py b/scripts/nemo/canary/export_onnx_180m_flash.py new file mode 100755 index 00000000..7585c18d --- /dev/null +++ b/scripts/nemo/canary/export_onnx_180m_flash.py @@ -0,0 +1,289 @@ +#!/usr/bin/env python3 +# Copyright 2025 Xiaomi Corp. (authors: Fangjun Kuang) + +import os +from typing import Tuple + +import nemo +import onnxmltools +import torch +from nemo.collections.common.parts import NEG_INF +from onnxmltools.utils.float16_converter import convert_float_to_float16 +from onnxruntime.quantization import QuantType, quantize_dynamic + +""" +NotImplemented: [ONNXRuntimeError] : 9 : NOT_IMPLEMENTED : +Could not find an implementation for Trilu(14) node with name '/Trilu' + +See also https://github.com/microsoft/onnxruntime/issues/16189#issuecomment-1722219631 + +So we use fixed_form_attention_mask() to replace +the original form_attention_mask() +""" + + +def fixed_form_attention_mask(input_mask, diagonal=None): + """ + Fixed: Build attention mask with optional masking of future tokens we forbid + to attend to (e.g. as it is in Transformer decoder). + + Args: + input_mask: binary mask of size B x L with 1s corresponding to valid + tokens and 0s corresponding to padding tokens + diagonal: diagonal where triangular future mask starts + None -- do not mask anything + 0 -- regular translation or language modeling future masking + 1 -- query stream masking as in XLNet architecture + Returns: + attention_mask: mask of size B x 1 x L x L with 0s corresponding to + tokens we plan to attend to and -10000 otherwise + """ + + if input_mask is None: + return None + attn_shape = (1, input_mask.shape[1], input_mask.shape[1]) + attn_mask = input_mask.to(dtype=bool).unsqueeze(1) + if diagonal is not None: + future_mask = torch.tril( + torch.ones( + attn_shape, + dtype=torch.int64, # it was torch.bool + # but onnxruntime does not support torch.int32 or torch.bool + # in torch.tril + device=input_mask.device, + ), + diagonal, + ).bool() + attn_mask = attn_mask & future_mask + attention_mask = (1 - attn_mask.to(torch.float)) * NEG_INF + return attention_mask.unsqueeze(1) + + +nemo.collections.common.parts.form_attention_mask = fixed_form_attention_mask + +from nemo.collections.asr.models import EncDecMultiTaskModel + + +def export_onnx_fp16(onnx_fp32_path, onnx_fp16_path): + onnx_fp32_model = onnxmltools.utils.load_model(onnx_fp32_path) + onnx_fp16_model = convert_float_to_float16(onnx_fp32_model, keep_io_types=True) + onnxmltools.utils.save_model(onnx_fp16_model, onnx_fp16_path) + + +def lens_to_mask(lens, max_length): + """ + Create a mask from a tensor of lengths. + """ + batch_size = lens.shape[0] + arange = torch.arange(max_length, device=lens.device) + mask = arange.expand(batch_size, max_length) < lens.unsqueeze(1) + return mask + + +class EncoderWrapper(torch.nn.Module): + def __init__(self, m): + super().__init__() + self.encoder = m.encoder + self.encoder_decoder_proj = m.encoder_decoder_proj + + def forward( + self, x: torch.Tensor, x_len: torch.Tensor + ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + """ + Args: + x: (N, T, C) + x_len: (N,) + Returns: + - enc_states: (N, T, C) + - encoded_len: (N,) + - enc_mask: (N, T) + """ + x = x.permute(0, 2, 1) + # x: (N, C, T) + encoded, encoded_len = self.encoder(audio_signal=x, length=x_len) + + enc_states = encoded.permute(0, 2, 1) + + enc_states = self.encoder_decoder_proj(enc_states) + + enc_mask = lens_to_mask(encoded_len, enc_states.shape[1]) + + return enc_states, encoded_len, enc_mask + + +class DecoderWrapper(torch.nn.Module): + def __init__(self, m): + super().__init__() + self.decoder = m.transf_decoder + self.log_softmax = m.log_softmax + + # We use only greedy search, so there is no need to compute log_softmax + self.log_softmax.mlp.log_softmax = False + + def forward( + self, + decoder_input_ids: torch.Tensor, + decoder_mems_list_0: torch.Tensor, + decoder_mems_list_1: torch.Tensor, + decoder_mems_list_2: torch.Tensor, + decoder_mems_list_3: torch.Tensor, + decoder_mems_list_4: torch.Tensor, + decoder_mems_list_5: torch.Tensor, + enc_states: torch.Tensor, + enc_mask: torch.Tensor, + ): + """ + Args: + decoder_input_ids: (N, num_tokens), torch.int32 + decoder_mems_list_i: (N, num_tokens, 1024) + enc_states: (N, T, 1024) + enc_mask: (N, T) + Returns: + - logits: (N, 1, vocab_size) + - decoder_mems_list_i: (N, num_tokens_2, 1024) + """ + pos = decoder_input_ids[0][-1].item() + decoder_input_ids = decoder_input_ids[:, :-1] + + decoder_hidden_states = self.decoder.embedding.forward( + decoder_input_ids, start_pos=pos + ) + decoder_input_mask = torch.ones_like(decoder_input_ids).float() + + decoder_mems_list = self.decoder.decoder.forward( + decoder_hidden_states, + decoder_input_mask, + enc_states, + enc_mask, + [ + decoder_mems_list_0, + decoder_mems_list_1, + decoder_mems_list_2, + decoder_mems_list_3, + decoder_mems_list_4, + decoder_mems_list_5, + ], + return_mems=True, + ) + logits = self.log_softmax(hidden_states=decoder_mems_list[-1][:, -1:]) + + return logits, decoder_mems_list + + +def export_encoder(canary_model): + encoder = EncoderWrapper(canary_model) + x = torch.rand(1, 4000, 128) + x_lens = torch.tensor([x.shape[1]], dtype=torch.int64) + + encoder_filename = "encoder.onnx" + torch.onnx.export( + encoder, + (x, x_lens), + encoder_filename, + input_names=["x", "x_len"], + output_names=["enc_states", "enc_len", "enc_mask"], + opset_version=14, + dynamic_axes={ + "x": {0: "N", 1: "T"}, + "x_len": {0: "N"}, + "enc_states": {0: "N", 1: "T"}, + "enc_len": {0: "N"}, + "enc_mask": {0: "N", 1: "T"}, + }, + ) + + +def export_decoder(canary_model): + decoder = DecoderWrapper(canary_model) + decoder_input_ids = torch.tensor([[1, 0]], dtype=torch.int32) + + decoder_mems_list_0 = torch.zeros(1, 1, 1024) + decoder_mems_list_1 = torch.zeros(1, 1, 1024) + decoder_mems_list_2 = torch.zeros(1, 1, 1024) + decoder_mems_list_3 = torch.zeros(1, 1, 1024) + decoder_mems_list_4 = torch.zeros(1, 1, 1024) + decoder_mems_list_5 = torch.zeros(1, 1, 1024) + + enc_states = torch.zeros(1, 1000, 1024) + enc_mask = torch.ones(1, 1000).bool() + + torch.onnx.export( + decoder, + ( + decoder_input_ids, + decoder_mems_list_0, + decoder_mems_list_1, + decoder_mems_list_2, + decoder_mems_list_3, + decoder_mems_list_4, + decoder_mems_list_5, + enc_states, + enc_mask, + ), + "decoder.onnx", + opset_version=14, + input_names=[ + "decoder_input_ids", + "decoder_mems_list_0", + "decoder_mems_list_1", + "decoder_mems_list_2", + "decoder_mems_list_3", + "decoder_mems_list_4", + "decoder_mems_list_5", + "enc_states", + "enc_mask", + ], + output_names=[ + "logits", + "next_decoder_mem_list_0", + "next_decoder_mem_list_1", + "next_decoder_mem_list_2", + "next_decoder_mem_list_3", + "next_decoder_mem_list_4", + "next_decoder_mem_list_5", + ], + dynamic_axes={ + "decoder_input_ids": {1: "num_tokens"}, + "decoder_mems_list_0": {1: "num_tokens"}, + "decoder_mems_list_1": {1: "num_tokens"}, + "decoder_mems_list_2": {1: "num_tokens"}, + "decoder_mems_list_3": {1: "num_tokens"}, + "decoder_mems_list_4": {1: "num_tokens"}, + "decoder_mems_list_5": {1: "num_tokens"}, + "enc_states": {1: "T"}, + "enc_mask": {1: "T"}, + }, + ) + + +def export_tokens(canary_model): + with open("./tokens.txt", "w", encoding="utf-8") as f: + for i in range(canary_model.tokenizer.vocab_size): + s = canary_model.tokenizer.ids_to_text([i]) + f.write(f"{s} {i}\n") + print("Saved to tokens.txt") + + +@torch.no_grad() +def main(): + canary_model = EncDecMultiTaskModel.from_pretrained("nvidia/canary-180m-flash") + export_tokens(canary_model) + export_encoder(canary_model) + export_decoder(canary_model) + + for m in ["encoder", "decoder"]: + if m == "encoder": + # we don't quantize the decoder with int8 since the accuracy drops + quantize_dynamic( + model_input=f"./{m}.onnx", + model_output=f"./{m}.int8.onnx", + weight_type=QuantType.QUInt8, + ) + + export_onnx_fp16(f"{m}.onnx", f"{m}.fp16.onnx") + + os.system("ls -lh *.onnx") + + +if __name__ == "__main__": + main() diff --git a/scripts/nemo/canary/run_180m_flash.sh b/scripts/nemo/canary/run_180m_flash.sh new file mode 100755 index 00000000..3780246a --- /dev/null +++ b/scripts/nemo/canary/run_180m_flash.sh @@ -0,0 +1,131 @@ +#!/usr/bin/env bash +# Copyright 2025 Xiaomi Corp. (authors: Fangjun Kuang) + +set -ex + +log() { + # This function is from espnet + local fname=${BASH_SOURCE[1]##*/} + echo -e "$(date '+%Y-%m-%d %H:%M:%S') (${fname}:${BASH_LINENO[0]}:${FUNCNAME[1]}) $*" +} + +curl -SL -O https://github.com/k2-fsa/sherpa-onnx/releases/download/asr-models/de.wav +curl -SL -O https://github.com/k2-fsa/sherpa-onnx/releases/download/asr-models/en.wav + +pip install \ + nemo_toolkit['asr'] \ + "numpy<2" \ + ipython \ + kaldi-native-fbank \ + librosa \ + onnx==1.17.0 \ + onnxmltools \ + onnxruntime==1.17.1 \ + soundfile + +python3 ./export_onnx_180m_flash.py +ls -lh *.onnx + + +log "-----fp32------" + +python3 ./test_180m_flash.py \ + --encoder ./encoder.onnx \ + --decoder ./decoder.onnx \ + --source-lang en \ + --target-lang en \ + --tokens ./tokens.txt \ + --wav ./en.wav + +python3 ./test_180m_flash.py \ + --encoder ./encoder.onnx \ + --decoder ./decoder.onnx \ + --source-lang en \ + --target-lang de \ + --tokens ./tokens.txt \ + --wav ./en.wav + +python3 ./test_180m_flash.py \ + --encoder ./encoder.onnx \ + --decoder ./decoder.onnx \ + --source-lang de \ + --target-lang de \ + --tokens ./tokens.txt \ + --wav ./de.wav + +python3 ./test_180m_flash.py \ + --encoder ./encoder.onnx \ + --decoder ./decoder.onnx \ + --source-lang de \ + --target-lang en \ + --tokens ./tokens.txt \ + --wav ./de.wav + + +log "-----int8------" + +python3 ./test_180m_flash.py \ + --encoder ./encoder.int8.onnx \ + --decoder ./decoder.fp16.onnx \ + --source-lang en \ + --target-lang en \ + --tokens ./tokens.txt \ + --wav ./en.wav + +python3 ./test_180m_flash.py \ + --encoder ./encoder.int8.onnx \ + --decoder ./decoder.fp16.onnx \ + --source-lang en \ + --target-lang de \ + --tokens ./tokens.txt \ + --wav ./en.wav + +python3 ./test_180m_flash.py \ + --encoder ./encoder.int8.onnx \ + --decoder ./decoder.fp16.onnx \ + --source-lang de \ + --target-lang de \ + --tokens ./tokens.txt \ + --wav ./de.wav + +python3 ./test_180m_flash.py \ + --encoder ./encoder.int8.onnx \ + --decoder ./decoder.fp16.onnx \ + --source-lang de \ + --target-lang en \ + --tokens ./tokens.txt \ + --wav ./de.wav + +log "-----fp16------" + +python3 ./test_180m_flash.py \ + --encoder ./encoder.fp16.onnx \ + --decoder ./decoder.fp16.onnx \ + --source-lang en \ + --target-lang en \ + --tokens ./tokens.txt \ + --wav ./en.wav + +python3 ./test_180m_flash.py \ + --encoder ./encoder.fp16.onnx \ + --decoder ./decoder.fp16.onnx \ + --source-lang en \ + --target-lang de \ + --tokens ./tokens.txt \ + --wav ./en.wav + +python3 ./test_180m_flash.py \ + --encoder ./encoder.fp16.onnx \ + --decoder ./decoder.fp16.onnx \ + --source-lang de \ + --target-lang de \ + --tokens ./tokens.txt \ + --wav ./de.wav + +python3 ./test_180m_flash.py \ + --encoder ./encoder.fp16.onnx \ + --decoder ./decoder.fp16.onnx \ + --source-lang de \ + --target-lang en \ + --tokens ./tokens.txt \ + --wav ./de.wav diff --git a/scripts/nemo/canary/test_180m_flash.py b/scripts/nemo/canary/test_180m_flash.py new file mode 100755 index 00000000..cfa04250 --- /dev/null +++ b/scripts/nemo/canary/test_180m_flash.py @@ -0,0 +1,299 @@ +#!/usr/bin/env python3 +# Copyright 2025 Xiaomi Corp. (authors: Fangjun Kuang) + +import argparse +import time +from pathlib import Path +from typing import List + +import kaldi_native_fbank as knf +import librosa +import numpy as np +import onnxruntime as ort +import soundfile as sf + + +def get_args(): + parser = argparse.ArgumentParser() + parser.add_argument( + "--encoder", type=str, required=True, help="Path to encoder.onnx" + ) + parser.add_argument( + "--decoder", type=str, required=True, help="Path to decoder.onnx" + ) + + parser.add_argument("--tokens", type=str, required=True, help="Path to tokens.txt") + + parser.add_argument( + "--source-lang", + type=str, + help="Language of the input wav. Valid values are: en, de, es, fr", + ) + parser.add_argument( + "--target-lang", + type=str, + help="Language of the recognition result. Valid values are: en, de, es, fr", + ) + parser.add_argument( + "--use-pnc", + type=int, + default=1, + help="1 to enable cases and punctuations. 0 to disable that", + ) + + parser.add_argument("--wav", type=str, required=True, help="Path to test.wav") + + return parser.parse_args() + + +def display(sess, model): + print(f"=========={model} Input==========") + for i in sess.get_inputs(): + print(i) + print(f"=========={model }Output==========") + for i in sess.get_outputs(): + print(i) + + +class OnnxModel: + def __init__( + self, + encoder: str, + decoder: str, + ): + self.init_encoder(encoder) + display(self.encoder, "encoder") + + self.init_decoder(decoder) + display(self.decoder, "decoder") + + def init_encoder(self, encoder): + session_opts = ort.SessionOptions() + session_opts.inter_op_num_threads = 1 + session_opts.intra_op_num_threads = 1 + + self.encoder = ort.InferenceSession( + encoder, + sess_options=session_opts, + providers=["CPUExecutionProvider"], + ) + + meta = self.encoder.get_modelmeta().custom_metadata_map + # self.normalize_type = meta["normalize_type"] + self.normalize_type = "per_feature" + print(meta) + + def init_decoder(self, decoder): + session_opts = ort.SessionOptions() + session_opts.inter_op_num_threads = 1 + session_opts.intra_op_num_threads = 1 + + self.decoder = ort.InferenceSession( + decoder, + sess_options=session_opts, + providers=["CPUExecutionProvider"], + ) + + def run_encoder(self, x: np.ndarray, x_lens: np.ndarray): + """ + Args: + x: (N, T, C), np.float + x_lens: (N,), np.int64 + Returns: + enc_states: (N, T, C) + enc_lens: (N,), np.int64 + enc_masks: (N, T), np.bool + """ + enc_states, enc_lens, enc_masks = self.encoder.run( + [ + self.encoder.get_outputs()[0].name, + self.encoder.get_outputs()[1].name, + self.encoder.get_outputs()[2].name, + ], + { + self.encoder.get_inputs()[0].name: x, + self.encoder.get_inputs()[1].name: x_lens, + }, + ) + return enc_states, enc_lens, enc_masks + + def run_decoder( + self, + decoder_input_ids: np.ndarray, + decoder_mems_list: List[np.ndarray], + enc_states: np.ndarray, + enc_mask: np.ndarray, + ): + """ + Args: + decoder_input_ids: (N, num_tokens), int32 + decoder_mems_list: a list of tensors, each of which is (N, num_tokens, C) + enc_states: (N, T, C), float + enc_mask: (N, T), bool + Returns: + logits: (1, 1, vocab_size), float + new_decoder_mems_list: + """ + (logits, *new_decoder_mems_list) = self.decoder.run( + [ + self.decoder.get_outputs()[0].name, + self.decoder.get_outputs()[1].name, + self.decoder.get_outputs()[2].name, + self.decoder.get_outputs()[3].name, + self.decoder.get_outputs()[4].name, + self.decoder.get_outputs()[5].name, + self.decoder.get_outputs()[6].name, + ], + { + self.decoder.get_inputs()[0].name: decoder_input_ids, + self.decoder.get_inputs()[1].name: decoder_mems_list[0], + self.decoder.get_inputs()[2].name: decoder_mems_list[1], + self.decoder.get_inputs()[3].name: decoder_mems_list[2], + self.decoder.get_inputs()[4].name: decoder_mems_list[3], + self.decoder.get_inputs()[5].name: decoder_mems_list[4], + self.decoder.get_inputs()[6].name: decoder_mems_list[5], + self.decoder.get_inputs()[7].name: enc_states, + self.decoder.get_inputs()[8].name: enc_mask, + }, + ) + return logits, new_decoder_mems_list + + +def create_fbank(): + opts = knf.FbankOptions() + opts.frame_opts.dither = 0 + opts.frame_opts.remove_dc_offset = False + opts.frame_opts.window_type = "hann" + + opts.mel_opts.low_freq = 0 + opts.mel_opts.num_bins = 128 + + opts.mel_opts.is_librosa = True + + fbank = knf.OnlineFbank(opts) + return fbank + + +def compute_features(audio, fbank): + assert len(audio.shape) == 1, audio.shape + fbank.accept_waveform(16000, audio) + ans = [] + processed = 0 + while processed < fbank.num_frames_ready: + ans.append(np.array(fbank.get_frame(processed))) + processed += 1 + ans = np.stack(ans) + return ans + + +def main(): + args = get_args() + assert Path(args.encoder).is_file(), args.encoder + assert Path(args.decoder).is_file(), args.decoder + assert Path(args.tokens).is_file(), args.tokens + assert Path(args.wav).is_file(), args.wav + + print(vars(args)) + + id2token = dict() + token2id = dict() + with open(args.tokens, encoding="utf-8") as f: + for line in f: + fields = line.split() + if len(fields) == 2: + t, idx = fields[0], int(fields[1]) + if line[0] == " ": + t = " " + t + else: + t = " " + idx = int(fields[0]) + + id2token[idx] = t + token2id[t] = idx + + model = OnnxModel(args.encoder, args.decoder) + + fbank = create_fbank() + + start = time.time() + audio, sample_rate = sf.read(args.wav, dtype="float32", always_2d=True) + audio = audio[:, 0] # only use the first channel + if sample_rate != 16000: + audio = librosa.resample( + audio, + orig_sr=sample_rate, + target_sr=16000, + ) + sample_rate = 16000 + + features = compute_features(audio, fbank) + if model.normalize_type != "": + assert model.normalize_type == "per_feature", model.normalize_type + mean = features.mean(axis=1, keepdims=True) + stddev = features.std(axis=1, keepdims=True) + 1e-5 + features = (features - mean) / stddev + + features = np.expand_dims(features, axis=0) + # features.shape: (1, 291, 128) + + features_len = np.array([features.shape[1]], dtype=np.int64) + + enc_states, _, enc_masks = model.run_encoder(features, features_len) + + decoder_input_ids = [] + decoder_input_ids.append(token2id["<|startofcontext|>"]) + decoder_input_ids.append(token2id["<|startoftranscript|>"]) + decoder_input_ids.append(token2id["<|emo:undefined|>"]) + if args.source_lang in ("en", "es", "de", "fr"): + decoder_input_ids.append(token2id[f"<|{args.source_lang}|>"]) + else: + decoder_input_ids.append(token2id[f"<|en|>"]) + + if args.target_lang in ("en", "es", "de", "fr"): + decoder_input_ids.append(token2id[f"<|{args.target_lang}|>"]) + else: + decoder_input_ids.append(token2id[f"<|en|>"]) + + if args.use_pnc: + decoder_input_ids.append(token2id[f"<|pnc|>"]) + else: + decoder_input_ids.append(token2id[f"<|nopnc|>"]) + + decoder_input_ids.append(token2id[f"<|noitn|>"]) + decoder_input_ids.append(token2id["<|notimestamp|>"]) + decoder_input_ids.append(token2id["<|nodiarize|>"]) + + decoder_input_ids.append(0) + + decoder_mems_list = [np.zeros((1, 0, 1024), dtype=np.float32) for _ in range(6)] + + logits, decoder_mems_list = model.run_decoder( + np.array([decoder_input_ids], dtype=np.int32), + decoder_mems_list, + enc_states, + enc_masks, + ) + tokens = [logits.argmax()] + print("decoder_input_ids", decoder_input_ids) + eos = token2id["<|endoftext|>"] + + for i in range(1, 200): + decoder_input_ids = [tokens[-1], i] + logits, decoder_mems_list = model.run_decoder( + np.array([decoder_input_ids], dtype=np.int32), + decoder_mems_list, + enc_states, + enc_masks, + ) + t = logits.argmax() + if t == eos: + break + tokens.append(t) + print("len(tokens)", len(tokens)) + print("tokens", tokens) + text = "".join([id2token[i] for i in tokens]) + print("text:", text) + + +if __name__ == "__main__": + main()