521 lines
19 KiB
Plaintext
521 lines
19 KiB
Plaintext
{
|
||
"metadata": {
|
||
"language_info": {
|
||
"codemirror_mode": {
|
||
"name": "ipython",
|
||
"version": 3
|
||
},
|
||
"file_extension": ".py",
|
||
"mimetype": "text/x-python",
|
||
"name": "python",
|
||
"nbconvert_exporter": "python",
|
||
"pygments_lexer": "ipython3",
|
||
"version": 3
|
||
},
|
||
"orig_nbformat": 2
|
||
},
|
||
"nbformat": 4,
|
||
"nbformat_minor": 2,
|
||
"cells": [
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": null,
|
||
"metadata": {},
|
||
"outputs": [],
|
||
"source": [
|
||
"%%capture\n",
|
||
"!pip install datasets==1.4.1\n",
|
||
"!pip install transformers==4.4.0\n",
|
||
"!pip install torchaudio\n",
|
||
"!pip install librosa\n",
|
||
"!pip install jiwer\n",
|
||
"!pip install mecab-python3\n",
|
||
"!pip install unidic-lite\n",
|
||
"!pip isntall audiomentations"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": null,
|
||
"metadata": {},
|
||
"outputs": [],
|
||
"source": [
|
||
"from transformers import Wav2Vec2CTCTokenizer, Wav2Vec2FeatureExtractor, Wav2Vec2Processor\n",
|
||
"from datasets import load_dataset, load_metric, ClassLabel, Dataset\n",
|
||
"from audiomentations import Compose, AddGaussianNoise, Gain, PitchShift, TimeStretch, Shift\n",
|
||
"from torch.optim.lr_scheduler import LambdaLR\n",
|
||
"from transformers import Wav2Vec2ForCTC, TrainingArguments, Trainer\n",
|
||
"\n",
|
||
"import pandas as pd\n",
|
||
"import numpy as np\n",
|
||
"import soundfile as sf\n",
|
||
"import re\n",
|
||
"import json\n",
|
||
"import torchaudio\n",
|
||
"import librosa\n",
|
||
"import datasets\n",
|
||
"import MeCab\n",
|
||
"import pykakasi\n",
|
||
"import random\n",
|
||
"\n",
|
||
"import torch\n",
|
||
"from dataclasses import dataclass, field\n",
|
||
"from typing import Any, Dict, List, Optional, Union"
|
||
]
|
||
},
|
||
{
|
||
"source": [
|
||
"# Load dataset and prepare processor"
|
||
],
|
||
"cell_type": "markdown",
|
||
"metadata": {}
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": null,
|
||
"metadata": {},
|
||
"outputs": [],
|
||
"source": [
|
||
"# Load public dataset from University of Tokyo\n",
|
||
"!wget http://ss-takashi.sakura.ne.jp/corpus/jsut_ver1.1.zip\n",
|
||
"!unzip jsut_ver1.1.zip\n",
|
||
"\n",
|
||
"path = 'jsut_ver1.1/basic5000/'\n",
|
||
"df = pd.read_csv(path + 'transcript_utf8.txt', header = None, delimiter = \":\", names=[\"path\", \"sentence\"], index_col=False)\n",
|
||
"df[\"path\"] = df[\"path\"].map(lambda x: path + 'wav/' + x + \".wav\")\n",
|
||
"df.head()\n",
|
||
"\n",
|
||
"jsut_voice_train = Dataset.from_pandas(df)"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": null,
|
||
"metadata": {},
|
||
"outputs": [],
|
||
"source": [
|
||
"# Import training dataset\n",
|
||
"common_voice_train = load_dataset('common_voice', 'ja',split='train+validation')\n",
|
||
"common_voice_test = load_dataset('common_voice', 'ja', split='test')\n",
|
||
"\n",
|
||
"# Remove unwanted columns\n",
|
||
"common_voice_train = common_voice_train.remove_columns([\"accent\", \"age\", \"client_id\", \"down_votes\", \"gender\", \"locale\", \"segment\", \"up_votes\"])\n",
|
||
"common_voice_test = common_voice_test.remove_columns([\"accent\", \"age\", \"client_id\", \"down_votes\", \"gender\", \"locale\", \"segment\", \"up_votes\"])\n",
|
||
"\n",
|
||
"# Concat common voice and public dataset\n",
|
||
"common_voice_train = datasets.concatenate_datasets([jsut_voice_train, common_voice_train])"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": null,
|
||
"metadata": {},
|
||
"outputs": [],
|
||
"source": [
|
||
"# Parser Japanese sentence. Ex: \"pythonが大好きです\" -> \"python が 大好き です EOS\"\n",
|
||
"wakati = MeCab.Tagger(\"-Owakati\")\n",
|
||
"\n",
|
||
"# Unwanted token\n",
|
||
"chars_to_ignore_regex = '[\\,\\、\\。\\.\\「\\」\\…\\?\\・]'\n",
|
||
"\n",
|
||
"def remove_special_characters(batch):\n",
|
||
" batch[\"sentence\"] = wakati.parse(batch[\"sentence\"]).strip()\n",
|
||
" batch[\"sentence\"] = re.sub(chars_to_ignore_regex,'', batch[\"sentence\"]).strip()\n",
|
||
" return batch\n",
|
||
"\n",
|
||
"common_voice_train = common_voice_train.map(remove_special_characters)\n",
|
||
"common_voice_test = common_voice_test.map(remove_special_characters)"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": null,
|
||
"metadata": {},
|
||
"outputs": [],
|
||
"source": [
|
||
"# make vocab file\n",
|
||
"def extract_all_chars(batch):\n",
|
||
" all_text = \" \".join(batch[\"sentence\"])\n",
|
||
" vocab = list(set(all_text))\n",
|
||
" return {\"vocab\": [vocab], \"all_text\": [all_text]}\n",
|
||
"\n",
|
||
"# make vocab list and text\n",
|
||
"vocab_train = common_voice_train.map(extract_all_chars, batched=True, batch_size=-1, keep_in_memory=True, remove_columns=common_voice_train.column_names)\n",
|
||
"vocab_test = common_voice_test.map(extract_all_chars, batched=True, batch_size=-1, keep_in_memory=True, remove_columns=common_voice_test.column_names)\n",
|
||
"\n",
|
||
"# concate vocab from train and test set\n",
|
||
"vocab_list = list(set(vocab_train[\"vocab\"][0]) | set(vocab_test[\"vocab\"][0]))\n",
|
||
"vocab_dict = {v: k for k, v in enumerate(vocab_list)}\n",
|
||
"print(len(vocab_dict))\n",
|
||
"vocab_dict[\"|\"] = vocab_dict[\" \"]\n",
|
||
"del vocab_dict[\" \"]\n",
|
||
"\n",
|
||
"# create unk and pad token\n",
|
||
"vocab_dict[\"[UNK]\"] = len(vocab_dict)\n",
|
||
"vocab_dict[\"[PAD]\"] = len(vocab_dict)\n",
|
||
"\n",
|
||
"# save to json file\n",
|
||
"with open('vocab.json', 'w') as vocab_file:\n",
|
||
" json.dump(vocab_dict, vocab_file, indent=2, ensure_ascii=False)"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": null,
|
||
"metadata": {},
|
||
"outputs": [],
|
||
"source": [
|
||
"save_dir = \"./output_models\"\n",
|
||
"# wrap tokenizer and feature extractor to processor\n",
|
||
"tokenizer = Wav2Vec2CTCTokenizer(\"./vocab_demo.json\", unk_token=\"[UNK]\", pad_token=\"[PAD]\", word_delimiter_token=\"|\")\n",
|
||
"feature_extractor = Wav2Vec2FeatureExtractor(feature_size=1, sampling_rate=16000, padding_value=0.0, do_normalize=True, return_attention_mask=True)\n",
|
||
"processor = Wav2Vec2Processor(feature_extractor=feature_extractor, tokenizer=tokenizer)\n",
|
||
"processor.save_pretrained(save_dir)"
|
||
]
|
||
},
|
||
{
|
||
"source": [
|
||
"# Prepare train and test dataset "
|
||
],
|
||
"cell_type": "markdown",
|
||
"metadata": {}
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": null,
|
||
"metadata": {},
|
||
"outputs": [],
|
||
"source": [
|
||
"# convert audio from 48kHz to 16kHz (standard sample rate of wave2vec model)\n",
|
||
"def speech_file_to_array_fn(batch):\n",
|
||
" speech_array, sampling_rate = torchaudio.load(batch[\"path\"])\n",
|
||
" batch[\"speech\"] = librosa.resample(np.asarray(speech_array[0].numpy()), 48_000, 16_000)\n",
|
||
" batch[\"sampling_rate\"] = 16_000\n",
|
||
" batch[\"target_text\"] = batch[\"sentence\"]\n",
|
||
" return batch\n",
|
||
"\n",
|
||
"common_voice_train = common_voice_train.map(speech_file_to_array_fn, remove_columns=common_voice_train.column_names,num_proc=4)\n",
|
||
"common_voice_test = common_voice_test.map(speech_file_to_array_fn,remove_columns=common_voice_test.column_names, num_proc=4) "
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": null,
|
||
"metadata": {},
|
||
"outputs": [],
|
||
"source": [
|
||
"# do augment to enrich common voice dataset \n",
|
||
"augment = Compose([\n",
|
||
" AddGaussianNoise(min_amplitude=0.0001, max_amplitude=0.001, p=0.8),\n",
|
||
" PitchShift(min_semitones=-1, max_semitones=1, p=0.8),\n",
|
||
" Gain(min_gain_in_db=-6, max_gain_in_db=6, p=0.8),\n",
|
||
" TimeStretch(min_rate=0.8, max_rate=1.25, p=0.8)\n",
|
||
"\n",
|
||
"])\n",
|
||
"\n",
|
||
"def augmented_speech(batch, augment):\n",
|
||
" samples = np.array(batch[\"speech\"])\n",
|
||
" batch[\"speech\"] = augment(samples=samples, sample_rate=16000)\n",
|
||
" batch[\"sampling_rate\"] = 16_000\n",
|
||
" batch[\"target_text\"] = batch[\"target_text\"]\n",
|
||
" return batch\n",
|
||
"\n",
|
||
"# augument 50% of trainset\n",
|
||
"common_voice_train_augmented = common_voice_train.train_test_split(test_size = 0.5)['train']\n",
|
||
"common_voice_train_augmented = common_voice_train_augmented.map(lambda batch: augmented_speech(batch, augment), num_proc=4)\n",
|
||
"\n",
|
||
"# concate with trainset\n",
|
||
"common_voice_train = datasets.concatenate_datasets([common_voice_train_augmented, common_voice_train])"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": null,
|
||
"metadata": {},
|
||
"outputs": [],
|
||
"source": [
|
||
"def prepare_dataset(batch):\n",
|
||
" # check that all files have the correct sampling rate\n",
|
||
" assert (\n",
|
||
" len(set(batch[\"sampling_rate\"])) == 1\n",
|
||
" ), f\"Make sure all inputs have the same sampling rate of {processor.feature_extractor.sampling_rate}.\"\n",
|
||
"\n",
|
||
" batch[\"input_values\"] = processor(batch[\"speech\"], sampling_rate=batch[\"sampling_rate\"][0]).input_values\n",
|
||
" \n",
|
||
" with processor.as_target_processor():\n",
|
||
" batch[\"labels\"] = processor(batch[\"target_text\"]).input_ids\n",
|
||
" return batch\n",
|
||
" \n",
|
||
"# prepare dataset\n",
|
||
"common_voice_train = common_voice_train.map(prepare_dataset, remove_columns=common_voice_train.column_names, batch_size=8, num_proc=4, batched=True)\n",
|
||
"common_voice_test = common_voice_test.map(prepare_dataset, remove_columns=common_voice_test.column_names, batch_size=8, num_proc=4, batched=True)"
|
||
]
|
||
},
|
||
{
|
||
"source": [
|
||
"# Training"
|
||
],
|
||
"cell_type": "markdown",
|
||
"metadata": {}
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": null,
|
||
"metadata": {},
|
||
"outputs": [],
|
||
"source": [
|
||
"# create data collator\n",
|
||
"@dataclass\n",
|
||
"class DataCollatorCTCWithPadding:\n",
|
||
"\n",
|
||
" processor: Wav2Vec2Processor\n",
|
||
" padding: Union[bool, str] = True\n",
|
||
" max_length: Optional[int] = None\n",
|
||
" max_length_labels: Optional[int] = None\n",
|
||
" pad_to_multiple_of: Optional[int] = None\n",
|
||
" pad_to_multiple_of_labels: Optional[int] = None\n",
|
||
"\n",
|
||
" def __call__(self, features: List[Dict[str, Union[List[int], torch.Tensor]]]) -> Dict[str, torch.Tensor]:\n",
|
||
" input_features = [{\"input_values\": feature[\"input_values\"]} for feature in features]\n",
|
||
" label_features = [{\"input_ids\": feature[\"labels\"]} for feature in features]\n",
|
||
"\n",
|
||
" batch = self.processor.pad(\n",
|
||
" input_features,\n",
|
||
" padding=self.padding,\n",
|
||
" max_length=self.max_length,\n",
|
||
" pad_to_multiple_of=self.pad_to_multiple_of,\n",
|
||
" return_tensors=\"pt\",\n",
|
||
" )\n",
|
||
" with self.processor.as_target_processor():\n",
|
||
" labels_batch = self.processor.pad(\n",
|
||
" label_features,\n",
|
||
" padding=self.padding,\n",
|
||
" max_length=self.max_length_labels,\n",
|
||
" pad_to_multiple_of=self.pad_to_multiple_of_labels,\n",
|
||
" return_tensors=\"pt\",\n",
|
||
" )\n",
|
||
"\n",
|
||
" # replace padding with -100 to ignore loss correctly\n",
|
||
" labels = labels_batch[\"input_ids\"].masked_fill(labels_batch.attention_mask.ne(1), -100)\n",
|
||
"\n",
|
||
" batch[\"labels\"] = labels\n",
|
||
"\n",
|
||
" return batch\n",
|
||
"\n",
|
||
"data_collator = DataCollatorCTCWithPadding(processor=processor, padding=True)"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": null,
|
||
"metadata": {},
|
||
"outputs": [],
|
||
"source": [
|
||
"# make metric function\n",
|
||
"wer_metric = load_metric(\"wer\")\n",
|
||
"\n",
|
||
"def compute_metrics(pred):\n",
|
||
" pred_logits = pred.predictions\n",
|
||
" pred_ids = np.argmax(pred_logits, axis=-1)\n",
|
||
"\n",
|
||
" pred.label_ids[pred.label_ids == -100] = processor.tokenizer.pad_token_id\n",
|
||
"\n",
|
||
" pred_str = processor.batch_decode(pred_ids)\n",
|
||
" # we do not want to group tokens when computing the metrics\n",
|
||
" label_str = processor.batch_decode(pred.label_ids, group_tokens=False)\n",
|
||
"\n",
|
||
" wer = wer_metric.compute(predictions=pred_str, references=label_str)\n",
|
||
"\n",
|
||
" return {\"wer\": wer}"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": null,
|
||
"metadata": {},
|
||
"outputs": [],
|
||
"source": [
|
||
"# create custom learning scheduler\n",
|
||
"\n",
|
||
"# polynomial decay\n",
|
||
"def get_polynomial_decay_schedule_with_warmup(\n",
|
||
" optimizer, num_warmup_steps, num_training_steps, lr_end=1e-7, power=1.2, last_epoch=-1\n",
|
||
"):\n",
|
||
"\n",
|
||
" lr_init = optimizer.defaults[\"lr\"]\n",
|
||
" assert lr_init > lr_end, f\"lr_end ({lr_end}) must be be smaller than initial lr ({lr_init})\"\n",
|
||
"\n",
|
||
" def lr_lambda(current_step: int):\n",
|
||
" if current_step < num_warmup_steps:\n",
|
||
" return float(current_step) / float(max(1, num_warmup_steps))\n",
|
||
" elif current_step > num_training_steps:\n",
|
||
" return lr_end / lr_init # as LambdaLR multiplies by lr_init\n",
|
||
" else:\n",
|
||
" lr_range = lr_init - lr_end\n",
|
||
" decay_steps = num_training_steps - num_warmup_steps\n",
|
||
" pct_remaining = 1 - (current_step - num_warmup_steps) / decay_steps\n",
|
||
" decay = lr_range * pct_remaining ** power + lr_end\n",
|
||
" return decay / lr_init # as LambdaLR multiplies by lr_init\n",
|
||
"\n",
|
||
" return LambdaLR(optimizer, lr_lambda, last_epoch)\n",
|
||
" \n",
|
||
"# wrap custom learning scheduler with trainer\n",
|
||
"class PolyTrainer(Trainer):\n",
|
||
" def __init__(self, *args, **kwargs):\n",
|
||
" super().__init__(*args, **kwargs)\n",
|
||
" \n",
|
||
" def create_scheduler(self, num_training_steps: int):\n",
|
||
" self.lr_scheduler = get_polynomial_decay_schedule_with_warmup(self.optimizer, \n",
|
||
" num_warmup_steps=self.args.warmup_steps,\n",
|
||
" num_training_steps=num_training_steps)\n",
|
||
" def create_optimizer_and_scheduler(self, num_training_steps: int):\n",
|
||
" self.create_optimizer()\n",
|
||
" self.create_scheduler(num_training_steps)"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": null,
|
||
"metadata": {},
|
||
"outputs": [],
|
||
"source": [
|
||
"# load pretrain model\n",
|
||
"model = Wav2Vec2ForCTC.from_pretrained(\n",
|
||
" \"facebook/wav2vec2-large-xlsr-53\", \n",
|
||
" attention_dropout=0.1,\n",
|
||
" hidden_dropout=0.1,\n",
|
||
" feat_proj_dropout=0.1,\n",
|
||
" mask_time_prob=0.1, \n",
|
||
" layerdrop=0.1,\n",
|
||
" gradient_checkpointing=True, \n",
|
||
" ctc_loss_reduction=\"mean\", \n",
|
||
" pad_token_id=processor.tokenizer.pad_token_id,\n",
|
||
" vocab_size=len(processor.tokenizer)\n",
|
||
")\n",
|
||
"# free feature extractor\n",
|
||
"model.freeze_feature_extractor()\n",
|
||
"\n",
|
||
"# define train argument\n",
|
||
"training_args = TrainingArguments(\n",
|
||
" output_dir=save_dir,\n",
|
||
" group_by_length=True,\n",
|
||
" per_device_train_batch_size=32,\n",
|
||
" gradient_accumulation_steps=2,\n",
|
||
" evaluation_strategy=\"steps\",\n",
|
||
" num_train_epochs=200,\n",
|
||
" fp16=True,\n",
|
||
" save_steps=2400, \n",
|
||
" eval_steps=800,\n",
|
||
" logging_steps=800, \n",
|
||
" learning_rate=1e-4, \n",
|
||
" warmup_steps=1500, \n",
|
||
" save_total_limit=2,\n",
|
||
" load_best_model_at_end = True, \n",
|
||
" metric_for_best_model='wer', \n",
|
||
" greater_is_better=False\n",
|
||
")\n",
|
||
"\n",
|
||
"# wrap everything to Trainer\n",
|
||
"trainer = PolyTrainer(\n",
|
||
" model=model,\n",
|
||
" data_collator=data_collator,\n",
|
||
" args=training_args,\n",
|
||
" compute_metrics=compute_metrics,\n",
|
||
" train_dataset=common_voice_train,\n",
|
||
" eval_dataset=common_voice_test,\n",
|
||
" tokenizer=processor.feature_extractor,\n",
|
||
")"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": null,
|
||
"metadata": {},
|
||
"outputs": [],
|
||
"source": [
|
||
"# training\n",
|
||
"train_result = trainer.train()"
|
||
]
|
||
},
|
||
{
|
||
"source": [
|
||
"# Testing result"
|
||
],
|
||
"cell_type": "markdown",
|
||
"metadata": {}
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": null,
|
||
"metadata": {},
|
||
"outputs": [],
|
||
"source": [
|
||
"import torch\n",
|
||
"import torchaudio\n",
|
||
"from datasets import load_dataset, load_metric\n",
|
||
"from transformers import Wav2Vec2ForCTC, Wav2Vec2Processor\n",
|
||
"import MeCab\n",
|
||
"import pykakasi\n",
|
||
"import re\n",
|
||
"\n",
|
||
"#config\n",
|
||
"wakati = MeCab.Tagger(\"-Owakati\")\n",
|
||
"chars_to_ignore_regex = '[\\,\\、\\。\\.\\「\\」\\…\\?\\・]'\n",
|
||
"\n",
|
||
"#load model\n",
|
||
"processor = Wav2Vec2Processor.from_pretrained(save_dir)\n",
|
||
"test_model = Wav2Vec2ForCTC.from_pretrained(save_dir)\n",
|
||
"test_model.to(\"cuda\")\n",
|
||
"resampler = torchaudio.transforms.Resample(48_000, 16_000)\n",
|
||
"\n",
|
||
"#load testdata\n",
|
||
"test_dataset = load_dataset(\"common_voice\", \"ja\", split=\"test\")\n",
|
||
"wer = load_metric(\"wer\")\n",
|
||
"\n",
|
||
"# Preprocessing the datasets.\n",
|
||
"def speech_file_to_array_fn(batch):\n",
|
||
" batch[\"sentence\"] = wakati.parse(batch[\"sentence\"]).strip()\n",
|
||
" batch[\"sentence\"] = re.sub(chars_to_ignore_regex,'', batch[\"sentence\"]).strip()\n",
|
||
" speech_array, sampling_rate = torchaudio.load(batch[\"path\"])\n",
|
||
" batch[\"speech\"] = resampler(speech_array).squeeze().numpy()\n",
|
||
" return batch\n",
|
||
"\n",
|
||
"test_dataset = test_dataset.map(speech_file_to_array_fn)\n",
|
||
"\n",
|
||
"# Preprocessing the datasets.\n",
|
||
"# We need to read the aduio files as arrays\n",
|
||
"def evaluate(batch):\n",
|
||
" inputs = processor(batch[\"speech\"], sampling_rate=16_000, return_tensors=\"pt\", padding=True)\n",
|
||
"\n",
|
||
" with torch.no_grad():\n",
|
||
" logits = test_model(inputs.input_values.to(\"cuda\"), attention_mask=inputs.attention_mask.to(\"cuda\")).logits\n",
|
||
" pred_ids = torch.argmax(logits, dim=-1)\n",
|
||
" batch[\"pred_strings\"] = processor.batch_decode(pred_ids)\n",
|
||
" return batch\n",
|
||
"\n",
|
||
"result = test_dataset.map(evaluate, batched=True, batch_size=8)\n",
|
||
"\n",
|
||
"print(\"WER: {:2f}\".format(100 * wer.compute(predictions=result[\"pred_strings\"], references=result[\"sentence\"])))"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": null,
|
||
"metadata": {},
|
||
"outputs": [],
|
||
"source": [
|
||
"# print some reusults\n",
|
||
"pick = random.randint(0, len(common_voice_test_transcription)-1)\n",
|
||
"input_dict = processor(common_voice_test[\"input_values\"][pick], return_tensors=\"pt\", padding=True)\n",
|
||
"logits = test_model(input_dict.input_values.to(\"cuda\")).logits\n",
|
||
"pred_ids = torch.argmax(logits, dim=-1)[0]\n",
|
||
"\n",
|
||
"print(\"Prediction:\")\n",
|
||
"print(processor.decode(pred_ids).strip())\n",
|
||
"\n",
|
||
"print(\"\\nLabel:\")\n",
|
||
"print(processor.decode(common_voice_test['labels'][pick]))\n"
|
||
]
|
||
}
|
||
]
|
||
} |